Skip to content

Commit 6fc2a3c

Browse files
jalazizcarltongibson
authored and
Shi Feng
committed
Fix workers support when using Redis PubSub layer (django#298)
* Fix workers support when using Redis PubSub layer The new Redis PubSub layer broke support for Channels workers. Add support for workers by subscribing to non-owned channels instead of throwing an exception. Co-authored-by: Carlton Gibson <[email protected]>
1 parent 2117830 commit 6fc2a3c

File tree

2 files changed

+42
-6
lines changed

2 files changed

+42
-6
lines changed

channels_redis/pubsub.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ def _get_group_channel_name(self, group):
138138
"""
139139
return f"{self.prefix}__group__{group}"
140140

141+
async def _subscribe_to_channel(self, channel):
142+
self.channels[channel] = asyncio.Queue()
143+
shard = self._get_shard(channel)
144+
await shard.subscribe(channel)
145+
141146
extensions = ["groups", "flush"]
142147

143148
################################################################################
@@ -157,9 +162,7 @@ async def new_channel(self, prefix="specific."):
157162
process as a specific channel.
158163
"""
159164
channel = f"{self.prefix}{prefix}{uuid.uuid4().hex}"
160-
self.channels[channel] = asyncio.Queue()
161-
shard = self._get_shard(channel)
162-
await shard.subscribe(channel)
165+
await self._subscribe_to_channel(channel)
163166
return channel
164167

165168
async def receive(self, channel):
@@ -169,9 +172,7 @@ async def receive(self, channel):
169172
of the waiting coroutines will get the result.
170173
"""
171174
if channel not in self.channels:
172-
raise RuntimeError(
173-
'You should only call receive() on channels that you "own" and that were created with `new_channel()`.'
174-
)
175+
await self._subscribe_to_channel(channel)
175176

176177
q = self.channels[channel]
177178

tests/test_pubsub.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@ async def channel_layer():
2424
await channel_layer.flush()
2525

2626

27+
@pytest.fixture()
28+
@async_generator
29+
async def other_channel_layer():
30+
"""
31+
Channel layer fixture that flushes automatically.
32+
"""
33+
channel_layer = RedisPubSubChannelLayer(hosts=TEST_HOSTS)
34+
await yield_(channel_layer)
35+
await channel_layer.flush()
36+
37+
2738
@pytest.mark.asyncio
2839
async def test_send_receive(channel_layer):
2940
"""
@@ -118,6 +129,30 @@ async def test_groups_same_prefix(channel_layer):
118129
assert (await channel_layer.receive(channel_name3))["type"] == "message.1"
119130

120131

132+
@pytest.mark.asyncio
133+
async def test_receive_on_non_owned_general_channel(channel_layer, other_channel_layer):
134+
"""
135+
Tests receive with general channel that is not owned by the layer
136+
"""
137+
receive_started = asyncio.Event()
138+
139+
async def receive():
140+
receive_started.set()
141+
return await other_channel_layer.receive("test-channel")
142+
143+
receive_task = asyncio.create_task(receive())
144+
await receive_started.wait()
145+
await asyncio.sleep(0.1) # Need to give time for "receive" to subscribe
146+
await channel_layer.send("test-channel", "message.1")
147+
148+
try:
149+
# Make sure we get the message on the channels that were in
150+
async with async_timeout.timeout(1):
151+
assert await receive_task == "message.1"
152+
finally:
153+
receive_task.cancel()
154+
155+
121156
@pytest.mark.asyncio
122157
async def test_random_reset__channel_name(channel_layer):
123158
"""

0 commit comments

Comments
 (0)