diff --git a/.changeset/smooth-dragons-tell.md b/.changeset/smooth-dragons-tell.md new file mode 100644 index 000000000000..a080cf2d1928 --- /dev/null +++ b/.changeset/smooth-dragons-tell.md @@ -0,0 +1,6 @@ +--- +'@ai-sdk/provider-utils': patch +'@ai-sdk/fireworks': patch +--- + +feat (provider/fireworks): Support add'l image models. diff --git a/content/docs/03-ai-sdk-core/35-image-generation.mdx b/content/docs/03-ai-sdk-core/35-image-generation.mdx index 9eaf41b8407a..e9756d3d4d72 100644 --- a/content/docs/03-ai-sdk-core/35-image-generation.mdx +++ b/content/docs/03-ai-sdk-core/35-image-generation.mdx @@ -166,11 +166,18 @@ const { image, warnings } = await generateImage({ ## Image Models -| Provider | Model | Sizes | Aspect Ratios | -| ----------------------------------------------------------------------- | ---------------------------------------------- | ------------------------------- | ----------------------------------------------- | -| [Google Vertex](/providers/ai-sdk-providers/google-vertex#image-models) | `imagen-3.0-generate-001` | Use aspect ratio | 1:1, 3:4, 4:3, 9:16, 16:9 | -| [Google Vertex](/providers/ai-sdk-providers/google-vertex#image-models) | `imagen-3.0-fast-generate-001` | Use aspect ratio | 1:1, 3:4, 4:3, 9:16, 16:9 | -| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-3` | 1024x1024, 1792x1024, 1024x1792 | use size | -| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-2` | 256x256, 512x512, 1024x1024 | use size | -| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/flux-1-dev-fp8` | Use aspect ratio | 1:1, 2:3, 3:2, 4:5, 5:4, 16:9, 9:16, 9:21, 21:9 | -| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/flux-1-schnell-fp8` | Use aspect ratio | 1:1, 2:3, 3:2, 4:5, 5:4, 16:9, 9:16, 9:21, 21:9 | +| Provider | Model | Sizes | Aspect Ratios | +| ----------------------------------------------------------------------- | ------------------------------------------------------------ | ------------------------------- | ----------------------------------------------- | +| [Google Vertex](/providers/ai-sdk-providers/google-vertex#image-models) | `imagen-3.0-generate-001` | Use aspect ratio | 1:1, 3:4, 4:3, 9:16, 16:9 | +| [Google Vertex](/providers/ai-sdk-providers/google-vertex#image-models) | `imagen-3.0-fast-generate-001` | Use aspect ratio | 1:1, 3:4, 4:3, 9:16, 16:9 | +| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-3` | 1024x1024, 1792x1024, 1024x1792 | use size | +| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-2` | 256x256, 512x512, 1024x1024 | use size | +| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/flux-1-dev-fp8` | Use aspect ratio | 1:1, 2:3, 3:2, 4:5, 5:4, 16:9, 9:16, 9:21, 21:9 | +| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/flux-1-schnell-fp8` | Use aspect ratio | 1:1, 2:3, 3:2, 4:5, 5:4, 16:9, 9:16, 9:21, 21:9 | +| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/playground-v2-5-1024px-aesthetic` | 640x1536 to 1536x640\* | use size | +| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/japanese-stable-diffusion-xl` | 640x1536 to 1536x640\* | use size | +| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/playground-v2-1024px-aesthetic` | 640x1536 to 1536x640\* | use size | +| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/SSD-1B` | 640x1536 to 1536x640\* | use size | +| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/stable-diffusion-xl-1024-v1-0` | 640x1536 to 1536x640\* | use size | + +\* Supported sizes: 640x1536, 768x1344, 832x1216, 896x1152, 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640 diff --git a/content/providers/01-ai-sdk-providers/26-fireworks.mdx b/content/providers/01-ai-sdk-providers/26-fireworks.mdx index 16cc989f73b1..04017ff3ecb4 100644 --- a/content/providers/01-ai-sdk-providers/26-fireworks.mdx +++ b/content/providers/01-ai-sdk-providers/26-fireworks.mdx @@ -142,13 +142,42 @@ const { image } = await generateImage({ ``` - Fireworks models do not support the `size` parameter. Use the `aspectRatio` - parameter instead. + Model support for `size` and `aspectRatio` parameters varies. See the [Model + Capabilities](#model-capabilities-1) section below for supported dimensions, + or check the model's documentation on [Fireworks models + page](https://fireworks.ai/models) for more details. ### Model Capabilities -| Model | Aspect Ratios | -| ---------------------------------------------- | ----------------------------------------------- | -| `accounts/fireworks/models/flux-1-dev-fp8` | 1:1, 2:3, 3:2, 4:5, 5:4, 16:9, 9:16, 9:21, 21:9 | -| `accounts/fireworks/models/flux-1-schnell-fp8` | 1:1, 2:3, 3:2, 4:5, 5:4, 16:9, 9:16, 9:21, 21:9 | +For all models supporting aspect ratios, the following aspect ratios are supported: + +`1:1 (default), 2:3, 3:2, 4:5, 5:4, 16:9, 9:16, 9:21, 21:9` + +For all models supporting size, the following sizes are supported: + +`640 x 1536, 768 x 1344, 832 x 1216, 896 x 1152, 1024x1024 (default), 1152 x 896, 1216 x 832, 1344 x 768, 1536 x 640` + +| Model | Dimensions Specification | +| ------------------------------------------------------------ | ------------------------ | +| `accounts/fireworks/models/flux-1-dev-fp8` | Aspect Ratio | +| `accounts/fireworks/models/flux-1-schnell-fp8` | Aspect Ratio | +| `accounts/fireworks/models/playground-v2-5-1024px-aesthetic` | Size | +| `accounts/fireworks/models/japanese-stable-diffusion-xl` | Size | +| `accounts/fireworks/models/playground-v2-1024px-aesthetic` | Size | +| `accounts/fireworks/models/SSD-1B` | Size | +| `accounts/fireworks/models/stable-diffusion-xl-1024-v1-0` | Size | + +For more details, see the [Fireworks models page](https://fireworks.ai/models). + +#### Stability AI Models + +Fireworks also presents several Stability AI models backed by Stability AI API +keys and endpoint. The AI SDK Fireworks provider does not currently include +support for these models: + +| Model ID | +| -------------------------------------- | +| `accounts/stability/models/sd3-turbo` | +| `accounts/stability/models/sd3-medium` | +| `accounts/stability/models/sd3` | diff --git a/packages/fireworks/README.md b/packages/fireworks/README.md index b715428839cc..a5bdc7e5cadd 100644 --- a/packages/fireworks/README.md +++ b/packages/fireworks/README.md @@ -1,34 +1,50 @@ # AI SDK - Fireworks Provider -The **[Fireworks provider](https://sdk.vercel.ai/providers/ai-sdk-providers/fireworks)** for the [AI SDK](https://sdk.vercel.ai/docs) contains language model support for the [Fireworks](https://fireworks.ai) platform. +The **[Fireworks provider](https://sdk.vercel.ai/providers/ai-sdk-providers/fireworks)** for the [AI SDK](https://sdk.vercel.ai/docs) contains language model and image model support for the [Fireworks](https://fireworks.ai) platform. ## Setup The Fireworks provider is available in the `@ai-sdk/fireworks` module. You can install it with -\```bash +```bash npm i @ai-sdk/fireworks -\``` +``` ## Provider Instance You can import the default provider instance `fireworks` from `@ai-sdk/fireworks`: -\```ts +```ts import { fireworks } from '@ai-sdk/fireworks'; -\``` +``` -## Example +## Language Model Example -\```ts +```ts import { fireworks } from '@ai-sdk/fireworks'; import { generateText } from 'ai'; const { text } = await generateText({ -model: fireworks('accounts/fireworks/models/llama-v2-13b-chat'), -prompt: 'Write a JavaScript function that sorts a list:', + model: fireworks('accounts/fireworks/models/deepseek-v3'), + prompt: 'Write a JavaScript function that sorts a list:', }); -\``` +``` + +## Image Model Examples + +```ts +import { fireworks } from '@ai-sdk/fireworks'; +import { experimental_generateImage as generateImage } from 'ai'; +import fs from 'fs'; + +const { image } = await generateImage({ + model: fireworks.image('accounts/fireworks/models/flux-1-dev-fp8'), + prompt: 'A serene mountain landscape at sunset', +}); +const filename = `image-${Date.now()}.png`; +fs.writeFileSync(filename, image.uint8Array); +console.log(`Image saved to ${filename}`); +``` ## Documentation diff --git a/packages/fireworks/src/fireworks-image-model.test.ts b/packages/fireworks/src/fireworks-image-model.test.ts index 3556a516e69d..f391b6046cdc 100644 --- a/packages/fireworks/src/fireworks-image-model.test.ts +++ b/packages/fireworks/src/fireworks-image-model.test.ts @@ -1,35 +1,65 @@ -import { APICallError } from '@ai-sdk/provider'; import { BinaryTestServer } from '@ai-sdk/provider-utils/test'; import { describe, expect, it } from 'vitest'; import { FireworksImageModel } from './fireworks-image-model'; +import { FetchFunction } from '@ai-sdk/provider-utils'; const prompt = 'A cute baby sea otter'; -const model = new FireworksImageModel( - 'accounts/fireworks/models/flux-1-dev-fp8', - { +function createBasicModel({ + headers, + fetch, +}: { + headers?: () => Record; + fetch?: FetchFunction; +} = {}) { + return new FireworksImageModel('accounts/fireworks/models/flux-1-dev-fp8', { provider: 'fireworks', baseURL: 'https://api.example.com', + headers: headers ?? (() => ({ 'api-key': 'test-key' })), + fetch, + }); +} + +function createSizeModel() { + return new FireworksImageModel( + 'accounts/fireworks/models/playground-v2-5-1024px-aesthetic', + { + provider: 'fireworks', + baseURL: 'https://api.size-example.com', + headers: () => ({ 'api-key': 'test-key' }), + }, + ); +} + +function createStabilityModel() { + return new FireworksImageModel('accounts/stability/models/sd3', { + provider: 'fireworks', + baseURL: 'https://api.stability.ai', headers: () => ({ 'api-key': 'test-key' }), - }, -); + }); +} + +const basicUrl = + 'https://api.example.com/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image'; +const sizeUrl = + 'https://api.size-example.com/image_generation/accounts/fireworks/models/playground-v2-5-1024px-aesthetic'; +const stabilityUrl = + 'https://api.stability.ai/v2beta/stable-image/generate/sd3'; describe('FireworksImageModel', () => { describe('doGenerate', () => { - const server = new BinaryTestServer( - 'https://api.example.com/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image', - ); - + const server = new BinaryTestServer([basicUrl, sizeUrl, stabilityUrl]); server.setupTestEnvironment(); - function prepareBinaryResponse() { + function prepareBinaryResponse(url: string) { const mockImageBuffer = Buffer.from('mock-image-data'); - server.responseBody = mockImageBuffer; + server.setResponseFor(url, { body: mockImageBuffer }); } it('should pass the correct parameters including aspect ratio and seed', async () => { - prepareBinaryResponse(); + prepareBinaryResponse(basicUrl); + const model = createBasicModel(); await model.doGenerate({ prompt, n: 1, @@ -39,7 +69,8 @@ describe('FireworksImageModel', () => { providerOptions: { fireworks: { additional_param: 'value' } }, }); - expect(await server.getRequestBodyJson()).toStrictEqual({ + const request = await server.getRequestDataFor(basicUrl); + expect(await request.bodyJson()).toStrictEqual({ prompt, aspect_ratio: '16:9', seed: 42, @@ -48,18 +79,13 @@ describe('FireworksImageModel', () => { }); it('should pass headers', async () => { - prepareBinaryResponse(); - - const modelWithHeaders = new FireworksImageModel( - 'accounts/fireworks/models/flux-1-dev-fp8', - { - provider: 'fireworks', - baseURL: 'https://api.example.com', - headers: () => ({ - 'Custom-Provider-Header': 'provider-header-value', - }), - }, - ); + prepareBinaryResponse(basicUrl); + + const modelWithHeaders = createBasicModel({ + headers: () => ({ + 'Custom-Provider-Header': 'provider-header-value', + }), + }); await modelWithHeaders.doGenerate({ prompt, @@ -73,47 +99,18 @@ describe('FireworksImageModel', () => { }, }); - const requestHeaders = await server.getRequestHeaders(); - - expect(requestHeaders).toStrictEqual({ + const request = await server.getRequestDataFor(basicUrl); + expect(request.headers()).toStrictEqual({ 'content-type': 'application/json', 'custom-provider-header': 'provider-header-value', 'custom-request-header': 'request-header-value', }); }); - it('should return binary image data', async () => { - const mockImageBuffer = Buffer.from('mock-image-data'); - server.responseBody = mockImageBuffer; - - const result = await model.doGenerate({ - prompt, - n: 1, - size: undefined, - aspectRatio: undefined, - seed: undefined, - providerOptions: {}, - }); - - expect(result.images).toHaveLength(1); - expect(result.images[0]).toBeInstanceOf(Uint8Array); - expect(Buffer.from(result.images[0])).toEqual(mockImageBuffer); - }); - it('should handle empty response body', async () => { - server.responseBody = null; - - await expect( - model.doGenerate({ - prompt, - n: 1, - size: undefined, - aspectRatio: undefined, - seed: undefined, - providerOptions: {}, - }), - ).rejects.toThrow(APICallError); + server.setResponseFor(basicUrl, { body: null }); + const model = createBasicModel(); await expect( model.doGenerate({ prompt, @@ -126,7 +123,7 @@ describe('FireworksImageModel', () => { ).rejects.toMatchObject({ message: 'Response body is empty', statusCode: 200, - url: 'https://api.example.com/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image', + url: basicUrl, requestBodyValues: { prompt: 'A cute baby sea otter', }, @@ -134,20 +131,12 @@ describe('FireworksImageModel', () => { }); it('should handle API errors', async () => { - server.responseStatus = 400; - server.responseBody = Buffer.from('Bad Request'); - - await expect( - model.doGenerate({ - prompt, - n: 1, - size: undefined, - aspectRatio: undefined, - seed: undefined, - providerOptions: {}, - }), - ).rejects.toThrow(APICallError); + server.setResponseFor(basicUrl, { + status: 400, + body: Buffer.from('Bad Request'), + }); + const model = createBasicModel(); await expect( model.doGenerate({ prompt, @@ -160,7 +149,7 @@ describe('FireworksImageModel', () => { ).rejects.toMatchObject({ message: 'Bad Request', statusCode: 400, - url: 'https://api.example.com/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image', + url: basicUrl, requestBodyValues: { prompt: 'A cute baby sea otter', }, @@ -168,11 +157,35 @@ describe('FireworksImageModel', () => { }); }); - it('should return warnings for unsupported settings', async () => { - const mockImageBuffer = Buffer.from('mock-image-data'); - server.responseBody = mockImageBuffer; + it('should handle size parameter for supported models', async () => { + prepareBinaryResponse(sizeUrl); + + const sizeModel = createSizeModel(); - const result = await model.doGenerate({ + await sizeModel.doGenerate({ + prompt, + n: 1, + size: '1024x768', + aspectRatio: undefined, + seed: 42, + providerOptions: {}, + }); + + const request = await server.getRequestDataFor(sizeUrl); + expect(await request.bodyJson()).toStrictEqual({ + prompt, + width: '1024', + height: '768', + seed: 42, + }); + }); + + it('should return appropriate warnings based on model capabilities', async () => { + prepareBinaryResponse(basicUrl); + + // Test workflow model (supports aspectRatio but not size) + const model = createBasicModel(); + const result1 = await model.doGenerate({ prompt, n: 1, size: '1024x1024', @@ -181,14 +194,87 @@ describe('FireworksImageModel', () => { providerOptions: {}, }); - expect(result.warnings).toStrictEqual([ - { - type: 'unsupported-setting', - setting: 'size', - details: - 'This model does not support the `size` option. Use `aspectRatio` instead.', - }, - ]); + expect(result1.warnings).toContainEqual({ + type: 'unsupported-setting', + setting: 'size', + details: + 'This model does not support the `size` option. Use `aspectRatio` instead.', + }); + + // Test size-supporting model + prepareBinaryResponse(sizeUrl); + const sizeModel = createSizeModel(); + + const result2 = await sizeModel.doGenerate({ + prompt, + n: 1, + size: '1024x1024', + aspectRatio: '1:1', + seed: 123, + providerOptions: {}, + }); + + expect(result2.warnings).toContainEqual({ + type: 'unsupported-setting', + setting: 'aspectRatio', + details: 'This model does not support the `aspectRatio` option.', + }); + }); + + it('should respect the abort signal', async () => { + prepareBinaryResponse(basicUrl); + const model = createBasicModel(); + const controller = new AbortController(); + + const generatePromise = model.doGenerate({ + prompt, + n: 1, + size: undefined, + aspectRatio: undefined, + seed: undefined, + providerOptions: {}, + abortSignal: controller.signal, + }); + + controller.abort(); + + await expect(generatePromise).rejects.toThrow( + 'This operation was aborted', + ); + }); + + it('should use custom fetch function when provided', async () => { + const mockFetch = vi.fn().mockResolvedValue( + new Response(Buffer.from('mock-image-data'), { + status: 200, + }), + ); + + const model = createBasicModel({ + fetch: mockFetch, + }); + + await model.doGenerate({ + prompt, + n: 1, + size: undefined, + aspectRatio: undefined, + seed: undefined, + providerOptions: {}, + }); + + expect(mockFetch).toHaveBeenCalled(); + }); + }); + + describe('constructor', () => { + it('should expose correct provider and model information', () => { + const model = createBasicModel(); + + expect(model.provider).toBe('fireworks'); + expect(model.modelId).toBe('accounts/fireworks/models/flux-1-dev-fp8'); + expect(model.specificationVersion).toBe('v1'); + expect(model.maxImagesPerCall).toBe(1); }); }); }); diff --git a/packages/fireworks/src/fireworks-image-model.ts b/packages/fireworks/src/fireworks-image-model.ts index 28602b0a76fe..b29b5057168e 100644 --- a/packages/fireworks/src/fireworks-image-model.ts +++ b/packages/fireworks/src/fireworks-image-model.ts @@ -15,8 +15,62 @@ import { export type FireworksImageModelId = | 'accounts/fireworks/models/flux-1-dev-fp8' | 'accounts/fireworks/models/flux-1-schnell-fp8' + | 'accounts/fireworks/models/playground-v2-5-1024px-aesthetic' + | 'accounts/fireworks/models/japanese-stable-diffusion-xl' + | 'accounts/fireworks/models/playground-v2-1024px-aesthetic' + | 'accounts/fireworks/models/SSD-1B' + | 'accounts/fireworks/models/stable-diffusion-xl-1024-v1-0' | (string & {}); +interface FireworksImageModelBackendConfig { + urlFormat: 'workflows' | 'image_generation'; + supportsSize?: boolean; +} + +const modelToBackendConfig: Partial< + Record +> = { + 'accounts/fireworks/models/flux-1-dev-fp8': { + urlFormat: 'workflows', + }, + 'accounts/fireworks/models/flux-1-schnell-fp8': { + urlFormat: 'workflows', + }, + 'accounts/fireworks/models/playground-v2-5-1024px-aesthetic': { + urlFormat: 'image_generation', + supportsSize: true, + }, + 'accounts/fireworks/models/japanese-stable-diffusion-xl': { + urlFormat: 'image_generation', + supportsSize: true, + }, + 'accounts/fireworks/models/playground-v2-1024px-aesthetic': { + urlFormat: 'image_generation', + supportsSize: true, + }, + 'accounts/fireworks/models/stable-diffusion-xl-1024-v1-0': { + urlFormat: 'image_generation', + supportsSize: true, + }, + 'accounts/fireworks/models/SSD-1B': { + urlFormat: 'image_generation', + supportsSize: true, + }, +}; + +function getUrlForModel( + baseUrl: string, + modelId: FireworksImageModelId, +): string { + switch (modelToBackendConfig[modelId]?.urlFormat) { + case 'image_generation': + return `${baseUrl}/image_generation/${modelId}`; + case 'workflows': + default: + return `${baseUrl}/workflows/${modelId}/text_to_image`; + } +} + interface FireworksImageModelConfig { provider: string; baseURL: string; @@ -80,6 +134,42 @@ const statusCodeErrorResponseHandler: ResponseHandler = async ({ }; }; +interface ImageRequestParams { + baseUrl: string; + modelId: FireworksImageModelId; + prompt: string; + aspectRatio?: string; + size?: string; + seed?: number; + providerOptions: Record; + headers: Record; + abortSignal?: AbortSignal; + fetch?: FetchFunction; +} + +async function postImageToApi( + params: ImageRequestParams, +): Promise { + const splitSize = params.size?.split('x'); + const { value: response } = await postJsonToApi({ + url: getUrlForModel(params.baseUrl, params.modelId), + headers: params.headers, + body: { + prompt: params.prompt, + aspect_ratio: params.aspectRatio, + seed: params.seed, + ...(splitSize && { width: splitSize[0], height: splitSize[1] }), + ...(params.providerOptions.fireworks ?? {}), + }, + failedResponseHandler: statusCodeErrorResponseHandler, + successfulResponseHandler: createBinaryResponseHandler(), + abortSignal: params.abortSignal, + fetch: params.fetch, + }); + + return response; +} + export class FireworksImageModel implements ImageModelV1 { readonly specificationVersion = 'v1'; @@ -108,7 +198,8 @@ export class FireworksImageModel implements ImageModelV1 { > { const warnings: Array = []; - if (size != null) { + const backendConfig = modelToBackendConfig[this.modelId]; + if (!backendConfig?.supportsSize && size != null) { warnings.push({ type: 'unsupported-setting', setting: 'size', @@ -117,20 +208,25 @@ export class FireworksImageModel implements ImageModelV1 { }); } - const url = `${this.config.baseURL}/workflows/${this.modelId}/text_to_image`; - const body = { + // Use supportsSize as a proxy for whether the model does not support + // aspectRatio. This invariant holds for the current set of models. + if (backendConfig?.supportsSize && aspectRatio != null) { + warnings.push({ + type: 'unsupported-setting', + setting: 'aspectRatio', + details: 'This model does not support the `aspectRatio` option.', + }); + } + + const response = await postImageToApi({ + baseUrl: this.config.baseURL, prompt, - aspect_ratio: aspectRatio, + aspectRatio, + size, seed, - ...(providerOptions.fireworks ?? {}), - }; - - const { value: response } = await postJsonToApi({ - url, + modelId: this.modelId, + providerOptions, headers: combineHeaders(this.config.headers(), headers), - body, - failedResponseHandler: statusCodeErrorResponseHandler, - successfulResponseHandler: createBinaryResponseHandler(), abortSignal, fetch: this.config.fetch, }); diff --git a/packages/fireworks/src/fireworks-provider.ts b/packages/fireworks/src/fireworks-provider.ts index 92d1501a3405..1413765346b6 100644 --- a/packages/fireworks/src/fireworks-provider.ts +++ b/packages/fireworks/src/fireworks-provider.ts @@ -45,7 +45,8 @@ const fireworksErrorStructure: ProviderErrorStructure = { export interface FireworksProviderSettings { /** -Fireworks API key. +Fireworks API key. Default value is taken from the `FIREWORKS_API_KEY` +environment variable. */ apiKey?: string; /** diff --git a/packages/provider-utils/src/test/binary-test-server.test.ts b/packages/provider-utils/src/test/binary-test-server.test.ts new file mode 100644 index 000000000000..321cd49f2710 --- /dev/null +++ b/packages/provider-utils/src/test/binary-test-server.test.ts @@ -0,0 +1,186 @@ +import { + describe, + it, + expect, + beforeEach, + afterEach, + beforeAll, + afterAll, + vi, +} from 'vitest'; +import { BinaryTestServer } from './binary-test-server'; + +describe('BinaryTestServer', () => { + let server: BinaryTestServer; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('constructor', () => { + it('should initialize with a single URL', () => { + const server = new BinaryTestServer('http://example.com'); + expect(server.server).toBeDefined(); + }); + + it('should initialize with multiple URLs', () => { + const server = new BinaryTestServer([ + 'http://example.com', + 'http://test.com', + ]); + expect(server.server).toBeDefined(); + }); + }); + + describe('setResponseFor', () => { + beforeAll(() => { + server = new BinaryTestServer('http://example.com'); + server.server.listen(); + }); + + afterAll(() => { + server.server.close(); + }); + + it('should set response options for a valid URL', () => { + const buffer = Buffer.from('test data'); + server.setResponseFor('http://example.com/', { + body: buffer, + headers: { 'content-type': 'application/octet-stream' }, + status: 201, + }); + }); + + it('should throw error for invalid URL', () => { + expect(() => + server.setResponseFor('http://invalid.com', { status: 200 }), + ).toThrow('No endpoint configured for URL'); + }); + }); + + describe('request handling', () => { + beforeAll(() => { + server = new BinaryTestServer('http://example.com'); + server.server.listen(); + }); + + afterAll(() => { + server.server.close(); + }); + + beforeEach(() => { + server.server.resetHandlers(); + }); + + it('should handle JSON requests', async () => { + const testData = { test: 'data' }; + const fetchSpy = vi.spyOn(global, 'fetch'); + + await fetch('http://example.com', { + method: 'POST', + headers: { 'content-type': 'application/json' }, + body: JSON.stringify(testData), + }); + + expect(fetchSpy).toHaveBeenCalledWith( + 'http://example.com', + expect.objectContaining({ + method: 'POST', + headers: expect.objectContaining({ + 'content-type': 'application/json', + }), + }), + ); + + const requestData = await server.getRequestDataFor('http://example.com/'); + const bodyJson = await requestData.bodyJson(); + expect(bodyJson).toEqual(testData); + }); + + it('should handle form data requests', async () => { + const formData = new FormData(); + formData.append('field', 'value'); + const fetchSpy = vi.spyOn(global, 'fetch'); + + await fetch('http://example.com', { + method: 'POST', + body: formData, + }); + + expect(fetchSpy).toHaveBeenCalledWith( + 'http://example.com', + expect.objectContaining({ + method: 'POST', + body: expect.any(FormData), + }), + ); + + const requestData = await server.getRequestDataFor('http://example.com'); + const formDataReceived = await requestData.bodyFormData(); + expect(formDataReceived.get('field')).toBe('value'); + }); + + it('should handle custom response configurations', async () => { + const responseBuffer = Buffer.from('test response'); + const fetchSpy = vi.spyOn(global, 'fetch'); + + server.setResponseFor('http://example.com', { + body: responseBuffer, + headers: { 'x-custom': 'test' }, + status: 201, + }); + + const response = await fetch('http://example.com', { method: 'POST' }); + + expect(fetchSpy).toHaveBeenCalledWith( + 'http://example.com', + expect.objectContaining({ method: 'POST' }), + ); + + expect(response.status).toBe(201); + expect(response.headers.get('x-custom')).toBe('test'); + const responseData = await response.arrayBuffer(); + expect(Buffer.from(responseData)).toEqual(responseBuffer); + }); + }); + + describe('URL handling', () => { + let server: BinaryTestServer; + + beforeEach(() => { + server = new BinaryTestServer('http://example.com'); + server.server.listen(); + // Set default response + server.setResponseFor('http://example.com', { + status: 200, + body: null, + }); + }); + + afterEach(() => { + server.server.resetHandlers(); + server.server.close(); + }); + + it('should handle search params', async () => { + const response = await fetch('http://example.com?param=value', { + method: 'POST', + body: JSON.stringify({ test: true }), + }); + + expect(response.status).toBe(200); + + const requestData = await server.getRequestDataFor('http://example.com'); + expect(requestData.urlSearchParams().get('param')).toBe('value'); + }); + + it('should handle relative URLs', () => { + const server = new BinaryTestServer('/api/endpoint'); + expect(server.server).toBeDefined(); + }); + }); +}); diff --git a/packages/provider-utils/src/test/binary-test-server.ts b/packages/provider-utils/src/test/binary-test-server.ts index 2dfad8c5d363..73fe5effb329 100644 --- a/packages/provider-utils/src/test/binary-test-server.ts +++ b/packages/provider-utils/src/test/binary-test-server.ts @@ -3,65 +3,127 @@ import { SetupServer, setupServer } from 'msw/node'; export class BinaryTestServer { readonly server: SetupServer; + private endpoints: Map< + string, + { + responseBody: Buffer | null; + responseHeaders: Record; + responseStatus: number; + request: Request | undefined; + } + > = new Map(); - responseBody: Buffer | null = null; - responseHeaders: Record = {}; - responseStatus = 200; + constructor(urls: string | string[]) { + const urlList = Array.isArray(urls) ? urls : [urls]; - request: Request | undefined; + // Initialize endpoints + urlList.forEach(url => { + const normalizedUrl = this.normalizeUrl(url); + this.endpoints.set(normalizedUrl, { + responseBody: null, + responseHeaders: {}, + responseStatus: 200, + request: undefined, + }); + }); - constructor(url: string) { this.server = setupServer( - http.post(url, ({ request }) => { - this.request = request; + ...urlList.map(url => + http.post(this.normalizeUrl(url), ({ request }) => { + const endpoint = this.endpoints.get(this.normalizeUrl(request.url)); + if (!endpoint) { + return new HttpResponse(null, { status: 500 }); + } + endpoint.request = request; - if (this.responseBody === null) { - return new HttpResponse(null, { status: this.responseStatus }); - } + if (endpoint.responseBody === null) { + return new HttpResponse(null, { status: endpoint.responseStatus }); + } - return new HttpResponse(this.responseBody, { - status: this.responseStatus, - headers: this.responseHeaders, - }); - }), + return new HttpResponse(endpoint.responseBody, { + status: endpoint.responseStatus, + headers: endpoint.responseHeaders, + }); + }), + ), ); } - async getRequestBodyJson() { - expect(this.request).toBeDefined(); - return JSON.parse(await this.request!.text()); + private normalizeUrl(url: string): string { + try { + // Parse URL and remove search params for endpoint matching + const urlObj = new URL(url); + urlObj.search = ''; // Clear search params for matching + const normalized = urlObj.toString(); + return normalized.endsWith('/') ? normalized.slice(0, -1) : normalized; + } catch { + // If not a valid URL, assume it's a path and return as-is + return url.endsWith('/') ? url.slice(0, -1) : url; + } } - async getRequestHeaders() { - expect(this.request).toBeDefined(); - const requestHeaders = this.request!.headers; - - // convert headers to object for easier comparison - const headersObject: Record = {}; - requestHeaders.forEach((value, key) => { - headersObject[key] = value; - }); - - return headersObject; + setResponseFor( + url: string, + options: { + body?: Buffer | null; + headers?: Record; + status?: number; + }, + ) { + // Normalize the URL before lookup + const normalizedUrl = this.normalizeUrl(url); + const endpoint = this.endpoints.get(normalizedUrl); + if (!endpoint) { + throw new Error(`No endpoint configured for URL: ${url}`); + } + if (options.body !== undefined) endpoint.responseBody = options.body; + if (options.headers) endpoint.responseHeaders = options.headers; + if (options.status) endpoint.responseStatus = options.status; } - async getRequestUrlSearchParams() { - expect(this.request).toBeDefined(); - return new URL(this.request!.url).searchParams; - } + async getRequestDataFor(url: string) { + // Normalize the URL before lookup + const normalizedUrl = this.normalizeUrl(url); + const endpoint = this.endpoints.get(normalizedUrl); + if (!endpoint) { + throw new Error(`No endpoint configured for URL: ${url}`); + } + expect(endpoint.request).toBeDefined(); - async getRequestUrl() { - expect(this.request).toBeDefined(); - return new URL(this.request!.url).toString(); + return { + bodyJson: async () => { + const text = await endpoint.request!.text(); + return JSON.parse(text); + }, + bodyFormData: async () => { + const contentType = endpoint.request!.headers.get('content-type'); + if (contentType?.includes('multipart/form-data')) { + return endpoint.request!.formData(); + } + throw new Error('Request content-type is not multipart/form-data'); + }, + headers: () => { + const headersObject: Record = {}; + endpoint.request!.headers.forEach((value, key) => { + headersObject[key] = value; + }); + return headersObject; + }, + urlSearchParams: () => new URL(endpoint.request!.url).searchParams, + url: () => new URL(endpoint.request!.url).toString(), + }; } setupTestEnvironment() { beforeAll(() => this.server.listen()); beforeEach(() => { - this.responseBody = null; - this.request = undefined; - this.responseHeaders = {}; - this.responseStatus = 200; + // Reset all endpoints + this.endpoints.forEach(endpoint => { + endpoint.responseBody = null; + endpoint.request = undefined; + endpoint.responseHeaders = {}; + endpoint.responseStatus = 200; + }); }); afterEach(() => this.server.resetHandlers()); afterAll(() => this.server.close());