Skip to content

Commit cad0464

Browse files
committed
fix: address failing hitl error scenarios
1 parent fdc1bbb commit cad0464

File tree

5 files changed

+228
-32
lines changed

5 files changed

+228
-32
lines changed

src/agents/_run_impl.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,33 @@ def get_approval_identity(approval: ToolApprovalItem) -> str | None:
695695
config=run_config,
696696
)
697697

698+
# Execute shell calls that were approved
699+
shell_results = await cls.execute_shell_calls(
700+
agent=agent,
701+
calls=processed_response.shell_calls,
702+
hooks=hooks,
703+
context_wrapper=context_wrapper,
704+
config=run_config,
705+
)
706+
707+
# Execute local shell calls that were approved
708+
local_shell_results = await cls.execute_local_shell_calls(
709+
agent=agent,
710+
calls=processed_response.local_shell_calls,
711+
hooks=hooks,
712+
context_wrapper=context_wrapper,
713+
config=run_config,
714+
)
715+
716+
# Execute apply_patch calls that were approved
717+
apply_patch_results = await cls.execute_apply_patch_calls(
718+
agent=agent,
719+
calls=processed_response.apply_patch_calls,
720+
hooks=hooks,
721+
context_wrapper=context_wrapper,
722+
config=run_config,
723+
)
724+
698725
# When resuming we receive the original RunItem references; suppress duplicates
699726
# so history and streaming do not double-emit the same items.
700727
# Use object IDs since RunItem objects are not hashable
@@ -715,6 +742,15 @@ def append_if_new(item: RunItem) -> None:
715742
for computer_result in computer_results:
716743
append_if_new(computer_result)
717744

745+
for shell_result in shell_results:
746+
append_if_new(shell_result)
747+
748+
for local_shell_result in local_shell_results:
749+
append_if_new(local_shell_result)
750+
751+
for apply_patch_result in apply_patch_results:
752+
append_if_new(apply_patch_result)
753+
718754
# Run MCP tools that require approval after they get their approval results
719755
# Find MCP approval requests that have corresponding ToolApprovalItems in interruptions
720756
mcp_approval_runs = []
@@ -1043,23 +1079,24 @@ def process_model_response(
10431079
tools_used.append("code_interpreter")
10441080
elif isinstance(output, LocalShellCall):
10451081
items.append(ToolCallItem(raw_item=output, agent=agent))
1046-
if shell_tool:
1082+
if local_shell_tool:
1083+
tools_used.append("local_shell")
1084+
local_shell_calls.append(
1085+
ToolRunLocalShellCall(tool_call=output, local_shell_tool=local_shell_tool)
1086+
)
1087+
elif shell_tool:
10471088
tools_used.append(shell_tool.name)
10481089
shell_calls.append(ToolRunShellCall(tool_call=output, shell_tool=shell_tool))
10491090
else:
10501091
tools_used.append("local_shell")
1051-
if not local_shell_tool:
1052-
_error_tracing.attach_error_to_current_span(
1053-
SpanError(
1054-
message="Local shell tool not found",
1055-
data={},
1056-
)
1057-
)
1058-
raise ModelBehaviorError(
1059-
"Model produced local shell call without a local shell tool."
1092+
_error_tracing.attach_error_to_current_span(
1093+
SpanError(
1094+
message="Local shell tool not found",
1095+
data={},
10601096
)
1061-
local_shell_calls.append(
1062-
ToolRunLocalShellCall(tool_call=output, local_shell_tool=local_shell_tool)
1097+
)
1098+
raise ModelBehaviorError(
1099+
"Model produced local shell call without a local shell tool."
10631100
)
10641101
elif isinstance(output, ResponseCustomToolCall) and _is_apply_patch_name(
10651102
output.name, apply_patch_tool

src/agents/items.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,43 @@ def __post_init__(self) -> None:
441441
else:
442442
self.tool_name = None
443443

444+
def __hash__(self) -> int:
445+
"""Make ToolApprovalItem hashable so it can be added to sets.
446+
447+
This is required for line 783 in _run_impl.py where pending_hosted_mcp_approvals.add()
448+
is called with a ToolApprovalItem.
449+
"""
450+
# Extract call_id or id from raw_item for hashing
451+
if isinstance(self.raw_item, dict):
452+
call_id = self.raw_item.get("call_id") or self.raw_item.get("id")
453+
else:
454+
call_id = getattr(self.raw_item, "call_id", None) or getattr(self.raw_item, "id", None)
455+
456+
# Hash using call_id and tool_name for uniqueness
457+
return hash((call_id, self.tool_name))
458+
459+
def __eq__(self, other: object) -> bool:
460+
"""Check equality based on call_id and tool_name."""
461+
if not isinstance(other, ToolApprovalItem):
462+
return False
463+
464+
# Extract call_id from both items
465+
if isinstance(self.raw_item, dict):
466+
self_call_id = self.raw_item.get("call_id") or self.raw_item.get("id")
467+
else:
468+
self_call_id = getattr(self.raw_item, "call_id", None) or getattr(
469+
self.raw_item, "id", None
470+
)
471+
472+
if isinstance(other.raw_item, dict):
473+
other_call_id = other.raw_item.get("call_id") or other.raw_item.get("id")
474+
else:
475+
other_call_id = getattr(other.raw_item, "call_id", None) or getattr(
476+
other.raw_item, "id", None
477+
)
478+
479+
return self_call_id == other_call_id and self.tool_name == other.tool_name
480+
444481
@property
445482
def name(self) -> str | None:
446483
"""Returns the tool name if available on the raw item or provided explicitly.

src/agents/result.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ class RunResult(RunResultBase):
162162
_original_input: str | list[TResponseInputItem] | None = field(default=None, repr=False)
163163
"""The original input from the first turn. Unlike `input`, this is never updated during the run.
164164
Used by to_state() to preserve the correct originalInput when serializing state."""
165+
max_turns: int = 10
166+
"""The maximum number of turns allowed for this run."""
165167

166168
def __post_init__(self) -> None:
167169
self._last_agent_ref = weakref.ref(self._last_agent)
@@ -218,7 +220,7 @@ def to_state(self) -> Any:
218220
if original_input_for_state is not None
219221
else self.input,
220222
starting_agent=self.last_agent,
221-
max_turns=10, # This will be overridden by the runner
223+
max_turns=self.max_turns,
222224
)
223225

224226
# Populate the state with data from the result

src/agents/run.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,9 @@ async def run(
919919
# Override context with the state's context if not provided
920920
if context is None and run_state._context is not None:
921921
context = run_state._context.context
922+
923+
# Override max_turns with the state's max_turns to preserve it across resumption
924+
max_turns = run_state._max_turns
922925
else:
923926
# Keep original user input separate from session-prepared input
924927
raw_input = cast(Union[str, list[TResponseInputItem]], input)
@@ -1240,6 +1243,7 @@ def _get_approval_identity(
12401243
_tool_use_tracker_snapshot=self._serialize_tool_use_tracker(
12411244
tool_use_tracker
12421245
),
1246+
max_turns=max_turns,
12431247
)
12441248
result._original_input = _copy_str_or_list(original_input)
12451249
return result
@@ -1284,6 +1288,7 @@ def _get_approval_identity(
12841288
_tool_use_tracker_snapshot=self._serialize_tool_use_tracker(
12851289
tool_use_tracker
12861290
),
1291+
max_turns=max_turns,
12871292
)
12881293
if server_conversation_tracker is None:
12891294
# Save both input and output items together at the end.
@@ -1648,6 +1653,7 @@ def _get_approval_identity(
16481653
_tool_use_tracker_snapshot=self._serialize_tool_use_tracker(
16491654
tool_use_tracker
16501655
),
1656+
max_turns=max_turns,
16511657
)
16521658
if run_state is not None:
16531659
result._current_turn_persisted_item_count = (
@@ -1702,6 +1708,7 @@ def _get_approval_identity(
17021708
_tool_use_tracker_snapshot=self._serialize_tool_use_tracker(
17031709
tool_use_tracker
17041710
),
1711+
max_turns=max_turns,
17051712
)
17061713
if run_state is not None:
17071714
result._current_turn_persisted_item_count = (
@@ -1940,6 +1947,10 @@ def run_streamed(
19401947
# Use context from RunState if not provided
19411948
if context is None and run_state._context is not None:
19421949
context = run_state._context.context
1950+
1951+
# Override max_turns with the state's max_turns to preserve it across resumption
1952+
max_turns = run_state._max_turns
1953+
19431954
# Use context wrapper from RunState
19441955
context_wrapper = cast(RunContextWrapper[TContext], run_state._context)
19451956
else:

src/agents/run_state.py

Lines changed: 128 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -681,22 +681,27 @@ def _serialize_current_step(self) -> dict[str, Any] | None:
681681
return None
682682

683683
# Interruptions are wrapped in a "data" field
684+
interruptions_data = []
685+
for item in self._current_step.interruptions:
686+
if isinstance(item, ToolApprovalItem):
687+
interruption_dict = {
688+
"type": "tool_approval_item",
689+
"rawItem": self._camelize_field_names(
690+
item.raw_item.model_dump(exclude_unset=True)
691+
if hasattr(item.raw_item, "model_dump")
692+
else item.raw_item
693+
),
694+
"agent": {"name": item.agent.name},
695+
}
696+
# Include tool_name if present
697+
if item.tool_name is not None:
698+
interruption_dict["toolName"] = item.tool_name
699+
interruptions_data.append(interruption_dict)
700+
684701
return {
685702
"type": "next_step_interruption",
686703
"data": {
687-
"interruptions": [
688-
{
689-
"type": "tool_approval_item",
690-
"rawItem": self._camelize_field_names(
691-
item.raw_item.model_dump(exclude_unset=True)
692-
if hasattr(item.raw_item, "model_dump")
693-
else item.raw_item
694-
),
695-
"agent": {"name": item.agent.name},
696-
}
697-
for item in self._current_step.interruptions
698-
if isinstance(item, ToolApprovalItem)
699-
],
704+
"interruptions": interruptions_data,
700705
},
701706
}
702707

@@ -994,8 +999,44 @@ async def from_string(
994999
# Normalize field names from JSON format (camelCase)
9951000
# to Python format (snake_case)
9961001
normalized_raw_item = _normalize_field_names(item_data["rawItem"])
997-
raw_item = ResponseFunctionToolCall(**normalized_raw_item)
998-
approval_item = ToolApprovalItem(agent=agent, raw_item=raw_item)
1002+
1003+
# Extract tool_name if present (for backwards compatibility)
1004+
tool_name = item_data.get("toolName")
1005+
1006+
# Tool call items can be function calls, shell calls, apply_patch calls,
1007+
# MCP calls, etc. Check the type field to determine which type to deserialize as
1008+
tool_type = normalized_raw_item.get("type")
1009+
1010+
# Try to deserialize based on the type field
1011+
try:
1012+
if tool_type == "function_call":
1013+
raw_item = ResponseFunctionToolCall(**normalized_raw_item)
1014+
elif tool_type == "shell_call":
1015+
# Shell calls use dict format, not a specific type
1016+
raw_item = normalized_raw_item # type: ignore[assignment]
1017+
elif tool_type == "apply_patch_call":
1018+
# Apply patch calls use dict format
1019+
raw_item = normalized_raw_item # type: ignore[assignment]
1020+
elif tool_type == "hosted_tool_call":
1021+
# MCP/hosted tool calls use dict format
1022+
raw_item = normalized_raw_item # type: ignore[assignment]
1023+
elif tool_type == "local_shell_call":
1024+
# Local shell calls use dict format
1025+
raw_item = normalized_raw_item # type: ignore[assignment]
1026+
else:
1027+
# Default to trying ResponseFunctionToolCall for backwards compatibility
1028+
try:
1029+
raw_item = ResponseFunctionToolCall(**normalized_raw_item)
1030+
except Exception:
1031+
# If that fails, use dict as-is
1032+
raw_item = normalized_raw_item # type: ignore[assignment]
1033+
except Exception:
1034+
# If deserialization fails, use dict for flexibility
1035+
raw_item = normalized_raw_item # type: ignore[assignment]
1036+
1037+
approval_item = ToolApprovalItem(
1038+
agent=agent, raw_item=raw_item, tool_name=tool_name
1039+
)
9991040
interruptions.append(approval_item)
10001041

10011042
# Import at runtime to avoid circular import
@@ -1172,8 +1213,44 @@ async def from_json(
11721213
# Normalize field names from JSON format (camelCase)
11731214
# to Python format (snake_case)
11741215
normalized_raw_item = _normalize_field_names(item_data["rawItem"])
1175-
raw_item = ResponseFunctionToolCall(**normalized_raw_item)
1176-
approval_item = ToolApprovalItem(agent=agent, raw_item=raw_item)
1216+
1217+
# Extract tool_name if present (for backwards compatibility)
1218+
tool_name = item_data.get("toolName")
1219+
1220+
# Tool call items can be function calls, shell calls, apply_patch calls,
1221+
# MCP calls, etc. Check the type field to determine which type to deserialize as
1222+
tool_type = normalized_raw_item.get("type")
1223+
1224+
# Try to deserialize based on the type field
1225+
try:
1226+
if tool_type == "function_call":
1227+
raw_item = ResponseFunctionToolCall(**normalized_raw_item)
1228+
elif tool_type == "shell_call":
1229+
# Shell calls use dict format, not a specific type
1230+
raw_item = normalized_raw_item # type: ignore[assignment]
1231+
elif tool_type == "apply_patch_call":
1232+
# Apply patch calls use dict format
1233+
raw_item = normalized_raw_item # type: ignore[assignment]
1234+
elif tool_type == "hosted_tool_call":
1235+
# MCP/hosted tool calls use dict format
1236+
raw_item = normalized_raw_item # type: ignore[assignment]
1237+
elif tool_type == "local_shell_call":
1238+
# Local shell calls use dict format
1239+
raw_item = normalized_raw_item # type: ignore[assignment]
1240+
else:
1241+
# Default to trying ResponseFunctionToolCall for backwards compatibility
1242+
try:
1243+
raw_item = ResponseFunctionToolCall(**normalized_raw_item)
1244+
except Exception:
1245+
# If that fails, use dict as-is
1246+
raw_item = normalized_raw_item # type: ignore[assignment]
1247+
except Exception:
1248+
# If deserialization fails, use dict for flexibility
1249+
raw_item = normalized_raw_item # type: ignore[assignment]
1250+
1251+
approval_item = ToolApprovalItem(
1252+
agent=agent, raw_item=raw_item, tool_name=tool_name
1253+
)
11771254
interruptions.append(approval_item)
11781255

11791256
# Import at runtime to avoid circular import
@@ -1575,8 +1652,40 @@ def _deserialize_items(
15751652
result.append(MessageOutputItem(agent=agent, raw_item=raw_item_msg))
15761653

15771654
elif item_type == "tool_call_item":
1578-
raw_item_tool = ResponseFunctionToolCall(**normalized_raw_item)
1579-
result.append(ToolCallItem(agent=agent, raw_item=raw_item_tool))
1655+
# Tool call items can be function calls, shell calls, apply_patch calls,
1656+
# MCP calls, etc. Check the type field to determine which type to deserialize as
1657+
tool_type = normalized_raw_item.get("type")
1658+
1659+
# Try to deserialize based on the type field
1660+
# If deserialization fails, fall back to using the dict as-is
1661+
try:
1662+
if tool_type == "function_call":
1663+
raw_item_tool = ResponseFunctionToolCall(**normalized_raw_item)
1664+
elif tool_type == "shell_call":
1665+
# Shell calls use dict format, not a specific type
1666+
raw_item_tool = normalized_raw_item # type: ignore[assignment]
1667+
elif tool_type == "apply_patch_call":
1668+
# Apply patch calls use dict format
1669+
raw_item_tool = normalized_raw_item # type: ignore[assignment]
1670+
elif tool_type == "hosted_tool_call":
1671+
# MCP/hosted tool calls use dict format
1672+
raw_item_tool = normalized_raw_item # type: ignore[assignment]
1673+
elif tool_type == "local_shell_call":
1674+
# Local shell calls use dict format
1675+
raw_item_tool = normalized_raw_item # type: ignore[assignment]
1676+
else:
1677+
# Default to trying ResponseFunctionToolCall for backwards compatibility
1678+
try:
1679+
raw_item_tool = ResponseFunctionToolCall(**normalized_raw_item)
1680+
except Exception:
1681+
# If that fails, use dict as-is
1682+
raw_item_tool = normalized_raw_item # type: ignore[assignment]
1683+
1684+
result.append(ToolCallItem(agent=agent, raw_item=raw_item_tool))
1685+
except Exception:
1686+
# If deserialization fails, use dict for flexibility
1687+
raw_item_tool = normalized_raw_item # type: ignore[assignment]
1688+
result.append(ToolCallItem(agent=agent, raw_item=raw_item_tool))
15801689

15811690
elif item_type == "tool_call_output_item":
15821691
# For tool call outputs, validate and convert the raw dict

0 commit comments

Comments
 (0)