Skip to content

Commit c498de5

Browse files
Support composite arg type parsing in dspy.Tool (#8095)
* tool calling args parsing * fix wrong comment * Add test
1 parent 049ff6d commit c498de5

File tree

3 files changed

+53
-16
lines changed

3 files changed

+53
-16
lines changed

dspy/predict/react.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import logging
2-
from typing import Any, Callable, Literal, get_origin
2+
from typing import Any, Callable, Literal
33

44
from litellm import ContextWindowExceededError
5-
from pydantic import BaseModel
65

76
import dspy
87
from dspy.primitives.program import Module
@@ -82,18 +81,7 @@ def forward(self, **input_args):
8281
trajectory[f"tool_args_{idx}"] = pred.next_tool_args
8382

8483
try:
85-
parsed_tool_args = {}
86-
tool = self.tools[pred.next_tool_name]
87-
for k, v in pred.next_tool_args.items():
88-
if hasattr(tool, "arg_types") and k in tool.arg_types:
89-
arg_type = tool.arg_types[k]
90-
if isinstance((origin := get_origin(arg_type) or arg_type), type) and issubclass(
91-
origin, BaseModel
92-
):
93-
parsed_tool_args[k] = arg_type.model_validate(v)
94-
continue
95-
parsed_tool_args[k] = v
96-
trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**parsed_tool_args)
84+
trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args)
9785
except Exception as e:
9886
trajectory[f"observation_{idx}"] = f"Failed to execute: {e}"
9987

dspy/primitives/tool.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Any, Callable, Optional, get_origin, get_type_hints
33

44
from jsonschema import ValidationError, validate
5-
from pydantic import BaseModel, TypeAdapter
5+
from pydantic import BaseModel, TypeAdapter, create_model
66

77
from dspy.utils.callback import with_callbacks
88

@@ -130,6 +130,19 @@ def _parse_function(self, func: Callable, arg_desc: dict[str, str] = None):
130130
self.args = self.args or args
131131
self.arg_types = self.arg_types or arg_types
132132

133+
def _parse_args(self, **kwargs):
134+
parsed_kwargs = {}
135+
for k, v in kwargs.items():
136+
if k in self.arg_types and self.arg_types[k] != any:
137+
# Create a pydantic model wrapper with a dummy field `value` to parse the arg to the correct type.
138+
# This is specifically useful for handling nested Pydantic models like `list[list[MyPydanticModel]]`
139+
pydantic_wrapper = create_model("Wrapper", value=(self.arg_types[k], ...))
140+
parsed = pydantic_wrapper.model_validate({"value": v})
141+
parsed_kwargs[k] = parsed.value
142+
else:
143+
parsed_kwargs[k] = v
144+
return parsed_kwargs
145+
133146
@with_callbacks
134147
def __call__(self, **kwargs):
135148
for k, v in kwargs.items():
@@ -142,4 +155,6 @@ def __call__(self, **kwargs):
142155
validate(instance=instance, schema=self.args[k])
143156
except ValidationError as e:
144157
raise ValueError(f"Arg {k} is invalid: {e.message}")
145-
return self.func(**kwargs)
158+
159+
parsed_kwargs = self._parse_args(**kwargs)
160+
return self.func(**parsed_kwargs)

tests/primitives/test_tool.py

+34
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,37 @@ def foo(x=100):
159159
tool = Tool(foo)
160160
assert tool.args["x"]["default"] == 100
161161
assert not hasattr(tool.args["x"], "type")
162+
163+
164+
def test_tool_call_parses_args():
165+
tool = Tool(dummy_with_pydantic)
166+
167+
args = {
168+
"model": {
169+
"field1": "hello",
170+
"field2": 123,
171+
}
172+
}
173+
174+
result = tool(**args)
175+
assert result == "hello 123"
176+
177+
178+
def test_tool_call_parses_nested_list_of_pydantic_model():
179+
def dummy_function(x: list[list[DummyModel]]):
180+
return x
181+
182+
tool = Tool(dummy_function)
183+
args = {
184+
"x": [
185+
[
186+
{
187+
"field1": "hello",
188+
"field2": 123,
189+
}
190+
]
191+
]
192+
}
193+
194+
result = tool(**args)
195+
assert result == [[DummyModel(field1="hello", field2=123)]]

0 commit comments

Comments
 (0)