Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 88 additions & 21 deletions packages/aiqtoolkit_mem0ai/src/aiq/plugins/mem0ai/mem0_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file is a part of the Nvidia AIQ Toolkit project, Mem0 Plugin, and has been modified to support the Mem0 v2 API.

import asyncio

from mem0 import AsyncMemory
from mem0 import AsyncMemoryClient

from aiq.memory.interfaces import MemoryEditor
Expand All @@ -26,16 +29,20 @@ class Mem0Editor(MemoryEditor):
Wrapper class that implements AIQ Toolkit Interfaces for Mem0 Integrations Async.
"""

def __init__(self, mem0_client: AsyncMemoryClient):
def __init__(self, mem0_client: AsyncMemoryClient | AsyncMemory):
"""
Initialize class with Predefined Mem0 Client.

Args:
mem0_client (AsyncMemoryClient): Preinstantiated
AsyncMemoryClient object for Mem0.
mem0_client (AsyncMemoryClient | AsyncMemory): Preinstantiated
AsyncMemoryClient or AsyncMemory object for Mem0.
"""
self._client = mem0_client

# Ensure the client is properly initialized
if self._client is None:
raise ValueError("Mem0 client cannot be None")

async def add_items(self, items: list[MemoryItem]) -> None:
"""
Insert Multiple MemoryItems into the memory.
Expand All @@ -51,15 +58,29 @@ async def add_items(self, items: list[MemoryItem]) -> None:

user_id = memory_item.user_id # This must be specified
run_id = item_meta.pop("run_id", None)
tags = memory_item.tags

coroutines.append(
self._client.add(content,
user_id=user_id,
run_id=run_id,
tags=tags,
metadata=item_meta,
output_format="v1.1"))
# UPDATED: In mem0 v2 API, tags are now part of metadata
# Moving tags into metadata dictionary
tags = memory_item.tags
if tags:
item_meta["categories"] = tags

# UPDATED: In mem0 v2 API, content is passed as messages array
# Handle different types of content
if isinstance(content, str):
messages = [{"role": "user", "content": content}]
elif isinstance(content, list) and all(isinstance(msg, dict) for msg in content):
# If content is already in the correct format (list of message dicts)
messages = content
else:
# Try to convert to string as a fallback
try:
messages = [{"role": "user", "content": str(content)}]
except Exception:
raise ValueError(f"Unable to convert content to a valid message format: {type(content)}")

# UPDATED: Removed output_format parameter as it's deprecated in v2 API
coroutines.append(self._client.add(messages, user_id=user_id, run_id=run_id, metadata=item_meta))

await asyncio.gather(*coroutines)

Expand All @@ -80,32 +101,78 @@ async def search(self, query: str, top_k: int = 5, **kwargs) \

user_id = kwargs.pop("user_id") # Ensure user ID is in keyword arguments

search_result = await self._client.search(query, user_id=user_id, top_k=top_k, output_format="v1.1", **kwargs)
# UPDATED: Removed output_format parameter as it's deprecated in v2 API
search_result = await self._client.search(query, user_id=user_id, limit=top_k, **kwargs)

# Construct MemoryItem instances
memories = []

for res in search_result["results"]:
item_meta = res.pop("metadata", {})
# UPDATED: Processing search results according to v2 API structure
# Handle both v1 and v2 API formats
# In v1, search_result is a dict with a "results" key
# In v2, search_result is directly a list of results
results_to_process = search_result

if isinstance(search_result, dict) and "results" in search_result:
results_to_process = search_result["results"]

for res in results_to_process:
# Handle different result formats
if isinstance(res, dict):
item_meta = res.get("metadata", {})
if isinstance(item_meta, dict):
# Make a copy to avoid modifying the original
item_meta = dict(item_meta)
else:
item_meta = {}

# UPDATED: In v2 API, tags/categories are in metadata
tags = []
if "categories" in item_meta:
tags = item_meta.pop("categories", [])
if not isinstance(tags, list):
tags = []

memory_content = res.get("memory", "")

elif isinstance(res, str):
# If the result is a string, use it as the memory content
memory_content = res
item_meta = {}
tags = []
else:
# Skip invalid results
continue

# Try to get the conversation from the 'input' field first (as in older versions)
# Only try to get 'input' if res is a dictionary
if isinstance(res, dict):
conversation = res.get("input", [])
# If not found or not in the expected format, construct it from memory_content
if not conversation or not isinstance(conversation, list):
conversation = [{
"role": "user", "content": memory_content
}] if isinstance(memory_content, str) else memory_content
else:
# If res is not a dictionary, construct conversation from memory_content
conversation = [{
"role": "user", "content": memory_content
}] if isinstance(memory_content, str) else memory_content

memories.append(
MemoryItem(conversation=res.pop("input", []),
MemoryItem(conversation=conversation,
user_id=user_id,
memory=res["memory"],
tags=res.pop("categories", []) or [],
memory=memory_content,
tags=tags,
metadata=item_meta))

return memories

async def remove_items(self, **kwargs):

if "memory_id" in kwargs:

memory_id = kwargs.pop("memory_id")
await self._client.delete(memory_id)

elif "user_id" in kwargs:

user_id = kwargs.pop("user_id")
await self._client.delete_all(user_id=user_id)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file is a part of the Nvidia AIQ Toolkit project, Mem0 Plugin
# as a way of integrating lcoal Mem0 with local Ollama, and vector db instances.

from aiq.builder.builder import Builder
from aiq.cli.register_workflow import register_memory
from aiq.data_models.memory import MemoryBaseConfig


class Mem0LocalMemoryClientConfig(MemoryBaseConfig, name="mem0_memory_local_ollama"):
"""
Mem0 Memory Client Configuration. Setup for use with local Ollama instaces.
"""
# Defaults are set to work with Ollama and Milvus local instances with a local Mem0 instance
# change them according to your local setup or override them in your workflow config file
vec_store_provider: str = "milvus" # Change to "qdrant" if you prefer that
vec_store_collection_name: str = "DefaultAIQCollectionNew"
vec_store_url: str = "http://localhost:19530" # Default Local Milvus URL, change if needed
vec_store_embedding_model_dims: int = 1024 # Updated to match the actual embedding dimensions
llm_provider: str = "ollama"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would these work only for Ollama models? I'm wondering if we can abstract this further to use any LLM provider supported by the NeMo Agent Toolkit.

llm_model: str = "aliafshar/gemma3-it-qat-tools:27b" # Change to your preferred model
llm_temperature: float = 0.0
llm_max_tokens: int = 2000
llm_base_url: str = "http://localhost:11434" # Default Ollama URL, change if needed
embedder_provider: str = "ollama"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also wondering we can we use the embedder interface in the toolkit to support any embedding model supported by the toolkit. If we can generalize to more than Ollama, it would be incredibly valuable

embedder_model: str = "snowflake-arctic-embed2:latest"
embedder_base_url: str = "http://localhost:11434" # Default Ollama URL, change if needed


@register_memory(config_type=Mem0LocalMemoryClientConfig)
async def mem0_memory_client(config: Mem0LocalMemoryClientConfig, builder: Builder):
# UPDATED: Import AsyncMemory for v2 API compatibility
from mem0 import AsyncMemory

from aiq.plugins.mem0ai.mem0_editor import Mem0Editor

# UPDATED: Create configuration dictionary for AsyncMemory
# This includes all the necessary configuration for the local embedder, LLM, and vector store
config_dict = {
"vector_store": {
"provider": config.vec_store_provider,
"config": {
"collection_name": config.vec_store_collection_name,
"url": config.vec_store_url,
"embedding_model_dims": config.vec_store_embedding_model_dims,
},
},
"llm": {
"provider": config.llm_provider,
"config": {
"model": config.llm_model,
"temperature": config.llm_temperature,
"max_tokens": config.llm_max_tokens,
"ollama_base_url": config.llm_base_url,
},
},
"embedder": {
"provider": config.embedder_provider,
"config": {
"model": config.embedder_model,
"ollama_base_url": config.embedder_base_url,
},
},
}

# UPDATED: Initialize AsyncMemory with the configuration
# This is compatible with the v2 API and the updated mem0_editor.py
# Use from_config to create an AsyncMemory instance from a dictionary
mem0_client = await AsyncMemory.from_config(config_dict)

memory_editor = Mem0Editor(mem0_client=mem0_client)

yield memory_editor
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
# Import any providers which need to be automatically registered here

from . import memory
from . import memory_local_ollama
Loading