Skip to content

Commit 88435f4

Browse files
committed
implement a TTL for RedisChannelLayer.receive_buffer
see: django#212
1 parent 2075071 commit 88435f4

File tree

2 files changed

+80
-2
lines changed

2 files changed

+80
-2
lines changed

channels_redis/core.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,37 @@ class UnsupportedRedis(Exception):
179179
pass
180180

181181

182+
class ExpiringCache(collections.defaultdict):
183+
def __init__(self, default, ttl=60, *args, **kw):
184+
collections.defaultdict.__init__(self, default)
185+
self._expires = collections.OrderedDict()
186+
self.ttl = ttl
187+
188+
def __setitem__(self, k, v):
189+
collections.defaultdict.__setitem__(self, k, v)
190+
self._expires[k] = time.time() + self.ttl
191+
192+
def __delitem__(self, k):
193+
try:
194+
collections.defaultdict.__delitem__(self, k)
195+
except KeyError:
196+
pass
197+
198+
def expire(self):
199+
expired = []
200+
for k in self._expires.keys():
201+
if self._expires[k] < time.time():
202+
expired.append(k)
203+
else:
204+
# as this is an OrderedDict, every key after this
205+
# was inserted *later*, so if _this_ key is *not* expired,
206+
# the ones after it aren't either (so we can stop iterating)
207+
break
208+
for k in expired:
209+
del self._expires[k]
210+
del self[k]
211+
212+
182213
class RedisChannelLayer(BaseChannelLayer):
183214
"""
184215
Redis channel layer.
@@ -226,7 +257,7 @@ def __init__(
226257
# Event loop they are trying to receive on
227258
self.receive_event_loop = None
228259
# Buffered messages by process-local channel name
229-
self.receive_buffer = collections.defaultdict(asyncio.Queue)
260+
self.receive_buffer = ExpiringCache(asyncio.Queue, ttl=self.expiry)
230261
# Detached channel cleanup tasks
231262
self.receive_cleaners = []
232263
# Per-channel cleanup locks to prevent a receive starting and moving
@@ -616,6 +647,7 @@ async def group_discard(self, group, channel):
616647
key = self._group_key(group)
617648
async with self.connection(self.consistent_hash(group)) as connection:
618649
await connection.zrem(key, channel)
650+
self.receive_buffer.expire()
619651

620652
async def group_send(self, group, message):
621653
"""

tests/test_core.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
import asyncio
22
import random
3+
import time
34

45
import async_timeout
56
import pytest
67
from async_generator import async_generator, yield_
78

89
from asgiref.sync import async_to_sync
9-
from channels_redis.core import ChannelFull, ConnectionPool, RedisChannelLayer
10+
from channels_redis.core import (
11+
ChannelFull,
12+
ConnectionPool,
13+
RedisChannelLayer,
14+
ExpiringCache,
15+
)
1016

1117
TEST_HOSTS = [("localhost", 6379)]
1218

@@ -627,3 +633,43 @@ def test_custom_group_key_format():
627633
channel_layer = RedisChannelLayer(prefix="test_prefix")
628634
group_name = channel_layer._group_key("test_group")
629635
assert group_name == b"test_prefix:group:test_group"
636+
637+
638+
def test_expiring_buffer_default_value():
639+
buff = ExpiringCache(asyncio.Queue)
640+
assert isinstance(buff["foo"], asyncio.Queue)
641+
642+
643+
def test_expiring_buffer_default_ttl():
644+
buff = ExpiringCache(None)
645+
assert buff.ttl == 60
646+
647+
648+
def test_expiring_buffer_ttl_expiration():
649+
past = time.time() - 60
650+
buff = ExpiringCache(None)
651+
652+
for x in range(100):
653+
buff[x] = "example"
654+
assert len(buff) == 100
655+
buff.expire()
656+
assert len(buff) == 100
657+
658+
for x in range(100):
659+
buff._expires[x] = past
660+
buff["extra"] = "extra"
661+
buff.expire()
662+
assert len(buff) == 1
663+
assert "extra" in buff
664+
assert len(buff._expires) == 1
665+
666+
667+
def test_expiring_buffer_ttl_already_gone():
668+
past = time.time() - 60
669+
buff = ExpiringCache(None)
670+
buff["delete"] = "example"
671+
buff._expires["delete"] = past
672+
del buff["delete"]
673+
buff.expire()
674+
assert len(buff) == 0
675+
assert len(buff._expires) == 0

0 commit comments

Comments
 (0)