From b39d2aa54966e558053fd851e2899b4a3d7be66c Mon Sep 17 00:00:00 2001 From: Sayo Date: Thu, 6 Feb 2025 19:01:10 +0530 Subject: [PATCH] newer tests --- .env.example | 110 +++-- .gitignore | 5 +- packages/client-direct/src/index.ts | 2 +- packages/core/__tests__/embedding.test.ts | 431 +++++++++----------- packages/core/__tests__/environment.test.ts | 57 +-- packages/core/__tests__/models.test.ts | 186 --------- packages/core/__tests__/runtime.test.ts | 126 +++++- packages/core/src/embedding.ts | 10 +- packages/core/src/environment.ts | 7 +- packages/core/src/generation.ts | 22 +- 10 files changed, 407 insertions(+), 549 deletions(-) delete mode 100644 packages/core/__tests__/models.test.ts diff --git a/.env.example b/.env.example index bb373c27178..fb92bd2bfbc 100644 --- a/.env.example +++ b/.env.example @@ -117,77 +117,63 @@ EXPRESS_MAX_PAYLOAD= # Default: 100kb ####################################### ### Provider Settings ### -PROVIDER_ENDPOINT=https://api.openai.com/v1/chat/completions +PROVIDER_ENDPOINT=https://api.openai.com/v1 PROVIDER_NAME=openai PROVIDER_API_KEY= -### Model Names ### -DEFAULT_MODEL=gpt-3.5-turbo -SMALL_MODEL=gpt-3.5-turbo-instruct +# Model Settings - Default +DEFAULT_MODEL=gpt-4 +DEFAULT_MODEL_MAX_INPUT_TOKENS=4096 +DEFAULT_MODEL_MAX_OUTPUT_TOKENS=1024 +DEFAULT_MODEL_TEMPERATURE=0.7 +DEFAULT_MODEL_FREQUENCY_PENALTY=0 +DEFAULT_MODEL_PRESENCE_PENALTY=0 +DEFAULT_MODEL_REPETITION_PENALTY=1.0 +DEFAULT_MODEL_STOP= + +# Model Settings - Small +SMALL_MODEL=gpt-3.5-turbo +SMALL_MODEL_MAX_INPUT_TOKENS=2048 +SMALL_MODEL_MAX_OUTPUT_TOKENS=1024 +SMALL_MODEL_TEMPERATURE=0.7 +SMALL_MODEL_FREQUENCY_PENALTY=0 +SMALL_MODEL_PRESENCE_PENALTY=0 +SMALL_MODEL_REPETITION_PENALTY=1.0 +SMALL_MODEL_STOP= + +# Model Settings - Medium MEDIUM_MODEL=gpt-4 +MEDIUM_MODEL_MAX_INPUT_TOKENS=2048 +MEDIUM_MODEL_MAX_OUTPUT_TOKENS=1024 +MEDIUM_MODEL_TEMPERATURE=0.7 +MEDIUM_MODEL_FREQUENCY_PENALTY=0 +MEDIUM_MODEL_PRESENCE_PENALTY=0 +MEDIUM_MODEL_REPETITION_PENALTY=1.0 +MEDIUM_MODEL_STOP= + +# Model Settings - Large LARGE_MODEL=gpt-4-turbo -EMBEDDING_MODEL=text-embedding-ada-002 -USE_OPENAI_EMBEDDING=true -IMAGE_MODEL=dall-e-3 - -### Default Model Settings ### -DEFAULT_MAX_INPUT_TOKENS=4096 -DEFAULT_MAX_OUTPUT_TOKENS=1024 -DEFAULT_TEMPERATURE=0.7 -DEFAULT_STOP_SEQUENCES= -DEFAULT_FREQUENCY_PENALTY=0 -DEFAULT_PRESENCE_PENALTY=0 -DEFAULT_REPETITION_PENALTY=1.0 - -### Model-Specific Token Limits ### -# Small Model -SMALL_INPUT_TOKENS=2048 -SMALL_OUTPUT_TOKENS=512 - -## Small Model Settings ### -SMALL_MAX_INPUT_TOKENS=4096 -SMALL_MAX_OUTPUT_TOKENS=1024 -SMALL_TEMPERATURE=0.7 -SMALL_STOP_SEQUENCES=###,END -SMALL_FREQUENCY_PENALTY=0 -SMALL_PRESENCE_PENALTY=0 -SMALL_REPETITION_PENALTY=1.0 - -# Medium Model -MEDIUM_INPUT_TOKENS=8192 -MEDIUM_OUTPUT_TOKENS=2048 - -## Medium Model Settings ### -MEDIUM_MAX_INPUT_TOKENS=4096 -MEDIUM_MAX_OUTPUT_TOKENS=1024 -MEDIUM_TEMPERATURE=0.7 -MEDIUM_STOP_SEQUENCES=###,END -MEDIUM_FREQUENCY_PENALTY=0 -MEDIUM_PRESENCE_PENALTY=0 -MEDIUM_REPETITION_PENALTY=1.0 - -# Large Model -LARGE_INPUT_TOKENS=128000 -LARGE_OUTPUT_TOKENS=4096 - -## Large Model Settings ### -LARGE_MAX_INPUT_TOKENS=4096 -LARGE_MAX_OUTPUT_TOKENS=1024 -LARGE_TEMPERATURE=0.7 -LARGE_STOP_SEQUENCES=###,END -LARGE_FREQUENCY_PENALTY=0 -LARGE_PRESENCE_PENALTY=0 -LARGE_REPETITION_PENALTY=1.0 - - - -### Specialized Model Settings ### -# Embedding Model +LARGE_MODEL_MAX_INPUT_TOKENS=4096 +LARGE_MODEL_MAX_OUTPUT_TOKENS=1024 +LARGE_MODEL_TEMPERATURE=0.7 +LARGE_MODEL_FREQUENCY_PENALTY=0 +LARGE_MODEL_PRESENCE_PENALTY=0 +LARGE_MODEL_REPETITION_PENALTY=1.0 +LARGE_MODEL_STOP= + +# Model Settings - Embedding +EMBEDDING_MODEL=text-embedding-3-small EMBEDDING_DIMENSIONS=1536 -# Image Generation +# Model Settings - Image +IMAGE_MODEL=dall-e-3 IMAGE_STEPS=50 +# Model Settings - Vision +IMAGE_VISION_MODEL=gpt-4-vision-preview + + + # Community Plugin for OpenAI Configuration ENABLE_OPEN_AI_COMMUNITY_PLUGIN=false diff --git a/.gitignore b/.gitignore index 2746205bef8..a3625f1e299 100644 --- a/.gitignore +++ b/.gitignore @@ -91,4 +91,7 @@ lit-config.json # Configuration to exclude the extra and local_docs directories extra -**/dist/** \ No newline at end of file +**/dist/** + +**/local_cache/** +**/local_cache \ No newline at end of file diff --git a/packages/client-direct/src/index.ts b/packages/client-direct/src/index.ts index 964acf02ffa..3019e3abc92 100644 --- a/packages/client-direct/src/index.ts +++ b/packages/client-direct/src/index.ts @@ -302,7 +302,7 @@ export class DirectClient { // save response to memory const responseMessage: Memory = { - id: stringToUuid(messageId + "-" + runtime.agentId), + id: stringToUuid(`${messageId}-${runtime.agentId}`), ...userMessage, userId: runtime.agentId, content: response, diff --git a/packages/core/__tests__/embedding.test.ts b/packages/core/__tests__/embedding.test.ts index 9358137c4c5..60eef920716 100644 --- a/packages/core/__tests__/embedding.test.ts +++ b/packages/core/__tests__/embedding.test.ts @@ -1,12 +1,13 @@ -import { describe, test, expect, vi, beforeEach } from "vitest"; +import { beforeEach, describe, expect, test, vi } from "vitest"; import { embed, + EmbeddingProvider, getEmbeddingConfig, getEmbeddingType, - getEmbeddingZeroVector, + getEmbeddingZeroVector } from "../src/embedding.ts"; -import { type IAgentRuntime, ModelProviderName } from "../src/types.ts"; import settings from "../src/settings.ts"; +import { type IAgentRuntime, type IDatabaseAdapter, ModelClass, ModelProviderName } from "../src/types.ts"; // Mock environment-related settings vi.mock("../settings", () => ({ @@ -14,18 +15,24 @@ vi.mock("../settings", () => ({ USE_OPENAI_EMBEDDING: "false", USE_OLLAMA_EMBEDDING: "false", USE_GAIANET_EMBEDDING: "false", - OPENAI_API_KEY: "mock-openai-key", - OPENAI_API_URL: "https://api.openai.com/v1", - GAIANET_API_KEY: "mock-gaianet-key", - OLLAMA_EMBEDDING_MODEL: "mxbai-embed-large", - GAIANET_EMBEDDING_MODEL: "nomic-embed", }, })); + +describe("Embedding Configuration", () => { + test("should return default config when no runtime provided", () => { + const config = getEmbeddingConfig(); + expect(config.provider).toBe(EmbeddingProvider.OpenAI); + expect(config.model).toBe("text-embedding-3-small"); + expect(config.dimensions).toBe(1536); + }); +}); + + // Mock fastembed module for local embeddings vi.mock("fastembed", () => { class MockFlagEmbedding { - constructor() {} + static async init() { return new MockFlagEmbedding(); @@ -44,9 +51,8 @@ vi.mock("fastembed", () => { }; }); -// Mock global fetch for remote embedding requests -const mockFetch = vi.fn(); -(global as any).fetch = mockFetch; +// Mock fetch for remote embedding calls +global.fetch = vi.fn(); describe("Embedding Module", () => { let mockRuntime: IAgentRuntime; @@ -55,39 +61,13 @@ describe("Embedding Module", () => { // Reset all mocks vi.clearAllMocks(); + // Reset settings + vi.mocked(settings).USE_OPENAI_EMBEDDING = "false"; + // Prepare a mock runtime mockRuntime = { agentId: "00000000-0000-0000-0000-000000000000" as `${string}-${string}-${string}-${string}-${string}`, serverUrl: "http://test-server", - databaseAdapter: { - init: vi.fn(), - close: vi.fn(), - getMemories: vi.fn(), - createMemory: vi.fn(), - removeMemory: vi.fn(), - searchMemories: vi.fn(), - searchMemoriesByEmbedding: vi.fn(), - getGoals: vi.fn(), - createGoal: vi.fn(), - updateGoal: vi.fn(), - removeGoal: vi.fn(), - getRoom: vi.fn(), - createRoom: vi.fn(), - removeRoom: vi.fn(), - addParticipant: vi.fn(), - removeParticipant: vi.fn(), - getParticipantsForRoom: vi.fn(), - getParticipantUserState: vi.fn(), - setParticipantUserState: vi.fn(), - createRelationship: vi.fn(), - getRelationship: vi.fn(), - getRelationships: vi.fn(), - getKnowledge: vi.fn(), - searchKnowledge: vi.fn(), - createKnowledge: vi.fn(), - removeKnowledge: vi.fn(), - clearKnowledge: vi.fn(), - }, token: "test-token", modelProvider: ModelProviderName.OPENAI, imageModelProvider: ModelProviderName.OPENAI, @@ -103,180 +83,83 @@ describe("Embedding Module", () => { bio: ["Test bio"], lore: ["Test lore"], messageExamples: [], - postExamples: [], - topics: [], - adjectives: [], - style: { - all: [], - chat: [], - post: [] - }, - clients: [], - plugins: [], }, getModelProvider: () => ({ apiKey: "test-key", endpoint: "test-endpoint", provider: ModelProviderName.OPENAI, models: { - default: { name: "test-model", maxInputTokens: 4096, maxOutputTokens: 4096, stop: [], temperature: 0.7 }, + [ModelClass.EMBEDDING]: { + name: "text-embedding-3-small", + dimensions: 1536 + } }, }), getSetting: (key: string) => { const settings = { USE_OPENAI_EMBEDDING: "false", - USE_OLLAMA_EMBEDDING: "false", - USE_GAIANET_EMBEDDING: "false", - OPENAI_API_KEY: "mock-openai-key", - OPENAI_API_URL: "https://api.openai.com/v1", - GAIANET_API_KEY: "mock-gaianet-key", - OLLAMA_EMBEDDING_MODEL: "mxbai-embed-large", - GAIANET_EMBEDDING_MODEL: "nomic-embed", + PROVIDER_ENDPOINT: "https://api.openai.com/v1", + PROVIDER_API_KEY: "test-key" }; return settings[key as keyof typeof settings] || ""; }, - knowledgeManager: { - init: vi.fn(), - close: vi.fn(), - addKnowledge: vi.fn(), - removeKnowledge: vi.fn(), - searchKnowledge: vi.fn(), - clearKnowledge: vi.fn(), - }, - memoryManager: { - init: vi.fn(), - close: vi.fn(), - addMemory: vi.fn(), - removeMemory: vi.fn(), - searchMemories: vi.fn(), - searchMemoriesByEmbedding: vi.fn(), - clearMemories: vi.fn(), - }, - goalManager: { - init: vi.fn(), - close: vi.fn(), - addGoal: vi.fn(), - updateGoal: vi.fn(), - removeGoal: vi.fn(), - getGoals: vi.fn(), - clearGoals: vi.fn(), - }, - relationshipManager: { - init: vi.fn(), - close: vi.fn(), - addRelationship: vi.fn(), - getRelationship: vi.fn(), - getRelationships: vi.fn(), - }, - cacheManager: { - get: vi.fn(), - set: vi.fn(), - delete: vi.fn(), - }, - services: new Map(), - clients: {}, messageManager: { - runtime: {} as IAgentRuntime, - tableName: "messages", - addEmbeddingToMemory: vi.fn(), - getMemories: vi.fn(), - getCachedEmbeddings: vi.fn(), - getMemoryById: vi.fn(), - getMemoriesByRoomIds: vi.fn(), - searchMemoriesByEmbedding: vi.fn(), - createMemory: vi.fn(), - removeMemory: vi.fn(), - removeAllMemories: vi.fn(), - countMemories: vi.fn(), - }, - descriptionManager: { - runtime: {} as IAgentRuntime, - tableName: "descriptions", - addEmbeddingToMemory: vi.fn(), - getMemories: vi.fn(), - getCachedEmbeddings: vi.fn(), - getMemoryById: vi.fn(), - getMemoriesByRoomIds: vi.fn(), - searchMemoriesByEmbedding: vi.fn(), - createMemory: vi.fn(), - removeMemory: vi.fn(), - removeAllMemories: vi.fn(), - countMemories: vi.fn(), - }, - documentsManager: { - runtime: {} as IAgentRuntime, - tableName: "documents", - addEmbeddingToMemory: vi.fn(), - getMemories: vi.fn(), - getCachedEmbeddings: vi.fn(), - getMemoryById: vi.fn(), - getMemoriesByRoomIds: vi.fn(), - searchMemoriesByEmbedding: vi.fn(), - createMemory: vi.fn(), - removeMemory: vi.fn(), - removeAllMemories: vi.fn(), - countMemories: vi.fn(), - }, - loreManager: { - runtime: {} as IAgentRuntime, - tableName: "lore", - addEmbeddingToMemory: vi.fn(), - getMemories: vi.fn(), - getCachedEmbeddings: vi.fn(), - getMemoryById: vi.fn(), - getMemoriesByRoomIds: vi.fn(), - searchMemoriesByEmbedding: vi.fn(), - createMemory: vi.fn(), - removeMemory: vi.fn(), - removeAllMemories: vi.fn(), - countMemories: vi.fn(), - }, - ragKnowledgeManager: { - runtime: {} as IAgentRuntime, - tableName: "rag_knowledge", - getKnowledge: vi.fn(), - createKnowledge: vi.fn(), - removeKnowledge: vi.fn(), - searchKnowledge: vi.fn(), - clearKnowledge: vi.fn(), - processFile: vi.fn(), - cleanupDeletedKnowledgeFiles: vi.fn(), - generateScopedId: vi.fn(), - }, - initialize: vi.fn(), - registerMemoryManager: vi.fn(), - getMemoryManager: vi.fn(), - getService: vi.fn(), - registerService: vi.fn(), - composeState: vi.fn(), - processActions: vi.fn(), - evaluate: vi.fn(), - ensureParticipantExists: vi.fn(), - ensureUserExists: vi.fn(), - ensureConnection: vi.fn(), - ensureParticipantInRoom: vi.fn(), - ensureRoomExists: vi.fn(), - updateRecentMessageState: vi.fn(), - getConversationLength: vi.fn(), - registerAction: vi.fn(), + getCachedEmbeddings: vi.fn().mockResolvedValue([]) + } } as unknown as IAgentRuntime; - // Reset fetch mock - mockFetch.mockReset(); - mockFetch.mockResolvedValue({ + // Reset fetch mock with proper Response object + const mockResponse = { ok: true, json: async () => ({ - data: [new Array(384).fill(0.1)], + data: [{ embedding: new Array(384).fill(0.1) }], }), - }); + headers: new Headers(), + redirected: false, + status: 200, + statusText: "OK", + type: "basic", + url: "https://api.openai.com/v1/embeddings", + body: null, + bodyUsed: false, + clone: () => ({} as Response), + arrayBuffer: async () => new ArrayBuffer(0), + blob: async () => new Blob(), + formData: async () => new FormData(), + text: async () => "" + } as Response; + + vi.mocked(global.fetch).mockReset(); + vi.mocked(global.fetch).mockResolvedValue(mockResponse); }); describe("getEmbeddingConfig", () => { - test("should return BGE config by default", () => { + test("should return OpenAI config by default", () => { const config = getEmbeddingConfig(); - expect(config.dimensions).toBe(384); - expect(config.model).toBe("BGE-small-en-v1.5"); - expect(config.provider).toBe("BGE"); + expect(config.provider).toBe(EmbeddingProvider.OpenAI); + expect(config.model).toBe("text-embedding-3-small"); + expect(config.dimensions).toBe(1536); + }); + + test("should use runtime provider when available", () => { + const mockModelProvider = { + provider: EmbeddingProvider.OpenAI, + models: { + [ModelClass.EMBEDDING]: { + name: "text-embedding-3-small", + dimensions: 1536 + } + } + }; + + const runtime = { + getModelProvider: () => mockModelProvider + } as unknown as IAgentRuntime; + + const config = getEmbeddingConfig(runtime); + expect(config.provider).toBe(EmbeddingProvider.OpenAI); + expect(config.model).toBe("text-embedding-3-small"); + expect(config.dimensions).toBe(1536); }); }); @@ -303,8 +186,6 @@ describe("Embedding Module", () => { describe("getEmbeddingZeroVector", () => { beforeEach(() => { vi.mocked(settings).USE_OPENAI_EMBEDDING = "false"; - vi.mocked(settings).USE_OLLAMA_EMBEDDING = "false"; - vi.mocked(settings).USE_GAIANET_EMBEDDING = "false"; }); test("should return 384-length zero vector by default (BGE)", () => { @@ -322,17 +203,6 @@ describe("Embedding Module", () => { }); describe("embed function", () => { - beforeEach(() => { - // Mock a successful remote response with an example 384-dim embedding - mockFetch.mockResolvedValue({ - ok: true, - json: () => - Promise.resolve({ - data: [{ embedding: new Array(384).fill(0.1) }], - }), - }); - }); - test("should return an empty array for empty input text", async () => { const result = await embed(mockRuntime, ""); expect(result).toEqual([]); @@ -348,57 +218,61 @@ describe("Embedding Module", () => { expect(result).toBe(cachedEmbedding); }); - test("should handle local embedding successfully (fastembed fallback)", async () => { - // By default, it tries local first if in Node. - // Then uses the mock fastembed response above. + test("should handle local embedding successfully", async () => { const result = await embed(mockRuntime, "test input"); expect(result).toHaveLength(384); expect(result.every((v) => typeof v === "number")).toBe(true); }); - test("should fallback to remote if local embedding fails", async () => { - // Force fastembed import to fail - vi.mock("fastembed", () => { - throw new Error("Module not found"); - }); - - // Mock a valid remote response - const mockResponse = { - ok: true, - json: () => - Promise.resolve({ - data: [{ embedding: new Array(384).fill(0.1) }], - }), - }; - mockFetch.mockResolvedValueOnce(mockResponse); - - const result = await embed(mockRuntime, "test input"); + test("should handle remote embedding successfully", async () => { + // Force remote embedding + const runtimeWithOpenAI = { + ...mockRuntime, + getSetting: (key: string) => { + if (key === "USE_OPENAI_EMBEDDING") return "true"; + return mockRuntime.getSetting(key); + }, + getModelProvider: () => ({ + ...mockRuntime.getModelProvider(), + provider: EmbeddingProvider.OpenAI, + models: { + [ModelClass.EMBEDDING]: { + name: "text-embedding-3-small", + dimensions: 1536 + } + } + }) + } as IAgentRuntime; + + const result = await embed(runtimeWithOpenAI, "test input"); expect(result).toHaveLength(384); - expect(mockFetch).toHaveBeenCalled(); + expect(vi.mocked(global.fetch)).toHaveBeenCalled(); }); test("should throw on remote embedding if fetch fails", async () => { - mockFetch.mockRejectedValueOnce(new Error("API Error")); - vi.mocked(settings).USE_OPENAI_EMBEDDING = "true"; // Force remote - - await expect(embed(mockRuntime, "test input")).rejects.toThrow( - "API Error" - ); - }); + // Force remote embedding + const runtimeWithOpenAI = { + ...mockRuntime, + getSetting: (key: string) => { + if (key === "USE_OPENAI_EMBEDDING") return "true"; + return mockRuntime.getSetting(key); + }, + getModelProvider: () => ({ + ...mockRuntime.getModelProvider(), + provider: EmbeddingProvider.OpenAI, + models: { + [ModelClass.EMBEDDING]: { + name: "text-embedding-3-small", + dimensions: 1536 + } + } + }) + } as IAgentRuntime; - test("should throw on non-200 remote response", async () => { - const errorResponse = { - ok: false, - status: 400, - statusText: "Bad Request", - text: () => Promise.resolve("Invalid input"), - }; - mockFetch.mockResolvedValueOnce(errorResponse); - vi.mocked(settings).USE_OPENAI_EMBEDDING = "true"; // Force remote + // Mock fetch to reject + vi.mocked(global.fetch).mockRejectedValueOnce(new Error("API Error")); - await expect(embed(mockRuntime, "test input")).rejects.toThrow( - "Embedding API Error" - ); + await expect(embed(runtimeWithOpenAI, "test input")).rejects.toThrow("API Error"); }); test("should handle concurrent embedding requests", async () => { @@ -408,4 +282,73 @@ describe("Embedding Module", () => { await expect(Promise.all(promises)).resolves.toBeDefined(); }); }); + + // Add tests for new embedding configurations + describe("embedding configuration", () => { + test("should handle embedding provider configuration", async () => { + const mockModelProvider = { + generateText: vi.fn(), + generateObject: vi.fn(), + generateImage: vi.fn(), + generateEmbedding: vi.fn(), + provider: EmbeddingProvider.OpenAI, + models: { + [ModelClass.EMBEDDING]: { + name: "text-embedding-3-small", + dimensions: 1536 + } + }, + getModelProvider: () => mockModelProvider + }; + + const runtime = { + agentId: "test-agent", + serverUrl: "http://test.com", + databaseAdapter: {} as IDatabaseAdapter, + token: "test-token", + modelProvider: mockModelProvider, + imageModelProvider: mockModelProvider, + imageVisionModelProvider: mockModelProvider, + embeddingModelProvider: mockModelProvider, + getModelProvider: () => mockModelProvider, + settings: { + USE_OPENAI_EMBEDDING: "true", + USE_OLLAMA_EMBEDDING: "true", + USE_GAIANET_EMBEDDING: "false", + OPENAI_API_KEY: "test-key", + OLLAMA_EMBEDDING_MODEL: "mxbai-embed-large", + } + } as unknown as IAgentRuntime; + + const config = getEmbeddingConfig(runtime); + expect(config.provider).toBe(EmbeddingProvider.OpenAI); + expect(config.model).toBe("text-embedding-3-small"); + expect(config.dimensions).toBe(1536); + }); + + test("should return default config when no runtime provided", () => { + const config = getEmbeddingConfig(); + expect(config.provider).toBe(EmbeddingProvider.OpenAI); + expect(config.model).toBe("text-embedding-3-small"); + expect(config.dimensions).toBe(1536); + }); + }); + + describe("embedding type detection", () => { + test("should determine embedding type based on runtime configuration", () => { + const mockRuntimeRemote = { + ...mockRuntime, + getSetting: (key: string) => key === "USE_OPENAI_EMBEDDING" ? "true" : "false" + } as IAgentRuntime; + + expect(getEmbeddingType(mockRuntimeRemote)).toBe("remote"); + + const mockRuntimeLocal = { + ...mockRuntime, + getSetting: (key: string) => "false" + } as IAgentRuntime; + + expect(getEmbeddingType(mockRuntimeLocal)).toBe("local"); + }); + }); }); diff --git a/packages/core/__tests__/environment.test.ts b/packages/core/__tests__/environment.test.ts index 7f26c0b672e..88d21776d50 100644 --- a/packages/core/__tests__/environment.test.ts +++ b/packages/core/__tests__/environment.test.ts @@ -8,13 +8,17 @@ describe("Environment Configuration", () => { beforeEach(() => { process.env = { ...originalEnv, - OPENAI_API_KEY: "sk-test123", - REDPILL_API_KEY: "test-key", - GROK_API_KEY: "test-key", - GROQ_API_KEY: "gsk_test123", - OPENROUTER_API_KEY: "test-key", - GOOGLE_GENERATIVE_AI_API_KEY: "test-key", + PROVIDER_NAME: "OPENAI", + PROVIDER_API_KEY: "sk-test123", + PROVIDER_ENDPOINT: "https://api.openai.com/v1", ELEVENLABS_XI_API_KEY: "test-key", + DEFAULT_MODEL: "gpt-4", + SMALL_MODEL: "gpt-3.5-turbo", + MEDIUM_MODEL: "gpt-4", + LARGE_MODEL: "gpt-4-turbo", + EMBEDDING_MODEL: "text-embedding-3-small", + IMAGE_MODEL: "dall-e-3", + IMAGE_VISION_MODEL: "gpt-4-vision-preview", }; }); @@ -26,35 +30,38 @@ describe("Environment Configuration", () => { expect(() => validateEnv()).not.toThrow(); }); - it("should throw error for invalid OpenAI API key format", () => { - process.env.OPENAI_API_KEY = "invalid-key"; + it("should throw error for missing required provider configuration", () => { + delete process.env.PROVIDER_NAME; + delete process.env.PROVIDER_API_KEY; expect(() => validateEnv()).toThrow( - "OpenAI API key must start with 'sk-'" + "Environment validation failed:\n" + + "PROVIDER_NAME: Required\n" + + "PROVIDER_API_KEY: Required" ); }); - it("should throw error for invalid GROQ API key format", () => { - process.env.GROQ_API_KEY = "invalid-key"; + it("should throw error for invalid provider endpoint URL", () => { + process.env.PROVIDER_ENDPOINT = "invalid-url"; expect(() => validateEnv()).toThrow( - "GROQ API key must start with 'gsk_'" + "Provider endpoint must be a valid URL" ); }); - it("should throw error for missing required keys", () => { - delete process.env.REDPILL_API_KEY; - expect(() => validateEnv()).toThrow("REDPILL_API_KEY: Required"); + it("should validate with optional fields missing", () => { + delete process.env.DEFAULT_MODEL; + delete process.env.SMALL_MODEL; + delete process.env.MEDIUM_MODEL; + delete process.env.LARGE_MODEL; + delete process.env.EMBEDDING_MODEL; + delete process.env.IMAGE_MODEL; + delete process.env.IMAGE_VISION_MODEL; + delete process.env.ELEVENLABS_XI_API_KEY; + expect(() => validateEnv()).not.toThrow(); }); - it("should throw error for multiple missing required keys", () => { - delete process.env.REDPILL_API_KEY; - delete process.env.GROK_API_KEY; - delete process.env.OPENROUTER_API_KEY; - expect(() => validateEnv()).toThrow( - "Environment validation failed:\n" + - "REDPILL_API_KEY: Required\n" + - "GROK_API_KEY: Required\n" + - "OPENROUTER_API_KEY: Required" - ); + it("should validate with invalid model provider name", () => { + process.env.PROVIDER_NAME = "INVALID_PROVIDER"; + expect(() => validateEnv()).toThrow(); }); }); diff --git a/packages/core/__tests__/models.test.ts b/packages/core/__tests__/models.test.ts deleted file mode 100644 index 53afe1e6719..00000000000 --- a/packages/core/__tests__/models.test.ts +++ /dev/null @@ -1,186 +0,0 @@ -import { describe, test, expect, vi, beforeEach } from "vitest"; -import { AgentRuntime } from "../src/runtime"; -import { ModelProviderName, ModelClass, type ModelSettings, type ImageModelSettings, type EmbeddingModelSettings } from "../src/types"; - -// Mock settings -vi.mock("../settings", () => { - return { - default: { - PROVIDER_NAME: process.env.PROVIDER_NAME || "openai", - PROVIDER_API_KEY: process.env.PROVIDER_API_KEY || "mock-openai-key", - PROVIDER_ENDPOINT: process.env.PROVIDER_ENDPOINT || "https://api.openai.com/v1", - DEFAULT_MODEL_NAME: process.env.DEFAULT_MODEL_NAME || "gpt-4o-mini", - DEFAULT_MODEL_MAX_INPUT_TOKENS: process.env.DEFAULT_MODEL_MAX_INPUT_TOKENS || "4096", - DEFAULT_MODEL_MAX_OUTPUT_TOKENS: process.env.DEFAULT_MODEL_MAX_OUTPUT_TOKENS || "1024", - DEFAULT_MODEL_TEMPERATURE: process.env.DEFAULT_MODEL_TEMPERATURE || "0.7", - DEFAULT_MODEL_STOP: process.env.DEFAULT_MODEL_STOP || "", - DEFAULT_MODEL_FREQUENCY_PENALTY: process.env.DEFAULT_MODEL_FREQUENCY_PENALTY || "0", - DEFAULT_MODEL_PRESENCE_PENALTY: process.env.DEFAULT_MODEL_PRESENCE_PENALTY || "0", - DEFAULT_MODEL_REPETITION_PENALTY: process.env.DEFAULT_MODEL_REPETITION_PENALTY || "1.0", - SMALL_MODEL_NAME: process.env.SMALL_MODEL_NAME || "gpt-4o-mini", - MEDIUM_MODEL_NAME: process.env.MEDIUM_MODEL_NAME || "gpt-4o-mini", - LARGE_MODEL_NAME: process.env.LARGE_MODEL_NAME || "gpt-4o-mini", - EMBEDDING_MODEL_NAME: process.env.EMBEDDING_MODEL_NAME || "text-embedding-3-small", - EMBEDDING_DIMENSIONS: process.env.EMBEDDING_DIMENSIONS || "1536", - IMAGE_MODEL_NAME: process.env.IMAGE_MODEL_NAME || "dall-e-3", - IMAGE_VISION_MODEL_NAME: process.env.IMAGE_VISION_MODEL_NAME || "gpt-4-vision-preview", - }, - loadEnv: vi.fn(), - }; -}); - -// Mock database adapter -const mockDatabaseAdapter = { - db: {}, - init: vi.fn().mockResolvedValue(undefined), - close: vi.fn().mockResolvedValue(undefined), - getAccountById: vi.fn().mockResolvedValue(null), - createAccount: vi.fn().mockResolvedValue(true), - getMemories: vi.fn().mockResolvedValue([]), - getMemoryById: vi.fn().mockResolvedValue(null), - getMemoriesByRoomIds: vi.fn().mockResolvedValue([]), - getMemoriesByIds: vi.fn().mockResolvedValue([]), - getCachedEmbeddings: vi.fn().mockResolvedValue([]), - log: vi.fn().mockResolvedValue(undefined), - getActorDetails: vi.fn().mockResolvedValue([]), - searchMemories: vi.fn().mockResolvedValue([]), - updateGoalStatus: vi.fn().mockResolvedValue(undefined), - searchMemoriesByEmbedding: vi.fn().mockResolvedValue([]), - createMemory: vi.fn().mockResolvedValue(undefined), - removeMemory: vi.fn().mockResolvedValue(undefined), - removeAllMemories: vi.fn().mockResolvedValue(undefined), - countMemories: vi.fn().mockResolvedValue(0), - getGoals: vi.fn().mockResolvedValue([]), - updateGoal: vi.fn().mockResolvedValue(undefined), - createGoal: vi.fn().mockResolvedValue(undefined), - removeGoal: vi.fn().mockResolvedValue(undefined), - removeAllGoals: vi.fn().mockResolvedValue(undefined), - getRoom: vi.fn().mockResolvedValue(null), - createRoom: vi.fn().mockResolvedValue("test-room-id"), - removeRoom: vi.fn().mockResolvedValue(undefined), - getRoomsForParticipant: vi.fn().mockResolvedValue([]), - getRoomsForParticipants: vi.fn().mockResolvedValue([]), - addParticipant: vi.fn().mockResolvedValue(true), - removeParticipant: vi.fn().mockResolvedValue(true), - getParticipantsForAccount: vi.fn().mockResolvedValue([]), - getParticipantsForRoom: vi.fn().mockResolvedValue([]), - getParticipantUserState: vi.fn().mockResolvedValue(null), - setParticipantUserState: vi.fn().mockResolvedValue(undefined), - createRelationship: vi.fn().mockResolvedValue(true), - getRelationship: vi.fn().mockResolvedValue(null), - getRelationships: vi.fn().mockResolvedValue([]), - getKnowledge: vi.fn().mockResolvedValue([]), - searchKnowledge: vi.fn().mockResolvedValue([]), - createKnowledge: vi.fn().mockResolvedValue(undefined), - removeKnowledge: vi.fn().mockResolvedValue(undefined), - clearKnowledge: vi.fn().mockResolvedValue(undefined), -}; - -// Mock cache manager -const mockCacheManager = { - get: vi.fn().mockResolvedValue(null), - set: vi.fn().mockResolvedValue(undefined), - delete: vi.fn().mockResolvedValue(undefined), -}; - -describe("Model Provider Configuration", () => { - let runtime: AgentRuntime; - - beforeEach(() => { - vi.clearAllMocks(); - runtime = new AgentRuntime({ - token: "test-token", - character: { - name: "Test Character", - username: "test", - bio: "Test bio", - lore: ["Test lore"], - modelProvider: ModelProviderName.OPENAI, - messageExamples: [], - postExamples: [], - topics: [], - adjectives: [], - style: { - all: [], - chat: [], - post: [] - }, - clients: [], - plugins: [], - }, - databaseAdapter: mockDatabaseAdapter, - cacheManager: mockCacheManager, - modelProvider: ModelProviderName.OPENAI, - }); - }); - - describe("Provider Configuration", () => { - test("should load provider configuration from environment", () => { - const provider = runtime.getModelProvider(); - expect(provider.endpoint).toBe(process.env.PROVIDER_ENDPOINT || "https://api.openai.com/v1"); - expect(provider.apiKey).toBe(process.env.PROVIDER_API_KEY || "mock-openai-key"); - expect(provider.provider).toBe(process.env.PROVIDER_NAME || "openai"); - }); - - test("should load model mappings from environment", () => { - const provider = runtime.getModelProvider(); - const models = provider.models; - - expect(models.default.name).toBe(process.env.DEFAULT_MODEL_NAME || "gpt-4o-mini"); - expect(models[ModelClass.SMALL]?.name).toBe(process.env.SMALL_MODEL_NAME || "gpt-4o-mini"); - expect(models[ModelClass.MEDIUM]?.name).toBe(process.env.MEDIUM_MODEL_NAME || "gpt-4o-mini"); - expect(models[ModelClass.LARGE]?.name).toBe(process.env.LARGE_MODEL_NAME || "gpt-4o-mini"); - expect(models[ModelClass.EMBEDDING]?.name).toBe(process.env.EMBEDDING_MODEL_NAME || "text-embedding-3-small"); - }); - - test("should load model settings from environment", () => { - const provider = runtime.getModelProvider(); - const defaultModel = provider.models.default as ModelSettings; - - expect(defaultModel.maxInputTokens).toBe(parseInt(process.env.DEFAULT_MODEL_MAX_INPUT_TOKENS || "4096")); - expect(defaultModel.maxOutputTokens).toBe(parseInt(process.env.DEFAULT_MODEL_MAX_OUTPUT_TOKENS || "1024")); - expect(defaultModel.stop).toEqual(process.env.DEFAULT_MODEL_STOP ? process.env.DEFAULT_MODEL_STOP.split(",") : []); - expect(defaultModel.temperature).toBe(parseFloat(process.env.DEFAULT_MODEL_TEMPERATURE || "0.7")); - expect(defaultModel.frequency_penalty).toBe(parseFloat(process.env.DEFAULT_MODEL_FREQUENCY_PENALTY || "0")); - expect(defaultModel.presence_penalty).toBe(parseFloat(process.env.DEFAULT_MODEL_PRESENCE_PENALTY || "0")); - expect(defaultModel.repetition_penalty).toBe(parseFloat(process.env.DEFAULT_MODEL_REPETITION_PENALTY || "1.0")); - }); - - test("should load embedding model configuration from environment", () => { - const provider = runtime.getModelProvider(); - const embeddingModel = provider.models[ModelClass.EMBEDDING] as EmbeddingModelSettings; - - expect(embeddingModel?.name).toBe(process.env.EMBEDDING_MODEL_NAME || "text-embedding-3-small"); - expect(embeddingModel?.dimensions).toBe(parseInt(process.env.EMBEDDING_DIMENSIONS || "1536")); - }); - }); - - describe("Model Provider Validation", () => { - test("should validate model provider name format", () => { - expect(() => new AgentRuntime({ - token: "test-token", - character: { - name: "Test Character", - username: "test", - bio: "Test bio", - lore: ["Test lore"], - modelProvider: "invalid@provider" as ModelProviderName, - messageExamples: [], - postExamples: [], - topics: [], - adjectives: [], - style: { - all: [], - chat: [], - post: [] - }, - clients: [], - plugins: [], - }, - databaseAdapter: mockDatabaseAdapter, - cacheManager: mockCacheManager, - modelProvider: "invalid@provider" as ModelProviderName, - })).toThrow(/Invalid model provider/); - }); - }); -}); diff --git a/packages/core/__tests__/runtime.test.ts b/packages/core/__tests__/runtime.test.ts index 144fe304bfb..75c11253eb0 100644 --- a/packages/core/__tests__/runtime.test.ts +++ b/packages/core/__tests__/runtime.test.ts @@ -16,6 +16,13 @@ import { } from "../src/types"; import { defaultCharacter } from "../src/defaultCharacter"; +// Mock the embedding module +vi.mock("../src/embedding", () => ({ + embed: vi.fn().mockResolvedValue([0.1, 0.2, 0.3]), + getRemoteEmbedding: vi.fn().mockResolvedValue(new Float32Array([0.1, 0.2, 0.3])), + getLocalEmbedding: vi.fn().mockResolvedValue(new Float32Array([0.1, 0.2, 0.3])) +})); + // Mock dependencies with minimal implementations const mockDatabaseAdapter: IDatabaseAdapter = { db: {}, @@ -66,7 +73,7 @@ const mockDatabaseAdapter: IDatabaseAdapter = { const mockCacheManager = { get: vi.fn().mockResolvedValue(null), set: vi.fn().mockResolvedValue(undefined), - delete: vi.fn().mockResolvedValue(undefined), + delete: vi.fn().mockResolvedValue(undefined) }; // Mock action creator @@ -115,12 +122,6 @@ const mockModelProvider: IModelProvider = { }, }; -// Mock embedding API -vi.mock("../src/embedding", () => ({ - getRemoteEmbedding: vi.fn().mockResolvedValue(new Float32Array([0.1, 0.2, 0.3])), - getLocalEmbedding: vi.fn().mockResolvedValue(new Float32Array([0.1, 0.2, 0.3])), -})); - describe("AgentRuntime", () => { let runtime: AgentRuntime; @@ -484,3 +485,114 @@ describe("Model Provider Configuration", () => { }); }); }); + +describe("ModelProviderManager", () => { + test("should get correct model provider settings", async () => { + const runtime = new AgentRuntime({ + token: "test-token", + modelProvider: ModelProviderName.OPENAI, + databaseAdapter: mockDatabaseAdapter, + cacheManager: { + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), + }, + }); + + const provider = runtime.getModelProvider(); + expect(provider).toBeDefined(); + expect(provider.provider).toBe(ModelProviderName.OPENAI); + }); +}); + +describe("MemoryManagerService", () => { + test("should provide access to different memory managers", async () => { + const runtime = new AgentRuntime({ + token: "test-token", + modelProvider: ModelProviderName.OPENAI, + databaseAdapter: mockDatabaseAdapter, + cacheManager: mockCacheManager + }); + + expect(runtime.messageManager).toBeDefined(); + expect(runtime.descriptionManager).toBeDefined(); + expect(runtime.loreManager).toBeDefined(); + expect(runtime.documentsManager).toBeDefined(); + expect(runtime.knowledgeManager).toBeDefined(); + expect(runtime.ragKnowledgeManager).toBeDefined(); + }); + + test("should allow registering custom memory managers", async () => { + const runtime = new AgentRuntime({ + token: "test-token", + modelProvider: ModelProviderName.OPENAI, + databaseAdapter: mockDatabaseAdapter, + cacheManager: mockCacheManager + }); + + const customManager: IMemoryManager = { + runtime: runtime, + tableName: "custom", + getMemories: vi.fn(), + getCachedEmbeddings: vi.fn(), + getMemoryById: vi.fn(), + getMemoriesByRoomIds: vi.fn(), + searchMemoriesByEmbedding: vi.fn(), + createMemory: vi.fn(), + removeMemory: vi.fn(), + removeAllMemories: vi.fn(), + countMemories: vi.fn(), + addEmbeddingToMemory: vi.fn() + }; + + runtime.registerMemoryManager(customManager); + expect(runtime.getMemoryManager("custom")).toBe(customManager); + }); +}); + +describe("ServiceManager", () => { + test("should handle service registration and retrieval", async () => { + const runtime = new AgentRuntime({ + token: "test-token", + modelProvider: ModelProviderName.OPENAI, + databaseAdapter: mockDatabaseAdapter, + cacheManager: mockCacheManager + }); + + const mockService = { + serviceType: ServiceType.TEXT_GENERATION, + type: ServiceType.TEXT_GENERATION, + initialize: vi.fn().mockResolvedValue(undefined) + }; + + await runtime.registerService(mockService); + const retrievedService = runtime.getService(ServiceType.TEXT_GENERATION); + expect(retrievedService).toBe(mockService); + }); +}); + +describe("Verifiable Inference", () => { + test("should handle verifiable inference adapter", async () => { + const runtime = new AgentRuntime({ + token: "test-token", + modelProvider: ModelProviderName.OPENAI, + databaseAdapter: mockDatabaseAdapter, + cacheManager: mockCacheManager + }); + + const mockAdapter = { + verify: vi.fn(), + options: {}, + generateText: vi.fn(), + verifyProof: vi.fn() + }; + + expect(runtime.getVerifiableInferenceAdapter()).toBeUndefined(); + + runtime.setVerifiableInferenceAdapter(mockAdapter); + expect(runtime.getVerifiableInferenceAdapter()).toBe(mockAdapter); + + runtime.setVerifiableInferenceAdapter(undefined); + expect(runtime.getVerifiableInferenceAdapter()).toBeUndefined(); + }); +}); diff --git a/packages/core/src/embedding.ts b/packages/core/src/embedding.ts index b93a1f2fc99..cf65c2bac4c 100644 --- a/packages/core/src/embedding.ts +++ b/packages/core/src/embedding.ts @@ -45,7 +45,7 @@ export const getEmbeddingConfig = (runtime?: IAgentRuntime): EmbeddingConfig => // Fallback to default config return { - dimensions: 1536, // OpenAI's text-embedding-ada-002 dimension + dimensions: 1536, // OpenAI's text-embedding-3-small dimension model: "text-embedding-3-small", // Default to OpenAI's latest embedding model provider: EmbeddingProvider.OpenAI }; @@ -131,16 +131,14 @@ export function getEmbeddingType(runtime: IAgentRuntime): "local" | "remote" { // - Running in Node.js // - Not using OpenAI provider // - Not forcing OpenAI embeddings - const isLocal = isNode && !settings.USE_OPENAI_EMBEDDING; + const useOpenAI = runtime.getSetting("USE_OPENAI_EMBEDDING") === "true" || settings.USE_OPENAI_EMBEDDING === "true"; + const isLocal = isNode && !useOpenAI; return isLocal ? "local" : "remote"; } export function getEmbeddingZeroVector(): number[] { - let embeddingDimension = 384; // Default BGE dimension - - // TODO: add logic to get from character settings - + let embeddingDimension = settings.USE_OPENAI_EMBEDDING === "true" ? 1536 : 384; return Array(embeddingDimension).fill(0); } diff --git a/packages/core/src/environment.ts b/packages/core/src/environment.ts index 1cc83678a70..05954731d78 100644 --- a/packages/core/src/environment.ts +++ b/packages/core/src/environment.ts @@ -28,7 +28,12 @@ export type EnvConfig = z.infer; // Validation function export function validateEnv(): EnvConfig { try { - return envSchema.parse(process.env); + // Transform provider name to lowercase before validation + const envWithLowercaseProvider = { + ...process.env, + PROVIDER_NAME: process.env.PROVIDER_NAME?.toLowerCase(), + }; + return envSchema.parse(envWithLowercaseProvider); } catch (error) { if (error instanceof z.ZodError) { const errorMessages = error.errors diff --git a/packages/core/src/generation.ts b/packages/core/src/generation.ts index d82d6b55da0..46ebf6a6501 100644 --- a/packages/core/src/generation.ts +++ b/packages/core/src/generation.ts @@ -35,17 +35,6 @@ type Tool = CoreTool; type StepResult = AIStepResult; -interface ModelSettings { - prompt: string; - temperature: number; - maxTokens: number; - frequencyPenalty: number; - presencePenalty: number; - stop?: string[]; - experimental_telemetry?: TelemetrySettings; -} - - interface VerifiedInferenceOptions { verifiableInference?: boolean; verifiableInferenceAdapter?: IVerifiableInferenceAdapter; @@ -62,7 +51,7 @@ interface GenerateObjectOptions extends VerifiedInferenceOptions { schemaDescription?: string; stop?: string[]; mode?: 'auto' | 'json' | 'tool'; - enum?: string[]; + enum?: Array; } // ================ COMMON UTILITIES ================ @@ -313,7 +302,8 @@ async function generateEnum({ runtime, context, modelClass, - schema: z.enum(enumValues as [T, ...T[]]), + output: 'enum', + enum: enumValues as unknown as string[], }); elizaLogger.debug("Received enum response:", result); @@ -340,7 +330,7 @@ export async function generateShouldRespond({ context, modelClass, enumValues: RESPONSE_VALUES, - functionName: 'generateShouldRespond' + functionName: 'generateShouldRespond', }); return result as ResponseType; @@ -376,7 +366,7 @@ export const generateObject = async ({ runtime, context, modelClass=ModelClass.DEFAULT, - mode, + output='object', schema, schemaName, schemaDescription, @@ -400,7 +390,7 @@ export const generateObject = async ({ model: client.languageModel(model), prompt: context.toString(), system: runtime.character.system ?? settings.SYSTEM_PROMPT ?? undefined, - output: schema ? undefined : 'no-schema', + output: output as never, ...(schema ? { schema, schemaName, schemaDescription } : {}), mode: 'json' })