Skip to content

Commit a911830

Browse files
committed
Implement LimitedConcurrencyShardConnectionBackoffPolicy
This policy is an implementation of `ShardConnectionBackoffPolicy`. Its primary purpose is to prevent connection storms by imposing restrictions on the number of concurrent pending connections per host and backoff time between each connection attempt.
1 parent eb8d396 commit a911830

File tree

3 files changed

+430
-8
lines changed

3 files changed

+430
-8
lines changed

cassandra/policies.py

Lines changed: 241 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
from __future__ import annotations
1515

1616
import random
17-
1817
from collections import namedtuple
18+
from functools import partial
1919
from itertools import islice, cycle, groupby, repeat
2020
import logging
2121
from random import randint, shuffle
2222
from threading import Lock
2323
import socket
2424
import warnings
25-
from typing import Callable, TYPE_CHECKING
25+
from typing import Callable, TYPE_CHECKING, Iterator, List, Tuple
2626
from abc import ABC, abstractmethod
2727
from cassandra import WriteType as WT
2828

@@ -999,6 +999,245 @@ def shutdown(self):
999999
self.is_shutdown = True
10001000

10011001

1002+
class ShardConnectionBackoffSchedule(ABC):
1003+
@abstractmethod
1004+
def new_schedule(self) -> Iterator[float]:
1005+
"""
1006+
This should return a finite or infinite iterable of delays (each as a
1007+
floating point number of seconds).
1008+
Note that if the iterable is finite, schedule will be recreated right after iterable is exhausted.
1009+
"""
1010+
raise NotImplementedError()
1011+
1012+
1013+
class ConstantShardConnectionBackoffSchedule(ShardConnectionBackoffSchedule):
1014+
"""
1015+
A :class:`.ShardConnectionBackoffSchedule` subclass which introduce a constant delay with jitter
1016+
between shard connections.
1017+
"""
1018+
1019+
def __init__(self, delay: float, jitter: float = 0.0):
1020+
"""
1021+
`delay` should be a floating point number of seconds to wait in-between
1022+
each connection attempt.
1023+
1024+
`jitter` is a random jitter in seconds.
1025+
"""
1026+
if delay < 0:
1027+
raise ValueError("delay must not be negative")
1028+
if jitter < 0:
1029+
raise ValueError("jitter must not be negative")
1030+
1031+
self.delay = delay
1032+
self.jitter = jitter
1033+
1034+
def new_schedule(self):
1035+
if self.jitter == 0:
1036+
yield from repeat(self.delay)
1037+
def iterator():
1038+
while True:
1039+
yield self.delay + random.uniform(0.0, self.jitter)
1040+
return iterator()
1041+
1042+
1043+
class LimitedConcurrencyShardConnectionBackoffPolicy(ShardConnectionBackoffPolicy):
1044+
"""
1045+
A shard connection backoff policy that allows only `max_concurrent` concurrent connections per `host_id`.
1046+
1047+
For backoff calculation, it requires either a `cassandra.policies.ShardConnectionBackoffSchedule` or
1048+
a `cassandra.policies.ReconnectionPolicy`, as both expose the same API.
1049+
1050+
It spawns a worker when there are pending requests, maximum number of workers is `max_concurrent` multiplied by nodes in the cluster.
1051+
When worker is spawn it initiates backoff schedule, which is local for this worker.
1052+
If there are no remaining requests for that `host_id`, worker is stopped.
1053+
1054+
This policy also prevents multiple pending or scheduled connections for the same (host, shard) pair;
1055+
any duplicate attempts to schedule a connection are silently ignored.
1056+
"""
1057+
backoff_policy: ShardConnectionBackoffSchedule | ReconnectionPolicy
1058+
1059+
max_concurrent: int
1060+
"""
1061+
Max concurrent connection creation requests per scope.
1062+
"""
1063+
1064+
def __init__(
1065+
self,
1066+
backoff_policy: ShardConnectionBackoffSchedule | ReconnectionPolicy,
1067+
max_concurrent: int = 1,
1068+
):
1069+
if not isinstance(backoff_policy, (ShardConnectionBackoffSchedule, ReconnectionPolicy)):
1070+
raise ValueError("backoff_policy must be a ShardConnectionBackoffSchedule or ReconnectionPolicy")
1071+
if max_concurrent < 1:
1072+
raise ValueError("max_concurrent must be a positive integer")
1073+
self.backoff_policy = backoff_policy
1074+
self.max_concurrent = max_concurrent
1075+
1076+
def new_connection_scheduler(self, scheduler: _Scheduler) -> ShardConnectionScheduler:
1077+
return _LimitedConcurrencyShardConnectionScheduler(scheduler, self.backoff_policy, self.max_concurrent)
1078+
1079+
1080+
class _ScopeBucket:
1081+
"""
1082+
Holds information for a shard connection backoff policy scope.
1083+
Responsible for scheduling and executing requests to create connections,
1084+
while ensuring that no more than `max_concurrent` requests run at the same time.
1085+
1086+
When `schedule_new_connection` is called, it adds an item to the queue. If `workers_count` is less than `max_concurrent`, it starts a new worker.
1087+
1088+
A worker is an instance of `_thread_body`.
1089+
Once `_thread_body` finishes processing the current item, it schedules the
1090+
next execution using the same schedule. If the queue is empty, the worker stops — it does not schedule another `_thread_body` execution and decrements `workers_count`.
1091+
"""
1092+
session: _Scheduler
1093+
backoff_policy: ShardConnectionBackoffSchedule
1094+
lock: Lock
1095+
is_shutdown: bool = False
1096+
1097+
max_concurrent: int
1098+
"""
1099+
Max concurrent connection creation requests in the scope.
1100+
"""
1101+
1102+
workers_count: int
1103+
"""
1104+
Number of currently pending workers.
1105+
"""
1106+
1107+
items: List[Callable[[], None]]
1108+
"""
1109+
List of scheduled create connections requests.
1110+
"""
1111+
1112+
def __init__(
1113+
self,
1114+
scheduler: _Scheduler,
1115+
backoff_policy: ShardConnectionBackoffSchedule,
1116+
max_concurrent: int,
1117+
):
1118+
self.items = []
1119+
self.scheduler = scheduler
1120+
self.backoff_policy = backoff_policy
1121+
self.lock = Lock()
1122+
self.max_concurrent = max_concurrent
1123+
self.workers_count = 0
1124+
1125+
def _get_delay(self, schedule: Iterator[float]) -> Tuple[Iterator[float], float]:
1126+
try:
1127+
return schedule, next(schedule)
1128+
except StopIteration:
1129+
schedule = self.backoff_policy.new_schedule()
1130+
return schedule, next(schedule)
1131+
1132+
def _worker_body(self, schedule: Iterator[float]):
1133+
if self.is_shutdown:
1134+
return
1135+
1136+
with self.lock:
1137+
try:
1138+
request = self.items.pop(0)
1139+
except IndexError:
1140+
# Just in case
1141+
if self.workers_count > 0:
1142+
self.workers_count -= 1
1143+
# When items are exhausted reset schedule to ensure that new items going to get another schedule
1144+
# It is important for exponential policy
1145+
return
1146+
1147+
try:
1148+
request()
1149+
finally:
1150+
schedule, delay = self._get_delay(schedule)
1151+
self.scheduler.schedule(delay, self._worker_body, schedule)
1152+
1153+
def _start_new_worker(self):
1154+
self.workers_count += 1
1155+
schedule = self.backoff_policy.new_schedule()
1156+
delay = next(schedule)
1157+
self.scheduler.schedule(delay, self._worker_body, schedule)
1158+
1159+
def schedule_new_connection(self, cb: Callable[[], None]):
1160+
with self.lock:
1161+
if self.is_shutdown:
1162+
return
1163+
self.items.append(cb)
1164+
if self.workers_count < self.max_concurrent:
1165+
self._start_new_worker()
1166+
1167+
def shutdown(self):
1168+
with self.lock:
1169+
self.is_shutdown = True
1170+
1171+
1172+
class _LimitedConcurrencyShardConnectionScheduler(ShardConnectionScheduler):
1173+
"""
1174+
A scheduler for ``cassandra.policies.LimitedConcurrencyShardConnectionPolicy``.
1175+
1176+
Limits concurrency for connection creation requests to ``max_concurrent`` per host_id.
1177+
"""
1178+
1179+
already_scheduled: set[tuple[str, int]]
1180+
"""
1181+
Set of (host_id, shard_id) of scheduled or pending requests.
1182+
"""
1183+
1184+
per_host_scope: dict[str, _ScopeBucket]
1185+
"""
1186+
Scopes storage, key is host_id, value is an instance that holds scope data.
1187+
"""
1188+
1189+
backoff_policy: ShardConnectionBackoffSchedule
1190+
scheduler: _Scheduler
1191+
lock: Lock
1192+
is_shutdown: bool = False
1193+
1194+
max_concurrent: int
1195+
"""
1196+
Max concurrent connection creation requests per host_id.
1197+
"""
1198+
1199+
def __init__(
1200+
self,
1201+
scheduler: _Scheduler,
1202+
backoff_policy: ShardConnectionBackoffSchedule,
1203+
max_concurrent: int,
1204+
):
1205+
self.already_scheduled = set()
1206+
self.per_host_scope = {}
1207+
self.backoff_policy = backoff_policy
1208+
self.max_concurrent = max_concurrent
1209+
self.scheduler = scheduler
1210+
self.lock = Lock()
1211+
1212+
def _execute(self, host_id: str, shard_id: int, method: Callable[[], None]):
1213+
if self.is_shutdown:
1214+
return
1215+
try:
1216+
method()
1217+
finally:
1218+
with self.lock:
1219+
self.already_scheduled.remove((host_id, shard_id))
1220+
1221+
def schedule(self, host_id: str, shard_id: int, method: Callable[[], None]) -> bool:
1222+
with self.lock:
1223+
if self.is_shutdown or (host_id, shard_id) in self.already_scheduled:
1224+
return False
1225+
self.already_scheduled.add((host_id, shard_id))
1226+
1227+
scope_info = self.per_host_scope.get(host_id)
1228+
if not scope_info:
1229+
scope_info = _ScopeBucket(self.scheduler, self.backoff_policy, self.max_concurrent)
1230+
self.per_host_scope[host_id] = scope_info
1231+
scope_info.schedule_new_connection(partial(self._execute, host_id, shard_id, method))
1232+
return True
1233+
1234+
def shutdown(self):
1235+
with self.lock:
1236+
self.is_shutdown = True
1237+
for scope in self.per_host_scope.values():
1238+
scope.shutdown()
1239+
1240+
10021241
class RetryPolicy(object):
10031242
"""
10041243
A policy that describes whether to retry, rethrow, or ignore coordinator

0 commit comments

Comments
 (0)