Skip to content

Commit

Permalink
Add lock for input and output management in m1
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Jan 10, 2025
1 parent c59cfdd commit 9486169
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# Define input function types more precisely
SyncInputFunc = Callable[[str], str]
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]


# TODO: ainput doesn't seem to play nicely with jupyter.
Expand Down Expand Up @@ -109,6 +108,8 @@ async def cancellable_user_agent():
print(f"BaseException: {e}")
"""

InputFuncType = Union[SyncInputFunc, AsyncInputFunc]

def __init__(
self,
name: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,7 @@ async def stop_runtime() -> None:
await self._runtime.stop_when_idle()
await self._output_message_queue.put(None)

shutdown_task = asyncio.create_task(stop_runtime())

shutdown_task: asyncio.Task[None] | None = None
try:
# Run the team by sending the start message to the group chat manager.
# The group chat manager will start the group chat by relaying the message to the participants
Expand All @@ -406,6 +405,10 @@ async def stop_runtime() -> None:
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
cancellation_token=cancellation_token,
)

# After sending the start message, we create the shutdown task that will wait for the runtime to become idle.
shutdown_task = asyncio.create_task(stop_runtime())

# Collect the output messages in order.
output_messages: List[AgentEvent | ChatMessage] = []
# Yield the messsages until the queue is empty.
Expand All @@ -425,7 +428,8 @@ async def stop_runtime() -> None:

finally:
# Wait for the shutdown task to finish.
await shutdown_task
if shutdown_task is not None:
await shutdown_task

# Clear the output message queue.
while not self._output_message_queue.empty():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys
import time
from asyncio import Lock
from typing import AsyncGenerator, List, Optional, TypeVar, cast

from aioconsole import aprint # type: ignore
Expand All @@ -22,11 +23,20 @@ def _is_output_a_tty() -> bool:
T = TypeVar("T", bound=TaskResult | Response)


class NoopLock:
async def __aenter__(self) -> None:
pass

async def __aexit__(self, exc_type: Optional[type], exc: Optional[BaseException], tb: Optional[object]) -> None:
pass


async def Console(
stream: AsyncGenerator[AgentEvent | ChatMessage | T, None],
*,
no_inline_images: bool = False,
output_stats: bool = True,
output_lock: Lock | None = None,
) -> T:
"""
Consumes the message stream from :meth:`~autogen_agentchat.base.TaskRunner.run_stream`
Expand All @@ -47,6 +57,8 @@ async def Console(
start_time = time.time()
total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)

actual_lock: Lock | NoopLock = output_lock or NoopLock()

last_processed: Optional[T] = None

async for message in stream:
Expand All @@ -61,7 +73,8 @@ async def Console(
f"Total completion tokens: {total_usage.completion_tokens}\n"
f"Duration: {duration:.2f} seconds\n"
)
await aprint(output, end="")
async with actual_lock:
await aprint(output, end="")
# mypy ignore
last_processed = message # type: ignore

Expand All @@ -75,7 +88,8 @@ async def Console(
output += f"[Prompt tokens: {message.chat_message.models_usage.prompt_tokens}, Completion tokens: {message.chat_message.models_usage.completion_tokens}]\n"
total_usage.completion_tokens += message.chat_message.models_usage.completion_tokens
total_usage.prompt_tokens += message.chat_message.models_usage.prompt_tokens
await aprint(output, end="")
async with actual_lock:
await aprint(output, end="")

# Print summary.
if output_stats:
Expand All @@ -90,7 +104,8 @@ async def Console(
f"Total completion tokens: {total_usage.completion_tokens}\n"
f"Duration: {duration:.2f} seconds\n"
)
await aprint(output, end="")
async with actual_lock:
await aprint(output, end="")
# mypy ignore
last_processed = message # type: ignore

Expand All @@ -103,7 +118,8 @@ async def Console(
output += f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]\n"
total_usage.completion_tokens += message.models_usage.completion_tokens
total_usage.prompt_tokens += message.models_usage.prompt_tokens
await aprint(output, end="")
async with actual_lock:
await aprint(output, end="")

if last_processed is None:
raise ValueError("No TaskResult or Response was processed.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,12 @@ async def example_usage_hil():
"""

def __init__(self, client: ChatCompletionClient, hil_mode: bool = False):
def __init__(
self,
client: ChatCompletionClient,
hil_mode: bool = False,
input_func: UserProxyAgent.InputFuncType | None = None,
):
self.client = client
self._validate_client_capabilities(client)

Expand All @@ -126,7 +131,7 @@ def __init__(self, client: ChatCompletionClient, hil_mode: bool = False):
executor = CodeExecutorAgent("Executor", code_executor=LocalCommandLineCodeExecutor())
agents: List[ChatAgent] = [fs, ws, coder, executor]
if hil_mode:
user_proxy = UserProxyAgent("User")
user_proxy = UserProxyAgent("User", input_func=input_func)
agents.append(user_proxy)
super().__init__(agents, model_client=client)

Expand Down
16 changes: 14 additions & 2 deletions python/packages/magentic-one-cli/src/magentic_one_cli/_m1.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import argparse
import asyncio
import warnings
from typing import Optional

from aioconsole import ainput # type: ignore
from autogen_agentchat.ui import Console
from autogen_core import CancellationToken
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.teams.magentic_one import MagenticOne

Expand Down Expand Up @@ -37,9 +40,18 @@ def main() -> None:
args = parser.parse_args()

async def run_task(task: str, hil_mode: bool) -> None:
output_lock = asyncio.Lock()

async def cancellable_input(prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
async with output_lock:
task: asyncio.Task[str] = ainput(prompt) # type: ignore
if cancellation_token is not None:
cancellation_token.link_future(task)
return await task

client = OpenAIChatCompletionClient(model="gpt-4o")
m1 = MagenticOne(client=client, hil_mode=hil_mode)
await Console(m1.run_stream(task=task), output_stats=False)
m1 = MagenticOne(client=client, hil_mode=hil_mode, input_func=cancellable_input)
await Console(m1.run_stream(task=task), output_stats=False, output_lock=output_lock)

task = args.task[0]
asyncio.run(run_task(task, not args.no_hil))
Expand Down

0 comments on commit 9486169

Please sign in to comment.