-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
442 additions
and
316 deletions.
There are no files selected for viewing
147 changes: 147 additions & 0 deletions
147
python/packages/autogen-ext/tests/test_filesurfer_agent.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import asyncio | ||
import json | ||
import logging | ||
import os | ||
from datetime import datetime | ||
from typing import Any, AsyncGenerator, List | ||
|
||
import aiofiles | ||
import pytest | ||
from autogen_agentchat import EVENT_LOGGER_NAME | ||
from autogen_ext.agents.file_surfer import FileSurfer | ||
from autogen_ext.models.openai import OpenAIChatCompletionClient | ||
from openai.resources.chat.completions import AsyncCompletions | ||
from openai.types.chat.chat_completion import ChatCompletion, Choice | ||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk | ||
from openai.types.chat.chat_completion_message import ChatCompletionMessage | ||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function | ||
from openai.types.completion_usage import CompletionUsage | ||
from pydantic import BaseModel | ||
|
||
|
||
class FileLogHandler(logging.Handler): | ||
def __init__(self, filename: str) -> None: | ||
super().__init__() | ||
self.filename = filename | ||
self.file_handler = logging.FileHandler(filename) | ||
|
||
def emit(self, record: logging.LogRecord) -> None: | ||
ts = datetime.fromtimestamp(record.created).isoformat() | ||
if isinstance(record.msg, BaseModel): | ||
record.msg = json.dumps( | ||
{ | ||
"timestamp": ts, | ||
"message": record.msg.model_dump(), | ||
"type": record.msg.__class__.__name__, | ||
}, | ||
) | ||
self.file_handler.emit(record) | ||
|
||
|
||
class _MockChatCompletion: | ||
def __init__(self, chat_completions: List[ChatCompletion]) -> None: | ||
self._saved_chat_completions = chat_completions | ||
self._curr_index = 0 | ||
|
||
async def mock_create( | ||
self, *args: Any, **kwargs: Any | ||
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]: | ||
await asyncio.sleep(0.1) | ||
completion = self._saved_chat_completions[self._curr_index] | ||
self._curr_index += 1 | ||
return completion | ||
|
||
|
||
logger = logging.getLogger(EVENT_LOGGER_NAME) | ||
logger.setLevel(logging.DEBUG) | ||
logger.addHandler(FileLogHandler("test_filesurfer_agent.log")) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_run_filesurfer(monkeypatch: pytest.MonkeyPatch) -> None: | ||
# Create a test file | ||
test_file = os.path.abspath("test_filesurfer_agent.html") | ||
async with aiofiles.open(test_file, "wt") as file: | ||
await file.write("""<html> | ||
<head> | ||
<title>FileSurfer test file</title> | ||
</head> | ||
<body> | ||
<h1>FileSurfer test H1</h1> | ||
<p>FileSurfer test body</p> | ||
</body> | ||
</html>""") | ||
|
||
# Mock the API calls | ||
model = "gpt-4o-2024-05-13" | ||
chat_completions = [ | ||
ChatCompletion( | ||
id="id1", | ||
choices=[ | ||
Choice( | ||
finish_reason="tool_calls", | ||
index=0, | ||
message=ChatCompletionMessage( | ||
content=None, | ||
tool_calls=[ | ||
ChatCompletionMessageToolCall( | ||
id="1", | ||
type="function", | ||
function=Function( | ||
name="open_path", | ||
arguments=json.dumps({"path": test_file}), | ||
), | ||
) | ||
], | ||
role="assistant", | ||
), | ||
) | ||
], | ||
created=0, | ||
model=model, | ||
object="chat.completion", | ||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), | ||
), | ||
ChatCompletion( | ||
id="id2", | ||
choices=[ | ||
Choice( | ||
finish_reason="tool_calls", | ||
index=0, | ||
message=ChatCompletionMessage( | ||
content=None, | ||
tool_calls=[ | ||
ChatCompletionMessageToolCall( | ||
id="1", | ||
type="function", | ||
function=Function( | ||
name="open_path", | ||
arguments=json.dumps({"path": os.path.dirname(test_file)}), | ||
), | ||
) | ||
], | ||
role="assistant", | ||
), | ||
) | ||
], | ||
created=0, | ||
model=model, | ||
object="chat.completion", | ||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), | ||
), | ||
] | ||
mock = _MockChatCompletion(chat_completions) | ||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) | ||
agent = FileSurfer( | ||
"FileSurfer", | ||
model_client=OpenAIChatCompletionClient(model=model, api_key=""), | ||
) | ||
|
||
# Get the FileSurfer to read the file, and the directory | ||
assert agent._name == "FileSurfer" # pyright: ignore[reportPrivateUsage] | ||
result = await agent.run(task="Please read the test file") | ||
assert "# FileSurfer test H1" in result.messages[1].content | ||
|
||
result = await agent.run(task="Please read the test directory") | ||
assert "# Index of " in result.messages[1].content | ||
assert "test_filesurfer_agent.html" in result.messages[1].content |
Oops, something went wrong.