Skip to content

Commit

Permalink
unit tests are working
Browse files Browse the repository at this point in the history
  • Loading branch information
EItanya committed Jan 25, 2025
1 parent 9c4a00b commit 95fc65e
Show file tree
Hide file tree
Showing 4 changed files with 317 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from ._http_tool import HttpTool, HttpToolConfig
from ._http_tool import HttpTool

__all__ = ["HttpTool", "HttpToolConfig"]
__all__ = ["HttpTool"]
Original file line number Diff line number Diff line change
Expand Up @@ -35,46 +35,115 @@ class HttpToolConfig(BaseModel):


class HttpTool(BaseTool[BaseModel, Any], Component[HttpToolConfig]):
"""Adapter for MCP tools to make them compatible with AutoGen.
"""A wrapper for using an HTTP server as a tool.
Args:
server_params (StdioServerParameters): Parameters for the MCP server connection
tool (Tool): The MCP tool to wrap
name (str): The name of the tool.
description (str, optional): A description of the tool.
url (str): The URL to send the request to.
method (str, optional): The HTTP method to use, will default to POST if not provided.
Must be one of "GET", "POST", "PUT", "DELETE", "PATCH".
headers (dict[str, Any], optional): A dictionary of headers to send with the request.
json_schema (dict[str, Any]): A JSON Schema object defining the expected parameters for the tool.
Example:
Simple usage case::
import asyncio
from autogen_ext.tools.http import HttpTool
from autogen_agentchat.agents import AssistantAgent
from autogen_ext.models.openai import OpenAIChatCompletionClient
# Define a JSON schema for a weather API
weather_schema = {
"type": "object",
"properties": {
"city": {"type": "string", "description": "The city to get weather for"},
"country": {"type": "string", "description": "The country code"}
},
"required": ["city"]
}
# Create an HTTP tool for the weather API
weather_tool = HttpTool(
name="get_weather",
description="Get the current weather for a city",
url="https://api.weatherapi.com/v1/current.json",
method="GET",
headers={"key": "your-api-key"},
json_schema=weather_schema
)
async def main():
# Create an assistant with the weather tool
model = OpenAIChatCompletionClient(model="gpt-4")
assistant = AssistantAgent(
"weather_assistant",
model_client=model,
tools=[weather_tool]
)
# The assistant can now use the weather tool to get weather data
response = await assistant.on_messages([
TextMessage(content="What's the weather like in London?")
])
print(response.chat_message.content)
asyncio.run(main())
"""

def __init__(self, server_params: HttpToolConfig) -> None:
self.server_params = server_params
component_type = "agent"
component_provider_override = "autogen_ext.tools.http.HttpTool"
component_config_schema = HttpToolConfig

def __init__(
self,
name: str,
url: str,
json_schema: dict[str, Any],
headers: Optional[dict[str, Any]],
description: str = "HTTP tool",
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "POST",
) -> None:
self.server_params = HttpToolConfig(
name=name,
description=description,
url=url,
method=method,
headers=headers,
json_schema=json_schema,
)

# Extract name and description
name = server_params.name
description = server_params.description or ""
name = self.server_params.name
description = self.server_params.description or ""

# Create the input model from the tool's schema
input_model = create_model(server_params.json_schema)
input_model = create_model(self.server_params.json_schema)

# Use Any as return type since MCP tool returns can vary
return_type: Type[Any] = object

super().__init__(input_model, return_type, name, description)

def _to_config(self) -> HttpToolConfig:
copied_config = self.server_params.copy()
copied_config = self.server_params.model_copy()
return copied_config

@classmethod
def _from_config(cls, config: HttpToolConfig):
copied_config = config.model_copy().model_dump(exclude_none=True)
copied_config = config.model_copy().model_dump()
return cls(**copied_config)

async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
"""Execute the MCP tool with the given arguments.
"""Execute the HTTO tool with the given arguments.
Args:
args: The validated input arguments
cancellation_token: Token for cancelling the operation
Returns:
The result from the MCP tool
The response body from the HTTP call in JSON format
Raises:
Exception: If tool execution fails
Expand All @@ -90,7 +159,7 @@ async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> A
response = await client.delete(self.server_params.url, params=args.model_dump())
case "PATCH":
response = await client.patch(self.server_params.url, json=args.model_dump())
case _: # Default case
case _: # Default case
response = await client.post(self.server_params.url, json=args.model_dump())

return response.json()
89 changes: 89 additions & 0 deletions python/packages/autogen-ext/tests/tools/http/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import asyncio
from typing import AsyncGenerator

import pytest
import pytest_asyncio
import uvicorn
from autogen_core import CancellationToken, ComponentModel
from autogen_ext.tools.http import HttpTool
from fastapi import Body, FastAPI
from pydantic import BaseModel, Field


class TestArgs(BaseModel):
query: str = Field(description="The test query")
value: int = Field(description="A test value")


class TestResponse(BaseModel):
result: str = Field(description="The test result")


# Create a test FastAPI app
app = FastAPI()


@app.post("/test")
async def test_endpoint(body: TestArgs = Body(...)) -> TestResponse:
return TestResponse(result=f"Received: {body.query} with value {body.value}")


@app.get("/test")
async def test_get_endpoint(query: str, value: int) -> TestResponse:
return TestResponse(result=f"Received: {query} with value {value}")


@app.put("/test")
async def test_put_endpoint(body: TestArgs = Body(...)) -> TestResponse:
return TestResponse(result=f"Received: {body.query} with value {body.value}")


@app.delete("/test")
async def test_delete_endpoint(query: str, value: int) -> TestResponse:
return TestResponse(result=f"Received: {query} with value {value}")


@app.patch("/test")
async def test_patch_endpoint(body: TestArgs = Body(...)) -> TestResponse:
return TestResponse(result=f"Received: {body.query} with value {body.value}")


@pytest.fixture
def test_config() -> ComponentModel:
return ComponentModel(
provider="autogen_ext.tools.http.HttpTool",
config={
"name": "TestHttpTool",
"description": "A test HTTP tool",
"url": "http://localhost:8000/test",
"method": "POST",
"headers": {"Content-Type": "application/json"},
"json_schema": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "The test query"},
"value": {"type": "integer", "description": "A test value"},
},
"required": ["query", "value"],
},
},
)


@pytest_asyncio.fixture
async def test_server() -> AsyncGenerator[None, None]:
# Start the test server
config = uvicorn.Config(app, host="127.0.0.1", port=8000, log_level="error")
server = uvicorn.Server(config)

# Create a task for the server
server_task = asyncio.create_task(server.serve())

# Wait a bit for server to start
await asyncio.sleep(0.5) # Increased sleep time to ensure server is ready

yield

# Cleanup
server.should_exit = True
await server_task
144 changes: 144 additions & 0 deletions python/packages/autogen-ext/tests/tools/http/test_http_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import pytest
import httpx
from pydantic import ValidationError
from autogen_core import CancellationToken
from autogen_ext.tools.http import HttpTool
from autogen_core import Component, ComponentModel


def test_tool_schema_generation(test_config: ComponentModel) -> None:
tool = HttpTool.load_component(test_config)
schema = tool.schema

assert schema["name"] == "TestHttpTool"
assert "description" in schema
assert schema["description"] == "A test HTTP tool"
assert "parameters" in schema
assert schema["parameters"]["type"] == "object"
assert "properties" in schema["parameters"]
assert schema["parameters"]["properties"]["query"]["description"] == "The test query"
assert schema["parameters"]["properties"]["query"]["type"] == "string"
assert schema["parameters"]["properties"]["value"]["description"] == "A test value"
assert schema["parameters"]["properties"]["value"]["type"] == "integer"
assert "required" in schema["parameters"]
assert set(schema["parameters"]["required"]) == {"query", "value"}


def test_tool_properties(test_config: ComponentModel) -> None:
tool = HttpTool.load_component(test_config)

assert tool.name == "TestHttpTool"
assert tool.description == "A test HTTP tool"
assert tool.server_params.url == "http://localhost:8000/test"
assert tool.server_params.method == "POST"


def test_component_base_class(test_config: ComponentModel) -> None:
tool = HttpTool.load_component(test_config)
assert tool.dump_component() is not None
assert HttpTool.load_component(tool.dump_component(), HttpTool) is not None
assert isinstance(tool, Component)


@pytest.mark.asyncio
async def test_post_request(test_config: ComponentModel, test_server: None) -> None:
tool = HttpTool.load_component(test_config)
result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken())

assert isinstance(result, dict)
assert result["result"] == "Received: test query with value 42"


@pytest.mark.asyncio
async def test_get_request(test_config: ComponentModel, test_server: None) -> None:
# Modify config for GET request
config = test_config.model_copy()
config.config["method"] = "GET"
tool = HttpTool.load_component(config)

result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken())

assert isinstance(result, dict)
assert result["result"] == "Received: test query with value 42"


@pytest.mark.asyncio
async def test_put_request(test_config: ComponentModel, test_server: None) -> None:
# Modify config for PUT request
config = test_config.model_copy()
config.config["method"] = "PUT"
tool = HttpTool.load_component(config)

result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken())

assert isinstance(result, dict)
assert result["result"] == "Received: test query with value 42"


@pytest.mark.asyncio
async def test_delete_request(test_config: ComponentModel, test_server: None) -> None:
# Modify config for DELETE request
config = test_config.model_copy()
config.config["method"] = "DELETE"
tool = HttpTool.load_component(config)

result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken())

assert isinstance(result, dict)
assert result["result"] == "Received: test query with value 42"


@pytest.mark.asyncio
async def test_patch_request(test_config: ComponentModel, test_server: None) -> None:
# Modify config for PATCH request
config = test_config.model_copy()
config.config["method"] = "PATCH"
tool = HttpTool.load_component(config)

result = await tool.run_json({"query": "test query", "value": 42}, CancellationToken())

assert isinstance(result, dict)
assert result["result"] == "Received: test query with value 42"


@pytest.mark.asyncio
async def test_invalid_schema(test_config: ComponentModel, test_server: None) -> None:
# Create an invalid schema missing required properties
config: ComponentModel = test_config.model_copy()
config.config["url"] = True # Incorrect type

with pytest.raises(ValidationError):
# Should fail when trying to create model from invalid schema
HttpTool.load_component(config)


@pytest.mark.asyncio
async def test_invalid_request(test_config: ComponentModel, test_server: None) -> None:
# Use an invalid URL
config = test_config.model_copy()
config.config["url"] = "http://fake:8000/nonexistent"
tool = HttpTool.load_component(config)

with pytest.raises(httpx.ConnectError):
await tool.run_json({"query": "test query", "value": 42}, CancellationToken())


def test_config_serialization(test_config: ComponentModel) -> None:
tool = HttpTool.load_component(test_config)
config = tool._to_config()

assert config.name == test_config.config["name"]
assert config.description == test_config.config["description"]
assert config.url == test_config.config["url"]
assert config.method == test_config.config["method"]
assert config.headers == test_config.config["headers"]


def test_config_deserialization(test_config: ComponentModel) -> None:
tool = HttpTool.load_component(test_config)

assert tool.name == test_config.config["name"]
assert tool.description == test_config.config["description"]
assert tool.server_params.url == test_config.config["url"]
assert tool.server_params.method == test_config.config["method"]
assert tool.server_params.headers == test_config.config["headers"]

0 comments on commit 95fc65e

Please sign in to comment.