5
5
6
6
Useful for testing.
7
7
"""
8
+ import asyncio
8
9
import logging
9
10
from dataclasses import dataclass , replace
10
11
from datetime import datetime , timezone
11
- from typing import AsyncIterator
12
+ from typing import AsyncIterator , TypeVar
12
13
13
14
import grpc
14
15
import grpc .aio
44
45
_logger = logging .getLogger (__name__ )
45
46
46
47
48
+ T = TypeVar ("T" )
49
+
50
+
51
+ class _MockStream (AsyncIterator [T ]):
52
+ """A mock stream that wraps an async iterator and adds initial_metadata."""
53
+
54
+ def __init__ (self , stream : AsyncIterator [T ]) -> None :
55
+ """Initialize the mock stream.
56
+
57
+ Args:
58
+ stream: The stream to wrap.
59
+ """
60
+ self ._iterator = stream .__aiter__ ()
61
+
62
+ async def initial_metadata (self ) -> None :
63
+ """Do nothing, just to mock the grpc call."""
64
+ _logger .debug ("Called initial_metadata()" )
65
+
66
+ def __aiter__ (self ) -> AsyncIterator [T ]:
67
+ """Return the async iterator."""
68
+ return self
69
+
70
+ async def __anext__ (self ) -> T :
71
+ """Return the next item from the stream."""
72
+ return await self ._iterator .__anext__ ()
73
+
74
+
47
75
class FakeService :
48
76
"""Dispatch mock service for testing."""
49
77
@@ -113,7 +141,7 @@ async def StreamMicrogridDispatches(
113
141
self ,
114
142
request : StreamMicrogridDispatchesRequest ,
115
143
timeout : int = 5 , # pylint: disable=unused-argument
116
- ) -> AsyncIterator [StreamMicrogridDispatchesResponse ]:
144
+ ) -> _MockStream [StreamMicrogridDispatchesResponse ]:
117
145
"""Stream microgrid dispatches changes.
118
146
119
147
Args:
@@ -122,20 +150,28 @@ async def StreamMicrogridDispatches(
122
150
123
151
Returns:
124
152
An async generator for dispatch changes.
125
-
126
- Yields:
127
- An event for each dispatch change.
128
153
"""
129
- receiver = self ._stream_channel .new_receiver ()
130
154
131
- async for message in receiver :
132
- _logger .debug ("Received message: %s" , message )
133
- if message .microgrid_id == MicrogridId (request .microgrid_id ):
134
- response = StreamMicrogridDispatchesResponse (
135
- event = message .event .event .value ,
136
- dispatch = message .event .dispatch .to_protobuf (),
137
- )
138
- yield response
155
+ async def stream () -> AsyncIterator [StreamMicrogridDispatchesResponse ]:
156
+ """Stream microgrid dispatches changes."""
157
+ _logger .debug ("Starting stream for microgrid %s" , request .microgrid_id )
158
+ receiver = self ._stream_channel .new_receiver ()
159
+
160
+ async for message in receiver :
161
+ _logger .debug ("Received message: %s" , message )
162
+ if message .microgrid_id == MicrogridId (request .microgrid_id ):
163
+ response = StreamMicrogridDispatchesResponse (
164
+ event = message .event .event .value ,
165
+ dispatch = message .event .dispatch .to_protobuf (),
166
+ )
167
+ yield response
168
+ else :
169
+ _logger .debug (
170
+ "Skipping message for microgrid %s" ,
171
+ message .microgrid_id ,
172
+ )
173
+
174
+ return _MockStream (stream ())
139
175
140
176
# pylint: disable=too-many-branches
141
177
@staticmethod
@@ -196,12 +232,18 @@ async def CreateMicrogridDispatch(
196
232
# implicitly create the list if it doesn't exist
197
233
self .dispatches .setdefault (microgrid_id , []).append (new_dispatch )
198
234
235
+ _logger .debug ("Created new dispatch: %s" , new_dispatch )
236
+
199
237
await self ._stream_sender .send (
200
238
self .StreamEvent (
201
239
microgrid_id ,
202
240
DispatchEvent (dispatch = new_dispatch , event = Event .CREATED ),
203
241
)
204
242
)
243
+ # Give the stream a chance to process the message
244
+ await asyncio .sleep (0 )
245
+
246
+ _logger .debug ("Sent create event for dispatch: %s" , new_dispatch )
205
247
206
248
return CreateMicrogridDispatchResponse (dispatch = new_dispatch .to_protobuf ())
207
249
@@ -293,6 +335,8 @@ async def UpdateMicrogridDispatch(
293
335
DispatchEvent (dispatch = dispatch , event = Event .UPDATED ),
294
336
)
295
337
)
338
+ # Give the stream a chance to process the message
339
+ await asyncio .sleep (0 )
296
340
297
341
return UpdateMicrogridDispatchResponse (dispatch = dispatch .to_protobuf ())
298
342
@@ -352,6 +396,8 @@ async def DeleteMicrogridDispatch(
352
396
),
353
397
)
354
398
)
399
+ # Give the stream a chance to process the message
400
+ await asyncio .sleep (0 )
355
401
356
402
return Empty ()
357
403
0 commit comments