Skip to content

Commit c7b5cbc

Browse files
committed
Move repeated Receiver tests to test_receiver.py
There is no need to test the methods that are implemented in `Receiver` itself for every channel or receiver implementation. We also add a couple of extra trivial tests. Signed-off-by: Leandro Lucarella <[email protected]>
1 parent d3f8eb7 commit c7b5cbc

File tree

3 files changed

+59
-136
lines changed

3 files changed

+59
-136
lines changed

tests/test_anycast.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -162,58 +162,3 @@ async def test_anycast_none_messages() -> None:
162162

163163
await sender.send(10)
164164
assert await receiver.receive() == 10
165-
166-
167-
async def test_anycast_async_iterator() -> None:
168-
"""Check that the anycast receiver works as an async iterator."""
169-
acast: Anycast[str] = Anycast(name="test")
170-
171-
sender = acast.new_sender()
172-
receiver = acast.new_receiver()
173-
174-
async def send_messages() -> None:
175-
for val in ["one", "two", "three", "four", "five"]:
176-
await sender.send(val)
177-
await acast.close()
178-
179-
sender_task = asyncio.create_task(send_messages())
180-
181-
received = []
182-
async for recv in receiver:
183-
received.append(recv)
184-
185-
assert received == ["one", "two", "three", "four", "five"]
186-
187-
await sender_task
188-
189-
190-
async def test_anycast_map() -> None:
191-
"""Ensure map runs on all incoming messages."""
192-
chan: Anycast[int] = Anycast(name="test")
193-
sender = chan.new_sender()
194-
195-
# transform int receiver into bool receiver.
196-
receiver: Receiver[bool] = chan.new_receiver().map(lambda num: num > 10)
197-
198-
await sender.send(8)
199-
await sender.send(12)
200-
201-
assert (await receiver.receive()) is False
202-
assert (await receiver.receive()) is True
203-
204-
205-
async def test_anycast_filter() -> None:
206-
"""Ensure filter keeps only the messages that pass the filter."""
207-
chan = Anycast[int](name="input-chan")
208-
sender = chan.new_sender()
209-
210-
# filter out all numbers less than 10.
211-
receiver: Receiver[int] = chan.new_receiver().filter(lambda num: num > 10)
212-
213-
await sender.send(8)
214-
await sender.send(12)
215-
await sender.send(5)
216-
await sender.send(15)
217-
218-
assert (await receiver.receive()) == 12
219-
assert (await receiver.receive()) == 15

tests/test_broadcast.py

Lines changed: 0 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import asyncio
88
from dataclasses import dataclass
9-
from typing import TypeGuard, assert_never
109

1110
import pytest
1211

@@ -194,86 +193,6 @@ async def test_broadcast_no_resend_latest() -> None:
194193
assert await new_recv.receive() == 100
195194

196195

197-
async def test_broadcast_async_iterator() -> None:
198-
"""Check that the broadcast receiver works as an async iterator."""
199-
bcast: Broadcast[int] = Broadcast(name="iter_test")
200-
201-
sender = bcast.new_sender()
202-
receiver = bcast.new_receiver()
203-
204-
async def send_messages() -> None:
205-
for val in range(0, 10):
206-
await sender.send(val)
207-
await bcast.close()
208-
209-
sender_task = asyncio.create_task(send_messages())
210-
211-
received = []
212-
async for recv in receiver:
213-
received.append(recv)
214-
215-
assert received == list(range(0, 10))
216-
217-
await sender_task
218-
219-
220-
async def test_broadcast_map() -> None:
221-
"""Ensure map runs on all incoming messages."""
222-
chan = Broadcast[int](name="input-chan")
223-
sender = chan.new_sender()
224-
225-
# transform int receiver into bool receiver.
226-
receiver: Receiver[bool] = chan.new_receiver().map(lambda num: num > 10)
227-
228-
await sender.send(8)
229-
await sender.send(12)
230-
231-
assert (await receiver.receive()) is False
232-
assert (await receiver.receive()) is True
233-
234-
235-
async def test_broadcast_filter() -> None:
236-
"""Ensure filter keeps only the messages that pass the filter."""
237-
chan = Broadcast[int](name="input-chan")
238-
sender = chan.new_sender()
239-
240-
# filter out all numbers less than 10.
241-
receiver: Receiver[int] = chan.new_receiver().filter(lambda num: num > 10)
242-
243-
await sender.send(8)
244-
await sender.send(12)
245-
await sender.send(5)
246-
await sender.send(15)
247-
248-
assert (await receiver.receive()) == 12
249-
assert (await receiver.receive()) == 15
250-
251-
252-
async def test_broadcast_filter_type_guard() -> None:
253-
"""Ensure filter type guard works."""
254-
chan = Broadcast[int | str](name="input-chan")
255-
sender = chan.new_sender()
256-
257-
def _is_int(num: int | str) -> TypeGuard[int]:
258-
return isinstance(num, int)
259-
260-
# filter out objects that are not integers.
261-
receiver = chan.new_receiver().filter(_is_int)
262-
263-
await sender.send("hello")
264-
await sender.send(8)
265-
266-
message = await receiver.receive()
267-
assert message == 8
268-
is_int = False
269-
match message:
270-
case int():
271-
is_int = True
272-
case unexpected:
273-
assert_never(unexpected)
274-
assert is_int
275-
276-
277196
async def test_broadcast_receiver_drop() -> None:
278197
"""Ensure deleted receivers get cleaned up."""
279198
chan = Broadcast[int](name="input-chan")

tests/test_receiver.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import asyncio
77
from collections.abc import Sequence
88
from typing import TypeGuard, assert_type
9+
from unittest.mock import MagicMock
910

1011
import pytest
1112
from typing_extensions import override
@@ -85,3 +86,61 @@ async def test_receiver_drop_while() -> None:
8586

8687
with pytest.raises(ReceiverStoppedError):
8788
await filtered_receiver.receive()
89+
90+
91+
async def test_receiver_async_iteration() -> None:
92+
"""Test async iteration over the receiver."""
93+
receiver = _MockReceiver([1, 2])
94+
95+
received = []
96+
async with asyncio.timeout(1):
97+
async for message in receiver:
98+
received.append(message)
99+
100+
assert received == [1, 2]
101+
102+
103+
async def test_receiver_map() -> None:
104+
"""Test mapping a function over the receiver's messages."""
105+
receiver = _MockReceiver([1, 2])
106+
107+
mapped_receiver = receiver.map(lambda x: f"{x} + 1")
108+
assert await mapped_receiver.receive() == "1 + 1"
109+
assert await mapped_receiver.receive() == "2 + 1"
110+
111+
112+
async def test_receiver_filter() -> None:
113+
"""Test filtering the receiver's messages."""
114+
receiver = _MockReceiver([1, 2, 3, 4, 5])
115+
116+
filtered_receiver = receiver.filter(lambda x: x % 2 == 0)
117+
async with asyncio.timeout(1):
118+
assert await filtered_receiver.receive() == 2
119+
assert await filtered_receiver.receive() == 4
120+
121+
with pytest.raises(ReceiverStoppedError):
122+
await filtered_receiver.receive()
123+
124+
125+
async def test_receiver_triggered() -> None:
126+
"""Test the triggered method."""
127+
receiver = _MockReceiver()
128+
selected = MagicMock()
129+
selected._recv = receiver # pylint: disable=protected-access
130+
131+
assert receiver.triggered(selected)
132+
assert selected._handled # pylint: disable=protected-access
133+
134+
135+
async def test_receiver_error_handling() -> None:
136+
"""Test error handling in the receiver."""
137+
receiver = _MockReceiver([1])
138+
receiver.stop()
139+
140+
async with asyncio.timeout(1):
141+
with pytest.raises(ReceiverStoppedError):
142+
await receiver.receive()
143+
144+
receiver = _MockReceiver()
145+
with pytest.raises(ReceiverError):
146+
await receiver.receive()

0 commit comments

Comments
 (0)