Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix user input in m1 #4995

Merged
merged 12 commits into from
Jan 13, 2025
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import asyncio
import uuid
from contextlib import contextmanager
from contextvars import ContextVar
from inspect import iscoroutinefunction
from typing import Awaitable, Callable, Optional, Sequence, Union, cast
from typing import Any, AsyncGenerator, Awaitable, Callable, ClassVar, Generator, Optional, Sequence, Union, cast

from aioconsole import ainput # type: ignore
from autogen_core import CancellationToken

from ..base import Response
from ..messages import ChatMessage, HandoffMessage, TextMessage
from ..messages import AgentEvent, ChatMessage, HandoffMessage, TextMessage, UserInputRequestedEvent
from ._base_chat_agent import BaseChatAgent

# Define input function types more precisely
SyncInputFunc = Callable[[str], str]
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]
Expand Down Expand Up @@ -109,6 +111,33 @@
print(f"BaseException: {e}")
"""

class InputRequestContext:
def __init__(self) -> None:
raise RuntimeError(

Check warning on line 116 in python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py#L116

Added line #L116 was not covered by tests
"InputRequestContext cannot be instantiated. It is a static class that provides context management for user input requests."
)

_INPUT_REQUEST_CONTEXT_VAR: ClassVar[ContextVar[str]] = ContextVar("_INPUT_REQUEST_CONTEXT_VAR")

@classmethod
@contextmanager
def populate_context(cls, ctx: str) -> Generator[None, Any, None]:
""":meta private:"""
token = UserProxyAgent.InputRequestContext._INPUT_REQUEST_CONTEXT_VAR.set(ctx)
try:
yield
finally:
UserProxyAgent.InputRequestContext._INPUT_REQUEST_CONTEXT_VAR.reset(token)

@classmethod
def request_id(cls) -> str:
try:
return cls._INPUT_REQUEST_CONTEXT_VAR.get()
except LookupError as e:
raise RuntimeError(

Check warning on line 137 in python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py#L134-L137

Added lines #L134 - L137 were not covered by tests
"InputRequestContext.runtime() must be called within the input callback of a UserProxyAgent."
) from e

def __init__(
self,
name: str,
Expand Down Expand Up @@ -153,9 +182,15 @@
except Exception as e:
raise RuntimeError(f"Failed to get user input: {str(e)}") from e

async def on_messages(
self, messages: Sequence[ChatMessage], cancellation_token: Optional[CancellationToken] = None
) -> Response:
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")

Check warning on line 189 in python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py#L189

Added line #L189 was not covered by tests

async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
"""Handle incoming messages by requesting user input."""
try:
# Check for handoff first
Expand All @@ -164,15 +199,18 @@
f"Handoff received from {handoff.source}. Enter your response: " if handoff else "Enter your response: "
)

user_input = await self._get_input(prompt, cancellation_token)
request_id = str(uuid.uuid4())

input_requested_event = UserInputRequestedEvent(request_id=request_id, source=self.name)
yield input_requested_event
with UserProxyAgent.InputRequestContext.populate_context(request_id):
user_input = await self._get_input(prompt, cancellation_token)

# Return appropriate message type based on handoff presence
if handoff:
return Response(
chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name)
)
yield Response(chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name))
else:
return Response(chat_message=TextMessage(content=user_input, source=self.name))
yield Response(chat_message=TextMessage(content=user_input, source=self.name))

except asyncio.CancelledError:
raise
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,25 +103,40 @@ class ToolCallSummaryMessage(BaseChatMessage):
type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage"


class UserInputRequestedEvent(BaseAgentEvent):
"""An event signaling a that the user proxy has requested user input. Published prior to invoking the input callback."""

request_id: str
"""Identifier for the user input request."""

content: Literal[""] = ""
"""Empty content for compat with consumers expecting a content field."""

type: Literal["UserInputRequestedEvent"] = "UserInputRequestedEvent"


ChatMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
]
"""Messages for agent-to-agent communication only."""


AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent, Field(discriminator="type")]
AgentEvent = Annotated[
ToolCallRequestEvent | ToolCallExecutionEvent | UserInputRequestedEvent, Field(discriminator="type")
]
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""


__all__ = [
"AgentEvent",
"BaseMessage",
"TextMessage",
"ChatMessage",
"HandoffMessage",
"MultiModalMessage",
"StopMessage",
"HandoffMessage",
"ToolCallRequestEvent",
"TextMessage",
"ToolCallExecutionEvent",
"ToolCallRequestEvent",
"ToolCallSummaryMessage",
"ChatMessage",
"AgentEvent",
"UserInputRequestedEvent",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
This module implements utility classes for formatting/printing agent messages.
"""

from ._console import Console
from ._console import Console, UserInputManager

__all__ = ["Console"]
__all__ = ["Console", "UserInputManager"]
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import asyncio
import os
import sys
import time
from typing import AsyncGenerator, List, Optional, TypeVar, cast
from inspect import iscoroutinefunction
from typing import AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union, cast

from aioconsole import aprint # type: ignore
from autogen_core import Image
from autogen_core import CancellationToken, Image
from autogen_core.models import RequestUsage

from autogen_agentchat.agents import UserProxyAgent
from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, UserInputRequestedEvent


def _is_running_in_iterm() -> bool:
Expand All @@ -19,14 +22,60 @@
return sys.stdout.isatty()


SyncInputFunc = Callable[[str], str]
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]

T = TypeVar("T", bound=TaskResult | Response)


class UserInputManager:
def __init__(self, callback: InputFuncType):
self.input_events: Dict[str, asyncio.Event] = {}
self.callback = callback

Check warning on line 35 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L34-L35

Added lines #L34 - L35 were not covered by tests

def get_wrapped_callback(self) -> AsyncInputFunc:
async def user_input_func_wrapper(prompt: str, cancellation_token: Optional[CancellationToken]) -> str:

Check warning on line 38 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L38

Added line #L38 was not covered by tests
# Lookup the event for the prompt, if it exists wait for it.
# If it doesn't exist, create it and store it.
# Get request ID:
request_id = UserProxyAgent.InputRequestContext.request_id()
if request_id in self.input_events:
event = self.input_events[request_id]

Check warning on line 44 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L42-L44

Added lines #L42 - L44 were not covered by tests
else:
event = asyncio.Event()
self.input_events[request_id] = event

Check warning on line 47 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L46-L47

Added lines #L46 - L47 were not covered by tests

await event.wait()

Check warning on line 49 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L49

Added line #L49 was not covered by tests

del self.input_events[request_id]

Check warning on line 51 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L51

Added line #L51 was not covered by tests

if iscoroutinefunction(self.callback):

Check warning on line 53 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L53

Added line #L53 was not covered by tests
# Cast to AsyncInputFunc for proper typing
async_func = cast(AsyncInputFunc, self.callback)
return await async_func(prompt, cancellation_token)

Check warning on line 56 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L55-L56

Added lines #L55 - L56 were not covered by tests
else:
# Cast to SyncInputFunc for proper typing
sync_func = cast(SyncInputFunc, self.callback)
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, sync_func, prompt)

Check warning on line 61 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L59-L61

Added lines #L59 - L61 were not covered by tests

return user_input_func_wrapper

Check warning on line 63 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L63

Added line #L63 was not covered by tests

def notify_event_received(self, request_id: str) -> None:
if request_id in self.input_events:
self.input_events[request_id].set()

Check warning on line 67 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L66-L67

Added lines #L66 - L67 were not covered by tests
else:
event = asyncio.Event()
self.input_events[request_id] = event

Check warning on line 70 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L69-L70

Added lines #L69 - L70 were not covered by tests


async def Console(
stream: AsyncGenerator[AgentEvent | ChatMessage | T, None],
*,
no_inline_images: bool = False,
output_stats: bool = False,
user_input_manager: UserInputManager | None = None,
) -> T:
"""
Consumes the message stream from :meth:`~autogen_agentchat.base.TaskRunner.run_stream`
Expand Down Expand Up @@ -67,6 +116,7 @@
f"Duration: {duration:.2f} seconds\n"
)
await aprint(output, end="")

# mypy ignore
last_processed = message # type: ignore

Expand Down Expand Up @@ -96,9 +146,13 @@
f"Duration: {duration:.2f} seconds\n"
)
await aprint(output, end="")

# mypy ignore
last_processed = message # type: ignore

# We don't want to print UserInputRequestedEvent messages, we just use them to signal the user input event.
elif isinstance(message, UserInputRequestedEvent):
if user_input_manager is not None:
user_input_manager.notify_event_received(message.request_id)

Check warning on line 155 in python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py#L154-L155

Added lines #L154 - L155 were not covered by tests
else:
# Cast required for mypy to be happy
message = cast(AgentEvent | ChatMessage, message) # type: ignore
Expand Down
16 changes: 13 additions & 3 deletions python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import warnings
from typing import List
from typing import Awaitable, Callable, List, Optional, Union

Check warning on line 2 in python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py#L2

Added line #L2 was not covered by tests

from autogen_agentchat.agents import CodeExecutorAgent, UserProxyAgent
from autogen_agentchat.base import ChatAgent
from autogen_agentchat.teams import MagenticOneGroupChat
from autogen_core import CancellationToken

Check warning on line 7 in python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py#L7

Added line #L7 was not covered by tests
from autogen_core.models import ChatCompletionClient

from autogen_ext.agents.file_surfer import FileSurfer
Expand All @@ -12,6 +13,10 @@
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
from autogen_ext.models.openai._openai_client import BaseOpenAIChatCompletionClient

SyncInputFunc = Callable[[str], str]
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]

Check warning on line 18 in python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py#L16-L18

Added lines #L16 - L18 were not covered by tests


class MagenticOne(MagenticOneGroupChat):
"""
Expand Down Expand Up @@ -116,7 +121,12 @@

"""

def __init__(self, client: ChatCompletionClient, hil_mode: bool = False):
def __init__(

Check warning on line 124 in python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py#L124

Added line #L124 was not covered by tests
self,
client: ChatCompletionClient,
hil_mode: bool = False,
input_func: InputFuncType | None = None,
):
self.client = client
self._validate_client_capabilities(client)

Expand All @@ -126,7 +136,7 @@
executor = CodeExecutorAgent("Executor", code_executor=LocalCommandLineCodeExecutor())
agents: List[ChatAgent] = [fs, ws, coder, executor]
if hil_mode:
user_proxy = UserProxyAgent("User")
user_proxy = UserProxyAgent("User", input_func=input_func)

Check warning on line 139 in python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/teams/magentic_one.py#L139

Added line #L139 was not covered by tests
agents.append(user_proxy)
super().__init__(agents, model_client=client)

Expand Down
17 changes: 14 additions & 3 deletions python/packages/magentic-one-cli/src/magentic_one_cli/_m1.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
import argparse
import asyncio
import warnings
from typing import Optional

from autogen_agentchat.ui import Console
from aioconsole import ainput # type: ignore
from autogen_agentchat.ui import Console, UserInputManager
from autogen_core import CancellationToken
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.teams.magentic_one import MagenticOne

# Suppress warnings about the requests.Session() not being closed
warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning)


async def cancellable_input(prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
task: asyncio.Task[str] = asyncio.create_task(ainput(prompt)) # type: ignore
if cancellation_token is not None:
cancellation_token.link_future(task)
return await task


def main() -> None:
"""
Command-line interface for running a complex task using MagenticOne.
Expand Down Expand Up @@ -37,9 +47,10 @@ def main() -> None:
args = parser.parse_args()

async def run_task(task: str, hil_mode: bool) -> None:
input_manager = UserInputManager(callback=cancellable_input)
client = OpenAIChatCompletionClient(model="gpt-4o")
m1 = MagenticOne(client=client, hil_mode=hil_mode)
await Console(m1.run_stream(task=task), output_stats=False)
m1 = MagenticOne(client=client, hil_mode=hil_mode, input_func=input_manager.get_wrapped_callback())
await Console(m1.run_stream(task=task), output_stats=False, user_input_manager=input_manager)

task = args.task[0]
asyncio.run(run_task(task, not args.no_hil))
Expand Down
Loading