Skip to content

Commit 33cc9d3

Browse files
CorieWcabljac
authored andcommitted
refactor(plugins/vertexai): migrate to v2 API
1 parent 91775bf commit 33cc9d3

File tree

17 files changed

+217
-229
lines changed

17 files changed

+217
-229
lines changed

js/plugins/vertexai/src/embedder.ts

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@
1414
* limitations under the License.
1515
*/
1616

17-
import { z, type Document, type Genkit } from 'genkit';
17+
import { z, type Document } from 'genkit';
1818
import {
19-
embedderRef,
19+
embedderRef as createEmbedderRef,
2020
type EmbedderAction,
2121
type EmbedderReference,
2222
} from 'genkit/embedder';
2323
import type { GoogleAuth } from 'google-auth-library';
2424
import type { PluginOptions } from './common/types.js';
2525
import { predictModel, type PredictClient } from './predict.js';
26+
import { embedder } from 'genkit/plugin';
2627

2728
export const TaskTypeSchema = z.enum([
2829
'RETRIEVAL_DOCUMENT',
@@ -61,7 +62,7 @@ function commonRef(
6162
name: string,
6263
input?: InputType[]
6364
): EmbedderReference<typeof VertexEmbeddingConfigSchema> {
64-
return embedderRef({
65+
return createEmbedderRef({
6566
name: `vertexai/${name}`,
6667
configSchema: VertexEmbeddingConfigSchema,
6768
info: {
@@ -88,7 +89,7 @@ export const multimodalEmbedding001 = commonRef('multimodalembedding@001', [
8889
'image',
8990
'video',
9091
]);
91-
export const geminiEmbedding001 = embedderRef({
92+
export const geminiEmbedding001 = createEmbedderRef({
9293
name: 'vertexai/gemini-embedding-001',
9394
configSchema: VertexEmbeddingConfigSchema,
9495
info: {
@@ -254,14 +255,13 @@ type EmbeddingResult = {
254255
};
255256

256257
export function defineVertexAIEmbedder(
257-
ai: Genkit,
258258
name: string,
259259
client: GoogleAuth,
260260
options: PluginOptions
261261
): EmbedderAction<any> {
262-
const embedder =
262+
const embedderRef =
263263
SUPPORTED_EMBEDDER_MODELS[name] ??
264-
embedderRef({
264+
createEmbedderRef({
265265
name: `vertexai/${name}`,
266266
configSchema: VertexEmbeddingConfigSchema,
267267
info: {
@@ -298,18 +298,18 @@ export function defineVertexAIEmbedder(
298298
return predictClients[requestLocation];
299299
};
300300

301-
return ai.defineEmbedder(
301+
return embedder(
302302
{
303-
name: embedder.name,
304-
configSchema: embedder.configSchema,
305-
info: embedder.info!,
303+
name: embedderRef.name,
304+
configSchema: embedderRef.configSchema,
305+
info: embedderRef.info!,
306306
},
307307
async (input, options) => {
308-
const predictClient = predictClientFactory(options);
308+
const predictClient = predictClientFactory(embedderRef.config);
309309
const response = await predictClient(
310-
input.map((doc: Document) => {
310+
input.input.map((doc: Document) => {
311311
let instance: EmbeddingInstance;
312-
if (isMultiModal(embedder) && checkValidDocument(embedder, doc)) {
312+
if (isMultiModal(embedderRef) && checkValidDocument(embedderRef, doc)) {
313313
instance = {};
314314
if (doc.text) {
315315
instance.text = doc.text;
@@ -370,13 +370,13 @@ export function defineVertexAIEmbedder(
370370
// Text only embedder
371371
instance = {
372372
content: doc.text,
373-
task_type: options?.taskType,
374-
title: options?.title,
373+
task_type: embedderRef.config?.taskType,
374+
title: embedderRef.config?.title,
375375
};
376376
}
377377
return instance;
378378
}),
379-
{ outputDimensionality: options?.outputDimensionality }
379+
{ outputDimensionality: embedderRef.config?.outputDimensionality }
380380
);
381381
return {
382382
embeddings: response.predictions

js/plugins/vertexai/src/evaluation/evaluation.ts

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
import { z, type Action, type Genkit } from 'genkit';
17+
import { z, type Action } from 'genkit';
1818
import type { GoogleAuth } from 'google-auth-library';
1919
import { EvaluatorFactory } from './evaluator_factory.js';
2020

@@ -54,7 +54,6 @@ function stringify(input: unknown) {
5454
}
5555

5656
export function vertexEvaluators(
57-
ai: Genkit,
5857
auth: GoogleAuth,
5958
metrics: VertexAIEvaluationMetric[],
6059
projectId: string,
@@ -67,28 +66,28 @@ export function vertexEvaluators(
6766

6867
switch (metricType) {
6968
case VertexAIEvaluationMetricType.BLEU: {
70-
return createBleuEvaluator(ai, factory, metricSpec);
69+
return createBleuEvaluator(factory, metricSpec);
7170
}
7271
case VertexAIEvaluationMetricType.ROUGE: {
73-
return createRougeEvaluator(ai, factory, metricSpec);
72+
return createRougeEvaluator(factory, metricSpec);
7473
}
7574
case VertexAIEvaluationMetricType.FLUENCY: {
76-
return createFluencyEvaluator(ai, factory, metricSpec);
75+
return createFluencyEvaluator(factory, metricSpec);
7776
}
7877
case VertexAIEvaluationMetricType.SAFETY: {
79-
return createSafetyEvaluator(ai, factory, metricSpec);
78+
return createSafetyEvaluator(factory, metricSpec);
8079
}
8180
case VertexAIEvaluationMetricType.GROUNDEDNESS: {
82-
return createGroundednessEvaluator(ai, factory, metricSpec);
81+
return createGroundednessEvaluator(factory, metricSpec);
8382
}
8483
case VertexAIEvaluationMetricType.SUMMARIZATION_QUALITY: {
85-
return createSummarizationQualityEvaluator(ai, factory, metricSpec);
84+
return createSummarizationQualityEvaluator(factory, metricSpec);
8685
}
8786
case VertexAIEvaluationMetricType.SUMMARIZATION_HELPFULNESS: {
88-
return createSummarizationHelpfulnessEvaluator(ai, factory, metricSpec);
87+
return createSummarizationHelpfulnessEvaluator(factory, metricSpec);
8988
}
9089
case VertexAIEvaluationMetricType.SUMMARIZATION_VERBOSITY: {
91-
return createSummarizationVerbosityEvaluator(ai, factory, metricSpec);
90+
return createSummarizationVerbosityEvaluator(factory, metricSpec);
9291
}
9392
}
9493
});
@@ -108,12 +107,10 @@ const BleuResponseSchema = z.object({
108107

109108
// TODO: Add support for batch inputs
110109
function createBleuEvaluator(
111-
ai: Genkit,
112110
factory: EvaluatorFactory,
113111
metricSpec: any
114112
): Action {
115113
return factory.create(
116-
ai,
117114
{
118115
metric: VertexAIEvaluationMetricType.BLEU,
119116
displayName: 'BLEU',
@@ -150,12 +147,10 @@ const RougeResponseSchema = z.object({
150147

151148
// TODO: Add support for batch inputs
152149
function createRougeEvaluator(
153-
ai: Genkit,
154150
factory: EvaluatorFactory,
155151
metricSpec: any
156152
): Action {
157153
return factory.create(
158-
ai,
159154
{
160155
metric: VertexAIEvaluationMetricType.ROUGE,
161156
displayName: 'ROUGE',
@@ -191,12 +186,10 @@ const FluencyResponseSchema = z.object({
191186
});
192187

193188
function createFluencyEvaluator(
194-
ai: Genkit,
195189
factory: EvaluatorFactory,
196190
metricSpec: any
197191
): Action {
198192
return factory.create(
199-
ai,
200193
{
201194
metric: VertexAIEvaluationMetricType.FLUENCY,
202195
displayName: 'Fluency',
@@ -233,12 +226,10 @@ const SafetyResponseSchema = z.object({
233226
});
234227

235228
function createSafetyEvaluator(
236-
ai: Genkit,
237229
factory: EvaluatorFactory,
238230
metricSpec: any
239231
): Action {
240232
return factory.create(
241-
ai,
242233
{
243234
metric: VertexAIEvaluationMetricType.SAFETY,
244235
displayName: 'Safety',
@@ -275,12 +266,10 @@ const GroundednessResponseSchema = z.object({
275266
});
276267

277268
function createGroundednessEvaluator(
278-
ai: Genkit,
279269
factory: EvaluatorFactory,
280270
metricSpec: any
281271
): Action {
282272
return factory.create(
283-
ai,
284273
{
285274
metric: VertexAIEvaluationMetricType.GROUNDEDNESS,
286275
displayName: 'Groundedness',
@@ -319,12 +308,10 @@ const SummarizationQualityResponseSchema = z.object({
319308
});
320309

321310
function createSummarizationQualityEvaluator(
322-
ai: Genkit,
323311
factory: EvaluatorFactory,
324312
metricSpec: any
325313
): Action {
326314
return factory.create(
327-
ai,
328315
{
329316
metric: VertexAIEvaluationMetricType.SUMMARIZATION_QUALITY,
330317
displayName: 'Summarization quality',
@@ -363,12 +350,10 @@ const SummarizationHelpfulnessResponseSchema = z.object({
363350
});
364351

365352
function createSummarizationHelpfulnessEvaluator(
366-
ai: Genkit,
367353
factory: EvaluatorFactory,
368354
metricSpec: any
369355
): Action {
370356
return factory.create(
371-
ai,
372357
{
373358
metric: VertexAIEvaluationMetricType.SUMMARIZATION_HELPFULNESS,
374359
displayName: 'Summarization helpfulness',
@@ -408,12 +393,10 @@ const SummarizationVerbositySchema = z.object({
408393
});
409394

410395
function createSummarizationVerbosityEvaluator(
411-
ai: Genkit,
412396
factory: EvaluatorFactory,
413397
metricSpec: any
414398
): Action {
415399
return factory.create(
416-
ai,
417400
{
418401
metric: VertexAIEvaluationMetricType.SUMMARIZATION_VERBOSITY,
419402
displayName: 'Summarization verbosity',

js/plugins/vertexai/src/evaluation/evaluator_factory.ts

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
* limitations under the License.
1515
*/
1616

17-
import { type Action, type Genkit, type z } from 'genkit';
17+
import { type Action, type z } from 'genkit';
1818
import type { BaseEvalDataPoint, Score } from 'genkit/evaluator';
1919
import { runInNewSpan } from 'genkit/tracing';
2020
import type { GoogleAuth } from 'google-auth-library';
2121
import { getGenkitClientHeader } from '../common/index.js';
2222
import type { VertexAIEvaluationMetricType } from './evaluation.js';
23+
import { evaluator } from 'genkit/plugin';
2324

2425
export class EvaluatorFactory {
2526
constructor(
@@ -29,7 +30,6 @@ export class EvaluatorFactory {
2930
) {}
3031

3132
create<ResponseType extends z.ZodTypeAny>(
32-
ai: Genkit,
3333
config: {
3434
metric: VertexAIEvaluationMetricType;
3535
displayName: string;
@@ -39,7 +39,7 @@ export class EvaluatorFactory {
3939
toRequest: (datapoint: BaseEvalDataPoint) => any,
4040
responseHandler: (response: z.infer<ResponseType>) => Score
4141
): Action {
42-
return ai.defineEvaluator(
42+
return evaluator(
4343
{
4444
name: `vertexai/${config.metric.toLocaleLowerCase()}`,
4545
displayName: config.displayName,
@@ -48,7 +48,6 @@ export class EvaluatorFactory {
4848
async (datapoint: BaseEvalDataPoint) => {
4949
const responseSchema = config.responseSchema;
5050
const response = await this.evaluateInstances(
51-
ai,
5251
toRequest(datapoint),
5352
responseSchema
5453
);
@@ -62,13 +61,11 @@ export class EvaluatorFactory {
6261
}
6362

6463
async evaluateInstances<ResponseType extends z.ZodTypeAny>(
65-
ai: Genkit,
6664
partialRequest: any,
6765
responseSchema: ResponseType
6866
): Promise<z.infer<ResponseType>> {
6967
const locationName = `projects/${this.projectId}/locations/${this.location}`;
7068
return await runInNewSpan(
71-
ai,
7269
{
7370
metadata: {
7471
name: 'EvaluationService#evaluateInstances',

js/plugins/vertexai/src/evaluation/index.ts

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
import type { Genkit } from 'genkit';
18-
import { genkitPlugin, type GenkitPlugin } from 'genkit/plugin';
17+
import { genkitPluginV2, type GenkitPluginV2 } from 'genkit/plugin';
1918
import { getDerivedParams } from '../common/index.js';
2019
import { vertexEvaluators } from './evaluation.js';
2120
import type { PluginOptions } from './types.js';
@@ -25,10 +24,13 @@ export type { PluginOptions };
2524
/**
2625
* Add Google Cloud Vertex AI Rerankers API to Genkit.
2726
*/
28-
export function vertexAIEvaluation(options: PluginOptions): GenkitPlugin {
29-
return genkitPlugin('vertexAIEvaluation', async (ai: Genkit) => {
30-
const { projectId, location, authClient } = await getDerivedParams(options);
27+
export function vertexAIEvaluation(options: PluginOptions): GenkitPluginV2 {
28+
return genkitPluginV2({
29+
name: 'vertexAIEvaluation',
30+
init: async () => {
31+
const { projectId, location, authClient } = await getDerivedParams(options);
3132

32-
vertexEvaluators(ai, authClient, options.metrics, projectId, location);
33+
return vertexEvaluators(authClient, options.metrics, projectId, location);
34+
},
3335
});
3436
}

0 commit comments

Comments
 (0)