Skip to content

Commit

Permalink
wip: save
Browse files Browse the repository at this point in the history
  • Loading branch information
wtfsayo committed Feb 6, 2025
1 parent a40b7fd commit 8499d96
Show file tree
Hide file tree
Showing 6 changed files with 68,917 additions and 339 deletions.
272 changes: 241 additions & 31 deletions packages/core/__tests__/embedding.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
getEmbeddingType,
getEmbeddingZeroVector,
} from "../src/embedding.ts";
import { type IAgentRuntime, ModelProviderName } from "../types.ts";
import { type IAgentRuntime, ModelProviderName } from "../src/types.ts";
import settings from "../src/settings.ts";

// Mock environment-related settings
Expand All @@ -23,18 +23,26 @@ vi.mock("../settings", () => ({
}));

// Mock fastembed module for local embeddings
vi.mock("fastembed", () => ({
FlagEmbedding: {
init: vi.fn().mockResolvedValue({
queryEmbed: vi
.fn()
.mockResolvedValue(new Float32Array(384).fill(0.1)),
}),
},
EmbeddingModel: {
BGESmallENV15: "BGE-small-en-v1.5",
},
}));
vi.mock("fastembed", () => {
class MockFlagEmbedding {
constructor() {}

static async init() {
return new MockFlagEmbedding();
}

async queryEmbed(text: string | string[]) {
return [new Float32Array(384).fill(0.1)];
}
}

return {
FlagEmbedding: MockFlagEmbedding,
EmbeddingModel: {
BGESmallENV15: "BGE-small-en-v1.5",
},
};
});

// Mock global fetch for remote embedding requests
const mockFetch = vi.fn();
Expand All @@ -44,20 +52,223 @@ describe("Embedding Module", () => {
let mockRuntime: IAgentRuntime;

beforeEach(() => {
// Reset all mocks
vi.clearAllMocks();

// Prepare a mock runtime
mockRuntime = {
agentId: "00000000-0000-0000-0000-000000000000" as `${string}-${string}-${string}-${string}-${string}`,
serverUrl: "http://test-server",
databaseAdapter: {
init: vi.fn(),
close: vi.fn(),
getMemories: vi.fn(),
createMemory: vi.fn(),
removeMemory: vi.fn(),
searchMemories: vi.fn(),
searchMemoriesByEmbedding: vi.fn(),
getGoals: vi.fn(),
createGoal: vi.fn(),
updateGoal: vi.fn(),
removeGoal: vi.fn(),
getRoom: vi.fn(),
createRoom: vi.fn(),
removeRoom: vi.fn(),
addParticipant: vi.fn(),
removeParticipant: vi.fn(),
getParticipantsForRoom: vi.fn(),
getParticipantUserState: vi.fn(),
setParticipantUserState: vi.fn(),
createRelationship: vi.fn(),
getRelationship: vi.fn(),
getRelationships: vi.fn(),
getKnowledge: vi.fn(),
searchKnowledge: vi.fn(),
createKnowledge: vi.fn(),
removeKnowledge: vi.fn(),
clearKnowledge: vi.fn(),
},
token: "test-token",
modelProvider: ModelProviderName.OPENAI,
imageModelProvider: ModelProviderName.OPENAI,
imageVisionModelProvider: ModelProviderName.OPENAI,
providers: [],
actions: [],
evaluators: [],
plugins: [],
character: {
modelProvider: ModelProviderName.OLLAMA,
modelEndpointOverride: null,
modelProvider: ModelProviderName.OPENAI,
name: "Test Character",
username: "test",
bio: ["Test bio"],
lore: ["Test lore"],
messageExamples: [],
postExamples: [],
topics: [],
adjectives: [],
style: {
all: [],
chat: [],
post: []
},
clients: [],
plugins: [],
},
getModelProvider: () => ({
apiKey: "test-key",
endpoint: "test-endpoint",
provider: ModelProviderName.OPENAI,
models: {
default: { name: "test-model", maxInputTokens: 4096, maxOutputTokens: 4096, stop: [], temperature: 0.7 },
},
}),
getSetting: (key: string) => {
const settings = {
USE_OPENAI_EMBEDDING: "false",
USE_OLLAMA_EMBEDDING: "false",
USE_GAIANET_EMBEDDING: "false",
OPENAI_API_KEY: "mock-openai-key",
OPENAI_API_URL: "https://api.openai.com/v1",
GAIANET_API_KEY: "mock-gaianet-key",
OLLAMA_EMBEDDING_MODEL: "mxbai-embed-large",
GAIANET_EMBEDDING_MODEL: "nomic-embed",
};
return settings[key as keyof typeof settings] || "";
},
knowledgeManager: {
init: vi.fn(),
close: vi.fn(),
addKnowledge: vi.fn(),
removeKnowledge: vi.fn(),
searchKnowledge: vi.fn(),
clearKnowledge: vi.fn(),
},
memoryManager: {
init: vi.fn(),
close: vi.fn(),
addMemory: vi.fn(),
removeMemory: vi.fn(),
searchMemories: vi.fn(),
searchMemoriesByEmbedding: vi.fn(),
clearMemories: vi.fn(),
},
goalManager: {
init: vi.fn(),
close: vi.fn(),
addGoal: vi.fn(),
updateGoal: vi.fn(),
removeGoal: vi.fn(),
getGoals: vi.fn(),
clearGoals: vi.fn(),
},
relationshipManager: {
init: vi.fn(),
close: vi.fn(),
addRelationship: vi.fn(),
getRelationship: vi.fn(),
getRelationships: vi.fn(),
},
token: "mock-token",
cacheManager: {
get: vi.fn(),
set: vi.fn(),
delete: vi.fn(),
},
services: new Map(),
clients: {},
messageManager: {
getCachedEmbeddings: vi.fn().mockResolvedValue([]),
runtime: {} as IAgentRuntime,
tableName: "messages",
addEmbeddingToMemory: vi.fn(),
getMemories: vi.fn(),
getCachedEmbeddings: vi.fn(),
getMemoryById: vi.fn(),
getMemoriesByRoomIds: vi.fn(),
searchMemoriesByEmbedding: vi.fn(),
createMemory: vi.fn(),
removeMemory: vi.fn(),
removeAllMemories: vi.fn(),
countMemories: vi.fn(),
},
descriptionManager: {
runtime: {} as IAgentRuntime,
tableName: "descriptions",
addEmbeddingToMemory: vi.fn(),
getMemories: vi.fn(),
getCachedEmbeddings: vi.fn(),
getMemoryById: vi.fn(),
getMemoriesByRoomIds: vi.fn(),
searchMemoriesByEmbedding: vi.fn(),
createMemory: vi.fn(),
removeMemory: vi.fn(),
removeAllMemories: vi.fn(),
countMemories: vi.fn(),
},
documentsManager: {
runtime: {} as IAgentRuntime,
tableName: "documents",
addEmbeddingToMemory: vi.fn(),
getMemories: vi.fn(),
getCachedEmbeddings: vi.fn(),
getMemoryById: vi.fn(),
getMemoriesByRoomIds: vi.fn(),
searchMemoriesByEmbedding: vi.fn(),
createMemory: vi.fn(),
removeMemory: vi.fn(),
removeAllMemories: vi.fn(),
countMemories: vi.fn(),
},
loreManager: {
runtime: {} as IAgentRuntime,
tableName: "lore",
addEmbeddingToMemory: vi.fn(),
getMemories: vi.fn(),
getCachedEmbeddings: vi.fn(),
getMemoryById: vi.fn(),
getMemoriesByRoomIds: vi.fn(),
searchMemoriesByEmbedding: vi.fn(),
createMemory: vi.fn(),
removeMemory: vi.fn(),
removeAllMemories: vi.fn(),
countMemories: vi.fn(),
},
ragKnowledgeManager: {
runtime: {} as IAgentRuntime,
tableName: "rag_knowledge",
getKnowledge: vi.fn(),
createKnowledge: vi.fn(),
removeKnowledge: vi.fn(),
searchKnowledge: vi.fn(),
clearKnowledge: vi.fn(),
processFile: vi.fn(),
cleanupDeletedKnowledgeFiles: vi.fn(),
generateScopedId: vi.fn(),
},
initialize: vi.fn(),
registerMemoryManager: vi.fn(),
getMemoryManager: vi.fn(),
getService: vi.fn(),
registerService: vi.fn(),
composeState: vi.fn(),
processActions: vi.fn(),
evaluate: vi.fn(),
ensureParticipantExists: vi.fn(),
ensureUserExists: vi.fn(),
ensureConnection: vi.fn(),
ensureParticipantInRoom: vi.fn(),
ensureRoomExists: vi.fn(),
updateRecentMessageState: vi.fn(),
getConversationLength: vi.fn(),
registerAction: vi.fn(),
} as unknown as IAgentRuntime;

vi.clearAllMocks();
// Reset fetch mock
mockFetch.mockReset();
mockFetch.mockResolvedValue({
ok: true,
json: async () => ({
data: [new Array(384).fill(0.1)],
}),
});
});

describe("getEmbeddingConfig", () => {
Expand All @@ -67,25 +278,24 @@ describe("Embedding Module", () => {
expect(config.model).toBe("BGE-small-en-v1.5");
expect(config.provider).toBe("BGE");
});

test("should return OpenAI config when USE_OPENAI_EMBEDDING is true", () => {
vi.mocked(settings).USE_OPENAI_EMBEDDING = "true";
const config = getEmbeddingConfig();
expect(config.dimensions).toBe(1536);
expect(config.model).toBe("text-embedding-3-small");
expect(config.provider).toBe("OpenAI");
});
});

describe("getEmbeddingType", () => {
test("should return 'remote' for Ollama provider", () => {
test("should return 'local' by default", () => {
const type = getEmbeddingType(mockRuntime);
expect(type).toBe("remote");
expect(type).toBe("local");
});

test("should return 'remote' for OpenAI provider", () => {
mockRuntime.character.modelProvider = ModelProviderName.OPENAI;
const type = getEmbeddingType(mockRuntime);
test("should return 'remote' when using OpenAI", () => {
const runtimeWithOpenAI = {
...mockRuntime,
getSetting: (key: string) => {
if (key === "USE_OPENAI_EMBEDDING") return "true";
return mockRuntime.getSetting(key);
},
} as IAgentRuntime;

const type = getEmbeddingType(runtimeWithOpenAI);
expect(type).toBe("remote");
});
});
Expand Down
Loading

0 comments on commit 8499d96

Please sign in to comment.