Skip to content

Commit

Permalink
feat (ai/core): generate many images with parallel model calls (#4307)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored Jan 7, 2025
1 parent db74f6e commit a92f5f6
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 23 deletions.
5 changes: 5 additions & 0 deletions .changeset/brown-shrimps-sparkle.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

feat (ai/core): generate many images with parallel model calls
7 changes: 6 additions & 1 deletion content/docs/03-ai-sdk-core/35-image-generation.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ const { image } = await generateImage({

### Generating Multiple Images

`generateImage` also supports generating multiple images at once for models that support it:
`generateImage` also supports generating multiple images at once:

```tsx highlight={"7"}
import { experimental_generateImage as generateImage } from 'ai';
Expand All @@ -78,6 +78,11 @@ const { images } = await generateImage({
});
```

<Note>
`generateImage` will automatically call the model as often as needed (in
parallel) to generate the requested number of images.
</Note>

### Providing a Seed

You can provide a seed to the `generateImage` function to control the output of the image generation process.
Expand Down
20 changes: 20 additions & 0 deletions examples/ai-core/src/generate-image/openai-many.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { openai } from '@ai-sdk/openai';
import { experimental_generateImage as generateImage } from 'ai';
import 'dotenv/config';
import fs from 'fs';

async function main() {
const { images } = await generateImage({
model: openai.image('dall-e-3'),
n: 3, // 3 calls; dall-e-3 can only generate 1 image at a time
prompt: 'Santa Claus driving a Cadillac',
});

for (const image of images) {
const filename = `image-${Date.now()}.png`;
fs.writeFileSync(filename, image.uint8Array);
console.log(`Image saved to ${filename}`);
}
}

main().catch(console.error);
123 changes: 123 additions & 0 deletions packages/ai/core/generate-image/generate-image.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,127 @@ describe('generateImage', () => {
]);
});
});

describe('when several calls are required', () => {
it('should generate images', async () => {
const base64Images = [
'SGVsbG8gV29ybGQ=', // "Hello World" in base64
'VGVzdGluZw==', // "Testing" in base64
'MTIz', // "123" in base64
];

let callCount = 0;

const result = await generateImage({
model: new MockImageModelV1({
maxImagesPerCall: 2,
doGenerate: async options => {
switch (callCount++) {
case 0:
expect(options).toStrictEqual({
prompt,
n: 2,
seed: 12345,
size: '1024x1024',
aspectRatio: '16:9',
providerOptions: { openai: { style: 'vivid' } },
headers: { 'custom-request-header': 'request-header-value' },
abortSignal: undefined,
});
return { images: base64Images.slice(0, 2), warnings: [] };
case 1:
expect(options).toStrictEqual({
prompt,
n: 1,
seed: 12345,
size: '1024x1024',
aspectRatio: '16:9',
providerOptions: { openai: { style: 'vivid' } },
headers: { 'custom-request-header': 'request-header-value' },
abortSignal: undefined,
});
return { images: base64Images.slice(2), warnings: [] };
default:
throw new Error('Unexpected call');
}
},
}),
prompt,
n: 3,
size: '1024x1024',
aspectRatio: '16:9',
seed: 12345,
providerOptions: { openai: { style: 'vivid' } },
headers: { 'custom-request-header': 'request-header-value' },
});

expect(result.images.map(image => image.base64)).toStrictEqual(
base64Images,
);
});

it('should aggregate warnings', async () => {
const base64Images = [
'SGVsbG8gV29ybGQ=', // "Hello World" in base64
'VGVzdGluZw==', // "Testing" in base64
'MTIz', // "123" in base64
];

let callCount = 0;

const result = await generateImage({
model: new MockImageModelV1({
maxImagesPerCall: 2,
doGenerate: async options => {
switch (callCount++) {
case 0:
expect(options).toStrictEqual({
prompt,
n: 2,
seed: 12345,
size: '1024x1024',
aspectRatio: '16:9',
providerOptions: { openai: { style: 'vivid' } },
headers: { 'custom-request-header': 'request-header-value' },
abortSignal: undefined,
});
return {
images: base64Images.slice(0, 2),
warnings: [{ type: 'other', message: '1' }],
};
case 1:
expect(options).toStrictEqual({
prompt,
n: 1,
seed: 12345,
size: '1024x1024',
aspectRatio: '16:9',
providerOptions: { openai: { style: 'vivid' } },
headers: { 'custom-request-header': 'request-header-value' },
abortSignal: undefined,
});
return {
images: base64Images.slice(2),
warnings: [{ type: 'other', message: '2' }],
};
default:
throw new Error('Unexpected call');
}
},
}),
prompt,
n: 3,
size: '1024x1024',
aspectRatio: '16:9',
seed: 12345,
providerOptions: { openai: { style: 'vivid' } },
headers: { 'custom-request-header': 'request-header-value' },
});

expect(result.warnings).toStrictEqual([
{ type: 'other', message: '1' },
{ type: 'other', message: '2' },
]);
});
});
});
69 changes: 48 additions & 21 deletions packages/ai/core/generate-image/generate-image.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ as body parameters.
export async function generateImage({
model,
prompt,
n,
n = 1,
size,
aspectRatio,
seed,
Expand Down Expand Up @@ -102,33 +102,60 @@ Only applicable for HTTP-based providers.
}): Promise<GenerateImageResult> {
const { retry } = prepareRetries({ maxRetries: maxRetriesArg });

return new DefaultGenerateImageResult(
await retry(() =>
model.doGenerate({
prompt,
n: n ?? 1,
abortSignal,
headers,
size,
aspectRatio,
seed,
providerOptions: providerOptions ?? {},
}),
// default to 1 if the model has not specified limits on
// how many images can be generated in a single call
const maxImagesPerCall = model.maxImagesPerCall ?? 1;

// parallelize calls to the model:
const callCount = Math.ceil(n / maxImagesPerCall);
const callImageCounts = Array.from({ length: callCount }, (_, i) => {
if (i < callCount - 1) {
return maxImagesPerCall;
}

const remainder = n % maxImagesPerCall;
return remainder === 0 ? maxImagesPerCall : remainder;
});
const results = await Promise.all(
callImageCounts.map(async callImageCount =>
retry(() =>
model.doGenerate({
prompt,
n: callImageCount,
abortSignal,
headers,
size,
aspectRatio,
seed,
providerOptions: providerOptions ?? {},
}),
),
),
);

// collect result images & warnings
const images: Array<DefaultGeneratedImage> = [];
const warnings: Array<ImageGenerationWarning> = [];

for (const result of results) {
images.push(
...result.images.map(image => new DefaultGeneratedImage({ image })),
);
warnings.push(...result.warnings);
}

return new DefaultGenerateImageResult({ images, warnings });
}

class DefaultGenerateImageResult implements GenerateImageResult {
readonly images: Array<GeneratedImage>;
readonly warnings: Array<ImageGenerationWarning>;

constructor(options: {
images: Array<string> | Array<Uint8Array>;
images: Array<DefaultGeneratedImage>;
warnings: Array<ImageGenerationWarning>;
}) {
this.images = options.images.map(
image => new DefaultGeneratedImage({ imageData: image }),
);
this.images = options.images;
this.warnings = options.warnings;
}

Expand All @@ -141,11 +168,11 @@ class DefaultGeneratedImage implements GeneratedImage {
private base64Data: string | undefined;
private uint8ArrayData: Uint8Array | undefined;

constructor({ imageData }: { imageData: string | Uint8Array }) {
const isUint8Array = imageData instanceof Uint8Array;
constructor({ image }: { image: string | Uint8Array }) {
const isUint8Array = image instanceof Uint8Array;

this.base64Data = isUint8Array ? undefined : imageData;
this.uint8ArrayData = isUint8Array ? imageData : undefined;
this.base64Data = isUint8Array ? undefined : image;
this.uint8ArrayData = isUint8Array ? image : undefined;
}

// lazy conversion with caching to avoid unnecessary conversion overhead:
Expand Down
5 changes: 4 additions & 1 deletion packages/ai/core/test/mock-image-model-v1.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,24 @@ export class MockImageModelV1 implements ImageModelV1 {
readonly specificationVersion = 'v1';
readonly provider: ImageModelV1['provider'];
readonly modelId: ImageModelV1['modelId'];
readonly maxImagesPerCall = 1;
readonly maxImagesPerCall: ImageModelV1['maxImagesPerCall'];

doGenerate: ImageModelV1['doGenerate'];

constructor({
provider = 'mock-provider',
modelId = 'mock-model-id',
maxImagesPerCall = 1,
doGenerate = notImplemented,
}: {
provider?: ImageModelV1['provider'];
modelId?: ImageModelV1['modelId'];
maxImagesPerCall?: ImageModelV1['maxImagesPerCall'];
doGenerate?: ImageModelV1['doGenerate'];
} = {}) {
this.provider = provider;
this.modelId = modelId;
this.maxImagesPerCall = maxImagesPerCall;
this.doGenerate = doGenerate;
}
}

0 comments on commit a92f5f6

Please sign in to comment.