Skip to content

Commit 5ba56ef

Browse files
committed
fix: make sure scan iterator commands are always issued to the same replica
1 parent 8403ddc commit 5ba56ef

File tree

7 files changed

+1011
-68
lines changed

7 files changed

+1011
-68
lines changed

redis/asyncio/client.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,12 +717,31 @@ async def execute_command(self, *args, **options):
717717
if self.single_connection_client:
718718
await self._single_conn_lock.acquire()
719719
try:
720-
return await conn.retry.call_with_retry(
720+
result = await conn.retry.call_with_retry(
721721
lambda: self._send_command_parse_response(
722722
conn, command_name, *args, **options
723723
),
724724
lambda _: self._close_connection(conn),
725725
)
726+
727+
# Clean up iter_req_id for SCAN family commands when the cursor returns to 0
728+
iter_req_id = options.get("iter_req_id")
729+
if iter_req_id and command_name.upper() in (
730+
"SCAN",
731+
"SSCAN",
732+
"HSCAN",
733+
"ZSCAN",
734+
):
735+
# If the result is a tuple with cursor as the first element and cursor is 0, cleanup
736+
if (
737+
isinstance(result, (list, tuple))
738+
and len(result) >= 2
739+
and result[0] == 0
740+
):
741+
if hasattr(pool, "cleanup"):
742+
await pool.cleanup(iter_req_id)
743+
744+
return result
726745
finally:
727746
if self.single_connection_client:
728747
self._single_conn_lock.release()

redis/asyncio/sentinel.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
import asyncio
22
import random
33
import weakref
4-
from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type
4+
from typing import (
5+
AsyncIterator,
6+
Dict,
7+
Iterable,
8+
Mapping,
9+
Optional,
10+
Sequence,
11+
Tuple,
12+
Type,
13+
)
514

615
from redis.asyncio.client import Redis
716
from redis.asyncio.connection import (
@@ -17,6 +26,7 @@
1726
ResponseError,
1827
TimeoutError,
1928
)
29+
from redis.utils import deprecated_args
2030

2131

2232
class MasterNotFoundError(ConnectionError):
@@ -121,6 +131,7 @@ def __init__(self, service_name, sentinel_manager, **kwargs):
121131
self.sentinel_manager = sentinel_manager
122132
self.master_address = None
123133
self.slave_rr_counter = None
134+
self._iter_req_connections: Dict[str, tuple] = {}
124135

125136
def __repr__(self):
126137
return (
@@ -166,6 +177,57 @@ async def rotate_slaves(self) -> AsyncIterator:
166177
pass
167178
raise SlaveNotFoundError(f"No slave found for {self.service_name!r}")
168179

180+
async def cleanup(self, iter_req_id: str):
181+
"""Remove tracking for a completed iteration request."""
182+
self._iter_req_connections.pop(iter_req_id, None)
183+
184+
@deprecated_args(
185+
args_to_warn=["*"],
186+
reason="Use get_connection() without args instead",
187+
version="5.3.0",
188+
)
189+
async def get_connection(self, command_name=None, *keys, **options):
190+
"""
191+
Get a connection from the pool, with special handling for scan commands.
192+
193+
For scan commands with iter_req_id, ensures the same replica is used
194+
throughout the iteration to maintain cursor consistency.
195+
"""
196+
iter_req_id = options.get("iter_req_id")
197+
198+
# For scan commands with iter_req_id, ensure we use the same replica
199+
if iter_req_id and not self.is_master:
200+
# Check if we've already established a connection for this iteration
201+
if iter_req_id in self._iter_req_connections:
202+
target_address = self._iter_req_connections[iter_req_id]
203+
connection = await super().get_connection()
204+
# If the connection doesn't match our target, try to get the right one
205+
if (connection.host, connection.port) != target_address:
206+
# Release this connection and try to find one for the target replica
207+
await self.release(connection)
208+
# For now, use the connection we got and update tracking
209+
connection = await super().get_connection()
210+
await connection.connect_to(target_address)
211+
return connection
212+
else:
213+
# First time for this iter_req_id, get a connection and track its replica
214+
connection = await super().get_connection()
215+
# Get the replica address this connection will use
216+
if hasattr(connection, "connect_to"):
217+
# Let the connection establish to its target replica
218+
try:
219+
replica_address = await self.rotate_slaves().__anext__()
220+
await connection.connect_to(replica_address)
221+
# Track this replica for future requests with this iter_req_id
222+
self._iter_req_connections[iter_req_id] = replica_address
223+
except (SlaveNotFoundError, StopAsyncIteration):
224+
# Fallback to normal connection if no slaves available
225+
pass
226+
return connection
227+
228+
# For non-scan commands or master connections, use normal behavior
229+
return await super().get_connection()
230+
169231

170232
class Sentinel(AsyncSentinelCommands):
171233
"""

redis/client.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,13 +658,31 @@ def _execute_command(self, *args, **options):
658658
if self._single_connection_client:
659659
self.single_connection_lock.acquire()
660660
try:
661-
return conn.retry.call_with_retry(
661+
result = conn.retry.call_with_retry(
662662
lambda: self._send_command_parse_response(
663663
conn, command_name, *args, **options
664664
),
665665
lambda _: self._close_connection(conn),
666666
)
667667

668+
# Clean up iter_req_id for SCAN family commands when the cursor returns to 0
669+
iter_req_id = options.get("iter_req_id")
670+
if iter_req_id and command_name.upper() in (
671+
"SCAN",
672+
"SSCAN",
673+
"HSCAN",
674+
"ZSCAN",
675+
):
676+
if (
677+
isinstance(result, (list, tuple))
678+
and len(result) >= 2
679+
and result[0] == 0
680+
):
681+
if hasattr(pool, "cleanup"):
682+
pool.cleanup(iter_req_id)
683+
684+
return result
685+
668686
finally:
669687
if conn and conn.should_reconnect():
670688
self._close_connection(conn)

0 commit comments

Comments
 (0)