From 466848ac6517ff21f6555f40d094b8bd02b98602 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Mon, 13 Jan 2025 10:28:08 -0500 Subject: [PATCH] fix: fix user input in m1 (#4995) * Add lock for input and output management in m1 * Use event to signal it is time to prompt for input * undo stop change * undo changes * Update python/packages/magentic-one-cli/src/magentic_one_cli/_m1.py Co-authored-by: Eric Zhu * reduce exported surface area * fix --------- Co-authored-by: Eric Zhu Co-authored-by: Hussein Mozannar --- .../agents/_user_proxy_agent.py | 60 ++++++++++++++---- .../src/autogen_agentchat/messages.py | 27 ++++++-- .../src/autogen_agentchat/ui/__init__.py | 4 +- .../src/autogen_agentchat/ui/_console.py | 62 +++++++++++++++++-- .../src/autogen_ext/teams/magentic_one.py | 16 ++++- .../src/magentic_one_cli/_m1.py | 17 ++++- 6 files changed, 157 insertions(+), 29 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py index 2ad9a24682f0..89e0b61a50ee 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py @@ -1,15 +1,17 @@ import asyncio +import uuid +from contextlib import contextmanager +from contextvars import ContextVar from inspect import iscoroutinefunction -from typing import Awaitable, Callable, Optional, Sequence, Union, cast +from typing import Any, AsyncGenerator, Awaitable, Callable, ClassVar, Generator, Optional, Sequence, Union, cast from aioconsole import ainput # type: ignore from autogen_core import CancellationToken from ..base import Response -from ..messages import ChatMessage, HandoffMessage, TextMessage +from ..messages import AgentEvent, ChatMessage, HandoffMessage, TextMessage, UserInputRequestedEvent from ._base_chat_agent import BaseChatAgent -# Define input function types more precisely SyncInputFunc = Callable[[str], str] AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]] InputFuncType = Union[SyncInputFunc, AsyncInputFunc] @@ -109,6 +111,33 @@ async def cancellable_user_agent(): print(f"BaseException: {e}") """ + class InputRequestContext: + def __init__(self) -> None: + raise RuntimeError( + "InputRequestContext cannot be instantiated. It is a static class that provides context management for user input requests." + ) + + _INPUT_REQUEST_CONTEXT_VAR: ClassVar[ContextVar[str]] = ContextVar("_INPUT_REQUEST_CONTEXT_VAR") + + @classmethod + @contextmanager + def populate_context(cls, ctx: str) -> Generator[None, Any, None]: + """:meta private:""" + token = UserProxyAgent.InputRequestContext._INPUT_REQUEST_CONTEXT_VAR.set(ctx) + try: + yield + finally: + UserProxyAgent.InputRequestContext._INPUT_REQUEST_CONTEXT_VAR.reset(token) + + @classmethod + def request_id(cls) -> str: + try: + return cls._INPUT_REQUEST_CONTEXT_VAR.get() + except LookupError as e: + raise RuntimeError( + "InputRequestContext.runtime() must be called within the input callback of a UserProxyAgent." + ) from e + def __init__( self, name: str, @@ -153,9 +182,15 @@ async def _get_input(self, prompt: str, cancellation_token: Optional[Cancellatio except Exception as e: raise RuntimeError(f"Failed to get user input: {str(e)}") from e - async def on_messages( - self, messages: Sequence[ChatMessage], cancellation_token: Optional[CancellationToken] = None - ) -> Response: + async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: + async for message in self.on_messages_stream(messages, cancellation_token): + if isinstance(message, Response): + return message + raise AssertionError("The stream should have returned the final result.") + + async def on_messages_stream( + self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken + ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]: """Handle incoming messages by requesting user input.""" try: # Check for handoff first @@ -164,15 +199,18 @@ async def on_messages( f"Handoff received from {handoff.source}. Enter your response: " if handoff else "Enter your response: " ) - user_input = await self._get_input(prompt, cancellation_token) + request_id = str(uuid.uuid4()) + + input_requested_event = UserInputRequestedEvent(request_id=request_id, source=self.name) + yield input_requested_event + with UserProxyAgent.InputRequestContext.populate_context(request_id): + user_input = await self._get_input(prompt, cancellation_token) # Return appropriate message type based on handoff presence if handoff: - return Response( - chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name) - ) + yield Response(chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name)) else: - return Response(chat_message=TextMessage(content=user_input, source=self.name)) + yield Response(chat_message=TextMessage(content=user_input, source=self.name)) except asyncio.CancelledError: raise diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 07fc3123eb4c..21fb32d9d584 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -103,25 +103,40 @@ class ToolCallSummaryMessage(BaseChatMessage): type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage" +class UserInputRequestedEvent(BaseAgentEvent): + """An event signaling a that the user proxy has requested user input. Published prior to invoking the input callback.""" + + request_id: str + """Identifier for the user input request.""" + + content: Literal[""] = "" + """Empty content for compat with consumers expecting a content field.""" + + type: Literal["UserInputRequestedEvent"] = "UserInputRequestedEvent" + + ChatMessage = Annotated[ TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type") ] """Messages for agent-to-agent communication only.""" -AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent, Field(discriminator="type")] +AgentEvent = Annotated[ + ToolCallRequestEvent | ToolCallExecutionEvent | UserInputRequestedEvent, Field(discriminator="type") +] """Events emitted by agents and teams when they work, not used for agent-to-agent communication.""" __all__ = [ + "AgentEvent", "BaseMessage", - "TextMessage", + "ChatMessage", + "HandoffMessage", "MultiModalMessage", "StopMessage", - "HandoffMessage", - "ToolCallRequestEvent", + "TextMessage", "ToolCallExecutionEvent", + "ToolCallRequestEvent", "ToolCallSummaryMessage", - "ChatMessage", - "AgentEvent", + "UserInputRequestedEvent", ] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/ui/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/ui/__init__.py index 65c4f1e07ad9..9cc0837c58c2 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/ui/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/ui/__init__.py @@ -2,6 +2,6 @@ This module implements utility classes for formatting/printing agent messages. """ -from ._console import Console +from ._console import Console, UserInputManager -__all__ = ["Console"] +__all__ = ["Console", "UserInputManager"] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py b/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py index 3c30d01bfd6b..767dc68d8b4e 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py @@ -1,14 +1,17 @@ +import asyncio import os import sys import time -from typing import AsyncGenerator, List, Optional, TypeVar, cast +from inspect import iscoroutinefunction +from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union, cast from aioconsole import aprint # type: ignore -from autogen_core import Image +from autogen_core import CancellationToken, Image from autogen_core.models import RequestUsage +from autogen_agentchat.agents import UserProxyAgent from autogen_agentchat.base import Response, TaskResult -from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage +from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, UserInputRequestedEvent def _is_running_in_iterm() -> bool: @@ -19,14 +22,60 @@ def _is_output_a_tty() -> bool: return sys.stdout.isatty() +SyncInputFunc = Callable[[str], str] +AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]] +InputFuncType = Union[SyncInputFunc, AsyncInputFunc] + T = TypeVar("T", bound=TaskResult | Response) +class UserInputManager: + def __init__(self, callback: InputFuncType): + self.input_events: Dict[str, asyncio.Event] = {} + self.callback = callback + + def get_wrapped_callback(self) -> AsyncInputFunc: + async def user_input_func_wrapper(prompt: str, cancellation_token: Optional[CancellationToken]) -> str: + # Lookup the event for the prompt, if it exists wait for it. + # If it doesn't exist, create it and store it. + # Get request ID: + request_id = UserProxyAgent.InputRequestContext.request_id() + if request_id in self.input_events: + event = self.input_events[request_id] + else: + event = asyncio.Event() + self.input_events[request_id] = event + + await event.wait() + + del self.input_events[request_id] + + if iscoroutinefunction(self.callback): + # Cast to AsyncInputFunc for proper typing + async_func = cast(AsyncInputFunc, self.callback) + return await async_func(prompt, cancellation_token) + else: + # Cast to SyncInputFunc for proper typing + sync_func = cast(SyncInputFunc, self.callback) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, sync_func, prompt) + + return user_input_func_wrapper + + def notify_event_received(self, request_id: str) -> None: + if request_id in self.input_events: + self.input_events[request_id].set() + else: + event = asyncio.Event() + self.input_events[request_id] = event + + async def Console( stream: AsyncGenerator[AgentEvent | ChatMessage | T, None], *, no_inline_images: bool = False, output_stats: bool = False, + user_input_manager: UserInputManager | None = None, ) -> T: """ Consumes the message stream from :meth:`~autogen_agentchat.base.TaskRunner.run_stream` @@ -67,6 +116,7 @@ async def Console( f"Duration: {duration:.2f} seconds\n" ) await aprint(output, end="") + # mypy ignore last_processed = message # type: ignore @@ -96,9 +146,13 @@ async def Console( f"Duration: {duration:.2f} seconds\n" ) await aprint(output, end="") + # mypy ignore last_processed = message # type: ignore - + # We don't want to print UserInputRequestedEvent messages, we just use them to signal the user input event. + elif isinstance(message, UserInputRequestedEvent): + if user_input_manager is not None: + user_input_manager.notify_event_received(message.request_id) else: # Cast required for mypy to be happy message = cast(AgentEvent | ChatMessage, message) # type: ignore diff --git a/python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py b/python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py index 23aca97014c3..fc2e4f6b9129 100644 --- a/python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py +++ b/python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py @@ -1,9 +1,10 @@ import warnings -from typing import List +from typing import Awaitable, Callable, List, Optional, Union from autogen_agentchat.agents import CodeExecutorAgent, UserProxyAgent from autogen_agentchat.base import ChatAgent from autogen_agentchat.teams import MagenticOneGroupChat +from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient from autogen_ext.agents.file_surfer import FileSurfer @@ -12,6 +13,10 @@ from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor from autogen_ext.models.openai._openai_client import BaseOpenAIChatCompletionClient +SyncInputFunc = Callable[[str], str] +AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]] +InputFuncType = Union[SyncInputFunc, AsyncInputFunc] + class MagenticOne(MagenticOneGroupChat): """ @@ -116,7 +121,12 @@ async def example_usage_hil(): """ - def __init__(self, client: ChatCompletionClient, hil_mode: bool = False): + def __init__( + self, + client: ChatCompletionClient, + hil_mode: bool = False, + input_func: InputFuncType | None = None, + ): self.client = client self._validate_client_capabilities(client) @@ -126,7 +136,7 @@ def __init__(self, client: ChatCompletionClient, hil_mode: bool = False): executor = CodeExecutorAgent("Executor", code_executor=LocalCommandLineCodeExecutor()) agents: List[ChatAgent] = [fs, ws, coder, executor] if hil_mode: - user_proxy = UserProxyAgent("User") + user_proxy = UserProxyAgent("User", input_func=input_func) agents.append(user_proxy) super().__init__(agents, model_client=client) diff --git a/python/packages/magentic-one-cli/src/magentic_one_cli/_m1.py b/python/packages/magentic-one-cli/src/magentic_one_cli/_m1.py index e5a07b164939..e7a3f2ed1e89 100644 --- a/python/packages/magentic-one-cli/src/magentic_one_cli/_m1.py +++ b/python/packages/magentic-one-cli/src/magentic_one_cli/_m1.py @@ -1,8 +1,11 @@ import argparse import asyncio import warnings +from typing import Optional -from autogen_agentchat.ui import Console +from aioconsole import ainput # type: ignore +from autogen_agentchat.ui import Console, UserInputManager +from autogen_core import CancellationToken from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_ext.teams.magentic_one import MagenticOne @@ -10,6 +13,13 @@ warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning) +async def cancellable_input(prompt: str, cancellation_token: Optional[CancellationToken]) -> str: + task: asyncio.Task[str] = asyncio.create_task(ainput(prompt)) # type: ignore + if cancellation_token is not None: + cancellation_token.link_future(task) + return await task + + def main() -> None: """ Command-line interface for running a complex task using MagenticOne. @@ -37,9 +47,10 @@ def main() -> None: args = parser.parse_args() async def run_task(task: str, hil_mode: bool) -> None: + input_manager = UserInputManager(callback=cancellable_input) client = OpenAIChatCompletionClient(model="gpt-4o") - m1 = MagenticOne(client=client, hil_mode=hil_mode) - await Console(m1.run_stream(task=task), output_stats=False) + m1 = MagenticOne(client=client, hil_mode=hil_mode, input_func=input_manager.get_wrapped_callback()) + await Console(m1.run_stream(task=task), output_stats=False, user_input_manager=input_manager) task = args.task[0] asyncio.run(run_task(task, not args.no_hil))