Skip to content

Commit

Permalink
Impl register and add sub RPC (#5191)
Browse files Browse the repository at this point in the history
* Refactor client id retrieval

* WIP

* fixes

* future annotations

* Fix tests

* remove import
  • Loading branch information
jackgerrits authored Jan 24, 2025
1 parent db2410c commit 55e929d
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 120 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from autogen_core._subscription import Subscription
from autogen_core._type_prefix_subscription import TypePrefixSubscription
from autogen_core._type_subscription import TypeSubscription

from .protos import agent_worker_pb2


def subscription_to_proto(subscription: Subscription) -> agent_worker_pb2.Subscription:
match subscription:
case TypeSubscription(topic_type=topic_type, agent_type=agent_type, id=id):
return agent_worker_pb2.Subscription(
id=id,
typeSubscription=agent_worker_pb2.TypeSubscription(topic_type=topic_type, agent_type=agent_type),
)
case TypePrefixSubscription(topic_type_prefix=topic_type_prefix, agent_type=agent_type, id=id):
return agent_worker_pb2.Subscription(
id=id,
typePrefixSubscription=agent_worker_pb2.TypePrefixSubscription(
topic_type_prefix=topic_type_prefix, agent_type=agent_type
),
)
case _:
raise ValueError("Unsupported subscription type.")


def subscription_from_proto(subscription: agent_worker_pb2.Subscription) -> Subscription:
oneofcase = subscription.WhichOneof("subscription")
match oneofcase:
case "typeSubscription":
type_subscription_msg: agent_worker_pb2.TypeSubscription = subscription.typeSubscription
return TypeSubscription(
topic_type=type_subscription_msg.topic_type,
agent_type=type_subscription_msg.agent_type,
id=subscription.id,
)

case "typePrefixSubscription":
type_prefix_subscription_msg: agent_worker_pb2.TypePrefixSubscription = subscription.typePrefixSubscription
return TypePrefixSubscription(
topic_type_prefix=type_prefix_subscription_msg.topic_type_prefix,
agent_type=type_prefix_subscription_msg.agent_type,
id=subscription.id,
)
case None:
raise ValueError("Invalid subscription message.")
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import inspect
import json
Expand All @@ -23,6 +25,7 @@
ParamSpec,
Sequence,
Set,
Tuple,
Type,
TypeVar,
cast,
Expand All @@ -43,8 +46,6 @@
MessageSerializer,
Subscription,
TopicId,
TypePrefixSubscription,
TypeSubscription,
)
from autogen_core._runtime_impl_helpers import SubscriptionManager, get_impl
from autogen_core._serialization import (
Expand All @@ -55,6 +56,8 @@
from opentelemetry.trace import TracerProvider
from typing_extensions import Self

from autogen_ext.runtimes.grpc._utils import subscription_to_proto

from . import _constants
from ._constants import GRPC_IMPORT_ERROR_STR
from ._type_helpers import ChannelArgumentType
Expand Down Expand Up @@ -112,13 +115,22 @@ class HostConnection:
)
]

def __init__(self, channel: grpc.aio.Channel) -> None: # type: ignore
def __init__(self, channel: grpc.aio.Channel, stub: Any) -> None: # type: ignore
self._channel = channel
self._send_queue = asyncio.Queue[agent_worker_pb2.Message]()
self._recv_queue = asyncio.Queue[agent_worker_pb2.Message]()
self._connection_task: Task[None] | None = None
self._stub: AgentRpcAsyncStub = stub
self._client_id = str(uuid.uuid4())

@property
def stub(self) -> Any:
return self._stub

@property
def metadata(self) -> Sequence[Tuple[str, str]]:
return [("client-id", self._client_id)]

@classmethod
def from_host_address(cls, host_address: str, extra_grpc_config: ChannelArgumentType = DEFAULT_GRPC_CONFIG) -> Self:
logger.info("Connecting to %s", host_address)
Expand All @@ -131,9 +143,10 @@ def from_host_address(cls, host_address: str, extra_grpc_config: ChannelArgument
host_address,
options=merged_options,
)
instance = cls(channel)
stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore
instance = cls(channel, stub)
instance._connection_task = asyncio.create_task(
instance._connect(channel, instance._send_queue, instance._recv_queue, instance._client_id)
instance._connect(stub, instance._send_queue, instance._recv_queue, instance._client_id)
)
return instance

Expand All @@ -144,28 +157,25 @@ async def close(self) -> None:
await self._connection_task

@staticmethod
async def _connect( # type: ignore
channel: grpc.aio.Channel,
async def _connect(
stub: Any, # AgentRpcAsyncStub
send_queue: asyncio.Queue[agent_worker_pb2.Message],
receive_queue: asyncio.Queue[agent_worker_pb2.Message],
client_id: str,
) -> None:
stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore

from grpc.aio import StreamStreamCall

# TODO: where do exceptions from reading the iterable go? How do we recover from those?
recv_stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore
QueueAsyncIterable(send_queue), metadata=[("client-id", client_id)]
) # type: ignore
)

while True:
logger.info("Waiting for message from host")
message = await recv_stream.read() # type: ignore
message = cast(agent_worker_pb2.Message, await recv_stream.read()) # type: ignore
if message == grpc.aio.EOF: # type: ignore
logger.info("EOF")
break
message = cast(agent_worker_pb2.Message, message)
logger.info(f"Received a message from host: {message}")
await receive_queue.put(message)
logger.info("Put message in receive queue")
Expand Down Expand Up @@ -258,10 +268,11 @@ def _raise_on_exception(self, task: Task[Any]) -> None:

async def _run_read_loop(self) -> None:
logger.info("Starting read loop")
assert self._host_connection is not None
# TODO: catch exceptions and reconnect
while self._running:
try:
message = await self._host_connection.recv() # type: ignore
message = await self._host_connection.recv()
oneofcase = agent_worker_pb2.Message.WhichOneof(message, "message")
match oneofcase:
case "registerAgentTypeRequest" | "addSubscriptionRequest":
Expand All @@ -277,9 +288,7 @@ async def _run_read_loop(self) -> None:
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "cloudEvent":
# The proto typing doesnt resolve this one
cloud_event = cast(cloudevent_pb2.CloudEvent, message.cloudEvent) # type: ignore
task = asyncio.create_task(self._process_event(cloud_event))
task = asyncio.create_task(self._process_event(message.cloudEvent))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
Expand Down Expand Up @@ -734,14 +743,13 @@ async def factory_wrapper() -> T:
self._pending_requests[request_id] = future

# Send the registration request message to the host.
message = agent_worker_pb2.Message(
registerAgentTypeRequest=agent_worker_pb2.RegisterAgentTypeRequest(request_id=request_id, type=type.type)
message = agent_worker_pb2.RegisterAgentTypeRequest(request_id=request_id, type=type.type)
response: agent_worker_pb2.RegisterAgentTypeResponse = await self._host_connection.stub.RegisterAgent(
message, metadata=self._host_connection.metadata
)
await self._host_connection.send(message)

# Wait for the registration response.
await future

# TODO: just use grpc error handling
if not response.success:
raise RuntimeError(response.error)
return type

async def _process_register_agent_type_response(self, response: agent_worker_pb2.RegisterAgentTypeResponse) -> None:
Expand Down Expand Up @@ -805,49 +813,20 @@ async def add_subscription(self, subscription: Subscription) -> None:
raise RuntimeError("Host connection is not set.")

# Create a future for the subscription response.
future = asyncio.get_event_loop().create_future()
request_id = await self._get_new_request_id()

match subscription:
case TypeSubscription(topic_type=topic_type, agent_type=agent_type, id=id):
message = agent_worker_pb2.Message(
addSubscriptionRequest=agent_worker_pb2.AddSubscriptionRequest(
request_id=request_id,
subscription=agent_worker_pb2.Subscription(
id=id,
typeSubscription=agent_worker_pb2.TypeSubscription(
topic_type=topic_type, agent_type=agent_type
),
),
)
)
case TypePrefixSubscription(topic_type_prefix=topic_type_prefix, agent_type=agent_type, id=id):
message = agent_worker_pb2.Message(
addSubscriptionRequest=agent_worker_pb2.AddSubscriptionRequest(
request_id=request_id,
subscription=agent_worker_pb2.Subscription(
id=id,
typePrefixSubscription=agent_worker_pb2.TypePrefixSubscription(
topic_type_prefix=topic_type_prefix, agent_type=agent_type
),
),
)
)
case _:
raise ValueError("Unsupported subscription type.")

# Add the future to the pending requests.
self._pending_requests[request_id] = future
message = agent_worker_pb2.AddSubscriptionRequest(
request_id=request_id, subscription=subscription_to_proto(subscription)
)
response: agent_worker_pb2.AddSubscriptionResponse = await self._host_connection.stub.AddSubscription(
message, metadata=self._host_connection.metadata
)
if not response.success:
raise RuntimeError(response.error)

# Add to local subscription manager.
await self._subscription_manager.add_subscription(subscription)

# Send the subscription to the host.
await self._host_connection.send(message)

# Wait for the subscription response.
await future

async def _process_add_subscription_response(self, response: agent_worker_pb2.AddSubscriptionResponse) -> None:
future = self._pending_requests.pop(response.request_id)
if response.HasField("error") and response.error != "":
Expand Down
Loading

0 comments on commit 55e929d

Please sign in to comment.