Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions js/plugins/vertexai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,17 @@ async function initializer(options?: PluginOptions) {

const actions: ResolvableAction[] = [];

<<<<<<< HEAD
for (const name of Object.keys(SUPPORTED_IMAGEN_MODELS)) {
actions.push(defineImagenModel(name, authClient, { projectId, location }));
}
for (const name of Object.keys(SUPPORTED_GEMINI_MODELS)) {
=======
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left over from a bad rebase, TODO: fix

Object.keys(SUPPORTED_IMAGEN_MODELS).map((name) =>
actions.push(defineImagenModel(name, authClient, { projectId, location }))
);
Object.keys(SUPPORTED_GEMINI_MODELS).map((name) =>
>>>>>>> 69cac47ac (refactor(plugins/vertexai): add embedder v2 API support)
actions.push(
defineGeminiKnownModel(
name,
Expand All @@ -134,8 +141,13 @@ async function initializer(options?: PluginOptions) {
},
options?.experimental_debugTraces
)
<<<<<<< HEAD
);
}
=======
)
);
>>>>>>> 69cac47ac (refactor(plugins/vertexai): add embedder v2 API support)
if (options?.models) {
for (const modelOrRef of options?.models) {
const modelName =
Expand Down
48 changes: 44 additions & 4 deletions js/plugins/vertexai/src/vectorsearch/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
* limitations under the License.
*/

import type { EmbedderAction } from 'genkit/embedder';
import {
genkitPluginV2,
ResolvableAction,
type GenkitPluginV2,
} from 'genkit/plugin';
import { getDerivedParams } from '../common/index.js';
import { defineVertexAIEmbedder } from '../embedder.js';
import type { PluginOptions } from './types.js';
import { vertexAiIndexers, vertexAiRetrievers } from './vector_search/index.js';
export type { PluginOptions } from '../common/types.js';
Expand Down Expand Up @@ -124,31 +126,69 @@ export function vertexAIVectorSearch(options?: PluginOptions): GenkitPluginV2 {
return genkitPluginV2({
name: 'vertexAIVectorSearch',
init: async () => {
const { authClient } = await getDerivedParams(options);
const { authClient, projectId, location } =
await getDerivedParams(options);

const actions: ResolvableAction[] = [];

// Resolve default embedder if provided
let defaultEmbedderAction: EmbedderAction | undefined;
if (options?.embedder) {
// Create an embedder action for the default embedder
const embedderName = options.embedder.name.includes('/')
? options.embedder.name.split('/')[1]
: options.embedder.name;
defaultEmbedderAction = defineVertexAIEmbedder(
embedderName,
authClient,
{ projectId, location }
);
}

if (
options?.vectorSearchOptions &&
options.vectorSearchOptions.length > 0
) {
// Process each vector search option to resolve embedders
const processedOptions = { ...options };
if (processedOptions.vectorSearchOptions) {
processedOptions.vectorSearchOptions = await Promise.all(
processedOptions.vectorSearchOptions.map(async (vso) => {
const processed = { ...vso };
// If this option has an embedder reference, resolve it to an action
if (vso.embedder && !vso.embedderAction) {
const embedderName = vso.embedder.name.includes('/')
? vso.embedder.name.split('/')[1]
: vso.embedder.name;
processed.embedderAction = defineVertexAIEmbedder(
embedderName,
authClient,
{ projectId, location }
);
}
return processed;
})
);
}

actions.push(
...vertexAiIndexers({
pluginOptions: options,
pluginOptions: processedOptions,
authClient,
defaultEmbedder: options.embedder,
defaultEmbedderAction,
})
);

actions.push(
...vertexAiRetrievers({
pluginOptions: options,
pluginOptions: processedOptions,
authClient,
defaultEmbedder: options.embedder,
defaultEmbedderAction,
})
);
}

return actions;
},
});
Expand Down
18 changes: 12 additions & 6 deletions js/plugins/vertexai/src/vectorsearch/vector_search/indexers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import type { z } from 'genkit';
import { indexer } from 'genkit/plugin';
import { indexerRef, type IndexerAction } from 'genkit/retriever';
import { Document, indexerRef, type IndexerAction } from 'genkit/retriever';
import {
Datapoint,
VertexAIVectorIndexerOptionsSchema,
Expand Down Expand Up @@ -66,12 +66,14 @@ export function vertexAiIndexers<EmbedderCustomOptions extends z.ZodTypeAny>(

for (const vectorSearchOption of vectorSearchOptions) {
const { documentIndexer, indexId } = vectorSearchOption;
const embedderAction =
vectorSearchOption.embedderAction ?? params.defaultEmbedderAction;
const embedderReference =
vectorSearchOption.embedder ?? params.defaultEmbedder;

if (!embedderReference) {
if (!embedderAction && !embedderReference) {
throw new Error(
'Embedder reference is required to define Vertex AI retriever'
'Embedder action or reference is required to define Vertex AI indexer'
);
}
const embedderOptions = vectorSearchOption.embedderOptions;
Expand All @@ -91,11 +93,15 @@ export function vertexAiIndexers<EmbedderCustomOptions extends z.ZodTypeAny>(
);
}

const embeddings = await ai.embedMany({
embedder: embedderReference,
content: docs,
// Call the embedder action directly
if (!embedderAction) {
throw new Error('Embedder action is required for indexing');
}
const response = await embedderAction({
input: docs.map((d) => (d instanceof Document ? d : new Document(d))),
options: embedderOptions,
});
const embeddings = response.embeddings;

const datapoints = embeddings.map(({ embedding }, i) => {
const dp = new Datapoint({
Expand Down
26 changes: 13 additions & 13 deletions js/plugins/vertexai/src/vectorsearch/vector_search/retrievers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import { retrieverRef, type RetrieverAction, type z } from 'genkit';
import { retriever } from 'genkit/plugin';
import { Document } from 'genkit/retriever';
import { queryPublicEndpoint } from './query_public_endpoint';
import {
VertexAIVectorRetrieverOptionsSchema,
Expand All @@ -38,7 +39,6 @@ export function vertexAiRetrievers<EmbedderCustomOptions extends z.ZodTypeAny>(
params: VertexVectorSearchOptions<EmbedderCustomOptions>
): RetrieverAction<z.ZodTypeAny>[] {
const vectorSearchOptions = params.pluginOptions.vectorSearchOptions;
const defaultEmbedder = params.defaultEmbedder;

const retrieverActions: RetrieverAction<z.ZodTypeAny>[] = [];

Expand All @@ -49,29 +49,29 @@ export function vertexAiRetrievers<EmbedderCustomOptions extends z.ZodTypeAny>(
for (const vectorSearchOption of vectorSearchOptions) {
const { documentRetriever, indexId, publicDomainName } = vectorSearchOption;
const embedderOptions = vectorSearchOption.embedderOptions;
const embedderAction =
vectorSearchOption.embedderAction ?? params.defaultEmbedderAction;

const retrieverAction = retriever(
{
name: `vertexai/${indexId}`,
configSchema: VertexAIVectorRetrieverOptionsSchema.optional(),
},
async (content, options) => {
const embedderReference =
vectorSearchOption.embedder ?? defaultEmbedder;

if (!embedderReference) {
if (!embedderAction) {
throw new Error(
'Embedder reference is required to define Vertex AI retriever'
'Embedder action is required to define Vertex AI retriever'
);
}

const queryEmbedding = (
await ai.embed({
embedder: embedderReference,
options: embedderOptions,
content,
})
)[0].embedding; // Single embedding for text
// Call the embedder action directly with a single document
const response = await embedderAction({
input: [
content instanceof Document ? content : new Document(content),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not 100% if this is passed in correctly

],
options: embedderOptions,
});
const queryEmbedding = response.embeddings[0].embedding; // Single embedding for text

const accessToken = await params.authClient.getAccessToken();

Expand Down
4 changes: 3 additions & 1 deletion js/plugins/vertexai/src/vectorsearch/vector_search/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import * as aiplatform from '@google-cloud/aiplatform';
import { z } from 'genkit';
import type { EmbedderReference } from 'genkit/embedder';
import type { EmbedderAction, EmbedderReference } from 'genkit/embedder';
import { CommonRetrieverOptionsSchema, type Document } from 'genkit/retriever';
import type { GoogleAuth } from 'google-auth-library';
import type { PluginOptions } from '../types.js';
Expand All @@ -28,6 +28,7 @@ export interface VertexVectorSearchOptions<
pluginOptions: PluginOptions;
authClient: GoogleAuth;
defaultEmbedder?: EmbedderReference<EmbedderCustomOptions>;
defaultEmbedderAction?: EmbedderAction<EmbedderCustomOptions>;
}

export type IIndexDatapoint =
Expand Down Expand Up @@ -186,5 +187,6 @@ export interface VectorSearchOptions<
documentIndexer: DocumentIndexer<IndexerOptions>;
// Embedder and default options to use for indexing and retrieval
embedder?: EmbedderReference<EmbedderCustomOptions>;
embedderAction?: EmbedderAction<EmbedderCustomOptions>;
embedderOptions?: z.infer<EmbedderCustomOptions>;
}
Loading