4
4
import types
5
5
import uuid
6
6
7
- import async_timeout
8
7
import msgpack
9
8
from redis import asyncio as aioredis
10
9
@@ -178,7 +177,7 @@ async def receive(self, channel):
178
177
q = self .channels [channel ]
179
178
try :
180
179
message = await q .get ()
181
- except asyncio .CancelledError :
180
+ except ( asyncio .CancelledError , asyncio . TimeoutError , GeneratorExit ) :
182
181
# We assume here that the reason we are cancelled is because the consumer
183
182
# is exiting, therefore we need to cleanup by unsubscribe below. Indeed,
184
183
# currently the way that Django Channels works, this is a safe assumption.
@@ -266,153 +265,81 @@ def __init__(self, host, channel_layer):
266
265
self .master_name = self .host .pop ("master_name" , None )
267
266
self .channel_layer = channel_layer
268
267
self ._subscribed_to = set ()
269
- self ._lock = None
268
+ self ._lock = asyncio . Lock ()
270
269
self ._redis = None
271
- self ._pub_conn = None
272
- self ._sub_conn = None
273
- self ._receiver = None
270
+ self ._pubsub = None
274
271
self ._receive_task = None
275
- self ._keepalive_task = None
276
272
277
273
async def publish (self , channel , message ):
278
- conn = await self ._get_pub_conn ()
279
- await conn .publish (channel , message )
274
+ async with self ._lock :
275
+ self ._ensure_redis ()
276
+ await self ._redis .publish (channel , message )
280
277
281
278
async def subscribe (self , channel ):
282
- if channel not in self ._subscribed_to :
283
- self ._subscribed_to .add (channel )
284
- await self ._get_sub_conn ()
285
- await self ._receiver .subscribe (channel )
279
+ async with self ._lock :
280
+ if channel not in self ._subscribed_to :
281
+ self ._ensure_redis ()
282
+ self ._ensure_receiver ()
283
+ await self ._pubsub .subscribe (channel )
284
+ self ._subscribed_to .add (channel )
286
285
287
286
async def unsubscribe (self , channel ):
288
- if channel in self ._subscribed_to :
289
- self ._subscribed_to .remove (channel )
290
- conn = await self ._get_sub_conn ()
291
- await conn .unsubscribe (channel )
287
+ async with self ._lock :
288
+ if channel in self ._subscribed_to :
289
+ self ._ensure_redis ()
290
+ self ._ensure_receiver ()
291
+ await self ._pubsub .unsubscribe (channel )
292
+ self ._subscribed_to .remove (channel )
292
293
293
294
async def flush (self ):
294
- for task in [ self ._keepalive_task , self . _receive_task ] :
295
- if task is not None :
296
- task .cancel ()
295
+ async with self ._lock :
296
+ if self . _receive_task is not None :
297
+ self . _receive_task .cancel ()
297
298
try :
298
- await task
299
+ await self . _receive_task
299
300
except asyncio .CancelledError :
300
301
pass
301
- self ._keepalive_task = None
302
- self ._receive_task = None
303
- self ._receiver = None
304
- if self ._sub_conn is not None :
305
- await self ._sub_conn .close ()
306
- await self ._put_redis_conn (self ._sub_conn )
307
- self ._sub_conn = None
308
- if self ._pub_conn is not None :
309
- await self ._pub_conn .close ()
310
- await self ._put_redis_conn (self ._pub_conn )
311
- self ._pub_conn = None
312
- self ._subscribed_to = set ()
313
-
314
- async def _get_pub_conn (self ):
315
- """
316
- Return the connection to this shard that is used for *publishing* messages.
317
-
318
- If the connection is dead, automatically reconnect.
319
- """
320
- if self ._lock is None :
321
- self ._lock = asyncio .Lock ()
322
- async with self ._lock :
323
- if self ._pub_conn is not None and self ._pub_conn .connection is None :
324
- await self ._put_redis_conn (self ._pub_conn )
325
- self ._pub_conn = None
326
- while self ._pub_conn is None :
327
- try :
328
- self ._pub_conn = await self ._get_redis_conn ()
329
- except BaseException :
330
- await self ._put_redis_conn (self ._pub_conn )
331
- logger .warning (
332
- f"Failed to connect to Redis publish host: { self .host } ; will try again in 1 second..."
333
- )
334
- await asyncio .sleep (1 )
335
- return self ._pub_conn
336
-
337
- async def _get_sub_conn (self ):
338
- """
339
- Return the connection to this shard that is used for *subscribing* to channels.
340
-
341
- If the connection is dead, automatically reconnect and resubscribe to all our channels!
342
- """
343
- if self ._keepalive_task is None :
344
- self ._keepalive_task = asyncio .ensure_future (self ._do_keepalive ())
345
- if self ._lock is None :
346
- self ._lock = asyncio .Lock ()
347
- async with self ._lock :
348
- if self ._sub_conn is not None and self ._sub_conn .connection is None :
349
- await self ._put_redis_conn (self ._sub_conn )
350
- self ._sub_conn = None
351
- self ._notify_consumers (self .channel_layer .on_disconnect )
352
- if self ._sub_conn is None :
353
- if self ._receive_task is not None :
354
- self ._receive_task .cancel ()
355
- try :
356
- await self ._receive_task
357
- except asyncio .CancelledError :
358
- # This is the normal case, that `asyncio.CancelledError` is throw. All good.
359
- pass
360
- except BaseException :
361
- logger .exception (
362
- "Unexpected exception while canceling the receiver task:"
363
- )
364
- # Don't re-raise here. We don't actually care why `_receive_task` didn't exit cleanly.
365
- self ._receive_task = None
366
- while self ._sub_conn is None :
367
- try :
368
- self ._sub_conn = await self ._get_redis_conn ()
369
- except BaseException :
370
- await self ._put_redis_conn (self ._sub_conn )
371
- logger .warning (
372
- f"Failed to connect to Redis subscribe host: { self .host } ; will try again in 1 second..."
373
- )
374
- await asyncio .sleep (1 )
375
- self ._receiver = self ._sub_conn .pubsub ()
376
- if not self ._receiver .subscribed :
377
- await self ._receiver .subscribe (* self ._subscribed_to )
378
- self ._notify_consumers (self .channel_layer .on_reconnect )
379
- self ._receive_task = asyncio .ensure_future (self ._do_receiving ())
380
- return self ._sub_conn
302
+ self ._receive_task = None
303
+ if self ._redis is not None :
304
+ await self ._redis .close ()
305
+ self ._redis = None
306
+ self ._pubsub = None
307
+ self ._subscribed_to = set ()
381
308
382
309
async def _do_receiving (self ):
383
310
while True :
384
311
try :
385
- async with async_timeout . timeout ( 1 ) :
386
- message = await self ._receiver .get_message (
387
- ignore_subscribe_messages = True
312
+ if self . _pubsub and self . _pubsub . subscribed :
313
+ message = await self ._pubsub .get_message (
314
+ ignore_subscribe_messages = True , timeout = 0.1
388
315
)
389
- if message is not None :
390
- name = message [ "channel" ]
391
- data = message [ "data" ]
392
- if isinstance ( name , bytes ):
393
- # Reversing what happens here:
394
- # https://github.com/aio-libs/aioredis-py/blob/8a207609b7f8a33e74c7c8130d97186e78cc0052/aioredis/util.py#L17
395
- name = name . decode ()
396
- if name in self . channel_layer . channels :
397
- self . channel_layer . channels [ name ]. put_nowait ( data )
398
- elif name in self . channel_layer . groups :
399
- for channel_name in self . channel_layer . groups [ name ]:
400
- if channel_name in self . channel_layer . channels :
401
- self . channel_layer . channels [
402
- channel_name
403
- ]. put_nowait ( data )
404
- await asyncio . sleep ( 0.01 )
405
- except asyncio . TimeoutError :
406
- pass
407
-
408
- def _notify_consumers ( self , mtype ) :
409
- if mtype is not None :
410
- for channel in self .channel_layer .channels . values () :
411
- channel . put_nowait (
412
- self .channel_layer .channel_layer . serialize ({ "type" : mtype })
413
- )
414
-
415
- async def _ensure_redis (self ):
316
+ self . _receive_message ( message )
317
+ else :
318
+ await asyncio . sleep ( 0.1 )
319
+ except (
320
+ asyncio . CancelledError ,
321
+ asyncio . TimeoutError ,
322
+ GeneratorExit ,
323
+ ) :
324
+ raise
325
+ except BaseException :
326
+ logger . exception ( "Unexpected exception in receive task" )
327
+ await asyncio . sleep ( 1 )
328
+
329
+ def _receive_message ( self , message ):
330
+ if message is not None :
331
+ name = message [ "channel" ]
332
+ data = message [ "data" ]
333
+ if isinstance ( name , bytes ):
334
+ name = name . decode ()
335
+ if name in self . channel_layer . channels :
336
+ self . channel_layer . channels [ name ]. put_nowait ( data )
337
+ elif name in self .channel_layer .groups :
338
+ for channel_name in self . channel_layer . groups [ name ]:
339
+ if channel_name in self .channel_layer .channels :
340
+ self . channel_layer . channels [ channel_name ]. put_nowait ( data )
341
+
342
+ def _ensure_redis (self ):
416
343
if self ._redis is None :
417
344
if self .master_name is None :
418
345
pool = aioredis .ConnectionPool .from_url (self .host ["address" ])
@@ -425,40 +352,8 @@ async def _ensure_redis(self):
425
352
),
426
353
)
427
354
self ._redis = aioredis .Redis (connection_pool = pool )
355
+ self ._pubsub = self ._redis .pubsub ()
428
356
429
- async def _get_redis_conn (self ):
430
- await self ._ensure_redis ()
431
- return self ._redis
432
-
433
- async def _put_redis_conn (self , conn ):
434
- if conn :
435
- await conn .close ()
436
-
437
- async def _do_keepalive (self ):
438
- """
439
- This task's simple job is just to call `self._get_sub_conn()` periodically.
440
-
441
- Why? Well, calling `self._get_sub_conn()` has the nice side-effect that if
442
- that connection has died (because Redis was restarted, or there was a networking
443
- hiccup, for example), then calling `self._get_sub_conn()` will reconnect and
444
- restore our old subscriptions. Thus, we want to do this on a predictable schedule.
445
- This is kinda a sub-optimal way to achieve this, but I can't find a way in aioredis
446
- to get a notification when the connection dies. I find this (sub-optimal) method
447
- of checking the connection state works fine for my app; if Redis restarts, we reconnect
448
- and resubscribe *quickly enough*; I mean, Redis restarting is already bad because it
449
- will cause messages to get lost, and this periodic check at least minimizes the
450
- damage *enough*.
451
-
452
- Note you wouldn't need this if you were *sure* that there would be a lot of subscribe/
453
- unsubscribe events on your site, because such events each call `self._get_sub_conn()`.
454
- Thus, on a site with heavy traffic this task may not be necessary, but also maybe it is.
455
- Why? Well, in a heavy traffic site you probably have more than one Django server replicas,
456
- so it might be the case that one of your replicas is under-utilized and this periodic
457
- connection check will be beneficial in the same way as it is for a low-traffic site.
458
- """
459
- while True :
460
- await asyncio .sleep (1 )
461
- try :
462
- await self ._get_sub_conn ()
463
- except Exception :
464
- logger .exception ("Unexpected exception in keepalive task:" )
357
+ def _ensure_receiver (self ):
358
+ if self ._receive_task is None :
359
+ self ._receive_task = asyncio .ensure_future (self ._do_receiving ())
0 commit comments