|
14 | 14 | from __future__ import annotations
|
15 | 15 |
|
16 | 16 | import random
|
17 |
| - |
18 | 17 | from collections import namedtuple
|
| 18 | +from functools import partial |
19 | 19 | from itertools import islice, cycle, groupby, repeat
|
20 | 20 | import logging
|
21 | 21 | from random import randint, shuffle
|
22 | 22 | from threading import Lock
|
23 | 23 | import socket
|
24 | 24 | import warnings
|
25 |
| -from typing import Callable, TYPE_CHECKING |
| 25 | +from typing import Callable, TYPE_CHECKING, Iterator, List, Tuple |
26 | 26 | from abc import ABC, abstractmethod
|
27 | 27 | from cassandra import WriteType as WT
|
28 | 28 |
|
@@ -999,6 +999,245 @@ def shutdown(self):
|
999 | 999 | self.is_shutdown = True
|
1000 | 1000 |
|
1001 | 1001 |
|
| 1002 | +class ShardConnectionBackoffSchedule(ABC): |
| 1003 | + @abstractmethod |
| 1004 | + def new_schedule(self) -> Iterator[float]: |
| 1005 | + """ |
| 1006 | + This should return a finite or infinite iterable of delays (each as a |
| 1007 | + floating point number of seconds). |
| 1008 | + Note that if the iterable is finite, schedule will be recreated right after iterable is exhausted. |
| 1009 | + """ |
| 1010 | + raise NotImplementedError() |
| 1011 | + |
| 1012 | + |
| 1013 | +class ConstantShardConnectionBackoffSchedule(ShardConnectionBackoffSchedule): |
| 1014 | + """ |
| 1015 | + A :class:`.ShardConnectionBackoffSchedule` subclass which introduce a constant delay with jitter |
| 1016 | + between shard connections. |
| 1017 | + """ |
| 1018 | + |
| 1019 | + def __init__(self, delay: float, jitter: float = 0.0): |
| 1020 | + """ |
| 1021 | + `delay` should be a floating point number of seconds to wait in-between |
| 1022 | + each connection attempt. |
| 1023 | +
|
| 1024 | + `jitter` is a random jitter in seconds. |
| 1025 | + """ |
| 1026 | + if delay < 0: |
| 1027 | + raise ValueError("delay must not be negative") |
| 1028 | + if jitter < 0: |
| 1029 | + raise ValueError("jitter must not be negative") |
| 1030 | + |
| 1031 | + self.delay = delay |
| 1032 | + self.jitter = jitter |
| 1033 | + |
| 1034 | + def new_schedule(self): |
| 1035 | + if self.jitter == 0: |
| 1036 | + yield from repeat(self.delay) |
| 1037 | + def iterator(): |
| 1038 | + while True: |
| 1039 | + yield self.delay + random.uniform(0.0, self.jitter) |
| 1040 | + return iterator() |
| 1041 | + |
| 1042 | + |
| 1043 | +class LimitedConcurrencyShardConnectionBackoffPolicy(ShardConnectionBackoffPolicy): |
| 1044 | + """ |
| 1045 | + A shard connection backoff policy that allows only `max_concurrent` concurrent connections per `host_id`. |
| 1046 | +
|
| 1047 | + For backoff calculation, it requires either a `cassandra.policies.ShardConnectionBackoffSchedule` or |
| 1048 | + a `cassandra.policies.ReconnectionPolicy`, as both expose the same API. |
| 1049 | +
|
| 1050 | + It spawns a worker when there are pending requests, maximum number of workers is `max_concurrent` multiplied by nodes in the cluster. |
| 1051 | + When worker is spawn it initiates backoff schedule, which is local for this worker. |
| 1052 | + If there are no remaining requests for that `host_id`, worker is stopped. |
| 1053 | +
|
| 1054 | + This policy also prevents multiple pending or scheduled connections for the same (host, shard) pair; |
| 1055 | + any duplicate attempts to schedule a connection are silently ignored. |
| 1056 | + """ |
| 1057 | + backoff_policy: ShardConnectionBackoffSchedule | ReconnectionPolicy |
| 1058 | + |
| 1059 | + max_concurrent: int |
| 1060 | + """ |
| 1061 | + Max concurrent connection creation requests per scope. |
| 1062 | + """ |
| 1063 | + |
| 1064 | + def __init__( |
| 1065 | + self, |
| 1066 | + backoff_policy: ShardConnectionBackoffSchedule | ReconnectionPolicy, |
| 1067 | + max_concurrent: int = 1, |
| 1068 | + ): |
| 1069 | + if not isinstance(backoff_policy, (ShardConnectionBackoffSchedule, ReconnectionPolicy)): |
| 1070 | + raise ValueError("backoff_policy must be a ShardConnectionBackoffSchedule or ReconnectionPolicy") |
| 1071 | + if max_concurrent < 1: |
| 1072 | + raise ValueError("max_concurrent must be a positive integer") |
| 1073 | + self.backoff_policy = backoff_policy |
| 1074 | + self.max_concurrent = max_concurrent |
| 1075 | + |
| 1076 | + def new_connection_scheduler(self, scheduler: _Scheduler) -> ShardConnectionScheduler: |
| 1077 | + return _LimitedConcurrencyShardConnectionScheduler(scheduler, self.backoff_policy, self.max_concurrent) |
| 1078 | + |
| 1079 | + |
| 1080 | +class _ScopeBucket: |
| 1081 | + """ |
| 1082 | + Holds information for a shard connection backoff policy scope. |
| 1083 | + Responsible for scheduling and executing requests to create connections, |
| 1084 | + while ensuring that no more than `max_concurrent` requests run at the same time. |
| 1085 | +
|
| 1086 | + When `schedule_new_connection` is called, it adds an item to the queue. If `workers_count` is less than `max_concurrent`, it starts a new worker. |
| 1087 | +
|
| 1088 | + A worker is an instance of `_thread_body`. |
| 1089 | + Once `_thread_body` finishes processing the current item, it schedules the |
| 1090 | + next execution using the same schedule. If the queue is empty, the worker stops — it does not schedule another `_thread_body` execution and decrements `workers_count`. |
| 1091 | + """ |
| 1092 | + session: _Scheduler |
| 1093 | + backoff_policy: ShardConnectionBackoffSchedule |
| 1094 | + lock: Lock |
| 1095 | + is_shutdown: bool = False |
| 1096 | + |
| 1097 | + max_concurrent: int |
| 1098 | + """ |
| 1099 | + Max concurrent connection creation requests in the scope. |
| 1100 | + """ |
| 1101 | + |
| 1102 | + workers_count: int |
| 1103 | + """ |
| 1104 | + Number of currently pending workers. |
| 1105 | + """ |
| 1106 | + |
| 1107 | + items: List[Callable[[], None]] |
| 1108 | + """ |
| 1109 | + List of scheduled create connections requests. |
| 1110 | + """ |
| 1111 | + |
| 1112 | + def __init__( |
| 1113 | + self, |
| 1114 | + scheduler: _Scheduler, |
| 1115 | + backoff_policy: ShardConnectionBackoffSchedule, |
| 1116 | + max_concurrent: int, |
| 1117 | + ): |
| 1118 | + self.items = [] |
| 1119 | + self.scheduler = scheduler |
| 1120 | + self.backoff_policy = backoff_policy |
| 1121 | + self.lock = Lock() |
| 1122 | + self.max_concurrent = max_concurrent |
| 1123 | + self.workers_count = 0 |
| 1124 | + |
| 1125 | + def _get_delay(self, schedule: Iterator[float]) -> Tuple[Iterator[float], float]: |
| 1126 | + try: |
| 1127 | + return schedule, next(schedule) |
| 1128 | + except StopIteration: |
| 1129 | + schedule = self.backoff_policy.new_schedule() |
| 1130 | + return schedule, next(schedule) |
| 1131 | + |
| 1132 | + def _worker_body(self, schedule: Iterator[float]): |
| 1133 | + if self.is_shutdown: |
| 1134 | + return |
| 1135 | + |
| 1136 | + with self.lock: |
| 1137 | + try: |
| 1138 | + request = self.items.pop(0) |
| 1139 | + except IndexError: |
| 1140 | + # Just in case |
| 1141 | + if self.workers_count > 0: |
| 1142 | + self.workers_count -= 1 |
| 1143 | + # When items are exhausted reset schedule to ensure that new items going to get another schedule |
| 1144 | + # It is important for exponential policy |
| 1145 | + return |
| 1146 | + |
| 1147 | + try: |
| 1148 | + request() |
| 1149 | + finally: |
| 1150 | + schedule, delay = self._get_delay(schedule) |
| 1151 | + self.scheduler.schedule(delay, self._worker_body, schedule) |
| 1152 | + |
| 1153 | + def _start_new_worker(self): |
| 1154 | + self.workers_count += 1 |
| 1155 | + schedule = self.backoff_policy.new_schedule() |
| 1156 | + delay = next(schedule) |
| 1157 | + self.scheduler.schedule(delay, self._worker_body, schedule) |
| 1158 | + |
| 1159 | + def schedule_new_connection(self, cb: Callable[[], None]): |
| 1160 | + with self.lock: |
| 1161 | + if self.is_shutdown: |
| 1162 | + return |
| 1163 | + self.items.append(cb) |
| 1164 | + if self.workers_count < self.max_concurrent: |
| 1165 | + self._start_new_worker() |
| 1166 | + |
| 1167 | + def shutdown(self): |
| 1168 | + with self.lock: |
| 1169 | + self.is_shutdown = True |
| 1170 | + |
| 1171 | + |
| 1172 | +class _LimitedConcurrencyShardConnectionScheduler(ShardConnectionScheduler): |
| 1173 | + """ |
| 1174 | + A scheduler for ``cassandra.policies.LimitedConcurrencyShardConnectionPolicy``. |
| 1175 | +
|
| 1176 | + Limits concurrency for connection creation requests to ``max_concurrent`` per host_id. |
| 1177 | + """ |
| 1178 | + |
| 1179 | + already_scheduled: set[tuple[str, int]] |
| 1180 | + """ |
| 1181 | + Set of (host_id, shard_id) of scheduled or pending requests. |
| 1182 | + """ |
| 1183 | + |
| 1184 | + per_host_scope: dict[str, _ScopeBucket] |
| 1185 | + """ |
| 1186 | + Scopes storage, key is host_id, value is an instance that holds scope data. |
| 1187 | + """ |
| 1188 | + |
| 1189 | + backoff_policy: ShardConnectionBackoffSchedule |
| 1190 | + scheduler: _Scheduler |
| 1191 | + lock: Lock |
| 1192 | + is_shutdown: bool = False |
| 1193 | + |
| 1194 | + max_concurrent: int |
| 1195 | + """ |
| 1196 | + Max concurrent connection creation requests per host_id. |
| 1197 | + """ |
| 1198 | + |
| 1199 | + def __init__( |
| 1200 | + self, |
| 1201 | + scheduler: _Scheduler, |
| 1202 | + backoff_policy: ShardConnectionBackoffSchedule, |
| 1203 | + max_concurrent: int, |
| 1204 | + ): |
| 1205 | + self.already_scheduled = set() |
| 1206 | + self.per_host_scope = {} |
| 1207 | + self.backoff_policy = backoff_policy |
| 1208 | + self.max_concurrent = max_concurrent |
| 1209 | + self.scheduler = scheduler |
| 1210 | + self.lock = Lock() |
| 1211 | + |
| 1212 | + def _execute(self, host_id: str, shard_id: int, method: Callable[[], None]): |
| 1213 | + if self.is_shutdown: |
| 1214 | + return |
| 1215 | + try: |
| 1216 | + method() |
| 1217 | + finally: |
| 1218 | + with self.lock: |
| 1219 | + self.already_scheduled.remove((host_id, shard_id)) |
| 1220 | + |
| 1221 | + def schedule(self, host_id: str, shard_id: int, method: Callable[[], None]) -> bool: |
| 1222 | + with self.lock: |
| 1223 | + if self.is_shutdown or (host_id, shard_id) in self.already_scheduled: |
| 1224 | + return False |
| 1225 | + self.already_scheduled.add((host_id, shard_id)) |
| 1226 | + |
| 1227 | + scope_info = self.per_host_scope.get(host_id) |
| 1228 | + if not scope_info: |
| 1229 | + scope_info = _ScopeBucket(self.scheduler, self.backoff_policy, self.max_concurrent) |
| 1230 | + self.per_host_scope[host_id] = scope_info |
| 1231 | + scope_info.schedule_new_connection(partial(self._execute, host_id, shard_id, method)) |
| 1232 | + return True |
| 1233 | + |
| 1234 | + def shutdown(self): |
| 1235 | + with self.lock: |
| 1236 | + self.is_shutdown = True |
| 1237 | + for scope in self.per_host_scope.values(): |
| 1238 | + scope.shutdown() |
| 1239 | + |
| 1240 | + |
1002 | 1241 | class RetryPolicy(object):
|
1003 | 1242 | """
|
1004 | 1243 | A policy that describes whether to retry, rethrow, or ignore coordinator
|
|
0 commit comments