Skip to content

Commit

Permalink
add specific extensible return types to memory query and update_context
Browse files Browse the repository at this point in the history
  • Loading branch information
victordibia committed Jan 11, 2025
1 parent c6e4825 commit 61bcf34
Showing 8 changed files with 79 additions and 120 deletions.
Original file line number Diff line number Diff line change
@@ -342,9 +342,11 @@ async def on_messages_stream(
# Update the model context with memory content.
if self._memory:
for memory in self._memory:
memory_query_result = await memory.update_context(self._model_context)
if memory_query_result and len(memory_query_result) > 0:
memory_query_event_msg = MemoryQueryEvent(content=memory_query_result, source=self.name)
update_context_result = await memory.update_context(self._model_context)
if update_context_result and len(update_context_result.memories.results) > 0:
memory_query_event_msg = MemoryQueryEvent(
content=update_context_result.memories.results, source=self.name
)
inner_messages.append(memory_query_event_msg)
yield memory_query_event_msg

15 changes: 12 additions & 3 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@
ToolCallSummaryMessage,
)
from autogen_core import Image
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult
from autogen_core.model_context import BufferedChatCompletionContext
from autogen_core.models import LLMMessage
from autogen_core.models._model_client import ModelFamily
@@ -543,17 +543,24 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None:

empty_context = BufferedChatCompletionContext(buffer_size=2)
empty_results = await memory.update_context(empty_context)
assert len(empty_results) == 0
assert len(empty_results.memories.results) == 0

# Test various content types
memory = ListMemory()
await memory.add(MemoryContent(content="text content", mime_type=MemoryMimeType.TEXT))
await memory.add(MemoryContent(content={"key": "value"}, mime_type=MemoryMimeType.JSON))
await memory.add(MemoryContent(content=Image.from_base64(b64_image_str), mime_type=MemoryMimeType.IMAGE))

# Test query functionality
query_result = await memory.query(MemoryContent(content="", mime_type=MemoryMimeType.TEXT))
assert isinstance(query_result, MemoryQueryResult)
# Should have all three memories we added
assert len(query_result.results) == 3

# Test clear and cleanup
await memory.clear()
assert await memory.query(MemoryContent(content="", mime_type=MemoryMimeType.TEXT)) == []
empty_query = await memory.query(MemoryContent(content="", mime_type=MemoryMimeType.TEXT))
assert len(empty_query.results) == 0
await memory.close() # Should not raise

# Test invalid memory type
@@ -576,6 +583,8 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(result.messages) > 0
memory_event = next((msg for msg in result.messages if isinstance(msg, MemoryQueryEvent)), None)
assert memory_event is not None
assert len(memory_event.content) > 0
assert isinstance(memory_event.content[0], MemoryContent)

# Test memory protocol
class BadMemory:
Original file line number Diff line number Diff line change
@@ -75,33 +75,10 @@
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------- user ----------\n",
"What is the weather in New York?\n",
"---------- assistant_agent ----------\n",
"[MemoryContent(content='The weather should be in metric units', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None, timestamp=None, source=None, score=None)]\n",
"---------- assistant_agent ----------\n",
"[FunctionCall(id='call_NR8vBXk0856yl9eYa8SMjYbo', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n",
"[Prompt tokens: 123, Completion tokens: 20]\n",
"---------- assistant_agent ----------\n",
"[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', call_id='call_NR8vBXk0856yl9eYa8SMjYbo')]\n",
"---------- assistant_agent ----------\n",
"The weather in New York is 23 °C and Sunny.\n",
"---------- Summary ----------\n",
"Number of messages: 5\n",
"Finish reason: None\n",
"Total prompt tokens: 123\n",
"Total completion tokens: 20\n",
"Duration: 1.27 seconds\n"
]
},
{
"data": {
"text/plain": [
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None, timestamp=None, source=None, score=None)], type='MemoryQueryEvent'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=123, completion_tokens=20), content=[FunctionCall(id='call_NR8vBXk0856yl9eYa8SMjYbo', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', call_id='call_NR8vBXk0856yl9eYa8SMjYbo')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 °C and Sunny.', type='ToolCallSummaryMessage')], stop_reason=None)"
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None, timestamp=None, source=None, score=None)], type='MemoryQueryEvent'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=123, completion_tokens=20), content=[FunctionCall(id='call_pHq4p89gW6oGjGr3VsVETCYX', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', call_id='call_pHq4p89gW6oGjGr3VsVETCYX')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 °C and Sunny.', type='ToolCallSummaryMessage')], stop_reason=None)"
]
},
"execution_count": 3,
@@ -132,8 +109,8 @@
"text/plain": [
"[UserMessage(content='What is the weather in New York?', source='user', type='UserMessage'),\n",
" SystemMessage(content='\\nRelevant memory content (in chronological order):\\n1. The weather should be in metric units\\n2. Meal recipe must be vegan\\n', type='SystemMessage'),\n",
" AssistantMessage(content=[FunctionCall(id='call_uvKugIKWzeCYK1px49HJhlku', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], source='assistant_agent', type='AssistantMessage'),\n",
" FunctionExecutionResultMessage(content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_uvKugIKWzeCYK1px49HJhlku')], type='FunctionExecutionResultMessage')]"
" AssistantMessage(content=[FunctionCall(id='call_pHq4p89gW6oGjGr3VsVETCYX', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], source='assistant_agent', type='AssistantMessage'),\n",
" FunctionExecutionResultMessage(content=[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', call_id='call_pHq4p89gW6oGjGr3VsVETCYX')], type='FunctionExecutionResultMessage')]"
]
},
"execution_count": 4,
@@ -159,57 +136,10 @@
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------- user ----------\n",
"Write brief meal recipe with broth\n",
"---------- assistant_agent ----------\n",
"[MemoryContent(content='The weather should be in metric units', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None, timestamp=None, source=None, score=None)]\n",
"---------- assistant_agent ----------\n",
"Here's a simple vegan recipe for a vegetable broth soup:\n",
"\n",
"**Vegan Vegetable Broth Soup**\n",
"\n",
"**Ingredients:**\n",
"- 8 cups vegetable broth\n",
"- 2 carrots, chopped\n",
"- 2 celery stalks, chopped\n",
"- 1 onion, diced\n",
"- 3 cloves garlic, minced\n",
"- 1 zucchini, chopped\n",
"- 1 cup green beans, trimmed and halved\n",
"- 1 cup chopped kale\n",
"- 1 teaspoon dried thyme\n",
"- 1 teaspoon dried basil\n",
"- Salt and pepper to taste\n",
"\n",
"**Instructions:**\n",
"1. In a large pot, heat a splash of vegetable broth over medium heat. Add the onions and garlic and sauté until the onions are translucent.\n",
"2. Add the carrots, celery, zucchini, and green beans, and sauté for another 5 minutes.\n",
"3. Pour in the remaining vegetable broth and bring the mixture to a gentle boil.\n",
"4. Stir in the thyme, basil, salt, and pepper. Reduce the heat to a simmer and let the soup cook for about 25-30 minutes, or until the vegetables are tender.\n",
"5. Add the chopped kale and cook for an additional 5 minutes.\n",
"6. Taste and adjust the seasoning if needed.\n",
"7. Serve hot as a comforting and nourishing meal.\n",
"\n",
"Enjoy your delicious vegan vegetable broth soup! \n",
"\n",
"TERMINATE\n",
"[Prompt tokens: 207, Completion tokens: 271]\n",
"---------- Summary ----------\n",
"Number of messages: 3\n",
"Finish reason: None\n",
"Total prompt tokens: 207\n",
"Total completion tokens: 271\n",
"Duration: 6.22 seconds\n"
]
},
{
"data": {
"text/plain": [
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Write brief meal recipe with broth', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None, timestamp=None, source=None, score=None)], type='MemoryQueryEvent'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=207, completion_tokens=271), content=\"Here's a simple vegan recipe for a vegetable broth soup:\\n\\n**Vegan Vegetable Broth Soup**\\n\\n**Ingredients:**\\n- 8 cups vegetable broth\\n- 2 carrots, chopped\\n- 2 celery stalks, chopped\\n- 1 onion, diced\\n- 3 cloves garlic, minced\\n- 1 zucchini, chopped\\n- 1 cup green beans, trimmed and halved\\n- 1 cup chopped kale\\n- 1 teaspoon dried thyme\\n- 1 teaspoon dried basil\\n- Salt and pepper to taste\\n\\n**Instructions:**\\n1. In a large pot, heat a splash of vegetable broth over medium heat. Add the onions and garlic and sauté until the onions are translucent.\\n2. Add the carrots, celery, zucchini, and green beans, and sauté for another 5 minutes.\\n3. Pour in the remaining vegetable broth and bring the mixture to a gentle boil.\\n4. Stir in the thyme, basil, salt, and pepper. Reduce the heat to a simmer and let the soup cook for about 25-30 minutes, or until the vegetables are tender.\\n5. Add the chopped kale and cook for an additional 5 minutes.\\n6. Taste and adjust the seasoning if needed.\\n7. Serve hot as a comforting and nourishing meal.\\n\\nEnjoy your delicious vegan vegetable broth soup! \\n\\nTERMINATE\", type='TextMessage')], stop_reason=None)"
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Write brief meal recipe with broth', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=<MemoryMimeType.TEXT: 'text/plain'>, metadata=None, timestamp=None, source=None, score=None)], type='MemoryQueryEvent'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=208, completion_tokens=253), content=\"Here's a brief vegan meal recipe using broth:\\n\\n**Vegan Mushroom & Herb Broth Soup**\\n\\n**Ingredients:**\\n- 1 tablespoon olive oil\\n- 1 onion, diced\\n- 2 cloves garlic, minced\\n- 250g mushrooms, sliced\\n- 1 carrot, diced\\n- 1 celery stalk, diced\\n- 4 cups vegetable broth\\n- 1 teaspoon thyme\\n- 1 teaspoon rosemary\\n- Salt and pepper to taste\\n- Fresh parsley for garnish\\n\\n**Instructions:**\\n1. Heat the olive oil in a large pot over medium heat. Add the diced onion and garlic, and sauté until the onion becomes translucent.\\n\\n2. Add the sliced mushrooms, carrot, and celery. Continue to sauté until the mushrooms are cooked through and the vegetables begin to soften, about 5 minutes.\\n\\n3. Pour in the vegetable broth. Stir in the thyme and rosemary, and bring the mixture to a boil.\\n\\n4. Reduce the heat to low and let the soup simmer for about 15 minutes, allowing the flavors to meld together.\\n\\n5. Season with salt and pepper to taste.\\n\\n6. Serve hot, garnished with fresh parsley.\\n\\nEnjoy your warm and comforting vegan mushroom & herb broth soup! \\n\\nTERMINATE\", type='TextMessage')], stop_reason=None)"
]
},
"execution_count": 5,
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from ._base_memory import Memory, MemoryContent, MemoryMimeType
from ._base_memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult
from ._list_memory import ListMemory

__all__ = [
"Memory",
"MemoryContent",
"MemoryQueryResult",
"UpdateContextResult",
"MemoryMimeType",
"ListMemory",
]
Original file line number Diff line number Diff line change
@@ -33,6 +33,14 @@ class MemoryContent(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)


class MemoryQueryResult(BaseModel):
results: List[MemoryContent]


class UpdateContextResult(BaseModel):
memories: MemoryQueryResult


@runtime_checkable
class Memory(Protocol):
"""Protocol defining the interface for memory implementations."""
@@ -45,15 +53,15 @@ def name(self) -> str | None:
async def update_context(
self,
model_context: ChatCompletionContext,
) -> List[MemoryContent]:
) -> UpdateContextResult:
"""
Update the provided model context using relevant memory content.
Args:
model_context: The context to update.
Returns:
List of memory entries with relevance scores
UpdateContextResult containing relevant memories
"""
...

Check warning on line 66 in python/packages/autogen-core/src/autogen_core/memory/_base_memory.py

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/memory/_base_memory.py#L66

Added line #L66 was not covered by tests

@@ -62,7 +70,7 @@ async def query(
query: str | MemoryContent,
cancellation_token: CancellationToken | None = None,
**kwargs: Any,
) -> List[MemoryContent]:
) -> MemoryQueryResult:
"""
Query the memory store and return relevant entries.
@@ -72,7 +80,7 @@ async def query(
**kwargs: Additional implementation-specific parameters
Returns:
List of memory entries with relevance scores
MemoryQueryResult containing memory entries with relevance scores
"""
...

Check warning on line 85 in python/packages/autogen-core/src/autogen_core/memory/_base_memory.py

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/memory/_base_memory.py#L85

Added line #L85 was not covered by tests

Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
from .._cancellation_token import CancellationToken
from ..model_context import ChatCompletionContext
from ..models import SystemMessage
from ._base_memory import Memory, MemoryContent
from ._base_memory import Memory, MemoryContent, MemoryQueryResult, UpdateContextResult


class ListMemory(Memory):
@@ -76,7 +76,7 @@ def content(self, value: List[MemoryContent]) -> None:
async def update_context(
self,
model_context: ChatCompletionContext,
) -> List[MemoryContent]:
) -> UpdateContextResult:
"""Update the model context by appending memory content.
This method mutates the provided model_context by adding all memories as a
@@ -86,25 +86,26 @@ async def update_context(
model_context: The context to update. Will be mutated if memories exist.
Returns:
List[MemoryContent]: List of memories that were added to the context
UpdateContextResult containing the memories that were added to the context
"""

if not self._contents:
return []
return UpdateContextResult(memories=MemoryQueryResult(results=[]))

memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(self._contents, 1)]

if memory_strings:
memory_context = "\nRelevant memory content (in chronological order):\n" + "\n".join(memory_strings) + "\n"
await model_context.add_message(SystemMessage(content=memory_context))

return self._contents
return UpdateContextResult(memories=MemoryQueryResult(results=self._contents))

async def query(
self,
query: str | MemoryContent = "",
cancellation_token: CancellationToken | None = None,
**kwargs: Any,
) -> List[MemoryContent]:
) -> MemoryQueryResult:
"""Return all memories without any filtering.
Args:
@@ -113,10 +114,10 @@ async def query(
**kwargs: Additional parameters (ignored)
Returns:
List[MemoryContent]: All stored memories
MemoryQueryResult containing all stored memories
"""
_ = query, cancellation_token, kwargs
return self._contents
return MemoryQueryResult(results=self._contents)

async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None:
"""Add new content to memory.
Loading

0 comments on commit 61bcf34

Please sign in to comment.