Skip to content

Commit

Permalink
Use event to signal it is time to prompt for input
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Jan 10, 2025
1 parent 9486169 commit 8a309b7
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 58 deletions.
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import asyncio
import uuid
from contextlib import contextmanager
from contextvars import ContextVar
from inspect import iscoroutinefunction
from typing import Awaitable, Callable, Optional, Sequence, Union, cast
from typing import Any, AsyncGenerator, Awaitable, Callable, ClassVar, Generator, Optional, Sequence, Union, cast

from aioconsole import ainput # type: ignore
from autogen_core import CancellationToken

from ..base import Response
from ..messages import ChatMessage, HandoffMessage, TextMessage
from ..messages import AgentEvent, ChatMessage, HandoffMessage, TextMessage, UserInputRequestedEvent
from ._base_chat_agent import BaseChatAgent

# Define input function types more precisely
SyncInputFunc = Callable[[str], str]
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]


# TODO: ainput doesn't seem to play nicely with jupyter.
# No input window appears in this case.
Expand Down Expand Up @@ -108,7 +107,38 @@ async def cancellable_user_agent():
print(f"BaseException: {e}")
"""

# Define input function types more precisely
SyncInputFunc = Callable[[str], str]
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]
DEFAULT_INPUT_FUNC: ClassVar[InputFuncType] = cancellable_input

class InputRequestContext:
def __init__(self) -> None:
raise RuntimeError(
"InputRequestContext cannot be instantiated. It is a static class that provides context management for user input requests."
)

_INPUT_REQUEST_CONTEXT_VAR: ClassVar[ContextVar[str]] = ContextVar("_INPUT_REQUEST_CONTEXT_VAR")

@classmethod
@contextmanager
def populate_context(cls, ctx: str) -> Generator[None, Any, None]:
""":meta private:"""
token = UserProxyAgent.InputRequestContext._INPUT_REQUEST_CONTEXT_VAR.set(ctx)
try:
yield
finally:
UserProxyAgent.InputRequestContext._INPUT_REQUEST_CONTEXT_VAR.reset(token)

@classmethod
def request_id(cls) -> str:
try:
return cls._INPUT_REQUEST_CONTEXT_VAR.get()
except LookupError as e:
raise RuntimeError(
"InputRequestContext.runtime() must be called within the input callback of a UserProxyAgent."
) from e

def __init__(
self,
Expand Down Expand Up @@ -141,11 +171,11 @@ async def _get_input(self, prompt: str, cancellation_token: Optional[Cancellatio
try:
if self._is_async:
# Cast to AsyncInputFunc for proper typing
async_func = cast(AsyncInputFunc, self.input_func)
async_func = cast(UserProxyAgent.AsyncInputFunc, self.input_func)
return await async_func(prompt, cancellation_token)
else:
# Cast to SyncInputFunc for proper typing
sync_func = cast(SyncInputFunc, self.input_func)
sync_func = cast(UserProxyAgent.SyncInputFunc, self.input_func)
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, sync_func, prompt)

Expand All @@ -154,9 +184,15 @@ async def _get_input(self, prompt: str, cancellation_token: Optional[Cancellatio
except Exception as e:
raise RuntimeError(f"Failed to get user input: {str(e)}") from e

async def on_messages(
self, messages: Sequence[ChatMessage], cancellation_token: Optional[CancellationToken] = None
) -> Response:
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")

async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
"""Handle incoming messages by requesting user input."""
try:
# Check for handoff first
Expand All @@ -165,15 +201,18 @@ async def on_messages(
f"Handoff received from {handoff.source}. Enter your response: " if handoff else "Enter your response: "
)

user_input = await self._get_input(prompt, cancellation_token)
request_id = str(uuid.uuid4())

input_requested_event = UserInputRequestedEvent(request_id=request_id, source=self.name)
yield input_requested_event
with UserProxyAgent.InputRequestContext.populate_context(request_id):
user_input = await self._get_input(prompt, cancellation_token)

# Return appropriate message type based on handoff presence
if handoff:
return Response(
chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name)
)
yield Response(chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name))
else:
return Response(chat_message=TextMessage(content=user_input, source=self.name))
yield Response(chat_message=TextMessage(content=user_input, source=self.name))

except asyncio.CancelledError:
raise
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,25 +103,40 @@ class ToolCallSummaryMessage(BaseChatMessage):
type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage"


class UserInputRequestedEvent(BaseAgentEvent):
"""An event signaling a that the user proxy has requested user input. Published prior to invoking the input callback."""

request_id: str
"""Identifier for the user input request."""

content: Literal[""] = ""
"""Empty content for compat with consumers expecting a content field."""

type: Literal["UserInputRequestedEvent"] = "UserInputRequestedEvent"


ChatMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
]
"""Messages for agent-to-agent communication only."""


AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent, Field(discriminator="type")]
AgentEvent = Annotated[
ToolCallRequestEvent | ToolCallExecutionEvent | UserInputRequestedEvent, Field(discriminator="type")
]
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""


__all__ = [
"AgentEvent",
"BaseMessage",
"TextMessage",
"ChatMessage",
"HandoffMessage",
"MultiModalMessage",
"StopMessage",
"HandoffMessage",
"ToolCallRequestEvent",
"TextMessage",
"ToolCallExecutionEvent",
"ToolCallRequestEvent",
"ToolCallSummaryMessage",
"ChatMessage",
"AgentEvent",
"UserInputRequestedEvent",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
This module implements utility classes for formatting/printing agent messages.
"""

from ._console import Console
from ._console import Console, UserInputManager

__all__ = ["Console"]
__all__ = ["Console", "UserInputManager"]
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import asyncio
import os
import sys
import time
from asyncio import Lock
from typing import AsyncGenerator, List, Optional, TypeVar, cast
from inspect import iscoroutinefunction
from typing import AsyncGenerator, Dict, List, Optional, TypeVar, cast

from aioconsole import aprint # type: ignore
from autogen_core import Image
from autogen_core._cancellation_token import CancellationToken
from autogen_core.models import RequestUsage

from autogen_agentchat.agents._user_proxy_agent import UserProxyAgent
from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, UserInputRequestedEvent


def _is_running_in_iterm() -> bool:
Expand All @@ -23,20 +26,53 @@ def _is_output_a_tty() -> bool:
T = TypeVar("T", bound=TaskResult | Response)


class NoopLock:
async def __aenter__(self) -> None:
pass
class UserInputManager:
def __init__(self, callback: UserProxyAgent.InputFuncType = UserProxyAgent.DEFAULT_INPUT_FUNC):
self.input_events: Dict[str, asyncio.Event] = {}
self.callback = callback

async def __aexit__(self, exc_type: Optional[type], exc: Optional[BaseException], tb: Optional[object]) -> None:
pass
def get_wrapped_callback(self) -> UserProxyAgent.AsyncInputFunc:
async def user_input_func_wrapper(prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
# Lookup the event for the prompt, if it exists wait for it.
# If it doesn't exist, create it and store it.
# Get request ID:
request_id = UserProxyAgent.InputRequestContext.request_id()
if request_id in self.input_events:
event = self.input_events[request_id]
else:
event = asyncio.Event()
self.input_events[request_id] = event

await event.wait()

del self.input_events[request_id]

if iscoroutinefunction(self.callback):
# Cast to AsyncInputFunc for proper typing
async_func = cast(UserProxyAgent.AsyncInputFunc, self.callback)
return await async_func(prompt, cancellation_token)
else:
# Cast to SyncInputFunc for proper typing
sync_func = cast(UserProxyAgent.SyncInputFunc, self.callback)
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, sync_func, prompt)

return user_input_func_wrapper

def notify_event_received(self, request_id: str) -> None:
if request_id in self.input_events:
self.input_events[request_id].set()
else:
event = asyncio.Event()
self.input_events[request_id] = event


async def Console(
stream: AsyncGenerator[AgentEvent | ChatMessage | T, None],
*,
no_inline_images: bool = False,
output_stats: bool = True,
output_lock: Lock | None = None,
user_input_manager: UserInputManager | None = None,
) -> T:
"""
Consumes the message stream from :meth:`~autogen_agentchat.base.TaskRunner.run_stream`
Expand All @@ -57,8 +93,6 @@ async def Console(
start_time = time.time()
total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)

actual_lock: Lock | NoopLock = output_lock or NoopLock()

last_processed: Optional[T] = None

async for message in stream:
Expand All @@ -73,8 +107,8 @@ async def Console(
f"Total completion tokens: {total_usage.completion_tokens}\n"
f"Duration: {duration:.2f} seconds\n"
)
async with actual_lock:
await aprint(output, end="")
await aprint(output, end="")

# mypy ignore
last_processed = message # type: ignore

Expand All @@ -88,8 +122,7 @@ async def Console(
output += f"[Prompt tokens: {message.chat_message.models_usage.prompt_tokens}, Completion tokens: {message.chat_message.models_usage.completion_tokens}]\n"
total_usage.completion_tokens += message.chat_message.models_usage.completion_tokens
total_usage.prompt_tokens += message.chat_message.models_usage.prompt_tokens
async with actual_lock:
await aprint(output, end="")
await aprint(output, end="")

# Print summary.
if output_stats:
Expand All @@ -104,11 +137,14 @@ async def Console(
f"Total completion tokens: {total_usage.completion_tokens}\n"
f"Duration: {duration:.2f} seconds\n"
)
async with actual_lock:
await aprint(output, end="")
await aprint(output, end="")

# mypy ignore
last_processed = message # type: ignore

# We don't want to print UserInputRequestedEvent messages, we just use them to signal the user input event.
elif isinstance(message, UserInputRequestedEvent):
if user_input_manager is not None:
user_input_manager.notify_event_received(message.request_id)
else:
# Cast required for mypy to be happy
message = cast(AgentEvent | ChatMessage, message) # type: ignore
Expand All @@ -118,8 +154,7 @@ async def Console(
output += f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]\n"
total_usage.completion_tokens += message.models_usage.completion_tokens
total_usage.prompt_tokens += message.models_usage.prompt_tokens
async with actual_lock:
await aprint(output, end="")
await aprint(output, end="")

if last_processed is None:
raise ValueError("No TaskResult or Response was processed.")
Expand Down
18 changes: 4 additions & 14 deletions python/packages/magentic-one-cli/src/magentic_one_cli/_m1.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import argparse
import asyncio
import warnings
from typing import Optional

from aioconsole import ainput # type: ignore
from autogen_agentchat.ui import Console
from autogen_core import CancellationToken
from autogen_agentchat.ui._console import UserInputManager
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.teams.magentic_one import MagenticOne

Expand Down Expand Up @@ -40,18 +38,10 @@ def main() -> None:
args = parser.parse_args()

async def run_task(task: str, hil_mode: bool) -> None:
output_lock = asyncio.Lock()

async def cancellable_input(prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
async with output_lock:
task: asyncio.Task[str] = ainput(prompt) # type: ignore
if cancellation_token is not None:
cancellation_token.link_future(task)
return await task

input_manager = UserInputManager()
client = OpenAIChatCompletionClient(model="gpt-4o")
m1 = MagenticOne(client=client, hil_mode=hil_mode, input_func=cancellable_input)
await Console(m1.run_stream(task=task), output_stats=False, output_lock=output_lock)
m1 = MagenticOne(client=client, hil_mode=hil_mode, input_func=input_manager.get_wrapped_callback())
await Console(m1.run_stream(task=task), output_stats=False, user_input_manager=input_manager)

task = args.task[0]
asyncio.run(run_task(task, not args.no_hil))
Expand Down

0 comments on commit 8a309b7

Please sign in to comment.