@@ -28,7 +28,11 @@ import {
28
28
type ModelReference ,
29
29
type z ,
30
30
} from 'genkit' ;
31
- import { genkitPluginV2 , ResolvableAction , type GenkitPluginV2 } from 'genkit/plugin' ;
31
+ import {
32
+ ResolvableAction ,
33
+ genkitPluginV2 ,
34
+ type GenkitPluginV2 ,
35
+ } from 'genkit/plugin' ;
32
36
import type { ActionType } from 'genkit/registry' ;
33
37
import { getDerivedParams } from './common/index.js' ;
34
38
import type { PluginOptions } from './common/types.js' ;
@@ -116,20 +120,22 @@ async function initializer(options?: PluginOptions) {
116
120
117
121
const actions : ResolvableAction [ ] = [ ] ;
118
122
119
- Object . keys ( SUPPORTED_IMAGEN_MODELS ) . map ( ( name ) =>
120
- actions . push ( defineImagenModel ( name , authClient , { projectId, location } ) )
121
- ) ;
122
- Object . keys ( SUPPORTED_GEMINI_MODELS ) . map ( ( name ) =>
123
- actions . push ( defineGeminiKnownModel (
124
- name ,
125
- vertexClientFactory ,
126
- {
127
- projectId,
128
- location,
129
- } ,
130
- options ?. experimental_debugTraces
131
- ) )
132
- ) ;
123
+ for ( const name of Object . keys ( SUPPORTED_IMAGEN_MODELS ) ) {
124
+ actions . push ( defineImagenModel ( name , authClient , { projectId, location } ) ) ;
125
+ }
126
+ for ( const name of Object . keys ( SUPPORTED_GEMINI_MODELS ) ) {
127
+ actions . push (
128
+ defineGeminiKnownModel (
129
+ name ,
130
+ vertexClientFactory ,
131
+ {
132
+ projectId,
133
+ location,
134
+ } ,
135
+ options ?. experimental_debugTraces
136
+ )
137
+ ) ;
138
+ }
133
139
if ( options ?. models ) {
134
140
for ( const modelOrRef of options ?. models ) {
135
141
const modelName =
@@ -139,22 +145,26 @@ async function initializer(options?: PluginOptions) {
139
145
modelOrRef . name . split ( '/' ) [ 1 ] ;
140
146
const modelRef =
141
147
typeof modelOrRef === 'string' ? gemini ( modelOrRef ) : modelOrRef ;
142
- actions . push ( defineGeminiModel ( {
143
- modelName : modelRef . name ,
144
- version : modelName ,
145
- modelInfo : modelRef . info ,
146
- vertexClientFactory,
147
- options : {
148
- projectId,
149
- location,
150
- } ,
151
- debugTraces : options . experimental_debugTraces ,
152
- } ) ) ;
148
+ actions . push (
149
+ defineGeminiModel ( {
150
+ modelName : modelRef . name ,
151
+ version : modelName ,
152
+ modelInfo : modelRef . info ,
153
+ vertexClientFactory,
154
+ options : {
155
+ projectId,
156
+ location,
157
+ } ,
158
+ debugTraces : options . experimental_debugTraces ,
159
+ } )
160
+ ) ;
153
161
}
154
162
}
155
163
156
164
Object . keys ( SUPPORTED_EMBEDDER_MODELS ) . map ( ( name ) =>
157
- actions . push ( defineVertexAIEmbedder ( name , authClient , { projectId, location } ) )
165
+ actions . push (
166
+ defineVertexAIEmbedder ( name , authClient , { projectId, location } )
167
+ )
158
168
) ;
159
169
160
170
return actions ;
@@ -185,8 +195,7 @@ async function resolveModel(
185
195
await getDerivedParams ( options ) ;
186
196
187
197
if ( actionName . startsWith ( 'imagen' ) ) {
188
- defineImagenModel ( actionName , authClient , { projectId, location } ) ;
189
- return ;
198
+ return defineImagenModel ( actionName , authClient , { projectId, location } ) ;
190
199
}
191
200
192
201
const modelRef = gemini ( actionName ) ;
@@ -209,7 +218,10 @@ async function resolveEmbedder(
209
218
) : Promise < ResolvableAction | undefined > {
210
219
const { projectId, location, authClient } = await getDerivedParams ( options ) ;
211
220
212
- return defineVertexAIEmbedder ( actionName , authClient , { projectId, location } )
221
+ return defineVertexAIEmbedder ( actionName , authClient , {
222
+ projectId,
223
+ location,
224
+ } ) ;
213
225
}
214
226
215
227
// Vertex AI list models still returns these and the API does not indicate in any way
@@ -268,14 +280,14 @@ function vertexAIPlugin(options?: PluginOptions): GenkitPluginV2 {
268
280
let listActionsCache ;
269
281
return genkitPluginV2 ( {
270
282
name : 'vertexai' ,
271
- init : async ( ) => initializer ( ) ,
283
+ init : async ( ) => initializer ( options ) ,
272
284
resolve : async ( actionType : ActionType , actionName : string ) =>
273
- await resolver ( actionType , actionName ) ,
285
+ await resolver ( actionType , actionName , options ) ,
274
286
list : async ( ) => {
275
287
if ( listActionsCache ) return listActionsCache ;
276
288
listActionsCache = await listActions ( options ) ;
277
289
return listActionsCache ;
278
- }
290
+ } ,
279
291
} ) ;
280
292
}
281
293
0 commit comments