Skip to content

Commit 4051f86

Browse files
committed
Prototype LocalModel
1 parent 4e0f630 commit 4051f86

File tree

5 files changed

+136
-2
lines changed

5 files changed

+136
-2
lines changed

packages/vertexai/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
"devDependencies": {
5959
"@firebase/app": "0.11.4",
6060
"@rollup/plugin-json": "6.1.0",
61+
"@types/dom-chromium-ai": "0.0.6",
6162
"rollup": "2.79.2",
6263
"rollup-plugin-replace": "2.2.0",
6364
"rollup-plugin-typescript2": "0.36.0",

packages/vertexai/src/api.test.ts

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616
*/
1717
import { ImagenModelParams, ModelParams, VertexAIErrorCode } from './types';
1818
import { VertexAIError } from './errors';
19-
import { ImagenModel, getGenerativeModel, getImagenModel } from './api';
19+
import {
20+
ImagenModel,
21+
getGenerativeModel,
22+
getHybridModel,
23+
getImagenModel,
24+
getLocalModel
25+
} from './api';
2026
import { expect } from 'chai';
2127
import { VertexAI } from './public-types';
2228
import { GenerativeModel } from './models/generative-model';
@@ -167,4 +173,38 @@ describe('Top level API', () => {
167173
expect(genModel).to.be.an.instanceOf(ImagenModel);
168174
expect(genModel.model).to.equal('publishers/google/models/my-model');
169175
});
176+
it('getLocalModel', async () => {
177+
const languageModel = {
178+
prompt: (s: string) => Promise.resolve(s)
179+
} as AILanguageModel;
180+
const aiProvider = {
181+
languageModel: {
182+
create: () => Promise.resolve(languageModel)
183+
}
184+
} as AI;
185+
const model = getLocalModel(aiProvider);
186+
const expectedText = 'hello';
187+
const response = await model.generateContent(expectedText);
188+
expect(response.response.text()).to.equal(expectedText);
189+
});
190+
it('getHybridModel', async () => {
191+
const languageModel = {
192+
prompt: (s: string) => Promise.resolve(s)
193+
} as AILanguageModel;
194+
const aiProvider = {
195+
languageModel: {
196+
create: () => Promise.resolve(languageModel)
197+
}
198+
} as AI;
199+
const model = getHybridModel(
200+
getGenerativeModel(fakeVertexAI, {
201+
model: 'my-model'
202+
}),
203+
getLocalModel(aiProvider)
204+
);
205+
const chat = model.startChat();
206+
const expectedText = 'hello';
207+
const response = await chat.sendMessage(expectedText);
208+
expect(response.response.text()).to.equal(expectedText);
209+
});
170210
});

packages/vertexai/src/api.ts

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@ import {
2828
VertexAIErrorCode
2929
} from './types';
3030
import { VertexAIError } from './errors';
31-
import { VertexAIModel, GenerativeModel, ImagenModel } from './models';
31+
import {
32+
VertexAIModel,
33+
GenerativeModel,
34+
ImagenModel,
35+
HybridModel,
36+
LocalModel
37+
} from './models';
3238

3339
export { ChatSession } from './methods/chat-session';
3440
export * from './requests/schema-builder';
@@ -109,3 +115,14 @@ export function getImagenModel(
109115
}
110116
return new ImagenModel(vertexAI, modelParams, requestOptions);
111117
}
118+
119+
export function getLocalModel(aiProvider: AI = window.ai): LocalModel {
120+
return new LocalModel(aiProvider);
121+
}
122+
123+
export function getHybridModel(
124+
remoteModel: GenerativeModel,
125+
localModel: LocalModel
126+
): HybridModel {
127+
return new HybridModel(remoteModel, localModel);
128+
}

packages/vertexai/src/models/generative-model.ts

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,74 @@ export class GenerativeModel extends VertexAIModel {
148148
return countTokens(this._apiSettings, this.model, formattedParams);
149149
}
150150
}
151+
152+
interface ChatMethods {
153+
sendMessage(
154+
request: string | Array<string | Part>
155+
): Promise<GenerateContentResult>;
156+
}
157+
class HybridChat implements ChatMethods {
158+
constructor(
159+
private remoteModel: GenerativeModel,
160+
private localModel: LocalModel
161+
) {}
162+
async sendMessage(
163+
request: string | Array<string | Part>
164+
): Promise<GenerateContentResult> {
165+
if (await this.localModel.isSupported(request)) {
166+
return this.localModel.generateContent(request);
167+
}
168+
return this.remoteModel.generateContent(request);
169+
}
170+
}
171+
interface GenModelMethods {
172+
generateContent(
173+
request: GenerateContentRequest | string | Array<string | Part>
174+
): Promise<GenerateContentResult>;
175+
}
176+
/**
177+
* Normalizes Chrome API, if available, to Vertex API
178+
*/
179+
export class LocalModel implements GenModelMethods {
180+
constructor(private aiProvider?: AI) {}
181+
async generateContent(
182+
request: GenerateContentRequest | string | Array<string | Part>
183+
): Promise<GenerateContentResult> {
184+
const session = await this.session();
185+
if (typeof request !== 'string') {
186+
throw new Error('unsupported request format');
187+
}
188+
const result = await session.prompt(request);
189+
return {
190+
response: {
191+
text: () => result,
192+
functionCalls: () => undefined
193+
}
194+
} as GenerateContentResult;
195+
}
196+
async isSupported(
197+
request: string | Array<string | Part> | GenerateContentRequest
198+
): Promise<boolean> {
199+
return typeof request === 'string';
200+
}
201+
private async session(): Promise<AILanguageModel> {
202+
return this.aiProvider!.languageModel.create();
203+
}
204+
}
205+
export class HybridModel implements GenModelMethods {
206+
constructor(
207+
private remoteModel: GenerativeModel,
208+
private localModel: LocalModel
209+
) {}
210+
async generateContent(
211+
request: GenerateContentRequest | string | Array<string | Part>
212+
): Promise<GenerateContentResult> {
213+
if (await this.localModel.isSupported(request)) {
214+
return this.localModel.generateContent(request);
215+
}
216+
return this.remoteModel.generateContent(request);
217+
}
218+
startChat(): HybridChat {
219+
return new HybridChat(this.remoteModel, this.localModel);
220+
}
221+
}

yarn.lock

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2949,6 +2949,11 @@
29492949
resolved "https://registry.npmjs.org/@types/deep-eql/-/deep-eql-4.0.2.tgz#334311971d3a07121e7eb91b684a605e7eea9cbd"
29502950
integrity sha512-c9h9dVVMigMPc4bwTvC5dxqtqJZwQPePsWjPlpSOnojbor6pGqdk541lfA7AqFQr5pB1BRdq0juY9db81BwyFw==
29512951

2952+
2953+
version "0.0.6"
2954+
resolved "https://registry.npmjs.org/@types/dom-chromium-ai/-/dom-chromium-ai-0.0.6.tgz#0c9e5712d8db3d26586cd9f175001b509cd2e514"
2955+
integrity sha512-/jUGe9a3BLzsjjg18Olk/Ul64PZ0P4aw8uNxrXeXVTni5PSxyCfyhHb4UohsXNVByOnwYGzlqUcb3vYKVsG4mg==
2956+
29522957
"@types/eslint-scope@^3.7.7":
29532958
version "3.7.7"
29542959
resolved "https://registry.npmjs.org/@types/eslint-scope/-/eslint-scope-3.7.7.tgz#3108bd5f18b0cdb277c867b3dd449c9ed7079ac5"

0 commit comments

Comments
 (0)