Skip to content

Adding zod validation #690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions lib/llm/GoogleClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ import {
} from "./LLMClient";
import {
CreateChatCompletionResponseError,
CreateChatCompletionResponseValidationError,
StagehandError,
} from "@/types/stagehandErrors";
import { validateZodSchemaWithResult } from "@/types/zod";

// Mapping from generic roles to Gemini roles
const roleMap: { [key in ChatMessage["role"]]: string } = {
Expand Down Expand Up @@ -440,7 +442,12 @@ export class GoogleClient extends LLMClient {
);
}

if (!validateZodSchema(response_model.schema, parsedData)) {
const validationResult = validateZodSchemaWithResult(
response_model.schema,
parsedData,
);

if (!validationResult.success) {
logger({
category: "google",
message: "Response failed Zod schema validation",
Expand All @@ -453,8 +460,8 @@ export class GoogleClient extends LLMClient {
retries: retries - 1,
});
}
throw new CreateChatCompletionResponseError(
"Invalid response schema",
throw new CreateChatCompletionResponseValidationError(
validationResult.error,
);
}

Expand Down
14 changes: 11 additions & 3 deletions lib/llm/OpenAIClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ import {
LLMResponse,
} from "./LLMClient";
import {
CreateChatCompletionResponseError,
CreateChatCompletionResponseValidationError,
StagehandError,
} from "@/types/stagehandErrors";
import { validateZodSchemaWithResult } from "@/types/zod";

export class OpenAIClient extends LLMClient {
public type = "openai" as const;
Expand Down Expand Up @@ -411,7 +412,12 @@ export class OpenAIClient extends LLMClient {
const extractedData = response.choices[0].message.content;
const parsedData = JSON.parse(extractedData);

if (!validateZodSchema(options.response_model.schema, parsedData)) {
const validationResult = validateZodSchemaWithResult(
options.response_model.schema,
parsedData,
);

if (!validationResult.success) {
if (retries > 0) {
// as-casting to account for o1 models not supporting all options
return this.createChatCompletion({
Expand All @@ -421,7 +427,9 @@ export class OpenAIClient extends LLMClient {
});
}

throw new CreateChatCompletionResponseError("Invalid response schema");
throw new CreateChatCompletionResponseValidationError(
validationResult.error,
);
}

if (this.enableCaching) {
Expand Down
8 changes: 8 additions & 0 deletions types/stagehandErrors.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { ZodValidationError } from "./zod";

export class StagehandError extends Error {
constructor(message: string) {
super(message);
Expand Down Expand Up @@ -136,6 +138,12 @@ export class CreateChatCompletionResponseError extends StagehandError {
}
}

export class CreateChatCompletionResponseValidationError extends StagehandError {
constructor(message: ZodValidationError) {
super(`ResponseValidationError: ${message.format()}`);
}
}

export class StagehandEvalError extends StagehandError {
constructor(message: string) {
super(`StagehandEvalError: ${message}`);
Expand Down
23 changes: 23 additions & 0 deletions types/zod.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import { z, ZodError } from "zod";

export interface ZodValidationResult {
success: boolean;
error?: ZodError;
}

export function validateZodSchemaWithResult(
schema: z.ZodTypeAny,
data: unknown
): ZodValidationResult {
try {
schema.parse(data);
return {
success: true,
};
} catch (error) {
return {
success: false,
error: error instanceof ZodError ? error : new ZodError([]),
};
}
}