Skip to content

Commit b504635

Browse files
qeternityShi Feng
authored and
Shi Feng
committed
Refactored PubSub shard connection. (django#326)
1 parent ce1ce03 commit b504635

File tree

3 files changed

+128
-167
lines changed

3 files changed

+128
-167
lines changed

channels_redis/pubsub.py

Lines changed: 62 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import types
55
import uuid
66

7-
import async_timeout
87
import msgpack
98
from redis import asyncio as aioredis
109

@@ -178,7 +177,7 @@ async def receive(self, channel):
178177
q = self.channels[channel]
179178
try:
180179
message = await q.get()
181-
except asyncio.CancelledError:
180+
except (asyncio.CancelledError, asyncio.TimeoutError, GeneratorExit):
182181
# We assume here that the reason we are cancelled is because the consumer
183182
# is exiting, therefore we need to cleanup by unsubscribe below. Indeed,
184183
# currently the way that Django Channels works, this is a safe assumption.
@@ -266,153 +265,81 @@ def __init__(self, host, channel_layer):
266265
self.master_name = self.host.pop("master_name", None)
267266
self.channel_layer = channel_layer
268267
self._subscribed_to = set()
269-
self._lock = None
268+
self._lock = asyncio.Lock()
270269
self._redis = None
271-
self._pub_conn = None
272-
self._sub_conn = None
273-
self._receiver = None
270+
self._pubsub = None
274271
self._receive_task = None
275-
self._keepalive_task = None
276272

277273
async def publish(self, channel, message):
278-
conn = await self._get_pub_conn()
279-
await conn.publish(channel, message)
274+
async with self._lock:
275+
self._ensure_redis()
276+
await self._redis.publish(channel, message)
280277

281278
async def subscribe(self, channel):
282-
if channel not in self._subscribed_to:
283-
self._subscribed_to.add(channel)
284-
await self._get_sub_conn()
285-
await self._receiver.subscribe(channel)
279+
async with self._lock:
280+
if channel not in self._subscribed_to:
281+
self._ensure_redis()
282+
self._ensure_receiver()
283+
await self._pubsub.subscribe(channel)
284+
self._subscribed_to.add(channel)
286285

287286
async def unsubscribe(self, channel):
288-
if channel in self._subscribed_to:
289-
self._subscribed_to.remove(channel)
290-
conn = await self._get_sub_conn()
291-
await conn.unsubscribe(channel)
287+
async with self._lock:
288+
if channel in self._subscribed_to:
289+
self._ensure_redis()
290+
self._ensure_receiver()
291+
await self._pubsub.unsubscribe(channel)
292+
self._subscribed_to.remove(channel)
292293

293294
async def flush(self):
294-
for task in [self._keepalive_task, self._receive_task]:
295-
if task is not None:
296-
task.cancel()
295+
async with self._lock:
296+
if self._receive_task is not None:
297+
self._receive_task.cancel()
297298
try:
298-
await task
299+
await self._receive_task
299300
except asyncio.CancelledError:
300301
pass
301-
self._keepalive_task = None
302-
self._receive_task = None
303-
self._receiver = None
304-
if self._sub_conn is not None:
305-
await self._sub_conn.close()
306-
await self._put_redis_conn(self._sub_conn)
307-
self._sub_conn = None
308-
if self._pub_conn is not None:
309-
await self._pub_conn.close()
310-
await self._put_redis_conn(self._pub_conn)
311-
self._pub_conn = None
312-
self._subscribed_to = set()
313-
314-
async def _get_pub_conn(self):
315-
"""
316-
Return the connection to this shard that is used for *publishing* messages.
317-
318-
If the connection is dead, automatically reconnect.
319-
"""
320-
if self._lock is None:
321-
self._lock = asyncio.Lock()
322-
async with self._lock:
323-
if self._pub_conn is not None and self._pub_conn.connection is None:
324-
await self._put_redis_conn(self._pub_conn)
325-
self._pub_conn = None
326-
while self._pub_conn is None:
327-
try:
328-
self._pub_conn = await self._get_redis_conn()
329-
except BaseException:
330-
await self._put_redis_conn(self._pub_conn)
331-
logger.warning(
332-
f"Failed to connect to Redis publish host: {self.host}; will try again in 1 second..."
333-
)
334-
await asyncio.sleep(1)
335-
return self._pub_conn
336-
337-
async def _get_sub_conn(self):
338-
"""
339-
Return the connection to this shard that is used for *subscribing* to channels.
340-
341-
If the connection is dead, automatically reconnect and resubscribe to all our channels!
342-
"""
343-
if self._keepalive_task is None:
344-
self._keepalive_task = asyncio.ensure_future(self._do_keepalive())
345-
if self._lock is None:
346-
self._lock = asyncio.Lock()
347-
async with self._lock:
348-
if self._sub_conn is not None and self._sub_conn.connection is None:
349-
await self._put_redis_conn(self._sub_conn)
350-
self._sub_conn = None
351-
self._notify_consumers(self.channel_layer.on_disconnect)
352-
if self._sub_conn is None:
353-
if self._receive_task is not None:
354-
self._receive_task.cancel()
355-
try:
356-
await self._receive_task
357-
except asyncio.CancelledError:
358-
# This is the normal case, that `asyncio.CancelledError` is throw. All good.
359-
pass
360-
except BaseException:
361-
logger.exception(
362-
"Unexpected exception while canceling the receiver task:"
363-
)
364-
# Don't re-raise here. We don't actually care why `_receive_task` didn't exit cleanly.
365-
self._receive_task = None
366-
while self._sub_conn is None:
367-
try:
368-
self._sub_conn = await self._get_redis_conn()
369-
except BaseException:
370-
await self._put_redis_conn(self._sub_conn)
371-
logger.warning(
372-
f"Failed to connect to Redis subscribe host: {self.host}; will try again in 1 second..."
373-
)
374-
await asyncio.sleep(1)
375-
self._receiver = self._sub_conn.pubsub()
376-
if not self._receiver.subscribed:
377-
await self._receiver.subscribe(*self._subscribed_to)
378-
self._notify_consumers(self.channel_layer.on_reconnect)
379-
self._receive_task = asyncio.ensure_future(self._do_receiving())
380-
return self._sub_conn
302+
self._receive_task = None
303+
if self._redis is not None:
304+
await self._redis.close()
305+
self._redis = None
306+
self._pubsub = None
307+
self._subscribed_to = set()
381308

382309
async def _do_receiving(self):
383310
while True:
384311
try:
385-
async with async_timeout.timeout(1):
386-
message = await self._receiver.get_message(
387-
ignore_subscribe_messages=True
312+
if self._pubsub and self._pubsub.subscribed:
313+
message = await self._pubsub.get_message(
314+
ignore_subscribe_messages=True, timeout=0.1
388315
)
389-
if message is not None:
390-
name = message["channel"]
391-
data = message["data"]
392-
if isinstance(name, bytes):
393-
# Reversing what happens here:
394-
# https://github.com/aio-libs/aioredis-py/blob/8a207609b7f8a33e74c7c8130d97186e78cc0052/aioredis/util.py#L17
395-
name = name.decode()
396-
if name in self.channel_layer.channels:
397-
self.channel_layer.channels[name].put_nowait(data)
398-
elif name in self.channel_layer.groups:
399-
for channel_name in self.channel_layer.groups[name]:
400-
if channel_name in self.channel_layer.channels:
401-
self.channel_layer.channels[
402-
channel_name
403-
].put_nowait(data)
404-
await asyncio.sleep(0.01)
405-
except asyncio.TimeoutError:
406-
pass
407-
408-
def _notify_consumers(self, mtype):
409-
if mtype is not None:
410-
for channel in self.channel_layer.channels.values():
411-
channel.put_nowait(
412-
self.channel_layer.channel_layer.serialize({"type": mtype})
413-
)
414-
415-
async def _ensure_redis(self):
316+
self._receive_message(message)
317+
else:
318+
await asyncio.sleep(0.1)
319+
except (
320+
asyncio.CancelledError,
321+
asyncio.TimeoutError,
322+
GeneratorExit,
323+
):
324+
raise
325+
except BaseException:
326+
logger.exception("Unexpected exception in receive task")
327+
await asyncio.sleep(1)
328+
329+
def _receive_message(self, message):
330+
if message is not None:
331+
name = message["channel"]
332+
data = message["data"]
333+
if isinstance(name, bytes):
334+
name = name.decode()
335+
if name in self.channel_layer.channels:
336+
self.channel_layer.channels[name].put_nowait(data)
337+
elif name in self.channel_layer.groups:
338+
for channel_name in self.channel_layer.groups[name]:
339+
if channel_name in self.channel_layer.channels:
340+
self.channel_layer.channels[channel_name].put_nowait(data)
341+
342+
def _ensure_redis(self):
416343
if self._redis is None:
417344
if self.master_name is None:
418345
pool = aioredis.ConnectionPool.from_url(self.host["address"])
@@ -425,40 +352,8 @@ async def _ensure_redis(self):
425352
),
426353
)
427354
self._redis = aioredis.Redis(connection_pool=pool)
355+
self._pubsub = self._redis.pubsub()
428356

429-
async def _get_redis_conn(self):
430-
await self._ensure_redis()
431-
return self._redis
432-
433-
async def _put_redis_conn(self, conn):
434-
if conn:
435-
await conn.close()
436-
437-
async def _do_keepalive(self):
438-
"""
439-
This task's simple job is just to call `self._get_sub_conn()` periodically.
440-
441-
Why? Well, calling `self._get_sub_conn()` has the nice side-effect that if
442-
that connection has died (because Redis was restarted, or there was a networking
443-
hiccup, for example), then calling `self._get_sub_conn()` will reconnect and
444-
restore our old subscriptions. Thus, we want to do this on a predictable schedule.
445-
This is kinda a sub-optimal way to achieve this, but I can't find a way in aioredis
446-
to get a notification when the connection dies. I find this (sub-optimal) method
447-
of checking the connection state works fine for my app; if Redis restarts, we reconnect
448-
and resubscribe *quickly enough*; I mean, Redis restarting is already bad because it
449-
will cause messages to get lost, and this periodic check at least minimizes the
450-
damage *enough*.
451-
452-
Note you wouldn't need this if you were *sure* that there would be a lot of subscribe/
453-
unsubscribe events on your site, because such events each call `self._get_sub_conn()`.
454-
Thus, on a site with heavy traffic this task may not be necessary, but also maybe it is.
455-
Why? Well, in a heavy traffic site you probably have more than one Django server replicas,
456-
so it might be the case that one of your replicas is under-utilized and this periodic
457-
connection check will be beneficial in the same way as it is for a low-traffic site.
458-
"""
459-
while True:
460-
await asyncio.sleep(1)
461-
try:
462-
await self._get_sub_conn()
463-
except Exception:
464-
logger.exception("Unexpected exception in keepalive task:")
357+
def _ensure_receiver(self):
358+
if self._receive_task is None:
359+
self._receive_task = asyncio.ensure_future(self._do_receiving())

tests/test_pubsub.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,36 @@ async def test_proxied_methods_coroutine_check(channel_layer):
204204
# below Python 3.8.
205205
if sys.version_info >= (3, 8):
206206
assert inspect.iscoroutinefunction(channel_layer.send)
207+
208+
209+
@pytest.mark.asyncio
210+
async def test_receive_hang(channel_layer):
211+
channel_name = await channel_layer.new_channel(prefix="test-channel")
212+
with pytest.raises(asyncio.TimeoutError):
213+
await asyncio.wait_for(channel_layer.receive(channel_name), timeout=1)
214+
215+
216+
@pytest.mark.asyncio
217+
async def test_auto_reconnect(channel_layer):
218+
"""
219+
Tests redis-py reconnect and resubscribe
220+
"""
221+
channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan-1")
222+
channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan-2")
223+
channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3")
224+
await channel_layer.group_add("test-group", channel_name1)
225+
await channel_layer.group_add("test-group", channel_name2)
226+
await channel_layer._shards[0]._redis.close(close_connection_pool=True)
227+
await channel_layer.group_add("test-group", channel_name3)
228+
await channel_layer.group_discard("test-group", channel_name2)
229+
await channel_layer._shards[0]._redis.close(close_connection_pool=True)
230+
await asyncio.sleep(1)
231+
await channel_layer.group_send("test-group", {"type": "message.1"})
232+
# Make sure we get the message on the two channels that were in
233+
async with async_timeout.timeout(5):
234+
assert (await channel_layer.receive(channel_name1))["type"] == "message.1"
235+
assert (await channel_layer.receive(channel_name3))["type"] == "message.1"
236+
# Make sure the removed channel did not get the message
237+
with pytest.raises(asyncio.TimeoutError):
238+
async with async_timeout.timeout(1):
239+
await channel_layer.receive(channel_name2)

tests/test_pubsub_sentinel.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,36 @@ def test_multi_event_loop_garbage_collection(channel_layer):
161161
assert len(channel_layer._layers.values()) == 0
162162
async_to_sync(test_send_receive)(channel_layer)
163163
assert len(channel_layer._layers.values()) == 0
164+
165+
166+
@pytest.mark.asyncio
167+
async def test_receive_hang(channel_layer):
168+
channel_name = await channel_layer.new_channel(prefix="test-channel")
169+
with pytest.raises(asyncio.TimeoutError):
170+
await asyncio.wait_for(channel_layer.receive(channel_name), timeout=1)
171+
172+
173+
@pytest.mark.asyncio
174+
async def test_auto_reconnect(channel_layer):
175+
"""
176+
Tests redis-py reconnect and resubscribe
177+
"""
178+
channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan-1")
179+
channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan-2")
180+
channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3")
181+
await channel_layer.group_add("test-group", channel_name1)
182+
await channel_layer.group_add("test-group", channel_name2)
183+
await channel_layer._shards[0]._redis.close(close_connection_pool=True)
184+
await channel_layer.group_add("test-group", channel_name3)
185+
await channel_layer.group_discard("test-group", channel_name2)
186+
await channel_layer._shards[0]._redis.close(close_connection_pool=True)
187+
await asyncio.sleep(1)
188+
await channel_layer.group_send("test-group", {"type": "message.1"})
189+
# Make sure we get the message on the two channels that were in
190+
async with async_timeout.timeout(5):
191+
assert (await channel_layer.receive(channel_name1))["type"] == "message.1"
192+
assert (await channel_layer.receive(channel_name3))["type"] == "message.1"
193+
# Make sure the removed channel did not get the message
194+
with pytest.raises(asyncio.TimeoutError):
195+
async with async_timeout.timeout(1):
196+
await channel_layer.receive(channel_name2)

0 commit comments

Comments
 (0)