Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ mem0_memory = [
# Need to be optional as a fix for https://github.com/strands-agents/docs/issues/19
"mem0ai>=0.1.99,<1.0.0",
"opensearch-py>=2.8.0,<3.0.0",
"psycopg2-binary",
]
local_chromium_browser = ["nest-asyncio>=1.5.0,<2.0.0", "playwright>=1.42.0,<2.0.0"]
agent_core_browser = [
Expand Down
79 changes: 75 additions & 4 deletions src/strands_tools/mem0_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,30 @@ class Mem0ServiceClient:
},
}

def _get_postgresql_config(self) -> Dict:
"""Get PostgreSQL configuration based on the current provider."""
# Start with the default embedder and llm config
config = {
"embedder": self.DEFAULT_CONFIG["embedder"].copy(),
"llm": self.DEFAULT_CONFIG["llm"].copy(),
}

# Add PostgreSQL vector store configuration
config["vector_store"] = {
"provider": "pgvector",
"config": {
"host": os.environ.get("POSTGRESQL_HOST"),
"port": int(os.environ.get("POSTGRESQL_PORT", 5432)),
"user": os.environ.get("POSTGRESQL_USER"),
"password": os.environ.get("POSTGRESQL_PASSWORD"),
"dbname": os.environ.get("DB_NAME", "postgres"),
"collection_name": os.environ.get("DB_COLLECTION_NAME", "mem0_memories"),
"embedding_model_dims": 1024,
},
}

return config

def __init__(self, config: Optional[Dict] = None):
"""Initialize the Mem0 service client.

Expand Down Expand Up @@ -208,6 +232,10 @@ def _initialize_client(self, config: Optional[Dict] = None) -> Any:
logger.debug("Using Neptune Analytics graph backend (Mem0Memory with Neptune Analytics)")
config = self._configure_neptune_analytics_backend(config)

if os.environ.get("POSTGRESQL_HOST"):
logger.info("Using PostgreSQL backend (Mem0Memory with PostgreSQL)")
return self._initialize_postgresql_client(config)

if os.environ.get("OPENSEARCH_HOST"):
logger.debug("Using OpenSearch backend (Mem0Memory with OpenSearch)")
return self._initialize_opensearch_client(config)
Expand All @@ -231,6 +259,37 @@ def _configure_neptune_analytics_backend(self, config: Optional[Dict] = None) ->
}
return config

def _initialize_postgresql_client(self, config: Optional[Dict] = None) -> Mem0Memory:
"""Initialize a Mem0 client with PostgreSQL backend.

Args:
config: Optional configuration dictionary to override defaults.

Returns:
An initialized Mem0Memory instance configured for PostgreSQL.

Raises:
ValueError: If required PostgreSQL environment variables are missing.
"""
# Validate required environment variables
required_vars = ["POSTGRESQL_HOST", "POSTGRESQL_USER", "POSTGRESQL_PASSWORD"]
missing_vars = [var for var in required_vars if not os.environ.get(var)]
if missing_vars:
raise ValueError(f"Missing required PostgreSQL environment variables: {', '.join(missing_vars)}")

# Get PostgreSQL configuration
pg_config = self._get_postgresql_config()

# Validate OpenAI API key if using OpenAI
provider = os.environ.get("MEM0_LLM_PROVIDER", "aws_bedrock")
if provider == "openai" and not os.environ.get("OPENAI_API_KEY"):
raise ValueError("OPENAI_API_KEY environment variable is required when using OpenAI provider")

# Merge with user-provided config if any
merged_config = self._merge_configs(pg_config, config)

return Mem0Memory.from_config(config_dict=merged_config)

def _initialize_opensearch_client(self, config: Optional[Dict] = None) -> Mem0Memory:
"""Initialize a Mem0 client with OpenSearch backend.

Expand Down Expand Up @@ -296,12 +355,24 @@ def _merge_config(self, config: Optional[Dict] = None) -> Dict:
Returns:
A merged configuration dictionary.
"""
merged_config = self.DEFAULT_CONFIG.copy()
if not config:
return self._merge_configs(self.DEFAULT_CONFIG, config)

def _merge_configs(self, base_config: Dict, override_config: Optional[Dict] = None) -> Dict:
"""Merge two configuration dictionaries.

Args:
base_config: Base configuration dictionary
override_config: Optional configuration to merge into base

Returns:
A merged configuration dictionary.
"""
merged_config = base_config.copy()
if not override_config:
return merged_config

# Deep merge the configs
for key, value in config.items():
# Merge the configs
for key, value in override_config.items():
if key in merged_config and isinstance(value, dict) and isinstance(merged_config[key], dict):
merged_config[key].update(value)
else:
Expand Down
Loading