Skip to content

Commit

Permalink
feat: add OpenAI provider with structured output support (#28)
Browse files Browse the repository at this point in the history
* feat: add OpenAI provider with structured output support

Co-Authored-By: Han Xiao <[email protected]>

* fix: add @ai-sdk/openai dependency and fix modelConfigs access

Co-Authored-By: Han Xiao <[email protected]>

* fix: correct indentation in agent.ts

Co-Authored-By: Han Xiao <[email protected]>

* refactor: centralize model initialization in config.ts

Co-Authored-By: Han Xiao <[email protected]>

* refactor: improve model config access patterns

Co-Authored-By: Han Xiao <[email protected]>

* fix: remove unused imports

Co-Authored-By: Han Xiao <[email protected]>

* refactor: clean up

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: Han Xiao <[email protected]>
  • Loading branch information
devin-ai-integration[bot] and hanxiao authored Feb 6, 2025
1 parent f1c7ada commit 50dff08
Show file tree
Hide file tree
Showing 15 changed files with 271 additions and 100 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ COPY . .

# Set environment variables
ENV GEMINI_API_KEY=${GEMINI_API_KEY}
ENV OPENAI_API_KEY=${OPENAI_API_KEY}
ENV JINA_API_KEY=${JINA_API_KEY}
ENV BRAVE_API_KEY=${BRAVE_API_KEY}

Expand Down
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,7 @@ flowchart LR

## Install

We use gemini for llm, [jina reader](https://jina.ai/reader) for searching and reading webpages.

```bash
export GEMINI_API_KEY=... # for gemini api, ask han
export JINA_API_KEY=jina_... # free jina api key, get from https://jina.ai/reader

git clone https://github.com/jina-ai/node-DeepResearch.git
cd node-DeepResearch
npm install
Expand All @@ -39,7 +34,14 @@ npm install

## Usage

We use Gemini/OpenAI for reasoning, [Jina Reader](https://jina.ai/reader) for searching and reading webpages, you can get a free API key with 1M tokens from jina.ai.

```bash
export GEMINI_API_KEY=... # for gemini
# export OPENAI_API_KEY=... # for openai
# export LLM_PROVIDER=openai # for openai
export JINA_API_KEY=jina_... # free jina api key, get from https://jina.ai/reader

npm run dev $QUERY
```

Expand Down
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ services:
dockerfile: Dockerfile
environment:
- GEMINI_API_KEY=${GEMINI_API_KEY}
- OPENAI_API_KEY=${OPENAI_API_KEY}
- JINA_API_KEY=${JINA_API_KEY}
- BRAVE_API_KEY=${BRAVE_API_KEY}
ports:
Expand Down
23 changes: 20 additions & 3 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"description": "",
"dependencies": {
"@ai-sdk/google": "^1.0.0",
"@ai-sdk/openai": "^1.1.9",
"@types/cors": "^2.8.17",
"@types/express": "^5.0.0",
"@types/node-fetch": "^2.6.12",
Expand Down
11 changes: 5 additions & 6 deletions src/agent.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import {createGoogleGenerativeAI} from '@ai-sdk/google';
import {z} from 'zod';
import {generateObject} from 'ai';
import {getModel, getMaxTokens, SEARCH_PROVIDER, STEP_SLEEP} from "./config";
import {readUrl} from "./tools/read";
import {handleGenerateObjectError} from './utils/error-handling';
import fs from 'fs/promises';
Expand All @@ -10,7 +10,6 @@ import {rewriteQuery} from "./tools/query-rewriter";
import {dedupQueries} from "./tools/dedup";
import {evaluateAnswer} from "./tools/evaluator";
import {analyzeSteps} from "./tools/error-analyzer";
import {SEARCH_PROVIDER, STEP_SLEEP, modelConfigs} from "./config";
import {TokenTracker} from "./utils/token-tracker";
import {ActionTracker} from "./utils/action-tracker";
import {StepAction, AnswerAction} from "./types";
Expand Down Expand Up @@ -325,15 +324,15 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_
false
);

const model = createGoogleGenerativeAI({apiKey: process.env.GEMINI_API_KEY})(modelConfigs.agent.model);
const model = getModel('agent');
let object;
let totalTokens = 0;
try {
const result = await generateObject({
model,
schema: getSchema(allowReflect, allowRead, allowAnswer, allowSearch),
prompt,
maxTokens: modelConfigs.agent.maxTokens
maxTokens: getMaxTokens('agent')
});
object = result.object;
totalTokens = result.usage?.totalTokens || 0;
Expand Down Expand Up @@ -671,15 +670,15 @@ You decided to think out of the box or cut from a completely different angle.`);
true
);

const model = createGoogleGenerativeAI({apiKey: process.env.GEMINI_API_KEY})(modelConfigs.agentBeastMode.model);
const model = getModel('agentBeastMode');
let object;
let totalTokens = 0;
try {
const result = await generateObject({
model,
schema: getSchema(false, false, allowAnswer, false),
prompt,
maxTokens: modelConfigs.agentBeastMode.maxTokens
maxTokens: getMaxTokens('agentBeastMode')
});
object = result.object;
totalTokens = result.usage?.totalTokens || 0;
Expand Down
121 changes: 93 additions & 28 deletions src/config.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,35 @@
import dotenv from 'dotenv';
import { ProxyAgent, setGlobalDispatcher } from 'undici';
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { createOpenAI } from '@ai-sdk/openai';

interface ModelConfig {
export type LLMProvider = 'openai' | 'gemini';
export type ToolName = keyof ToolConfigs;

function isValidProvider(provider: string): provider is LLMProvider {
return provider === 'openai' || provider === 'gemini';
}

function validateModelConfig(config: ModelConfig, toolName: string): ModelConfig {
if (typeof config.model !== 'string' || config.model.length === 0) {
throw new Error(`Invalid model name for ${toolName}`);
}
if (typeof config.temperature !== 'number' || config.temperature < 0 || config.temperature > 1) {
throw new Error(`Invalid temperature for ${toolName}`);
}
if (typeof config.maxTokens !== 'number' || config.maxTokens <= 0) {
throw new Error(`Invalid maxTokens for ${toolName}`);
}
return config;
}

export interface ModelConfig {
model: string;
temperature: number;
maxTokens: number;
}

interface ToolConfigs {
export interface ToolConfigs {
dedup: ModelConfig;
evaluator: ModelConfig;
errorAnalyzer: ModelConfig;
Expand All @@ -31,44 +53,87 @@ if (process.env.https_proxy) {
}

export const GEMINI_API_KEY = process.env.GEMINI_API_KEY as string;
export const OPENAI_API_KEY = process.env.OPENAI_API_KEY as string;
export const JINA_API_KEY = process.env.JINA_API_KEY as string;
export const BRAVE_API_KEY = process.env.BRAVE_API_KEY as string;
export const SEARCH_PROVIDER: 'brave' | 'jina' | 'duck' = 'jina'
export const SEARCH_PROVIDER: 'brave' | 'jina' | 'duck' = 'jina';
export const LLM_PROVIDER: LLMProvider = (() => {
const provider = process.env.LLM_PROVIDER || 'gemini';
if (!isValidProvider(provider)) {
throw new Error(`Invalid LLM provider: ${provider}`);
}
return provider;
})();

const DEFAULT_MODEL = 'gemini-1.5-flash';
const DEFAULT_GEMINI_MODEL = 'gemini-1.5-flash';
const DEFAULT_OPENAI_MODEL = 'gpt-4o-mini';

const defaultConfig: ModelConfig = {
model: DEFAULT_MODEL,
const defaultGeminiConfig: ModelConfig = {
model: DEFAULT_GEMINI_MODEL,
temperature: 0,
maxTokens: 1000
};

export const modelConfigs: ToolConfigs = {
dedup: {
...defaultConfig,
temperature: 0.1
},
evaluator: {
...defaultConfig
},
errorAnalyzer: {
...defaultConfig
},
queryRewriter: {
...defaultConfig,
temperature: 0.1
},
agent: {
...defaultConfig,
temperature: 0.7
const defaultOpenAIConfig: ModelConfig = {
model: DEFAULT_OPENAI_MODEL,
temperature: 0,
maxTokens: 1000
};

export const modelConfigs: Record<LLMProvider, ToolConfigs> = {
gemini: {
dedup: validateModelConfig({ ...defaultGeminiConfig, temperature: 0.1 }, 'dedup'),
evaluator: validateModelConfig({ ...defaultGeminiConfig }, 'evaluator'),
errorAnalyzer: validateModelConfig({ ...defaultGeminiConfig }, 'errorAnalyzer'),
queryRewriter: validateModelConfig({ ...defaultGeminiConfig, temperature: 0.1 }, 'queryRewriter'),
agent: validateModelConfig({ ...defaultGeminiConfig, temperature: 0.7 }, 'agent'),
agentBeastMode: validateModelConfig({ ...defaultGeminiConfig, temperature: 0.7 }, 'agentBeastMode')
},
agentBeastMode: {
...defaultConfig,
temperature: 0.7
openai: {
dedup: validateModelConfig({ ...defaultOpenAIConfig, temperature: 0.1 }, 'dedup'),
evaluator: validateModelConfig({ ...defaultOpenAIConfig }, 'evaluator'),
errorAnalyzer: validateModelConfig({ ...defaultOpenAIConfig }, 'errorAnalyzer'),
queryRewriter: validateModelConfig({ ...defaultOpenAIConfig, temperature: 0.1 }, 'queryRewriter'),
agent: validateModelConfig({ ...defaultOpenAIConfig, temperature: 0.7 }, 'agent'),
agentBeastMode: validateModelConfig({ ...defaultOpenAIConfig, temperature: 0.7 }, 'agentBeastMode')
}
};

export function getToolConfig(toolName: ToolName): ModelConfig {
if (!modelConfigs[LLM_PROVIDER][toolName]) {
throw new Error(`Invalid tool name: ${toolName}`);
}
return modelConfigs[LLM_PROVIDER][toolName];
}

export function getMaxTokens(toolName: ToolName): number {
return getToolConfig(toolName).maxTokens;
}


export function getModel(toolName: ToolName) {
const config = getToolConfig(toolName);

if (LLM_PROVIDER === 'openai') {
if (!OPENAI_API_KEY) {
throw new Error('OPENAI_API_KEY not found');
}
return createOpenAI({
apiKey: OPENAI_API_KEY,
compatibility: 'strict'
})(config.model);
}

if (!GEMINI_API_KEY) {
throw new Error('GEMINI_API_KEY not found');
}
return createGoogleGenerativeAI({ apiKey: GEMINI_API_KEY })(config.model);
}

export const STEP_SLEEP = 1000;

if (!GEMINI_API_KEY) throw new Error("GEMINI_API_KEY not found");
if (LLM_PROVIDER === 'gemini' && !GEMINI_API_KEY) throw new Error("GEMINI_API_KEY not found");
if (LLM_PROVIDER === 'openai' && !OPENAI_API_KEY) throw new Error("OPENAI_API_KEY not found");
if (!JINA_API_KEY) throw new Error("JINA_API_KEY not found");

console.log('LLM Provider:', LLM_PROVIDER)
39 changes: 30 additions & 9 deletions src/tools/__tests__/dedup.test.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,37 @@
import { dedupQueries } from '../dedup';
import { LLMProvider } from '../../config';

describe('dedupQueries', () => {
it('should remove duplicate queries', async () => {
jest.setTimeout(10000); // Increase timeout to 10s
const queries = ['typescript tutorial', 'typescript tutorial', 'javascript basics'];
const { unique_queries } = await dedupQueries(queries, []);
expect(unique_queries).toHaveLength(2);
expect(unique_queries).toContain('javascript basics');
const providers: Array<LLMProvider> = ['openai', 'gemini'];
const originalEnv = process.env;

beforeEach(() => {
jest.resetModules();
process.env = { ...originalEnv };
});

afterEach(() => {
process.env = originalEnv;
});

it('should handle empty input', async () => {
const { unique_queries } = await dedupQueries([], []);
expect(unique_queries).toHaveLength(0);
providers.forEach(provider => {
describe(`with ${provider} provider`, () => {
beforeEach(() => {
process.env.LLM_PROVIDER = provider;
});

it('should remove duplicate queries', async () => {
jest.setTimeout(10000);
const queries = ['typescript tutorial', 'typescript tutorial', 'javascript basics'];
const { unique_queries } = await dedupQueries(queries, []);
expect(unique_queries).toHaveLength(2);
expect(unique_queries).toContain('javascript basics');
});

it('should handle empty input', async () => {
const { unique_queries } = await dedupQueries([], []);
expect(unique_queries).toHaveLength(0);
});
});
});
});
31 changes: 26 additions & 5 deletions src/tools/__tests__/error-analyzer.test.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,31 @@
import { analyzeSteps } from '../error-analyzer';
import { LLMProvider } from '../../config';

describe('analyzeSteps', () => {
it('should analyze error steps', async () => {
const { response } = await analyzeSteps(['Step 1: Search failed', 'Step 2: Invalid query']);
expect(response).toHaveProperty('recap');
expect(response).toHaveProperty('blame');
expect(response).toHaveProperty('improvement');
const providers: Array<LLMProvider> = ['openai', 'gemini'];
const originalEnv = process.env;

beforeEach(() => {
jest.resetModules();
process.env = { ...originalEnv };
});

afterEach(() => {
process.env = originalEnv;
});

providers.forEach(provider => {
describe(`with ${provider} provider`, () => {
beforeEach(() => {
process.env.LLM_PROVIDER = provider;
});

it('should analyze error steps', async () => {
const { response } = await analyzeSteps(['Step 1: Search failed', 'Step 2: Invalid query']);
expect(response).toHaveProperty('recap');
expect(response).toHaveProperty('blame');
expect(response).toHaveProperty('improvement');
});
});
});
});
Loading

0 comments on commit 50dff08

Please sign in to comment.