Skip to content

Commit 11d714f

Browse files
committed
Initial crack at moving to redis-py
1 parent 9821ab1 commit 11d714f

File tree

4 files changed

+31
-30
lines changed

4 files changed

+31
-30
lines changed

channels_redis/core.py

+23-22
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import types
1212
import uuid
1313

14-
import aioredis
14+
from redis import asyncio as aioredis
1515
import msgpack
1616

1717
from channels.exceptions import ChannelFull
@@ -21,8 +21,6 @@
2121

2222
logger = logging.getLogger(__name__)
2323

24-
AIOREDIS_VERSION = tuple(map(int, aioredis.__version__.split(".")))
25-
2624

2725
def _wrap_close(loop, pool):
2826
"""
@@ -49,8 +47,9 @@ class ConnectionPool:
4947
"""
5048

5149
def __init__(self, host):
52-
self.host = host.copy()
53-
self.master_name = self.host.pop("master_name", None)
50+
self.host = host
51+
# TODO: re-add support for master_name
52+
self.master_name = None # self.host.pop("master_name", None)
5453
self.conn_map = {}
5554
self.sentinel_map = {}
5655
self.in_use = {}
@@ -72,11 +71,11 @@ def _ensure_loop(self, loop):
7271

7372
async def create_conn(self, loop):
7473
# One connection per pool since we are emulating a single connection
75-
kwargs = {"minsize": 1, "maxsize": 1, **self.host}
76-
if not (sys.version_info >= (3, 8, 0) and AIOREDIS_VERSION >= (1, 3, 1)):
77-
kwargs["loop"] = loop
74+
kwargs = {"max_connections": 1}
75+
# if not sys.version_info >= (3, 8, 0):
76+
# kwargs["loop"] = loop
7877
if self.master_name is None:
79-
return await aioredis.create_redis_pool(**kwargs)
78+
return aioredis.ConnectionPool.from_url(self.host, )
8079
else:
8180
kwargs = {"timeout": 2, **kwargs} # aioredis default is way too low
8281
sentinel = await aioredis.sentinel.create_sentinel(**kwargs)
@@ -93,7 +92,9 @@ async def pop(self, loop=None):
9392
conn = await self.create_conn(loop)
9493
conns.append(conn)
9594
conn = conns.pop()
96-
if conn.closed:
95+
# Redis ConnectionPool has no closed attribute
96+
# if conn.closed:
97+
if False:
9798
conn = await self.pop(loop=loop)
9899
return conn
99100
self.in_use[conn] = loop
@@ -131,8 +132,7 @@ async def _close_conn(self, conn, sentinel_map=None):
131132
sentinel_map[conn].close()
132133
await sentinel_map[conn].wait_closed()
133134
del sentinel_map[conn]
134-
conn.close()
135-
await conn.wait_closed()
135+
await conn.disconnect()
136136

137137
async def close_loop(self, loop):
138138
"""
@@ -279,7 +279,7 @@ def decode_hosts(self, hosts):
279279
"""
280280
# If no hosts were provided, return a default value
281281
if not hosts:
282-
return [{"address": ("localhost", 6379)}]
282+
return ["redis://localhost:6379"]
283283
# If they provided just a string, scold them.
284284
if isinstance(hosts, (str, bytes)):
285285
raise ValueError(
@@ -289,10 +289,11 @@ def decode_hosts(self, hosts):
289289
# Decode each hosts entry into a kwargs dict
290290
result = []
291291
for entry in hosts:
292+
# TODO: re-add support for dict-based connections
292293
if isinstance(entry, dict):
293294
result.append(entry)
294295
else:
295-
result.append({"address": entry})
296+
result.append(entry)
296297
return result
297298

298299
def _setup_encryption(self, symmetric_encryption_keys):
@@ -348,11 +349,11 @@ async def send(self, channel, message):
348349

349350
# Check the length of the list before send
350351
# This can allow the list to leak slightly over capacity, but that's fine.
351-
if await connection.zcount(channel_key) >= self.get_capacity(channel):
352+
if await connection.zcount(channel_key, "-inf", "+inf") >= self.get_capacity(channel):
352353
raise ChannelFull()
353354

354355
# Push onto the list then set it to expire in case it's not consumed
355-
await connection.zadd(channel_key, time.time(), self.serialize(message))
356+
await connection.zadd(channel_key, {self.serialize(message): time.time()})
356357
await connection.expire(channel_key, int(self.expiry))
357358

358359
def _backup_channel_name(self, channel):
@@ -380,13 +381,13 @@ async def _brpop_with_clean(self, index, channel, timeout):
380381
async with self.connection(index) as connection:
381382
# Cancellation here doesn't matter, we're not doing anything destructive
382383
# and the script executes atomically...
383-
await connection.eval(cleanup_script, keys=[], args=[channel, backup_queue])
384+
await connection.eval(cleanup_script, 0, channel, backup_queue)
384385
# ...and it doesn't matter here either, the message will be safe in the backup.
385386
result = await connection.bzpopmin(channel, timeout=timeout)
386387

387388
if result is not None:
388389
_, member, timestamp = result
389-
await connection.zadd(backup_queue, float(timestamp), member)
390+
await connection.zadd(backup_queue, {member: float(timestamp)})
390391
else:
391392
member = None
392393

@@ -610,7 +611,7 @@ async def flush(self):
610611
# Go through each connection and remove all with prefix
611612
for i in range(self.ring_size):
612613
async with self.connection(i) as connection:
613-
await connection.eval(delete_prefix, keys=[], args=[self.prefix + "*"])
614+
await connection.eval(delete_prefix, 0, self.prefix + "*")
614615
# Now clear the pools as well
615616
await self.close_pools()
616617

@@ -645,7 +646,7 @@ async def group_add(self, group, channel):
645646
group_key = self._group_key(group)
646647
async with self.connection(self.consistent_hash(group)) as connection:
647648
# Add to group sorted set with creation time as timestamp
648-
await connection.zadd(group_key, time.time(), channel)
649+
await connection.zadd(group_key, {channel: time.time()})
649650
# Set expiration to be group_expiry, since everything in
650651
# it at this point is guaranteed to expire before that
651652
await connection.expire(group_key, self.group_expiry)
@@ -730,7 +731,7 @@ async def group_send(self, group, message):
730731
# channel_keys does not contain a single redis key more than once
731732
async with self.connection(connection_index) as connection:
732733
channels_over_capacity = await connection.eval(
733-
group_send_lua, keys=channel_redis_keys, args=args
734+
group_send_lua, len(channel_redis_keys), *channel_redis_keys, *args
734735
)
735736
if channels_over_capacity > 0:
736737
logger.info(
@@ -900,7 +901,7 @@ def __init__(self, pool):
900901

901902
async def __aenter__(self):
902903
self.conn = await self.pool.pop()
903-
return self.conn
904+
return aioredis.Redis(connection_pool=self.conn)
904905

905906
async def __aexit__(self, exc_type, exc, tb):
906907
if exc:

channels_redis/pubsub.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import types
66
import uuid
77

8-
import aioredis
8+
from redis import asyncio as aioredis
99
import msgpack
1010

1111
from .utils import _consistent_hash
@@ -105,7 +105,7 @@ def __init__(
105105
**kwargs,
106106
):
107107
if hosts is None:
108-
hosts = [("localhost", 6379)]
108+
hosts = ["redis://localhost:6379"]
109109
assert (
110110
isinstance(hosts, list) and len(hosts) > 0
111111
), "`hosts` must be a list with at least one Redis server"
@@ -427,7 +427,7 @@ def _notify_consumers(self, mtype):
427427
async def _ensure_redis(self):
428428
if self._redis is None:
429429
if self.master_name is None:
430-
self._redis = await aioredis.create_redis_pool(**self.host)
430+
self._redis = aioredis.ConnectionPool(**self.host)
431431
else:
432432
# aioredis default timeout is way too low
433433
self._redis = await aioredis.sentinel.create_sentinel(
@@ -443,7 +443,7 @@ def _get_aioredis_pool(self):
443443
async def _get_redis_conn(self):
444444
await self._ensure_redis()
445445
conn = await self._get_aioredis_pool().acquire()
446-
return aioredis.Redis(conn)
446+
return aioredis.ConnectionPool(conn)
447447

448448
def _put_redis_conn(self, conn):
449449
if conn:

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
include_package_data=True,
3232
python_requires=">=3.6",
3333
install_requires=[
34-
"aioredis~=1.0",
34+
"redis>=4.2.0-rc1",
3535
"msgpack~=1.0",
3636
"asgiref>=3.2.10,<4",
3737
"channels<4",

tests/test_core.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from asgiref.sync import async_to_sync
99
from channels_redis.core import ChannelFull, ConnectionPool, RedisChannelLayer
1010

11-
TEST_HOSTS = [("localhost", 6379)]
11+
TEST_HOSTS = ["redis://localhost:6379"]
1212

1313
MULTIPLE_TEST_HOSTS = [
1414
"redis://localhost:6379/0",
@@ -411,11 +411,11 @@ async def test_connection_pool_pop():
411411
"""
412412

413413
# Setup scenario
414-
connection_pool = ConnectionPool({"address": TEST_HOSTS[0]})
414+
connection_pool = ConnectionPool(TEST_HOSTS[0])
415415
conn = await connection_pool.pop()
416416

417417
# Emualte a disconnect and return it to the pool
418-
conn.close()
418+
conn.disconnect()
419419
assert conn.closed
420420
connection_pool.push(conn)
421421

0 commit comments

Comments
 (0)