Skip to content

Commit e279f8c

Browse files
feat(azure): batch api (openai#839)
1 parent 6e556d9 commit e279f8c

File tree

2 files changed

+279
-6
lines changed

2 files changed

+279
-6
lines changed

src/index.ts

+5-6
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ export interface AzureClientOptions extends ClientOptions {
346346
/** API Client for interfacing with the Azure OpenAI API. */
347347
export class AzureOpenAI extends OpenAI {
348348
private _azureADTokenProvider: (() => Promise<string>) | undefined;
349+
private _deployment: string | undefined;
349350
apiVersion: string = '';
350351
/**
351352
* API Client for interfacing with the Azure OpenAI API.
@@ -412,11 +413,7 @@ export class AzureOpenAI extends OpenAI {
412413
);
413414
}
414415

415-
if (deployment) {
416-
baseURL = `${endpoint}/openai/deployments/${deployment}`;
417-
} else {
418-
baseURL = `${endpoint}/openai`;
419-
}
416+
baseURL = `${endpoint}/openai`;
420417
} else {
421418
if (endpoint) {
422419
throw new Errors.OpenAIError('baseURL and endpoint are mutually exclusive');
@@ -432,6 +429,7 @@ export class AzureOpenAI extends OpenAI {
432429

433430
this._azureADTokenProvider = azureADTokenProvider;
434431
this.apiVersion = apiVersion;
432+
this._deployment = deployment;
435433
}
436434

437435
override buildRequest(options: Core.FinalRequestOptions<unknown>): {
@@ -443,7 +441,7 @@ export class AzureOpenAI extends OpenAI {
443441
if (!Core.isObj(options.body)) {
444442
throw new Error('Expected request body to be an object');
445443
}
446-
const model = options.body['model'];
444+
const model = this._deployment || options.body['model'];
447445
delete options.body['model'];
448446
if (model !== undefined && !this.baseURL.includes('/deployments')) {
449447
options.path = `/deployments/${model}${options.path}`;
@@ -494,6 +492,7 @@ const _deployments_endpoints = new Set([
494492
'/audio/translations',
495493
'/audio/speech',
496494
'/images/generations',
495+
'/batches',
497496
]);
498497

499498
const API_KEY_SENTINEL = '<Missing Key>';

tests/lib/azure.test.ts

+274
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import { Headers } from 'openai/core';
44
import defaultFetch, { Response, type RequestInit, type RequestInfo } from 'node-fetch';
55

66
const apiVersion = '2024-02-15-preview';
7+
const deployment = 'deployment';
8+
const model = 'unused model';
79

810
describe('instantiate azure client', () => {
911
const env = process.env;
@@ -275,6 +277,278 @@ describe('instantiate azure client', () => {
275277
describe('azure request building', () => {
276278
const client = new AzureOpenAI({ baseURL: 'https://example.com', apiKey: 'My API Key', apiVersion });
277279

280+
describe('model to deployment mapping', function () {
281+
const testFetch = async (url: RequestInfo): Promise<Response> => {
282+
return new Response(JSON.stringify({ url }), { headers: { 'content-type': 'application/json' } });
283+
};
284+
describe('with client-level deployment', function () {
285+
const client = new AzureOpenAI({
286+
endpoint: 'https://example.com',
287+
apiKey: 'My API Key',
288+
apiVersion,
289+
deployment,
290+
fetch: testFetch,
291+
});
292+
293+
test('handles Batch', async () => {
294+
expect(
295+
await client.batches.create({
296+
completion_window: '24h',
297+
endpoint: '/v1/chat/completions',
298+
input_file_id: 'file-id',
299+
}),
300+
).toStrictEqual({
301+
url: `https://example.com/openai/deployments/${deployment}/batches?api-version=${apiVersion}`,
302+
});
303+
});
304+
305+
test('handles completions', async () => {
306+
expect(
307+
await client.completions.create({
308+
model,
309+
prompt: 'prompt',
310+
}),
311+
).toStrictEqual({
312+
url: `https://example.com/openai/deployments/${deployment}/completions?api-version=${apiVersion}`,
313+
});
314+
});
315+
316+
test('handles chat completions', async () => {
317+
expect(
318+
await client.chat.completions.create({
319+
model,
320+
messages: [{ role: 'system', content: 'Hello' }],
321+
}),
322+
).toStrictEqual({
323+
url: `https://example.com/openai/deployments/${deployment}/chat/completions?api-version=${apiVersion}`,
324+
});
325+
});
326+
327+
test('handles embeddings', async () => {
328+
expect(
329+
await client.embeddings.create({
330+
model,
331+
input: 'input',
332+
}),
333+
).toStrictEqual({
334+
url: `https://example.com/openai/deployments/${deployment}/embeddings?api-version=${apiVersion}`,
335+
});
336+
});
337+
338+
test('handles audio translations', async () => {
339+
expect(
340+
await client.audio.translations.create({
341+
model,
342+
file: { url: 'https://example.com', blob: () => 0 as any },
343+
}),
344+
).toStrictEqual({
345+
url: `https://example.com/openai/deployments/${deployment}/audio/translations?api-version=${apiVersion}`,
346+
});
347+
});
348+
349+
test('handles audio transcriptions', async () => {
350+
expect(
351+
await client.audio.transcriptions.create({
352+
model,
353+
file: { url: 'https://example.com', blob: () => 0 as any },
354+
}),
355+
).toStrictEqual({
356+
url: `https://example.com/openai/deployments/${deployment}/audio/transcriptions?api-version=${apiVersion}`,
357+
});
358+
});
359+
360+
test('handles text to speech', async () => {
361+
expect(
362+
await (
363+
await client.audio.speech.create({
364+
model,
365+
input: '',
366+
voice: 'alloy',
367+
})
368+
).json(),
369+
).toStrictEqual({
370+
url: `https://example.com/openai/deployments/${deployment}/audio/speech?api-version=${apiVersion}`,
371+
});
372+
});
373+
374+
test('handles image generation', async () => {
375+
expect(
376+
await client.images.generate({
377+
model,
378+
prompt: 'prompt',
379+
}),
380+
).toStrictEqual({
381+
url: `https://example.com/openai/deployments/${deployment}/images/generations?api-version=${apiVersion}`,
382+
});
383+
});
384+
385+
test('handles assistants', async () => {
386+
expect(
387+
await client.beta.assistants.create({
388+
model,
389+
}),
390+
).toStrictEqual({
391+
url: `https://example.com/openai/assistants?api-version=${apiVersion}`,
392+
});
393+
});
394+
395+
test('handles files', async () => {
396+
expect(
397+
await client.files.create({
398+
file: { url: 'https://example.com', blob: () => 0 as any },
399+
purpose: 'assistants',
400+
}),
401+
).toStrictEqual({
402+
url: `https://example.com/openai/files?api-version=${apiVersion}`,
403+
});
404+
});
405+
406+
test('handles fine tuning', async () => {
407+
expect(
408+
await client.fineTuning.jobs.create({
409+
model,
410+
training_file: '',
411+
}),
412+
).toStrictEqual({
413+
url: `https://example.com/openai/fine_tuning/jobs?api-version=${apiVersion}`,
414+
});
415+
});
416+
});
417+
418+
describe('with no client-level deployment', function () {
419+
const client = new AzureOpenAI({
420+
endpoint: 'https://example.com',
421+
apiKey: 'My API Key',
422+
apiVersion,
423+
fetch: testFetch,
424+
});
425+
426+
test('Batch is not handled', async () => {
427+
expect(
428+
await client.batches.create({
429+
completion_window: '24h',
430+
endpoint: '/v1/chat/completions',
431+
input_file_id: 'file-id',
432+
}),
433+
).toStrictEqual({
434+
url: `https://example.com/openai/batches?api-version=${apiVersion}`,
435+
});
436+
});
437+
438+
test('handles completions', async () => {
439+
expect(
440+
await client.completions.create({
441+
model: deployment,
442+
prompt: 'prompt',
443+
}),
444+
).toStrictEqual({
445+
url: `https://example.com/openai/deployments/${deployment}/completions?api-version=${apiVersion}`,
446+
});
447+
});
448+
449+
test('handles chat completions', async () => {
450+
expect(
451+
await client.chat.completions.create({
452+
model: deployment,
453+
messages: [{ role: 'system', content: 'Hello' }],
454+
}),
455+
).toStrictEqual({
456+
url: `https://example.com/openai/deployments/${deployment}/chat/completions?api-version=${apiVersion}`,
457+
});
458+
});
459+
460+
test('handles embeddings', async () => {
461+
expect(
462+
await client.embeddings.create({
463+
model: deployment,
464+
input: 'input',
465+
}),
466+
).toStrictEqual({
467+
url: `https://example.com/openai/deployments/${deployment}/embeddings?api-version=${apiVersion}`,
468+
});
469+
});
470+
471+
test('Audio translations is not handled', async () => {
472+
expect(
473+
await client.audio.translations.create({
474+
model: deployment,
475+
file: { url: 'https://example.com', blob: () => 0 as any },
476+
}),
477+
).toStrictEqual({
478+
url: `https://example.com/openai/audio/translations?api-version=${apiVersion}`,
479+
});
480+
});
481+
482+
test('Audio transcriptions is not handled', async () => {
483+
expect(
484+
await client.audio.transcriptions.create({
485+
model: deployment,
486+
file: { url: 'https://example.com', blob: () => 0 as any },
487+
}),
488+
).toStrictEqual({
489+
url: `https://example.com/openai/audio/transcriptions?api-version=${apiVersion}`,
490+
});
491+
});
492+
493+
test('handles text to speech', async () => {
494+
expect(
495+
await (
496+
await client.audio.speech.create({
497+
model: deployment,
498+
input: '',
499+
voice: 'alloy',
500+
})
501+
).json(),
502+
).toStrictEqual({
503+
url: `https://example.com/openai/deployments/${deployment}/audio/speech?api-version=${apiVersion}`,
504+
});
505+
});
506+
507+
test('handles image generation', async () => {
508+
expect(
509+
await client.images.generate({
510+
model: deployment,
511+
prompt: 'prompt',
512+
}),
513+
).toStrictEqual({
514+
url: `https://example.com/openai/deployments/${deployment}/images/generations?api-version=${apiVersion}`,
515+
});
516+
});
517+
518+
test('handles assistants', async () => {
519+
expect(
520+
await client.beta.assistants.create({
521+
model,
522+
}),
523+
).toStrictEqual({
524+
url: `https://example.com/openai/assistants?api-version=${apiVersion}`,
525+
});
526+
});
527+
528+
test('handles files', async () => {
529+
expect(
530+
await client.files.create({
531+
file: { url: 'https://example.com', blob: () => 0 as any },
532+
purpose: 'assistants',
533+
}),
534+
).toStrictEqual({
535+
url: `https://example.com/openai/files?api-version=${apiVersion}`,
536+
});
537+
});
538+
539+
test('handles fine tuning', async () => {
540+
expect(
541+
await client.fineTuning.jobs.create({
542+
model,
543+
training_file: '',
544+
}),
545+
).toStrictEqual({
546+
url: `https://example.com/openai/fine_tuning/jobs?api-version=${apiVersion}`,
547+
});
548+
});
549+
});
550+
});
551+
278552
describe('Content-Length', () => {
279553
test('handles multi-byte characters', () => {
280554
const { req } = client.buildRequest({ path: '/foo', method: 'post', body: { value: '—' } });

0 commit comments

Comments
 (0)