Skip to content

Extracting new_messages() in the middle of an Agent.iter run to support user-in-the-loop approval of tool calls #1995

Open
@tibbe

Description

@tibbe

Question

I'm trying to implement an agent which streams outputs to the user and uses tools. Some tools require user approval, which means that we need to

  • suspend the agent pending user approval and thus
  • serialize all the messages so far so we can wake up the agent once we have the user's approval.

I'm trying to use Agent.iter to achieve this, with a function along the lines of:

# Allows sending the user messages over e.g. a WebSocket:
class OutputStreamBase(ABC):
    def on_tool_call_start(self, event: FunctionToolCallEvent) -> None:
        """Handle the start of a tool call event."""
        ...

    @abstractmethod
    def on_tool_call_end(self, event: FunctionToolResultEvent) -> None:
        """Handle the end of a tool call event."""
        ...

    @abstractmethod
    def on_part_start(self, index: int, part: TextPartDelta) -> None:
        """Handle the start of a part event."""
        ...

    @abstractmethod
    def on_part_delta(self, delta: TextPartDelta) -> None:
        """Handle the delta of a part event."""
        ...

    @abstractmethod
    def on_final_result(self, event: FinalResultEvent) -> None:
        """Handle the final result event."""
        ...

# Allows storing the current state to e.g. a database:
class BaseStore(ABC):
    @abstractmethod
    async def add_messages(self, messages: Iterable[ModelMessage]) -> None:
        ...

    @abstractmethod
    async def get_messages(self) -> list[ModelMessage]:
        ...

# Track the status of the agent (run):
RunStatus = Literal['COMPLETED', 'AWAITING_APPROVAL', 'RUNNING']

class RunResult(BaseModel):
    status: RunStatus
    output: str | None = None

# This is the interesting part:
async def run_agent_graph_step(
    agent: Agent,
    store: BaseStore,
    sink: OutputStreamBase,
    user_prompt: Optional[str] = None,
) -> RunResult:
    messages = await store.get_messages()

    async with agent.iter(user_prompt, message_history=messages) as run:
        async for node in run:
            if agent.is_user_prompt_node(node):
                pass
            elif Agent.is_model_request_node(node):
                async with node.stream(run.ctx) as request_stream:
                    async for event in request_stream:
                        if isinstance(event, PartStartEvent):
                            if isinstance(event.part, TextPartDelta):
                                sink.on_part_start(event.index, event.part)
                        elif isinstance(event, PartDeltaEvent):
                            if isinstance(event.delta, TextPartDelta):
                                sink.on_part_delta(event.delta)
                            elif isinstance(event.delta, ToolCallPartDelta):
                                pass
                        elif isinstance(event, FinalResultEvent):
                            pass
            elif agent.is_call_tools_node(node):
                async with node.stream(run.ctx) as handle_stream:
                    async for tool_event in handle_stream:
                        if isinstance(tool_event, FunctionToolCallEvent):
                            if requires_approval(tool_event):
                              # ----> HERE: we need to get new_messages() and serialize them <----
                              return RunResult(status='AWAITING_APPROVAL')
                            sink.on_tool_call_start(tool_event)
                        elif isinstance(tool_event, FunctionToolResultEvent):
                            sink.on_tool_call_end(tool_event)
            elif Agent.is_end_node(node):
                assert (
                    run.result is not None
                    and run.result.output == node.data.output
                )
                return RunResult(
                    status='COMPLETED',
                    output=node.data.output,
                )

    await store.add_messages(run.result.new_messages())
    return RunResult(status='RUNNING')

Is there a way to get new_messages() (so far) in the middle of an iter run? Is there a better way to model user-in-the-loop pausing/resuming of an agent that need tool call approvals?

Additional Context

No response

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions