Skip to content

default to stagehand LLM clients for evals #669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions evals/args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ const parsedArgs: {
concurrency?: number;
extractMethod?: string;
provider?: string;
useExternalClients?: boolean;
leftover: string[];
} = {
leftover: [],
Expand All @@ -31,6 +32,9 @@ for (const arg of rawArgs) {
parsedArgs.extractMethod = arg.split("=")[1];
} else if (arg.startsWith("provider=")) {
parsedArgs.provider = arg.split("=")[1]?.toLowerCase();
} else if (arg.startsWith("useExternalClients=")) {
const val = arg.split("=")[1]?.toLowerCase();
parsedArgs.useExternalClients = val === "true";
} else {
parsedArgs.leftover.push(arg);
}
Expand Down
68 changes: 13 additions & 55 deletions evals/index.eval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,19 @@ import {
filterByCategory,
filterByEvalName,
useTextExtract,
parsedArgs,
} from "./args";
import { generateExperimentName } from "./utils";
import { createLLMClient, generateExperimentName } from "./utils";
import { exactMatch, errorMatch } from "./scoring";
import { tasksByName, MODELS, tasksConfig } from "./taskConfig";
import { Eval, wrapAISDKModel, wrapOpenAI } from "braintrust";
import { Eval } from "braintrust";
import { EvalFunction, SummaryResult, Testcase } from "@/types/evals";
import { EvalLogger } from "./logger";
import { AvailableModel, LLMClient } from "@/dist";
import { AvailableModel } from "@/dist";
import { env } from "./env";
import dotenv from "dotenv";
import { StagehandEvalError } from "@/types/stagehandErrors";
import { CustomOpenAIClient } from "@/examples/external_clients/customOpenAI";
import OpenAI from "openai";
import { initStagehand } from "./initStagehand";
import { AISdkClient } from "@/examples/external_clients/aisdk";
import { google } from "@ai-sdk/google";
import { anthropic } from "@ai-sdk/anthropic";
import { groq } from "@ai-sdk/groq";
import { cerebras } from "@ai-sdk/cerebras";
dotenv.config();

/**
Expand Down Expand Up @@ -273,51 +267,15 @@ const generateFilteredTestcases = (): Testcase[] => {
}

// Execute the task
let llmClient: LLMClient;
if (input.modelName.startsWith("gpt")) {
llmClient = new CustomOpenAIClient({
modelName: input.modelName as AvailableModel,
client: wrapOpenAI(
new OpenAI({
apiKey: process.env.OPENAI_API_KEY,
}),
),
});
} else if (input.modelName.startsWith("gemini")) {
llmClient = new AISdkClient({
model: wrapAISDKModel(google(input.modelName)),
});
} else if (input.modelName.startsWith("claude")) {
llmClient = new AISdkClient({
model: wrapAISDKModel(anthropic(input.modelName)),
});
} else if (input.modelName.includes("groq")) {
llmClient = new AISdkClient({
model: wrapAISDKModel(
groq(
input.modelName.substring(input.modelName.indexOf("/") + 1),
),
),
});
} else if (input.modelName.includes("cerebras")) {
llmClient = new AISdkClient({
model: wrapAISDKModel(
cerebras(
input.modelName.substring(input.modelName.indexOf("/") + 1),
),
),
});
} else if (input.modelName.includes("/")) {
llmClient = new CustomOpenAIClient({
modelName: input.modelName as AvailableModel,
client: wrapOpenAI(
new OpenAI({
apiKey: process.env.TOGETHER_AI_API_KEY,
baseURL: "https://api.together.xyz/v1",
}),
),
});
}
const llmClient = createLLMClient({
modelName: input.modelName,
useExternalClients: parsedArgs.useExternalClients === true,
logger: (msg) => logger.log(msg),
Comment on lines +271 to +273
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Strict comparison with boolean could cause issues if parsedArgs.useExternalClients is undefined. Consider using !!parsedArgs.useExternalClients instead.

openAiKey: process.env.OPENAI_API_KEY,
googleKey: process.env.GOOGLE_API_KEY,
anthropicKey: process.env.ANTHROPIC_API_KEY,
togetherKey: process.env.TOGETHER_AI_API_KEY,
});
const taskInput = await initStagehand({
logger,
llmClient,
Expand Down
115 changes: 115 additions & 0 deletions evals/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@
import { LogLine } from "@/dist";
import stringComparison from "string-comparison";
const { jaroWinkler } = stringComparison;
import OpenAI from "openai";
import { wrapAISDKModel, wrapOpenAI } from "braintrust";
import { anthropic } from "@ai-sdk/anthropic";
import { google } from "@ai-sdk/google";
import { groq } from "@ai-sdk/groq";
import { cerebras } from "@ai-sdk/cerebras";
import { LLMClient } from "@/dist";
import { AISdkClient } from "@/examples/external_clients/aisdk";
import { CustomOpenAIClient } from "@/examples/external_clients/customOpenAI";
import { OpenAIClient } from "@/lib/llm/OpenAIClient";
import { AnthropicClient } from "@/lib/llm/AnthropicClient";
import { GoogleClient } from "@/lib/llm/GoogleClient";
import { CreateLLMClientOptions } from "@/types/evals";
import { StagehandEvalError } from "@/types/stagehandErrors";
import { openai } from "@ai-sdk/openai";

/**
* normalizeString:
Expand Down Expand Up @@ -119,3 +134,103 @@ export function logLineToString(logLine: LogLine): string {
return "error logging line";
}
}

export function createLLMClient({
modelName,
useExternalClients,
logger,
openAiKey,
googleKey,
anthropicKey,
togetherKey,
}: CreateLLMClientOptions): LLMClient {
const isOpenAIModel =
modelName.startsWith("gpt") || modelName.startsWith("o");
const isGoogleModel = modelName.startsWith("gemini");
const isAnthropicModel = modelName.startsWith("claude");
const isGroqModel = modelName.includes("groq");
const isCerebrasModel = modelName.includes("cerebras");

if (useExternalClients) {
if (isOpenAIModel) {
if (modelName.includes("/")) {
return new CustomOpenAIClient({
modelName,
client: wrapOpenAI(
new OpenAI({
apiKey: togetherKey,
baseURL: "https://api.together.xyz/v1",
}),
),
});
}
return new AISdkClient({
model: wrapAISDKModel(openai(modelName)),
});
} else if (isGoogleModel) {
return new AISdkClient({
model: wrapAISDKModel(google(modelName)),
});
} else if (isAnthropicModel) {
return new AISdkClient({
model: wrapAISDKModel(anthropic(modelName)),
});
} else if (isGroqModel) {
const groqModel = modelName.substring(modelName.indexOf("/") + 1);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Potential error if '/' is not found in modelName. Add null check before substring operation.

Suggested change
const groqModel = modelName.substring(modelName.indexOf("/") + 1);
const slashIndex = modelName.indexOf("/");
const groqModel = slashIndex === -1 ? modelName : modelName.substring(slashIndex + 1);

return new AISdkClient({
model: wrapAISDKModel(groq(groqModel)),
});
} else if (isCerebrasModel) {
const cerebrasModel = modelName.substring(modelName.indexOf("/") + 1);
return new AISdkClient({
model: wrapAISDKModel(cerebras(cerebrasModel)),
});
}
throw new StagehandEvalError(`Unknown modelName: ${modelName}`);
} else {
if (isOpenAIModel) {
if (modelName.includes("/")) {
return new CustomOpenAIClient({
modelName,
client: wrapOpenAI(
new OpenAI({
apiKey: togetherKey,
baseURL: "https://api.together.xyz/v1",
}),
),
});
}
Comment on lines +192 to +202
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: This block is duplicated from the external clients section. Consider extracting Together.ai model handling into a separate function to avoid duplication.

return new OpenAIClient({
logger,
modelName,
enableCaching: false,
clientOptions: {
apiKey: openAiKey,
},
});
} else if (isGoogleModel) {
return new GoogleClient({
logger,
modelName,
enableCaching: false,
clientOptions: {
apiKey: googleKey,
},
});
} else if (isAnthropicModel) {
return new AnthropicClient({
logger,
modelName,
enableCaching: false,
clientOptions: {
apiKey: anthropicKey,
},
});
} else if (isGroqModel || isCerebrasModel) {
throw new StagehandEvalError(
`${modelName} can only be used when useExternalClients=true`,
);
}
throw new StagehandEvalError(`Unknown modelName: ${modelName}`);
}
}
10 changes: 10 additions & 0 deletions types/evals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,13 @@ export interface EvalResult {
export type LogLineEval = LogLine & {
parsedAuxiliary?: string | object;
};

export interface CreateLLMClientOptions {
modelName: AvailableModel;
useExternalClients?: boolean;
logger?: (msg: LogLine) => void;
openAiKey?: string;
googleKey?: string;
anthropicKey?: string;
togetherKey?: string;
}