Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Google vertex provider support (CONFLICTED) #3136

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -953,3 +953,16 @@ FLASHBOTS_RELAY_SIGNING_KEY= # Signing key for Flashbots relay interactions
BUNDLE_EXECUTOR_ADDRESS= # Address of the bundle executor contract


####################################
#### Google Vertex Provider ####
####################################
GOOGLE_VERTEX_KEY= # "-----BEGIN PRIVATE KEY-----\n YOUR_PRIVATE \n-----END PRIVATE KEY-----\n"
GOOGLE_VERTEX_EMAIL= # [email protected]

GOOGLE_VERTEX_LOCATION= # us-central1
GOOGLE_VERTEX_PROJECT= # project-plexus-447108-e4

SMALL_GOOGLE_VERTEX_MODEL= # gemini-2.0-flash-exp
MEDIUM_GOOGLE_VERTEX_MODEL= # gemini-2.0-flash-exp
LARGE_GOOGLE_VERTEX_MODEL= # gemini-1.5-pro-002
GOOGLE_VERTEX_EMBEDDING_MODEL=
5 changes: 5 additions & 0 deletions agent/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,11 @@ export function getTokenForProvider(
character.settings?.secrets?.DEEPSEEK_API_KEY ||
settings.DEEPSEEK_API_KEY
);
case ModelProviderName.GOOGLE_VERTEX:
return (
character.settings?.secrets?.GOOGLE_VERTEX_API_KEY ||
settings.GOOGLE_VERTEX_API_KEY
);
case ModelProviderName.LIVEPEER:
return (
character.settings?.secrets?.LIVEPEER_GATEWAY_URL ||
Expand Down
4 changes: 2 additions & 2 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@
"typescript": "5.6.3"
},
"dependencies": {
"@ai-sdk/amazon-bedrock": "1.1.0",
"@ai-sdk/anthropic": "0.0.56",
"@ai-sdk/google": "0.0.55",
"@ai-sdk/google-vertex": "0.0.43",
"@ai-sdk/google-vertex": "1.0.0",
"@ai-sdk/groq": "0.0.3",
"@ai-sdk/mistral": "1.0.9",
"@ai-sdk/openai": "1.0.5",
"@ai-sdk/amazon-bedrock": "1.1.0",
"@fal-ai/client": "1.2.0",
"@tavily/core": "^0.0.2",
"@types/uuid": "10.0.0",
Expand Down
68 changes: 68 additions & 0 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import { createGroq } from "@ai-sdk/groq";
import { createOpenAI } from "@ai-sdk/openai";
import { bedrock } from "@ai-sdk/amazon-bedrock";
import { createVertex } from "@ai-sdk/google-vertex";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import {
generateObject as aiGenerateObject,
Expand Down Expand Up @@ -1283,6 +1284,40 @@
break;
}

case ModelProviderName.GOOGLE_VERTEX: {
elizaLogger.debug("Initializing Google Vertex API model.");

const vertex = createVertex({
location: settings.GOOGLE_VERTEX_LOCATION,
project: settings. GOOGLE_VERTEX_PROJECT,
googleAuthOptions: {
credentials: {
client_email: settings.GOOGLE_VERTEX_EMAIL,
private_key: settings.GOOGLE_VERTEX_KEY || runtime.token
},
},
});

const { text: openaiResponse } = await aiGenerateText({
// model: vertex.languageModel(model),
model: vertex(model),
prompt: context,
system: runtime.character.system ?? settings.SYSTEM_PROMPT ?? undefined,
tools: tools,
onStepFinish: onStepFinish,
maxSteps: maxSteps,
temperature: temperature,
maxTokens: max_response_length,
frequencyPenalty: frequency_penalty,
presencePenalty: presence_penalty,
experimental_telemetry: experimental_telemetry,
});

response = openaiResponse;
elizaLogger.debug("Received response from Google Vertex API model.");
break;
}

default: {
const errorMessage = `Unsupported provider: ${provider}`;
elizaLogger.error(errorMessage);
Expand Down Expand Up @@ -1618,42 +1653,42 @@
}
}

export const generateImage = async (
data: {
prompt: string;
width: number;
height: number;
count?: number;
negativePrompt?: string;
numIterations?: number;
guidanceScale?: number;
seed?: number;
modelId?: string;
jobId?: string;
stylePreset?: string;
hideWatermark?: boolean;
safeMode?: boolean;
cfgScale?: number;
},
runtime: IAgentRuntime
): Promise<{
success: boolean;
data?: string[];
error?: any;
}> => {
const modelSettings = getImageModelSettings(runtime.imageModelProvider);
if (!modelSettings) {
elizaLogger.warn("No model settings found for the image model provider.");
return { success: false, error: "No model settings available" };
}
const model = modelSettings.name;
elizaLogger.info("Generating image with options:", {
imageModelProvider: model,
});

const apiKey =
runtime.imageModelProvider === runtime.modelProvider
? runtime.token

Check notice on line 1691 in packages/core/src/generation.ts

View check run for this annotation

codefactor.io / CodeFactor

packages/core/src/generation.ts#L1656-L1691

Complex Method
: (() => {
// First try to match the specific provider
switch (runtime.imageModelProvider) {
Expand Down Expand Up @@ -2212,6 +2247,10 @@
return await handleDeepSeek(options);
case ModelProviderName.LIVEPEER:
return await handleLivepeer(options);
case ModelProviderName.BEDROCK:
return await handleBedrock(options);
case ModelProviderName.GOOGLE_VERTEX:
return await handleGoogleVertexApi(options);
default: {
const errorMessage = `Unsupported provider: ${provider}`;
elizaLogger.error(errorMessage);
Expand Down Expand Up @@ -2559,6 +2598,35 @@
});
}

async function handleGoogleVertexApi({
model,
apiKey,
schema,
schemaName,
schemaDescription,
mode,
modelOptions,
}: ProviderOptions): Promise<GenerateObjectResult<unknown>> {
const vertex = createVertex({
location: settings.GOOGLE_VERTEX_LOCATION,
project: settings.GOOGLE_VERTEX_PROJECT,
googleAuthOptions: {
credentials: {
client_email: settings.GOOGLE_VERTEX_EMAIL,
private_key: settings.GOOGLE_VERTEX_KEY || apiKey,
},
},
});

return await aiGenerateObject({
model: vertex(model),
schema,
schemaName,
schemaDescription,
mode,
...modelOptions,
});
}
// Add type definition for Together AI response
interface TogetherAIImageResponse {
data: Array<{
Expand Down
35 changes: 35 additions & 0 deletions packages/core/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,41 @@ export const models: Models = {
},
},
},
[ModelProviderName.GOOGLE_VERTEX]: {
model: {
[ModelClass.SMALL]: {
name: settings.SMALL_GOOGLE_VERTEX_MODEL || "text-bison@002",
stop: [],
maxInputTokens: 128000,
maxOutputTokens: 8192,
frequency_penalty: 0.0,
presence_penalty: 0.0,
temperature: 0.6,
},
[ModelClass.MEDIUM]: {
name: settings.MEDIUM_GOOGLE_VERTEX_MODEL || "text-bison@002",
stop: [],
maxInputTokens: 128000,
maxOutputTokens: 8192,
frequency_penalty: 0.0,
presence_penalty: 0.0,
temperature: 0.6,
},
[ModelClass.LARGE]: {
name: settings.LARGE_GOOGLE_VERTEX_MODEL || "text-bison@002",
stop: [],
maxInputTokens: 128000,
maxOutputTokens: 8192,
frequency_penalty: 0.0,
presence_penalty: 0.0,
temperature: 0.6,
},
[ModelClass.EMBEDDING]: {
name: settings.GOOGLE_VERTEX_EMBEDDING_MODEL || "textembedding-gecko@003",
dimensions: 768,
},
},
},
[ModelProviderName.BEDROCK]: {
model: {
[ModelClass.SMALL]: {
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ export type Models = {
[ModelProviderName.LIVEPEER]: Model;
[ModelProviderName.DEEPSEEK]: Model;
[ModelProviderName.INFERA]: Model;
[ModelProviderName.GOOGLE_VERTEX]: Model;
[ModelProviderName.BEDROCK]: Model;
[ModelProviderName.ATOMA]: Model;
};
Expand Down Expand Up @@ -272,6 +273,7 @@ export enum ModelProviderName {
INFERA = "infera",
BEDROCK = "bedrock",
ATOMA = "atoma",
GOOGLE_VERTEX = "google_vertex_api"
}

/**
Expand Down
Loading
Loading