Skip to content

Add take_while and drop_while to Receiver #356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

## New Features

- `Receiver`

* Add `take_while()` as a less ambiguous and more readable alternative to `filter()`.
* Add `drop_while()` as a convenience and more readable alternative to `filter()` with a negated predicate.
* The usage of `filter()` is discouraged in favor of `take_while()` and `drop_while()`.

### Experimental

- A new predicate, `OnlyIfPrevious`, to `filter()` messages based on the previous message.
Expand Down
149 changes: 144 additions & 5 deletions src/frequenz/channels/_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,27 @@
# Message Filtering

If you need to filter the received messages, receivers provide a
[`filter()`][frequenz.channels.Receiver.filter] method to easily do so:
[`take_while()`][frequenz.channels.Receiver.take_while] and a
[`drop_while()`][frequenz.channels.Receiver.drop_while]
method to easily do so:

```python show_lines="6:"
from frequenz.channels import Anycast

channel = Anycast[int](name="test-channel")
receiver = channel.new_receiver()

async for message in receiver.filter(lambda x: x % 2 == 0):
async for message in receiver.take_while(lambda x: x % 2 == 0):
print(message) # Only even numbers will be printed
```

As with [`map()`][frequenz.channels.Receiver.map],
[`filter()`][frequenz.channels.Receiver.filter] returns a new full receiver, so you can
use it in any of the ways described above.
[`take_while()`][frequenz.channels.Receiver.take_while] returns a new full receiver, so
you can use it in any of the ways described above.

[`take_while()`][frequenz.channels.Receiver.take_while] can even receive a
[type guard][typing.TypeGuard] as the predicate to narrow the type of the received
messages.

# Error Handling

Expand Down Expand Up @@ -280,6 +286,11 @@ def filter(
) -> Receiver[FilteredMessageT_co]:
"""Apply a type guard on the messages on a receiver.

Tip:
It is recommended to use the
[`take_while()`][frequenz.channels.Receiver.take_while] method instead of
this one, as it makes the intention more clear.

Tip:
The returned receiver type won't have all the methods of the original
receiver. If you need to access methods of the original receiver that are
Expand All @@ -301,6 +312,11 @@ def filter(
) -> Receiver[ReceiverMessageT_co]:
"""Apply a filter function on the messages on a receiver.

Tip:
It is recommended to use the
[`take_while()`][frequenz.channels.Receiver.take_while] method instead of
this one, as it makes the intention more clear.

Tip:
The returned receiver type won't have all the methods of the original
receiver. If you need to access methods of the original receiver that are
Expand All @@ -326,6 +342,11 @@ def filter(
) -> Receiver[ReceiverMessageT_co] | Receiver[FilteredMessageT_co]:
"""Apply a filter function on the messages on a receiver.

Tip:
It is recommended to use the
[`take_while()`][frequenz.channels.Receiver.take_while] method instead of
this one, as it makes the intention more clear.

Note:
You can pass a [type guard][typing.TypeGuard] as the filter function to
narrow the type of the messages that pass the filter.
Expand All @@ -345,6 +366,117 @@ def filter(
"""
return _Filter(receiver=self, filter_function=filter_function)

@overload
def take_while(
self,
predicate: Callable[[ReceiverMessageT_co], TypeGuard[FilteredMessageT_co]],
/,
) -> Receiver[FilteredMessageT_co]:
"""Take only the messages that fulfill a predicate, narrowing the type.

The returned receiver will only receive messages that fulfill the predicate
(evaluates to `True`), and will drop messages that don't.

Tip:
The returned receiver type won't have all the methods of the original
receiver. If you need to access methods of the original receiver that are
not part of the `Receiver` interface you should save a reference to the
original receiver and use that instead.

Args:
predicate: The predicate to be applied on incoming messages to
determine if they should be taken.

Returns:
A new receiver that only receives messages that fulfill the predicate.
"""
... # pylint: disable=unnecessary-ellipsis

@overload
def take_while(
self, predicate: Callable[[ReceiverMessageT_co], bool], /
) -> Receiver[ReceiverMessageT_co]:
"""Take only the messages that fulfill a predicate.

The returned receiver will only receive messages that fulfill the predicate
(evaluates to `True`), and will drop messages that don't.

Tip:
The returned receiver type won't have all the methods of the original
receiver. If you need to access methods of the original receiver that are
not part of the `Receiver` interface you should save a reference to the
original receiver and use that instead.

Args:
predicate: The predicate to be applied on incoming messages to
determine if they should be taken.

Returns:
A new receiver that only receives messages that fulfill the predicate.
"""
... # pylint: disable=unnecessary-ellipsis

def take_while(
self,
predicate: (
Callable[[ReceiverMessageT_co], bool]
| Callable[[ReceiverMessageT_co], TypeGuard[FilteredMessageT_co]]
),
/,
) -> Receiver[ReceiverMessageT_co] | Receiver[FilteredMessageT_co]:
"""Take only the messages that fulfill a predicate.

The returned receiver will only receive messages that fulfill the predicate
(evaluates to `True`), and will drop messages that don't.

Note:
You can pass a [type guard][typing.TypeGuard] as the predicate to narrow the
type of the received messages.

Tip:
The returned receiver type won't have all the methods of the original
receiver. If you need to access methods of the original receiver that are
not part of the `Receiver` interface you should save a reference to the
original receiver and use that instead.

Args:
predicate: The predicate to be applied on incoming messages to
determine if they should be taken.

Returns:
A new receiver that only receives messages that fulfill the predicate.
"""
return _Filter(receiver=self, filter_function=predicate)

def drop_while(
self,
predicate: Callable[[ReceiverMessageT_co], bool],
/,
) -> Receiver[ReceiverMessageT_co] | Receiver[ReceiverMessageT_co]:
"""Drop the messages that fulfill a predicate.

The returned receiver will drop messages that fulfill the predicate
(evaluates to `True`), and receive messages that don't.

Tip:
If you need to narrow the type of the received messages, you can use the
[`take_while()`][frequenz.channels.Receiver.take_while] method instead.

Tip:
The returned receiver type won't have all the methods of the original
receiver. If you need to access methods of the original receiver that are
not part of the `Receiver` interface you should save a reference to the
original receiver and use that instead.

Args:
predicate: The predicate to be applied on incoming messages to
determine if they should be dropped.

Returns:
A new receiver that only receives messages that don't fulfill the predicate.
"""
return _Filter(receiver=self, filter_function=predicate, negate=True)

def triggered(
self, selected: Selected[Any]
) -> TypeGuard[Selected[ReceiverMessageT_co]]:
Expand Down Expand Up @@ -492,12 +624,14 @@ def __init__(
*,
receiver: Receiver[ReceiverMessageT_co],
filter_function: Callable[[ReceiverMessageT_co], bool],
negate: bool = False,
) -> None:
"""Initialize this receiver filter.

Args:
receiver: The input receiver.
filter_function: The function to apply on the input data.
negate: Whether to negate the filter function.
"""
self._receiver: Receiver[ReceiverMessageT_co] = receiver
"""The input receiver."""
Expand All @@ -507,6 +641,8 @@ def __init__(

self._next_message: ReceiverMessageT_co | _Sentinel = _SENTINEL

self._negate: bool = negate

self._recv_closed = False

async def ready(self) -> bool:
Expand All @@ -522,7 +658,10 @@ async def ready(self) -> bool:
"""
while await self._receiver.ready():
message = self._receiver.consume()
if self._filter_function(message):
result = self._filter_function(message)
if self._negate:
result = not result
if result:
self._next_message = message
return True
self._recv_closed = True
Expand Down
55 changes: 0 additions & 55 deletions tests/test_anycast.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,58 +162,3 @@ async def test_anycast_none_messages() -> None:

await sender.send(10)
assert await receiver.receive() == 10


async def test_anycast_async_iterator() -> None:
"""Check that the anycast receiver works as an async iterator."""
acast: Anycast[str] = Anycast(name="test")

sender = acast.new_sender()
receiver = acast.new_receiver()

async def send_messages() -> None:
for val in ["one", "two", "three", "four", "five"]:
await sender.send(val)
await acast.close()

sender_task = asyncio.create_task(send_messages())

received = []
async for recv in receiver:
received.append(recv)

assert received == ["one", "two", "three", "four", "five"]

await sender_task


async def test_anycast_map() -> None:
"""Ensure map runs on all incoming messages."""
chan: Anycast[int] = Anycast(name="test")
sender = chan.new_sender()

# transform int receiver into bool receiver.
receiver: Receiver[bool] = chan.new_receiver().map(lambda num: num > 10)

await sender.send(8)
await sender.send(12)

assert (await receiver.receive()) is False
assert (await receiver.receive()) is True


async def test_anycast_filter() -> None:
"""Ensure filter keeps only the messages that pass the filter."""
chan = Anycast[int](name="input-chan")
sender = chan.new_sender()

# filter out all numbers less than 10.
receiver: Receiver[int] = chan.new_receiver().filter(lambda num: num > 10)

await sender.send(8)
await sender.send(12)
await sender.send(5)
await sender.send(15)

assert (await receiver.receive()) == 12
assert (await receiver.receive()) == 15
81 changes: 0 additions & 81 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import asyncio
from dataclasses import dataclass
from typing import TypeGuard, assert_never

import pytest

Expand Down Expand Up @@ -194,86 +193,6 @@ async def test_broadcast_no_resend_latest() -> None:
assert await new_recv.receive() == 100


async def test_broadcast_async_iterator() -> None:
"""Check that the broadcast receiver works as an async iterator."""
bcast: Broadcast[int] = Broadcast(name="iter_test")

sender = bcast.new_sender()
receiver = bcast.new_receiver()

async def send_messages() -> None:
for val in range(0, 10):
await sender.send(val)
await bcast.close()

sender_task = asyncio.create_task(send_messages())

received = []
async for recv in receiver:
received.append(recv)

assert received == list(range(0, 10))

await sender_task


async def test_broadcast_map() -> None:
"""Ensure map runs on all incoming messages."""
chan = Broadcast[int](name="input-chan")
sender = chan.new_sender()

# transform int receiver into bool receiver.
receiver: Receiver[bool] = chan.new_receiver().map(lambda num: num > 10)

await sender.send(8)
await sender.send(12)

assert (await receiver.receive()) is False
assert (await receiver.receive()) is True


async def test_broadcast_filter() -> None:
"""Ensure filter keeps only the messages that pass the filter."""
chan = Broadcast[int](name="input-chan")
sender = chan.new_sender()

# filter out all numbers less than 10.
receiver: Receiver[int] = chan.new_receiver().filter(lambda num: num > 10)

await sender.send(8)
await sender.send(12)
await sender.send(5)
await sender.send(15)

assert (await receiver.receive()) == 12
assert (await receiver.receive()) == 15


async def test_broadcast_filter_type_guard() -> None:
"""Ensure filter type guard works."""
chan = Broadcast[int | str](name="input-chan")
sender = chan.new_sender()

def _is_int(num: int | str) -> TypeGuard[int]:
return isinstance(num, int)

# filter out objects that are not integers.
receiver = chan.new_receiver().filter(_is_int)

await sender.send("hello")
await sender.send(8)

message = await receiver.receive()
assert message == 8
is_int = False
match message:
case int():
is_int = True
case unexpected:
assert_never(unexpected)
assert is_int


async def test_broadcast_receiver_drop() -> None:
"""Ensure deleted receivers get cleaned up."""
chan = Broadcast[int](name="input-chan")
Expand Down
Loading
Loading