Skip to content

Commit

Permalink
Merge branch 'v2-develop' into tcm/elevnlabs-plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
tcm390 authored Feb 14, 2025
2 parents b1b9fdf + 89a22b2 commit 8a26f38
Show file tree
Hide file tree
Showing 11 changed files with 959 additions and 407 deletions.
1 change: 1 addition & 0 deletions packages/plugin-anthropic/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
],
"dependencies": {
"@ai-sdk/anthropic": "^1.1.6",
"fastembed": "^1.0.0",
"@elizaos/core": "workspace:*",
"tsup": "8.3.5",
"zod": "3.21.4"
Expand Down
39 changes: 37 additions & 2 deletions packages/plugin-anthropic/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
} from "@elizaos/core";
import { generateText } from "ai";
import { z } from "zod";
import { EmbeddingModel, FlagEmbedding } from "fastembed";

// Define a configuration schema for the Anthropics plugin.
const configSchema = z.object({
Expand Down Expand Up @@ -82,12 +83,46 @@ export const anthropicPlugin: Plugin = {
stopSequences,
});
return text;
},
[ModelClass.TEXT_EMBEDDING]: async (runtime, text: string | null) => {

// TODO: Make this fallback only!!!
// TODO: pass on cacheDir to FlagEmbedding.init
if (!text) return new Array(1536).fill(0);

const model = await FlagEmbedding.init({ model: EmbeddingModel.BGESmallENV15 });
const embedding = await model.queryEmbed(text);

const finalEmbedding = Array.isArray(embedding)
? ArrayBuffer.isView(embedding[0])
? Array.from(embedding[0] as never)
: embedding
: Array.from(embedding);

if (!Array.isArray(finalEmbedding) || finalEmbedding[0] === undefined) {
throw new Error("Invalid embedding format");
}

return finalEmbedding.map(Number);
}
},
tests: [
{
name: "anthropic_plugin_tests",
tests: [
{
name: 'anthropic_test_text_embedding',
fn: async (runtime) => {
try {
console.log("testing embedding");
const embedding = await runtime.useModel(ModelClass.TEXT_EMBEDDING, "Hello, world!");
console.log("embedding done", embedding);
} catch (error) {
console.error("Error in test_text_embedding:", error);
throw error;
}
}
},
{
name: 'anthropic_test_text_small',
fn: async (runtime) => {
Expand Down Expand Up @@ -117,9 +152,9 @@ export const anthropicPlugin: Plugin = {
if (text.length === 0) {
throw new Error("Failed to generate text");
}
console.log("generated with test_text_small:", text);
console.log("generated with test_text_large:", text);
} catch (error) {
console.error("Error in test_text_small:", error);
console.error("Error in test_text_large:", error);
throw error;
}
}
Expand Down
14 changes: 8 additions & 6 deletions packages/plugin-discord/__tests__/discord-client.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Events } from 'discord.js';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { DiscordClient } from '../src';
import { DiscordConfig } from '../src/environment';

// Mock @elizaos/core
vi.mock('@elizaos/core', () => ({
Expand Down Expand Up @@ -56,15 +57,13 @@ vi.mock('discord.js', () => {
});

describe('DiscordClient', () => {
let mockConfig: DiscordConfig;
let mockRuntime: any;
let discordClient: DiscordClient;

beforeEach(() => {
mockRuntime = {
getSetting: vi.fn((key: string) => {
if (key === 'DISCORD_API_TOKEN') return 'mock-token';
return undefined;
}),
getSetting: vi.fn(),
getState: vi.fn(),
setState: vi.fn(),
getMemory: vi.fn(),
Expand All @@ -81,13 +80,16 @@ describe('DiscordClient', () => {
}
};

discordClient = new DiscordClient(mockRuntime);
mockConfig = {
DISCORD_API_TOKEN: "mock-token",
}

discordClient = new DiscordClient(mockRuntime, mockConfig);
});

it('should initialize with correct configuration', () => {
expect(discordClient.apiToken).toBe('mock-token');
expect(discordClient.client).toBeDefined();
expect(mockRuntime.getSetting).toHaveBeenCalledWith('DISCORD_API_TOKEN');
});

it('should login to Discord on initialization', () => {
Expand Down
37 changes: 37 additions & 0 deletions packages/plugin-discord/__tests__/environment.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { describe, it, expect } from 'vitest';
import { validateDiscordConfig } from '../src/environment';
import type { IAgentRuntime } from '@elizaos/core';

// Mock runtime environment
const mockRuntime: IAgentRuntime = {
env: {
DISCORD_API_TOKEN: 'mocked-discord-token',
},
getEnv: function (key: string) {
return this.env[key] || null;
},
getSetting: function (key: string) {
return this.env[key] || null;
}
} as unknown as IAgentRuntime;

describe('Discord Environment Configuration', () => {
it('should validate correct configuration', async () => {
const config = await validateDiscordConfig(mockRuntime);
expect(config).toBeDefined();
expect(config.DISCORD_API_TOKEN).toBe('mocked-discord-token');
});

it('should throw an error when DISCORD_API_TOKEN is missing', async () => {
const invalidRuntime = {
...mockRuntime,
env: {
DISCORD_API_TOKEN: undefined,
},
} as IAgentRuntime;

await expect(validateDiscordConfig(invalidRuntime)).rejects.toThrowError(
'Discord configuration validation failed:\nDISCORD_API_TOKEN: Expected string, received null'
);
});
});
211 changes: 211 additions & 0 deletions packages/plugin-discord/__tests__/message.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { MessageManager } from "../src/messages.ts";
import { ChannelType, Client, Collection } from "discord.js";
import { type IAgentRuntime } from "@elizaos/core";
import type { VoiceManager } from "../src/voice";

vi.mock("@elizaos/core", () => ({
logger: {
info: vi.fn(),
error: vi.fn(),
debug: vi.fn(),
},
stringToUuid: (str: string) => str,
messageCompletionFooter:
"# INSTRUCTIONS: Choose the best response for the agent.",
shouldRespondFooter: "# INSTRUCTIONS: Choose if the agent should respond.",
generateMessageResponse: vi.fn(),
generateShouldRespond: vi.fn().mockResolvedValue("IGNORE"), // Prevent API calls by always returning "IGNORE"
composeContext: vi.fn(),
ModelClass: {
TEXT_SMALL: "TEXT_SMALL",
},
ServiceType: {
VIDEO: "VIDEO",
BROWSER: "BROWSER",
},
}));

describe("Discord MessageManager", () => {
let mockRuntime: IAgentRuntime;
let mockClient: Client;
let mockDiscordClient: { client: Client; runtime: IAgentRuntime };
let mockVoiceManager: VoiceManager;
let mockMessage: any;
let messageManager: MessageManager;

beforeEach(() => {
vi.clearAllMocks();
mockRuntime = {
character: {
name: "TestBot",
templates: {},
clientConfig: {
discord: {
allowedChannelIds: ["mock-channal-id"],
shouldIgnoreBotMessages: true,
shouldIgnoreDirectMessages: true,
},
},
},
evaluate: vi.fn(),
composeState: vi.fn(),
ensureConnection: vi.fn(),
ensureUserExists: vi.fn(),
messageManager: {
createMemory: vi.fn(),
addEmbeddingToMemory: vi.fn(),
},
databaseAdapter: {
getParticipantUserState: vi.fn().mockResolvedValue("ACTIVE"),
log: vi.fn(),
},
processActions: vi.fn(),
} as unknown as IAgentRuntime;

mockClient = new Client({ intents: [] });
mockClient.user = {
id: "mock-bot-id",
username: "MockBot",
tag: "MockBot#0001",
displayName: "MockBotDisplay",
} as any;

mockDiscordClient = {
client: mockClient,
runtime: mockRuntime,
};

mockVoiceManager = {
playAudioStream: vi.fn(),
} as unknown as VoiceManager;

messageManager = new MessageManager(mockDiscordClient, mockVoiceManager);

const guild = {
members: {
cache: {
get: vi.fn().mockReturnValue({
nickname: "MockBotNickname",
permissions: {
has: vi.fn().mockReturnValue(true), // Bot has permissions
},
}),
},
},
};
mockMessage = {
content: "Hello, MockBot!",
author: {
id: "mock-user-id",
username: "MockUser",
bot: false,
},
guild,
channel: {
id: "mock-channal-id",
type: ChannelType.GuildText,
send: vi.fn(),
guild,
client: {
user: mockClient.user,
},
permissionsFor: vi.fn().mockReturnValue({
has: vi.fn().mockReturnValue(true),
}),
},
id: "mock-message-id",
createdTimestamp: Date.now(),
mentions: {
has: vi.fn().mockReturnValue(false),
},
reference: null,
attachments: [],
};
});

it("should initialize MessageManager", () => {
expect(messageManager).toBeDefined();
});

it("should process user messages", async () => {
// Prevent further message processing after response check
vi.spyOn(
Object.getPrototypeOf(messageManager),
"_shouldRespond"
).mockReturnValueOnce(false);

await messageManager.handleMessage(mockMessage);
expect(mockRuntime.ensureConnection).toHaveBeenCalled();
expect(mockRuntime.messageManager.createMemory).toHaveBeenCalled();
});

it("should ignore bot messages", async () => {
mockMessage.author.bot = true;
await messageManager.handleMessage(mockMessage);
expect(mockRuntime.ensureConnection).not.toHaveBeenCalled();
});

it("should ignore messages from restricted channels", async () => {
mockMessage.channel.id = "undefined-channel-id";
await messageManager.handleMessage(mockMessage);
expect(mockRuntime.ensureConnection).not.toHaveBeenCalled();
});

it.each([
["Hey MockBot, are you there?", "username"],
["MockBot#0001, respond please.", "tag"],
["MockBotNickname, can you help?", "nickname"],
["MoCkBoT, can you help?", "mixed case mention"],
])(
"should respond if the bot name is included in the message",
async (content) => {
mockMessage.content = content;

const result = await messageManager["_shouldRespond"](
mockMessage,
{} as any
);
expect(result).toBe(true);
}
);

it("should process audio attachments", async () => {
vi.spyOn(
Object.getPrototypeOf(messageManager),
"_shouldRespond"
).mockReturnValueOnce(false);
vi.spyOn(messageManager, "processMessageMedia").mockReturnValueOnce(
Promise.resolve({ processedContent: "", attachments: [] })
);

const myVariable = new Collection<string, any>([
[
"mock-attachment-id",
{
attachment: "https://www.example.mp3",
name: "mock-attachment.mp3",
contentType: "audio/mpeg",
},
],
]);

mockMessage.attachments = myVariable;

const processAttachmentsMock = vi.fn().mockResolvedValue([]);

const mockAttachmentManager = {
processAttachments: processAttachmentsMock,
} as unknown as (typeof messageManager)["attachmentManager"];

// Override the private property with a mock
Object.defineProperty(messageManager, "attachmentManager", {
value: mockAttachmentManager,
writable: true,
});

await messageManager.handleMessage(mockMessage);

expect(processAttachmentsMock).toHaveBeenCalled();
});
});
5 changes: 0 additions & 5 deletions packages/plugin-discord/src/environment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ import type { IAgentRuntime } from "@elizaos/core";
import { z } from "zod";

export const discordEnvSchema = z.object({
DISCORD_APPLICATION_ID: z
.string()
.min(1, "Discord application ID is required"),
DISCORD_API_TOKEN: z.string().min(1, "Discord API token is required"),
});

Expand All @@ -15,8 +12,6 @@ export async function validateDiscordConfig(
): Promise<DiscordConfig> {
try {
const config = {
DISCORD_APPLICATION_ID:
runtime.getSetting("DISCORD_APPLICATION_ID"),
DISCORD_API_TOKEN:
runtime.getSetting("DISCORD_API_TOKEN"),
};
Expand Down
Loading

0 comments on commit 8a26f38

Please sign in to comment.