Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
FunctionToolResult,
Tool,
ToolErrorFunction,
ToolOrigin,
ToolOriginType,
_extract_tool_argument_json_error,
default_tool_error_function,
)
Expand Down Expand Up @@ -802,6 +804,11 @@ async def _run_agent_tool(context: ToolContext, input_json: str) -> Any:
)
run_agent_tool._is_agent_tool = True
run_agent_tool._agent_instance = self
# Set origin tracking on run_agent (the FunctionTool returned by @function_tool)
run_agent_tool._tool_origin = ToolOrigin(
type=ToolOriginType.AGENT_AS_TOOL,
agent_as_tool=self,
)

return run_agent_tool

Expand Down
19 changes: 19 additions & 0 deletions src/agents/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from .exceptions import AgentsException, ModelBehaviorError
from .logger import logger
from .tool import (
ToolOrigin,
ToolOutputFileContent,
ToolOutputImage,
ToolOutputText,
Expand Down Expand Up @@ -248,6 +249,15 @@ class ToolCallItem(RunItemBase[Any]):
description: str | None = None
"""Optional tool description if known at item creation time."""

tool_origin: ToolOrigin | None = field(default=None, repr=False)
"""Information about the origin/source of the tool call. Only set for FunctionTool calls."""

def release_agent(self) -> None:
"""Release agent references including tool_origin.agent_as_tool."""
super().release_agent()
if self.tool_origin is not None:
self.tool_origin.release_agent()


ToolCallOutputTypes: TypeAlias = Union[
FunctionCallOutput,
Expand All @@ -271,6 +281,15 @@ class ToolCallOutputItem(RunItemBase[Any]):

type: Literal["tool_call_output_item"] = "tool_call_output_item"

tool_origin: ToolOrigin | None = field(default=None, repr=False)
"""Information about the origin/source of the tool call. Only set for FunctionTool calls."""

def release_agent(self) -> None:
"""Release agent references including tool_origin.agent_as_tool."""
super().release_agent()
if self.tool_origin is not None:
self.tool_origin.release_agent()

def to_input_item(self) -> TResponseInputItem:
"""Converts the tool output into an input item for the next model turn.

Expand Down
9 changes: 8 additions & 1 deletion src/agents/mcp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
FunctionTool,
Tool,
ToolErrorFunction,
ToolOrigin,
ToolOriginType,
ToolOutputImageDict,
ToolOutputTextDict,
default_tool_error_function,
Expand Down Expand Up @@ -301,14 +303,19 @@ async def invoke_func(ctx: ToolContext[Any], input_json: str) -> ToolOutput:
bool | Callable[[RunContextWrapper[Any], dict[str, Any], str], Awaitable[bool]]
) = server._get_needs_approval_for_tool(tool, agent)

return FunctionTool(
function_tool = FunctionTool(
name=tool.name,
description=tool.description or "",
params_json_schema=schema,
on_invoke_tool=invoke_func,
strict_json_schema=is_strict,
needs_approval=needs_approval,
)
function_tool._tool_origin = ToolOrigin(
type=ToolOriginType.MCP,
mcp_server=server,
)
return function_tool

@staticmethod
def _merge_mcp_meta(
Expand Down
9 changes: 7 additions & 2 deletions src/agents/run_internal/run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
RawResponsesStreamEvent,
RunItemStreamEvent,
)
from ..tool import Tool, dispose_resolved_computers
from ..tool import FunctionTool, Tool, _get_tool_origin_info, dispose_resolved_computers
from ..tracing import Span, SpanError, agent_span, get_current_trace
from ..tracing.model_tracing import get_model_tracing_impl
from ..tracing.span_data import AgentSpanData
Expand Down Expand Up @@ -1216,13 +1216,18 @@ async def run_single_turn_streamed(
# execution behavior in process_model_response).
tool_name = getattr(output_item, "name", None)
tool_description: str | None = None
tool_origin = None
if isinstance(tool_name, str) and tool_name in tool_map:
tool_description = getattr(tool_map[tool_name], "description", None)
tool = tool_map[tool_name]
tool_description = getattr(tool, "description", None)
if isinstance(tool, FunctionTool):
tool_origin = _get_tool_origin_info(tool)

tool_item = ToolCallItem(
raw_item=cast(ToolCallItemTypes, output_item),
agent=agent,
description=tool_description,
tool_origin=tool_origin,
)
streamed_result._event_queue.put_nowait(
RunItemStreamEvent(item=tool_item, name="tool_called")
Expand Down
3 changes: 3 additions & 0 deletions src/agents/run_internal/tool_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
ShellCallOutcome,
ShellCommandOutput,
Tool,
_get_tool_origin_info,
resolve_computer,
)
from ..tool_context import ToolContext
Expand Down Expand Up @@ -973,10 +974,12 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo

run_item: RunItem | None = None
if not nested_interruptions:
tool_origin = _get_tool_origin_info(tool_run.function_tool)
run_item = ToolCallOutputItem(
output=result,
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, result),
agent=agent,
Comment on lines +977 to 981

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve tool_origin on rejected function outputs

This path only sets tool_origin when run_single_tool returns a plain value, but approval rejections return a prebuilt FunctionToolResult (run_item=function_rejection_item(...)) and go through the isinstance(result, FunctionToolResult) branch without enrichment, so rejected function calls emit ToolCallOutputItem with tool_origin=None. In runs that depend on approval workflows, this makes output-item origin metadata inconsistent with successful calls and breaks origin-based tracing/auditing.

Useful? React with 👍 / 👎.

tool_origin=tool_origin,
)
else:
# Skip tool output until nested interruptions are resolved.
Expand Down
9 changes: 8 additions & 1 deletion src/agents/run_internal/turn_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
LocalShellTool,
ShellTool,
Tool,
_get_tool_origin_info,
)
from ..tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult
from ..tracing import SpanError, handoff_span
Expand Down Expand Up @@ -1473,8 +1474,14 @@ def process_model_response(
raise ModelBehaviorError(error)

func_tool = function_map[output.name]
tool_origin = _get_tool_origin_info(func_tool)
items.append(
ToolCallItem(raw_item=output, agent=agent, description=func_tool.description)
ToolCallItem(
raw_item=output,
agent=agent,
description=func_tool.description,
tool_origin=tool_origin,
)
)
functions.append(
ToolRunFunction(
Expand Down
86 changes: 85 additions & 1 deletion src/agents/run_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
HostedMCPTool,
LocalShellTool,
ShellTool,
ToolOrigin,
ToolOriginType,
)
from .tool_guardrails import (
AllowBehavior,
Expand Down Expand Up @@ -635,6 +637,13 @@ def _serialize_item(self, item: RunItem) -> dict[str, Any]:
result["tool_name"] = item.tool_name
if hasattr(item, "description") and item.description is not None:
result["description"] = item.description
if hasattr(item, "tool_origin") and item.tool_origin is not None:
tool_origin_data: dict[str, Any] = {"type": item.tool_origin.type.value}
if item.tool_origin.agent_as_tool is not None:
tool_origin_data["agent_as_tool"] = {"name": item.tool_origin.agent_as_tool.name}
if item.tool_origin.mcp_server is not None:
tool_origin_data["mcp_server"] = {"name": item.tool_origin.mcp_server.name}
result["tool_origin"] = tool_origin_data

return result

Expand Down Expand Up @@ -1918,6 +1927,67 @@ def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]:
return agent_map


def _deserialize_tool_origin(
tool_origin_data: dict[str, Any] | None, agent_map: dict[str, Agent[Any]], agent: Agent[Any]
) -> ToolOrigin | None:
"""Deserialize ToolOrigin from JSON data.

Args:
tool_origin_data: Serialized tool origin dictionary.
agent_map: Map of agent names to agent instances.
agent: The agent associated with this item (used for MCP server lookup).

Returns:
ToolOrigin instance or None if data is missing/invalid.
"""
if not tool_origin_data:
return None

origin_type_str = tool_origin_data.get("type")
if not origin_type_str:
return None

try:
origin_type = ToolOriginType(origin_type_str)
except ValueError:
logger.warning(f"Unknown tool origin type: {origin_type_str}")
return None

agent_as_tool: Agent[Any] | None = None
mcp_server: Any | None = None

if origin_type == ToolOriginType.AGENT_AS_TOOL:
agent_data = tool_origin_data.get("agent_as_tool")
if agent_data and isinstance(agent_data, Mapping):
agent_name = agent_data.get("name")
if agent_name:
agent_as_tool = agent_map.get(agent_name)
if not agent_as_tool:
logger.warning(f"Agent {agent_name} not found in agent map for tool_origin")

elif origin_type == ToolOriginType.MCP:
mcp_data = tool_origin_data.get("mcp_server")
if mcp_data and isinstance(mcp_data, Mapping):
server_name = mcp_data.get("name")
if server_name:
# Try to find the MCP server from the agent's mcp_servers list
mcp_servers = getattr(agent, "mcp_servers", [])
for server in mcp_servers:
if hasattr(server, "name") and server.name == server_name:
mcp_server = server
break
if not mcp_server:
logger.debug(
f"MCP server {server_name} not found in agent's mcp_servers for tool_origin"
)

return ToolOrigin(
type=origin_type,
agent_as_tool=agent_as_tool,
mcp_server=mcp_server,
)


def _deserialize_model_responses(responses_data: list[dict[str, Any]]) -> list[ModelResponse]:
"""Deserialize model responses from JSON data.

Expand Down Expand Up @@ -2019,8 +2089,17 @@ def _resolve_agent_info(
raw_item_tool = _deserialize_tool_call_raw_item(normalized_raw_item)
# Preserve description if it was stored with the item
description = item_data.get("description")
# Preserve tool_origin if it was stored with the item
tool_origin = _deserialize_tool_origin(
item_data.get("tool_origin"), agent_map, agent
)
result.append(
ToolCallItem(agent=agent, raw_item=raw_item_tool, description=description)
ToolCallItem(
agent=agent,
raw_item=raw_item_tool,
description=description,
tool_origin=tool_origin,
)
)

elif item_type == "tool_call_output_item":
Expand All @@ -2029,11 +2108,16 @@ def _resolve_agent_info(
raw_item_output = _deserialize_tool_call_output_raw_item(normalized_raw_item)
if raw_item_output is None:
continue
# Preserve tool_origin if it was stored with the item
tool_origin = _deserialize_tool_origin(
item_data.get("tool_origin"), agent_map, agent
)
result.append(
ToolCallOutputItem(
agent=agent,
raw_item=raw_item_output,
output=item_data.get("output", ""),
tool_origin=tool_origin,
)
)

Expand Down
Loading