Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 61 additions & 19 deletions src/strands_tools/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def batch(tool: ToolUse, **kwargs) -> ToolResult:
- If a tool function is not found or an error occurs, it will be captured in the results.
- This tool is designed to work with agents that support dynamic tool invocation.

Sammple output:
Sample output:
{
"status": "success",
"results": [
Expand All @@ -96,41 +96,83 @@ def batch(tool: ToolUse, **kwargs) -> ToolResult:
agent = kwargs.get("agent")
invocations = kwargs.get("invocations", [])
results = []

try:
if not hasattr(agent, "tool") or agent.tool is None:
raise AttributeError("Agent does not have a valid 'tool' attribute.")

for invocation in invocations:
tool_name = invocation.get("name")
arguments = invocation.get("arguments", {})
tool_fn = getattr(agent.tool, tool_name, None)

if callable(tool_fn):
try:
# Only pass JSON-serializable arguments to the tool
# Call the tool function with the provided arguments
result = tool_fn(**arguments)

if result["status"] == "success":
results.append({"json": {"name": tool_name, "status": "success", "result": result}})
else:
results.append(
{"toolUseId": tool_use_id, "status": "error", "content": [{"text": "Tool missing"}]}
)

# Create a consistent result structure
batch_result = {
"name": tool_name,
"status": "success",
"result": result
}
results.append(batch_result)

except Exception as e:
error_msg = f"Error in batch tool: {str(e)}\n{traceback.format_exc()}"
console.print(f"Error in batch tool: {str(e)}")
results.append({"toolUseId": tool_use_id, "status": "error", "content": [{"text": error_msg}]})
else:
results.append(
{
"toolUseId": tool_use_id,
error_msg = f"Error executing tool '{tool_name}': {str(e)}"
console.print(error_msg)

batch_result = {
"name": tool_name,
"status": "error",
"content": [{"text": f"Tool '{tool_name}' not found in agent or tool call failed."}],
"error": str(e),
"traceback": traceback.format_exc()
}
)
results.append(batch_result)
else:
error_msg = f"Tool '{tool_name}' not found in agent"
console.print(error_msg)

batch_result = {
"name": tool_name,
"status": "error",
"error": error_msg
}
results.append(batch_result)

# Create a readable summary for the agent
summary_lines = []
summary_lines.append(f"Batch execution completed with {len(results)} tool(s):")

for result in results:
if result["status"] == "success":
summary_lines.append(f"✓ {result['name']}: Success")
else:
summary_lines.append(f"✗ {result['name']}: Error - {result['error']}")

summary_text = "\n".join(summary_lines)

return {
"toolUseId": tool_use_id,
"status": "success",
"content": results,
"content": [
{
"text": summary_text
},
{
"json": {
"batch_summary": {
"total_tools": len(results),
"successful": len([r for r in results if r["status"] == "success"]),
"failed": len([r for r in results if r["status"] == "error"])
},
"results": results
}
}
]
}

except Exception as e:
error_msg = f"Error in batch tool: {str(e)}\n{traceback.format_exc()}"
console.print(f"Error in batch tool: {str(e)}")
Expand Down
107 changes: 89 additions & 18 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,29 @@
def mock_agent():
"""Fixture to create a mock agent with tools."""
agent = MagicMock()
agent.tool.http_request = MagicMock(return_value={"status": "success", "result": {"ip": "127.0.0.1"}})
agent.tool.use_aws = MagicMock(return_value={"status": "success", "result": {"buckets": ["bucket1", "bucket2"]}})
agent.tool.error_tool = MagicMock(side_effect=Exception("Tool execution failed"))

# Create a mock tool registry that mimics the real agent's tool access pattern
mock_tool_registry = MagicMock()
mock_tool_registry.registry = {
"http_request": MagicMock(return_value={"status": "success", "result": {"ip": "127.0.0.1"}}),
"use_aws": MagicMock(return_value={"status": "success", "result": {"buckets": ["bucket1", "bucket2"]}}),
"error_tool": MagicMock(side_effect=Exception("Tool execution failed"))
}
agent.tool_registry = mock_tool_registry

# Create a custom mock tool object that properly handles getattr
class MockTool:
def __init__(self):
self.http_request = mock_tool_registry.registry["http_request"]
self.use_aws = mock_tool_registry.registry["use_aws"]
self.error_tool = mock_tool_registry.registry["error_tool"]

def __getattr__(self, name):
# Return None for non-existent tools (this will make callable() return False)
return None

agent.tool = MockTool()

return agent


Expand All @@ -27,12 +47,26 @@ def test_batch_success(mock_agent):
assert result["toolUseId"] == "mock_tool_id"
assert result["status"] == "success"
assert len(result["content"]) == 2
assert result["content"][0]["json"]["name"] == "http_request"
assert result["content"][0]["json"]["status"] == "success"
assert result["content"][0]["json"]["result"]["result"]["ip"] == "127.0.0.1"
assert result["content"][1]["json"]["name"] == "use_aws"
assert result["content"][1]["json"]["status"] == "success"
assert result["content"][1]["json"]["result"]["result"]["buckets"] == ["bucket1", "bucket2"]

# Check the summary text
assert "Batch execution completed with 2 tool(s):" in result["content"][0]["text"]
assert "✓ http_request: Success" in result["content"][0]["text"]
assert "✓ use_aws: Success" in result["content"][0]["text"]

# Check the JSON results
json_content = result["content"][1]["json"]
assert json_content["batch_summary"]["total_tools"] == 2
assert json_content["batch_summary"]["successful"] == 2
assert json_content["batch_summary"]["failed"] == 0

results = json_content["results"]
assert len(results) == 2
assert results[0]["name"] == "http_request"
assert results[0]["status"] == "success"
assert results[0]["result"]["result"]["ip"] == "127.0.0.1"
assert results[1]["name"] == "use_aws"
assert results[1]["status"] == "success"
assert results[1]["result"]["result"]["buckets"] == ["bucket1", "bucket2"]


def test_batch_missing_tool(mock_agent):
Expand All @@ -46,10 +80,23 @@ def test_batch_missing_tool(mock_agent):

assert result["toolUseId"] == "mock_tool_id"
assert result["status"] == "success"
assert len(result["content"]) == 1
assert result["content"][0]["toolUseId"] == "mock_tool_id"
assert result["content"][0]["status"] == "error"
assert "Tool missing" in result["content"][0]["content"][0]["text"]
assert len(result["content"]) == 2

# Check the summary text
assert "Batch execution completed with 1 tool(s):" in result["content"][0]["text"]
assert "✗ non_existent_tool: Error" in result["content"][0]["text"]

# Check the JSON results
json_content = result["content"][1]["json"]
assert json_content["batch_summary"]["total_tools"] == 1
assert json_content["batch_summary"]["successful"] == 0
assert json_content["batch_summary"]["failed"] == 1

results = json_content["results"]
assert len(results) == 1
assert results[0]["name"] == "non_existent_tool"
assert results[0]["status"] == "error"
assert "not found in agent" in results[0]["error"]


def test_batch_tool_error(mock_agent):
Expand All @@ -63,10 +110,24 @@ def test_batch_tool_error(mock_agent):

assert result["toolUseId"] == "mock_tool_id"
assert result["status"] == "success"
assert len(result["content"]) == 1
assert result["content"][0]["toolUseId"] == "mock_tool_id"
assert result["content"][0]["status"] == "error"
assert "Error in batch tool" in result["content"][0]["content"][0]["text"]
assert len(result["content"]) == 2

# Check the summary text
assert "Batch execution completed with 1 tool(s):" in result["content"][0]["text"]
assert "✗ error_tool: Error" in result["content"][0]["text"]

# Check the JSON results
json_content = result["content"][1]["json"]
assert json_content["batch_summary"]["total_tools"] == 1
assert json_content["batch_summary"]["successful"] == 0
assert json_content["batch_summary"]["failed"] == 1

results = json_content["results"]
assert len(results) == 1
assert results[0]["name"] == "error_tool"
assert results[0]["status"] == "error"
assert "Tool execution failed" in results[0]["error"]
assert "traceback" in results[0]


def test_batch_no_invocations(mock_agent):
Expand All @@ -78,7 +139,17 @@ def test_batch_no_invocations(mock_agent):

assert result["toolUseId"] == "mock_tool_id"
assert result["status"] == "success"
assert len(result["content"]) == 0
assert len(result["content"]) == 2

# Check the summary text
assert "Batch execution completed with 0 tool(s):" in result["content"][0]["text"]

# Check the JSON results
json_content = result["content"][1]["json"]
assert json_content["batch_summary"]["total_tools"] == 0
assert json_content["batch_summary"]["successful"] == 0
assert json_content["batch_summary"]["failed"] == 0
assert len(json_content["results"]) == 0


def test_batch_top_level_error(mock_agent):
Expand Down