Skip to content

Commit dd894ad

Browse files
committed
implement a TTL for RedisChannelLayer.receive_buffer
see: #212
1 parent 2075071 commit dd894ad

File tree

2 files changed

+83
-2
lines changed

2 files changed

+83
-2
lines changed

channels_redis/core.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,40 @@ class UnsupportedRedis(Exception):
179179
pass
180180

181181

182+
class ExpiringCache(collections.defaultdict):
183+
def __init__(self, default, ttl=60, *args, **kw):
184+
collections.defaultdict.__init__(self, default)
185+
self._expires = collections.OrderedDict()
186+
self.ttl = ttl
187+
188+
def __setitem__(self, k, v):
189+
collections.defaultdict.__setitem__(self, k, v)
190+
self._expires[k] = time.time() + self.ttl
191+
192+
def __delitem__(self, k):
193+
try:
194+
collections.defaultdict.__delitem__(self, k)
195+
except KeyError:
196+
# RedisChannelLayer itself _does_ periodically clean up this
197+
# dictionary (e.g., when exceptions like asyncio.CancelledError
198+
# occur)
199+
pass
200+
201+
def expire(self):
202+
expired = []
203+
for k in self._expires.keys():
204+
if self._expires[k] < time.time():
205+
expired.append(k)
206+
else:
207+
# as this is an OrderedDict, every key after this
208+
# was inserted *later*, so if _this_ key is *not* expired,
209+
# the ones after it aren't either (so we can stop iterating)
210+
break
211+
for k in expired:
212+
del self._expires[k]
213+
del self[k]
214+
215+
182216
class RedisChannelLayer(BaseChannelLayer):
183217
"""
184218
Redis channel layer.
@@ -226,7 +260,7 @@ def __init__(
226260
# Event loop they are trying to receive on
227261
self.receive_event_loop = None
228262
# Buffered messages by process-local channel name
229-
self.receive_buffer = collections.defaultdict(asyncio.Queue)
263+
self.receive_buffer = ExpiringCache(asyncio.Queue, ttl=self.expiry)
230264
# Detached channel cleanup tasks
231265
self.receive_cleaners = []
232266
# Per-channel cleanup locks to prevent a receive starting and moving
@@ -616,6 +650,7 @@ async def group_discard(self, group, channel):
616650
key = self._group_key(group)
617651
async with self.connection(self.consistent_hash(group)) as connection:
618652
await connection.zrem(key, channel)
653+
self.receive_buffer.expire()
619654

620655
async def group_send(self, group, message):
621656
"""

tests/test_core.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
import asyncio
22
import random
3+
import time
34

45
import async_timeout
56
import pytest
67
from async_generator import async_generator, yield_
78

89
from asgiref.sync import async_to_sync
9-
from channels_redis.core import ChannelFull, ConnectionPool, RedisChannelLayer
10+
from channels_redis.core import (
11+
ChannelFull,
12+
ConnectionPool,
13+
ExpiringCache,
14+
RedisChannelLayer,
15+
)
1016

1117
TEST_HOSTS = [("localhost", 6379)]
1218

@@ -627,3 +633,43 @@ def test_custom_group_key_format():
627633
channel_layer = RedisChannelLayer(prefix="test_prefix")
628634
group_name = channel_layer._group_key("test_group")
629635
assert group_name == b"test_prefix:group:test_group"
636+
637+
638+
def test_expiring_buffer_default_value():
639+
buff = ExpiringCache(asyncio.Queue)
640+
assert isinstance(buff["foo"], asyncio.Queue)
641+
642+
643+
def test_expiring_buffer_default_ttl():
644+
buff = ExpiringCache(None)
645+
assert buff.ttl == 60
646+
647+
648+
def test_expiring_buffer_ttl_expiration():
649+
past = time.time() - 60
650+
buff = ExpiringCache(None)
651+
652+
for x in range(100):
653+
buff[x] = "example"
654+
assert len(buff) == 100
655+
buff.expire()
656+
assert len(buff) == 100
657+
658+
for x in range(100):
659+
buff._expires[x] = past
660+
buff["extra"] = "extra"
661+
buff.expire()
662+
assert len(buff) == 1
663+
assert "extra" in buff
664+
assert len(buff._expires) == 1
665+
666+
667+
def test_expiring_buffer_ttl_already_gone():
668+
past = time.time() - 60
669+
buff = ExpiringCache(None)
670+
buff["delete"] = "example"
671+
buff._expires["delete"] = past
672+
del buff["delete"]
673+
buff.expire()
674+
assert len(buff) == 0
675+
assert len(buff._expires) == 0

0 commit comments

Comments
 (0)