Skip to content

Commit 049ff6d

Browse files
authored
Allow overwrite max_iter in ReAct (#8096)
* allow overriding max_iter in ReAct * add test for retrying
1 parent dc64e12 commit 049ff6d

File tree

2 files changed

+38
-9
lines changed

2 files changed

+38
-9
lines changed

dspy/predict/react.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def _format_trajectory(self, trajectory: dict[str, Any]):
7373

7474
def forward(self, **input_args):
7575
trajectory = {}
76-
for idx in range(self.max_iters):
76+
max_iters = input_args.pop("max_iters", self.max_iters)
77+
for idx in range(max_iters):
7778
pred = self._call_with_potential_trajectory_truncation(self.react, trajectory, **input_args)
7879

7980
trajectory[f"thought_{idx}"] = pred.next_thought
@@ -173,8 +174,4 @@ def truncate_trajectory(self, trajectory):
173174
TOPIC 06: Idiomatically allowing tools that maintain state across iterations, but not across different `forward` calls.
174175
* So the tool would be newly initialized at the start of each `forward` call, but maintain state across iterations.
175176
* This is pretty useful for allowing the agent to keep notes or count certain things, etc.
176-
177-
TOPIC 07: Make max_iters a bit more expressive.
178-
* Allow passing `max_iters` in forward to overwrite the default.
179-
* Get rid of `last_iteration: bool` in the format function. It's not necessary now.
180177
"""

tests/predict/test_react.py

+36-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
from dataclasses import dataclass
2-
31
from pydantic import BaseModel
42

53
import dspy
6-
from dspy.predict import react
7-
from dspy.utils.dummies import DummyLM, dummy_rm
4+
from dspy.utils.dummies import DummyLM
85
import litellm
96

107
# def test_example_no_tools():
@@ -276,3 +273,38 @@ def mock_react(**kwargs):
276273
assert "thought_0" not in result.trajectory
277274
assert "thought_2" in result.trajectory
278275
assert result.output_text == "Final output"
276+
277+
278+
def test_error_retry():
279+
def foo(a, b):
280+
raise Exception("tool error")
281+
282+
react = dspy.ReAct("a, b -> c:int", tools=[foo])
283+
max_iters = 2
284+
lm = DummyLM(
285+
[
286+
{"next_thought": "I need to add two numbers.", "next_tool_name": "foo", "next_tool_args": {"a": 1, "b": 2}},
287+
{"next_thought": "I need to add two numbers.", "next_tool_name": "foo", "next_tool_args": {"a": 1, "b": 2}},
288+
{"reasoning": "I added the numbers successfully", "c": 3},
289+
]
290+
)
291+
dspy.settings.configure(lm=lm)
292+
293+
outputs = react(a=1, b=2, max_iters=max_iters)
294+
expected_trajectory = {
295+
"thought_0": "I need to add two numbers.",
296+
"tool_name_0": "foo",
297+
"tool_args_0": {
298+
"a": 1,
299+
"b": 2,
300+
},
301+
'observation_0': 'Failed to execute: tool error',
302+
'thought_1': 'I need to add two numbers.',
303+
'tool_name_1': 'foo',
304+
"tool_args_1": {
305+
"a": 1,
306+
"b": 2,
307+
},
308+
'observation_1': 'Failed to execute: tool error',
309+
}
310+
assert outputs.trajectory == expected_trajectory

0 commit comments

Comments
 (0)