diff --git a/packages/adapter-drizzle/.gitignore b/packages/adapter-drizzle/.gitignore new file mode 100644 index 0000000000..7a2e48a9e7 --- /dev/null +++ b/packages/adapter-drizzle/.gitignore @@ -0,0 +1,176 @@ +# Based on https://raw.githubusercontent.com/github/gitignore/main/Node.gitignore + +# Logs + +logs +_.log +npm-debug.log_ +yarn-debug.log* +yarn-error.log* +lerna-debug.log* +.pnpm-debug.log* + +# Caches + +.cache + +# Diagnostic reports (https://nodejs.org/api/report.html) + +report.[0-9]_.[0-9]_.[0-9]_.[0-9]_.json + +# Runtime data + +pids +_.pid +_.seed +*.pid.lock + +# Directory for instrumented libs generated by jscoverage/JSCover + +lib-cov + +# Coverage directory used by tools like istanbul + +coverage +*.lcov + +# nyc test coverage + +.nyc_output + +# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) + +.grunt + +# Bower dependency directory (https://bower.io/) + +bower_components + +# node-waf configuration + +.lock-wscript + +# Compiled binary addons (https://nodejs.org/api/addons.html) + +build/Release + +# Dependency directories + +node_modules/ +jspm_packages/ + +# Snowpack dependency directory (https://snowpack.dev/) + +web_modules/ + +# TypeScript cache + +*.tsbuildinfo + +# Optional npm cache directory + +.npm + +# Optional eslint cache + +.eslintcache + +# Optional stylelint cache + +.stylelintcache + +# Microbundle cache + +.rpt2_cache/ +.rts2_cache_cjs/ +.rts2_cache_es/ +.rts2_cache_umd/ + +# Optional REPL history + +.node_repl_history + +# Output of 'npm pack' + +*.tgz + +# Yarn Integrity file + +.yarn-integrity + +# dotenv environment variable files + +.env +.env.development.local +.env.test.local +.env.production.local +.env.local +.env.test + +# parcel-bundler cache (https://parceljs.org/) + +.parcel-cache + +# Next.js build output + +.next +out + +# Nuxt.js build / generate output + +.nuxt +dist + +# Gatsby files + +# Comment in the public line in if your project uses Gatsby and not Next.js + +# https://nextjs.org/blog/next-9-1#public-directory-support + +# public + +# vuepress build output + +.vuepress/dist + +# vuepress v2.x temp and cache directory + +.temp + +# Docusaurus cache and generated files + +.docusaurus + +# Serverless directories + +.serverless/ + +# FuseBox cache + +.fusebox/ + +# DynamoDB Local files + +.dynamodb/ + +# TernJS port file + +.tern-port + +# Stores VSCode versions used for testing VSCode extensions + +.vscode-test + +# yarn v2 + +.yarn/cache +.yarn/unplugged +.yarn/build-state.yml +.yarn/install-state.gz +.pnp.* + +# IntelliJ based IDEs +.idea + +# Finder (MacOS) folder config +.DS_Store diff --git a/packages/adapter-drizzle/.npmignore b/packages/adapter-drizzle/.npmignore new file mode 100644 index 0000000000..eb4b3947ff --- /dev/null +++ b/packages/adapter-drizzle/.npmignore @@ -0,0 +1,9 @@ +* + +!dist/** +!package.json +!readme.md +!tsup.config.ts +!schema.sql +!seed.sql +!config.toml \ No newline at end of file diff --git a/packages/adapter-drizzle/README.md b/packages/adapter-drizzle/README.md new file mode 100644 index 0000000000..888f03e95e --- /dev/null +++ b/packages/adapter-drizzle/README.md @@ -0,0 +1,15 @@ +# @elizaos/adapter-drizzle + +To install dependencies: + +```bash +bun install +``` + +To run: + +```bash +bun run dist/index.js +``` + +This project was created using `bun init` in bun v1.2.1. [Bun](https://bun.sh) is a fast all-in-one JavaScript runtime. diff --git a/packages/adapter-drizzle/config.toml b/packages/adapter-drizzle/config.toml new file mode 100644 index 0000000000..c1f016d4a4 --- /dev/null +++ b/packages/adapter-drizzle/config.toml @@ -0,0 +1,159 @@ +# A string used to distinguish different Supabase projects on the same host. Defaults to the +# working directory name when running `supabase init`. +project_id = "eliza" + +[api] +enabled = true +# Port to use for the API URL. +port = 54321 +# Schemas to expose in your API. Tables, views and stored procedures in this schema will get API +# endpoints. public and storage are always included. +schemas = ["public", "storage", "graphql_public"] +# Extra schemas to add to the search_path of every request. public is always included. +extra_search_path = ["public", "extensions"] +# The maximum number of rows returns from a view, table, or stored procedure. Limits payload size +# for accidental or malicious requests. +max_rows = 1000 + +[db] +# Port to use for the local database URL. +port = 54322 +# Port used by db diff command to initialize the shadow database. +shadow_port = 54320 +# The database major version to use. This has to be the same as your remote database's. Run `SHOW +# server_version;` on the remote database to check. +major_version = 15 + +[db.pooler] +enabled = false +# Port to use for the local connection pooler. +port = 54329 +# Specifies when a server connection can be reused by other clients. +# Configure one of the supported pooler modes: `transaction`, `session`. +pool_mode = "transaction" +# How many server connections to allow per user/database pair. +default_pool_size = 20 +# Maximum number of client connections allowed. +max_client_conn = 100 + +[realtime] +enabled = true +# Bind realtime via either IPv4 or IPv6. (default: IPv6) +# ip_version = "IPv6" +# The maximum length in bytes of HTTP request headers. (default: 4096) +# max_header_length = 4096 + +[studio] +enabled = true +# Port to use for Supabase Studio. +port = 54323 +# External URL of the API server that frontend connects to. +api_url = "http://127.0.0.1" + +# Email testing server. Emails sent with the local dev setup are not actually sent - rather, they +# are monitored, and you can view the emails that would have been sent from the web interface. +[inbucket] +enabled = true +# Port to use for the email testing server web interface. +port = 54324 +# Uncomment to expose additional ports for testing user applications that send emails. +# smtp_port = 54325 +# pop3_port = 54326 + +[storage] +enabled = true +# The maximum file size allowed (e.g. "5MB", "500KB"). +file_size_limit = "50MiB" + +[auth] +enabled = true +# The base URL of your website. Used as an allow-list for redirects and for constructing URLs used +# in emails. +site_url = "http://127.0.0.1:3000" +# A list of *exact* URLs that auth providers are permitted to redirect to post authentication. +additional_redirect_urls = ["https://127.0.0.1:3000"] +# How long tokens are valid for, in seconds. Defaults to 3600 (1 hour), maximum 604,800 (1 week). +jwt_expiry = 3600 +# If disabled, the refresh token will never expire. +enable_refresh_token_rotation = true +# Allows refresh tokens to be reused after expiry, up to the specified interval in seconds. +# Requires enable_refresh_token_rotation = true. +refresh_token_reuse_interval = 10 +# Allow/disallow new user signups to your project. +enable_signup = true +# Allow/disallow testing manual linking of accounts +enable_manual_linking = false + +[auth.email] +# Allow/disallow new user signups via email to your project. +enable_signup = true +# If enabled, a user will be required to confirm any email change on both the old, and new email +# addresses. If disabled, only the new email is required to confirm. +double_confirm_changes = true +# If enabled, users need to confirm their email address before signing in. +enable_confirmations = false + +# Uncomment to customize email template +# [auth.email.template.invite] +# subject = "You have been invited" +# content_path = "./supabase/templates/invite.html" + +[auth.sms] +# Allow/disallow new user signups via SMS to your project. +enable_signup = true +# If enabled, users need to confirm their phone number before signing in. +enable_confirmations = false +# Template for sending OTP to users +template = "Your code is {{ .Code }} ." + +# Use pre-defined map of phone number to OTP for testing. +[auth.sms.test_otp] +# 4152127777 = "123456" + +# This hook runs before a token is issued and allows you to add additional claims based on the authentication method used. +[auth.hook.custom_access_token] +# enabled = true +# uri = "pg-functions:////" + + +# Configure one of the supported SMS providers: `twilio`, `twilio_verify`, `messagebird`, `textlocal`, `vonage`. +[auth.sms.twilio] +enabled = false +account_sid = "" +message_service_sid = "" +# DO NOT commit your Twilio auth token to git. Use environment variable substitution instead: +auth_token = "env(SUPABASE_AUTH_SMS_TWILIO_AUTH_TOKEN)" + +# Use an external OAuth provider. The full list of providers are: `apple`, `azure`, `bitbucket`, +# `discord`, `facebook`, `github`, `gitlab`, `google`, `keycloak`, `linkedin_oidc`, `notion`, `twitch`, +# `twitter`, `slack`, `spotify`, `workos`, `zoom`. +[auth.external.apple] +enabled = false +client_id = "" +# DO NOT commit your OAuth provider secret to git. Use environment variable substitution instead: +secret = "env(SUPABASE_AUTH_EXTERNAL_APPLE_SECRET)" +# Overrides the default auth redirectUrl. +redirect_uri = "" +# Overrides the default auth provider URL. Used to support self-hosted gitlab, single-tenant Azure, +# or any other third-party OIDC providers. +url = "" + +[analytics] +enabled = false +port = 54327 +vector_port = 54328 +# Configure one of the supported backends: `postgres`, `bigquery`. +backend = "postgres" + +# Experimental features may be deprecated any time +[experimental] +# Configures Postgres storage engine to use OrioleDB (S3) +orioledb_version = "" +# Configures S3 bucket URL, eg. .s3-.amazonaws.com +s3_host = "env(S3_HOST)" +# Configures S3 bucket region, eg. us-east-1 +s3_region = "env(S3_REGION)" +# Configures AWS_ACCESS_KEY_ID for S3 bucket +s3_access_key = "env(S3_ACCESS_KEY)" +# Configures AWS_SECRET_ACCESS_KEY for S3 bucket +s3_secret_key = "env(S3_SECRET_KEY)" diff --git a/packages/adapter-drizzle/drizzle.config.ts b/packages/adapter-drizzle/drizzle.config.ts new file mode 100644 index 0000000000..99961d3f87 --- /dev/null +++ b/packages/adapter-drizzle/drizzle.config.ts @@ -0,0 +1,16 @@ +import { defineConfig } from "drizzle-kit"; + +export default defineConfig({ + dialect: "postgresql", + schema: "./src/schema.ts", + out: "./drizzle/migrations", + dbCredentials: { + url: "postgres://postgres:postgres@localhost:5432/eliza", + }, + migrations: { + table: "__drizzle_migrations", + schema: "public", + prefix: "timestamp", + }, + breakpoints: true, +}); diff --git a/packages/adapter-drizzle/drizzle/migrations/20250201002018_init.sql b/packages/adapter-drizzle/drizzle/migrations/20250201002018_init.sql new file mode 100644 index 0000000000..7889470c14 --- /dev/null +++ b/packages/adapter-drizzle/drizzle/migrations/20250201002018_init.sql @@ -0,0 +1,169 @@ +-- Custom SQL migration file, put your code below! -- +-- Enable pgvector extension + +-- -- Drop existing tables and extensions +-- DROP EXTENSION IF EXISTS vector CASCADE; +-- DROP TABLE IF EXISTS relationships CASCADE; +-- DROP TABLE IF EXISTS participants CASCADE; +-- DROP TABLE IF EXISTS logs CASCADE; +-- DROP TABLE IF EXISTS goals CASCADE; +-- DROP TABLE IF EXISTS memories CASCADE; +-- DROP TABLE IF EXISTS rooms CASCADE; +-- DROP TABLE IF EXISTS accounts CASCADE; +-- DROP TABLE IF EXISTS knowledge CASCADE; + + +CREATE EXTENSION IF NOT EXISTS vector; +CREATE EXTENSION IF NOT EXISTS fuzzystrmatch; + +-- Create a function to determine vector dimension +CREATE OR REPLACE FUNCTION get_embedding_dimension() +RETURNS INTEGER AS $$ +BEGIN + -- Check for OpenAI first + IF current_setting('app.use_openai_embedding', TRUE) = 'true' THEN + RETURN 1536; -- OpenAI dimension + -- Then check for Ollama + ELSIF current_setting('app.use_ollama_embedding', TRUE) = 'true' THEN + RETURN 1024; -- Ollama mxbai-embed-large dimension + -- Then check for GAIANET + ELSIF current_setting('app.use_gaianet_embedding', TRUE) = 'true' THEN + RETURN 768; -- Gaianet nomic-embed dimension + ELSE + RETURN 384; -- BGE/Other embedding dimension + END IF; +END; +$$ LANGUAGE plpgsql; + +BEGIN; + +CREATE TABLE IF NOT EXISTS accounts ( + "id" UUID PRIMARY KEY, + "createdAt" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + "name" TEXT, + "username" TEXT, + "email" TEXT NOT NULL, + "avatarUrl" TEXT, + "details" JSONB DEFAULT '{}'::jsonb +); + +CREATE TABLE IF NOT EXISTS rooms ( + "id" UUID PRIMARY KEY, + "createdAt" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP +); + +DO $$ +DECLARE + vector_dim INTEGER; +BEGIN + vector_dim := get_embedding_dimension(); + + EXECUTE format(' + CREATE TABLE IF NOT EXISTS memories ( + "id" UUID PRIMARY KEY, + "type" TEXT NOT NULL, + "createdAt" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + "content" JSONB NOT NULL, + "embedding" vector(%s), + "userId" UUID REFERENCES accounts("id"), + "agentId" UUID REFERENCES accounts("id"), + "roomId" UUID REFERENCES rooms("id"), + "unique" BOOLEAN DEFAULT true NOT NULL, + CONSTRAINT fk_room FOREIGN KEY ("roomId") REFERENCES rooms("id") ON DELETE CASCADE, + CONSTRAINT fk_user FOREIGN KEY ("userId") REFERENCES accounts("id") ON DELETE CASCADE, + CONSTRAINT fk_agent FOREIGN KEY ("agentId") REFERENCES accounts("id") ON DELETE CASCADE + )', vector_dim); +END $$; + +CREATE TABLE IF NOT EXISTS goals ( + "id" UUID PRIMARY KEY, + "createdAt" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + "userId" UUID REFERENCES accounts("id"), + "name" TEXT, + "status" TEXT, + "description" TEXT, + "roomId" UUID REFERENCES rooms("id"), + "objectives" JSONB DEFAULT '[]'::jsonb NOT NULL, + CONSTRAINT fk_room FOREIGN KEY ("roomId") REFERENCES rooms("id") ON DELETE CASCADE, + CONSTRAINT fk_user FOREIGN KEY ("userId") REFERENCES accounts("id") ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS logs ( + "id" UUID PRIMARY KEY DEFAULT gen_random_uuid(), + "createdAt" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + "userId" UUID NOT NULL REFERENCES accounts("id"), + "body" JSONB NOT NULL, + "type" TEXT NOT NULL, + "roomId" UUID NOT NULL REFERENCES rooms("id"), + CONSTRAINT fk_room FOREIGN KEY ("roomId") REFERENCES rooms("id") ON DELETE CASCADE, + CONSTRAINT fk_user FOREIGN KEY ("userId") REFERENCES accounts("id") ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS participants ( + "id" UUID PRIMARY KEY, + "createdAt" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + "userId" UUID REFERENCES accounts("id"), + "roomId" UUID REFERENCES rooms("id"), + "userState" TEXT, + "last_message_read" TEXT, + UNIQUE("userId", "roomId"), + CONSTRAINT fk_room FOREIGN KEY ("roomId") REFERENCES rooms("id") ON DELETE CASCADE, + CONSTRAINT fk_user FOREIGN KEY ("userId") REFERENCES accounts("id") ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS relationships ( + "id" UUID PRIMARY KEY, + "createdAt" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + "userA" UUID NOT NULL REFERENCES accounts("id"), + "userB" UUID NOT NULL REFERENCES accounts("id"), + "status" TEXT, + "userId" UUID NOT NULL REFERENCES accounts("id"), + CONSTRAINT fk_user_a FOREIGN KEY ("userA") REFERENCES accounts("id") ON DELETE CASCADE, + CONSTRAINT fk_user_b FOREIGN KEY ("userB") REFERENCES accounts("id") ON DELETE CASCADE, + CONSTRAINT fk_user FOREIGN KEY ("userId") REFERENCES accounts("id") ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS cache ( + "key" TEXT NOT NULL, + "agentId" TEXT NOT NULL, + "value" JSONB DEFAULT '{}'::jsonb, + "createdAt" TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + "expiresAt" TIMESTAMP, + PRIMARY KEY ("key", "agentId") +); + +DO $$ +DECLARE + vector_dim INTEGER; +BEGIN + vector_dim := get_embedding_dimension(); + + EXECUTE format(' + CREATE TABLE IF NOT EXISTS knowledge ( + "id" UUID PRIMARY KEY, + "agentId" UUID REFERENCES accounts("id"), + "content" JSONB NOT NULL, + "embedding" vector(%s), + "createdAt" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + "isMain" BOOLEAN DEFAULT FALSE, + "originalId" UUID REFERENCES knowledge("id"), + "chunkIndex" INTEGER, + "isShared" BOOLEAN DEFAULT FALSE, + CHECK(("isShared" = true AND "agentId" IS NULL) OR ("isShared" = false AND "agentId" IS NOT NULL)) + )', vector_dim); +END $$; + +-- Indexes +CREATE INDEX IF NOT EXISTS idx_memories_embedding ON memories USING hnsw ("embedding" vector_cosine_ops); +CREATE INDEX IF NOT EXISTS idx_memories_type_room ON memories("type", "roomId"); +CREATE INDEX IF NOT EXISTS idx_participants_user ON participants("userId"); +CREATE INDEX IF NOT EXISTS idx_participants_room ON participants("roomId"); +CREATE INDEX IF NOT EXISTS idx_relationships_users ON relationships("userA", "userB"); +CREATE INDEX IF NOT EXISTS idx_knowledge_agent ON knowledge("agentId"); +CREATE INDEX IF NOT EXISTS idx_knowledge_agent_main ON knowledge("agentId", "isMain"); +CREATE INDEX IF NOT EXISTS idx_knowledge_original ON knowledge("originalId"); +CREATE INDEX IF NOT EXISTS idx_knowledge_created ON knowledge("agentId", "createdAt"); +CREATE INDEX IF NOT EXISTS idx_knowledge_shared ON knowledge("isShared"); +CREATE INDEX IF NOT EXISTS idx_knowledge_embedding ON knowledge USING ivfflat (embedding vector_cosine_ops); + +COMMIT; diff --git a/packages/adapter-drizzle/drizzle/migrations/meta/20250201002018_snapshot.json b/packages/adapter-drizzle/drizzle/migrations/meta/20250201002018_snapshot.json new file mode 100644 index 0000000000..6453ebcdf8 --- /dev/null +++ b/packages/adapter-drizzle/drizzle/migrations/meta/20250201002018_snapshot.json @@ -0,0 +1,18 @@ +{ + "id": "1011a59b-2c21-47a9-bf57-cc3585fcc660", + "prevId": "00000000-0000-0000-0000-000000000000", + "version": "7", + "dialect": "postgresql", + "tables": {}, + "enums": {}, + "schemas": {}, + "views": {}, + "sequences": {}, + "roles": {}, + "policies": {}, + "_meta": { + "columns": {}, + "schemas": {}, + "tables": {} + } +} \ No newline at end of file diff --git a/packages/adapter-drizzle/drizzle/migrations/meta/_journal.json b/packages/adapter-drizzle/drizzle/migrations/meta/_journal.json new file mode 100644 index 0000000000..47e5d13dc9 --- /dev/null +++ b/packages/adapter-drizzle/drizzle/migrations/meta/_journal.json @@ -0,0 +1,13 @@ +{ + "version": "7", + "dialect": "postgresql", + "entries": [ + { + "idx": 0, + "version": "7", + "when": 1738369218670, + "tag": "20250201002018_init", + "breakpoints": true + } + ] +} \ No newline at end of file diff --git a/packages/adapter-drizzle/package.json b/packages/adapter-drizzle/package.json new file mode 100644 index 0000000000..d27758e0ee --- /dev/null +++ b/packages/adapter-drizzle/package.json @@ -0,0 +1,43 @@ +{ + "name": "@elizaos/adapter-drizzle", + "version": "0.1.9", + "type": "module", + "main": "dist/index.js", + "module": "dist/index.js", + "types": "dist/index.d.ts", + "exports": { + "./package.json": "./package.json", + ".": { + "import": { + "@elizaos/source": "./src/index.ts", + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + } + } + }, + "files": [ + "dist", + "schema.sql", + "seed.sql" + ], + "dependencies": { + "@elizaos/core": "workspace:*", + "@types/pg": "8.11.10", + "bun": "1.2.0", + "drizzle-kit": "^0.30.4", + "drizzle-orm": "^0.39.1", + "pg": "8.13.1" + }, + "devDependencies": { + "bun-types": "^1.2.0", + "dockerode": "^4.0.4", + "tsup": "8.3.5" + }, + "scripts": { + "build": "tsup --format esm --dts", + "dev": "tsup --format esm --dts --watch" + }, + "peerDependencies": { + "typescript": "^5.0.0" + } +} \ No newline at end of file diff --git a/packages/adapter-drizzle/src/__tests__/database.test.ts b/packages/adapter-drizzle/src/__tests__/database.test.ts new file mode 100644 index 0000000000..1e1014b4f6 --- /dev/null +++ b/packages/adapter-drizzle/src/__tests__/database.test.ts @@ -0,0 +1,689 @@ +import { + describe, + expect, + test, + beforeAll, + beforeEach, + afterEach, + afterAll, +} from "bun:test"; +import { DrizzleDatabaseAdapter } from "../index"; +import { elizaLogger, stringToUuid } from "@elizaos/core"; +import Docker from "dockerode"; +import getPort from "get-port"; +import pg from "pg"; +import { v4 as uuid } from "uuid"; +import { sql } from "drizzle-orm"; +import { getEmbeddingForTest } from "@elizaos/core"; +import { MemorySeedManager } from "./seed.ts"; + +const { Client } = pg; + +type DatabaseConnection = { + client: pg.Client; + adapter: DrizzleDatabaseAdapter; + docker: Docker; + container: Docker.Container; + }; + +async function createDockerDB(docker: Docker): Promise { + const port = await getPort({ port: 5432 }); + const image = "pgvector/pgvector:pg16"; + + const pullStream = await docker.pull(image); + await new Promise((resolve, reject) => + docker.modem.followProgress(pullStream, (err) => + err ? reject(err) : resolve(err) + ) + ); + + const container = await docker.createContainer({ + Image: image, + Env: [ + "POSTGRES_PASSWORD=postgres", + "POSTGRES_USER=postgres", + "POSTGRES_DB=postgres", + ], + name: `drizzle-integration-tests-${uuid()}`, + HostConfig: { + AutoRemove: true, + PortBindings: { + "5432/tcp": [{ HostPort: `${port}` }], + }, + }, + }); + + await container.start(); + + return `postgres://postgres:postgres@localhost:${port}/postgres`; +} + +async function connectDatabase(): Promise { + const docker = new Docker(); + const connectionString = process.env["PG_VECTOR_CONNECTION_STRING"] ?? + (await createDockerDB(docker)); + + const sleep = 250; + let timeLeft = 5000; + let connected = false; + let lastError: unknown | undefined; + let client: pg.Client | undefined; + let container: Docker.Container | undefined; + + // Get the container reference if we created one + if (!process.env["PG_VECTOR_CONNECTION_STRING"]) { + const containers = await docker.listContainers(); + container = docker.getContainer( + containers.find(c => c.Names[0].includes('drizzle-integration-tests'))?.Id! + ); + } + + do { + try { + client = new Client(connectionString); + await client.connect(); + connected = true; + break; + } catch (e) { + lastError = e; + await new Promise((resolve) => setTimeout(resolve, sleep)); + timeLeft -= sleep; + } + } while (timeLeft > 0); + + if (!connected || !client) { + elizaLogger.error("Cannot connect to Postgres"); + await client?.end().catch(console.error); + await container?.stop().catch(console.error); + throw lastError; + } + + const adapter = new DrizzleDatabaseAdapter(connectionString); + + return { + client, + adapter, + docker, + container: container! + }; +} + +const parseVectorString = (vectorStr: string): number[] => { + if (!vectorStr) return []; + // Remove brackets and split by comma + return vectorStr.replace(/[[\]]/g, '').split(',').map(Number); +}; + +async function cleanDatabase(client: pg.Client) { + try { + await client.query('DROP TABLE IF EXISTS relationships CASCADE'); + await client.query('DROP TABLE IF EXISTS participants CASCADE'); + await client.query('DROP TABLE IF EXISTS logs CASCADE'); + await client.query('DROP TABLE IF EXISTS goals CASCADE'); + await client.query('DROP TABLE IF EXISTS memories CASCADE'); + await client.query('DROP TABLE IF EXISTS rooms CASCADE'); + await client.query('DROP TABLE IF EXISTS accounts CASCADE'); + await client.query('DROP TABLE IF EXISTS cache CASCADE'); + await client.query('DROP EXTENSION IF EXISTS vector CASCADE'); + await client.query('DROP SCHEMA IF EXISTS extensions CASCADE'); + await client.query("DROP TABLE IF EXISTS __drizzle_migrations"); + elizaLogger.success("Database cleanup completed successfully"); + } catch (error) { + elizaLogger.error( + `Database cleanup failed: ${ + error instanceof Error ? error.message : "Unknown error" + }` + ); + throw error; + } +} + +async function stopContainers(client: pg.Client, docker: Docker) { + try { + // First end the client connection + await client?.end().catch(error => { + elizaLogger.error(`Failed to close client: ${error instanceof Error ? error.message : "Unknown error"}`); + }); + + // Get all containers with our test prefix + const containers = await docker.listContainers({ + all: true, // Include stopped containers + filters: { + name: ['drizzle-integration-tests'] + } + }); + + // Stop all matching containers + await Promise.all( + containers.map(async containerInfo => { + const container = docker.getContainer(containerInfo.Id); + try { + await container.stop(); + elizaLogger.success(`Stopped container: ${containerInfo.Id.substring(0, 12)}`); + } catch (error) { + // If container is already stopped, that's fine + if (error instanceof Error && !error.message.includes('container already stopped')) { + elizaLogger.error( + `Failed to stop container ${containerInfo.Id.substring(0, 12)}: ${error.message}` + ); + } + } + }) + ); + } catch (error) { + elizaLogger.error( + `Container cleanup failed: ${ + error instanceof Error ? error.message : "Unknown error" + }` + ); + throw error; + } +} + +const initializeDatabase = async (client: pg.Client) => { + try { + await client.query(` + ALTER DATABASE postgres SET app.use_openai_embedding = 'true'; + ALTER DATABASE postgres SET app.use_ollama_embedding = 'false'; + `); + + await client.query("CREATE EXTENSION IF NOT EXISTS vector"); + + const { rows: vectorExt } = await client.query(` + SELECT * FROM pg_extension WHERE extname = 'vector' + `); + elizaLogger.info("Vector extension status:", { + isInstalled: vectorExt.length > 0, + }); + + const { rows: searchPath } = await client.query("SHOW search_path"); + elizaLogger.info("Search path:", { + searchPath: searchPath[0].search_path, + }); + } catch (error) { + elizaLogger.error( + `Database initialization failed: ${ + error instanceof Error ? error.message : "Unknown error" + }` + ); + throw error; + } +}; + +describe("DrizzleDatabaseAdapter - Vector Extension Validation", () => { + describe("Schema and Extension Management", () => { + let adapter: DrizzleDatabaseAdapter; + let client: pg.Client; + let docker: Docker; + + beforeEach(async () => { + ({ client, adapter, docker } = await connectDatabase()); + await initializeDatabase(client); + }); + + afterEach(async () => { + await stopContainers(client, docker); + }); + + test("should initialize with vector extension", async () => { + elizaLogger.info("Testing vector extension initialization..."); + try { + await adapter.init(); + + const { rows } = await client.query(` + SELECT 1 FROM pg_extension WHERE extname = 'vector' + `); + expect(rows.length).toBe(1); + elizaLogger.success("Vector extension verified successfully"); + } catch (error) { + elizaLogger.error( + `Vector extension test failed: ${ + error instanceof Error ? error.message : "Unknown error" + }` + ); + throw error; + } + }); + + test("should handle missing rooms table", async () => { + try { + // First initialize adapter which should create the rooms table + await adapter.init(); + + const id = stringToUuid("test-room"); + + // Try creating new room + await adapter.createRoom(id); + + // Try getting room + const roomId = await adapter.getRoom(id); + expect(roomId).toEqual(id); + + elizaLogger.success("Rooms table verified successfully"); + } catch (error) { + elizaLogger.error( + `Rooms table test failed: ${ + error instanceof Error ? error.message : "Unknown error" + }` + ); + throw error; + } + }); + + test("should not reapply schema when everything exists", async () => { + elizaLogger.info("Testing schema reapplication prevention..."); + try { + // First initialization + await adapter.init(); + + // Get table count after first initialization + const { rows: firstCount } = await client.query(` + SELECT count(*) FROM information_schema.tables + WHERE table_schema = 'public' + `); + + // Second initialization + await adapter.init(); + + // Get table count after second initialization + const { rows: secondCount } = await client.query(` + SELECT count(*) FROM information_schema.tables + WHERE table_schema = 'public' + `); + + // Verify counts are the same + expect(firstCount[0].count).toEqual(secondCount[0].count); + elizaLogger.success("Verified schema was not reapplied"); + } catch (error) { + elizaLogger.error( + `Schema reapplication test failed: ${ + error instanceof Error ? error.message : "Unknown error" + }` + ); + throw error; + } + }); + }); +}); + + +describe("Memory Operations with Vector", () => { + let adapter: DrizzleDatabaseAdapter; + let client: pg.Client; + let docker: Docker; + + beforeAll(async () => { + ({ adapter, client, docker } = await connectDatabase()); + await adapter.init(); + + const seedManager = new MemorySeedManager(); + await seedManager.createMemories(); + + // Create necessary account and room first + await adapter.createAccount({ + id: agentId, + name: "Agent Test", + username: "agent-test", + email: "agent-test@test.com", + }); + + await adapter.createRoom(roomId); + await adapter.addParticipant(agentId, roomId); + }); + + afterAll(async () => { + await cleanDatabase(client); + await stopContainers(client, docker); + }); + + test("should create and retrieve memory with vector embedding", async () => { + const content = "This is a test memory about cats and dogs"; + const dimensions = 384; + const embedding = await getEmbeddingForTest(content, { + model: "text-embedding-3-large", + endpoint: "https://api.openai.com/v1", + apiKey: process.env.OPENAI_API_KEY, + dimensions: dimensions, + isOllama: false, + provider: "OpenAI", + }); + + + // Create memory + await adapter.createMemory({ + id: memoryId, + content: { + text: content, + type: "message" + }, + embedding: embedding, + userId: agentId, + agentId: agentId, + roomId: roomId, + createdAt: Date.now(), + unique: true + }, TEST_TABLE); + + const memory = await adapter.getMemoryById(memoryId); + + // Verify memory and embedding + expect(memory).toBeDefined(); + const parsedEmbedding = typeof memory?.embedding === 'string' ? parseVectorString(memory.embedding) : memory?.embedding; + expect(Array.isArray(parsedEmbedding)).toBe(true); + expect(parsedEmbedding).toHaveLength(dimensions); + expect(memory?.content?.text).toEqual(content); + }); + + test("should create and retrieve memory with vector embedding", async () => { + const testMemoryId = stringToUuid('memory-test-2'); + const content = "The quick brown fox jumps over the lazy dog"; + const dimensions = 384; + const embedding = await getEmbeddingForTest(content, { + model: "text-embedding-3-large", + endpoint: "https://api.openai.com/v1", + apiKey: process.env.OPENAI_API_KEY, + dimensions: dimensions, + isOllama: false, + provider: "OpenAI", + }); + + // Create memory + await adapter.createMemory({ + id: testMemoryId, + content: { + text: content, + type: "message" + }, + embedding: embedding, + userId: agentId, + agentId: agentId, + roomId: roomId, + createdAt: Date.now(), + unique: true + }, TEST_TABLE); + + // Search by embedding and verify + const results = await adapter.searchMemoriesByEmbedding(embedding, { + tableName: TEST_TABLE, + roomId: roomId, + agentId: agentId, + match_threshold: 0.8, + count: 1 + }); + + expect(results).toHaveLength(1); + expect(results[0].similarity).toBeGreaterThanOrEqual(0.8); + expect(results[0].content.text).toBe(content); + expect(results[0].embedding).toEqual(embedding); + expect(results[0].roomId).toBe(roomId); + expect(results[0].agentId).toBe(agentId); + }); + + test("should handle invalid embedding dimensions", async () => { + const wrongDimensionEmbedding = new Array(100).fill(0.1); + + const [{ get_embedding_dimension: embeddingDimension }] = await adapter.db.execute( + sql`SELECT get_embedding_dimension()` + ); + + const memoryWithWrongDimension = { + id: memoryId, + content: { + text: "This is a test memory with wrong dimensions", + type: "message" + }, + embedding: wrongDimensionEmbedding, + userId: agentId, + agentId: agentId, + roomId: roomId, + createdAt: Date.now(), + unique: true + }; + + try { + await adapter.createMemory(memoryWithWrongDimension, TEST_TABLE); + } catch (error) { + expect(error).toBeDefined(); + expect((error as Error).message).toBe(`different vector dimensions ${embeddingDimension} and ${wrongDimensionEmbedding.length}`); + } + }); +}); + + +describe("Advanced Vector Memory Operations", () => { + // Test data constants + const TEST_TABLE = 'test_memories'; + const MEMORY_SETS = { + programming: [ + "JavaScript is a versatile programming language used for web development", + "Python is known for its simplicity and readability in coding", + "Java remains popular for enterprise application development", + "TypeScript adds static typing to JavaScript for better development", + "React is a popular framework for building user interfaces" + ], + science: [ + "Quantum physics explores the behavior of matter at atomic scales", + "Biology studies the structure and function of living organisms", + "Chemistry investigates the composition of substances", + "Astronomy examines celestial bodies and phenomena", + "Geology focuses on Earth's structure and history" + ], + cooking: [ + "Italian cuisine emphasizes fresh ingredients and simple preparation", + "French cooking techniques form the basis of culinary arts", + "Asian fusion combines traditional flavors with modern methods", + "Baking requires precise measurements and temperature control", + "Mediterranean diet includes olive oil, vegetables, and seafood" + ] + }; + const memoryIds = new Map(); + + // Helper function to create memory with embedding + async function createMemoryWithContent(content: string, category: string): Promise { + const memoryId = stringToUuid(`memory-${category}-${Date.now()}`); + const embedding = await getEmbeddingForTest(content, { + model: "text-embedding-3-large", + endpoint: "https://api.openai.com/v1", + apiKey: process.env.OPENAI_API_KEY, + dimensions: 384, + isOllama: false, + provider: "OpenAI" + }); + + await adapter.createMemory({ + id: memoryId, + content: { + text: content, + type: "message" + }, + embedding, + userId: agentId, + agentId: agentId, + roomId: roomId, + createdAt: Date.now(), + unique: true + }, TEST_TABLE); + + return memoryId; + } + + // Setup test environment + beforeAll(async () => { + ({ adapter, client, docker } = await connectDatabase()); + await adapter.init(); + + // Create test account and room + await adapter.createAccount({ + id: agentId, + name: "Agent Test", + username: "agent-test", + email: "agent-test@test.com", + }); + + await adapter.createRoom(roomId); + await adapter.addParticipant(agentId, roomId); + + // Create memories for each category + for (const [category, contents] of Object.entries(MEMORY_SETS)) { + const ids = await Promise.all( + contents.map(content => createMemoryWithContent(content, category)) + ); + memoryIds.set(category, ids); + } + + elizaLogger.success("Test environment setup completed"); + }); + + // Cleanup after tests + afterAll(async () => { + await cleanDatabase(client); + await stopContainers(client, docker); + }); + + test("should find similar memories within same context", async () => { + const queryContent = "How do programming languages like JavaScript and Python compare?"; + const embedding = await getEmbeddingForTest(queryContent, { + model: "text-embedding-3-large", + endpoint: "https://api.openai.com/v1", + apiKey: process.env.OPENAI_API_KEY, + dimensions: 384, + isOllama: false, + provider: "OpenAI" + }); + + const results = await adapter.searchMemoriesByEmbedding(embedding, { + tableName: TEST_TABLE, + roomId: roomId, + agentId: agentId, + match_threshold: 0.7, + count: 3 + }); + + expect(results.length).toBeGreaterThan(0); + expect(results[0].similarity).toBeGreaterThan(0.7); + expect(results.some(r => r.content.text.includes("JavaScript"))).toBe(true); + expect(results.some(r => r.content.text.includes("Python"))).toBe(true); + }); + + test("should effectively filter cross-context searches", async () => { + const queryContent = "What are the best programming frameworks for web development?"; + const embedding = await getEmbeddingForTest(queryContent, { + model: "text-embedding-3-large", + endpoint: "https://api.openai.com/v1", + apiKey: process.env.OPENAI_API_KEY, + dimensions: 384, + isOllama: false, + provider: "OpenAI" + }); + + const results = await adapter.searchMemoriesByEmbedding(embedding, { + tableName: TEST_TABLE, + roomId: roomId, + agentId: agentId, + match_threshold: 0.75, + count: 5 + }); + + // Should find programming-related memories but not cooking or science + expect(results.every(r => !r.content.text.toLowerCase().includes("cuisine"))).toBe(true); + expect(results.every(r => !r.content.text.toLowerCase().includes("physics"))).toBe(true); + }); + + test("should handle threshold-based filtering accurately", async () => { + const queryContent = "Tell me about web development and user interfaces"; + const embedding = await getEmbeddingForTest(queryContent, { + model: "text-embedding-3-large", + endpoint: "https://api.openai.com/v1", + apiKey: process.env.OPENAI_API_KEY, + dimensions: 384, + isOllama: false, + provider: "OpenAI" + }); + + // Test with different thresholds + const highThresholdResults = await adapter.searchMemoriesByEmbedding(embedding, { + tableName: TEST_TABLE, + roomId: roomId, + agentId: agentId, + match_threshold: 0.9, + count: 5 + }); + + const lowThresholdResults = await adapter.searchMemoriesByEmbedding(embedding, { + tableName: TEST_TABLE, + roomId: roomId, + agentId: agentId, + match_threshold: 0.6, + count: 5 + }); + + expect(highThresholdResults.length).toBeLessThan(lowThresholdResults.length); + expect(highThresholdResults.every(r => r.similarity >= 0.9)).toBe(true); + }); + + test("should return paginated results for large-scale searches", async () => { + const queryContent = "Tell me about science and research"; + const embedding = await getEmbeddingForTest(queryContent, { + model: "text-embedding-3-large", + endpoint: "https://api.openai.com/v1", + apiKey: process.env.OPENAI_API_KEY, + dimensions: 384, + isOllama: false, + provider: "OpenAI" + }); + + const firstPage = await adapter.searchMemoriesByEmbedding(embedding, { + tableName: TEST_TABLE, + roomId: roomId, + agentId: agentId, + match_threshold: 0.6, + count: 2 + }); + + const secondPage = await adapter.searchMemoriesByEmbedding(embedding, { + tableName: TEST_TABLE, + roomId: roomId, + agentId: agentId, + match_threshold: 0.6, + count: 2, + offset: 2 + }); + + expect(firstPage.length).toBe(2); + expect(secondPage.length).toBe(2); + expect(firstPage[0].id).not.toBe(secondPage[0].id); + }); + + test("should handle complex multi-context search scenarios", async () => { + const queryContent = "How does scientific research methodology compare to programming best practices?"; + const embedding = await getEmbeddingForTest(queryContent, { + model: "text-embedding-3-large", + endpoint: "https://api.openai.com/v1", + apiKey: process.env.OPENAI_API_KEY, + dimensions: 384, + isOllama: false, + provider: "OpenAI" + }); + + const results = await adapter.searchMemoriesByEmbedding(embedding, { + tableName: TEST_TABLE, + roomId: roomId, + agentId: agentId, + match_threshold: 0.65, + count: 6 + }); + + // Should find both programming and science related memories + const hasScience = results.some(r => + r.content.text.toLowerCase().includes("science") || + r.content.text.toLowerCase().includes("research") + ); + const hasProgramming = results.some(r => + r.content.text.toLowerCase().includes("programming") || + r.content.text.toLowerCase().includes("development") + ); + + expect(hasScience && hasProgramming).toBe(true); + expect(results.length).toBeGreaterThan(3); + }); +}); diff --git a/packages/adapter-drizzle/src/__tests__/memory.test.ts b/packages/adapter-drizzle/src/__tests__/memory.test.ts new file mode 100644 index 0000000000..3fb31a82bc --- /dev/null +++ b/packages/adapter-drizzle/src/__tests__/memory.test.ts @@ -0,0 +1,1185 @@ +import { + describe, + expect, + test, + beforeAll, + afterAll, + afterEach, +} from "bun:test"; +import { DrizzleDatabaseAdapter } from "../index"; +import { elizaLogger, Memory, stringToUuid, UUID } from "@elizaos/core"; +import Docker from "dockerode"; +import pg from "pg"; +import { sql } from "drizzle-orm"; +import { getEmbeddingForTest } from "@elizaos/core"; +import { EMBEDDING_OPTIONS, MemorySeedManager } from "./seed.ts"; +import { + connectDatabase, + cleanDatabase, + stopContainers, + parseVectorString, +} from "./utils.ts"; + +describe("Memory Operations with Vector", () => { + const TEST_TABLE = "test_memories"; + let adapter: DrizzleDatabaseAdapter; + let client: pg.Client; + let docker: Docker; + const seedManager = new MemorySeedManager(); + let MEMORY_IDS: Map; + + beforeAll(async () => { + ({ adapter, client, docker } = await connectDatabase()); + await adapter.init(); + await adapter.createAccount({ + id: seedManager.AGENT_ID, + name: "Agent Test", + username: "agent-test", + email: "agent-test@test.com", + }); + + await adapter.createAccount({ + id: seedManager.USER_ID, + name: "User Test", + username: "user-test", + email: "user-test@test.com", + }); + + await adapter.createRoom(seedManager.ROOM_ID); + await adapter.addParticipant(seedManager.AGENT_ID, seedManager.ROOM_ID); + await adapter.addParticipant(seedManager.USER_ID, seedManager.ROOM_ID); + MEMORY_IDS = await seedManager.createMemories(adapter, TEST_TABLE); + }); + + afterAll(async () => { + await cleanDatabase(client); + // Wait for cleanup to complete + await new Promise((resolve) => setTimeout(resolve, 500)); + await stopContainers(client, docker); + }); + + afterEach(async () => { + // Get all current memory IDs + const allMemories = await adapter.getMemories({ + roomId: seedManager.ROOM_ID, + tableName: TEST_TABLE, + }); + + // Get all seeded memory IDs as a flat array + const seededIds = Array.from(MEMORY_IDS.values()).flat(); + + // Remove memories that aren't in our seeded set + for (const memory of allMemories) { + if (!seededIds.includes(memory.id as UUID)) { + await adapter.removeMemory(memory.id as UUID, TEST_TABLE); + } + } + }); + + test("should create and retrieve memory with vector embedding", async () => { + const memoryId = stringToUuid("memory-test-1"); + const content = "This is a test memory about cats and dogs"; + const dimensions = 384; + const embedding = await getEmbeddingForTest(content, EMBEDDING_OPTIONS); + + // Create memory + await adapter.createMemory( + { + id: memoryId, + content: { + text: content, + type: "message", + }, + embedding: embedding, + userId: seedManager.USER_ID, + agentId: seedManager.AGENT_ID, + roomId: seedManager.ROOM_ID, + createdAt: Date.now(), + unique: true, + }, + TEST_TABLE + ); + + const memory = await adapter.getMemoryById(memoryId); + + // Verify memory and embedding + expect(memory).toBeDefined(); + const parsedEmbedding = + typeof memory?.embedding === "string" + ? parseVectorString(memory.embedding) + : memory?.embedding; + expect(Array.isArray(parsedEmbedding)).toBe(true); + expect(parsedEmbedding).toHaveLength(dimensions); + expect(memory?.content?.text).toEqual(content); + }); + + test("should create and retrieve memory with vector embedding", async () => { + const testMemoryId = stringToUuid("memory-test-2"); + const content = "The quick brown fox jumps over the lazy dog"; + const embedding = await getEmbeddingForTest(content, EMBEDDING_OPTIONS); + + // Create memory + await adapter.createMemory( + { + id: testMemoryId, + content: { + text: content, + type: "message", + }, + embedding: embedding, + userId: seedManager.USER_ID, + agentId: seedManager.AGENT_ID, + roomId: seedManager.ROOM_ID, + createdAt: Date.now(), + unique: true, + }, + TEST_TABLE + ); + + // Search by embedding and verify + const results = await adapter.searchMemoriesByEmbedding(embedding, { + tableName: TEST_TABLE, + roomId: seedManager.ROOM_ID, + agentId: seedManager.AGENT_ID, + match_threshold: 0.8, + count: 1, + }); + + expect(results).toHaveLength(1); + expect(results[0].similarity).toBeGreaterThanOrEqual(0.8); + expect(results[0].content.text).toBe(content); + expect(results[0].embedding).toEqual(embedding); + expect(results[0].roomId).toBe(seedManager.ROOM_ID); + expect(results[0].agentId).toBe(seedManager.AGENT_ID); + }); + + test("should handle invalid embedding dimensions", async () => { + const wrongDimensionEmbedding = new Array(100).fill(0.1); + + const [{ get_embedding_dimension: embeddingDimension }] = + await adapter.db.execute(sql`SELECT get_embedding_dimension()`); + + const memoryWithWrongDimension = { + id: stringToUuid("memory-test-3"), + content: { + text: "This is a test memory with wrong dimensions", + type: "message", + }, + embedding: wrongDimensionEmbedding, + userId: seedManager.USER_ID, + agentId: seedManager.AGENT_ID, + roomId: seedManager.ROOM_ID, + createdAt: Date.now(), + unique: true, + }; + + try { + await adapter.createMemory(memoryWithWrongDimension, TEST_TABLE); + } catch (error) { + expect(error).toBeDefined(); + expect((error as Error).message).toBe( + `different vector dimensions ${embeddingDimension} and ${wrongDimensionEmbedding.length}` + ); + } + }); + + test("should find similar memories within same context", async () => { + const queryContent = + "How do programming languages like JavaScript and Python compare?"; + const embedding = await getEmbeddingForTest( + queryContent, + EMBEDDING_OPTIONS + ); + + const results = await adapter.searchMemoriesByEmbedding(embedding, { + tableName: TEST_TABLE, + roomId: seedManager.ROOM_ID, + agentId: seedManager.AGENT_ID, + match_threshold: 0.35, + count: 3, + }); + + expect(results.length).toBeGreaterThan(0); + expect(results[0].similarity).toBeGreaterThan(0.35); + expect(results.some((r) => r.content.text.includes("JavaScript"))).toBe( + true + ); + expect(results.some((r) => r.content.text.includes("Python"))).toBe( + true + ); + }); + + test("should find similar memories within same context - testing various thresholds", async () => { + const testQueries = [ + { + query: "How do programming languages like JavaScript and Python compare?", + expectedTerms: ["JavaScript", "Python"], + context: "programming", + }, + { + query: "Tell me about web development frameworks and tools", + expectedTerms: ["React", "JavaScript", "TypeScript"], + context: "programming", + }, + { + query: "What's the relationship between physics and chemistry?", + expectedTerms: ["physics", "chemistry"], + context: "science", + }, + ]; + + for (const testCase of testQueries) { + const embedding = await getEmbeddingForTest( + testCase.query, + EMBEDDING_OPTIONS + ); + + const thresholdResults = await Promise.all([ + adapter.searchMemoriesByEmbedding(embedding, { + tableName: TEST_TABLE, + roomId: seedManager.ROOM_ID, + agentId: seedManager.AGENT_ID, + match_threshold: 0.25, + count: 5, + }), + adapter.searchMemoriesByEmbedding(embedding, { + tableName: TEST_TABLE, + roomId: seedManager.ROOM_ID, + agentId: seedManager.AGENT_ID, + match_threshold: 0.5, + count: 5, + }), + ]); + + elizaLogger.debug(`Results for query: ${testCase.query}`, { + threshold_0_25: { + count: thresholdResults[0].length, + similarities: thresholdResults[0].map((r) => + (r.similarity ?? 0).toFixed(4) + ), + texts: thresholdResults[0].map((r) => r.content.text), + }, + threshold_0_50: { + count: thresholdResults[1].length, + similarities: thresholdResults[1].map((r) => + (r.similarity ?? 0).toFixed(4) + ), + texts: thresholdResults[1].map((r) => r.content.text), + }, + }); + + const results = thresholdResults[0]; + + // Test basic search functionality + expect(results.length).toBeGreaterThan(0); + expect(results[0].similarity).toBeGreaterThan(0.25); + + // Test context relevance + const contextMatches = results.filter((r) => + testCase.expectedTerms.some((term) => + r.content.text.toLowerCase().includes(term.toLowerCase()) + ) + ); + expect(contextMatches.length).toBeGreaterThan(0); + + // Test semantic relevance - results should be ordered by relevance + expect(results).toEqual( + [...results].sort( + (a, b) => (b.similarity ?? 0) - (a.similarity ?? 0) + ) + ); + + // Demonstrate threshold impact + expect(thresholdResults[1].length).toBeLessThanOrEqual( + thresholdResults[0].length + ); + + // Log cross-context contamination + const otherContextTerms = + testCase.context === "programming" + ? ["physics", "chemistry", "cuisine"] + : ["JavaScript", "Python", "cuisine"]; + + const crossContextMatches = results.filter((r) => + otherContextTerms.some((term) => + r.content.text.toLowerCase().includes(term.toLowerCase()) + ) + ); + + elizaLogger.debug("Cross-context analysis:", { + query: testCase.query, + expectedContext: testCase.context, + crossContextMatchCount: crossContextMatches.length, + crossContextMatches: crossContextMatches.map((r) => ({ + similarity: (r.similarity ?? 0).toFixed(4), + text: r.content.text, + })), + }); + } + }); + + test("should effectively filter cross-context searches", async () => { + const queryContent = + "What are the best programming frameworks for web development?"; + const embedding = await getEmbeddingForTest( + queryContent, + EMBEDDING_OPTIONS + ); + + const results = await adapter.searchMemoriesByEmbedding(embedding, { + tableName: TEST_TABLE, + roomId: seedManager.ROOM_ID, + agentId: seedManager.AGENT_ID, + match_threshold: 0.35, // Lowered based on our 384 dimension findings + count: 5, + }); + + elizaLogger.debug("Semantic search results:", { + count: results.length, + similarities: results.map((r) => ({ + similarity: (r.similarity ?? 0).toFixed(4), + text: r.content.text, + })), + }); + + // Positive matches - should find programming-related content + expect(results.length).toBeGreaterThan(0); + expect(results[0].similarity).toBeGreaterThan(0.35); + + const programmingTerms = [ + "JavaScript", + "web", + "React", + "development", + "TypeScript", + "Python", + ]; + const hasProgrammingContent = results.some((r) => + programmingTerms.some((term) => + r.content.text.toLowerCase().includes(term.toLowerCase()) + ) + ); + expect(hasProgrammingContent).toBe(true); + + // Negative matches - should not find unrelated content + const unrelatedTerms = [ + "cuisine", + "physics", + "chemistry", + "biology", + "cooking", + ]; + const hasUnrelatedContent = results.some((r) => + unrelatedTerms.some((term) => + r.content.text.toLowerCase().includes(term.toLowerCase()) + ) + ); + expect(hasUnrelatedContent).toBe(false); + + // Verify ordering by relevance + expect(results).toEqual( + [...results].sort( + (a, b) => (b.similarity ?? 0) - (a.similarity ?? 0) + ) + ); + + // Check relative rankings + const rankingCheck = results.map((r) => ({ + text: r.content.text, + terms: programmingTerms.filter((term) => + r.content.text.toLowerCase().includes(term.toLowerCase()) + ), + similarity: r.similarity, + })); + + elizaLogger.debug("Ranking analysis:", { + results: rankingCheck, + }); + + // Most relevant results should have higher similarity scores + expect(rankingCheck[0].terms.length).toBeGreaterThan(0); + }); + + test("should handle threshold-based filtering accurately", async () => { + const queryContent = + "Tell me about web development and user interfaces"; + const embedding = await getEmbeddingForTest( + queryContent, + EMBEDDING_OPTIONS + ); + + // Test with various thresholds matching our 384 dimension expectations + const thresholds = [0.5, 0.35, 0.25]; + const thresholdResults = await Promise.all( + thresholds.map((threshold) => + adapter.searchMemoriesByEmbedding(embedding, { + tableName: TEST_TABLE, + roomId: seedManager.ROOM_ID, + agentId: seedManager.AGENT_ID, + match_threshold: threshold, + count: 5, + }) + ) + ); + + // Log results for each threshold + thresholds.forEach((threshold, i) => { + elizaLogger.debug(`Results for threshold ${threshold}:`, { + count: thresholdResults[i].length, + matches: thresholdResults[i].map((r) => ({ + similarity: (r.similarity ?? 0).toFixed(4), + text: r.content.text, + })), + }); + }); + + // Verify descending result counts with increasing thresholds + expect(thresholdResults[0].length).toBeLessThanOrEqual( + thresholdResults[1].length + ); + expect(thresholdResults[1].length).toBeLessThanOrEqual( + thresholdResults[2].length + ); + + // Check threshold enforcement + thresholds.forEach((threshold, i) => { + expect( + thresholdResults[i].every( + (r) => (r.similarity ?? 0) >= threshold + ) + ).toBe(true); + }); + + // Verify relevance of results + const webDevTerms = [ + "JavaScript", + "web", + "React", + "development", + "interface", + "TypeScript", + ]; + const relevantResults = thresholdResults[2].filter((r) => + webDevTerms.some((term) => + r.content.text.toLowerCase().includes(term.toLowerCase()) + ) + ); + + elizaLogger.debug("Relevance analysis:", { + totalResults: thresholdResults[2].length, + relevantResults: relevantResults.length, + relevantMatches: relevantResults.map((r) => ({ + similarity: (r.similarity ?? 0).toFixed(4), + text: r.content.text, + matchedTerms: webDevTerms.filter((term) => + r.content.text.toLowerCase().includes(term.toLowerCase()) + ), + })), + }); + + expect(relevantResults.length).toBeGreaterThan(0); + }); + + test("should handle large-scale semantic searches effectively", async () => { + const queries = [ + { + content: "Tell me about science and research", + expectedContext: "science", + relevantTerms: [ + "physics", + "biology", + "chemistry", + "research", + "science", + ], + // Only consider it irrelevant if it's primarily about these topics + primaryIrrelevantContexts: { + programming: [ + "JavaScript", + "Python", + "programming", + "coding", + ], + cooking: ["recipe", "cuisine", "ingredients", "baking"], + }, + }, + { + content: "Explain different research methodologies", + expectedContext: "science", + relevantTerms: ["research", "study", "methodology", "analysis"], + primaryIrrelevantContexts: { + programming: [ + "JavaScript", + "Python", + "programming", + "coding", + ], + cooking: ["recipe", "cuisine", "ingredients", "baking"], + }, + }, + ]; + + for (const query of queries) { + const embedding = await getEmbeddingForTest( + query.content, + EMBEDDING_OPTIONS + ); + + // Test different result set sizes + const resultSets = await Promise.all([ + adapter.searchMemoriesByEmbedding(embedding, { + tableName: TEST_TABLE, + roomId: seedManager.ROOM_ID, + agentId: seedManager.AGENT_ID, + match_threshold: 0.35, + count: 2, + }), + adapter.searchMemoriesByEmbedding(embedding, { + tableName: TEST_TABLE, + roomId: seedManager.ROOM_ID, + agentId: seedManager.AGENT_ID, + match_threshold: 0.35, + count: 5, + }), + adapter.searchMemoriesByEmbedding(embedding, { + tableName: TEST_TABLE, + roomId: seedManager.ROOM_ID, + agentId: seedManager.AGENT_ID, + match_threshold: 0.35, + count: 10, + }), + ]); + + // Log results for analysis + elizaLogger.debug(`Search results for query: ${query.content}`, { + smallSet: { + count: resultSets[0].length, + similarities: resultSets[0].map((r) => ({ + similarity: (r.similarity ?? 0).toFixed(4), + text: r.content.text, + })), + }, + mediumSet: { + count: resultSets[1].length, + similarities: resultSets[1].map((r) => ({ + similarity: (r.similarity ?? 0).toFixed(4), + text: r.content.text, + })), + }, + largeSet: { + count: resultSets[2].length, + similarities: resultSets[2].map((r) => ({ + similarity: (r.similarity ?? 0).toFixed(4), + text: r.content.text, + })), + }, + }); + + // Test result set sizes + expect(resultSets[0].length).toBeLessThanOrEqual(2); + expect(resultSets[1].length).toBeLessThanOrEqual(5); + expect(resultSets[2].length).toBeLessThanOrEqual(10); + + // Verify ordering consistency + const verifyOrderingConsistency = ( + smaller: Memory[], + larger: Memory[] + ) => { + smaller.forEach((result, index) => { + expect(result.id).toBe(larger[index].id); + expect(result.similarity).toBe(larger[index].similarity); + }); + }; + + verifyOrderingConsistency(resultSets[0], resultSets[1]); + verifyOrderingConsistency(resultSets[1], resultSets[2]); + + for (const resultSet of resultSets) { + // Verify descending similarity order + expect(resultSet).toEqual( + [...resultSet].sort( + (a, b) => (b.similarity ?? 0) - (a.similarity ?? 0) + ) + ); + + // Check context relevance + const relevantResults = resultSet.filter((r) => + query.relevantTerms.some((term) => + r.content.text + .toLowerCase() + .includes(term.toLowerCase()) + ) + ); + expect(relevantResults.length).toBeGreaterThan(0); + + // Check that no result is primarily about irrelevant contexts + for (const [context, terms] of Object.entries( + query.primaryIrrelevantContexts + )) { + const resultsInContext = resultSet.filter((r) => { + const text = r.content.text.toLowerCase(); + // Consider it primarily about this context if it matches multiple terms + return ( + terms.filter((term) => + text.includes(term.toLowerCase()) + ).length >= 2 + ); + }); + expect(resultsInContext.length).toBe(0); + } + } + } + }); + + test("should effectively handle complex multi-context semantic relationships", async () => { + // Test various cross-domain queries that require understanding relationships + const queries = [ + { + content: + "How does scientific research methodology compare to programming best practices?", + expectedContexts: ["science", "programming"], + relationshipTerms: [ + "methodology", + "practices", + "testing", + "analysis", + ], + // Instead of exact weights, define min/max ranges + contextRanges: { + science: { min: 0.2, max: 0.6 }, // Allow 20-60% science content + programming: { min: 0.2, max: 0.6 }, // Allow 20-60% programming content + }, + contextTerms: { + science: [ + "science", + "research", + "experiment", + "methodology", + "analysis", + "hypothesis", + ], + programming: [ + "programming", + "development", + "software", + "testing", + "agile", + "code", + ], + }, + }, + { + content: + "What similarities exist between software testing and scientific experimentation?", + expectedContexts: ["science", "programming"], + relationshipTerms: [ + "testing", + "experimentation", + "verification", + "validation", + ], + contextRanges: { + science: { min: 0.2, max: 0.6 }, + programming: { min: 0.2, max: 0.6 }, + }, + contextTerms: { + science: [ + "science", + "research", + "experiment", + "methodology", + "analysis", + "hypothesis", + ], + programming: [ + "programming", + "development", + "software", + "testing", + "agile", + "code", + ], + }, + }, + ]; + + for (const query of queries) { + const embedding = await getEmbeddingForTest( + query.content, + EMBEDDING_OPTIONS + ); + + const results = await adapter.searchMemoriesByEmbedding(embedding, { + tableName: TEST_TABLE, + roomId: seedManager.ROOM_ID, + agentId: seedManager.AGENT_ID, + match_threshold: 0.3, // Lowered threshold to catch more semantic relationships + count: 10, + }); + + elizaLogger.debug( + "Raw results:", + results.map((r) => ({ + text: r.content.text, + similarity: r.similarity?.toFixed(4), + })) + ); + + // Analyze context distribution with expanded terms + const contextCounts = { + science: results.filter((r) => + query.contextTerms.science.some((term) => + r.content.text + .toLowerCase() + .includes(term.toLowerCase()) + ) + ).length, + programming: results.filter((r) => + query.contextTerms.programming.some((term) => + r.content.text + .toLowerCase() + .includes(term.toLowerCase()) + ) + ).length, + }; + + const totalResults = results.length; + const distributions = { + science: contextCounts.science / totalResults, + programming: contextCounts.programming / totalResults, + }; + + // Log detailed distribution analysis + elizaLogger.debug( + `Distribution analysis for query: ${query.content}`, + { + totalResults, + contextCounts, + distributions, + results: results.map((r) => ({ + text: r.content.text, + similarity: r.similarity?.toFixed(4), + contexts: Object.entries(query.contextTerms).reduce( + (acc, [context, terms]) => ({ + ...acc, + [context]: terms.some((term) => + r.content.text + .toLowerCase() + .includes(term.toLowerCase()) + ), + }), + {} + ), + })), + } + ); + + // Verify distributions fall within expected ranges + for (const [context, range] of Object.entries( + query.contextRanges + )) { + const actualDistribution = distributions[context]; + expect(actualDistribution).toBeGreaterThanOrEqual(range.min); + expect(actualDistribution).toBeLessThanOrEqual(range.max); + + elizaLogger.debug(`${context} distribution check:`, { + actual: actualDistribution, + range, + }); + } + + // Verify semantic relevance ordering + expect(results).toEqual( + [...results].sort( + (a, b) => (b.similarity ?? 0) - (a.similarity ?? 0) + ) + ); + + // Analyze similarity scores + const similarities = results.map((r) => r.similarity ?? 0); + const avgSimilarity = + similarities.reduce((a, b) => a + b, 0) / similarities.length; + const maxSimilarity = Math.max(...similarities); + + elizaLogger.debug("Similarity analysis:", { + max: maxSimilarity.toFixed(4), + average: avgSimilarity.toFixed(4), + distribution: similarities.map((s) => s.toFixed(4)), + }); + + // Test for semantic coherence with adjusted threshold + const hasReasonableRelevance = results.some( + (r) => (r.similarity ?? 0) > 0.4 + ); + expect(hasReasonableRelevance).toBe(true); + + // Verify relationship terms + const relationshipTermMatches = query.relationshipTerms.map( + (term) => ({ + term, + matches: results.filter((r) => + r.content.text + .toLowerCase() + .includes(term.toLowerCase()) + ), + }) + ); + + elizaLogger.debug("Relationship term analysis:", { + matches: relationshipTermMatches.map(({ term, matches }) => ({ + term, + matchCount: matches.length, + examples: matches.map((m) => ({ + similarity: m.similarity?.toFixed(4), + text: m.content.text, + })), + })), + }); + + // Expect at least some relationship terms to be present + const hasRelationshipTerms = relationshipTermMatches.some( + ({ matches }) => matches.length > 0 + ); + expect(hasRelationshipTerms).toBe(true); + + // Analyze cross-context coverage with more flexible criteria + const crossContextResults = results.filter((r) => { + const hasScience = query.contextTerms.science.some((term) => + r.content.text.toLowerCase().includes(term.toLowerCase()) + ); + const hasProgramming = query.contextTerms.programming.some( + (term) => + r.content.text + .toLowerCase() + .includes(term.toLowerCase()) + ); + return hasScience && hasProgramming; + }); + + elizaLogger.debug("Cross-context analysis:", { + crossContextCount: crossContextResults.length, + totalResults: results.length, + crossContextRatio: crossContextResults.length / results.length, + examples: crossContextResults.map((r) => ({ + similarity: r.similarity?.toFixed(4), + text: r.content.text, + })), + }); + + // Verify we have some cross-context results or high similarity results + const hasValidResults = + crossContextResults.length > 0 || + results.some((r) => (r.similarity ?? 0) > 0.7); + expect(hasValidResults).toBe(true); + } + }); + + test("should get memory by ID - existing and non-existing cases", async () => { + // Pick an existing memory ID from programming category + const existingMemoryId = MEMORY_IDS.get("programming")![0]; + + // Test getting existing memory + const retrievedMemory = await adapter.getMemoryById(existingMemoryId); + expect(retrievedMemory).toBeDefined(); + expect(retrievedMemory?.id).toBe(existingMemoryId); + + // We can verify the content matches what's in our seed data + const expectedText = seedManager.getTextByMemoryId(existingMemoryId); + expect(retrievedMemory?.content.text).toBe(expectedText); + expect(retrievedMemory?.userId).toBe(seedManager.USER_ID); + expect(retrievedMemory?.agentId).toBe(seedManager.AGENT_ID); + expect(retrievedMemory?.roomId).toBe(seedManager.ROOM_ID); + + // Test getting non-existent memory + const nonExistentId = stringToUuid("non-existent-memory"); + const nonExistentMemory = await adapter.getMemoryById(nonExistentId); + expect(nonExistentMemory).toBeNull(); + }); + + test("should successfully create and remove memory", async () => { + // Create a new test memory + const testMemoryId = stringToUuid("test-removal-memory"); + const testContent = "This is a test memory for removal testing"; + + // Create the test memory + await adapter.createMemory( + { + id: testMemoryId, + content: { + text: testContent, + type: "message", + }, + userId: seedManager.USER_ID, + agentId: seedManager.AGENT_ID, + roomId: seedManager.ROOM_ID, + createdAt: Date.now(), + unique: true, + }, + TEST_TABLE + ); + + // Verify memory was created successfully + const createdMemory = await adapter.getMemoryById(testMemoryId); + expect(createdMemory).toBeDefined(); + expect(createdMemory?.content.text).toBe(testContent); + + // Remove the memory + await adapter.removeMemory(testMemoryId, TEST_TABLE); + + // Verify memory no longer exists + const memoryAfterRemoval = await adapter.getMemoryById(testMemoryId); + expect(memoryAfterRemoval).toBeNull(); + }); + + test("should handle removal of non-existent memory without errors", async () => { + const nonExistentId = stringToUuid("non-existent-memory"); + + // Verify memory doesn't exist before removal attempt + const memoryBeforeRemoval = await adapter.getMemoryById(nonExistentId); + expect(memoryBeforeRemoval).toBeNull(); + + // Attempt to remove non-existent memory (should complete without error) + await adapter.removeMemory(nonExistentId, TEST_TABLE); + + // Verify memory is still non-existent + const memoryAfterRemoval = await adapter.getMemoryById(nonExistentId); + expect(memoryAfterRemoval).toBeNull(); + }); + + test("should retrieve all memories for a given room", async () => { + const allMemories = await adapter.getMemories({ + roomId: seedManager.ROOM_ID, + tableName: TEST_TABLE, + }); + expect(allMemories.length).toBeGreaterThan(0); + allMemories.forEach((memory) => { + expect(memory.roomId).toBe(seedManager.ROOM_ID); + }); + }); + + test("should limit number of memories returned when count is specified", async () => { + const limitedMemories = await adapter.getMemories({ + roomId: seedManager.ROOM_ID, + count: 2, + tableName: TEST_TABLE, + }); + expect(limitedMemories.length).toBe(2); + }); + + test("should retrieve memories within specified time range", async () => { + const now = Date.now(); + const hourAgo = now - 60 * 60 * 1000; + + const timeRangeMemories = await adapter.getMemories({ + roomId: seedManager.ROOM_ID, + tableName: TEST_TABLE, + start: hourAgo, + end: now, + }); + expect(timeRangeMemories.length).toBeGreaterThan(0); + timeRangeMemories.forEach((memory) => { + expect(memory.createdAt).toBeLessThanOrEqual(now); + expect(memory.createdAt).toBeGreaterThanOrEqual(hourAgo); + }); + }); + + test("should retrieve only unique memories when unique flag is set", async () => { + const uniqueMemories = await adapter.getMemories({ + roomId: seedManager.ROOM_ID, + unique: true, + tableName: TEST_TABLE, + }); + expect(uniqueMemories.length).toBeGreaterThan(0); + uniqueMemories.forEach((memory) => { + expect(memory.unique).toBe(true); + }); + }); + + test("should count all memories in a room", async () => { + const totalCount = await adapter.countMemories( + seedManager.ROOM_ID, + false, + TEST_TABLE + ); + expect(totalCount).toBe( + MEMORY_IDS.get("programming")!.length + + MEMORY_IDS.get("science")!.length + + MEMORY_IDS.get("cooking")!.length + ); + }); + + test("should count only unique memories in a room", async () => { + const uniqueCount = await adapter.countMemories( + seedManager.ROOM_ID, + true, + TEST_TABLE + ); + // Since our seed data creates unique memories by default + expect(uniqueCount).toBe( + MEMORY_IDS.get("programming")!.length + + MEMORY_IDS.get("science")!.length + + MEMORY_IDS.get("cooking")!.length + ); + }); + + test("should return zero count for non-existent room", async () => { + const nonExistentRoomId = stringToUuid("non-existent-room"); + const count = await adapter.countMemories( + nonExistentRoomId, + false, + TEST_TABLE + ); + expect(count).toBe(0); + }); + + test("should get memories with various filters", async () => { + // Test basic retrieval + const allMemories = await adapter.getMemories({ + roomId: seedManager.ROOM_ID, + tableName: TEST_TABLE, + }); + + // Test with count limit + const limitedMemories = await adapter.getMemories({ + roomId: seedManager.ROOM_ID, + count: 2, + tableName: TEST_TABLE, + }); + + // Test time range + const timeRangeMemories = await adapter.getMemories({ + roomId: seedManager.ROOM_ID, + tableName: TEST_TABLE, + start: Date.now() - 3600000, // 1 hour ago + end: Date.now(), + }); + + // Assertions + expect(allMemories.length).toBeGreaterThan(0); + expect(limitedMemories.length).toBe(2); + expect( + timeRangeMemories.every( + (m) => + m.createdAt && + m.createdAt >= Date.now() - 3600000 && + m.createdAt <= Date.now() + ) + ).toBe(true); + }); + + test("should handle batch memory operations", async () => { + // Get memories by IDs + const ids = MEMORY_IDS.get("programming")!.slice(0, 2); + const memoriesByIds = await adapter.getMemoriesByIds(ids, TEST_TABLE); + expect(memoriesByIds.length).toBe(2); + + // Get memories by room IDs + const roomIds = [seedManager.ROOM_ID]; + const memoriesByRooms = await adapter.getMemoriesByRoomIds({ + roomIds, + tableName: TEST_TABLE, + limit: 5, + }); + expect(memoriesByRooms.length).toBeLessThanOrEqual(5); + }); + + test("should remove all memories from room", async () => { + const roomId = stringToUuid("test-room"); + await adapter.createRoom(roomId); + + // Create test memories + const memory1 = { + id: stringToUuid("test-1"), + content: { text: "Test 1" }, + roomId, + userId: seedManager.USER_ID, + agentId: seedManager.AGENT_ID, + }; + const memory2 = { + id: stringToUuid("test-2"), + content: { text: "Test 2" }, + roomId, + userId: seedManager.USER_ID, + agentId: seedManager.AGENT_ID, + }; + + await adapter.createMemory(memory1, TEST_TABLE); + await adapter.createMemory(memory2, TEST_TABLE); + + // Verify memories exist + const beforeCount = await adapter.countMemories( + roomId, + false, + TEST_TABLE + ); + expect(beforeCount).toBe(2); + + // Remove all memories + await adapter.removeAllMemories(roomId, TEST_TABLE); + + // Verify memories removed + const afterCount = await adapter.countMemories( + roomId, + false, + TEST_TABLE + ); + expect(afterCount).toBe(0); + }); + + test.only("should handle cached embeddings retrieval", async () => { + const testMemoryId = stringToUuid("cache-test"); + const testEmbedding = new Array(384).fill(0.1); + const testContent = "test cache text specific content"; + + await adapter.createMemory( + { + id: testMemoryId, + content: { + text: testContent, + type: "message" + }, + embedding: testEmbedding, + userId: seedManager.USER_ID, + agentId: seedManager.AGENT_ID, + roomId: seedManager.ROOM_ID, + createdAt: Date.now(), + }, + TEST_TABLE + ); + + const savedMemory = await adapter.getMemoryById(testMemoryId); + elizaLogger.debug("Saved memory:", { memory: savedMemory }); + + const params = { + query_table_name: TEST_TABLE, + query_threshold: 10, + query_input: testContent, + query_field_name: "text", + query_field_sub_name: "type", + query_match_count: 5 + }; + + const results = await adapter.getCachedEmbeddings(params); + elizaLogger.debug("Search results:", { results }); + + expect(Array.isArray(results)).toBe(true); + expect(results.length).toBeGreaterThan(0); + }); + + test("should handle edge cases for batch operations", async () => { + // Empty arrays + const emptyResults = await adapter.getMemoriesByIds([]); + expect(emptyResults).toHaveLength(0); + + const emptyRoomResults = await adapter.getMemoriesByRoomIds({ + roomIds: [], + tableName: TEST_TABLE, + }); + expect(emptyRoomResults).toHaveLength(0); + + // Non-existent IDs + const fakeId = stringToUuid("fake-id"); + const nonExistentResults = await adapter.getMemoriesByIds([fakeId]); + expect(nonExistentResults).toHaveLength(0); + }); +}); diff --git a/packages/adapter-drizzle/src/__tests__/seed.ts b/packages/adapter-drizzle/src/__tests__/seed.ts new file mode 100644 index 0000000000..7e67386c3f --- /dev/null +++ b/packages/adapter-drizzle/src/__tests__/seed.ts @@ -0,0 +1,221 @@ +import fs from "fs/promises"; +import path from "path"; +import { + stringToUuid, + UUID, + getEmbeddingForTest, + Memory, + EmbeddingOptions, +} from "@elizaos/core"; +import { DrizzleDatabaseAdapter } from ".."; + +interface EmbeddingCacheEntry { + textHash: string; + text: string; + embedding: number[]; + memoryId?: string; +} + +interface EmbeddingCache { + byHash: { [hash: string]: EmbeddingCacheEntry }; + byMemoryId: { [memoryId: string]: string }; +} + +export const EMBEDDING_OPTIONS: EmbeddingOptions = { + model: "text-embedding-3-large", + endpoint: "https://api.openai.com/v1", + apiKey: process.env.OPENAI_API_KEY, + dimensions: 384, + isOllama: false, + provider: "OpenAI", +}; + +export class MemorySeedManager { + private readonly CACHE_PATH = path.join(__dirname, "embedding-cache.json"); + private cache: EmbeddingCache = { + byHash: {}, + byMemoryId: {}, + }; + + private readonly MEMORY_SETS = { + programming: [ + "JavaScript is a versatile programming language used for web development", + "Python is known for its simplicity and readability in coding", + "Java remains popular for enterprise application development", + "TypeScript adds static typing to JavaScript for better development", + "React is a popular framework for building user interfaces", + "Test-driven development emphasizes writing tests before implementing features", + "Agile methodology promotes iterative development and continuous feedback", + ], + science: [ + "Quantum physics explores the behavior of matter at atomic scales", + "Biology studies the structure and function of living organisms", + "Chemistry investigates the composition of substances", + "Astronomy examines celestial bodies and phenomena", + "Geology focuses on Earth's structure and history", + "Scientific research methodology includes observation, hypothesis testing, and data analysis", + "Experimental research methods rely on controlled variables and reproducible results", + ], + cooking: [ + "Italian cuisine emphasizes fresh ingredients and simple preparation", + "French cooking techniques form the basis of culinary arts", + "Asian fusion combines traditional flavors with modern methods", + "Baking requires precise measurements and temperature control", + "Mediterranean diet includes olive oil, vegetables, and seafood", + "Molecular gastronomy applies scientific principles to cooking", + "Kitchen workflow organization improves cooking efficiency", + ] + }; + + public readonly AGENT_ID: UUID = stringToUuid(`agent-test`); + public readonly ROOM_ID: UUID = stringToUuid(`room-test`); + public readonly USER_ID: UUID = stringToUuid(`user-test`); + + constructor() { + this.loadEmbeddingCache(); + } + + private generateTextHash(text: string): string { + return stringToUuid(`text-${text.trim().toLowerCase()}`); + } + + private async loadEmbeddingCache(): Promise { + try { + const cacheData = await fs.readFile(this.CACHE_PATH, "utf-8"); + this.cache = JSON.parse(cacheData); + this.cache.byHash = this.cache.byHash || {}; + this.cache.byMemoryId = this.cache.byMemoryId || {}; + } catch (error) { + this.cache = { byHash: {}, byMemoryId: {} }; + } + } + + private async saveEmbeddingCache(): Promise { + await fs.writeFile( + this.CACHE_PATH, + JSON.stringify(this.cache, null, 2) + ); + } + + private async getEmbeddingWithCache( + content: string, + config: EmbeddingOptions + ): Promise { + const textHash = this.generateTextHash(content); + + if (this.cache.byHash[textHash]) { + return this.cache.byHash[textHash].embedding; + } + + const embedding = await getEmbeddingForTest(content, config); + + this.cache.byHash[textHash] = { + textHash, + text: content, + embedding, + }; + + await this.saveEmbeddingCache(); + return embedding; + } + + private async addMemoryToCache( + memoryId: string, + text: string + ): Promise { + const textHash = this.generateTextHash(text); + this.cache.byMemoryId[memoryId] = textHash; + if (this.cache.byHash[textHash]) { + this.cache.byHash[textHash].memoryId = memoryId; + } + await this.saveEmbeddingCache(); + } + + private async generateMemoryData( + content: string, + config: EmbeddingOptions + ): Promise { + const contentHash = this.generateTextHash(content); + const memoryId = stringToUuid(`memory-test-${contentHash}`); + + const embedding = await this.getEmbeddingWithCache(content, config); + await this.addMemoryToCache(memoryId, content); + + return { + id: memoryId, + content: { + text: content, + type: "message", + }, + embedding, + userId: this.USER_ID, + agentId: this.AGENT_ID, + roomId: this.ROOM_ID, + }; + } + + public getEmbeddingByMemoryId(memoryId: string): number[] | null { + const textHash = this.cache.byMemoryId[memoryId]; + if (textHash && this.cache.byHash[textHash]) { + return this.cache.byHash[textHash].embedding; + } + return null; + } + + public getTextByMemoryId(memoryId: string): string | null { + const textHash = this.cache.byMemoryId[memoryId]; + if (textHash && this.cache.byHash[textHash]) { + return this.cache.byHash[textHash].text; + } + return null; + } + + public getEmbeddingByText(text: string): number[] | null { + const textHash = this.generateTextHash(text); + if (this.cache.byHash[textHash]) { + return this.cache.byHash[textHash].embedding; + } + return null; + } + + public async createMemories( + adapter: DrizzleDatabaseAdapter, + tableName: string, + ): Promise> { + const memoryIds = new Map(); + + for (const [category, contents] of Object.entries(this.MEMORY_SETS)) { + const categoryMemories: UUID[] = []; + + for (const content of contents) { + const memoryData = await this.generateMemoryData( + content, + EMBEDDING_OPTIONS + ); + await adapter.createMemory(memoryData, tableName); + categoryMemories.push(memoryData.id as UUID); + } + + memoryIds.set(category, categoryMemories); + } + + return memoryIds; + } + + public getCacheStats(): { + totalEntries: number; + memoryIdsMapped: number; + cacheSize: number; + } { + return { + totalEntries: Object.keys(this.cache.byHash).length, + memoryIdsMapped: Object.keys(this.cache.byMemoryId).length, + cacheSize: JSON.stringify(this.cache).length, + }; + } + + public async clearCache(): Promise { + this.cache = { byHash: {}, byMemoryId: {} }; + await this.saveEmbeddingCache(); + } +} diff --git a/packages/adapter-drizzle/src/__tests__/setup.ts b/packages/adapter-drizzle/src/__tests__/setup.ts new file mode 100644 index 0000000000..9bfa043322 --- /dev/null +++ b/packages/adapter-drizzle/src/__tests__/setup.ts @@ -0,0 +1,80 @@ +import { afterAll, beforeAll, beforeEach, expect, test } from "bun:test"; +import { config } from "dotenv"; + +// Load test environment variables +config({ path: ".env.test" }); + +import { DrizzleDatabaseAdapter } from "../index"; +import { getEmbeddingConfig, getEmbeddingForTest } from '@elizaos/core'; +import type { UUID } from '@elizaos/core'; +import { sql } from "drizzle-orm"; +import { v4 as uuid } from 'uuid'; + +let drizzleAdapter: DrizzleDatabaseAdapter; +let userId: UUID; +let agentId: UUID; +let roomId: UUID; + +export const TEST_DATABASE_URL = process.env.TEST_DATABASE_URL || " "; +export const OPENAI_API_KEY = process.env.OPENAI_API_KEY || " "; + +// beforeAll(async () => { +// drizzleAdapter = new DrizzleDatabaseAdapter(TEST_DATABASE_URL); + +// await drizzleAdapter.init(); + +// // Create test users +// userId = uuid() as UUID; +// agentId = uuid() as UUID; +// roomId = uuid() as UUID; + +// // Create user account +// await drizzleAdapter.createAccount({ +// id: userId, +// name: "test-user", +// username: "test-user", +// email: "test@test.com" +// }); + +// // Create agent account +// await drizzleAdapter.createAccount({ +// id: agentId, +// name: "test-agent", +// username: "test-agent", +// email: "agent@test.com" +// }); + +// // Create test room +// await drizzleAdapter.createRoom(roomId); +// }); + +// beforeEach(async () => { +// // Clear any test data before each test +// await drizzleAdapter.db.execute(sql`DELETE FROM memories WHERE TRUE`); +// await drizzleAdapter.db.execute(sql`DELETE FROM knowledge WHERE TRUE`); +// }); + +// afterAll(async () => { +// await drizzleAdapter.close(); +// }); + +// // Helper function to generate test embedding +// const generateTestEmbedding = async (text: string) => { +// const embeddingConfig = getEmbeddingConfig(); +// return await getEmbeddingForTest(text, { +// model: 'text-embedding-3-large', +// endpoint: 'https://api.openai.com/v1', +// apiKey: process.env.OPENAI_API_KEY!, +// dimensions: embeddingConfig.dimensions, +// isOllama: false, +// provider: 'OpenAI' +// }); +// } + +export { + drizzleAdapter, + userId, + agentId, + roomId, + generateTestEmbedding +}; \ No newline at end of file diff --git a/packages/adapter-drizzle/src/__tests__/utils.ts b/packages/adapter-drizzle/src/__tests__/utils.ts new file mode 100644 index 0000000000..8a54051562 --- /dev/null +++ b/packages/adapter-drizzle/src/__tests__/utils.ts @@ -0,0 +1,225 @@ +import { DrizzleDatabaseAdapter } from "../index"; +import { elizaLogger } from "@elizaos/core"; +import Docker from "dockerode"; +import getPort from "get-port"; +import pg from "pg"; +import { v4 as uuid } from "uuid"; + +const { Client } = pg; + +export type DatabaseConnection = { + client: pg.Client; + adapter: DrizzleDatabaseAdapter; + docker: Docker; + container: Docker.Container; +}; + +export async function createDockerDB(docker: Docker): Promise { + const port = await getPort({ port: 5432 }); + const image = "pgvector/pgvector:pg16"; + + const pullStream = await docker.pull(image); + await new Promise((resolve, reject) => + docker.modem.followProgress(pullStream, (err) => + err ? reject(err) : resolve(err) + ) + ); + + const container = await docker.createContainer({ + Image: image, + Env: [ + "POSTGRES_PASSWORD=postgres", + "POSTGRES_USER=postgres", + "POSTGRES_DB=postgres", + ], + name: `drizzle-integration-tests-${uuid()}`, + HostConfig: { + AutoRemove: true, + PortBindings: { + "5432/tcp": [{ HostPort: `${port}` }], + }, + }, + }); + + await container.start(); + + return `postgres://postgres:postgres@localhost:${port}/postgres`; +} + +export async function connectDatabase(): Promise { + const docker = new Docker(); + const connectionString = + process.env["PG_VECTOR_CONNECTION_STRING"] ?? + (await createDockerDB(docker)); + + const sleep = 250; + let timeLeft = 5000; + let connected = false; + let lastError: unknown | undefined; + let client: pg.Client | undefined; + let container: Docker.Container | undefined; + + // Get the container reference if we created one + if (!process.env["PG_VECTOR_CONNECTION_STRING"]) { + const containers = await docker.listContainers(); + container = docker.getContainer( + containers.find((c) => + c.Names[0].includes("drizzle-integration-tests") + )?.Id! + ); + } + + do { + try { + client = new Client(connectionString); + await client.connect(); + connected = true; + break; + } catch (e) { + lastError = e; + await new Promise((resolve) => setTimeout(resolve, sleep)); + timeLeft -= sleep; + } + } while (timeLeft > 0); + + if (!connected || !client) { + elizaLogger.error("Cannot connect to Postgres"); + await client?.end().catch(console.error); + await container?.stop().catch(console.error); + throw lastError; + } + + const adapter = new DrizzleDatabaseAdapter(connectionString); + + return { + client, + adapter, + docker, + container: container!, + }; +} + +export const parseVectorString = (vectorStr: string): number[] => { + if (!vectorStr) return []; + // Remove brackets and split by comma + return vectorStr.replace(/[[\]]/g, "").split(",").map(Number); +}; + +export async function cleanDatabase(client: pg.Client) { + try { + await client.query("DROP TABLE IF EXISTS relationships CASCADE"); + await client.query("DROP TABLE IF EXISTS participants CASCADE"); + await client.query("DROP TABLE IF EXISTS logs CASCADE"); + await client.query("DROP TABLE IF EXISTS goals CASCADE"); + await client.query("DROP TABLE IF EXISTS memories CASCADE"); + await client.query("DROP TABLE IF EXISTS rooms CASCADE"); + await client.query("DROP TABLE IF EXISTS accounts CASCADE"); + await client.query("DROP TABLE IF EXISTS cache CASCADE"); + await client.query("DROP EXTENSION IF EXISTS vector CASCADE"); + await client.query("DROP SCHEMA IF EXISTS extensions CASCADE"); + await client.query("DROP TABLE IF EXISTS __drizzle_migrations"); + elizaLogger.success("Database cleanup completed successfully"); + } catch (error) { + elizaLogger.error( + `Database cleanup failed: ${ + error instanceof Error ? error.message : "Unknown error" + }` + ); + throw error; + } +} + +export async function stopContainers(client: pg.Client, docker: Docker) { + try { + // First, terminate all existing connections except our current one + await client.query(` + SELECT pg_terminate_backend(pid) + FROM pg_stat_activity + WHERE pid <> pg_backend_pid() + AND datname = current_database() + `); + + // Wait a bit for connections to terminate + await new Promise((resolve) => setTimeout(resolve, 1000)); + + // Now end our client connection + await client.end(); + + // Get all containers with our test prefix + const containers = await docker.listContainers({ + all: true, + filters: { + name: ["drizzle-integration-tests"], + }, + }); + + // Stop all matching containers + await Promise.all( + containers.map(async (containerInfo) => { + const container = docker.getContainer(containerInfo.Id); + try { + // Force stop to ensure it terminates + await container.stop({ t: 5 }); + elizaLogger.success( + `Stopped container: ${containerInfo.Id.substring( + 0, + 12 + )}` + ); + } catch (error) { + if ( + error instanceof Error && + !error.message.includes("container already stopped") + ) { + elizaLogger.error( + `Failed to stop container ${containerInfo.Id.substring( + 0, + 12 + )}: ${error.message}` + ); + } + } + }) + ); + + // Wait for containers to fully stop + await new Promise((resolve) => setTimeout(resolve, 1000)); + } catch (error) { + elizaLogger.error( + `Container cleanup failed: ${ + error instanceof Error ? error.message : "Unknown error" + }` + ); + throw error; + } +} + +export const initializeDatabase = async (client: pg.Client) => { + try { + await client.query(` + ALTER DATABASE postgres SET app.use_openai_embedding = 'true'; + ALTER DATABASE postgres SET app.use_ollama_embedding = 'false'; + `); + + await client.query("CREATE EXTENSION IF NOT EXISTS vector"); + + const { rows: vectorExt } = await client.query(` + SELECT * FROM pg_extension WHERE extname = 'vector' + `); + elizaLogger.info("Vector extension status:", { + isInstalled: vectorExt.length > 0, + }); + + const { rows: searchPath } = await client.query("SHOW search_path"); + elizaLogger.info("Search path:", { + searchPath: searchPath[0].search_path, + }); + } catch (error) { + elizaLogger.error( + `Database initialization failed: ${ + error instanceof Error ? error.message : "Unknown error" + }` + ); + throw error; + } +}; \ No newline at end of file diff --git a/packages/adapter-drizzle/src/__tests__/vector-extension.test.ts b/packages/adapter-drizzle/src/__tests__/vector-extension.test.ts new file mode 100644 index 0000000000..7acd3c8dcf --- /dev/null +++ b/packages/adapter-drizzle/src/__tests__/vector-extension.test.ts @@ -0,0 +1,116 @@ +import { + describe, + expect, + test, + beforeEach, + afterEach, +} from "bun:test"; +import { DrizzleDatabaseAdapter } from "../index"; +import { elizaLogger, stringToUuid } from "@elizaos/core"; +import Docker from "dockerode"; +import pg from "pg"; +import { + connectDatabase, + initializeDatabase, + cleanDatabase, + stopContainers, +} from "./utils.ts"; + +describe("DrizzleDatabaseAdapter - Vector Extension Validation", () => { + describe("Schema and Extension Management", () => { + let adapter: DrizzleDatabaseAdapter; + let client: pg.Client; + let docker: Docker; + + beforeEach(async () => { + ({ client, adapter, docker } = await connectDatabase()); + await initializeDatabase(client); + }); + + afterEach(async () => { + await cleanDatabase(client); + // Wait for cleanup to complete + await new Promise((resolve) => setTimeout(resolve, 500)); + await stopContainers(client, docker); + }); + + test("should initialize with vector extension", async () => { + elizaLogger.info("Testing vector extension initialization..."); + try { + await adapter.init(); + + const { rows } = await client.query(` + SELECT 1 FROM pg_extension WHERE extname = 'vector' + `); + expect(rows.length).toBe(1); + elizaLogger.success("Vector extension verified successfully"); + } catch (error) { + elizaLogger.error( + `Vector extension test failed: ${ + error instanceof Error ? error.message : "Unknown error" + }` + ); + throw error; + } + }); + + test("should handle missing rooms table", async () => { + try { + // First initialize adapter which should create the rooms table + await adapter.init(); + + const id = stringToUuid("test-room"); + + // Try creating new room + await adapter.createRoom(id); + + // Try getting room + const roomId = await adapter.getRoom(id); + expect(roomId).toEqual(id); + + elizaLogger.success("Rooms table verified successfully"); + } catch (error) { + elizaLogger.error( + `Rooms table test failed: ${ + error instanceof Error ? error.message : "Unknown error" + }` + ); + throw error; + } + }); + + test("should not reapply schema when everything exists", async () => { + elizaLogger.info("Testing schema reapplication prevention..."); + try { + // First initialization + await adapter.init(); + + // Get table count after first initialization + const { rows: firstCount } = await client.query(` + SELECT count(*) FROM information_schema.tables + WHERE table_schema = 'public' + `); + + // Second initialization + await adapter.init(); + + // Get table count after second initialization + const { rows: secondCount } = await client.query(` + SELECT count(*) FROM information_schema.tables + WHERE table_schema = 'public' + `); + + // Verify counts are the same + expect(firstCount[0].count).toEqual(secondCount[0].count); + elizaLogger.success("Verified schema was not reapplied"); + } catch (error) { + elizaLogger.error( + `Schema reapplication test failed: ${ + error instanceof Error ? error.message : "Unknown error" + }` + ); + throw error; + } + }); + }); +}); diff --git a/packages/adapter-drizzle/src/index.ts b/packages/adapter-drizzle/src/index.ts new file mode 100644 index 0000000000..34b761c6ad --- /dev/null +++ b/packages/adapter-drizzle/src/index.ts @@ -0,0 +1,1599 @@ +import { + type Account, + type Actor, + DatabaseAdapter, + EmbeddingProvider, + type GoalStatus, + type Participant, + type RAGKnowledgeItem, + elizaLogger, + getEmbeddingConfig, + type Goal, + type IDatabaseCacheAdapter, + type Memory, + type Relationship, + type UUID, +} from "@elizaos/core"; +import { and, eq, gte, lte, sql, desc, inArray, or, cosineDistance, gt } from "drizzle-orm"; +import { + accounts, + goals, + logs, + memories, + participants, + relationships, + rooms, + knowledges, + caches, +} from "./schema"; +import { drizzle, BunSQLDatabase } from "drizzle-orm/bun-sql"; +import { v4 as uuid } from "uuid"; +import { runMigrations } from "./migrations"; +import { Pool } from "pg"; + +export class DrizzleDatabaseAdapter + extends DatabaseAdapter + implements IDatabaseCacheAdapter +{ + private databaseUrl: string; + constructor( + databaseUrl: string, + circuitBreakerConfig?: { + failureThreshold?: number; + resetTimeout?: number; + halfOpenMaxAttempts?: number; + } + ) { + super({ + failureThreshold: circuitBreakerConfig?.failureThreshold ?? 5, + resetTimeout: circuitBreakerConfig?.resetTimeout ?? 60000, + halfOpenMaxAttempts: circuitBreakerConfig?.halfOpenMaxAttempts ?? 3, + }); + + this.databaseUrl = databaseUrl; + this.db = drizzle(databaseUrl); + } + + async init(): Promise { + try { + elizaLogger.info("init () STARTING"); + const embeddingConfig = getEmbeddingConfig(); + elizaLogger.info("init () EMBEDDING CONFIG", embeddingConfig); + if (embeddingConfig.provider === EmbeddingProvider.OpenAI) { + await this.db.execute(sql`SET app.use_openai_embedding = 'true'`); + await this.db.execute(sql`SET app.use_ollama_embedding = 'false'`); + await this.db.execute(sql`SET app.use_gaianet_embedding = 'false'`); + } else if (embeddingConfig.provider === EmbeddingProvider.Ollama) { + await this.db.execute(sql`SET app.use_openai_embedding = 'false'`); + await this.db.execute(sql`SET app.use_ollama_embedding = 'true'`); + await this.db.execute(sql`SET app.use_gaianet_embedding = 'false'`); + } else if (embeddingConfig.provider === EmbeddingProvider.GaiaNet) { + await this.db.execute(sql`SET app.use_openai_embedding = 'false'`); + await this.db.execute(sql`SET app.use_ollama_embedding = 'false'`); + await this.db.execute(sql`SET app.use_gaianet_embedding = 'true'`); + } else { + await this.db.execute(sql`SET app.use_openai_embedding = 'false'`); + await this.db.execute(sql`SET app.use_ollama_embedding = 'false'`); + await this.db.execute(sql`SET app.use_gaianet_embedding = 'false'`); + } + + const exists: boolean = await this.checkTable(); + + elizaLogger.info("####### exists", exists); + + if (!exists || !(await this.validateVectorSetup())) { + elizaLogger.info("####### running migrations"); + const pool = new Pool({ + connectionString: this.databaseUrl, + }); + await runMigrations(pool); + } + } catch (error) { + elizaLogger.error("Failed to initialize database:", error); + throw error; + } + } + + private async checkTable(): Promise { + try { + const result = await this.db.execute<{ + to_regclass: string | null; + }>(sql` + SELECT to_regclass('public.rooms') as to_regclass + `); + return Boolean(result[0]?.to_regclass); + } catch (error) { + elizaLogger.error("### checkTable() error ###", error); + return false; + } + } + + private async validateVectorSetup(): Promise { + try { + const vectorExt = await this.db.execute(sql` + SELECT * FROM pg_extension WHERE extname = 'vector' + `); + + const hasVector = vectorExt?.length > 0; + + if (!hasVector) { + elizaLogger.warn("Vector extension not found"); + return false; + } + + return true; + } catch (error) { + elizaLogger.error("Error validating vector setup:", error); + return false; + } + } + + async close(): Promise { + try { + // For Bun SQL we just need to close the main connection + if (this.db && (this.db as any).client) { + await (this.db as any).client.close(); + } + } catch (error) { + elizaLogger.error("Failed to close database connection:", { + error: error instanceof Error ? error.message : String(error), + }); + throw error; + } + } + + async getAccountById(userId: UUID): Promise { + try { + const result = await this.db + .select() + .from(accounts) + .where(eq(accounts.id, userId)) + .limit(1); + + if (result.length === 0) return null; + + const account = result[0]; + + return { + id: account.id as UUID, + name: account.name ?? "", + username: account.username ?? "", + email: account.email ?? "", + avatarUrl: account.avatarUrl ?? "", + details: account.details ?? {}, + }; + } catch (error) { + elizaLogger.error("Failed to get account by ID:", { + error: error instanceof Error ? error.message : String(error), + userId, + }); + throw error; + } + } + + async createAccount(account: Account): Promise { + try { + const accountId = account.id ?? uuid(); + + await this.db.insert(accounts).values({ + id: accountId, + name: account.name ?? null, + username: account.username ?? null, + email: account.email ?? "", + avatarUrl: account.avatarUrl ?? null, + details: account.details ?? {}, + }); + + elizaLogger.debug("Account created successfully:", { + accountId, + }); + + return true; + } catch (error) { + elizaLogger.error("Error creating account:", { + error: error instanceof Error ? error.message : String(error), + accountId: account.id, + }); + return false; + } + } + + async getMemories(params: { + roomId: UUID; + count?: number; + unique?: boolean; + tableName: string; + agentId?: UUID; + start?: number; + end?: number; + }): Promise { + if (!params.tableName) throw new Error("tableName is required"); + if (!params.roomId) throw new Error("roomId is required"); + + try { + const conditions = [ + eq(memories.type, params.tableName), + eq(memories.roomId, params.roomId), + ]; + + if (params.start) { + conditions.push( + gte(memories.createdAt, new Date(params.start)) + ); + } + + if (params.end) { + conditions.push(lte(memories.createdAt, new Date(params.end))); + } + + if (params.unique) { + conditions.push(eq(memories.unique, true)); + } + + if (params.agentId) { + conditions.push(eq(memories.agentId, params.agentId)); + } + + const query = this.db + .select() + .from(memories) + .where(and(...conditions)) + .orderBy(desc(memories.createdAt)); + + const rows = params.count + ? await query.limit(params.count) + : await query; + + elizaLogger.debug("Fetching memories:", { + roomId: params.roomId, + tableName: params.tableName, + unique: params.unique, + agentId: params.agentId, + timeRange: + params.start || params.end + ? { + start: params.start + ? new Date(params.start).toISOString() + : undefined, + end: params.end + ? new Date(params.end).toISOString() + : undefined, + } + : undefined, + limit: params.count, + }); + + return rows.map((row) => ({ + id: row.id as UUID, + type: row.type, + createdAt: row.createdAt ? row.createdAt.getTime() : Date.now(), + content: + typeof row.content === "string" + ? JSON.parse(row.content) + : row.content, + embedding: row.embedding ?? undefined, + userId: row.userId as UUID, + agentId: row.agentId as UUID, + roomId: row.roomId as UUID, + unique: row.unique, + })); + } catch (error) { + elizaLogger.error("Failed to fetch memories:", { + error: error instanceof Error ? error.message : String(error), + params, + }); + throw error; + } + } + + async getMemoriesByRoomIds(params: { + roomIds: UUID[]; + agentId?: UUID; + tableName: string; + limit?: number; + }): Promise { + try { + if (params.roomIds.length === 0) return []; + + const conditions = [ + eq(memories.type, params.tableName), + inArray(memories.roomId, params.roomIds), + ]; + + if (params.agentId) { + conditions.push(eq(memories.agentId, params.agentId)); + } + + const query = this.db + .select() + .from(memories) + .where(and(...conditions)) + .orderBy(desc(memories.createdAt)); + + const rows = params.limit + ? await query.limit(params.limit) + : await query; + + return rows.map((row) => ({ + id: row.id as UUID, + createdAt: row.createdAt ? row.createdAt.getTime() : Date.now(), + content: + typeof row.content === "string" + ? JSON.parse(row.content) + : row.content, + embedding: row.embedding, + userId: row.userId as UUID, + agentId: row.agentId as UUID, + roomId: row.roomId as UUID, + unique: row.unique, + })) as Memory[]; + } catch (error) { + elizaLogger.error("Error in getMemoriesByRoomIds:", { + error: error instanceof Error ? error.message : String(error), + roomIds: params.roomIds, + tableName: params.tableName, + }); + throw error; + } + } + + async getMemoryById(id: UUID): Promise { + try { + const result = await this.db + .select() + .from(memories) + .where(eq(memories.id, id)) + .limit(1); + + if (result.length === 0) return null; + + const row = result[0]; + return { + id: row.id as UUID, + createdAt: row.createdAt ? row.createdAt.getTime() : Date.now(), + content: + typeof row.content === "string" + ? JSON.parse(row.content) + : row.content, + embedding: row.embedding ?? undefined, + userId: row.userId as UUID, + agentId: row.agentId as UUID, + roomId: row.roomId as UUID, + unique: row.unique, + }; + } catch (error) { + elizaLogger.error("Error in getMemoryById:", error); + throw error; + } + } + + async getMemoriesByIds( + memoryIds: UUID[], + tableName?: string + ): Promise { + if (memoryIds.length === 0) return []; + + try { + const conditions = [inArray(memories.id, memoryIds)]; + + if (tableName) { + conditions.push(eq(memories.type, tableName)); + } + + const rows = await this.db + .select() + .from(memories) + .where(and(...conditions)) + .orderBy(desc(memories.createdAt)); + + return rows.map((row) => ({ + id: row.id as UUID, + createdAt: row.createdAt ? row.createdAt.getTime() : Date.now(), + content: + typeof row.content === "string" + ? JSON.parse(row.content) + : row.content, + embedding: row.embedding ?? undefined, + userId: row.userId as UUID, + agentId: row.agentId as UUID, + roomId: row.roomId as UUID, + unique: row.unique, + })); + } catch (error) { + elizaLogger.error("Failed to fetch memories by IDs:", { + error: error instanceof Error ? error.message : String(error), + memoryIds, + tableName, + }); + throw error; + } + } + + async getCachedEmbeddings(params: { + query_table_name: string; + query_threshold: number; + query_input: string; + query_field_name: string; + query_field_sub_name: string; + query_match_count: number; + }): Promise<{ embedding: number[]; levenshtein_score: number }[]> { + try { + const results = await this.db.execute<{ + embedding: number[]; + levenshtein_score: number; + }>(sql` + WITH content_text AS ( + SELECT + embedding, + content#>>'{text}' as content_text + FROM memories + WHERE type = ${params.query_table_name} + ) + SELECT + embedding, + levenshtein(${params.query_input}, content_text) as levenshtein_score + FROM content_text + WHERE content_text IS NOT NULL + AND levenshtein(${params.query_input}, content_text) <= ${params.query_threshold} + ORDER BY levenshtein_score + LIMIT ${params.query_match_count} + `); + + return results + .map(row => ({ + embedding: Array.isArray(row.embedding) ? row.embedding : + typeof row.embedding === 'string' ? JSON.parse(row.embedding) : [], + levenshtein_score: Number(row.levenshtein_score) + })) + .filter(row => Array.isArray(row.embedding)); + } catch (error) { + elizaLogger.error("Error in getCachedEmbeddings:", error); + throw error; + } + } + + async log(params: { + body: { [key: string]: unknown }; + userId: UUID; + roomId: UUID; + type: string; + }): Promise { + try { + const logId = uuid(); + + elizaLogger.debug("Creating log entry:", { + logId, + type: params.type, + roomId: params.roomId, + userId: params.userId, + bodyKeys: Object.keys(params.body), + }); + + await this.db.insert(logs).values({ + body: params.body, + userId: params.userId, + roomId: params.roomId, + type: params.type, + }); + } catch (error) { + elizaLogger.error("Failed to create log entry:", { + error: error instanceof Error ? error.message : String(error), + type: params.type, + roomId: params.roomId, + userId: params.userId, + }); + throw error; + } + } + + async getActorDetails(params: { roomId: UUID }): Promise { + if (!params.roomId) { + throw new Error("roomId is required"); + } + + try { + const result = await this.db + .select({ + id: accounts.id, + name: accounts.name, + username: accounts.username, + details: accounts.details, + }) + .from(participants) + .leftJoin(accounts, eq(participants.userId, accounts.id)) + .where(eq(participants.roomId, params.roomId)) + .orderBy(accounts.name); + + elizaLogger.debug("Retrieved actor details:", { + roomId: params.roomId, + actorCount: result.length, + }); + + return result.map((row) => { + try { + const details = + typeof row.details === "string" + ? JSON.parse(row.details) + : row.details || {}; + + return { + id: row.id as UUID, + name: row.name ?? "", + username: row.username ?? "", + details: { + tagline: details.tagline ?? "", + summary: details.summary ?? "", + quote: details.quote ?? "", + }, + }; + } catch (error) { + elizaLogger.warn("Failed to parse actor details:", { + actorId: row.id, + error: + error instanceof Error + ? error.message + : String(error), + }); + + return { + id: row.id as UUID, + name: row.name ?? "", + username: row.username ?? "", + details: { + tagline: "", + summary: "", + quote: "", + }, + }; + } + }); + } catch (error) { + elizaLogger.error("Failed to fetch actor details:", { + roomId: params.roomId, + error: error instanceof Error ? error.message : String(error), + }); + throw error; + } + } + + async searchMemories(params: { + tableName: string; + agentId: UUID; + roomId: UUID; + embedding: number[]; + match_threshold: number; + match_count: number; + unique: boolean; + }): Promise { + try { + return await this.searchMemoriesByEmbedding(params.embedding, { + match_threshold: params.match_threshold, + count: params.match_count, + agentId: params.agentId, + roomId: params.roomId, + unique: params.unique, + tableName: params.tableName, + }); + } catch (error) { + elizaLogger.error("Failed to search memories:", { + error: error instanceof Error ? error.message : String(error), + tableName: params.tableName, + agentId: params.agentId, + roomId: params.roomId, + }); + throw error; + } + } + + async updateGoalStatus(params: { + goalId: UUID; + status: GoalStatus; + }): Promise { + try { + await this.db + .update(goals) + .set({ status: params.status }) + .where(eq(goals.id, params.goalId)); + + elizaLogger.debug("Updated goal status:", { + goalId: params.goalId, + newStatus: params.status, + }); + } catch (error) { + elizaLogger.error("Failed to update goal status:", { + error: error instanceof Error ? error.message : String(error), + goalId: params.goalId, + status: params.status, + }); + throw error; + } + } + + async searchMemoriesByEmbedding( + embedding: number[], + params: { + match_threshold?: number; + count?: number; + roomId?: UUID; + agentId?: UUID; + unique?: boolean; + tableName: string; + } + ): Promise { + try { + // Ensure vector is properly formatted + const cleanVector = embedding.map((n) => { + if (!Number.isFinite(n)) return 0; + // Limit precision to avoid floating point issues + return Number(n.toFixed(6)); + }); + + const similarity = sql`1 - (${cosineDistance(memories.embedding, cleanVector)})`; + + const conditions = [eq(memories.type, params.tableName)]; + + if (params.unique) { + conditions.push(eq(memories.unique, true)); + } + if (params.agentId) { + conditions.push(eq(memories.agentId, params.agentId)); + } + if (params.roomId) { + conditions.push(eq(memories.roomId, params.roomId)); + } + + if (params.match_threshold) { + conditions.push(gte(similarity, params.match_threshold)); + } + + const results = await this.db + .select({ + id: memories.id, + type: memories.type, + createdAt: memories.createdAt, + content: memories.content, + embedding: memories.embedding, + userId: memories.userId, + agentId: memories.agentId, + roomId: memories.roomId, + unique: memories.unique, + similarity: similarity, + }) + .from(memories) + .where(and(...conditions)) + .orderBy(desc(similarity)) + .limit(params.count ?? 10); + + return results.map(row => ({ + id: row.id as UUID, + type: row.type, + createdAt: row.createdAt ? row.createdAt.getTime() : Date.now(), + content: typeof row.content === "string" + ? JSON.parse(row.content) + : row.content, + embedding: row.embedding ?? undefined, + userId: row.userId as UUID, + agentId: row.agentId as UUID, + roomId: row.roomId as UUID, + unique: row.unique, + similarity: row.similarity, + })); + } catch (error) { + elizaLogger.error("Failed to search memories by embedding:", { + error: error instanceof Error ? error.message : String(error), + vectorLength: embedding.length, + tableName: params.tableName, + roomId: params.roomId, + agentId: params.agentId, + }); + throw error; + } + } + + async createMemory( + memory: Memory, + tableName: string, + unique?: boolean + ): Promise { + try { + elizaLogger.info("DrizzleAdapter createMemory:", { + memoryId: memory.id, + embeddingLength: memory.embedding?.length, + contentLength: memory.content?.text?.length, + }); + + let isUnique = true; + if (memory.embedding) { + elizaLogger.info("Searching for similar memories:"); + const similarMemories = await this.searchMemoriesByEmbedding( + memory.embedding, + { + tableName, + roomId: memory.roomId, + match_threshold: 0.95, + count: 1, + } + ); + isUnique = similarMemories.length === 0; + } + + elizaLogger.info("Inserting memory:"); + + await this.db.insert(memories).values([ + { + id: memory.id ?? uuid(), + type: tableName, + content: memory.content as any, + embedding: memory.embedding, + userId: memory.userId, + roomId: memory.roomId, + agentId: memory.agentId, + unique: memory.unique ?? isUnique, + createdAt: memory.createdAt + ? new Date(memory.createdAt) + : new Date(), + }, + ]); + } catch (error) { + elizaLogger.debug("$$$$$errrrrror:", error); + elizaLogger.error("Failed to create memory:", { + error: error instanceof Error ? error.message : String(error), + memoryId: memory.id, + tableName, + roomId: memory.roomId, + }); + throw error; + } + } + + async removeMemory(memoryId: UUID, tableName: string): Promise { + try { + await this.db + .delete(memories) + .where( + and(eq(memories.id, memoryId), eq(memories.type, tableName)) + ); + + elizaLogger.debug("Memory removed successfully:", { + memoryId, + tableName, + }); + } catch (error) { + elizaLogger.info("ERROR?????????????", error); + elizaLogger.error("Failed to remove memory:", { + error: error instanceof Error ? error.message : String(error), + memoryId, + tableName, + }); + throw error; + } + } + + async removeAllMemories(roomId: UUID, tableName: string): Promise { + try { + await this.db + .delete(memories) + .where( + and( + eq(memories.roomId, roomId), + eq(memories.type, tableName) + ) + ); + + elizaLogger.debug("All memories removed successfully:", { + roomId, + tableName, + }); + } catch (error) { + elizaLogger.error("Failed to remove all memories:", { + error: error instanceof Error ? error.message : String(error), + roomId, + tableName, + }); + throw error; + } + } + + async countMemories( + roomId: UUID, + unique = true, + tableName = "" + ): Promise { + if (!tableName) throw new Error("tableName is required"); + + try { + const conditions = [ + eq(memories.roomId, roomId), + eq(memories.type, tableName), + ]; + + if (unique) { + conditions.push(eq(memories.unique, true)); + } + + const result = await this.db + .select({ count: sql`count(*)` }) + .from(memories) + .where(and(...conditions)); + + return Number(result[0]?.count ?? 0); + } catch (error) { + elizaLogger.error("Failed to count memories:", { + error: error instanceof Error ? error.message : String(error), + roomId, + tableName, + unique, + }); + throw error; + } + } + + async getGoals(params: { + roomId: UUID; + userId?: UUID | null; + onlyInProgress?: boolean; + count?: number; + }): Promise { + try { + const conditions = [eq(goals.roomId, params.roomId)]; + + if (params.userId) { + conditions.push(eq(goals.userId, params.userId)); + } + + if (params.onlyInProgress) { + conditions.push(eq(goals.status, "IN_PROGRESS" as GoalStatus)); + } + + const query = this.db + .select() + .from(goals) + .where(and(...conditions)) + .orderBy(desc(goals.createdAt)); + + const result = await (params.count + ? query.limit(params.count) + : query); + + return result.map((row) => ({ + id: row.id as UUID, + roomId: row.roomId as UUID, + userId: row.userId as UUID, + name: row.name ?? "", + status: (row.status ?? "NOT_STARTED") as GoalStatus, + description: row.description ?? "", + objectives: row.objectives as any[], + createdAt: row.createdAt?.getTime() ?? Date.now(), + })); + } catch (error) { + elizaLogger.error("Failed to get goals:", { + error: error instanceof Error ? error.message : String(error), + roomId: params.roomId, + userId: params.userId, + onlyInProgress: params.onlyInProgress, + }); + throw error; + } + } + + async updateGoal(goal: Goal): Promise { + try { + await this.db + .update(goals) + .set({ + name: goal.name, + status: goal.status, + objectives: goal.objectives, + }) + .where(eq(goals.id, goal.id as string)); + } catch (error) { + elizaLogger.error("Failed to update goal:", { + error: error instanceof Error ? error.message : String(error), + goalId: goal.id, + status: goal.status, + }); + throw error; + } + } + + async createGoal(goal: Goal): Promise { + try { + await this.db.insert(goals).values({ + id: goal.id ?? uuid(), + roomId: goal.roomId, + userId: goal.userId, + name: goal.name, + status: goal.status, + objectives: goal.objectives, + }); + } catch (error) { + elizaLogger.error("Failed to create goal:", { + error: error instanceof Error ? error.message : String(error), + goalId: goal.id, + }); + throw error; + } + } + + async removeGoal(goalId: UUID): Promise { + if (!goalId) throw new Error("Goal ID is required"); + + try { + await this.db.delete(goals).where(eq(goals.id, goalId)); + + elizaLogger.debug("Goal removal attempt:", { + goalId, + removed: true, + }); + } catch (error) { + elizaLogger.error("Failed to remove goal:", { + error: error instanceof Error ? error.message : String(error), + goalId, + }); + throw error; + } + } + + async removeAllGoals(roomId: UUID): Promise { + try { + await this.db.delete(goals).where(eq(goals.roomId, roomId)); + } catch (error) { + elizaLogger.error("Failed to remove all goals:", { + error: error instanceof Error ? error.message : String(error), + roomId, + }); + throw error; + } + } + + async getRoom(roomId: UUID): Promise { + try { + const result = await this.db + .select({ + id: rooms.id, + }) + .from(rooms) + .where(eq(rooms.id, roomId)) + .limit(1); + + return (result[0]?.id as UUID) ?? null; + } catch (error) { + elizaLogger.error("Failed to get room:", { + error: error instanceof Error ? error.message : String(error), + roomId, + }); + throw error; + } + } + + async createRoom(roomId?: UUID): Promise { + try { + const id = roomId ?? uuid(); + + await this.db.insert(rooms).values([ + { + id: id as string, + }, + ]); + + return id as UUID; + } catch (error) { + elizaLogger.error("Failed to create room:", { + error: error instanceof Error ? error.message : String(error), + roomId, + }); + throw error; + } + } + + async removeRoom(roomId: UUID): Promise { + try { + await this.db.delete(rooms).where(eq(rooms.id, roomId)); + } catch (error) { + elizaLogger.error("Failed to remove room:", { + error: error instanceof Error ? error.message : String(error), + roomId, + }); + throw error; + } + } + + async getRoomsForParticipant(userId: UUID): Promise { + const result = await this.db + .select({ roomId: participants.roomId }) + .from(participants) + .where(eq(participants.userId, userId)); + + return result.map((row) => row.roomId as UUID); + } + + async getRoomsForParticipants(userIds: UUID[]): Promise { + const result = await this.db + .selectDistinct({ roomId: participants.roomId }) + .from(participants) + .where(inArray(participants.userId, userIds)); + + return result.map((row) => row.roomId as UUID); + } + + async addParticipant(userId: UUID, roomId: UUID): Promise { + try { + await this.db.insert(participants).values({ + id: uuid(), + userId, + roomId, + }); + return true; + } catch (error) { + console.log("Error adding participant", error); + return false; + } + } + + async removeParticipant(userId: UUID, roomId: UUID): Promise { + try { + const result = await this.db + .delete(participants) + .where( + and( + eq(participants.userId, userId), + eq(participants.roomId, roomId) + ) + ) + .returning(); + + return result.length > 0; + } catch (error) { + elizaLogger.error("Failed to remove participant:", { + error: error instanceof Error ? error.message : String(error), + userId, + roomId, + }); + throw error; + } + } + + async getParticipantsForAccount(userId: UUID): Promise { + try { + const result = await this.db + .select({ + id: participants.id, + userId: participants.userId, + roomId: participants.roomId, + lastMessageRead: participants.lastMessageRead, + }) + .from(participants) + .where(eq(participants.userId, userId)); + + const account = await this.getAccountById(userId); + + return result.map((row) => ({ + id: row.id as UUID, + account: account!, + })); + } catch (error) { + elizaLogger.error("Failed to get participants for account:", { + error: error instanceof Error ? error.message : String(error), + userId, + }); + throw error; + } + } + + async getParticipantsForRoom(roomId: UUID): Promise { + try { + const result = await this.db + .select({ userId: participants.userId }) + .from(participants) + .where(eq(participants.roomId, roomId)); + + return result.map((row) => row.userId as UUID); + } catch (error) { + elizaLogger.error("Failed to get participants for room:", { + error: error instanceof Error ? error.message : String(error), + roomId, + }); + throw error; + } + } + + async getParticipantUserState( + roomId: UUID, + userId: UUID + ): Promise<"FOLLOWED" | "MUTED" | null> { + try { + const result = await this.db + .select({ userState: participants.userState }) + .from(participants) + .where( + and( + eq(participants.roomId, roomId), + eq(participants.userId, userId) + ) + ) + .limit(1); + + return ( + (result[0]?.userState as "FOLLOWED" | "MUTED" | null) ?? null + ); + } catch (error) { + elizaLogger.error("Failed to get participant user state:", { + error: error instanceof Error ? error.message : String(error), + roomId, + userId, + }); + throw error; + } + } + + async setParticipantUserState( + roomId: UUID, + userId: UUID, + state: "FOLLOWED" | "MUTED" | null + ): Promise { + try { + await this.db + .update(participants) + .set({ userState: state }) + .where( + and( + eq(participants.roomId, roomId), + eq(participants.userId, userId) + ) + ); + } catch (error) { + elizaLogger.error("Failed to set participant user state:", { + error: error instanceof Error ? error.message : String(error), + roomId, + userId, + state, + }); + throw error; + } + } + + async createRelationship(params: { + userA: UUID; + userB: UUID; + }): Promise { + try { + const relationshipId = uuid(); + await this.db.insert(relationships).values({ + id: relationshipId, + userA: params.userA, + userB: params.userB, + userId: params.userA, + }); + + elizaLogger.debug("Relationship created successfully:", { + relationshipId, + userA: params.userA, + userB: params.userB, + }); + + return true; + } catch (error) { + if ((error as { code?: string }).code === "23505") { + // Unique violation + elizaLogger.warn("Relationship already exists:", { + userA: params.userA, + userB: params.userB, + error: + error instanceof Error ? error.message : String(error), + }); + return false; + } + + elizaLogger.error("Failed to create relationship:", { + error: error instanceof Error ? error.message : String(error), + userA: params.userA, + userB: params.userB, + }); + return false; + } + } + + async getRelationship(params: { + userA: UUID; + userB: UUID; + }): Promise { + try { + const result = await this.db + .select() + .from(relationships) + .where( + or( + and( + eq(relationships.userA, params.userA), + eq(relationships.userB, params.userB) + ), + and( + eq(relationships.userA, params.userB), + eq(relationships.userB, params.userA) + ) + ) + ) + .limit(1); + + if (result.length > 0) { + return result[0] as unknown as Relationship; + } + + elizaLogger.debug("No relationship found between users:", { + userA: params.userA, + userB: params.userB, + }); + return null; + } catch (error) { + elizaLogger.error("Error fetching relationship:", { + error: error instanceof Error ? error.message : String(error), + userA: params.userA, + userB: params.userB, + }); + throw error; + } + } + + async getRelationships(params: { userId: UUID }): Promise { + try { + const result = await this.db + .select() + .from(relationships) + .where( + or( + eq(relationships.userA, params.userId), + eq(relationships.userB, params.userId) + ) + ) + .orderBy(desc(relationships.createdAt)); + + elizaLogger.debug("Retrieved relationships:", { + userId: params.userId, + count: result.length, + }); + + return result as unknown as Relationship[]; + } catch (error) { + elizaLogger.error("Failed to fetch relationships:", { + error: error instanceof Error ? error.message : String(error), + userId: params.userId, + }); + throw error; + } + } + + async getKnowledge(params: { + id?: UUID; + agentId: UUID; + limit?: number; + query?: string; + }): Promise { + try { + let conditions = [ + or( + eq(knowledges.agentId, params.agentId), + eq(knowledges.isShared, true) + ), + ]; + + if (params.id) { + conditions.push(eq(knowledges.id, params.id)); + } + + const query = this.db + .select() + .from(knowledges) + .where(and(...conditions)) + .orderBy(desc(knowledges.createdAt)); + + const result = await (params.limit + ? query.limit(params.limit) + : query); + + return result.map((row) => ({ + id: row.id as UUID, + agentId: row.agentId as UUID, + content: + typeof row.content === "string" + ? JSON.parse(row.content) + : row.content, + embedding: row.embedding + ? new Float32Array(row.embedding) + : undefined, + createdAt: row.createdAt?.getTime(), + })); + } catch (error) { + elizaLogger.error("Failed to get knowledge:", { + error: error instanceof Error ? error.message : String(error), + id: params.id, + agentId: params.agentId, + limit: params.limit, + }); + throw error; + } + } + + async searchKnowledge(params: { + agentId: UUID; + embedding: Float32Array; + match_threshold: number; + match_count: number; + searchText?: string; + }): Promise { + try { + const cacheKey = `embedding_${params.agentId}_${params.searchText}`; + + const cachedResult = await this.getCache({ + key: cacheKey, + agentId: params.agentId, + }); + + if (cachedResult) { + return JSON.parse(cachedResult); + } + + const vectorStr = params.embedding.toString(); + + const result = await this.db.execute>(sql` + WITH vector_scores AS ( + SELECT id, + 1 - (embedding <-> ${vectorStr}::vector) as vector_score + FROM knowledge + WHERE ("agentId" IS NULL AND "isShared" = true) OR "agentId" = ${ + params.agentId + } + AND embedding IS NOT NULL + ), + keyword_matches AS ( + SELECT id, + CASE + WHEN content->>'text' ILIKE ${`%${ + params.searchText || "" + }%`} THEN 3.0 + ELSE 1.0 + END * + CASE + WHEN (content->'metadata'->>'isChunk')::boolean = true THEN 1.5 + WHEN (content->'metadata'->>'isMain')::boolean = true THEN 1.2 + ELSE 1.0 + END as keyword_score + FROM knowledge + WHERE ("agentId" IS NULL AND "isShared" = true) OR "agentId" = ${ + params.agentId + } + ) + SELECT k.*, + v.vector_score, + kw.keyword_score, + (v.vector_score * kw.keyword_score) as combined_score + FROM knowledge k + JOIN vector_scores v ON k.id = v.id + LEFT JOIN keyword_matches kw ON k.id = kw.id + WHERE ("agentId" IS NULL AND "isShared" = true) OR k."agentId" = ${ + params.agentId + } + AND ( + v.vector_score >= ${params.match_threshold} + OR (kw.keyword_score > 1.0 AND v.vector_score >= 0.3) + ) + ORDER BY combined_score DESC + LIMIT ${params.match_count} + `); + + const mappedResults = result.map((row: any) => ({ + id: row.id as UUID, + agentId: row.agentId as UUID, + content: + typeof row.content === "string" + ? JSON.parse(row.content) + : row.content, + embedding: row.embedding + ? new Float32Array( + row.embedding.slice(1, -1).split(",").map(Number) + ) + : undefined, + createdAt: row.createdAt?.getTime(), + similarity: row.combined_score, + })); + + await this.setCache({ + key: cacheKey, + agentId: params.agentId, + value: JSON.stringify(mappedResults), + }); + + return mappedResults; + } catch (error) { + elizaLogger.error("Error in searchKnowledge:", { + error: error instanceof Error ? error.message : String(error), + agentId: params.agentId, + searchText: params.searchText, + }); + throw error; + } + } + + async createKnowledge(knowledge: RAGKnowledgeItem): Promise { + await this.db.transaction(async (tx) => { + try { + const metadata = knowledge.content.metadata || {}; + + // If this is a chunk, use createKnowledgeChunk + if (metadata.isChunk && metadata.originalId) { + await this.createKnowledgeChunk({ + id: knowledge.id, + originalId: metadata.originalId, + agentId: metadata.isShared ? null : knowledge.agentId, + content: knowledge.content, + embedding: knowledge.embedding, + chunkIndex: metadata.chunkIndex || 0, + isShared: metadata.isShared || false, + createdAt: knowledge.createdAt || Date.now(), + }); + } else { + // This is a main knowledge item + await tx.insert(knowledges).values({ + id: knowledge.id, + agentId: metadata.isShared ? null : knowledge.agentId, + content: knowledge.content, + embedding: knowledge.embedding + ? Array.from(knowledge.embedding) + : null, + createdAt: new Date(knowledge.createdAt || Date.now()), + isMain: true, + originalId: null, + chunkIndex: null, + isShared: metadata.isShared || false, + }); + } + } catch (error) { + elizaLogger.error("Failed to create knowledge:", error); + throw error; + } + }); + } + + private async createKnowledgeChunk(params: { + id: UUID; + originalId: UUID; + agentId: UUID | null; + content: any; + embedding: Float32Array | undefined | null; + chunkIndex: number; + isShared: boolean; + createdAt: number; + }): Promise { + const embedding = params.embedding + ? Array.from(params.embedding) + : null; + + const patternId = `${params.originalId}-chunk-${params.chunkIndex}`; + const contentWithPatternId = { + ...params.content, + metadata: { + ...params.content.metadata, + patternId, + }, + }; + + await this.db.insert(knowledges).values({ + id: params.id, + agentId: params.agentId, + content: contentWithPatternId, + embedding: embedding, + createdAt: new Date(params.createdAt), + isMain: false, + originalId: params.originalId, + chunkIndex: params.chunkIndex, + isShared: params.isShared, + }); + } + + async removeKnowledge(id: UUID): Promise { + try { + await this.db.delete(knowledges).where(eq(knowledges.id, id)); + } catch (error) { + elizaLogger.error("Failed to remove knowledge:", { + error: error instanceof Error ? error.message : String(error), + id, + }); + throw error; + } + } + + async clearKnowledge(agentId: UUID, shared?: boolean): Promise { + try { + await this.db + .delete(knowledges) + .where(eq(knowledges.agentId, agentId)); + } catch (error) { + elizaLogger.error("Failed to clear knowledge:", { + error: error instanceof Error ? error.message : String(error), + agentId, + shared, + }); + throw error; + } + } + + async getCache(params: { + agentId: UUID; + key: string; + }): Promise { + try { + const result = await this.db + .select() + .from(caches) + .where( + and( + eq(caches.agentId, params.agentId), + eq(caches.key, params.key) + ) + ); + return result[0]?.value as string | undefined; + } catch (error) { + elizaLogger.error("Failed to get cache:", { + error: error instanceof Error ? error.message : String(error), + agentId: params.agentId, + key: params.key, + }); + throw error; + } + } + + async setCache(params: { + agentId: UUID; + key: string; + value: string; + }): Promise { + try { + await this.db + .insert(caches) + .values({ + key: params.key, + agentId: params.agentId, + value: params.value, + createdAt: new Date(), + }) + .onConflictDoUpdate({ + target: [caches.key, caches.agentId], + set: { + value: params.value, + createdAt: new Date(), + }, + }); + return true; + } catch (error) { + elizaLogger.error("Error setting cache", { + error: error instanceof Error ? error.message : String(error), + key: params.key, + agentId: params.agentId, + }); + return false; + } + } + + async deleteCache(params: { + agentId: UUID; + key: string; + }): Promise { + try { + await this.db + .delete(caches) + .where( + and( + eq(caches.agentId, params.agentId), + eq(caches.key, params.key) + ) + ); + return true; + } catch (error) { + elizaLogger.error("Error deleting cache", { + error: error instanceof Error ? error.message : String(error), + key: params.key, + agentId: params.agentId, + }); + return false; + } + } +} diff --git a/packages/adapter-drizzle/src/migrations.ts b/packages/adapter-drizzle/src/migrations.ts new file mode 100644 index 0000000000..2c6ddbb2a1 --- /dev/null +++ b/packages/adapter-drizzle/src/migrations.ts @@ -0,0 +1,18 @@ +import { migrate } from "drizzle-orm/node-postgres/migrator"; +import path from "path"; +import { drizzle } from "drizzle-orm/node-postgres"; +import { Pool } from "pg"; +import { elizaLogger } from "@elizaos/core"; + +export async function runMigrations(pgPool: Pool): Promise { + try { + const db = drizzle(pgPool); + await migrate(db, { + migrationsFolder: path.resolve(__dirname, "../drizzle/migrations"), + }); + elizaLogger.info("Migrations completed successfully!"); + } catch (error) { + elizaLogger.error("Failed to run database migrations:", error); + throw error; + } +} diff --git a/packages/adapter-drizzle/src/schema.ts b/packages/adapter-drizzle/src/schema.ts new file mode 100644 index 0000000000..708af3ea32 --- /dev/null +++ b/packages/adapter-drizzle/src/schema.ts @@ -0,0 +1,169 @@ +import { + pgTable, + uuid, + timestamp, + text, + jsonb, + boolean, + integer, + vector, +} from "drizzle-orm/pg-core"; + +const getEmbeddingDimension = ( + embeddingProvider: string = "default" +): number => { + switch (embeddingProvider) { + case "openai": + return 1536; + case "ollama": + return 1024; + case "gaianet": + return 768; + default: + return 384; + } +}; + +const dimensions = getEmbeddingDimension(); + +export const accounts = pgTable("accounts", { + id: uuid("id").primaryKey().notNull(), + createdAt: timestamp("createdAt", { + withTimezone: true, + mode: "date", + }).defaultNow(), + name: text("name"), + username: text("username"), + email: text("email").notNull(), + avatarUrl: text("avatarUrl"), + details: jsonb("details").default({}), +}); + +export const memories = pgTable("memories", { + id: uuid("id").primaryKey().notNull(), + type: text("type").notNull(), + createdAt: timestamp("createdAt", { + withTimezone: true, + mode: "date", + }).defaultNow(), + content: jsonb("content").notNull(), + embedding: vector("embedding", { + dimensions: dimensions, + }), + userId: uuid("userId") + .references(() => accounts.id) + .references(() => accounts.id), + agentId: uuid("agentId") + .references(() => accounts.id) + .references(() => accounts.id), + roomId: uuid("roomId") + .references(() => rooms.id) + .references(() => rooms.id), + unique: boolean("unique").default(true).notNull(), +}); + +export const rooms = pgTable("rooms", { + id: uuid("id").primaryKey().notNull(), + createdAt: timestamp("createdAt", { + withTimezone: true, + mode: "date", + }).defaultNow(), +}); + +export const goals = pgTable("goals", { + id: uuid("id").primaryKey().notNull(), + createdAt: timestamp("createdAt", { + withTimezone: true, + mode: "date", + }).defaultNow(), + userId: uuid("userId") + .references(() => accounts.id) + .references(() => accounts.id), + name: text("name"), + status: text("status"), + description: text("description"), + roomId: uuid("roomId") + .references(() => rooms.id) + .references(() => rooms.id), + objectives: jsonb("objectives").default([]).notNull(), +}); + +export const logs = pgTable("logs", { + id: uuid("id").defaultRandom().primaryKey().notNull(), + createdAt: timestamp("createdAt", { + withTimezone: true, + mode: "date", + }).defaultNow(), + userId: uuid("userId") + .notNull() + .references(() => accounts.id) + .references(() => accounts.id), + body: jsonb("body").notNull(), + type: text("type").notNull(), + roomId: uuid("roomId") + .notNull() + .references(() => rooms.id) + .references(() => rooms.id), +}); + +export const participants = pgTable("participants", { + id: uuid("id").primaryKey().notNull(), + createdAt: timestamp("createdAt", { + withTimezone: true, + mode: "date", + }).defaultNow(), + userId: uuid("userId") + .references(() => accounts.id) + .references(() => accounts.id), + roomId: uuid("roomId") + .references(() => rooms.id) + .references(() => rooms.id), + userState: text("userState"), + lastMessageRead: text("last_message_read"), +}); + +export const relationships = pgTable("relationships", { + id: uuid("id").primaryKey().notNull(), + createdAt: timestamp("createdAt", { + withTimezone: true, + mode: "date", + }).defaultNow(), + userA: uuid("userA") + .notNull() + .references(() => accounts.id) + .references(() => accounts.id), + userB: uuid("userB") + .notNull() + .references(() => accounts.id) + .references(() => accounts.id), + status: text("status"), + userId: uuid("userId") + .notNull() + .references(() => accounts.id) + .references(() => accounts.id), +}); + +export const knowledges = pgTable("knowledge", { + id: uuid("id").primaryKey().notNull(), + agentId: uuid("agentId").references(() => accounts.id), + content: jsonb("content").notNull(), + embedding: vector("embedding", { + dimensions: dimensions, + }), + createdAt: timestamp("createdAt", { + withTimezone: true, + mode: "date", + }).defaultNow(), + isMain: boolean("isMain").default(false), + originalId: uuid("originalId"), + chunkIndex: integer("chunkIndex"), + isShared: boolean("isShared").default(false), +}); + +export const caches = pgTable("cache", { + key: text("key").notNull(), + agentId: text("agentId").notNull(), + value: jsonb("value").default({}), + createdAt: timestamp("createdAt", { mode: "date" }).defaultNow(), + expiresAt: timestamp("expiresAt", { mode: "date" }), +}); diff --git a/packages/adapter-drizzle/tsconfig.json b/packages/adapter-drizzle/tsconfig.json new file mode 100644 index 0000000000..ad27e288d0 --- /dev/null +++ b/packages/adapter-drizzle/tsconfig.json @@ -0,0 +1,10 @@ +{ + "extends": "../core/tsconfig.json", + "compilerOptions": { + "outDir": "dist", + "rootDir": "src", + "strict": true + }, + "include": ["src/**/*.ts", "src/migrations.ts", "src/schema.ts"], + "exclude": ["node_modules", "dist"] +} diff --git a/packages/adapter-drizzle/tsup.config.ts b/packages/adapter-drizzle/tsup.config.ts new file mode 100644 index 0000000000..9acebc5ba9 --- /dev/null +++ b/packages/adapter-drizzle/tsup.config.ts @@ -0,0 +1,21 @@ +import { defineConfig } from "tsup"; + +export default defineConfig({ + entry: ["src/index.ts"], + outDir: "dist", + sourcemap: true, + clean: true, + format: ["esm"], // Ensure you're targeting CommonJS + external: [ + "dotenv", // Externalize dotenv to prevent bundling + "fs", // Externalize fs to use Node.js built-in module + "path", // Externalize other built-ins if necessary + "@reflink/reflink", + "@node-llama-cpp", + "https", + "http", + "agentkeepalive", + "uuid", + // Add other modules you want to externalize + ], +}); diff --git a/packages/core/src/embedding.ts b/packages/core/src/embedding.ts index 5fb7d75ede..75a74ec4fc 100644 --- a/packages/core/src/embedding.ts +++ b/packages/core/src/embedding.ts @@ -4,7 +4,7 @@ import settings from "./settings.ts"; import elizaLogger from "./logger.ts"; import LocalEmbeddingModelManager from "./localembeddingManager.ts"; -interface EmbeddingOptions { +export interface EmbeddingOptions { model: string; endpoint: string; apiKey?: string; @@ -99,6 +99,8 @@ async function getRemoteEmbedding( }; try { + elizaLogger.debug("Full URL:", fullUrl); + elizaLogger.debug("Request Options:", requestOptions); const response = await fetch(fullUrl, requestOptions); if (!response.ok) { @@ -302,3 +304,14 @@ export async function embed(runtime: IAgentRuntime, input: string) { return null; } } + +export async function getEmbeddingForTest(input: string, config: EmbeddingOptions) { + return await getRemoteEmbedding(input, { + model: config.model, + endpoint: config.endpoint, + apiKey: config.apiKey, + dimensions: config.dimensions, + isOllama: config.isOllama, + provider: config.provider + }); +} \ No newline at end of file