Skip to content

make ai sdk native #698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Apr 29, 2025
Merged
5 changes: 5 additions & 0 deletions .changeset/mean-plums-sin.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@browserbasehq/stagehand": patch
---

Fixing LLM client support to natively integrate with AI SDK
2 changes: 1 addition & 1 deletion evals/index.eval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ import { StagehandEvalError } from "@/types/stagehandErrors";
import { CustomOpenAIClient } from "@/examples/external_clients/customOpenAI";
import OpenAI from "openai";
import { initStagehand } from "./initStagehand";
import { AISdkClient } from "@/examples/external_clients/aisdk";
import { google } from "@ai-sdk/google";
import { anthropic } from "@ai-sdk/anthropic";
import { groq } from "@ai-sdk/groq";
import { cerebras } from "@ai-sdk/cerebras";
import { openai } from "@ai-sdk/openai";
import { AISdkClient } from "@/examples/external_clients/aisdk";
dotenv.config();

/**
Expand Down
6 changes: 1 addition & 5 deletions evals/llm_clients/hn_aisdk.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import { Stagehand } from "@/dist";
import { AISdkClient } from "@/examples/external_clients/aisdk";
import { EvalFunction } from "@/types/evals";
import { openai } from "@ai-sdk/openai/dist";
import { z } from "zod";

export const hn_aisdk: EvalFunction = async ({
Expand All @@ -12,9 +10,7 @@ export const hn_aisdk: EvalFunction = async ({
}) => {
const stagehand = new Stagehand({
...stagehandConfig,
llmClient: new AISdkClient({
model: openai("gpt-4o-mini"),
}),
modelName: "openai/gpt-4o-mini",
});
await stagehand.init();
await stagehand.page.goto(
Expand Down
6 changes: 1 addition & 5 deletions examples/ai_sdk_example.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import { openai } from "@ai-sdk/openai";
import { Stagehand } from "@/dist";
import { AISdkClient } from "./external_clients/aisdk";
import StagehandConfig from "@/stagehand.config";
import { z } from "zod";

async function example() {
const stagehand = new Stagehand({
...StagehandConfig,
llmClient: new AISdkClient({
model: openai("gpt-4o"),
}),
modelName: "openai/gpt-4o",
});

await stagehand.init();
Expand Down
11 changes: 10 additions & 1 deletion lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ import {
StagehandEnvironmentError,
MissingEnvironmentVariableError,
UnsupportedModelError,
UnsupportedAISDKModelProviderError,
InvalidAISDKModelFormatError,
} from "../types/stagehandErrors";

dotenv.config({ path: ".env" });
Expand Down Expand Up @@ -543,7 +545,13 @@ export class Stagehand {
modelName ?? DEFAULT_MODEL_NAME,
modelClientOptions,
);
} catch {
} catch (error) {
if (
error instanceof UnsupportedAISDKModelProviderError ||
error instanceof InvalidAISDKModelFormatError
) {
throw error;
}
this.llmClient = undefined;
}
}
Expand Down Expand Up @@ -656,6 +664,7 @@ export class Stagehand {
projectId: this.projectId,
logger: this.logger,
});

const modelApiKey =
LLMProvider.getModelProvider(this.modelName) === "openai"
? process.env.OPENAI_API_KEY || this.llmClient.clientOptions.apiKey
Expand Down
21 changes: 21 additions & 0 deletions lib/llm/LLMClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@ import { ZodType } from "zod";
import { LLMTool } from "../../types/llm";
import { LogLine } from "../../types/log";
import { AvailableModel, ClientOptions } from "../../types/model";
import {
generateObject,
generateText,
streamText,
streamObject,
experimental_generateImage,
embed,
embedMany,
experimental_transcribe,
experimental_generateSpeech,
} from "ai";

export interface ChatMessage {
role: "system" | "user" | "assistant";
Expand Down Expand Up @@ -102,4 +113,14 @@ export abstract class LLMClient {
usage?: LLMResponse["usage"];
},
>(options: CreateChatCompletionOptions): Promise<T>;

public generateObject = generateObject;
public generateText = generateText;
public streamText = streamText;
public streamObject = streamObject;
public generateImage = experimental_generateImage;
public embed = embed;
public embedMany = embedMany;
public transcribe = experimental_transcribe;
public generateSpeech = experimental_generateSpeech;
}
76 changes: 66 additions & 10 deletions lib/llm/LLMProvider.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,50 @@
import {
UnsupportedAISDKModelProviderError,
UnsupportedModelError,
UnsupportedModelProviderError,
} from "@/types/stagehandErrors";
import { LogLine } from "../../types/log";
import {
AvailableModel,
ClientOptions,
ModelProvider,
} from "../../types/model";
import { LLMCache } from "../cache/LLMCache";
import { AISdkClient } from "./aisdk";
import { AnthropicClient } from "./AnthropicClient";
import { CerebrasClient } from "./CerebrasClient";
import { GoogleClient } from "./GoogleClient";
import { GroqClient } from "./GroqClient";
import { LLMClient } from "./LLMClient";
import { OpenAIClient } from "./OpenAIClient";
import {
UnsupportedModelError,
UnsupportedModelProviderError,
} from "@/types/stagehandErrors";
import { openai } from "@ai-sdk/openai";
import { anthropic } from "@ai-sdk/anthropic";
import { google } from "@ai-sdk/google";
import { xai } from "@ai-sdk/xai";
import { azure } from "@ai-sdk/azure";
import { groq } from "@ai-sdk/groq";
import { cerebras } from "@ai-sdk/cerebras";
import { togetherai } from "@ai-sdk/togetherai";
import { mistral } from "@ai-sdk/mistral";
import { deepseek } from "@ai-sdk/deepseek";
import { perplexity } from "@ai-sdk/perplexity";
import { ollama } from "ollama-ai-provider";
import { AISDKProvider } from "@/types/llm";

const AISDKProviders: Record<string, AISDKProvider> = {
openai,
anthropic,
google,
xai,
azure,
groq,
cerebras,
togetherai,
mistral,
deepseek,
perplexity,
ollama,
};

const modelToProviderMap: { [key in AvailableModel]: ModelProvider } = {
"gpt-4.1": "openai",
Expand Down Expand Up @@ -84,50 +114,76 @@ export class LLMProvider {
modelName: AvailableModel,
clientOptions?: ClientOptions,
): LLMClient {
if (modelName.includes("/")) {
const firstSlashIndex = modelName.indexOf("/");
const subProvider = modelName.substring(0, firstSlashIndex);
const subModelName = modelName.substring(firstSlashIndex + 1);

const languageModel = getAISDKLanguageModel(subProvider, subModelName);

return new AISdkClient({
model: languageModel,
logger: this.logger,
enableCaching: this.enableCaching,
cache: this.cache,
});
}

function getAISDKLanguageModel(subProvider: string, subModelName: string) {
const aiSDKLanguageModel = AISDKProviders[subProvider];
if (!aiSDKLanguageModel) {
throw new UnsupportedAISDKModelProviderError(
subProvider,
Object.keys(AISDKProviders),
);
}
return aiSDKLanguageModel(subModelName);
}

const provider = modelToProviderMap[modelName];
if (!provider) {
throw new UnsupportedModelError(Object.keys(modelToProviderMap));
}

const availableModel = modelName as AvailableModel;
switch (provider) {
case "openai":
return new OpenAIClient({
logger: this.logger,
enableCaching: this.enableCaching,
cache: this.cache,
modelName,
modelName: availableModel,
clientOptions,
});
case "anthropic":
return new AnthropicClient({
logger: this.logger,
enableCaching: this.enableCaching,
cache: this.cache,
modelName,
modelName: availableModel,
clientOptions,
});
case "cerebras":
return new CerebrasClient({
logger: this.logger,
enableCaching: this.enableCaching,
cache: this.cache,
modelName,
modelName: availableModel,
clientOptions,
});
case "groq":
return new GroqClient({
logger: this.logger,
enableCaching: this.enableCaching,
cache: this.cache,
modelName,
modelName: availableModel,
clientOptions,
});
case "google":
return new GoogleClient({
logger: this.logger,
enableCaching: this.enableCaching,
cache: this.cache,
modelName,
modelName: availableModel,
clientOptions,
});
default:
Expand Down
Loading