diff --git a/.changeset/smooth-dragons-tell.md b/.changeset/smooth-dragons-tell.md
new file mode 100644
index 000000000000..a080cf2d1928
--- /dev/null
+++ b/.changeset/smooth-dragons-tell.md
@@ -0,0 +1,6 @@
+---
+'@ai-sdk/provider-utils': patch
+'@ai-sdk/fireworks': patch
+---
+
+feat (provider/fireworks): Support add'l image models.
diff --git a/content/docs/03-ai-sdk-core/35-image-generation.mdx b/content/docs/03-ai-sdk-core/35-image-generation.mdx
index 9eaf41b8407a..e9756d3d4d72 100644
--- a/content/docs/03-ai-sdk-core/35-image-generation.mdx
+++ b/content/docs/03-ai-sdk-core/35-image-generation.mdx
@@ -166,11 +166,18 @@ const { image, warnings } = await generateImage({
## Image Models
-| Provider | Model | Sizes | Aspect Ratios |
-| ----------------------------------------------------------------------- | ---------------------------------------------- | ------------------------------- | ----------------------------------------------- |
-| [Google Vertex](/providers/ai-sdk-providers/google-vertex#image-models) | `imagen-3.0-generate-001` | Use aspect ratio | 1:1, 3:4, 4:3, 9:16, 16:9 |
-| [Google Vertex](/providers/ai-sdk-providers/google-vertex#image-models) | `imagen-3.0-fast-generate-001` | Use aspect ratio | 1:1, 3:4, 4:3, 9:16, 16:9 |
-| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-3` | 1024x1024, 1792x1024, 1024x1792 | use size |
-| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-2` | 256x256, 512x512, 1024x1024 | use size |
-| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/flux-1-dev-fp8` | Use aspect ratio | 1:1, 2:3, 3:2, 4:5, 5:4, 16:9, 9:16, 9:21, 21:9 |
-| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/flux-1-schnell-fp8` | Use aspect ratio | 1:1, 2:3, 3:2, 4:5, 5:4, 16:9, 9:16, 9:21, 21:9 |
+| Provider | Model | Sizes | Aspect Ratios |
+| ----------------------------------------------------------------------- | ------------------------------------------------------------ | ------------------------------- | ----------------------------------------------- |
+| [Google Vertex](/providers/ai-sdk-providers/google-vertex#image-models) | `imagen-3.0-generate-001` | Use aspect ratio | 1:1, 3:4, 4:3, 9:16, 16:9 |
+| [Google Vertex](/providers/ai-sdk-providers/google-vertex#image-models) | `imagen-3.0-fast-generate-001` | Use aspect ratio | 1:1, 3:4, 4:3, 9:16, 16:9 |
+| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-3` | 1024x1024, 1792x1024, 1024x1792 | use size |
+| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-2` | 256x256, 512x512, 1024x1024 | use size |
+| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/flux-1-dev-fp8` | Use aspect ratio | 1:1, 2:3, 3:2, 4:5, 5:4, 16:9, 9:16, 9:21, 21:9 |
+| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/flux-1-schnell-fp8` | Use aspect ratio | 1:1, 2:3, 3:2, 4:5, 5:4, 16:9, 9:16, 9:21, 21:9 |
+| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/playground-v2-5-1024px-aesthetic` | 640x1536 to 1536x640\* | use size |
+| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/japanese-stable-diffusion-xl` | 640x1536 to 1536x640\* | use size |
+| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/playground-v2-1024px-aesthetic` | 640x1536 to 1536x640\* | use size |
+| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/SSD-1B` | 640x1536 to 1536x640\* | use size |
+| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/stable-diffusion-xl-1024-v1-0` | 640x1536 to 1536x640\* | use size |
+
+\* Supported sizes: 640x1536, 768x1344, 832x1216, 896x1152, 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640
diff --git a/content/providers/01-ai-sdk-providers/26-fireworks.mdx b/content/providers/01-ai-sdk-providers/26-fireworks.mdx
index 16cc989f73b1..04017ff3ecb4 100644
--- a/content/providers/01-ai-sdk-providers/26-fireworks.mdx
+++ b/content/providers/01-ai-sdk-providers/26-fireworks.mdx
@@ -142,13 +142,42 @@ const { image } = await generateImage({
```
- Fireworks models do not support the `size` parameter. Use the `aspectRatio`
- parameter instead.
+ Model support for `size` and `aspectRatio` parameters varies. See the [Model
+ Capabilities](#model-capabilities-1) section below for supported dimensions,
+ or check the model's documentation on [Fireworks models
+ page](https://fireworks.ai/models) for more details.
### Model Capabilities
-| Model | Aspect Ratios |
-| ---------------------------------------------- | ----------------------------------------------- |
-| `accounts/fireworks/models/flux-1-dev-fp8` | 1:1, 2:3, 3:2, 4:5, 5:4, 16:9, 9:16, 9:21, 21:9 |
-| `accounts/fireworks/models/flux-1-schnell-fp8` | 1:1, 2:3, 3:2, 4:5, 5:4, 16:9, 9:16, 9:21, 21:9 |
+For all models supporting aspect ratios, the following aspect ratios are supported:
+
+`1:1 (default), 2:3, 3:2, 4:5, 5:4, 16:9, 9:16, 9:21, 21:9`
+
+For all models supporting size, the following sizes are supported:
+
+`640 x 1536, 768 x 1344, 832 x 1216, 896 x 1152, 1024x1024 (default), 1152 x 896, 1216 x 832, 1344 x 768, 1536 x 640`
+
+| Model | Dimensions Specification |
+| ------------------------------------------------------------ | ------------------------ |
+| `accounts/fireworks/models/flux-1-dev-fp8` | Aspect Ratio |
+| `accounts/fireworks/models/flux-1-schnell-fp8` | Aspect Ratio |
+| `accounts/fireworks/models/playground-v2-5-1024px-aesthetic` | Size |
+| `accounts/fireworks/models/japanese-stable-diffusion-xl` | Size |
+| `accounts/fireworks/models/playground-v2-1024px-aesthetic` | Size |
+| `accounts/fireworks/models/SSD-1B` | Size |
+| `accounts/fireworks/models/stable-diffusion-xl-1024-v1-0` | Size |
+
+For more details, see the [Fireworks models page](https://fireworks.ai/models).
+
+#### Stability AI Models
+
+Fireworks also presents several Stability AI models backed by Stability AI API
+keys and endpoint. The AI SDK Fireworks provider does not currently include
+support for these models:
+
+| Model ID |
+| -------------------------------------- |
+| `accounts/stability/models/sd3-turbo` |
+| `accounts/stability/models/sd3-medium` |
+| `accounts/stability/models/sd3` |
diff --git a/packages/fireworks/README.md b/packages/fireworks/README.md
index b715428839cc..a5bdc7e5cadd 100644
--- a/packages/fireworks/README.md
+++ b/packages/fireworks/README.md
@@ -1,34 +1,50 @@
# AI SDK - Fireworks Provider
-The **[Fireworks provider](https://sdk.vercel.ai/providers/ai-sdk-providers/fireworks)** for the [AI SDK](https://sdk.vercel.ai/docs) contains language model support for the [Fireworks](https://fireworks.ai) platform.
+The **[Fireworks provider](https://sdk.vercel.ai/providers/ai-sdk-providers/fireworks)** for the [AI SDK](https://sdk.vercel.ai/docs) contains language model and image model support for the [Fireworks](https://fireworks.ai) platform.
## Setup
The Fireworks provider is available in the `@ai-sdk/fireworks` module. You can install it with
-\```bash
+```bash
npm i @ai-sdk/fireworks
-\```
+```
## Provider Instance
You can import the default provider instance `fireworks` from `@ai-sdk/fireworks`:
-\```ts
+```ts
import { fireworks } from '@ai-sdk/fireworks';
-\```
+```
-## Example
+## Language Model Example
-\```ts
+```ts
import { fireworks } from '@ai-sdk/fireworks';
import { generateText } from 'ai';
const { text } = await generateText({
-model: fireworks('accounts/fireworks/models/llama-v2-13b-chat'),
-prompt: 'Write a JavaScript function that sorts a list:',
+ model: fireworks('accounts/fireworks/models/deepseek-v3'),
+ prompt: 'Write a JavaScript function that sorts a list:',
});
-\```
+```
+
+## Image Model Examples
+
+```ts
+import { fireworks } from '@ai-sdk/fireworks';
+import { experimental_generateImage as generateImage } from 'ai';
+import fs from 'fs';
+
+const { image } = await generateImage({
+ model: fireworks.image('accounts/fireworks/models/flux-1-dev-fp8'),
+ prompt: 'A serene mountain landscape at sunset',
+});
+const filename = `image-${Date.now()}.png`;
+fs.writeFileSync(filename, image.uint8Array);
+console.log(`Image saved to ${filename}`);
+```
## Documentation
diff --git a/packages/fireworks/src/fireworks-image-model.test.ts b/packages/fireworks/src/fireworks-image-model.test.ts
index 3556a516e69d..f391b6046cdc 100644
--- a/packages/fireworks/src/fireworks-image-model.test.ts
+++ b/packages/fireworks/src/fireworks-image-model.test.ts
@@ -1,35 +1,65 @@
-import { APICallError } from '@ai-sdk/provider';
import { BinaryTestServer } from '@ai-sdk/provider-utils/test';
import { describe, expect, it } from 'vitest';
import { FireworksImageModel } from './fireworks-image-model';
+import { FetchFunction } from '@ai-sdk/provider-utils';
const prompt = 'A cute baby sea otter';
-const model = new FireworksImageModel(
- 'accounts/fireworks/models/flux-1-dev-fp8',
- {
+function createBasicModel({
+ headers,
+ fetch,
+}: {
+ headers?: () => Record;
+ fetch?: FetchFunction;
+} = {}) {
+ return new FireworksImageModel('accounts/fireworks/models/flux-1-dev-fp8', {
provider: 'fireworks',
baseURL: 'https://api.example.com',
+ headers: headers ?? (() => ({ 'api-key': 'test-key' })),
+ fetch,
+ });
+}
+
+function createSizeModel() {
+ return new FireworksImageModel(
+ 'accounts/fireworks/models/playground-v2-5-1024px-aesthetic',
+ {
+ provider: 'fireworks',
+ baseURL: 'https://api.size-example.com',
+ headers: () => ({ 'api-key': 'test-key' }),
+ },
+ );
+}
+
+function createStabilityModel() {
+ return new FireworksImageModel('accounts/stability/models/sd3', {
+ provider: 'fireworks',
+ baseURL: 'https://api.stability.ai',
headers: () => ({ 'api-key': 'test-key' }),
- },
-);
+ });
+}
+
+const basicUrl =
+ 'https://api.example.com/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image';
+const sizeUrl =
+ 'https://api.size-example.com/image_generation/accounts/fireworks/models/playground-v2-5-1024px-aesthetic';
+const stabilityUrl =
+ 'https://api.stability.ai/v2beta/stable-image/generate/sd3';
describe('FireworksImageModel', () => {
describe('doGenerate', () => {
- const server = new BinaryTestServer(
- 'https://api.example.com/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image',
- );
-
+ const server = new BinaryTestServer([basicUrl, sizeUrl, stabilityUrl]);
server.setupTestEnvironment();
- function prepareBinaryResponse() {
+ function prepareBinaryResponse(url: string) {
const mockImageBuffer = Buffer.from('mock-image-data');
- server.responseBody = mockImageBuffer;
+ server.setResponseFor(url, { body: mockImageBuffer });
}
it('should pass the correct parameters including aspect ratio and seed', async () => {
- prepareBinaryResponse();
+ prepareBinaryResponse(basicUrl);
+ const model = createBasicModel();
await model.doGenerate({
prompt,
n: 1,
@@ -39,7 +69,8 @@ describe('FireworksImageModel', () => {
providerOptions: { fireworks: { additional_param: 'value' } },
});
- expect(await server.getRequestBodyJson()).toStrictEqual({
+ const request = await server.getRequestDataFor(basicUrl);
+ expect(await request.bodyJson()).toStrictEqual({
prompt,
aspect_ratio: '16:9',
seed: 42,
@@ -48,18 +79,13 @@ describe('FireworksImageModel', () => {
});
it('should pass headers', async () => {
- prepareBinaryResponse();
-
- const modelWithHeaders = new FireworksImageModel(
- 'accounts/fireworks/models/flux-1-dev-fp8',
- {
- provider: 'fireworks',
- baseURL: 'https://api.example.com',
- headers: () => ({
- 'Custom-Provider-Header': 'provider-header-value',
- }),
- },
- );
+ prepareBinaryResponse(basicUrl);
+
+ const modelWithHeaders = createBasicModel({
+ headers: () => ({
+ 'Custom-Provider-Header': 'provider-header-value',
+ }),
+ });
await modelWithHeaders.doGenerate({
prompt,
@@ -73,47 +99,18 @@ describe('FireworksImageModel', () => {
},
});
- const requestHeaders = await server.getRequestHeaders();
-
- expect(requestHeaders).toStrictEqual({
+ const request = await server.getRequestDataFor(basicUrl);
+ expect(request.headers()).toStrictEqual({
'content-type': 'application/json',
'custom-provider-header': 'provider-header-value',
'custom-request-header': 'request-header-value',
});
});
- it('should return binary image data', async () => {
- const mockImageBuffer = Buffer.from('mock-image-data');
- server.responseBody = mockImageBuffer;
-
- const result = await model.doGenerate({
- prompt,
- n: 1,
- size: undefined,
- aspectRatio: undefined,
- seed: undefined,
- providerOptions: {},
- });
-
- expect(result.images).toHaveLength(1);
- expect(result.images[0]).toBeInstanceOf(Uint8Array);
- expect(Buffer.from(result.images[0])).toEqual(mockImageBuffer);
- });
-
it('should handle empty response body', async () => {
- server.responseBody = null;
-
- await expect(
- model.doGenerate({
- prompt,
- n: 1,
- size: undefined,
- aspectRatio: undefined,
- seed: undefined,
- providerOptions: {},
- }),
- ).rejects.toThrow(APICallError);
+ server.setResponseFor(basicUrl, { body: null });
+ const model = createBasicModel();
await expect(
model.doGenerate({
prompt,
@@ -126,7 +123,7 @@ describe('FireworksImageModel', () => {
).rejects.toMatchObject({
message: 'Response body is empty',
statusCode: 200,
- url: 'https://api.example.com/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image',
+ url: basicUrl,
requestBodyValues: {
prompt: 'A cute baby sea otter',
},
@@ -134,20 +131,12 @@ describe('FireworksImageModel', () => {
});
it('should handle API errors', async () => {
- server.responseStatus = 400;
- server.responseBody = Buffer.from('Bad Request');
-
- await expect(
- model.doGenerate({
- prompt,
- n: 1,
- size: undefined,
- aspectRatio: undefined,
- seed: undefined,
- providerOptions: {},
- }),
- ).rejects.toThrow(APICallError);
+ server.setResponseFor(basicUrl, {
+ status: 400,
+ body: Buffer.from('Bad Request'),
+ });
+ const model = createBasicModel();
await expect(
model.doGenerate({
prompt,
@@ -160,7 +149,7 @@ describe('FireworksImageModel', () => {
).rejects.toMatchObject({
message: 'Bad Request',
statusCode: 400,
- url: 'https://api.example.com/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image',
+ url: basicUrl,
requestBodyValues: {
prompt: 'A cute baby sea otter',
},
@@ -168,11 +157,35 @@ describe('FireworksImageModel', () => {
});
});
- it('should return warnings for unsupported settings', async () => {
- const mockImageBuffer = Buffer.from('mock-image-data');
- server.responseBody = mockImageBuffer;
+ it('should handle size parameter for supported models', async () => {
+ prepareBinaryResponse(sizeUrl);
+
+ const sizeModel = createSizeModel();
- const result = await model.doGenerate({
+ await sizeModel.doGenerate({
+ prompt,
+ n: 1,
+ size: '1024x768',
+ aspectRatio: undefined,
+ seed: 42,
+ providerOptions: {},
+ });
+
+ const request = await server.getRequestDataFor(sizeUrl);
+ expect(await request.bodyJson()).toStrictEqual({
+ prompt,
+ width: '1024',
+ height: '768',
+ seed: 42,
+ });
+ });
+
+ it('should return appropriate warnings based on model capabilities', async () => {
+ prepareBinaryResponse(basicUrl);
+
+ // Test workflow model (supports aspectRatio but not size)
+ const model = createBasicModel();
+ const result1 = await model.doGenerate({
prompt,
n: 1,
size: '1024x1024',
@@ -181,14 +194,87 @@ describe('FireworksImageModel', () => {
providerOptions: {},
});
- expect(result.warnings).toStrictEqual([
- {
- type: 'unsupported-setting',
- setting: 'size',
- details:
- 'This model does not support the `size` option. Use `aspectRatio` instead.',
- },
- ]);
+ expect(result1.warnings).toContainEqual({
+ type: 'unsupported-setting',
+ setting: 'size',
+ details:
+ 'This model does not support the `size` option. Use `aspectRatio` instead.',
+ });
+
+ // Test size-supporting model
+ prepareBinaryResponse(sizeUrl);
+ const sizeModel = createSizeModel();
+
+ const result2 = await sizeModel.doGenerate({
+ prompt,
+ n: 1,
+ size: '1024x1024',
+ aspectRatio: '1:1',
+ seed: 123,
+ providerOptions: {},
+ });
+
+ expect(result2.warnings).toContainEqual({
+ type: 'unsupported-setting',
+ setting: 'aspectRatio',
+ details: 'This model does not support the `aspectRatio` option.',
+ });
+ });
+
+ it('should respect the abort signal', async () => {
+ prepareBinaryResponse(basicUrl);
+ const model = createBasicModel();
+ const controller = new AbortController();
+
+ const generatePromise = model.doGenerate({
+ prompt,
+ n: 1,
+ size: undefined,
+ aspectRatio: undefined,
+ seed: undefined,
+ providerOptions: {},
+ abortSignal: controller.signal,
+ });
+
+ controller.abort();
+
+ await expect(generatePromise).rejects.toThrow(
+ 'This operation was aborted',
+ );
+ });
+
+ it('should use custom fetch function when provided', async () => {
+ const mockFetch = vi.fn().mockResolvedValue(
+ new Response(Buffer.from('mock-image-data'), {
+ status: 200,
+ }),
+ );
+
+ const model = createBasicModel({
+ fetch: mockFetch,
+ });
+
+ await model.doGenerate({
+ prompt,
+ n: 1,
+ size: undefined,
+ aspectRatio: undefined,
+ seed: undefined,
+ providerOptions: {},
+ });
+
+ expect(mockFetch).toHaveBeenCalled();
+ });
+ });
+
+ describe('constructor', () => {
+ it('should expose correct provider and model information', () => {
+ const model = createBasicModel();
+
+ expect(model.provider).toBe('fireworks');
+ expect(model.modelId).toBe('accounts/fireworks/models/flux-1-dev-fp8');
+ expect(model.specificationVersion).toBe('v1');
+ expect(model.maxImagesPerCall).toBe(1);
});
});
});
diff --git a/packages/fireworks/src/fireworks-image-model.ts b/packages/fireworks/src/fireworks-image-model.ts
index 28602b0a76fe..b29b5057168e 100644
--- a/packages/fireworks/src/fireworks-image-model.ts
+++ b/packages/fireworks/src/fireworks-image-model.ts
@@ -15,8 +15,62 @@ import {
export type FireworksImageModelId =
| 'accounts/fireworks/models/flux-1-dev-fp8'
| 'accounts/fireworks/models/flux-1-schnell-fp8'
+ | 'accounts/fireworks/models/playground-v2-5-1024px-aesthetic'
+ | 'accounts/fireworks/models/japanese-stable-diffusion-xl'
+ | 'accounts/fireworks/models/playground-v2-1024px-aesthetic'
+ | 'accounts/fireworks/models/SSD-1B'
+ | 'accounts/fireworks/models/stable-diffusion-xl-1024-v1-0'
| (string & {});
+interface FireworksImageModelBackendConfig {
+ urlFormat: 'workflows' | 'image_generation';
+ supportsSize?: boolean;
+}
+
+const modelToBackendConfig: Partial<
+ Record
+> = {
+ 'accounts/fireworks/models/flux-1-dev-fp8': {
+ urlFormat: 'workflows',
+ },
+ 'accounts/fireworks/models/flux-1-schnell-fp8': {
+ urlFormat: 'workflows',
+ },
+ 'accounts/fireworks/models/playground-v2-5-1024px-aesthetic': {
+ urlFormat: 'image_generation',
+ supportsSize: true,
+ },
+ 'accounts/fireworks/models/japanese-stable-diffusion-xl': {
+ urlFormat: 'image_generation',
+ supportsSize: true,
+ },
+ 'accounts/fireworks/models/playground-v2-1024px-aesthetic': {
+ urlFormat: 'image_generation',
+ supportsSize: true,
+ },
+ 'accounts/fireworks/models/stable-diffusion-xl-1024-v1-0': {
+ urlFormat: 'image_generation',
+ supportsSize: true,
+ },
+ 'accounts/fireworks/models/SSD-1B': {
+ urlFormat: 'image_generation',
+ supportsSize: true,
+ },
+};
+
+function getUrlForModel(
+ baseUrl: string,
+ modelId: FireworksImageModelId,
+): string {
+ switch (modelToBackendConfig[modelId]?.urlFormat) {
+ case 'image_generation':
+ return `${baseUrl}/image_generation/${modelId}`;
+ case 'workflows':
+ default:
+ return `${baseUrl}/workflows/${modelId}/text_to_image`;
+ }
+}
+
interface FireworksImageModelConfig {
provider: string;
baseURL: string;
@@ -80,6 +134,42 @@ const statusCodeErrorResponseHandler: ResponseHandler = async ({
};
};
+interface ImageRequestParams {
+ baseUrl: string;
+ modelId: FireworksImageModelId;
+ prompt: string;
+ aspectRatio?: string;
+ size?: string;
+ seed?: number;
+ providerOptions: Record;
+ headers: Record;
+ abortSignal?: AbortSignal;
+ fetch?: FetchFunction;
+}
+
+async function postImageToApi(
+ params: ImageRequestParams,
+): Promise {
+ const splitSize = params.size?.split('x');
+ const { value: response } = await postJsonToApi({
+ url: getUrlForModel(params.baseUrl, params.modelId),
+ headers: params.headers,
+ body: {
+ prompt: params.prompt,
+ aspect_ratio: params.aspectRatio,
+ seed: params.seed,
+ ...(splitSize && { width: splitSize[0], height: splitSize[1] }),
+ ...(params.providerOptions.fireworks ?? {}),
+ },
+ failedResponseHandler: statusCodeErrorResponseHandler,
+ successfulResponseHandler: createBinaryResponseHandler(),
+ abortSignal: params.abortSignal,
+ fetch: params.fetch,
+ });
+
+ return response;
+}
+
export class FireworksImageModel implements ImageModelV1 {
readonly specificationVersion = 'v1';
@@ -108,7 +198,8 @@ export class FireworksImageModel implements ImageModelV1 {
> {
const warnings: Array = [];
- if (size != null) {
+ const backendConfig = modelToBackendConfig[this.modelId];
+ if (!backendConfig?.supportsSize && size != null) {
warnings.push({
type: 'unsupported-setting',
setting: 'size',
@@ -117,20 +208,25 @@ export class FireworksImageModel implements ImageModelV1 {
});
}
- const url = `${this.config.baseURL}/workflows/${this.modelId}/text_to_image`;
- const body = {
+ // Use supportsSize as a proxy for whether the model does not support
+ // aspectRatio. This invariant holds for the current set of models.
+ if (backendConfig?.supportsSize && aspectRatio != null) {
+ warnings.push({
+ type: 'unsupported-setting',
+ setting: 'aspectRatio',
+ details: 'This model does not support the `aspectRatio` option.',
+ });
+ }
+
+ const response = await postImageToApi({
+ baseUrl: this.config.baseURL,
prompt,
- aspect_ratio: aspectRatio,
+ aspectRatio,
+ size,
seed,
- ...(providerOptions.fireworks ?? {}),
- };
-
- const { value: response } = await postJsonToApi({
- url,
+ modelId: this.modelId,
+ providerOptions,
headers: combineHeaders(this.config.headers(), headers),
- body,
- failedResponseHandler: statusCodeErrorResponseHandler,
- successfulResponseHandler: createBinaryResponseHandler(),
abortSignal,
fetch: this.config.fetch,
});
diff --git a/packages/fireworks/src/fireworks-provider.ts b/packages/fireworks/src/fireworks-provider.ts
index 92d1501a3405..1413765346b6 100644
--- a/packages/fireworks/src/fireworks-provider.ts
+++ b/packages/fireworks/src/fireworks-provider.ts
@@ -45,7 +45,8 @@ const fireworksErrorStructure: ProviderErrorStructure = {
export interface FireworksProviderSettings {
/**
-Fireworks API key.
+Fireworks API key. Default value is taken from the `FIREWORKS_API_KEY`
+environment variable.
*/
apiKey?: string;
/**
diff --git a/packages/provider-utils/src/test/binary-test-server.test.ts b/packages/provider-utils/src/test/binary-test-server.test.ts
new file mode 100644
index 000000000000..321cd49f2710
--- /dev/null
+++ b/packages/provider-utils/src/test/binary-test-server.test.ts
@@ -0,0 +1,186 @@
+import {
+ describe,
+ it,
+ expect,
+ beforeEach,
+ afterEach,
+ beforeAll,
+ afterAll,
+ vi,
+} from 'vitest';
+import { BinaryTestServer } from './binary-test-server';
+
+describe('BinaryTestServer', () => {
+ let server: BinaryTestServer;
+
+ beforeEach(() => {
+ vi.clearAllMocks();
+ });
+
+ afterEach(() => {
+ vi.restoreAllMocks();
+ });
+
+ describe('constructor', () => {
+ it('should initialize with a single URL', () => {
+ const server = new BinaryTestServer('http://example.com');
+ expect(server.server).toBeDefined();
+ });
+
+ it('should initialize with multiple URLs', () => {
+ const server = new BinaryTestServer([
+ 'http://example.com',
+ 'http://test.com',
+ ]);
+ expect(server.server).toBeDefined();
+ });
+ });
+
+ describe('setResponseFor', () => {
+ beforeAll(() => {
+ server = new BinaryTestServer('http://example.com');
+ server.server.listen();
+ });
+
+ afterAll(() => {
+ server.server.close();
+ });
+
+ it('should set response options for a valid URL', () => {
+ const buffer = Buffer.from('test data');
+ server.setResponseFor('http://example.com/', {
+ body: buffer,
+ headers: { 'content-type': 'application/octet-stream' },
+ status: 201,
+ });
+ });
+
+ it('should throw error for invalid URL', () => {
+ expect(() =>
+ server.setResponseFor('http://invalid.com', { status: 200 }),
+ ).toThrow('No endpoint configured for URL');
+ });
+ });
+
+ describe('request handling', () => {
+ beforeAll(() => {
+ server = new BinaryTestServer('http://example.com');
+ server.server.listen();
+ });
+
+ afterAll(() => {
+ server.server.close();
+ });
+
+ beforeEach(() => {
+ server.server.resetHandlers();
+ });
+
+ it('should handle JSON requests', async () => {
+ const testData = { test: 'data' };
+ const fetchSpy = vi.spyOn(global, 'fetch');
+
+ await fetch('http://example.com', {
+ method: 'POST',
+ headers: { 'content-type': 'application/json' },
+ body: JSON.stringify(testData),
+ });
+
+ expect(fetchSpy).toHaveBeenCalledWith(
+ 'http://example.com',
+ expect.objectContaining({
+ method: 'POST',
+ headers: expect.objectContaining({
+ 'content-type': 'application/json',
+ }),
+ }),
+ );
+
+ const requestData = await server.getRequestDataFor('http://example.com/');
+ const bodyJson = await requestData.bodyJson();
+ expect(bodyJson).toEqual(testData);
+ });
+
+ it('should handle form data requests', async () => {
+ const formData = new FormData();
+ formData.append('field', 'value');
+ const fetchSpy = vi.spyOn(global, 'fetch');
+
+ await fetch('http://example.com', {
+ method: 'POST',
+ body: formData,
+ });
+
+ expect(fetchSpy).toHaveBeenCalledWith(
+ 'http://example.com',
+ expect.objectContaining({
+ method: 'POST',
+ body: expect.any(FormData),
+ }),
+ );
+
+ const requestData = await server.getRequestDataFor('http://example.com');
+ const formDataReceived = await requestData.bodyFormData();
+ expect(formDataReceived.get('field')).toBe('value');
+ });
+
+ it('should handle custom response configurations', async () => {
+ const responseBuffer = Buffer.from('test response');
+ const fetchSpy = vi.spyOn(global, 'fetch');
+
+ server.setResponseFor('http://example.com', {
+ body: responseBuffer,
+ headers: { 'x-custom': 'test' },
+ status: 201,
+ });
+
+ const response = await fetch('http://example.com', { method: 'POST' });
+
+ expect(fetchSpy).toHaveBeenCalledWith(
+ 'http://example.com',
+ expect.objectContaining({ method: 'POST' }),
+ );
+
+ expect(response.status).toBe(201);
+ expect(response.headers.get('x-custom')).toBe('test');
+ const responseData = await response.arrayBuffer();
+ expect(Buffer.from(responseData)).toEqual(responseBuffer);
+ });
+ });
+
+ describe('URL handling', () => {
+ let server: BinaryTestServer;
+
+ beforeEach(() => {
+ server = new BinaryTestServer('http://example.com');
+ server.server.listen();
+ // Set default response
+ server.setResponseFor('http://example.com', {
+ status: 200,
+ body: null,
+ });
+ });
+
+ afterEach(() => {
+ server.server.resetHandlers();
+ server.server.close();
+ });
+
+ it('should handle search params', async () => {
+ const response = await fetch('http://example.com?param=value', {
+ method: 'POST',
+ body: JSON.stringify({ test: true }),
+ });
+
+ expect(response.status).toBe(200);
+
+ const requestData = await server.getRequestDataFor('http://example.com');
+ expect(requestData.urlSearchParams().get('param')).toBe('value');
+ });
+
+ it('should handle relative URLs', () => {
+ const server = new BinaryTestServer('/api/endpoint');
+ expect(server.server).toBeDefined();
+ });
+ });
+});
diff --git a/packages/provider-utils/src/test/binary-test-server.ts b/packages/provider-utils/src/test/binary-test-server.ts
index 2dfad8c5d363..73fe5effb329 100644
--- a/packages/provider-utils/src/test/binary-test-server.ts
+++ b/packages/provider-utils/src/test/binary-test-server.ts
@@ -3,65 +3,127 @@ import { SetupServer, setupServer } from 'msw/node';
export class BinaryTestServer {
readonly server: SetupServer;
+ private endpoints: Map<
+ string,
+ {
+ responseBody: Buffer | null;
+ responseHeaders: Record;
+ responseStatus: number;
+ request: Request | undefined;
+ }
+ > = new Map();
- responseBody: Buffer | null = null;
- responseHeaders: Record = {};
- responseStatus = 200;
+ constructor(urls: string | string[]) {
+ const urlList = Array.isArray(urls) ? urls : [urls];
- request: Request | undefined;
+ // Initialize endpoints
+ urlList.forEach(url => {
+ const normalizedUrl = this.normalizeUrl(url);
+ this.endpoints.set(normalizedUrl, {
+ responseBody: null,
+ responseHeaders: {},
+ responseStatus: 200,
+ request: undefined,
+ });
+ });
- constructor(url: string) {
this.server = setupServer(
- http.post(url, ({ request }) => {
- this.request = request;
+ ...urlList.map(url =>
+ http.post(this.normalizeUrl(url), ({ request }) => {
+ const endpoint = this.endpoints.get(this.normalizeUrl(request.url));
+ if (!endpoint) {
+ return new HttpResponse(null, { status: 500 });
+ }
+ endpoint.request = request;
- if (this.responseBody === null) {
- return new HttpResponse(null, { status: this.responseStatus });
- }
+ if (endpoint.responseBody === null) {
+ return new HttpResponse(null, { status: endpoint.responseStatus });
+ }
- return new HttpResponse(this.responseBody, {
- status: this.responseStatus,
- headers: this.responseHeaders,
- });
- }),
+ return new HttpResponse(endpoint.responseBody, {
+ status: endpoint.responseStatus,
+ headers: endpoint.responseHeaders,
+ });
+ }),
+ ),
);
}
- async getRequestBodyJson() {
- expect(this.request).toBeDefined();
- return JSON.parse(await this.request!.text());
+ private normalizeUrl(url: string): string {
+ try {
+ // Parse URL and remove search params for endpoint matching
+ const urlObj = new URL(url);
+ urlObj.search = ''; // Clear search params for matching
+ const normalized = urlObj.toString();
+ return normalized.endsWith('/') ? normalized.slice(0, -1) : normalized;
+ } catch {
+ // If not a valid URL, assume it's a path and return as-is
+ return url.endsWith('/') ? url.slice(0, -1) : url;
+ }
}
- async getRequestHeaders() {
- expect(this.request).toBeDefined();
- const requestHeaders = this.request!.headers;
-
- // convert headers to object for easier comparison
- const headersObject: Record = {};
- requestHeaders.forEach((value, key) => {
- headersObject[key] = value;
- });
-
- return headersObject;
+ setResponseFor(
+ url: string,
+ options: {
+ body?: Buffer | null;
+ headers?: Record;
+ status?: number;
+ },
+ ) {
+ // Normalize the URL before lookup
+ const normalizedUrl = this.normalizeUrl(url);
+ const endpoint = this.endpoints.get(normalizedUrl);
+ if (!endpoint) {
+ throw new Error(`No endpoint configured for URL: ${url}`);
+ }
+ if (options.body !== undefined) endpoint.responseBody = options.body;
+ if (options.headers) endpoint.responseHeaders = options.headers;
+ if (options.status) endpoint.responseStatus = options.status;
}
- async getRequestUrlSearchParams() {
- expect(this.request).toBeDefined();
- return new URL(this.request!.url).searchParams;
- }
+ async getRequestDataFor(url: string) {
+ // Normalize the URL before lookup
+ const normalizedUrl = this.normalizeUrl(url);
+ const endpoint = this.endpoints.get(normalizedUrl);
+ if (!endpoint) {
+ throw new Error(`No endpoint configured for URL: ${url}`);
+ }
+ expect(endpoint.request).toBeDefined();
- async getRequestUrl() {
- expect(this.request).toBeDefined();
- return new URL(this.request!.url).toString();
+ return {
+ bodyJson: async () => {
+ const text = await endpoint.request!.text();
+ return JSON.parse(text);
+ },
+ bodyFormData: async () => {
+ const contentType = endpoint.request!.headers.get('content-type');
+ if (contentType?.includes('multipart/form-data')) {
+ return endpoint.request!.formData();
+ }
+ throw new Error('Request content-type is not multipart/form-data');
+ },
+ headers: () => {
+ const headersObject: Record = {};
+ endpoint.request!.headers.forEach((value, key) => {
+ headersObject[key] = value;
+ });
+ return headersObject;
+ },
+ urlSearchParams: () => new URL(endpoint.request!.url).searchParams,
+ url: () => new URL(endpoint.request!.url).toString(),
+ };
}
setupTestEnvironment() {
beforeAll(() => this.server.listen());
beforeEach(() => {
- this.responseBody = null;
- this.request = undefined;
- this.responseHeaders = {};
- this.responseStatus = 200;
+ // Reset all endpoints
+ this.endpoints.forEach(endpoint => {
+ endpoint.responseBody = null;
+ endpoint.request = undefined;
+ endpoint.responseHeaders = {};
+ endpoint.responseStatus = 200;
+ });
});
afterEach(() => this.server.resetHandlers());
afterAll(() => this.server.close());