diff --git a/packages/plugin-discord/__tests__/discord-client.test.ts b/packages/plugin-discord/__tests__/discord-client.test.ts index ee1cb0b8030..73fc2f94edd 100644 --- a/packages/plugin-discord/__tests__/discord-client.test.ts +++ b/packages/plugin-discord/__tests__/discord-client.test.ts @@ -1,6 +1,7 @@ import { Events } from 'discord.js'; import { beforeEach, describe, expect, it, vi } from 'vitest'; import { DiscordClient } from '../src'; +import { DiscordConfig } from '../src/environment'; // Mock @elizaos/core vi.mock('@elizaos/core', () => ({ @@ -56,15 +57,13 @@ vi.mock('discord.js', () => { }); describe('DiscordClient', () => { + let mockConfig: DiscordConfig; let mockRuntime: any; let discordClient: DiscordClient; beforeEach(() => { mockRuntime = { - getSetting: vi.fn((key: string) => { - if (key === 'DISCORD_API_TOKEN') return 'mock-token'; - return undefined; - }), + getSetting: vi.fn(), getState: vi.fn(), setState: vi.fn(), getMemory: vi.fn(), @@ -81,13 +80,16 @@ describe('DiscordClient', () => { } }; - discordClient = new DiscordClient(mockRuntime); + mockConfig = { + DISCORD_API_TOKEN: "mock-token", + } + + discordClient = new DiscordClient(mockRuntime, mockConfig); }); it('should initialize with correct configuration', () => { expect(discordClient.apiToken).toBe('mock-token'); expect(discordClient.client).toBeDefined(); - expect(mockRuntime.getSetting).toHaveBeenCalledWith('DISCORD_API_TOKEN'); }); it('should login to Discord on initialization', () => { diff --git a/packages/plugin-discord/__tests__/environment.test.ts b/packages/plugin-discord/__tests__/environment.test.ts new file mode 100644 index 00000000000..a250a4833ed --- /dev/null +++ b/packages/plugin-discord/__tests__/environment.test.ts @@ -0,0 +1,37 @@ +import { describe, it, expect } from 'vitest'; +import { validateDiscordConfig } from '../src/environment'; +import type { IAgentRuntime } from '@elizaos/core'; + +// Mock runtime environment +const mockRuntime: IAgentRuntime = { + env: { + DISCORD_API_TOKEN: 'mocked-discord-token', + }, + getEnv: function (key: string) { + return this.env[key] || null; + }, + getSetting: function (key: string) { + return this.env[key] || null; + } +} as unknown as IAgentRuntime; + +describe('Discord Environment Configuration', () => { + it('should validate correct configuration', async () => { + const config = await validateDiscordConfig(mockRuntime); + expect(config).toBeDefined(); + expect(config.DISCORD_API_TOKEN).toBe('mocked-discord-token'); + }); + + it('should throw an error when DISCORD_API_TOKEN is missing', async () => { + const invalidRuntime = { + ...mockRuntime, + env: { + DISCORD_API_TOKEN: undefined, + }, + } as IAgentRuntime; + + await expect(validateDiscordConfig(invalidRuntime)).rejects.toThrowError( + 'Discord configuration validation failed:\nDISCORD_API_TOKEN: Expected string, received null' + ); + }); +}); diff --git a/packages/plugin-discord/__tests__/message.test.ts b/packages/plugin-discord/__tests__/message.test.ts new file mode 100644 index 00000000000..b2b498b70de --- /dev/null +++ b/packages/plugin-discord/__tests__/message.test.ts @@ -0,0 +1,211 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { MessageManager } from "../src/messages.ts"; +import { ChannelType, Client, Collection } from "discord.js"; +import { type IAgentRuntime } from "@elizaos/core"; +import type { VoiceManager } from "../src/voice"; + +vi.mock("@elizaos/core", () => ({ + logger: { + info: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }, + stringToUuid: (str: string) => str, + messageCompletionFooter: + "# INSTRUCTIONS: Choose the best response for the agent.", + shouldRespondFooter: "# INSTRUCTIONS: Choose if the agent should respond.", + generateMessageResponse: vi.fn(), + generateShouldRespond: vi.fn().mockResolvedValue("IGNORE"), // Prevent API calls by always returning "IGNORE" + composeContext: vi.fn(), + ModelClass: { + TEXT_SMALL: "TEXT_SMALL", + }, + ServiceType: { + VIDEO: "VIDEO", + BROWSER: "BROWSER", + }, +})); + +describe("Discord MessageManager", () => { + let mockRuntime: IAgentRuntime; + let mockClient: Client; + let mockDiscordClient: { client: Client; runtime: IAgentRuntime }; + let mockVoiceManager: VoiceManager; + let mockMessage: any; + let messageManager: MessageManager; + + beforeEach(() => { + vi.clearAllMocks(); + mockRuntime = { + character: { + name: "TestBot", + templates: {}, + clientConfig: { + discord: { + allowedChannelIds: ["mock-channal-id"], + shouldIgnoreBotMessages: true, + shouldIgnoreDirectMessages: true, + }, + }, + }, + evaluate: vi.fn(), + composeState: vi.fn(), + ensureConnection: vi.fn(), + ensureUserExists: vi.fn(), + messageManager: { + createMemory: vi.fn(), + addEmbeddingToMemory: vi.fn(), + }, + databaseAdapter: { + getParticipantUserState: vi.fn().mockResolvedValue("ACTIVE"), + log: vi.fn(), + }, + processActions: vi.fn(), + } as unknown as IAgentRuntime; + + mockClient = new Client({ intents: [] }); + mockClient.user = { + id: "mock-bot-id", + username: "MockBot", + tag: "MockBot#0001", + displayName: "MockBotDisplay", + } as any; + + mockDiscordClient = { + client: mockClient, + runtime: mockRuntime, + }; + + mockVoiceManager = { + playAudioStream: vi.fn(), + } as unknown as VoiceManager; + + messageManager = new MessageManager(mockDiscordClient, mockVoiceManager); + + const guild = { + members: { + cache: { + get: vi.fn().mockReturnValue({ + nickname: "MockBotNickname", + permissions: { + has: vi.fn().mockReturnValue(true), // Bot has permissions + }, + }), + }, + }, + }; + mockMessage = { + content: "Hello, MockBot!", + author: { + id: "mock-user-id", + username: "MockUser", + bot: false, + }, + guild, + channel: { + id: "mock-channal-id", + type: ChannelType.GuildText, + send: vi.fn(), + guild, + client: { + user: mockClient.user, + }, + permissionsFor: vi.fn().mockReturnValue({ + has: vi.fn().mockReturnValue(true), + }), + }, + id: "mock-message-id", + createdTimestamp: Date.now(), + mentions: { + has: vi.fn().mockReturnValue(false), + }, + reference: null, + attachments: [], + }; + }); + + it("should initialize MessageManager", () => { + expect(messageManager).toBeDefined(); + }); + + it("should process user messages", async () => { + // Prevent further message processing after response check + vi.spyOn( + Object.getPrototypeOf(messageManager), + "_shouldRespond" + ).mockReturnValueOnce(false); + + await messageManager.handleMessage(mockMessage); + expect(mockRuntime.ensureConnection).toHaveBeenCalled(); + expect(mockRuntime.messageManager.createMemory).toHaveBeenCalled(); + }); + + it("should ignore bot messages", async () => { + mockMessage.author.bot = true; + await messageManager.handleMessage(mockMessage); + expect(mockRuntime.ensureConnection).not.toHaveBeenCalled(); + }); + + it("should ignore messages from restricted channels", async () => { + mockMessage.channel.id = "undefined-channel-id"; + await messageManager.handleMessage(mockMessage); + expect(mockRuntime.ensureConnection).not.toHaveBeenCalled(); + }); + + it.each([ + ["Hey MockBot, are you there?", "username"], + ["MockBot#0001, respond please.", "tag"], + ["MockBotNickname, can you help?", "nickname"], + ["MoCkBoT, can you help?", "mixed case mention"], + ])( + "should respond if the bot name is included in the message", + async (content) => { + mockMessage.content = content; + + const result = await messageManager["_shouldRespond"]( + mockMessage, + {} as any + ); + expect(result).toBe(true); + } + ); + + it("should process audio attachments", async () => { + vi.spyOn( + Object.getPrototypeOf(messageManager), + "_shouldRespond" + ).mockReturnValueOnce(false); + vi.spyOn(messageManager, "processMessageMedia").mockReturnValueOnce( + Promise.resolve({ processedContent: "", attachments: [] }) + ); + + const myVariable = new Collection([ + [ + "mock-attachment-id", + { + attachment: "https://www.example.mp3", + name: "mock-attachment.mp3", + contentType: "audio/mpeg", + }, + ], + ]); + + mockMessage.attachments = myVariable; + + const processAttachmentsMock = vi.fn().mockResolvedValue([]); + + const mockAttachmentManager = { + processAttachments: processAttachmentsMock, + } as unknown as (typeof messageManager)["attachmentManager"]; + + // Override the private property with a mock + Object.defineProperty(messageManager, "attachmentManager", { + value: mockAttachmentManager, + writable: true, + }); + + await messageManager.handleMessage(mockMessage); + + expect(processAttachmentsMock).toHaveBeenCalled(); + }); +}); diff --git a/packages/plugin-discord/src/environment.ts b/packages/plugin-discord/src/environment.ts index ad3489a8fc2..a99dbb1e172 100644 --- a/packages/plugin-discord/src/environment.ts +++ b/packages/plugin-discord/src/environment.ts @@ -2,9 +2,6 @@ import type { IAgentRuntime } from "@elizaos/core"; import { z } from "zod"; export const discordEnvSchema = z.object({ - DISCORD_APPLICATION_ID: z - .string() - .min(1, "Discord application ID is required"), DISCORD_API_TOKEN: z.string().min(1, "Discord API token is required"), }); @@ -15,8 +12,6 @@ export async function validateDiscordConfig( ): Promise { try { const config = { - DISCORD_APPLICATION_ID: - runtime.getSetting("DISCORD_APPLICATION_ID"), DISCORD_API_TOKEN: runtime.getSetting("DISCORD_API_TOKEN"), }; diff --git a/packages/plugin-discord/src/index.ts b/packages/plugin-discord/src/index.ts index 1427733fc12..8f6ff84d43b 100644 --- a/packages/plugin-discord/src/index.ts +++ b/packages/plugin-discord/src/index.ts @@ -9,14 +9,14 @@ import { ClientInstance } from "@elizaos/core"; import { - Client, - Events, - GatewayIntentBits, - Partials, - PermissionsBitField, - type Guild, - type MessageReaction, - type User, + Client, + Events, + GatewayIntentBits, + Partials, + PermissionsBitField, + type Guild, + type MessageReaction, + type User, } from "discord.js"; import { EventEmitter } from "events"; import chat_with_attachments from "./actions/chat_with_attachments.ts"; @@ -31,401 +31,380 @@ import voiceStateProvider from "./providers/voiceState.ts"; import reply from "./actions/reply.ts"; import type { IDiscordClient } from "./types.ts"; import { VoiceManager } from "./voice.ts"; +import { validateDiscordConfig, DiscordConfig } from "./environment.ts"; +import { DiscordTestSuite } from "./test-suite.ts"; import { DISCORD_CLIENT_NAME } from "./constants.ts"; export class DiscordClient extends EventEmitter implements IDiscordClient { - apiToken: string; - client: Client; - runtime: IAgentRuntime; - character: Character; - private messageManager: MessageManager; - private voiceManager: VoiceManager; - - constructor(runtime: IAgentRuntime) { - super(); - - this.apiToken = runtime.getSetting("DISCORD_API_TOKEN") as string; - this.client = new Client({ - intents: [ - GatewayIntentBits.Guilds, - GatewayIntentBits.DirectMessages, - GatewayIntentBits.GuildVoiceStates, - GatewayIntentBits.MessageContent, - GatewayIntentBits.GuildMessages, - GatewayIntentBits.DirectMessageTyping, - GatewayIntentBits.GuildMessageTyping, - GatewayIntentBits.GuildMessageReactions, - ], - partials: [ - Partials.Channel, - Partials.Message, - Partials.User, - Partials.Reaction, - ], - }); - - this.runtime = runtime; - this.voiceManager = new VoiceManager(this); - this.messageManager = new MessageManager(this, this.voiceManager); - - this.client.once(Events.ClientReady, this.onClientReady.bind(this)); - this.client.login(this.apiToken); - - this.setupEventListeners(); - } - - private setupEventListeners() { - // When joining to a new server - this.client.on("guildCreate", this.handleGuildCreate.bind(this)); - - this.client.on( - Events.MessageReactionAdd, - this.handleReactionAdd.bind(this) - ); - this.client.on( - Events.MessageReactionRemove, - this.handleReactionRemove.bind(this) - ); - - // Handle voice events with the voice manager - this.client.on( - "voiceStateUpdate", - this.voiceManager.handleVoiceStateUpdate.bind(this.voiceManager) - ); - this.client.on( - "userStream", - this.voiceManager.handleUserStream.bind(this.voiceManager) - ); - - // Handle a new message with the message manager - this.client.on( - Events.MessageCreate, - this.messageManager.handleMessage.bind(this.messageManager) - ); - - // Handle a new interaction - this.client.on( - Events.InteractionCreate, - this.handleInteractionCreate.bind(this) - ); + apiToken: string; + client: Client; + runtime: IAgentRuntime; + character: Character; + messageManager: MessageManager; + voiceManager: VoiceManager; + + constructor(runtime: IAgentRuntime, discordConfig: DiscordConfig) { + super(); + + this.apiToken = discordConfig.DISCORD_API_TOKEN; + this.client = new Client({ + intents: [ + GatewayIntentBits.Guilds, + GatewayIntentBits.DirectMessages, + GatewayIntentBits.GuildVoiceStates, + GatewayIntentBits.MessageContent, + GatewayIntentBits.GuildMessages, + GatewayIntentBits.DirectMessageTyping, + GatewayIntentBits.GuildMessageTyping, + GatewayIntentBits.GuildMessageReactions, + ], + partials: [ + Partials.Channel, + Partials.Message, + Partials.User, + Partials.Reaction, + ], + }); + + this.runtime = runtime; + this.voiceManager = new VoiceManager(this); + this.messageManager = new MessageManager(this, this.voiceManager); + + this.client.once(Events.ClientReady, this.onClientReady.bind(this)); + this.client.login(this.apiToken); + + this.setupEventListeners(); + } + + private setupEventListeners() { + // When joining to a new server + this.client.on("guildCreate", this.handleGuildCreate.bind(this)); + + this.client.on( + Events.MessageReactionAdd, + this.handleReactionAdd.bind(this) + ); + this.client.on( + Events.MessageReactionRemove, + this.handleReactionRemove.bind(this) + ); + + // Handle voice events with the voice manager + this.client.on( + "voiceStateUpdate", + this.voiceManager.handleVoiceStateUpdate.bind(this.voiceManager) + ); + this.client.on( + "userStream", + this.voiceManager.handleUserStream.bind(this.voiceManager) + ); + + // Handle a new message with the message manager + this.client.on( + Events.MessageCreate, + this.messageManager.handleMessage.bind(this.messageManager) + ); + + // Handle a new interaction + this.client.on( + Events.InteractionCreate, + this.handleInteractionCreate.bind(this) + ); + } + + async stop() { + try { + // disconnect websocket + // this unbinds all the listeners + await this.client.destroy(); + } catch (e) { + logger.error("client-discord instance stop err", e); } - - async stop() { - try { - // disconnect websocket - // this unbinds all the listeners - await this.client.destroy(); - } catch (e) { - logger.error("client-discord instance stop err", e); - } + } + + private async onClientReady(readyClient: { user: { tag: any; id: any } }) { + logger.success(`Logged in as ${readyClient.user?.tag}`); + + // Register slash commands + const commands = [ + { + name: "joinchannel", + description: "Join a voice channel", + options: [ + { + name: "channel", + type: 7, // CHANNEL type + description: "The voice channel to join", + required: true, + channel_types: [2], // GuildVoice type + }, + ], + }, + { + name: "leavechannel", + description: "Leave the current voice channel", + }, + ]; + + try { + await this.client.application?.commands.set(commands); + logger.success("Slash commands registered"); + } catch (error) { + console.error("Error registering slash commands:", error); } - private async onClientReady(readyClient: { user: { tag: any; id: any } }) { - logger.success(`Logged in as ${readyClient.user?.tag}`); - - // Register slash commands - const commands = [ - { - name: "joinchannel", - description: "Join a voice channel", - options: [ - { - name: "channel", - type: 7, // CHANNEL type - description: "The voice channel to join", - required: true, - channel_types: [2], // GuildVoice type - }, - ], - }, - { - name: "leavechannel", - description: "Leave the current voice channel", - }, - ]; - + // Required permissions for the bot + const requiredPermissions = [ + // Text Permissions + PermissionsBitField.Flags.ViewChannel, + PermissionsBitField.Flags.SendMessages, + PermissionsBitField.Flags.SendMessagesInThreads, + PermissionsBitField.Flags.CreatePrivateThreads, + PermissionsBitField.Flags.CreatePublicThreads, + PermissionsBitField.Flags.EmbedLinks, + PermissionsBitField.Flags.AttachFiles, + PermissionsBitField.Flags.AddReactions, + PermissionsBitField.Flags.UseExternalEmojis, + PermissionsBitField.Flags.UseExternalStickers, + PermissionsBitField.Flags.MentionEveryone, + PermissionsBitField.Flags.ManageMessages, + PermissionsBitField.Flags.ReadMessageHistory, + // Voice Permissions + PermissionsBitField.Flags.Connect, + PermissionsBitField.Flags.Speak, + PermissionsBitField.Flags.UseVAD, + PermissionsBitField.Flags.PrioritySpeaker, + ].reduce((a, b) => a | b, 0n); + + logger.success("Use this URL to add the bot to your server:"); + logger.success( + `https://discord.com/api/oauth2/authorize?client_id=${readyClient.user?.id}&permissions=${requiredPermissions}&scope=bot%20applications.commands` + ); + await this.onReady(); + } + + async handleReactionAdd(reaction: MessageReaction, user: User) { + try { + logger.log("Reaction added"); + + // Early returns + if (!reaction || !user) { + logger.warn("Invalid reaction or user"); + return; + } + + // Get emoji info + let emoji = reaction.emoji.name; + if (!emoji && reaction.emoji.id) { + emoji = `<:${reaction.emoji.name}:${reaction.emoji.id}>`; + } + + // Fetch full message if partial + if (reaction.partial) { try { - await this.client.application?.commands.set(commands); - logger.success("Slash commands registered"); + await reaction.fetch(); } catch (error) { - console.error("Error registering slash commands:", error); + logger.error("Failed to fetch partial reaction:", error); + return; } - - // Required permissions for the bot - const requiredPermissions = [ - // Text Permissions - PermissionsBitField.Flags.ViewChannel, - PermissionsBitField.Flags.SendMessages, - PermissionsBitField.Flags.SendMessagesInThreads, - PermissionsBitField.Flags.CreatePrivateThreads, - PermissionsBitField.Flags.CreatePublicThreads, - PermissionsBitField.Flags.EmbedLinks, - PermissionsBitField.Flags.AttachFiles, - PermissionsBitField.Flags.AddReactions, - PermissionsBitField.Flags.UseExternalEmojis, - PermissionsBitField.Flags.UseExternalStickers, - PermissionsBitField.Flags.MentionEveryone, - PermissionsBitField.Flags.ManageMessages, - PermissionsBitField.Flags.ReadMessageHistory, - // Voice Permissions - PermissionsBitField.Flags.Connect, - PermissionsBitField.Flags.Speak, - PermissionsBitField.Flags.UseVAD, - PermissionsBitField.Flags.PrioritySpeaker, - ].reduce((a, b) => a | b, 0n); - - logger.success("Use this URL to add the bot to your server:"); - logger.success( - `https://discord.com/api/oauth2/authorize?client_id=${readyClient.user?.id}&permissions=${requiredPermissions}&scope=bot%20applications.commands` - ); - await this.onReady(); - } - - async handleReactionAdd(reaction: MessageReaction, user: User) { - try { - logger.log("Reaction added"); - - // Early returns - if (!reaction || !user) { - logger.warn("Invalid reaction or user"); - return; - } - - // Get emoji info - let emoji = reaction.emoji.name; - if (!emoji && reaction.emoji.id) { - emoji = `<:${reaction.emoji.name}:${reaction.emoji.id}>`; - } - - // Fetch full message if partial - if (reaction.partial) { - try { - await reaction.fetch(); - } catch (error) { - logger.error( - "Failed to fetch partial reaction:", - error - ); - return; - } - } - - // Generate IDs with timestamp to ensure uniqueness - const timestamp = Date.now(); - const roomId = stringToUuid( - `${reaction.message.channel.id}-${this.runtime.agentId}` - ); - const userIdUUID = stringToUuid( - `${user.id}-${this.runtime.agentId}` - ); - const reactionUUID = stringToUuid( - `${reaction.message.id}-${user.id}-${emoji}-${timestamp}-${this.runtime.agentId}` - ); - - // Validate IDs - if (!userIdUUID || !roomId) { - logger.error("Invalid user ID or room ID", { - userIdUUID, - roomId, - }); - return; - } - - // Process message content - const messageContent = reaction.message.content || ""; - const truncatedContent = - messageContent.length > 100 - ? `${messageContent.substring(0, 100)}...` - : messageContent; - const reactionMessage = `*<${emoji}>: "${truncatedContent}"*`; - - // Get user info - const userName = reaction.message.author?.username || "unknown"; - const name = reaction.message.author?.displayName || userName; - - // Ensure connection - await this.runtime.ensureConnection( - userIdUUID, - roomId, - userName, - name, - "discord" - ); - - // Create memory with retry logic - const memory = { - id: reactionUUID, - userId: userIdUUID, - agentId: this.runtime.agentId, - content: { - text: reactionMessage, - source: "discord", - inReplyTo: stringToUuid( - `${reaction.message.id}-${this.runtime.agentId}` - ), - }, - roomId, - createdAt: timestamp, - }; - - try { - await this.runtime.messageManager.createMemory(memory); - logger.debug("Reaction memory created", { - reactionId: reactionUUID, - emoji, - userId: user.id, - }); - } catch (error) { - if (error.code === "23505") { - // Duplicate key error - logger.warn("Duplicate reaction memory, skipping", { - reactionId: reactionUUID, - }); - return; - } - throw error; // Re-throw other errors - } - } catch (error) { - logger.error("Error handling reaction:", error); + } + + // Generate IDs with timestamp to ensure uniqueness + const timestamp = Date.now(); + const roomId = stringToUuid( + `${reaction.message.channel.id}-${this.runtime.agentId}` + ); + const userIdUUID = stringToUuid(`${user.id}-${this.runtime.agentId}`); + const reactionUUID = stringToUuid( + `${reaction.message.id}-${user.id}-${emoji}-${timestamp}-${this.runtime.agentId}` + ); + + // Validate IDs + if (!userIdUUID || !roomId) { + logger.error("Invalid user ID or room ID", { + userIdUUID, + roomId, + }); + return; + } + + // Process message content + const messageContent = reaction.message.content || ""; + const truncatedContent = + messageContent.length > 100 + ? `${messageContent.substring(0, 100)}...` + : messageContent; + const reactionMessage = `*<${emoji}>: "${truncatedContent}"*`; + + // Get user info + const userName = reaction.message.author?.username || "unknown"; + const name = reaction.message.author?.displayName || userName; + + // Ensure connection + await this.runtime.ensureConnection( + userIdUUID, + roomId, + userName, + name, + "discord" + ); + + // Create memory with retry logic + const memory = { + id: reactionUUID, + userId: userIdUUID, + agentId: this.runtime.agentId, + content: { + text: reactionMessage, + source: "discord", + inReplyTo: stringToUuid( + `${reaction.message.id}-${this.runtime.agentId}` + ), + }, + roomId, + createdAt: timestamp, + }; + + try { + await this.runtime.messageManager.createMemory(memory); + logger.debug("Reaction memory created", { + reactionId: reactionUUID, + emoji, + userId: user.id, + }); + } catch (error) { + if (error.code === "23505") { + // Duplicate key error + logger.warn("Duplicate reaction memory, skipping", { + reactionId: reactionUUID, + }); + return; } + throw error; // Re-throw other errors + } + } catch (error) { + logger.error("Error handling reaction:", error); } + } - async handleReactionRemove(reaction: MessageReaction, user: User) { - logger.log("Reaction removed"); - // if (user.bot) return; - - let emoji = reaction.emoji.name; - if (!emoji && reaction.emoji.id) { - emoji = `<:${reaction.emoji.name}:${reaction.emoji.id}>`; - } + async handleReactionRemove(reaction: MessageReaction, user: User) { + logger.log("Reaction removed"); + // if (user.bot) return; - // Fetch the full message if it's a partial - if (reaction.partial) { - try { - await reaction.fetch(); - } catch (error) { - console.error( - "Something went wrong when fetching the message:", - error - ); - return; - } - } - - const messageContent = reaction.message.content; - const truncatedContent = - messageContent.length > 50 - ? messageContent.substring(0, 50) + "..." - : messageContent; - - const reactionMessage = `*Removed <${emoji} emoji> from: "${truncatedContent}"*`; - - const roomId = stringToUuid( - reaction.message.channel.id + "-" + this.runtime.agentId - ); - const userIdUUID = stringToUuid(user.id); - - // Generate a unique UUID for the reaction removal - const reactionUUID = stringToUuid( - `${reaction.message.id}-${user.id}-${emoji}-removed-${this.runtime.agentId}` - ); - - const userName = reaction.message.author.username; - const name = reaction.message.author.displayName; - - await this.runtime.ensureConnection( - userIdUUID, - roomId, - userName, - name, - "discord" - ); - - try { - // Save the reaction removal as a message - await this.runtime.messageManager.createMemory({ - id: reactionUUID, // This is the ID of the reaction removal message - userId: userIdUUID, - agentId: this.runtime.agentId, - content: { - text: reactionMessage, - source: "discord", - inReplyTo: stringToUuid( - reaction.message.id + "-" + this.runtime.agentId - ), // This is the ID of the original message - }, - roomId, - createdAt: Date.now(), - }); - } catch (error) { - console.error("Error creating reaction removal message:", error); - } + let emoji = reaction.emoji.name; + if (!emoji && reaction.emoji.id) { + emoji = `<:${reaction.emoji.name}:${reaction.emoji.id}>`; } - private handleGuildCreate(guild: Guild) { - console.log(`Joined guild ${guild.name}`); - this.voiceManager.scanGuild(guild); + // Fetch the full message if it's a partial + if (reaction.partial) { + try { + await reaction.fetch(); + } catch (error) { + console.error("Something went wrong when fetching the message:", error); + return; + } } - private async handleInteractionCreate(interaction: any) { - if (!interaction.isCommand()) return; - - switch (interaction.commandName) { - case "joinchannel": - await this.voiceManager.handleJoinChannelCommand(interaction); - break; - case "leavechannel": - await this.voiceManager.handleLeaveChannelCommand(interaction); - break; - } + const messageContent = reaction.message.content; + const truncatedContent = + messageContent.length > 50 + ? messageContent.substring(0, 50) + "..." + : messageContent; + + const reactionMessage = `*Removed <${emoji} emoji> from: "${truncatedContent}"*`; + + const roomId = stringToUuid( + reaction.message.channel.id + "-" + this.runtime.agentId + ); + const userIdUUID = stringToUuid(user.id); + + // Generate a unique UUID for the reaction removal + const reactionUUID = stringToUuid( + `${reaction.message.id}-${user.id}-${emoji}-removed-${this.runtime.agentId}` + ); + + const userName = reaction.message.author.username; + const name = reaction.message.author.displayName; + + await this.runtime.ensureConnection( + userIdUUID, + roomId, + userName, + name, + "discord" + ); + + try { + // Save the reaction removal as a message + await this.runtime.messageManager.createMemory({ + id: reactionUUID, // This is the ID of the reaction removal message + userId: userIdUUID, + agentId: this.runtime.agentId, + content: { + text: reactionMessage, + source: "discord", + inReplyTo: stringToUuid( + reaction.message.id + "-" + this.runtime.agentId + ), // This is the ID of the original message + }, + roomId, + createdAt: Date.now(), + }); + } catch (error) { + console.error("Error creating reaction removal message:", error); + } + } + + private handleGuildCreate(guild: Guild) { + console.log(`Joined guild ${guild.name}`); + this.voiceManager.scanGuild(guild); + } + + private async handleInteractionCreate(interaction: any) { + if (!interaction.isCommand()) return; + + switch (interaction.commandName) { + case "joinchannel": + await this.voiceManager.handleJoinChannelCommand(interaction); + break; + case "leavechannel": + await this.voiceManager.handleLeaveChannelCommand(interaction); + break; } + } - private async onReady() { - const guilds = await this.client.guilds.fetch(); - for (const [, guild] of guilds) { - const fullGuild = await guild.fetch(); - this.voiceManager.scanGuild(fullGuild); - } + private async onReady() { + const guilds = await this.client.guilds.fetch(); + for (const [, guild] of guilds) { + const fullGuild = await guild.fetch(); + this.voiceManager.scanGuild(fullGuild); } + } } const DiscordClientInterface: ElizaClient = { - name: DISCORD_CLIENT_NAME, - start: async (runtime: IAgentRuntime) => new DiscordClient(runtime), -}; - -const testSuite: TestSuite = { - name: "discord", - tests: [ - { - name: "test creating discord client", - fn: async (runtime: IAgentRuntime) => { - const discordClient = new DiscordClient(runtime); - console.log("Created a discord client"); - } - } - ] + name: "discord", + start: async (runtime: IAgentRuntime) => { + const discordConfig: DiscordConfig = await validateDiscordConfig(runtime); + return new DiscordClient(runtime, discordConfig); + }, }; const discordPlugin: Plugin = { - name: "discord", - description: "Discord client plugin", - clients: [DiscordClientInterface], - actions: [ - reply, - chat_with_attachments, - download_media, - joinvoice, - leavevoice, - summarize, - transcribe_media, - ], - providers: [ - channelStateProvider, - voiceStateProvider, - ], - tests: [ - testSuite, - ] + name: "discord", + description: "Discord client plugin", + clients: [DiscordClientInterface], + actions: [ + reply, + chat_with_attachments, + download_media, + joinvoice, + leavevoice, + summarize, + transcribe_media, + ], + providers: [channelStateProvider, voiceStateProvider], + tests: [new DiscordTestSuite()], }; -export default discordPlugin; \ No newline at end of file +export default discordPlugin; diff --git a/packages/plugin-discord/src/messages.ts b/packages/plugin-discord/src/messages.ts index cc95ad25359..c87304949da 100644 --- a/packages/plugin-discord/src/messages.ts +++ b/packages/plugin-discord/src/messages.ts @@ -19,8 +19,6 @@ import { MESSAGE_LENGTH_THRESHOLDS } from "./constants.ts"; import { - discordAnnouncementHypeTemplate, - discordAutoPostTemplate, discordMessageHandlerTemplate, discordShouldRespondTemplate } from "./templates.ts"; @@ -50,29 +48,21 @@ export class MessageManager { private runtime: IAgentRuntime; private attachmentManager: AttachmentManager; private interestChannels: InterestChannels = {}; - private discordClient: any; private voiceManager: VoiceManager; - //Auto post - private lastChannelActivity: { [channelId: string]: number } = {}; - private autoPostInterval: NodeJS.Timeout; constructor(discordClient: any, voiceManager: VoiceManager) { this.client = discordClient.client; this.voiceManager = voiceManager; - this.discordClient = discordClient; this.runtime = discordClient.runtime; this.attachmentManager = new AttachmentManager(this.runtime); } - async handleMessage(message: DiscordMessage) { + async handleMessage(message: DiscordMessage) { if (this.runtime.character.clientConfig?.discord?.allowedChannelIds && - !this.runtime.character.clientConfig.discord.allowedChannelIds.includes(message.channelId)) { + !this.runtime.character.clientConfig.discord.allowedChannelIds.some((id: string) => id == message.channel.id)) { return; } - // Update last activity time for the channel - this.lastChannelActivity[message.channelId] = Date.now(); - if ( message.interaction || message.author.id === @@ -106,7 +96,6 @@ export class MessageManager { try { const { processedContent, attachments } = await this.processMessageMedia(message); - const audioAttachments = message.attachments.filter((attachment) => attachment.contentType?.startsWith("audio/") ); diff --git a/packages/plugin-discord/src/test-suite.ts b/packages/plugin-discord/src/test-suite.ts new file mode 100644 index 00000000000..fe4e404fe64 --- /dev/null +++ b/packages/plugin-discord/src/test-suite.ts @@ -0,0 +1,291 @@ +import { + logger, + type TestSuite, + type IAgentRuntime, + ModelClass, +} from "@elizaos/core"; +import { DiscordClient } from "./index.ts"; +import { DiscordConfig, validateDiscordConfig } from "./environment"; +import { sendMessageInChunks } from "./utils.ts"; +import { ChannelType, Events, TextChannel } from "discord.js"; +import { + createAudioPlayer, + NoSubscriberBehavior, + createAudioResource, + AudioPlayerStatus, +} from "@discordjs/voice"; + +export class DiscordTestSuite implements TestSuite { + name = "discord"; + private discordClient: DiscordClient | null = null; + tests: { name: string; fn: (runtime: IAgentRuntime) => Promise }[]; + + constructor() { + this.tests = [ + { + name: "test creating discord client", + fn: this.testCreatingDiscordClient.bind(this), + }, + { + name: "test joining voice channel", + fn: this.testJoiningVoiceChannel.bind(this), + }, + { + name: "test text-to-speech playback", + fn: this.testTextToSpeechPlayback.bind(this), + }, + { + name: "test sending message", + fn: this.testSendingTextMessage.bind(this), + }, + { + name: "handle message in message manager", + fn: this.testHandlingMessage.bind(this), + }, + ]; + } + + async testCreatingDiscordClient(runtime: IAgentRuntime) { + try { + const existingPlugin = runtime.getClient("discord"); + + if (existingPlugin) { + // Reuse the existing DiscordClient if available + this.discordClient = existingPlugin as DiscordClient; + logger.info("Reusing existing DiscordClient instance."); + } else { + if (!this.discordClient) { + const discordConfig: DiscordConfig = await validateDiscordConfig( + runtime + ); + this.discordClient = new DiscordClient(runtime, discordConfig); + await new Promise((resolve, reject) => { + this.discordClient.client.once(Events.ClientReady, resolve); + this.discordClient.client.once(Events.Error, reject); + }); + } else { + logger.info("Reusing existing DiscordClient instance."); + } + logger.success("DiscordClient successfully initialized."); + } + } catch (error) { + throw new Error(`Error in test creating Discord client: ${error}`); + } + } + + async testJoiningVoiceChannel(runtime: IAgentRuntime) { + try { + let voiceChannel = null; + let channelId = process.env.DISCORD_VOICE_CHANNEL_ID || null; + + if (!channelId) { + const guilds = await this.discordClient.client.guilds.fetch(); + for (const [, guild] of guilds) { + const fullGuild = await guild.fetch(); + const voiceChannels = fullGuild.channels.cache + .filter((c) => c.type === ChannelType.GuildVoice) + .values(); + voiceChannel = voiceChannels.next().value; + if (voiceChannel) break; + } + + if (!voiceChannel) { + logger.warn("No suitable voice channel found to join."); + return; + } + } else { + voiceChannel = await this.discordClient.client.channels.fetch( + channelId + ); + } + + if (!voiceChannel || voiceChannel.type !== ChannelType.GuildVoice) { + logger.error("Invalid voice channel."); + return; + } + + await this.discordClient.voiceManager.joinChannel(voiceChannel); + + logger.success(`Joined voice channel: ${voiceChannel.id}`); + } catch (error) { + logger.error("Error joining voice channel:", error); + } + } + + async testTextToSpeechPlayback(runtime: IAgentRuntime) { + try { + let guildId = this.discordClient.client.guilds.cache.find( + (guild) => guild.members.me?.voice.channelId + )?.id; + + if (!guildId) { + logger.warn( + "Bot is not connected to a voice channel. Attempting to join one..." + ); + + await this.testJoiningVoiceChannel(runtime); + + guildId = this.discordClient.client.guilds.cache.find( + (guild) => guild.members.me?.voice.channelId + )?.id; + + if (!guildId) { + logger.error("Failed to join a voice channel. TTS playback aborted."); + return; + } + } + + const connection = + this.discordClient.voiceManager.getVoiceConnection(guildId); + if (!connection) { + logger.warn("No active voice connection found for the bot."); + return; + } + + let responseStream = null; + try { + responseStream = await runtime.useModel( + ModelClass.TEXT_TO_SPEECH, + `Hi! I'm ${runtime.character.name}! How are you doing today?` + ); + } catch(error) { + logger.warn("No text to speech service found"); + return; + } + + + if (!responseStream) { + logger.error("TTS response stream is null or undefined."); + return; + } + + const audioPlayer = createAudioPlayer({ + behaviors: { + noSubscriber: NoSubscriberBehavior.Pause, + }, + }); + + const audioResource = createAudioResource(responseStream); + + audioPlayer.play(audioResource); + connection.subscribe(audioPlayer); + + logger.success("TTS playback started successfully."); + + await new Promise((resolve, reject) => { + audioPlayer.once(AudioPlayerStatus.Idle, () => { + logger.info("TTS playback finished."); + resolve(); + }); + + audioPlayer.once("error", (error) => { + logger.error("TTS playback error:", error); + reject(error); + }); + }); + } catch (error) { + logger.error("Error in TTS playback test:", error); + } + } + + async testSendingTextMessage(runtime: IAgentRuntime) { + try { + const channel = await this.getTextChannel(); + if (!channel) return; + + await this.sendMessageToChannel(channel, "Testing sending message"); + } catch (error) { + logger.error("Error in sending text message:", error); + } + } + + async testHandlingMessage(runtime: IAgentRuntime) { + try { + const channel = await this.getTextChannel(); + if (!channel) return; + + const fakeMessage = { + content: `Hello, ${runtime.character.name}! How are you?`, + author: { + id: "mock-user-id", + username: "MockUser", + bot: false, + }, + channel, + id: "mock-message-id", + createdTimestamp: Date.now(), + mentions: { + has: () => false, + }, + reference: null, + attachments: [], + }; + await this.discordClient.messageManager.handleMessage(fakeMessage as any); + + } catch (error) { + logger.error("Error in sending text message:", error); + } + } + + + async getTextChannel(): Promise { + try { + let channel: TextChannel | null = null; + const channelId = process.env.DISCORD_TEXT_CHANNEL_ID || null; + + if (!channelId) { + const guilds = await this.discordClient.client.guilds.fetch(); + for (const [, guild] of guilds) { + const fullGuild = await guild.fetch(); + const textChannels = fullGuild.channels.cache + .filter((c) => c.type === ChannelType.GuildText) + .values(); + channel = textChannels.next().value as TextChannel; + if (channel) break; // Stop if we found a valid channel + } + + if (!channel) { + logger.warn("No suitable text channel found."); + return null; + } + } else { + const fetchedChannel = await this.discordClient.client.channels.fetch(channelId); + if (fetchedChannel && fetchedChannel.isTextBased()) { + channel = fetchedChannel as TextChannel; + } else { + logger.warn(`Provided channel ID (${channelId}) is invalid or not a text channel.`); + return null; + } + } + + if (!channel) { + logger.warn("Failed to determine a valid text channel."); + return null; + } + + return channel; + } catch (error) { + logger.error("Error fetching text channel:", error); + return null; + } + } + + + async sendMessageToChannel(channel: TextChannel, messageContent: string) { + try { + if (!channel || !channel.isTextBased()) { + logger.error("Channel is not a text-based channel or does not exist."); + return; + } + + await sendMessageInChunks( + channel as TextChannel, + messageContent, + null, + null + ); + } catch (error) { + logger.error("Error sending message:", error); + } + } +} diff --git a/packages/plugin-discord/src/voice.ts b/packages/plugin-discord/src/voice.ts index 070ec363d34..1886d8c6142 100644 --- a/packages/plugin-discord/src/voice.ts +++ b/packages/plugin-discord/src/voice.ts @@ -327,7 +327,7 @@ export class VoiceManager extends EventEmitter { } } - private getVoiceConnection(guildId: string) { + getVoiceConnection(guildId: string) { const connections = getVoiceConnections(this.client.user.id); if (!connections) { return; @@ -509,7 +509,7 @@ export class VoiceManager extends EventEmitter { } finally { this.processingVoice = false; } - }, DEBOUNCE_TRANSCRIPTION_THRESHOLD); + }, DEBOUNCE_TRANSCRIPTION_THRESHOLD) as unknown as NodeJS.Timeout; } async handleUserStream(