Skip to content

Commit 6d35c33

Browse files
authored
Introduce tool_use_behavior on agents (#203)
## Context By default, the outputs of tools are sent to the LLM again. The LLM gets to read the outputs, and produce a new response. There are cases where this is not desired: 1. Every tool results in another round trip, and sometimes the output of the tool is enough. 2. If you force tool use (via model settings `tool_choice=required`), then the agent will just infinite loop. This enables you to have different behavior, e.g. use the first tool output as the final output, or write a custom function to process tool results and potentially produce an output. ## Test plan Added new tests and ran existing tests Also added examples. Closes #117
2 parents 48ff99b + 10aa555 commit 6d35c33

12 files changed

+594
-26
lines changed

docs/agents.md

+13
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,16 @@ robot_agent = pirate_agent.clone(
130130
instructions="Write like a robot",
131131
)
132132
```
133+
134+
## Forcing tool use
135+
136+
Supplying a list of tools doesn't always mean the LLM will use a tool. You can force tool use by setting [`ModelSettings.tool_choice`][agents.model_settings.ModelSettings.tool_choice]. Valid values are:
137+
138+
1. `auto`, which allows the LLM to decide whether or not to use a tool.
139+
2. `required`, which requires the LLM to use a tool (but it can intelligently decide which tool).
140+
3. `none`, which requires the LLM to _not_ use a tool.
141+
4. Setting a specific string e.g. `my_tool`, which requires the LLM to use that specific tool.
142+
143+
!!! note
144+
145+
If requiring tool use, you should consider setting [`Agent.tool_use_behavior`] to stop the Agent from running when a tool output is produced. Otherwise, the Agent might run in an infinite loop, where the LLM produces a tool call , and the tool result is sent to the LLM, and this infinite loops because the LLM is always forced to use a tool.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from typing import Any, Literal
5+
6+
from pydantic import BaseModel
7+
8+
from agents import (
9+
Agent,
10+
FunctionToolResult,
11+
ModelSettings,
12+
RunContextWrapper,
13+
Runner,
14+
ToolsToFinalOutputFunction,
15+
ToolsToFinalOutputResult,
16+
function_tool,
17+
)
18+
19+
"""
20+
This example shows how to force the agent to use a tool. It uses `ModelSettings(tool_choice="required")`
21+
to force the agent to use any tool.
22+
23+
You can run it with 3 options:
24+
1. `default`: The default behavior, which is to send the tool output to the LLM. In this case,
25+
`tool_choice` is not set, because otherwise it would result in an infinite loop - the LLM would
26+
call the tool, the tool would run and send the results to the LLM, and that would repeat
27+
(because the model is forced to use a tool every time.)
28+
2. `first_tool_result`: The first tool result is used as the final output.
29+
3. `custom`: A custom tool use behavior function is used. The custom function receives all the tool
30+
results, and chooses to use the first tool result to generate the final output.
31+
32+
Usage:
33+
python examples/agent_patterns/forcing_tool_use.py -t default
34+
python examples/agent_patterns/forcing_tool_use.py -t first_tool
35+
python examples/agent_patterns/forcing_tool_use.py -t custom
36+
"""
37+
38+
39+
class Weather(BaseModel):
40+
city: str
41+
temperature_range: str
42+
conditions: str
43+
44+
45+
@function_tool
46+
def get_weather(city: str) -> Weather:
47+
print("[debug] get_weather called")
48+
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind")
49+
50+
51+
async def custom_tool_use_behavior(
52+
context: RunContextWrapper[Any], results: list[FunctionToolResult]
53+
) -> ToolsToFinalOutputResult:
54+
weather: Weather = results[0].output
55+
return ToolsToFinalOutputResult(
56+
is_final_output=True, final_output=f"{weather.city} is {weather.conditions}."
57+
)
58+
59+
60+
async def main(tool_use_behavior: Literal["default", "first_tool", "custom"] = "default"):
61+
if tool_use_behavior == "default":
62+
behavior: Literal["run_llm_again", "stop_on_first_tool"] | ToolsToFinalOutputFunction = (
63+
"run_llm_again"
64+
)
65+
elif tool_use_behavior == "first_tool":
66+
behavior = "stop_on_first_tool"
67+
elif tool_use_behavior == "custom":
68+
behavior = custom_tool_use_behavior
69+
70+
agent = Agent(
71+
name="Weather agent",
72+
instructions="You are a helpful agent.",
73+
tools=[get_weather],
74+
tool_use_behavior=behavior,
75+
model_settings=ModelSettings(
76+
tool_choice="required" if tool_use_behavior != "default" else None
77+
),
78+
)
79+
80+
result = await Runner.run(agent, input="What's the weather in Tokyo?")
81+
print(result.final_output)
82+
83+
84+
if __name__ == "__main__":
85+
import argparse
86+
87+
parser = argparse.ArgumentParser()
88+
parser.add_argument(
89+
"-t",
90+
"--tool-use-behavior",
91+
type=str,
92+
required=True,
93+
choices=["default", "first_tool", "custom"],
94+
help="The behavior to use for tool use. Default will cause tool outputs to be sent to the model. "
95+
"first_tool_result will cause the first tool result to be used as the final output. "
96+
"custom will use a custom tool use behavior function.",
97+
)
98+
args = parser.parse_args()
99+
asyncio.run(main(args.tool_use_behavior))

examples/basic/tools.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import asyncio
2+
3+
from pydantic import BaseModel
4+
5+
from agents import Agent, Runner, function_tool
6+
7+
8+
class Weather(BaseModel):
9+
city: str
10+
temperature_range: str
11+
conditions: str
12+
13+
14+
@function_tool
15+
def get_weather(city: str) -> Weather:
16+
print("[debug] get_weather called")
17+
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
18+
19+
20+
agent = Agent(
21+
name="Hello world",
22+
instructions="You are a helpful agent.",
23+
tools=[get_weather],
24+
)
25+
26+
27+
async def main():
28+
result = await Runner.run(agent, input="What's the weather in Tokyo?")
29+
print(result.final_output)
30+
# The weather in Tokyo is sunny.
31+
32+
33+
if __name__ == "__main__":
34+
asyncio.run(main())

src/agents/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from openai import AsyncOpenAI
66

77
from . import _config
8-
from .agent import Agent
8+
from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
99
from .agent_output import AgentOutputSchema
1010
from .computer import AsyncComputer, Button, Computer, Environment
1111
from .exceptions import (
@@ -57,6 +57,7 @@
5757
ComputerTool,
5858
FileSearchTool,
5959
FunctionTool,
60+
FunctionToolResult,
6061
Tool,
6162
WebSearchTool,
6263
default_tool_error_function,
@@ -137,6 +138,8 @@ def enable_verbose_stdout_logging():
137138

138139
__all__ = [
139140
"Agent",
141+
"ToolsToFinalOutputFunction",
142+
"ToolsToFinalOutputResult",
140143
"Runner",
141144
"Model",
142145
"ModelProvider",
@@ -190,6 +193,7 @@ def enable_verbose_stdout_logging():
190193
"AgentUpdatedStreamEvent",
191194
"StreamEvent",
192195
"FunctionTool",
196+
"FunctionToolResult",
193197
"ComputerTool",
194198
"FileSearchTool",
195199
"Tool",

src/agents/_run_impl.py

+89-10
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import inspect
5+
from collections.abc import Awaitable
46
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Any
7+
from typing import TYPE_CHECKING, Any, cast
68

79
from openai.types.responses import (
810
ResponseComputerToolCall,
@@ -25,7 +27,7 @@
2527
from openai.types.responses.response_input_param import ComputerCallOutput
2628
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
2729

28-
from .agent import Agent
30+
from .agent import Agent, ToolsToFinalOutputResult
2931
from .agent_output import AgentOutputSchema
3032
from .computer import AsyncComputer, Computer
3133
from .exceptions import AgentsException, ModelBehaviorError, UserError
@@ -48,7 +50,7 @@
4850
from .models.interface import ModelTracing
4951
from .run_context import RunContextWrapper, TContext
5052
from .stream_events import RunItemStreamEvent, StreamEvent
51-
from .tool import ComputerTool, FunctionTool
53+
from .tool import ComputerTool, FunctionTool, FunctionToolResult
5254
from .tracing import (
5355
SpanError,
5456
Trace,
@@ -70,6 +72,8 @@ class QueueCompleteSentinel:
7072

7173
QUEUE_COMPLETE_SENTINEL = QueueCompleteSentinel()
7274

75+
_NOT_FINAL_OUTPUT = ToolsToFinalOutputResult(is_final_output=False, final_output=None)
76+
7377

7478
@dataclass
7579
class ToolRunHandoff:
@@ -199,7 +203,7 @@ async def execute_tools_and_side_effects(
199203
config=run_config,
200204
),
201205
)
202-
new_step_items.extend(function_results)
206+
new_step_items.extend([result.run_item for result in function_results])
203207
new_step_items.extend(computer_results)
204208

205209
# Second, check if there are any handoffs
@@ -216,6 +220,36 @@ async def execute_tools_and_side_effects(
216220
run_config=run_config,
217221
)
218222

223+
# Third, we'll check if the tool use should result in a final output
224+
check_tool_use = await cls._check_for_final_output_from_tools(
225+
agent=agent,
226+
tool_results=function_results,
227+
context_wrapper=context_wrapper,
228+
config=run_config,
229+
)
230+
231+
if check_tool_use.is_final_output:
232+
# If the output type is str, then let's just stringify it
233+
if not agent.output_type or agent.output_type is str:
234+
check_tool_use.final_output = str(check_tool_use.final_output)
235+
236+
if check_tool_use.final_output is None:
237+
logger.error(
238+
"Model returned a final output of None. Not raising an error because we assume"
239+
"you know what you're doing."
240+
)
241+
242+
return await cls.execute_final_output(
243+
agent=agent,
244+
original_input=original_input,
245+
new_response=new_response,
246+
pre_step_items=pre_step_items,
247+
new_step_items=new_step_items,
248+
final_output=check_tool_use.final_output,
249+
hooks=hooks,
250+
context_wrapper=context_wrapper,
251+
)
252+
219253
# Now we can check if the model also produced a final output
220254
message_items = [item for item in new_step_items if isinstance(item, MessageOutputItem)]
221255

@@ -355,10 +389,10 @@ async def execute_function_tool_calls(
355389
hooks: RunHooks[TContext],
356390
context_wrapper: RunContextWrapper[TContext],
357391
config: RunConfig,
358-
) -> list[RunItem]:
392+
) -> list[FunctionToolResult]:
359393
async def run_single_tool(
360394
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
361-
) -> str:
395+
) -> Any:
362396
with function_span(func_tool.name) as span_fn:
363397
if config.trace_include_sensitive_data:
364398
span_fn.span_data.input = tool_call.arguments
@@ -404,10 +438,14 @@ async def run_single_tool(
404438
results = await asyncio.gather(*tasks)
405439

406440
return [
407-
ToolCallOutputItem(
408-
output=str(result),
409-
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
410-
agent=agent,
441+
FunctionToolResult(
442+
tool=tool_run.function_tool,
443+
output=result,
444+
run_item=ToolCallOutputItem(
445+
output=result,
446+
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
447+
agent=agent,
448+
),
411449
)
412450
for tool_run, result in zip(tool_runs, results)
413451
]
@@ -646,6 +684,47 @@ def stream_step_result_to_queue(
646684
if event:
647685
queue.put_nowait(event)
648686

687+
@classmethod
688+
async def _check_for_final_output_from_tools(
689+
cls,
690+
*,
691+
agent: Agent[TContext],
692+
tool_results: list[FunctionToolResult],
693+
context_wrapper: RunContextWrapper[TContext],
694+
config: RunConfig,
695+
) -> ToolsToFinalOutputResult:
696+
"""Returns (i, final_output)."""
697+
if not tool_results:
698+
return _NOT_FINAL_OUTPUT
699+
700+
if agent.tool_use_behavior == "run_llm_again":
701+
return _NOT_FINAL_OUTPUT
702+
elif agent.tool_use_behavior == "stop_on_first_tool":
703+
return ToolsToFinalOutputResult(
704+
is_final_output=True, final_output=tool_results[0].output
705+
)
706+
elif isinstance(agent.tool_use_behavior, dict):
707+
names = agent.tool_use_behavior.get("stop_at_tool_names", [])
708+
for tool_result in tool_results:
709+
if tool_result.tool.name in names:
710+
return ToolsToFinalOutputResult(
711+
is_final_output=True, final_output=tool_result.output
712+
)
713+
return ToolsToFinalOutputResult(is_final_output=False, final_output=None)
714+
elif callable(agent.tool_use_behavior):
715+
if inspect.iscoroutinefunction(agent.tool_use_behavior):
716+
return await cast(
717+
Awaitable[ToolsToFinalOutputResult],
718+
agent.tool_use_behavior(context_wrapper, tool_results),
719+
)
720+
else:
721+
return cast(
722+
ToolsToFinalOutputResult, agent.tool_use_behavior(context_wrapper, tool_results)
723+
)
724+
725+
logger.error(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
726+
raise UserError(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
727+
649728

650729
class TraceCtxManager:
651730
"""Creates a trace only if there is no current trace, and manages the trace lifecycle."""

0 commit comments

Comments
 (0)