Skip to content

Commit 427ff7c

Browse files
committed
fix: finish up human-in-the-loop port
1 parent 5dd00a2 commit 427ff7c

File tree

11 files changed

+5064
-746
lines changed

11 files changed

+5064
-746
lines changed

src/agents/_run_impl.py

Lines changed: 417 additions & 5 deletions
Large diffs are not rendered by default.

src/agents/agent.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,43 @@
2929
from .util._types import MaybeAwaitable
3030

3131
if TYPE_CHECKING:
32+
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
33+
3234
from .lifecycle import AgentHooks, RunHooks
3335
from .mcp import MCPServer
3436
from .memory.session import Session
3537
from .result import RunResult
3638
from .run import RunConfig
3739

40+
# Per-process, ephemeral map linking a tool call ID to its nested
41+
# Agent run result within the same run; entry is removed after consumption.
42+
_agent_tool_run_results: dict[str, RunResult] = {}
43+
44+
45+
def save_agent_tool_run_result(
46+
tool_call: ResponseFunctionToolCall | None,
47+
run_result: RunResult,
48+
) -> None:
49+
"""Save the nested agent run result for later consumption.
50+
51+
This is used when an agent is used as a tool. The run result is stored
52+
so that interruptions from the nested agent run can be collected.
53+
"""
54+
if tool_call:
55+
_agent_tool_run_results[tool_call.call_id] = run_result
56+
57+
58+
def consume_agent_tool_run_result(
59+
tool_call: ResponseFunctionToolCall,
60+
) -> RunResult | None:
61+
"""Consume and return the nested agent run result for a tool call.
62+
63+
This retrieves and removes the stored run result. Returns None if
64+
no result was stored for this tool call.
65+
"""
66+
run_result = _agent_tool_run_results.pop(tool_call.call_id, None)
67+
return run_result
68+
3869

3970
@dataclass
4071
class ToolsToFinalOutputResult:
@@ -385,6 +416,8 @@ def as_tool(
385416
custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None,
386417
is_enabled: bool
387418
| Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True,
419+
needs_approval: bool
420+
| Callable[[RunContextWrapper[Any], dict[str, Any], str], Awaitable[bool]] = False,
388421
run_config: RunConfig | None = None,
389422
max_turns: int | None = None,
390423
hooks: RunHooks[TContext] | None = None,
@@ -409,15 +442,24 @@ def as_tool(
409442
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
410443
context and agent and returns whether the tool is enabled. Disabled tools are hidden
411444
from the LLM at runtime.
445+
needs_approval: Whether the tool needs approval before execution.
446+
If True, the run will be interrupted and the tool call will need
447+
to be approved using RunState.approve() or rejected using
448+
RunState.reject() before continuing. Can be a bool
449+
(always/never needs approval) or a function that takes
450+
(run_context, tool_parameters, call_id) and returns whether this
451+
specific call needs approval.
412452
"""
413453

414454
@function_tool(
415455
name_override=tool_name or _transforms.transform_string_function_style(self.name),
416456
description_override=tool_description or "",
417457
is_enabled=is_enabled,
458+
needs_approval=needs_approval,
418459
)
419460
async def run_agent(context: RunContextWrapper, input: str) -> Any:
420461
from .run import DEFAULT_MAX_TURNS, Runner
462+
from .tool_context import ToolContext
421463

422464
resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS
423465

@@ -432,12 +474,24 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any:
432474
conversation_id=conversation_id,
433475
session=session,
434476
)
477+
478+
# Store the run result keyed by tool_call_id so it can be retrieved later
479+
# when the tool_call is available during result processing
480+
# At runtime, context is actually a ToolContext which has tool_call_id
481+
if isinstance(context, ToolContext):
482+
_agent_tool_run_results[context.tool_call_id] = output
483+
435484
if custom_output_extractor:
436485
return await custom_output_extractor(output)
437486

438487
return output.final_output
439488

440-
return run_agent
489+
# Mark the function tool as an agent tool
490+
run_agent_tool = run_agent
491+
run_agent_tool._is_agent_tool = True
492+
run_agent_tool._agent_instance = self
493+
494+
return run_agent_tool
441495

442496
async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None:
443497
if isinstance(self.instructions, str):

src/agents/memory/openai_conversations_session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
6767

6868
async def add_items(self, items: list[TResponseInputItem]) -> None:
6969
session_id = await self._get_session_id()
70+
if not items:
71+
return
72+
7073
await self._openai_client.conversations.items.create(
7174
conversation_id=session_id,
7275
items=items,

src/agents/result.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,13 @@ class RunResult(RunResultBase):
155155
)
156156
_last_processed_response: ProcessedResponse | None = field(default=None, repr=False)
157157
"""The last processed model response. This is needed for resuming from interruptions."""
158+
_tool_use_tracker_snapshot: dict[str, list[str]] = field(default_factory=dict, repr=False)
159+
_current_turn_persisted_item_count: int = 0
160+
"""Number of items from new_items already persisted to session for the
161+
current turn."""
162+
_original_input: str | list[TResponseInputItem] | None = field(default=None, repr=False)
163+
"""The original input from the first turn. Unlike `input`, this is never updated during the run.
164+
Used by to_state() to preserve the correct originalInput when serializing state."""
158165

159166
def __post_init__(self) -> None:
160167
self._last_agent_ref = weakref.ref(self._last_agent)
@@ -204,9 +211,12 @@ def to_state(self) -> Any:
204211
```
205212
"""
206213
# Create a RunState from the current result
214+
original_input_for_state = getattr(self, "_original_input", None)
207215
state = RunState(
208216
context=self.context_wrapper,
209-
original_input=self.input,
217+
original_input=original_input_for_state
218+
if original_input_for_state is not None
219+
else self.input,
210220
starting_agent=self.last_agent,
211221
max_turns=10, # This will be overridden by the runner
212222
)
@@ -217,6 +227,8 @@ def to_state(self) -> Any:
217227
state._input_guardrail_results = self.input_guardrail_results
218228
state._output_guardrail_results = self.output_guardrail_results
219229
state._last_processed_response = self._last_processed_response
230+
state._current_turn_persisted_item_count = self._current_turn_persisted_item_count
231+
state.set_tool_use_tracker_snapshot(self._tool_use_tracker_snapshot)
220232

221233
# If there are interruptions, set the current step
222234
if self.interruptions:
@@ -279,11 +291,32 @@ class RunResultStreaming(RunResultBase):
279291
_output_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False)
280292
_stored_exception: Exception | None = field(default=None, repr=False)
281293

294+
_current_turn_persisted_item_count: int = 0
295+
"""Number of items from new_items already persisted to session for the
296+
current turn."""
297+
298+
_stream_input_persisted: bool = False
299+
"""Whether the input has been persisted to the session. Prevents double-saving."""
300+
301+
_original_input_for_persistence: list[TResponseInputItem] = field(default_factory=list)
302+
"""Original turn input before session history was merged, used for
303+
persistence (matches JS sessionInputOriginalSnapshot)."""
304+
282305
# Soft cancel state
283306
_cancel_mode: Literal["none", "immediate", "after_turn"] = field(default="none", repr=False)
284307

308+
_original_input: str | list[TResponseInputItem] | None = field(default=None, repr=False)
309+
"""The original input from the first turn. Unlike `input`, this is never updated during the run.
310+
Used by to_state() to preserve the correct originalInput when serializing state."""
311+
_tool_use_tracker_snapshot: dict[str, list[str]] = field(default_factory=dict, repr=False)
312+
_state: Any = field(default=None, repr=False)
313+
"""Internal reference to the RunState for streaming results."""
314+
285315
def __post_init__(self) -> None:
286316
self._current_agent_ref = weakref.ref(self.current_agent)
317+
# Store the original input at creation time (it will be set via input field)
318+
if self._original_input is None:
319+
self._original_input = self.input
287320

288321
@property
289322
def last_agent(self) -> Agent[Any]:
@@ -508,9 +541,11 @@ def to_state(self) -> Any:
508541
```
509542
"""
510543
# Create a RunState from the current result
544+
# Use _original_input (the input from the first turn) instead of input
545+
# (which may have been updated during the run)
511546
state = RunState(
512547
context=self.context_wrapper,
513-
original_input=self.input,
548+
original_input=self._original_input if self._original_input is not None else self.input,
514549
starting_agent=self.last_agent,
515550
max_turns=self.max_turns,
516551
)
@@ -522,6 +557,8 @@ def to_state(self) -> Any:
522557
state._output_guardrail_results = self.output_guardrail_results
523558
state._current_turn = self.current_turn
524559
state._last_processed_response = self._last_processed_response
560+
state._current_turn_persisted_item_count = self._current_turn_persisted_item_count
561+
state.set_tool_use_tracker_snapshot(self._tool_use_tracker_snapshot)
525562

526563
# If there are interruptions, set the current step
527564
if self.interruptions:

0 commit comments

Comments
 (0)