Skip to content

Commit e746f9b

Browse files
authored
Merge branch 'strands-agents:main' into feature/vincilb/config-loader
2 parents abda708 + 17ccdd2 commit e746f9b

File tree

2 files changed

+65
-95
lines changed

2 files changed

+65
-95
lines changed

src/strands/agent/agent.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -642,8 +642,11 @@ def _record_tool_execution(
642642
tool_result: The result returned by the tool.
643643
user_message_override: Optional custom message to include.
644644
"""
645+
# Filter tool input parameters to only include those defined in tool spec
646+
filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"])
647+
645648
# Create user message describing the tool call
646-
input_parameters = json.dumps(tool["input"], default=lambda o: f"<<non-serializable: {type(o).__qualname__}>>")
649+
input_parameters = json.dumps(filtered_input, default=lambda o: f"<<non-serializable: {type(o).__qualname__}>>")
647650

648651
user_msg_content: list[ContentBlock] = [
649652
{"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")}
@@ -653,14 +656,21 @@ def _record_tool_execution(
653656
if user_message_override:
654657
user_msg_content.insert(0, {"text": f"{user_message_override}\n"})
655658

659+
# Create filtered tool use for message history
660+
filtered_tool: ToolUse = {
661+
"toolUseId": tool["toolUseId"],
662+
"name": tool["name"],
663+
"input": filtered_input,
664+
}
665+
656666
# Create the message sequence
657667
user_msg: Message = {
658668
"role": "user",
659669
"content": user_msg_content,
660670
}
661671
tool_use_msg: Message = {
662672
"role": "assistant",
663-
"content": [{"toolUse": tool}],
673+
"content": [{"toolUse": filtered_tool}],
664674
}
665675
tool_result_msg: Message = {
666676
"role": "user",
@@ -717,6 +727,25 @@ def _end_agent_trace_span(
717727

718728
self.tracer.end_agent_span(**trace_attributes)
719729

730+
def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]:
731+
"""Filter input parameters to only include those defined in the tool specification.
732+
733+
Args:
734+
tool_name: Name of the tool to get specification for
735+
input_params: Original input parameters
736+
737+
Returns:
738+
Filtered parameters containing only those defined in tool spec
739+
"""
740+
all_tools_config = self.tool_registry.get_all_tools_config()
741+
tool_spec = all_tools_config.get(tool_name)
742+
743+
if not tool_spec or "inputSchema" not in tool_spec:
744+
return input_params.copy()
745+
746+
properties = tool_spec["inputSchema"]["json"]["properties"]
747+
return {k: v for k, v in input_params.items() if k in properties}
748+
720749
def _append_message(self, message: Message) -> None:
721750
"""Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent."""
722751
self.messages.append(message)

tests/strands/agent/test_agent.py

Lines changed: 34 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1738,99 +1738,7 @@ def test_agent_tool_non_serializable_parameter_filtering(agent, mock_randint):
17381738
tool_call_text = user_message["content"][1]["text"]
17391739
assert "agent.tool.tool_decorated direct tool call." in tool_call_text
17401740
assert '"random_string": "test_value"' in tool_call_text
1741-
assert '"non_serializable_agent": "<<non-serializable: Agent>>"' in tool_call_text
1742-
1743-
1744-
def test_agent_tool_multiple_non_serializable_types(agent, mock_randint):
1745-
"""Test filtering of various non-serializable object types."""
1746-
mock_randint.return_value = 123
1747-
1748-
# Create various non-serializable objects
1749-
class CustomClass:
1750-
def __init__(self, value):
1751-
self.value = value
1752-
1753-
non_serializable_objects = {
1754-
"agent": Agent(),
1755-
"custom_object": CustomClass("test"),
1756-
"function": lambda x: x,
1757-
"set_object": {1, 2, 3},
1758-
"complex_number": 3 + 4j,
1759-
"serializable_string": "this_should_remain",
1760-
"serializable_number": 42,
1761-
"serializable_list": [1, 2, 3],
1762-
"serializable_dict": {"key": "value"},
1763-
}
1764-
1765-
# This should not crash
1766-
result = agent.tool.tool_decorated(random_string="test_filtering", **non_serializable_objects)
1767-
1768-
# Verify tool executed successfully
1769-
expected_result = {
1770-
"content": [{"text": "test_filtering"}],
1771-
"status": "success",
1772-
"toolUseId": "tooluse_tool_decorated_123",
1773-
}
1774-
assert result == expected_result
1775-
1776-
# Check the recorded message for proper parameter filtering
1777-
assert len(agent.messages) > 0
1778-
user_message = agent.messages[0]
1779-
tool_call_text = user_message["content"][0]["text"]
1780-
1781-
# Verify serializable objects remain unchanged
1782-
assert '"serializable_string": "this_should_remain"' in tool_call_text
1783-
assert '"serializable_number": 42' in tool_call_text
1784-
assert '"serializable_list": [1, 2, 3]' in tool_call_text
1785-
assert '"serializable_dict": {"key": "value"}' in tool_call_text
1786-
1787-
# Verify non-serializable objects are replaced with descriptive strings
1788-
assert '"agent": "<<non-serializable: Agent>>"' in tool_call_text
1789-
assert (
1790-
'"custom_object": "<<non-serializable: test_agent_tool_multiple_non_serializable_types.<locals>.CustomClass>>"'
1791-
in tool_call_text
1792-
)
1793-
assert '"function": "<<non-serializable: function>>"' in tool_call_text
1794-
assert '"set_object": "<<non-serializable: set>>"' in tool_call_text
1795-
assert '"complex_number": "<<non-serializable: complex>>"' in tool_call_text
1796-
1797-
1798-
def test_agent_tool_serialization_edge_cases(agent, mock_randint):
1799-
"""Test edge cases in parameter serialization filtering."""
1800-
mock_randint.return_value = 999
1801-
1802-
# Test with None values, empty containers, and nested structures
1803-
edge_case_params = {
1804-
"none_value": None,
1805-
"empty_list": [],
1806-
"empty_dict": {},
1807-
"nested_list_with_non_serializable": [1, 2, Agent()], # This should be filtered out
1808-
"nested_dict_serializable": {"nested": {"key": "value"}}, # This should remain
1809-
}
1810-
1811-
result = agent.tool.tool_decorated(random_string="edge_cases", **edge_case_params)
1812-
1813-
# Verify successful execution
1814-
expected_result = {
1815-
"content": [{"text": "edge_cases"}],
1816-
"status": "success",
1817-
"toolUseId": "tooluse_tool_decorated_999",
1818-
}
1819-
assert result == expected_result
1820-
1821-
# Check parameter filtering in recorded message
1822-
assert len(agent.messages) > 0
1823-
user_message = agent.messages[0]
1824-
tool_call_text = user_message["content"][0]["text"]
1825-
1826-
# Verify serializable values remain
1827-
assert '"none_value": null' in tool_call_text
1828-
assert '"empty_list": []' in tool_call_text
1829-
assert '"empty_dict": {}' in tool_call_text
1830-
assert '"nested_dict_serializable": {"nested": {"key": "value"}}' in tool_call_text
1831-
1832-
# Verify non-serializable nested structure is replaced
1833-
assert '"nested_list_with_non_serializable": [1, 2, "<<non-serializable: Agent>>"]' in tool_call_text
1741+
assert '"non_serializable_agent": "<<non-serializable: Agent>>"' not in tool_call_text
18341742

18351743

18361744
def test_agent_tool_no_non_serializable_parameters(agent, mock_randint):
@@ -1882,3 +1790,36 @@ def test_agent_tool_record_direct_tool_call_disabled_with_non_serializable(agent
18821790

18831791
# Verify no messages were recorded
18841792
assert len(agent.messages) == 0
1793+
1794+
1795+
def test_agent_tool_call_parameter_filtering_integration(mock_randint):
1796+
"""Test that tool calls properly filter parameters in message recording."""
1797+
mock_randint.return_value = 42
1798+
1799+
@strands.tool
1800+
def test_tool(action: str) -> str:
1801+
"""Test tool with single parameter."""
1802+
return action
1803+
1804+
agent = Agent(tools=[test_tool])
1805+
1806+
# Call tool with extra non-spec parameters
1807+
result = agent.tool.test_tool(
1808+
action="test_value",
1809+
agent=agent, # Should be filtered out
1810+
extra_param="filtered", # Should be filtered out
1811+
)
1812+
1813+
# Verify tool executed successfully
1814+
assert result["status"] == "success"
1815+
assert result["content"] == [{"text": "test_value"}]
1816+
1817+
# Check that only spec parameters are recorded in message history
1818+
assert len(agent.messages) > 0
1819+
user_message = agent.messages[0]
1820+
tool_call_text = user_message["content"][0]["text"]
1821+
1822+
# Should only contain the 'action' parameter
1823+
assert '"action": "test_value"' in tool_call_text
1824+
assert '"agent"' not in tool_call_text
1825+
assert '"extra_param"' not in tool_call_text

0 commit comments

Comments
 (0)