Skip to content

Commit 23a8c86

Browse files
erikeldridgegsiddh
authored andcommitted
VinF Hybrid Inference #4: ChromeAdapter in stream methods (rebased) (#8949)
1 parent cdbe63e commit 23a8c86

File tree

7 files changed

+172
-31
lines changed

7 files changed

+172
-31
lines changed

e2e/sample-apps/modular.js

+8-8
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,7 @@ import {
5858
onValue,
5959
off
6060
} from 'firebase/database';
61-
import {
62-
getGenerativeModel,
63-
getVertexAI,
64-
InferenceMode,
65-
VertexAI
66-
} from 'firebase/vertexai';
61+
import { getGenerativeModel, getVertexAI } from 'firebase/vertexai';
6762
import { getDataConnect, DataConnect } from 'firebase/data-connect';
6863

6964
/**
@@ -318,8 +313,13 @@ function callPerformance(app) {
318313
async function callVertexAI(app) {
319314
console.log('[VERTEXAI] start');
320315
const vertexAI = getVertexAI(app);
321-
const model = getGenerativeModel(vertexAI, { model: 'gemini-1.5-flash' });
322-
const result = await model.countTokens('abcdefg');
316+
const model = getGenerativeModel(vertexAI, {
317+
mode: 'prefer_in_cloud'
318+
});
319+
const result = await model.generateContentStream("What is Roko's Basalisk?");
320+
for await (const chunk of result.stream) {
321+
console.log(chunk.text());
322+
}
323323
console.log(`[VERTEXAI] counted tokens: ${result.totalTokens}`);
324324
}
325325

packages/vertexai/src/methods/chat-session.ts

+1
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ export class ChatSession {
149149
this._apiSettings,
150150
this.model,
151151
generateContentRequest,
152+
this.chromeAdapter,
152153
this.requestOptions
153154
);
154155

packages/vertexai/src/methods/chrome-adapter.test.ts

+70-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,25 @@ import { GenerateContentRequest } from '../types';
3030
use(sinonChai);
3131
use(chaiAsPromised);
3232

33+
/**
34+
* Converts the ReadableStream from response.body to an array of strings.
35+
*/
36+
async function toStringArray(
37+
stream: ReadableStream<Uint8Array>
38+
): Promise<string[]> {
39+
const decoder = new TextDecoder();
40+
const actual = [];
41+
const reader = stream.getReader();
42+
while (true) {
43+
const { done, value } = await reader.read();
44+
if (done) {
45+
break;
46+
}
47+
actual.push(decoder.decode(value));
48+
}
49+
return actual;
50+
}
51+
3352
describe('ChromeAdapter', () => {
3453
describe('isAvailable', () => {
3554
it('returns false if mode is only cloud', async () => {
@@ -280,7 +299,7 @@ describe('ChromeAdapter', () => {
280299
const request = {
281300
contents: [{ role: 'user', parts: [{ text: 'anything' }] }]
282301
} as GenerateContentRequest;
283-
const response = await adapter.generateContentOnDevice(request);
302+
const response = await adapter.generateContent(request);
284303
// Asserts initialization params are proxied.
285304
expect(createStub).to.have.been.calledOnceWith(onDeviceParams);
286305
// Asserts Vertex input type is mapped to Chrome type.
@@ -325,6 +344,7 @@ describe('ChromeAdapter', () => {
325344
const createStub = stub(languageModelProvider, 'create').resolves(
326345
languageModel
327346
);
347+
328348
// overrides impl with stub method
329349
const measureInputUsageStub = stub(
330350
languageModel,
@@ -336,6 +356,7 @@ describe('ChromeAdapter', () => {
336356
'prefer_on_device',
337357
onDeviceParams
338358
);
359+
339360
const countTokenRequest = {
340361
contents: [{ role: 'user', parts: [{ text: inputText }] }]
341362
} as GenerateContentRequest;
@@ -359,4 +380,52 @@ describe('ChromeAdapter', () => {
359380
});
360381
});
361382
});
383+
describe('generateContentStreamOnDevice', () => {
384+
it('generates content stream', async () => {
385+
const languageModelProvider = {
386+
create: () => Promise.resolve({})
387+
} as LanguageModel;
388+
const languageModel = {
389+
promptStreaming: _i => new ReadableStream()
390+
} as LanguageModel;
391+
const createStub = stub(languageModelProvider, 'create').resolves(
392+
languageModel
393+
);
394+
const part = 'hi';
395+
const promptStub = stub(languageModel, 'promptStreaming').returns(
396+
new ReadableStream({
397+
start(controller) {
398+
controller.enqueue([part]);
399+
controller.close();
400+
}
401+
})
402+
);
403+
const onDeviceParams = {} as LanguageModelCreateOptions;
404+
const adapter = new ChromeAdapter(
405+
languageModelProvider,
406+
'prefer_on_device',
407+
onDeviceParams
408+
);
409+
const request = {
410+
contents: [{ role: 'user', parts: [{ text: 'anything' }] }]
411+
} as GenerateContentRequest;
412+
const response = await adapter.generateContentStream(request);
413+
expect(createStub).to.have.been.calledOnceWith(onDeviceParams);
414+
expect(promptStub).to.have.been.calledOnceWith([
415+
{
416+
role: request.contents[0].role,
417+
content: [
418+
{
419+
type: 'text',
420+
content: request.contents[0].parts[0].text
421+
}
422+
]
423+
}
424+
]);
425+
const actual = await toStringArray(response.body!);
426+
expect(actual).to.deep.equal([
427+
`data: {"candidates":[{"content":{"role":"model","parts":[{"text":["${part}"]}]}}]}\n\n`
428+
]);
429+
});
430+
});
362431
});

packages/vertexai/src/methods/chrome-adapter.ts

+64-14
Original file line numberDiff line numberDiff line change
@@ -95,27 +95,34 @@ export class ChromeAdapter {
9595
* @param request a standard Vertex {@link GenerateContentRequest}
9696
* @returns {@link Response}, so we can reuse common response formatting.
9797
*/
98-
async generateContentOnDevice(
98+
async generateContent(request: GenerateContentRequest): Promise<Response> {
99+
const session = await this.createSession(
100+
// TODO: normalize on-device params during construction.
101+
this.onDeviceParams || {}
102+
);
103+
const messages = ChromeAdapter.toLanguageModelMessages(request.contents);
104+
const text = await session.prompt(messages);
105+
return ChromeAdapter.toResponse(text);
106+
}
107+
108+
/**
109+
* Generates content stream on device.
110+
*
111+
* <p>This is comparable to {@link GenerativeModel.generateContentStream} for generating content in
112+
* Cloud.</p>
113+
* @param request a standard Vertex {@link GenerateContentRequest}
114+
* @returns {@link Response}, so we can reuse common response formatting.
115+
*/
116+
async generateContentStream(
99117
request: GenerateContentRequest
100118
): Promise<Response> {
101119
const session = await this.createSession(
102120
// TODO: normalize on-device params during construction.
103121
this.onDeviceParams || {}
104122
);
105123
const messages = ChromeAdapter.toLanguageModelMessages(request.contents);
106-
const text = await session.prompt(messages);
107-
return {
108-
json: () =>
109-
Promise.resolve({
110-
candidates: [
111-
{
112-
content: {
113-
parts: [{ text }]
114-
}
115-
}
116-
]
117-
})
118-
} as Response;
124+
const stream = await session.promptStreaming(messages);
125+
return ChromeAdapter.toStreamResponse(stream);
119126
}
120127

121128
async countTokens(request: CountTokensRequest): Promise<Response> {
@@ -240,4 +247,47 @@ export class ChromeAdapter {
240247
this.oldSession = newSession;
241248
return newSession;
242249
}
250+
251+
/**
252+
* Formats string returned by Chrome as a {@link Response} returned by Vertex.
253+
*/
254+
private static toResponse(text: string): Response {
255+
return {
256+
json: async () => ({
257+
candidates: [
258+
{
259+
content: {
260+
parts: [{ text }]
261+
}
262+
}
263+
]
264+
})
265+
} as Response;
266+
}
267+
268+
/**
269+
* Formats string stream returned by Chrome as SSE returned by Vertex.
270+
*/
271+
private static toStreamResponse(stream: ReadableStream<string>): Response {
272+
const encoder = new TextEncoder();
273+
return {
274+
body: stream.pipeThrough(
275+
new TransformStream({
276+
transform(chunk, controller) {
277+
const json = JSON.stringify({
278+
candidates: [
279+
{
280+
content: {
281+
role: 'model',
282+
parts: [{ text: chunk }]
283+
}
284+
}
285+
]
286+
});
287+
controller.enqueue(encoder.encode(`data: ${json}\n\n`));
288+
}
289+
})
290+
)
291+
} as Response;
292+
}
243293
}

packages/vertexai/src/methods/generate-content.test.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -308,17 +308,17 @@ describe('generateContent()', () => {
308308
);
309309
expect(mockFetch).to.be.called;
310310
});
311+
// TODO: define a similar test for generateContentStream
311312
it('on-device', async () => {
312313
const chromeAdapter = new ChromeAdapter();
313314
const isAvailableStub = stub(chromeAdapter, 'isAvailable').resolves(true);
314315
const mockResponse = getMockResponse(
315316
'vertexAI',
316317
'unary-success-basic-reply-short.json'
317318
);
318-
const generateContentStub = stub(
319-
chromeAdapter,
320-
'generateContentOnDevice'
321-
).resolves(mockResponse as Response);
319+
const generateContentStub = stub(chromeAdapter, 'generateContent').resolves(
320+
mockResponse as Response
321+
);
322322
const result = await generateContent(
323323
fakeApiSettings,
324324
'model',

packages/vertexai/src/methods/generate-content.ts

+24-4
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,40 @@ import { processStream } from '../requests/stream-reader';
2828
import { ApiSettings } from '../types/internal';
2929
import { ChromeAdapter } from './chrome-adapter';
3030

31-
export async function generateContentStream(
31+
async function generateContentStreamOnCloud(
3232
apiSettings: ApiSettings,
3333
model: string,
3434
params: GenerateContentRequest,
3535
requestOptions?: RequestOptions
36-
): Promise<GenerateContentStreamResult> {
37-
const response = await makeRequest(
36+
): Promise<Response> {
37+
return makeRequest(
3838
model,
3939
Task.STREAM_GENERATE_CONTENT,
4040
apiSettings,
4141
/* stream */ true,
4242
JSON.stringify(params),
4343
requestOptions
4444
);
45+
}
46+
47+
export async function generateContentStream(
48+
apiSettings: ApiSettings,
49+
model: string,
50+
params: GenerateContentRequest,
51+
chromeAdapter: ChromeAdapter,
52+
requestOptions?: RequestOptions
53+
): Promise<GenerateContentStreamResult> {
54+
let response;
55+
if (await chromeAdapter.isAvailable(params)) {
56+
response = await chromeAdapter.generateContentStream(params);
57+
} else {
58+
response = await generateContentStreamOnCloud(
59+
apiSettings,
60+
model,
61+
params,
62+
requestOptions
63+
);
64+
}
4565
return processStream(response);
4666
}
4767

@@ -70,7 +90,7 @@ export async function generateContent(
7090
): Promise<GenerateContentResult> {
7191
let response;
7292
if (await chromeAdapter.isAvailable(params)) {
73-
response = await chromeAdapter.generateContentOnDevice(params);
93+
response = await chromeAdapter.generateContent(params);
7494
} else {
7595
response = await generateContentOnCloud(
7696
apiSettings,

packages/vertexai/src/models/generative-model.ts

+1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ export class GenerativeModel extends VertexAIModel {
123123
systemInstruction: this.systemInstruction,
124124
...formattedParams
125125
},
126+
this.chromeAdapter,
126127
this.requestOptions
127128
);
128129
}

0 commit comments

Comments
 (0)