Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory Interface in AgentChat #4438

Merged
merged 39 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
48d7ecb
initial base memroy impl
victordibia Nov 30, 2024
f70f61e
update, add example with chromadb
victordibia Dec 1, 2024
24fa684
include mimetype consideration
victordibia Dec 1, 2024
9e94ec8
Merge remote-tracking branch 'origin/main' into agentchat_memory_vd
victordibia Dec 19, 2024
0b7469e
add transform method
victordibia Dec 20, 2024
138ee05
update to address feedback, will update after 4681 is merged
victordibia Dec 20, 2024
a94634b
Merge remote-tracking branch 'origin/main' into agentchat_memory_vd
victordibia Dec 20, 2024
675924c
update memory impl,
victordibia Dec 25, 2024
b1da7e2
remove chroma db, typing fixes
victordibia Jan 3, 2025
f0812a3
Merge remote-tracking branch 'origin/main' into agentchat_memory_vd
victordibia Jan 3, 2025
32701db
format, add test
victordibia Jan 4, 2025
d7bf4d2
update uv lock
victordibia Jan 4, 2025
afbef4d
update docs
victordibia Jan 4, 2025
003bb2e
format updates
victordibia Jan 4, 2025
7b15c2e
update notebook
victordibia Jan 4, 2025
b353110
add memoryqueryevent message, yield message for observability.
victordibia Jan 4, 2025
e1a9be2
Merge branch 'main' into agentchat_memory_vd
victordibia Jan 4, 2025
c797f6a
minor fixes, make score optional/none
victordibia Jan 4, 2025
dfb1da6
Merge branch 'agentchat_memory_vd' of github.com:microsoft/autogen in…
victordibia Jan 4, 2025
97ed7f5
Update python/packages/autogen-agentchat/src/autogen_agentchat/agents…
victordibia Jan 6, 2025
5a74fdf
Merge branch 'main' into agentchat_memory_vd
victordibia Jan 6, 2025
24bd81e
update tests to improve cov
victordibia Jan 7, 2025
5b2c222
refactor, move memory to core.
victordibia Jan 8, 2025
30628f3
format fixxes
victordibia Jan 8, 2025
6b4a53a
Merge remote-tracking branch 'origin/main' into agentchat_memory_vd
victordibia Jan 8, 2025
2072c46
format updates
victordibia Jan 8, 2025
4382c86
format updates
victordibia Jan 9, 2025
0e6df1e
Merge remote-tracking branch 'origin/main' into agentchat_memory_vd
victordibia Jan 9, 2025
08d23cf
fix azure notebook import, other fixes
victordibia Jan 9, 2025
d34c07c
Merge remote-tracking branch 'origin/main' into agentchat_memory_vd
victordibia Jan 9, 2025
9316f6d
update notebook, support str query in Memory protocol
victordibia Jan 9, 2025
1ba5381
update test
victordibia Jan 9, 2025
4805c22
Merge branch 'main' into agentchat_memory_vd
victordibia Jan 9, 2025
3f9db61
Merge remote-tracking branch 'origin/main' into agentchat_memory_vd
victordibia Jan 9, 2025
7b5b97c
update cells
victordibia Jan 9, 2025
c6e4825
Merge branch 'main' into agentchat_memory_vd
victordibia Jan 10, 2025
61bcf34
add specific extensible return types to memory query and update_context
victordibia Jan 11, 2025
5ceb961
Merge remote-tracking branch 'origin/main' into agentchat_memory_vd
victordibia Jan 14, 2025
cb3b051
format update
victordibia Jan 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)

from autogen_core import CancellationToken, FunctionCall
from autogen_core.memory import Memory
from autogen_core.model_context import (
ChatCompletionContext,
UnboundedChatCompletionContext,
Expand All @@ -35,6 +36,7 @@
AgentEvent,
ChatMessage,
HandoffMessage,
MemoryQueryEvent,
MultiModalMessage,
TextMessage,
ToolCallExecutionEvent,
Expand Down Expand Up @@ -120,6 +122,7 @@ class AssistantAgent(BaseChatAgent):
will be returned as the response.
Available variables: `{tool_name}`, `{arguments}`, `{result}`.
For example, `"{tool_name}: {result}"` will create a summary like `"tool_name: result"`.
memory (Sequence[Memory] | None, optional): The memory store to use for the agent. Defaults to `None`.

Raises:
ValueError: If tool names are not unique.
Expand Down Expand Up @@ -240,9 +243,20 @@ def __init__(
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
reflect_on_tool_use: bool = False,
tool_call_summary_format: str = "{result}",
memory: Sequence[Memory] | None = None,
):
super().__init__(name=name, description=description)
self._model_client = model_client
self._memory = None
if memory is not None:
if isinstance(memory, list):
self._memory = memory
else:
raise TypeError(f"Expected Memory, List[Memory], or None, got {type(memory)}")

self._system_messages: List[
SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage
] = []
if system_message is None:
self._system_messages = []
else:
Expand Down Expand Up @@ -325,6 +339,15 @@ async def on_messages_stream(
# Inner messages.
inner_messages: List[AgentEvent | ChatMessage] = []

# Update the model context with memory content.
if self._memory:
for memory in self._memory:
memory_query_result = await memory.update_context(self._model_context)
if memory_query_result and len(memory_query_result) > 0:
memory_query_event_msg = MemoryQueryEvent(content=memory_query_result, source=self.name)
inner_messages.append(memory_query_event_msg)
yield memory_query_event_msg

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class and includes specific fields relevant to the type of message being sent.
from typing import List, Literal

from autogen_core import FunctionCall, Image
from autogen_core.memory import MemoryContent
from autogen_core.models import FunctionExecutionResult, RequestUsage
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
Expand Down Expand Up @@ -103,13 +104,22 @@ class ToolCallSummaryMessage(BaseChatMessage):
type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage"


class MemoryQueryEvent(BaseAgentEvent):
"""An event signaling the results of memory queries."""

content: List[MemoryContent]
"""The memory query results."""

type: Literal["MemoryQueryEvent"] = "MemoryQueryEvent"


ChatMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
]
"""Messages for agent-to-agent communication only."""


AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent, Field(discriminator="type")]
AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent | MemoryQueryEvent, Field(discriminator="type")]
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""


Expand All @@ -122,6 +132,7 @@ class ToolCallSummaryMessage(BaseChatMessage):
"ToolCallRequestEvent",
"ToolCallExecutionEvent",
"ToolCallSummaryMessage",
"MemoryQueryEvent",
"ChatMessage",
"AgentEvent",
]
76 changes: 75 additions & 1 deletion python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from autogen_agentchat.messages import (
ChatMessage,
HandoffMessage,
MemoryQueryEvent,
MultiModalMessage,
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from autogen_core import Image
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType
from autogen_core.model_context import BufferedChatCompletionContext
from autogen_core.models import LLMMessage
from autogen_core.models._model_client import ModelFamily
Expand Down Expand Up @@ -508,4 +510,76 @@ async def test_model_context(monkeypatch: pytest.MonkeyPatch) -> None:

# Check if the mock client is called with only the last two messages.
assert len(mock.calls) == 1
assert len(mock.calls[0]) == 3 # 2 message from the context + 1 system message
# 2 message from the context + 1 system message
assert len(mock.calls[0]) == 3


@pytest.mark.asyncio
async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="Hello", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
b64_image_str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)

# Test basic memory properties and empty context
memory = ListMemory(name="test_memory")
assert memory.name == "test_memory"

empty_context = BufferedChatCompletionContext(buffer_size=2)
empty_results = await memory.update_context(empty_context)
assert len(empty_results) == 0

# Test various content types
memory = ListMemory()
await memory.add(MemoryContent(content="text content", mime_type=MemoryMimeType.TEXT))
await memory.add(MemoryContent(content={"key": "value"}, mime_type=MemoryMimeType.JSON))
await memory.add(MemoryContent(content=Image.from_base64(b64_image_str), mime_type=MemoryMimeType.IMAGE))

# Test clear and cleanup
await memory.clear()
assert await memory.query(MemoryContent(content="", mime_type=MemoryMimeType.TEXT)) == []
await memory.close() # Should not raise

# Test invalid memory type
with pytest.raises(TypeError):
AssistantAgent(
"test_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
memory="invalid", # type: ignore
)

# Test with agent
memory2 = ListMemory()
await memory2.add(MemoryContent(content="test instruction", mime_type=MemoryMimeType.TEXT))

agent = AssistantAgent(
"test_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), memory=[memory2]
)

result = await agent.run(task="test task")
assert len(result.messages) > 0
memory_event = next((msg for msg in result.messages if isinstance(msg, MemoryQueryEvent)), None)
assert memory_event is not None

# Test memory protocol
class BadMemory:
pass

assert not isinstance(BadMemory(), Memory)
assert isinstance(ListMemory(), Memory)
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ tutorial/human-in-the-loop
tutorial/termination
tutorial/custom-agents
tutorial/state
tutorial/memory
```

```{toctree}
Expand Down
Loading
Loading