Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Impl register and add sub RPC #5191

Merged
merged 8 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from autogen_core._subscription import Subscription
from autogen_core._type_prefix_subscription import TypePrefixSubscription
from autogen_core._type_subscription import TypeSubscription

from .protos import agent_worker_pb2


def subscription_to_proto(subscription: Subscription) -> agent_worker_pb2.Subscription:
match subscription:
case TypeSubscription(topic_type=topic_type, agent_type=agent_type, id=id):
return agent_worker_pb2.Subscription(
id=id,
typeSubscription=agent_worker_pb2.TypeSubscription(topic_type=topic_type, agent_type=agent_type),
)
case TypePrefixSubscription(topic_type_prefix=topic_type_prefix, agent_type=agent_type, id=id):
return agent_worker_pb2.Subscription(
id=id,
typePrefixSubscription=agent_worker_pb2.TypePrefixSubscription(
topic_type_prefix=topic_type_prefix, agent_type=agent_type
),
)
case _:
raise ValueError("Unsupported subscription type.")

Check warning on line 23 in python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_utils.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_utils.py#L22-L23

Added lines #L22 - L23 were not covered by tests


def subscription_from_proto(subscription: agent_worker_pb2.Subscription) -> Subscription:
oneofcase = subscription.WhichOneof("subscription")
match oneofcase:
case "typeSubscription":
type_subscription_msg: agent_worker_pb2.TypeSubscription = subscription.typeSubscription
return TypeSubscription(
topic_type=type_subscription_msg.topic_type,
agent_type=type_subscription_msg.agent_type,
id=subscription.id,
)

case "typePrefixSubscription":
type_prefix_subscription_msg: agent_worker_pb2.TypePrefixSubscription = subscription.typePrefixSubscription
return TypePrefixSubscription(
topic_type_prefix=type_prefix_subscription_msg.topic_type_prefix,
agent_type=type_prefix_subscription_msg.agent_type,
id=subscription.id,
)
case None:
raise ValueError("Invalid subscription message.")

Check warning on line 45 in python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_utils.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_utils.py#L44-L45

Added lines #L44 - L45 were not covered by tests
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import inspect
import json
Expand All @@ -23,6 +25,7 @@
ParamSpec,
Sequence,
Set,
Tuple,
Type,
TypeVar,
cast,
Expand All @@ -43,8 +46,6 @@
MessageSerializer,
Subscription,
TopicId,
TypePrefixSubscription,
TypeSubscription,
)
from autogen_core._runtime_impl_helpers import SubscriptionManager, get_impl
from autogen_core._serialization import (
Expand All @@ -55,6 +56,8 @@
from opentelemetry.trace import TracerProvider
from typing_extensions import Self

from autogen_ext.runtimes.grpc._utils import subscription_to_proto

from . import _constants
from ._constants import GRPC_IMPORT_ERROR_STR
from ._type_helpers import ChannelArgumentType
Expand Down Expand Up @@ -112,13 +115,22 @@
)
]

def __init__(self, channel: grpc.aio.Channel) -> None: # type: ignore
def __init__(self, channel: grpc.aio.Channel, stub: Any) -> None: # type: ignore
self._channel = channel
self._send_queue = asyncio.Queue[agent_worker_pb2.Message]()
self._recv_queue = asyncio.Queue[agent_worker_pb2.Message]()
self._connection_task: Task[None] | None = None
self._stub: AgentRpcAsyncStub = stub
self._client_id = str(uuid.uuid4())

@property
def stub(self) -> Any:
return self._stub

@property
def metadata(self) -> Sequence[Tuple[str, str]]:
return [("client-id", self._client_id)]

@classmethod
def from_host_address(cls, host_address: str, extra_grpc_config: ChannelArgumentType = DEFAULT_GRPC_CONFIG) -> Self:
logger.info("Connecting to %s", host_address)
Expand All @@ -131,9 +143,10 @@
host_address,
options=merged_options,
)
instance = cls(channel)
stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore
instance = cls(channel, stub)
instance._connection_task = asyncio.create_task(
instance._connect(channel, instance._send_queue, instance._recv_queue, instance._client_id)
instance._connect(stub, instance._send_queue, instance._recv_queue, instance._client_id)
)
return instance

Expand All @@ -144,28 +157,25 @@
await self._connection_task

@staticmethod
async def _connect( # type: ignore
channel: grpc.aio.Channel,
async def _connect(
stub: Any, # AgentRpcAsyncStub
send_queue: asyncio.Queue[agent_worker_pb2.Message],
receive_queue: asyncio.Queue[agent_worker_pb2.Message],
client_id: str,
) -> None:
stub: AgentRpcAsyncStub = agent_worker_pb2_grpc.AgentRpcStub(channel) # type: ignore

from grpc.aio import StreamStreamCall

# TODO: where do exceptions from reading the iterable go? How do we recover from those?
recv_stream: StreamStreamCall[agent_worker_pb2.Message, agent_worker_pb2.Message] = stub.OpenChannel( # type: ignore
QueueAsyncIterable(send_queue), metadata=[("client-id", client_id)]
) # type: ignore
)

while True:
logger.info("Waiting for message from host")
message = await recv_stream.read() # type: ignore
message = cast(agent_worker_pb2.Message, await recv_stream.read()) # type: ignore
if message == grpc.aio.EOF: # type: ignore
logger.info("EOF")
break
message = cast(agent_worker_pb2.Message, message)
logger.info(f"Received a message from host: {message}")
await receive_queue.put(message)
logger.info("Put message in receive queue")
Expand Down Expand Up @@ -258,10 +268,11 @@

async def _run_read_loop(self) -> None:
logger.info("Starting read loop")
assert self._host_connection is not None
# TODO: catch exceptions and reconnect
while self._running:
try:
message = await self._host_connection.recv() # type: ignore
message = await self._host_connection.recv()
oneofcase = agent_worker_pb2.Message.WhichOneof(message, "message")
match oneofcase:
case "registerAgentTypeRequest" | "addSubscriptionRequest":
Expand All @@ -277,9 +288,7 @@
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "cloudEvent":
# The proto typing doesnt resolve this one
cloud_event = cast(cloudevent_pb2.CloudEvent, message.cloudEvent) # type: ignore
task = asyncio.create_task(self._process_event(cloud_event))
task = asyncio.create_task(self._process_event(message.cloudEvent))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
Expand Down Expand Up @@ -734,14 +743,13 @@
self._pending_requests[request_id] = future

# Send the registration request message to the host.
message = agent_worker_pb2.Message(
registerAgentTypeRequest=agent_worker_pb2.RegisterAgentTypeRequest(request_id=request_id, type=type.type)
message = agent_worker_pb2.RegisterAgentTypeRequest(request_id=request_id, type=type.type)
response: agent_worker_pb2.RegisterAgentTypeResponse = await self._host_connection.stub.RegisterAgent(
message, metadata=self._host_connection.metadata
)
await self._host_connection.send(message)

# Wait for the registration response.
await future

# TODO: just use grpc error handling
if not response.success:
raise RuntimeError(response.error)
return type

async def _process_register_agent_type_response(self, response: agent_worker_pb2.RegisterAgentTypeResponse) -> None:
Expand Down Expand Up @@ -805,49 +813,20 @@
raise RuntimeError("Host connection is not set.")

# Create a future for the subscription response.
future = asyncio.get_event_loop().create_future()
request_id = await self._get_new_request_id()

match subscription:
case TypeSubscription(topic_type=topic_type, agent_type=agent_type, id=id):
message = agent_worker_pb2.Message(
addSubscriptionRequest=agent_worker_pb2.AddSubscriptionRequest(
request_id=request_id,
subscription=agent_worker_pb2.Subscription(
id=id,
typeSubscription=agent_worker_pb2.TypeSubscription(
topic_type=topic_type, agent_type=agent_type
),
),
)
)
case TypePrefixSubscription(topic_type_prefix=topic_type_prefix, agent_type=agent_type, id=id):
message = agent_worker_pb2.Message(
addSubscriptionRequest=agent_worker_pb2.AddSubscriptionRequest(
request_id=request_id,
subscription=agent_worker_pb2.Subscription(
id=id,
typePrefixSubscription=agent_worker_pb2.TypePrefixSubscription(
topic_type_prefix=topic_type_prefix, agent_type=agent_type
),
),
)
)
case _:
raise ValueError("Unsupported subscription type.")

# Add the future to the pending requests.
self._pending_requests[request_id] = future
message = agent_worker_pb2.AddSubscriptionRequest(
request_id=request_id, subscription=subscription_to_proto(subscription)
)
response: agent_worker_pb2.AddSubscriptionResponse = await self._host_connection.stub.AddSubscription(
message, metadata=self._host_connection.metadata
)
if not response.success:
raise RuntimeError(response.error)

Check warning on line 825 in python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py#L825

Added line #L825 was not covered by tests

# Add to local subscription manager.
await self._subscription_manager.add_subscription(subscription)

# Send the subscription to the host.
await self._host_connection.send(message)

# Wait for the subscription response.
await future

async def _process_add_subscription_response(self, response: agent_worker_pb2.AddSubscriptionResponse) -> None:
future = self._pending_requests.pop(response.request_id)
if response.HasField("error") and response.error != "":
Expand Down
Loading
Loading