Skip to content

Commit c319539

Browse files
committed
Introduce tool_use_behavior on agents
1 parent b09a5bf commit c319539

9 files changed

+355
-24
lines changed

docs/agents.md

+13
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,16 @@ robot_agent = pirate_agent.clone(
130130
instructions="Write like a robot",
131131
)
132132
```
133+
134+
## Forcing tool use
135+
136+
Supplying a list of tools doesn't always mean the LLM will use a tool. You can force tool use by setting [`ModelSettings.tool_choice`][agents.model_settings.ModelSettings.tool_choice]. Valid values are:
137+
138+
1. `auto`, which allows the LLM to decide whether or not to use a tool.
139+
2. `required`, which requires the LLM to use a tool (but it can intelligently decide which tool).
140+
3. `none`, which requires the LLM to _not_ use a tool.
141+
4. Setting a specific string e.g. `my_tool`, which requires the LLM to use that specific tool.
142+
143+
!!! note
144+
145+
If requiring tool use, you should consider setting [`Agent.tool_use_behavior`] to stop the Agent from running when a tool output is produced. Otherwise, the Agent might run in an infinite loop, where the LLM produces a tool call , and the tool result is sent to the LLM, and this infinite loops because the LLM is always forced to use a tool.

src/agents/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from openai import AsyncOpenAI
66

77
from . import _config
8-
from .agent import Agent
8+
from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
99
from .agent_output import AgentOutputSchema
1010
from .computer import AsyncComputer, Button, Computer, Environment
1111
from .exceptions import (
@@ -57,6 +57,7 @@
5757
ComputerTool,
5858
FileSearchTool,
5959
FunctionTool,
60+
FunctionToolResult,
6061
Tool,
6162
WebSearchTool,
6263
default_tool_error_function,
@@ -136,6 +137,8 @@ def enable_verbose_stdout_logging():
136137

137138
__all__ = [
138139
"Agent",
140+
"ToolsToFinalOutputFunction",
141+
"ToolsToFinalOutputResult",
139142
"Runner",
140143
"Model",
141144
"ModelProvider",
@@ -189,6 +192,7 @@ def enable_verbose_stdout_logging():
189192
"AgentUpdatedStreamEvent",
190193
"StreamEvent",
191194
"FunctionTool",
195+
"FunctionToolResult",
192196
"ComputerTool",
193197
"FileSearchTool",
194198
"Tool",

src/agents/_run_impl.py

+75-10
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import inspect
5+
from collections.abc import Awaitable
46
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Any
7+
from typing import TYPE_CHECKING, Any, cast
68

79
from openai.types.responses import (
810
ResponseComputerToolCall,
@@ -25,7 +27,7 @@
2527
from openai.types.responses.response_input_param import ComputerCallOutput
2628
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
2729

28-
from .agent import Agent
30+
from .agent import Agent, ToolsToFinalOutputResult
2931
from .agent_output import AgentOutputSchema
3032
from .computer import AsyncComputer, Computer
3133
from .exceptions import AgentsException, ModelBehaviorError, UserError
@@ -48,7 +50,7 @@
4850
from .models.interface import ModelTracing
4951
from .run_context import RunContextWrapper, TContext
5052
from .stream_events import RunItemStreamEvent, StreamEvent
51-
from .tool import ComputerTool, FunctionTool
53+
from .tool import ComputerTool, FunctionTool, FunctionToolResult
5254
from .tracing import (
5355
SpanError,
5456
Trace,
@@ -70,6 +72,8 @@ class QueueCompleteSentinel:
7072

7173
QUEUE_COMPLETE_SENTINEL = QueueCompleteSentinel()
7274

75+
_NOT_FINAL_OUTPUT = ToolsToFinalOutputResult(is_final_output=False, final_output=None)
76+
7377

7478
@dataclass
7579
class ToolRunHandoff:
@@ -199,7 +203,7 @@ async def execute_tools_and_side_effects(
199203
config=run_config,
200204
),
201205
)
202-
new_step_items.extend(function_results)
206+
new_step_items.extend([result.run_item for result in function_results])
203207
new_step_items.extend(computer_results)
204208

205209
# Second, check if there are any handoffs
@@ -216,6 +220,30 @@ async def execute_tools_and_side_effects(
216220
run_config=run_config,
217221
)
218222

223+
# Third, we'll check if the tool use should result in a final output
224+
check_tool_use = await cls._check_for_final_output_from_tools(
225+
agent=agent,
226+
tool_results=function_results,
227+
context_wrapper=context_wrapper,
228+
config=run_config,
229+
)
230+
231+
if check_tool_use.is_final_output:
232+
# If the output type is str, then let's just stringify it
233+
if not agent.output_type or agent.output_type is str:
234+
check_tool_use.final_output = str(check_tool_use.final_output)
235+
236+
return await cls.execute_final_output(
237+
agent=agent,
238+
original_input=original_input,
239+
new_response=new_response,
240+
pre_step_items=pre_step_items,
241+
new_step_items=new_step_items,
242+
final_output=check_tool_use.final_output,
243+
hooks=hooks,
244+
context_wrapper=context_wrapper,
245+
)
246+
219247
# Now we can check if the model also produced a final output
220248
message_items = [item for item in new_step_items if isinstance(item, MessageOutputItem)]
221249

@@ -355,10 +383,10 @@ async def execute_function_tool_calls(
355383
hooks: RunHooks[TContext],
356384
context_wrapper: RunContextWrapper[TContext],
357385
config: RunConfig,
358-
) -> list[RunItem]:
386+
) -> list[FunctionToolResult]:
359387
async def run_single_tool(
360388
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
361-
) -> str:
389+
) -> Any:
362390
with function_span(func_tool.name) as span_fn:
363391
if config.trace_include_sensitive_data:
364392
span_fn.span_data.input = tool_call.arguments
@@ -404,10 +432,14 @@ async def run_single_tool(
404432
results = await asyncio.gather(*tasks)
405433

406434
return [
407-
ToolCallOutputItem(
408-
output=str(result),
409-
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
410-
agent=agent,
435+
FunctionToolResult(
436+
tool=tool_run.function_tool,
437+
output=result,
438+
run_item=ToolCallOutputItem(
439+
output=result,
440+
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
441+
agent=agent,
442+
),
411443
)
412444
for tool_run, result in zip(tool_runs, results)
413445
]
@@ -646,6 +678,39 @@ def stream_step_result_to_queue(
646678
if event:
647679
queue.put_nowait(event)
648680

681+
@classmethod
682+
async def _check_for_final_output_from_tools(
683+
cls,
684+
*,
685+
agent: Agent[TContext],
686+
tool_results: list[FunctionToolResult],
687+
context_wrapper: RunContextWrapper[TContext],
688+
config: RunConfig,
689+
) -> ToolsToFinalOutputResult:
690+
"""Returns (i, final_output)."""
691+
if not tool_results:
692+
return _NOT_FINAL_OUTPUT
693+
694+
if agent.tool_use_behavior == "run_llm_again":
695+
return _NOT_FINAL_OUTPUT
696+
elif agent.tool_use_behavior == "stop_on_first_tool":
697+
return ToolsToFinalOutputResult(
698+
is_final_output=True, final_output=tool_results[0].output
699+
)
700+
elif callable(agent.tool_use_behavior):
701+
if inspect.iscoroutinefunction(agent.tool_use_behavior):
702+
return await cast(
703+
Awaitable[ToolsToFinalOutputResult],
704+
agent.tool_use_behavior(context_wrapper, tool_results),
705+
)
706+
else:
707+
return cast(
708+
ToolsToFinalOutputResult, agent.tool_use_behavior(context_wrapper, tool_results)
709+
)
710+
else:
711+
logger.error(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
712+
raise UserError(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
713+
649714

650715
class TraceCtxManager:
651716
"""Creates a trace only if there is no current trace, and manages the trace lifecycle."""

src/agents/items.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,10 @@ class ToolCallOutputItem(RunItemBase[Union[FunctionCallOutput, ComputerCallOutpu
129129
raw_item: FunctionCallOutput | ComputerCallOutput
130130
"""The raw item from the model."""
131131

132-
output: str
133-
"""The output of the tool call."""
132+
output: Any
133+
"""The output of the tool call. This is whatever the tool call returned; the `raw_item`
134+
contains a string representation of the output.
135+
"""
134136

135137
type: Literal["tool_call_output_item"] = "tool_call_output_item"
136138

src/agents/tool.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .computer import AsyncComputer, Computer
1616
from .exceptions import ModelBehaviorError
1717
from .function_schema import DocstringStyle, function_schema
18+
from .items import RunItem
1819
from .logger import logger
1920
from .run_context import RunContextWrapper
2021
from .tracing import SpanError
@@ -29,6 +30,18 @@
2930
ToolFunction = Union[ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]]
3031

3132

33+
@dataclass
34+
class FunctionToolResult:
35+
tool: FunctionTool
36+
"""The tool that was run."""
37+
38+
output: Any
39+
"""The output of the tool."""
40+
41+
run_item: RunItem
42+
"""The run item that was produced as a result of the tool call."""
43+
44+
3245
@dataclass
3346
class FunctionTool:
3447
"""A tool that wraps a function. In most cases, you should use the `function_tool` helpers to
@@ -44,15 +57,15 @@ class FunctionTool:
4457
params_json_schema: dict[str, Any]
4558
"""The JSON schema for the tool's parameters."""
4659

47-
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[str]]
60+
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[Any]]
4861
"""A function that invokes the tool with the given context and parameters. The params passed
4962
are:
5063
1. The tool run context.
5164
2. The arguments from the LLM, as a JSON string.
5265
53-
You must return a string representation of the tool output. In case of errors, you can either
54-
raise an Exception (which will cause the run to fail) or return a string error message (which
55-
will be sent back to the LLM).
66+
You must return a string representation of the tool output, or something we can call `str()` on.
67+
In case of errors, you can either raise an Exception (which will cause the run to fail) or
68+
return a string error message (which will be sent back to the LLM).
5669
"""
5770

5871
strict_json_schema: bool = True
@@ -204,7 +217,7 @@ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
204217
strict_json_schema=strict_mode,
205218
)
206219

207-
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str:
220+
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:
208221
try:
209222
json_data: dict[str, Any] = json.loads(input) if input else {}
210223
except Exception as e:
@@ -251,9 +264,9 @@ async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str:
251264
else:
252265
logger.debug(f"Tool {schema.name} returned {result}")
253266

254-
return str(result)
267+
return result
255268

256-
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str:
269+
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any:
257270
try:
258271
return await _on_invoke_tool_impl(ctx, input)
259272
except Exception as e:

src/agents/tracing/span_data.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def export(self) -> dict[str, Any]:
5151
class FunctionSpanData(SpanData):
5252
__slots__ = ("name", "input", "output")
5353

54-
def __init__(self, name: str, input: str | None, output: str | None):
54+
def __init__(self, name: str, input: str | None, output: Any | None):
5555
self.name = name
5656
self.input = input
5757
self.output = output
@@ -65,7 +65,7 @@ def export(self) -> dict[str, Any]:
6565
"type": self.type,
6666
"name": self.name,
6767
"input": self.input,
68-
"output": self.output,
68+
"output": str(self.output) if self.output else None,
6969
}
7070

7171

tests/test_agent_runner.py

+82
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
UserError,
2222
handoff,
2323
)
24+
from agents.agent import ToolsToFinalOutputResult
25+
from agents.tool import FunctionToolResult, function_tool
2426

2527
from .fake_model import FakeModel
2628
from .test_responses import (
@@ -552,3 +554,83 @@ def guardrail_function(
552554

553555
with pytest.raises(OutputGuardrailTripwireTriggered):
554556
await Runner.run(agent, input="user_message")
557+
558+
559+
@function_tool
560+
def test_tool_one():
561+
return Foo(bar="tool_one_result")
562+
563+
564+
@function_tool
565+
def test_tool_two():
566+
return "tool_two_result"
567+
568+
569+
@pytest.mark.asyncio
570+
async def test_tool_use_behavior_first_output():
571+
model = FakeModel()
572+
agent = Agent(
573+
name="test",
574+
model=model,
575+
tools=[get_function_tool("foo", "tool_result"), test_tool_one, test_tool_two],
576+
tool_use_behavior="stop_on_first_tool",
577+
output_type=Foo,
578+
)
579+
580+
model.add_multiple_turn_outputs(
581+
[
582+
# First turn: a message and tool call
583+
[
584+
get_text_message("a_message"),
585+
get_function_tool_call("test_tool_one", None),
586+
get_function_tool_call("test_tool_two", None),
587+
],
588+
]
589+
)
590+
591+
result = await Runner.run(agent, input="user_message")
592+
593+
assert result.final_output == Foo(bar="tool_one_result"), (
594+
"should have used the first tool result"
595+
)
596+
597+
598+
def custom_tool_use_behavior(
599+
context: RunContextWrapper[Any], results: list[FunctionToolResult]
600+
) -> ToolsToFinalOutputResult:
601+
if "test_tool_one" in [result.tool.name for result in results]:
602+
return ToolsToFinalOutputResult(is_final_output=True, final_output="the_final_output")
603+
else:
604+
return ToolsToFinalOutputResult(is_final_output=False, final_output=None)
605+
606+
607+
@pytest.mark.asyncio
608+
async def test_tool_use_behavior_custom_function():
609+
model = FakeModel()
610+
agent = Agent(
611+
name="test",
612+
model=model,
613+
tools=[get_function_tool("foo", "tool_result"), test_tool_one, test_tool_two],
614+
tool_use_behavior=custom_tool_use_behavior,
615+
)
616+
617+
model.add_multiple_turn_outputs(
618+
[
619+
# First turn: a message and tool call
620+
[
621+
get_text_message("a_message"),
622+
get_function_tool_call("test_tool_two", None),
623+
],
624+
# Second turn: a message and tool call
625+
[
626+
get_text_message("a_message"),
627+
get_function_tool_call("test_tool_one", None),
628+
get_function_tool_call("test_tool_two", None),
629+
],
630+
]
631+
)
632+
633+
result = await Runner.run(agent, input="user_message")
634+
635+
assert len(result.raw_responses) == 2, "should have two model responses"
636+
assert result.final_output == "the_final_output", "should have used the custom function"

0 commit comments

Comments
 (0)