Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 245 additions & 5 deletions js/plugins/chroma/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,19 @@ import {
indexerRef,
retrieverRef,
z,
type EmbedderAction,
type EmbedderArgument,
type Embedding,
type Genkit,
} from 'genkit';
import { genkitPlugin, type GenkitPlugin } from 'genkit/plugin';
import {
genkitPluginV2,
indexer,
retriever,
type GenkitPluginV2,
type ResolvableAction,
} from 'genkit/plugin';
import type { ActionType } from 'genkit/registry';
import { CommonRetrieverOptionsSchema } from 'genkit/retriever';
import { Md5 } from 'ts-md5';

Expand Down Expand Up @@ -74,12 +82,53 @@ type ChromaPluginParams<
/**
* Chroma plugin that provides the Chroma retriever and indexer
*/

export function chroma<EmbedderCustomOptions extends z.ZodTypeAny>(
params: ChromaPluginParams<EmbedderCustomOptions>
): GenkitPlugin {
return genkitPlugin('chroma', async (ai: Genkit) => {
params.map((i) => chromaRetriever(ai, i));
params.map((i) => chromaIndexer(ai, i));
): GenkitPluginV2 {
return genkitPluginV2({
name: 'chroma',
async init() {
const actions: ResolvableAction[] = [];
for (const param of params) {
actions.push(createChromaRetriever(param));
actions.push(createChromaIndexer(param));
}
return actions;
},
async resolve(actionType: ActionType, name: string) {
// Find the matching param by collection name
const collectionName = name.replace('chroma/', '');
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we definitely need to strip thechroma prefix off here?

const param = params.find((p) => p.collectionName === collectionName);
if (!param) return undefined;

switch (actionType) {
case 'retriever':
return createChromaRetriever(param);
case 'indexer':
return createChromaIndexer(param);
default:
return undefined;
}
},
async list() {
return params.flatMap((param) => [
{
name: `chroma/${param.collectionName}`,
type: 'retriever' as const,
info: {
label: `Chroma DB - ${param.collectionName}`,
},
},
{
name: `chroma/${param.collectionName}`,
type: 'indexer' as const,
info: {
label: `Chroma DB - ${param.collectionName}`,
},
},
]);
},
});
}

Expand Down Expand Up @@ -369,6 +418,197 @@ export async function deleteChromaCollection(params: {
});
}

/**
* Standalone Chroma retriever action (v2 API)
*/
function createChromaRetriever<
EmbedderCustomOptions extends z.ZodTypeAny,
>(params: {
clientParams?: ChromaClientParams;
collectionName: string;
createCollectionIfMissing?: boolean;
embedder: EmbedderArgument<EmbedderCustomOptions>;
embedderOptions?: z.infer<EmbedderCustomOptions>;
}) {
const { embedder, collectionName, embedderOptions } = params;

return retriever(
{
name: `chroma/${collectionName}`,
configSchema: ChromaRetrieverOptionsSchema.optional(),
},
async (content, options) => {
const clientParams = await resolve(params.clientParams);
const client = new ChromaClient(clientParams);
let collection: Collection;
if (params.createCollectionIfMissing) {
collection = await client.getOrCreateCollection({
name: collectionName,
});
} else {
collection = await client.getCollection({
name: collectionName,
});
}

// For v2 API, we need to handle embedding differently
// The embedder will be resolved at runtime
const queryEmbeddings = await resolveEmbedder(embedder, {
content,
options: embedderOptions,
});

const results = await collection.query({
nResults: options?.k,
include: getIncludes(options?.include),
where: options?.where,
whereDocument: options?.whereDocument,
queryEmbeddings: queryEmbeddings[0].embedding,
});

const documents = results.documents[0];
const metadatas = results.metadatas;
const embeddings = results.embeddings;
const distances = results.distances;

const combined = documents
.map((d, i) => {
if (d !== null) {
return {
document: d,
metadata: constructMetadata(i, metadatas, embeddings, distances),
};
}
return undefined;
})
.filter(
(r): r is { document: string; metadata: Record<string, any> } => !!r
);

return {
documents: combined.map((result) => {
const data = result.document;
const metadata = result.metadata.metadata[0];
const dataType = metadata.dataType;
const docMetadata = metadata.docMetadata
? JSON.parse(metadata.docMetadata)
: undefined;
return Document.fromData(data, dataType, docMetadata).toJSON();
}),
};
}
);
}

/**
* Standalone Chroma indexer action (v2 API)
*/
function createChromaIndexer<
EmbedderCustomOptions extends z.ZodTypeAny,
>(params: {
clientParams?: ChromaClientParams;
collectionName: string;
createCollectionIfMissing?: boolean;
embedder: EmbedderArgument<EmbedderCustomOptions>;
embedderOptions?: z.infer<EmbedderCustomOptions>;
}) {
const { collectionName, embedder, embedderOptions } = {
...params,
};

return indexer(
{
name: `chroma/${params.collectionName}`,
configSchema: ChromaIndexerOptionsSchema,
},
async (docs) => {
const clientParams = await resolve(params.clientParams);
const client = new ChromaClient(clientParams);

let collection: Collection;
if (params.createCollectionIfMissing) {
collection = await client.getOrCreateCollection({
name: collectionName,
});
} else {
collection = await client.getCollection({
name: collectionName,
});
}

const embeddings = await Promise.all(
docs.map((doc) =>
resolveEmbedder(embedder, {
content: doc,
options: embedderOptions,
})
)
);

const entries = embeddings
.map((value, i) => {
const doc = docs[i];
// The array of embeddings for this document
const docEmbeddings: Embedding[] = value;
const embeddingDocs = doc.getEmbeddingDocuments(docEmbeddings);
return docEmbeddings.map((docEmbedding, j) => {
const metadata: Metadata = {
docMetadata: JSON.stringify(embeddingDocs[j].metadata),
dataType: embeddingDocs[j].dataType || '',
};

const data = embeddingDocs[j].data;
const id = Md5.hashStr(JSON.stringify(embeddingDocs[j]));
return {
id,
value: docEmbedding.embedding,
document: data,
metadata,
};
});
})
.reduce((acc, val) => {
return acc.concat(val);
}, []);

await collection.add({
ids: entries.map((e) => e.id),
embeddings: entries.map((e) => e.value),
metadatas: entries.map((e) => e.metadata),
documents: entries.map((e) => e.document),
});
}
);
}

/**
* Helper function to resolve embedder and get embeddings
* Call embedder actions directly
*/
async function resolveEmbedder<EmbedderCustomOptions extends z.ZodTypeAny>(
embedder: EmbedderArgument<EmbedderCustomOptions>,
params: {
content: Document;
options?: z.infer<EmbedderCustomOptions>;
}
): Promise<Embedding[]> {
// If embedder is an action (function with __action property), call it directly
if (typeof embedder === 'function' && '__action' in embedder) {
const embedderAction = embedder as EmbedderAction<EmbedderCustomOptions>;
const response = await embedderAction({
input: [params.content],
options: params.options,
});
return response.embeddings;
}

// If embedder is a string reference, we need to resolve it
// throw an error as this requires registry access
throw new Error(
`Embedder resolution for string references not supported in v2 API: ${embedder}`
);
}

async function resolve(
params?: ChromaClientParams
): Promise<NativeChromaClientParams | undefined> {
Expand Down