|
6 | 6 | import asyncio
|
7 | 7 | from collections.abc import Sequence
|
8 | 8 | from typing import TypeGuard, assert_type
|
| 9 | +from unittest.mock import MagicMock |
9 | 10 |
|
10 | 11 | import pytest
|
11 | 12 | from typing_extensions import override
|
@@ -85,3 +86,61 @@ async def test_receiver_drop_while() -> None:
|
85 | 86 |
|
86 | 87 | with pytest.raises(ReceiverStoppedError):
|
87 | 88 | await filtered_receiver.receive()
|
| 89 | + |
| 90 | + |
| 91 | +async def test_receiver_async_iteration() -> None: |
| 92 | + """Test async iteration over the receiver.""" |
| 93 | + receiver = _MockReceiver([1, 2]) |
| 94 | + |
| 95 | + received = [] |
| 96 | + async with asyncio.timeout(1): |
| 97 | + async for message in receiver: |
| 98 | + received.append(message) |
| 99 | + |
| 100 | + assert received == [1, 2] |
| 101 | + |
| 102 | + |
| 103 | +async def test_receiver_map() -> None: |
| 104 | + """Test mapping a function over the receiver's messages.""" |
| 105 | + receiver = _MockReceiver([1, 2]) |
| 106 | + |
| 107 | + mapped_receiver = receiver.map(lambda x: f"{x} + 1") |
| 108 | + assert await mapped_receiver.receive() == "1 + 1" |
| 109 | + assert await mapped_receiver.receive() == "2 + 1" |
| 110 | + |
| 111 | + |
| 112 | +async def test_receiver_filter() -> None: |
| 113 | + """Test filtering the receiver's messages.""" |
| 114 | + receiver = _MockReceiver([1, 2, 3, 4, 5]) |
| 115 | + |
| 116 | + filtered_receiver = receiver.filter(lambda x: x % 2 == 0) |
| 117 | + async with asyncio.timeout(1): |
| 118 | + assert await filtered_receiver.receive() == 2 |
| 119 | + assert await filtered_receiver.receive() == 4 |
| 120 | + |
| 121 | + with pytest.raises(ReceiverStoppedError): |
| 122 | + await filtered_receiver.receive() |
| 123 | + |
| 124 | + |
| 125 | +async def test_receiver_triggered() -> None: |
| 126 | + """Test the triggered method.""" |
| 127 | + receiver = _MockReceiver() |
| 128 | + selected = MagicMock() |
| 129 | + selected._recv = receiver # pylint: disable=protected-access |
| 130 | + |
| 131 | + assert receiver.triggered(selected) |
| 132 | + assert selected._handled # pylint: disable=protected-access |
| 133 | + |
| 134 | + |
| 135 | +async def test_receiver_error_handling() -> None: |
| 136 | + """Test error handling in the receiver.""" |
| 137 | + receiver = _MockReceiver([1]) |
| 138 | + receiver.stop() |
| 139 | + |
| 140 | + async with asyncio.timeout(1): |
| 141 | + with pytest.raises(ReceiverStoppedError): |
| 142 | + await receiver.receive() |
| 143 | + |
| 144 | + receiver = _MockReceiver() |
| 145 | + with pytest.raises(ReceiverError): |
| 146 | + await receiver.receive() |
0 commit comments