Skip to content

Commit d59e4a6

Browse files
committed
Add Mock for initial_metadata() for base-clients GrpcStreamer
Signed-off-by: Mathias L. Baumann <[email protected]>
1 parent afa6ae5 commit d59e4a6

File tree

1 file changed

+60
-14
lines changed

1 file changed

+60
-14
lines changed

src/frequenz/client/dispatch/test/_service.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
66
Useful for testing.
77
"""
8+
import asyncio
89
import logging
910
from dataclasses import dataclass, replace
1011
from datetime import datetime, timezone
11-
from typing import AsyncIterator
12+
from typing import AsyncIterator, TypeVar
1213

1314
import grpc
1415
import grpc.aio
@@ -44,6 +45,33 @@
4445
_logger = logging.getLogger(__name__)
4546

4647

48+
T = TypeVar("T")
49+
50+
51+
class _MockStream(AsyncIterator[T]):
52+
"""A mock stream that wraps an async iterator and adds initial_metadata."""
53+
54+
def __init__(self, stream: AsyncIterator[T]) -> None:
55+
"""Initialize the mock stream.
56+
57+
Args:
58+
stream: The stream to wrap.
59+
"""
60+
self._iterator = stream.__aiter__()
61+
62+
async def initial_metadata(self) -> None:
63+
"""Do nothing, just to mock the grpc call."""
64+
_logger.debug("Called initial_metadata()")
65+
66+
def __aiter__(self) -> AsyncIterator[T]:
67+
"""Return the async iterator."""
68+
return self
69+
70+
async def __anext__(self) -> T:
71+
"""Return the next item from the stream."""
72+
return await self._iterator.__anext__()
73+
74+
4775
class FakeService:
4876
"""Dispatch mock service for testing."""
4977

@@ -113,7 +141,7 @@ async def StreamMicrogridDispatches(
113141
self,
114142
request: StreamMicrogridDispatchesRequest,
115143
timeout: int = 5, # pylint: disable=unused-argument
116-
) -> AsyncIterator[StreamMicrogridDispatchesResponse]:
144+
) -> _MockStream[StreamMicrogridDispatchesResponse]:
117145
"""Stream microgrid dispatches changes.
118146
119147
Args:
@@ -122,20 +150,28 @@ async def StreamMicrogridDispatches(
122150
123151
Returns:
124152
An async generator for dispatch changes.
125-
126-
Yields:
127-
An event for each dispatch change.
128153
"""
129-
receiver = self._stream_channel.new_receiver()
130154

131-
async for message in receiver:
132-
_logger.debug("Received message: %s", message)
133-
if message.microgrid_id == MicrogridId(request.microgrid_id):
134-
response = StreamMicrogridDispatchesResponse(
135-
event=message.event.event.value,
136-
dispatch=message.event.dispatch.to_protobuf(),
137-
)
138-
yield response
155+
async def stream() -> AsyncIterator[StreamMicrogridDispatchesResponse]:
156+
"""Stream microgrid dispatches changes."""
157+
_logger.debug("Starting stream for microgrid %s", request.microgrid_id)
158+
receiver = self._stream_channel.new_receiver()
159+
160+
async for message in receiver:
161+
_logger.debug("Received message: %s", message)
162+
if message.microgrid_id == MicrogridId(request.microgrid_id):
163+
response = StreamMicrogridDispatchesResponse(
164+
event=message.event.event.value,
165+
dispatch=message.event.dispatch.to_protobuf(),
166+
)
167+
yield response
168+
else:
169+
_logger.debug(
170+
"Skipping message for microgrid %s",
171+
message.microgrid_id,
172+
)
173+
174+
return _MockStream(stream())
139175

140176
# pylint: disable=too-many-branches
141177
@staticmethod
@@ -196,12 +232,18 @@ async def CreateMicrogridDispatch(
196232
# implicitly create the list if it doesn't exist
197233
self.dispatches.setdefault(microgrid_id, []).append(new_dispatch)
198234

235+
_logger.debug("Created new dispatch: %s", new_dispatch)
236+
199237
await self._stream_sender.send(
200238
self.StreamEvent(
201239
microgrid_id,
202240
DispatchEvent(dispatch=new_dispatch, event=Event.CREATED),
203241
)
204242
)
243+
# Give the stream a chance to process the message
244+
await asyncio.sleep(0)
245+
246+
_logger.debug("Sent create event for dispatch: %s", new_dispatch)
205247

206248
return CreateMicrogridDispatchResponse(dispatch=new_dispatch.to_protobuf())
207249

@@ -293,6 +335,8 @@ async def UpdateMicrogridDispatch(
293335
DispatchEvent(dispatch=dispatch, event=Event.UPDATED),
294336
)
295337
)
338+
# Give the stream a chance to process the message
339+
await asyncio.sleep(0)
296340

297341
return UpdateMicrogridDispatchResponse(dispatch=dispatch.to_protobuf())
298342

@@ -352,6 +396,8 @@ async def DeleteMicrogridDispatch(
352396
),
353397
)
354398
)
399+
# Give the stream a chance to process the message
400+
await asyncio.sleep(0)
355401

356402
return Empty()
357403

0 commit comments

Comments
 (0)