Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix user input in m1 #4995

Merged
merged 12 commits into from
Jan 13, 2025
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import asyncio
import uuid
from contextlib import contextmanager
from contextvars import ContextVar
from inspect import iscoroutinefunction
from typing import Awaitable, Callable, Optional, Sequence, Union, cast
from typing import Any, AsyncGenerator, Awaitable, Callable, ClassVar, Generator, Optional, Sequence, Union, cast

from aioconsole import ainput # type: ignore
from autogen_core import CancellationToken

from ..base import Response
from ..messages import ChatMessage, HandoffMessage, TextMessage
from ..messages import AgentEvent, ChatMessage, HandoffMessage, TextMessage, UserInputRequestedEvent
from ._base_chat_agent import BaseChatAgent

# Define input function types more precisely
SyncInputFunc = Callable[[str], str]
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]


# TODO: ainput doesn't seem to play nicely with jupyter.
# No input window appears in this case.
Expand Down Expand Up @@ -109,6 +107,39 @@
print(f"BaseException: {e}")
"""

# Define input function types more precisely
SyncInputFunc = Callable[[str], str]
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]
DEFAULT_INPUT_FUNC: ClassVar[InputFuncType] = cancellable_input

class InputRequestContext:
def __init__(self) -> None:
raise RuntimeError(

Check warning on line 118 in python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py#L118

Added line #L118 was not covered by tests
"InputRequestContext cannot be instantiated. It is a static class that provides context management for user input requests."
)

_INPUT_REQUEST_CONTEXT_VAR: ClassVar[ContextVar[str]] = ContextVar("_INPUT_REQUEST_CONTEXT_VAR")

@classmethod
@contextmanager
def populate_context(cls, ctx: str) -> Generator[None, Any, None]:
""":meta private:"""
token = UserProxyAgent.InputRequestContext._INPUT_REQUEST_CONTEXT_VAR.set(ctx)
try:
yield
finally:
UserProxyAgent.InputRequestContext._INPUT_REQUEST_CONTEXT_VAR.reset(token)

@classmethod
def request_id(cls) -> str:
try:
return cls._INPUT_REQUEST_CONTEXT_VAR.get()
except LookupError as e:
raise RuntimeError(

Check warning on line 139 in python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py#L136-L139

Added lines #L136 - L139 were not covered by tests
"InputRequestContext.runtime() must be called within the input callback of a UserProxyAgent."
) from e

def __init__(
self,
name: str,
Expand Down Expand Up @@ -140,11 +171,11 @@
try:
if self._is_async:
# Cast to AsyncInputFunc for proper typing
async_func = cast(AsyncInputFunc, self.input_func)
async_func = cast(UserProxyAgent.AsyncInputFunc, self.input_func)
return await async_func(prompt, cancellation_token)
else:
# Cast to SyncInputFunc for proper typing
sync_func = cast(SyncInputFunc, self.input_func)
sync_func = cast(UserProxyAgent.SyncInputFunc, self.input_func)
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, sync_func, prompt)

Expand All @@ -153,9 +184,15 @@
except Exception as e:
raise RuntimeError(f"Failed to get user input: {str(e)}") from e

async def on_messages(
self, messages: Sequence[ChatMessage], cancellation_token: Optional[CancellationToken] = None
) -> Response:
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")

Check warning on line 191 in python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py#L191

Added line #L191 was not covered by tests

async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
"""Handle incoming messages by requesting user input."""
try:
# Check for handoff first
Expand All @@ -164,15 +201,18 @@
f"Handoff received from {handoff.source}. Enter your response: " if handoff else "Enter your response: "
)

user_input = await self._get_input(prompt, cancellation_token)
request_id = str(uuid.uuid4())

input_requested_event = UserInputRequestedEvent(request_id=request_id, source=self.name)
yield input_requested_event
with UserProxyAgent.InputRequestContext.populate_context(request_id):
user_input = await self._get_input(prompt, cancellation_token)

# Return appropriate message type based on handoff presence
if handoff:
return Response(
chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name)
)
yield Response(chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name))
else:
return Response(chat_message=TextMessage(content=user_input, source=self.name))
yield Response(chat_message=TextMessage(content=user_input, source=self.name))

except asyncio.CancelledError:
raise
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,25 +103,40 @@ class ToolCallSummaryMessage(BaseChatMessage):
type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage"


class UserInputRequestedEvent(BaseAgentEvent):
"""An event signaling a that the user proxy has requested user input. Published prior to invoking the input callback."""

request_id: str
"""Identifier for the user input request."""

content: Literal[""] = ""
"""Empty content for compat with consumers expecting a content field."""

type: Literal["UserInputRequestedEvent"] = "UserInputRequestedEvent"


ChatMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
]
"""Messages for agent-to-agent communication only."""


AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent, Field(discriminator="type")]
AgentEvent = Annotated[
ToolCallRequestEvent | ToolCallExecutionEvent | UserInputRequestedEvent, Field(discriminator="type")
]
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""


__all__ = [
"AgentEvent",
"BaseMessage",
"TextMessage",
"ChatMessage",
"HandoffMessage",
"MultiModalMessage",
"StopMessage",
"HandoffMessage",
"ToolCallRequestEvent",
"TextMessage",
"ToolCallExecutionEvent",
"ToolCallRequestEvent",
"ToolCallSummaryMessage",
"ChatMessage",
"AgentEvent",
"UserInputRequestedEvent",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
This module implements utility classes for formatting/printing agent messages.
"""

from ._console import Console
from ._console import Console, UserInputManager

__all__ = ["Console"]
__all__ = ["Console", "UserInputManager"]
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import asyncio
import os
import sys
import time
from typing import AsyncGenerator, List, Optional, TypeVar, cast
from inspect import iscoroutinefunction
from typing import AsyncGenerator, Dict, List, Optional, TypeVar, cast

from aioconsole import aprint # type: ignore
from autogen_core import Image
from autogen_core._cancellation_token import CancellationToken
from autogen_core.models import RequestUsage

from autogen_agentchat.agents._user_proxy_agent import UserProxyAgent
from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, UserInputRequestedEvent


def _is_running_in_iterm() -> bool:
Expand All @@ -22,11 +26,53 @@
T = TypeVar("T", bound=TaskResult | Response)


class UserInputManager:
def __init__(self, callback: UserProxyAgent.InputFuncType = UserProxyAgent.DEFAULT_INPUT_FUNC):
jackgerrits marked this conversation as resolved.
Show resolved Hide resolved
self.input_events: Dict[str, asyncio.Event] = {}
self.callback = callback

Check warning on line 32 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L31-L32

Added lines #L31 - L32 were not covered by tests

def get_wrapped_callback(self) -> UserProxyAgent.AsyncInputFunc:
async def user_input_func_wrapper(prompt: str, cancellation_token: Optional[CancellationToken]) -> str:

Check warning on line 35 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L35

Added line #L35 was not covered by tests
# Lookup the event for the prompt, if it exists wait for it.
# If it doesn't exist, create it and store it.
# Get request ID:
request_id = UserProxyAgent.InputRequestContext.request_id()
if request_id in self.input_events:
event = self.input_events[request_id]

Check warning on line 41 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L39-L41

Added lines #L39 - L41 were not covered by tests
else:
event = asyncio.Event()
self.input_events[request_id] = event

Check warning on line 44 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L43-L44

Added lines #L43 - L44 were not covered by tests

await event.wait()

Check warning on line 46 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L46

Added line #L46 was not covered by tests

del self.input_events[request_id]

Check warning on line 48 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L48

Added line #L48 was not covered by tests

if iscoroutinefunction(self.callback):

Check warning on line 50 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L50

Added line #L50 was not covered by tests
# Cast to AsyncInputFunc for proper typing
async_func = cast(UserProxyAgent.AsyncInputFunc, self.callback)
return await async_func(prompt, cancellation_token)

Check warning on line 53 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L52-L53

Added lines #L52 - L53 were not covered by tests
else:
# Cast to SyncInputFunc for proper typing
sync_func = cast(UserProxyAgent.SyncInputFunc, self.callback)
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, sync_func, prompt)

Check warning on line 58 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L56-L58

Added lines #L56 - L58 were not covered by tests

return user_input_func_wrapper

Check warning on line 60 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L60

Added line #L60 was not covered by tests

def notify_event_received(self, request_id: str) -> None:
if request_id in self.input_events:
self.input_events[request_id].set()

Check warning on line 64 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L63-L64

Added lines #L63 - L64 were not covered by tests
else:
event = asyncio.Event()
self.input_events[request_id] = event

Check warning on line 67 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L66-L67

Added lines #L66 - L67 were not covered by tests


async def Console(
stream: AsyncGenerator[AgentEvent | ChatMessage | T, None],
*,
no_inline_images: bool = False,
output_stats: bool = True,
user_input_manager: UserInputManager | None = None,
) -> T:
"""
Consumes the message stream from :meth:`~autogen_agentchat.base.TaskRunner.run_stream`
Expand Down Expand Up @@ -62,6 +108,7 @@
f"Duration: {duration:.2f} seconds\n"
)
await aprint(output, end="")

# mypy ignore
last_processed = message # type: ignore

Expand Down Expand Up @@ -91,9 +138,13 @@
f"Duration: {duration:.2f} seconds\n"
)
await aprint(output, end="")

# mypy ignore
last_processed = message # type: ignore

# We don't want to print UserInputRequestedEvent messages, we just use them to signal the user input event.
elif isinstance(message, UserInputRequestedEvent):
if user_input_manager is not None:
user_input_manager.notify_event_received(message.request_id)

Check warning on line 147 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L146-L147

Added lines #L146 - L147 were not covered by tests
else:
# Cast required for mypy to be happy
message = cast(AgentEvent | ChatMessage, message) # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,12 @@

"""

def __init__(self, client: ChatCompletionClient, hil_mode: bool = False):
def __init__(

Check warning on line 119 in python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py#L119

Added line #L119 was not covered by tests
self,
client: ChatCompletionClient,
hil_mode: bool = False,
input_func: UserProxyAgent.InputFuncType | None = None,
):
self.client = client
self._validate_client_capabilities(client)

Expand All @@ -126,7 +131,7 @@
executor = CodeExecutorAgent("Executor", code_executor=LocalCommandLineCodeExecutor())
agents: List[ChatAgent] = [fs, ws, coder, executor]
if hil_mode:
user_proxy = UserProxyAgent("User")
user_proxy = UserProxyAgent("User", input_func=input_func)

Check warning on line 134 in python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py#L134

Added line #L134 was not covered by tests
agents.append(user_proxy)
super().__init__(agents, model_client=client)

Expand Down
6 changes: 4 additions & 2 deletions python/packages/magentic-one-cli/src/magentic_one_cli/_m1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings

from autogen_agentchat.ui import Console
from autogen_agentchat.ui._console import UserInputManager
jackgerrits marked this conversation as resolved.
Show resolved Hide resolved
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.teams.magentic_one import MagenticOne

Expand Down Expand Up @@ -37,9 +38,10 @@ def main() -> None:
args = parser.parse_args()

async def run_task(task: str, hil_mode: bool) -> None:
input_manager = UserInputManager()
client = OpenAIChatCompletionClient(model="gpt-4o")
m1 = MagenticOne(client=client, hil_mode=hil_mode)
await Console(m1.run_stream(task=task), output_stats=False)
m1 = MagenticOne(client=client, hil_mode=hil_mode, input_func=input_manager.get_wrapped_callback())
await Console(m1.run_stream(task=task), output_stats=False, user_input_manager=input_manager)

task = args.task[0]
asyncio.run(run_task(task, not args.no_hil))
Expand Down
Loading