Open
Description
Question
I'm trying to implement an agent which streams outputs to the user and uses tools. Some tools require user approval, which means that we need to
- suspend the agent pending user approval and thus
- serialize all the messages so far so we can wake up the agent once we have the user's approval.
I'm trying to use Agent.iter
to achieve this, with a function along the lines of:
# Allows sending the user messages over e.g. a WebSocket:
class OutputStreamBase(ABC):
def on_tool_call_start(self, event: FunctionToolCallEvent) -> None:
"""Handle the start of a tool call event."""
...
@abstractmethod
def on_tool_call_end(self, event: FunctionToolResultEvent) -> None:
"""Handle the end of a tool call event."""
...
@abstractmethod
def on_part_start(self, index: int, part: TextPartDelta) -> None:
"""Handle the start of a part event."""
...
@abstractmethod
def on_part_delta(self, delta: TextPartDelta) -> None:
"""Handle the delta of a part event."""
...
@abstractmethod
def on_final_result(self, event: FinalResultEvent) -> None:
"""Handle the final result event."""
...
# Allows storing the current state to e.g. a database:
class BaseStore(ABC):
@abstractmethod
async def add_messages(self, messages: Iterable[ModelMessage]) -> None:
...
@abstractmethod
async def get_messages(self) -> list[ModelMessage]:
...
# Track the status of the agent (run):
RunStatus = Literal['COMPLETED', 'AWAITING_APPROVAL', 'RUNNING']
class RunResult(BaseModel):
status: RunStatus
output: str | None = None
# This is the interesting part:
async def run_agent_graph_step(
agent: Agent,
store: BaseStore,
sink: OutputStreamBase,
user_prompt: Optional[str] = None,
) -> RunResult:
messages = await store.get_messages()
async with agent.iter(user_prompt, message_history=messages) as run:
async for node in run:
if agent.is_user_prompt_node(node):
pass
elif Agent.is_model_request_node(node):
async with node.stream(run.ctx) as request_stream:
async for event in request_stream:
if isinstance(event, PartStartEvent):
if isinstance(event.part, TextPartDelta):
sink.on_part_start(event.index, event.part)
elif isinstance(event, PartDeltaEvent):
if isinstance(event.delta, TextPartDelta):
sink.on_part_delta(event.delta)
elif isinstance(event.delta, ToolCallPartDelta):
pass
elif isinstance(event, FinalResultEvent):
pass
elif agent.is_call_tools_node(node):
async with node.stream(run.ctx) as handle_stream:
async for tool_event in handle_stream:
if isinstance(tool_event, FunctionToolCallEvent):
if requires_approval(tool_event):
# ----> HERE: we need to get new_messages() and serialize them <----
return RunResult(status='AWAITING_APPROVAL')
sink.on_tool_call_start(tool_event)
elif isinstance(tool_event, FunctionToolResultEvent):
sink.on_tool_call_end(tool_event)
elif Agent.is_end_node(node):
assert (
run.result is not None
and run.result.output == node.data.output
)
return RunResult(
status='COMPLETED',
output=node.data.output,
)
await store.add_messages(run.result.new_messages())
return RunResult(status='RUNNING')
Is there a way to get new_messages()
(so far) in the middle of an iter
run? Is there a better way to model user-in-the-loop pausing/resuming of an agent that need tool call approvals?
Additional Context
No response