Skip to content

Commit

Permalink
feat: change image generation errors to warnings (#4298)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored Jan 7, 2025
1 parent 045f726 commit 6337688
Show file tree
Hide file tree
Showing 20 changed files with 287 additions and 175 deletions.
10 changes: 10 additions & 0 deletions .changeset/sour-months-give.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
'@ai-sdk/provider-utils': patch
'@ai-sdk/google-vertex': patch
'@ai-sdk/fireworks': patch
'@ai-sdk/provider': patch
'@ai-sdk/openai': patch
'ai': patch
---

feat: change image generation errors to warnings
11 changes: 11 additions & 0 deletions content/docs/03-ai-sdk-core/35-image-generation.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,17 @@ const { image } = await generateImage({
});
```

### Warnings

If the model returns warnings, e.g. for unsupported parameters, they will be available in the `warnings` property of the response.

```tsx
const { image, warnings } = await generateImage({
model: openai.image('dall-e-3'),
prompt: 'Santa Claus driving a Cadillac',
});
```

## Image Models

| Provider | Model | Sizes | Aspect Ratios |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,5 +149,11 @@ console.log(images);
},
],
},
{
name: 'warnings',
type: 'ImageGenerationWarning[]',
description:
'Warnings from the model provider (e.g. unsupported settings).',
},
]}
/>
7 changes: 7 additions & 0 deletions packages/ai/core/generate-image/generate-image-result.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { ImageGenerationWarning } from '../types/image-model';

/**
The result of a `generateImage` call.
It contains the images and additional information.
Expand All @@ -12,6 +14,11 @@ The first image that was generated.
The images that were generated.
*/
readonly images: Array<GeneratedImage>;

/**
Warnings for the call, e.g. unsupported settings.
*/
readonly warnings: Array<ImageGenerationWarning>;
}

export interface GeneratedImage {
Expand Down
35 changes: 31 additions & 4 deletions packages/ai/core/generate-image/generate-image.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ describe('generateImage', () => {
model: new MockImageModelV1({
doGenerate: async args => {
capturedArgs = args;
return { images: [] };
return { images: [], warnings: [] };
},
}),
prompt,
Expand All @@ -43,6 +43,30 @@ describe('generateImage', () => {
});
});

it('should return warnings', async () => {
const result = await generateImage({
model: new MockImageModelV1({
doGenerate: async () => ({
images: [],
warnings: [
{
type: 'other',
message: 'Setting is not supported',
},
],
}),
}),
prompt,
});

expect(result.warnings).toStrictEqual([
{
type: 'other',
message: 'Setting is not supported',
},
]);
});

describe('base64 image data', () => {
it('should return generated images', async () => {
const base64Images = [
Expand All @@ -52,7 +76,7 @@ describe('generateImage', () => {

const result = await generateImage({
model: new MockImageModelV1({
doGenerate: async () => ({ images: base64Images }),
doGenerate: async () => ({ images: base64Images, warnings: [] }),
}),
prompt,
});
Expand All @@ -79,7 +103,10 @@ describe('generateImage', () => {

const result = await generateImage({
model: new MockImageModelV1({
doGenerate: async () => ({ images: [base64Image, 'base64-image-2'] }),
doGenerate: async () => ({
images: [base64Image, 'base64-image-2'],
warnings: [],
}),
}),
prompt,
});
Expand All @@ -103,7 +130,7 @@ describe('generateImage', () => {

const result = await generateImage({
model: new MockImageModelV1({
doGenerate: async () => ({ images: uint8ArrayImages }),
doGenerate: async () => ({ images: uint8ArrayImages, warnings: [] }),
}),
prompt,
});
Expand Down
34 changes: 20 additions & 14 deletions packages/ai/core/generate-image/generate-image.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
convertUint8ArrayToBase64,
} from '@ai-sdk/provider-utils';
import { prepareRetries } from '../prompt/prepare-retries';
import { ImageGenerationWarning } from '../types/image-model';
import { GeneratedImage, GenerateImageResult } from './generate-image-result';

/**
Expand Down Expand Up @@ -101,29 +102,34 @@ Only applicable for HTTP-based providers.
}): Promise<GenerateImageResult> {
const { retry } = prepareRetries({ maxRetries: maxRetriesArg });

const { images } = await retry(() =>
model.doGenerate({
prompt,
n: n ?? 1,
abortSignal,
headers,
size,
aspectRatio,
seed,
providerOptions: providerOptions ?? {},
}),
return new DefaultGenerateImageResult(
await retry(() =>
model.doGenerate({
prompt,
n: n ?? 1,
abortSignal,
headers,
size,
aspectRatio,
seed,
providerOptions: providerOptions ?? {},
}),
),
);

return new DefaultGenerateImageResult({ images });
}

class DefaultGenerateImageResult implements GenerateImageResult {
readonly images: Array<GeneratedImage>;
readonly warnings: Array<ImageGenerationWarning>;

constructor(options: { images: Array<string> | Array<Uint8Array> }) {
constructor(options: {
images: Array<string> | Array<Uint8Array>;
warnings: Array<ImageGenerationWarning>;
}) {
this.images = options.images.map(
image => new DefaultGeneratedImage({ imageData: image }),
);
this.warnings = options.warnings;
}

get image() {
Expand Down
12 changes: 12 additions & 0 deletions packages/ai/core/types/image-model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { ImageModelV1, ImageModelV1CallWarning } from '@ai-sdk/provider';

/**
Image model that is used by the AI SDK Core functions.
*/
export type ImageModel = ImageModelV1;

/**
Warning from the model provider for this call. The call will proceed, but e.g.
some settings might not be supported, which can lead to suboptimal results.
*/
export type ImageGenerationWarning = ImageModelV1CallWarning;
4 changes: 4 additions & 0 deletions packages/ai/core/types/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
export type { Embedding, EmbeddingModel } from './embedding-model';
export type {
ImageModel,
ImageGenerationWarning as ImageModelCallWarning,
} from './image-model';
export type {
CallWarning,
CoreToolChoice,
Expand Down
55 changes: 20 additions & 35 deletions packages/fireworks/src/fireworks-image-model.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import { APICallError } from '@ai-sdk/provider';
import { BinaryTestServer } from '@ai-sdk/provider-utils/test';
import { describe, expect, it } from 'vitest';
import { FireworksImageModel } from './fireworks-image-model';
import { describe, it, expect } from 'vitest';
import { UnsupportedFunctionalityError } from '@ai-sdk/provider';

const prompt = 'A cute baby sea otter';

Expand Down Expand Up @@ -169,41 +168,27 @@ describe('FireworksImageModel', () => {
});
});

it('should throw error when requesting more than one image', async () => {
await expect(
model.doGenerate({
prompt,
n: 2,
size: undefined,
aspectRatio: undefined,
seed: undefined,
providerOptions: {},
}),
).rejects.toThrowError(
new UnsupportedFunctionalityError({
functionality: 'generate multiple images',
message: `This model does not support generating more than 1 images at a time.`,
}),
);
});
it('should return warnings for unsupported settings', async () => {
const mockImageBuffer = Buffer.from('mock-image-data');
server.responseBody = mockImageBuffer;

it('should throw error when specifying image size', async () => {
await expect(
model.doGenerate({
prompt,
n: 1,
size: '512x512',
aspectRatio: undefined,
seed: undefined,
providerOptions: {},
}),
).rejects.toThrowError(
new UnsupportedFunctionalityError({
functionality: 'image size',
message:
const result = await model.doGenerate({
prompt,
n: 1,
size: '1024x1024',
aspectRatio: '1:1',
seed: 123,
providerOptions: {},
});

expect(result.warnings).toStrictEqual([
{
type: 'unsupported-setting',
setting: 'size',
details:
'This model does not support the `size` option. Use `aspectRatio` instead.',
}),
);
},
]);
});
});
});
20 changes: 8 additions & 12 deletions packages/fireworks/src/fireworks-image-model.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import {
APICallError,
ImageModelV1,
UnsupportedFunctionalityError,
ImageModelV1CallWarning,
} from '@ai-sdk/provider';
import {
combineHeaders,
Expand Down Expand Up @@ -106,21 +106,17 @@ export class FireworksImageModel implements ImageModelV1 {
}: Parameters<ImageModelV1['doGenerate']>[0]): Promise<
Awaited<ReturnType<ImageModelV1['doGenerate']>>
> {
const warnings: Array<ImageModelV1CallWarning> = [];

if (size != null) {
throw new UnsupportedFunctionalityError({
functionality: 'image size',
message:
warnings.push({
type: 'unsupported-setting',
setting: 'size',
details:
'This model does not support the `size` option. Use `aspectRatio` instead.',
});
}

if (n > this.maxImagesPerCall) {
throw new UnsupportedFunctionalityError({
functionality: `generate more than ${this.maxImagesPerCall} images`,
message: `This model does not support generating more than ${this.maxImagesPerCall} images at a time.`,
});
}

const url = `${this.config.baseURL}/workflows/${this.modelId}/text_to_image`;
const body = {
prompt,
Expand All @@ -139,6 +135,6 @@ export class FireworksImageModel implements ImageModelV1 {
fetch: this.config.fetch,
});

return { images: [new Uint8Array(response)] };
return { images: [new Uint8Array(response)], warnings };
}
}
47 changes: 22 additions & 25 deletions packages/google-vertex/src/google-vertex-image-model.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { JsonTestServer } from '@ai-sdk/provider-utils/test';
import { describe, expect, it } from 'vitest';
import { GoogleVertexImageModel } from './google-vertex-image-model';
import { UnsupportedFunctionalityError } from '@ai-sdk/provider';

const prompt = 'A cute baby sea otter';

Expand Down Expand Up @@ -99,30 +98,6 @@ describe('GoogleVertexImageModel', () => {
expect(result.images).toStrictEqual(['base64-image-1', 'base64-image-2']);
});

it('throws when size is specified', async () => {
const model = new GoogleVertexImageModel('imagen-3.0-generate-001', {
provider: 'vertex',
baseURL: 'https://example.com',
});

await expect(
model.doGenerate({
prompt: 'test prompt',
n: 1,
size: '1024x1024',
aspectRatio: undefined,
seed: undefined,
providerOptions: {},
}),
).rejects.toThrow(
new UnsupportedFunctionalityError({
functionality: 'image size',
message:
'This model does not support the `size` option. Use `aspectRatio` instead.',
}),
);
});

it('sends aspect ratio in the request', async () => {
prepareJsonResponse();

Expand Down Expand Up @@ -216,5 +191,27 @@ describe('GoogleVertexImageModel', () => {
},
});
});

it('should return warnings for unsupported settings', async () => {
prepareJsonResponse();

const result = await model.doGenerate({
prompt,
n: 1,
size: '1024x1024',
aspectRatio: '1:1',
seed: 123,
providerOptions: {},
});

expect(result.warnings).toStrictEqual([
{
type: 'unsupported-setting',
setting: 'size',
details:
'This model does not support the `size` option. Use `aspectRatio` instead.',
},
]);
});
});
});
Loading

0 comments on commit 6337688

Please sign in to comment.