|
| 1 | +from asyncio import Future, Queue, ensure_future, sleep |
| 2 | +from inspect import isawaitable |
| 3 | +from typing import Any, AsyncIterator, Callable, Optional, Set |
| 4 | + |
| 5 | +try: |
| 6 | + from asyncio import get_running_loop |
| 7 | +except ImportError: |
| 8 | + from asyncio import get_event_loop as get_running_loop # Python < 3.7 |
| 9 | + |
| 10 | + |
| 11 | +__all__ = ["SimplePubSub", "SimplePubSubIterator"] |
| 12 | + |
| 13 | + |
| 14 | +class SimplePubSub: |
| 15 | + """A very simple publish-subscript system. |
| 16 | +
|
| 17 | + Creates an AsyncIterator from an EventEmitter. |
| 18 | +
|
| 19 | + Useful for mocking a PubSub system for tests. |
| 20 | + """ |
| 21 | + |
| 22 | + subscribers: Set[Callable] |
| 23 | + |
| 24 | + def __init__(self) -> None: |
| 25 | + self.subscribers = set() |
| 26 | + |
| 27 | + def emit(self, event: Any) -> bool: |
| 28 | + """Emit an event.""" |
| 29 | + for subscriber in self.subscribers: |
| 30 | + result = subscriber(event) |
| 31 | + if isawaitable(result): |
| 32 | + ensure_future(result) |
| 33 | + return bool(self.subscribers) |
| 34 | + |
| 35 | + def get_subscriber( |
| 36 | + self, transform: Optional[Callable] = None |
| 37 | + ) -> "SimplePubSubIterator": |
| 38 | + return SimplePubSubIterator(self, transform) |
| 39 | + |
| 40 | + |
| 41 | +class SimplePubSubIterator(AsyncIterator): |
| 42 | + def __init__(self, pubsub: SimplePubSub, transform: Optional[Callable]) -> None: |
| 43 | + self.pubsub = pubsub |
| 44 | + self.transform = transform |
| 45 | + self.pull_queue: Queue[Future] = Queue() |
| 46 | + self.push_queue: Queue[Any] = Queue() |
| 47 | + self.listening = True |
| 48 | + pubsub.subscribers.add(self.push_value) |
| 49 | + |
| 50 | + def __aiter__(self) -> "SimplePubSubIterator": |
| 51 | + return self |
| 52 | + |
| 53 | + async def __anext__(self) -> Any: |
| 54 | + if not self.listening: |
| 55 | + raise StopAsyncIteration |
| 56 | + await sleep(0) |
| 57 | + if not self.push_queue.empty(): |
| 58 | + return await self.push_queue.get() |
| 59 | + future = get_running_loop().create_future() |
| 60 | + await self.pull_queue.put(future) |
| 61 | + return future |
| 62 | + |
| 63 | + async def aclose(self) -> None: |
| 64 | + if self.listening: |
| 65 | + await self.empty_queue() |
| 66 | + |
| 67 | + async def empty_queue(self) -> None: |
| 68 | + self.listening = False |
| 69 | + self.pubsub.subscribers.remove(self.push_value) |
| 70 | + while not self.pull_queue.empty(): |
| 71 | + future = await self.pull_queue.get() |
| 72 | + future.cancel() |
| 73 | + while not self.push_queue.empty(): |
| 74 | + await self.push_queue.get() |
| 75 | + |
| 76 | + async def push_value(self, event: Any) -> None: |
| 77 | + value = event if self.transform is None else self.transform(event) |
| 78 | + if self.pull_queue.empty(): |
| 79 | + await self.push_queue.put(value) |
| 80 | + else: |
| 81 | + (await self.pull_queue.get()).set_result(value) |
0 commit comments