|
1 | 1 | import asyncio
|
2 | 2 | import random
|
3 | 3 | import weakref
|
4 |
| -from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type |
| 4 | +from typing import ( |
| 5 | + AsyncIterator, |
| 6 | + Dict, |
| 7 | + Iterable, |
| 8 | + Mapping, |
| 9 | + Optional, |
| 10 | + Sequence, |
| 11 | + Tuple, |
| 12 | + Type, |
| 13 | +) |
5 | 14 |
|
6 | 15 | from redis.asyncio.client import Redis
|
7 | 16 | from redis.asyncio.connection import (
|
|
17 | 26 | ResponseError,
|
18 | 27 | TimeoutError,
|
19 | 28 | )
|
| 29 | +from redis.utils import deprecated_args |
20 | 30 |
|
21 | 31 |
|
22 | 32 | class MasterNotFoundError(ConnectionError):
|
@@ -121,6 +131,7 @@ def __init__(self, service_name, sentinel_manager, **kwargs):
|
121 | 131 | self.sentinel_manager = sentinel_manager
|
122 | 132 | self.master_address = None
|
123 | 133 | self.slave_rr_counter = None
|
| 134 | + self._iter_req_connections: Dict[str, tuple] = {} |
124 | 135 |
|
125 | 136 | def __repr__(self):
|
126 | 137 | return (
|
@@ -166,6 +177,57 @@ async def rotate_slaves(self) -> AsyncIterator:
|
166 | 177 | pass
|
167 | 178 | raise SlaveNotFoundError(f"No slave found for {self.service_name!r}")
|
168 | 179 |
|
| 180 | + async def cleanup(self, iter_req_id: str): |
| 181 | + """Remove tracking for a completed iteration request.""" |
| 182 | + self._iter_req_connections.pop(iter_req_id, None) |
| 183 | + |
| 184 | + @deprecated_args( |
| 185 | + args_to_warn=["*"], |
| 186 | + reason="Use get_connection() without args instead", |
| 187 | + version="5.3.0", |
| 188 | + ) |
| 189 | + async def get_connection(self, command_name=None, *keys, **options): |
| 190 | + """ |
| 191 | + Get a connection from the pool, with special handling for scan commands. |
| 192 | +
|
| 193 | + For scan commands with iter_req_id, ensures the same replica is used |
| 194 | + throughout the iteration to maintain cursor consistency. |
| 195 | + """ |
| 196 | + iter_req_id = options.get("iter_req_id") |
| 197 | + |
| 198 | + # For scan commands with iter_req_id, ensure we use the same replica |
| 199 | + if iter_req_id and not self.is_master: |
| 200 | + # Check if we've already established a connection for this iteration |
| 201 | + if iter_req_id in self._iter_req_connections: |
| 202 | + target_address = self._iter_req_connections[iter_req_id] |
| 203 | + connection = await super().get_connection() |
| 204 | + # If the connection doesn't match our target, try to get the right one |
| 205 | + if (connection.host, connection.port) != target_address: |
| 206 | + # Release this connection and try to find one for the target replica |
| 207 | + await self.release(connection) |
| 208 | + # For now, use the connection we got and update tracking |
| 209 | + connection = await super().get_connection() |
| 210 | + await connection.connect_to(target_address) |
| 211 | + return connection |
| 212 | + else: |
| 213 | + # First time for this iter_req_id, get a connection and track its replica |
| 214 | + connection = await super().get_connection() |
| 215 | + # Get the replica address this connection will use |
| 216 | + if hasattr(connection, "connect_to"): |
| 217 | + # Let the connection establish to its target replica |
| 218 | + try: |
| 219 | + replica_address = await self.rotate_slaves().__anext__() |
| 220 | + await connection.connect_to(replica_address) |
| 221 | + # Track this replica for future requests with this iter_req_id |
| 222 | + self._iter_req_connections[iter_req_id] = replica_address |
| 223 | + except (SlaveNotFoundError, StopAsyncIteration): |
| 224 | + # Fallback to normal connection if no slaves available |
| 225 | + pass |
| 226 | + return connection |
| 227 | + |
| 228 | + # For non-scan commands or master connections, use normal behavior |
| 229 | + return await super().get_connection() |
| 230 | + |
169 | 231 |
|
170 | 232 | class Sentinel(AsyncSentinelCommands):
|
171 | 233 | """
|
|
0 commit comments