Skip to content

Commit eb8d396

Browse files
committed
Integrate ShardConnectionBackoffPolicy
Add code that integrates ShardConnectionBackoffPolicy into: 1. Cluster 2. Session 3. HostConnection Main idea is to put ShardConnectionBackoffPolicy in control of shard connection creation proccess. Removing duplicate logic from HostConnection that tracks pending connection creation requests.
1 parent a7294c0 commit eb8d396

File tree

4 files changed

+119
-51
lines changed

4 files changed

+119
-51
lines changed

cassandra/cluster.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@
7373
ExponentialReconnectionPolicy, HostDistance,
7474
RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan,
7575
NoSpeculativeExecutionPolicy, DefaultLoadBalancingPolicy,
76-
NeverRetryPolicy)
76+
NeverRetryPolicy, ShardConnectionBackoffPolicy, NoDelayShardConnectionBackoffPolicy,
77+
ShardConnectionScheduler)
7778
from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler,
7879
HostConnectionPool, HostConnection,
7980
NoConnectionsAvailable)
@@ -757,6 +758,11 @@ def auth_provider(self, value):
757758

758759
self._auth_provider = value
759760

761+
_shard_connection_backoff_policy: ShardConnectionBackoffPolicy
762+
@property
763+
def shard_connection_backoff_policy(self) -> ShardConnectionBackoffPolicy:
764+
return self._shard_connection_backoff_policy
765+
760766
_load_balancing_policy = None
761767
@property
762768
def load_balancing_policy(self):
@@ -1219,7 +1225,8 @@ def __init__(self,
12191225
shard_aware_options=None,
12201226
metadata_request_timeout=None,
12211227
column_encryption_policy=None,
1222-
application_info:Optional[ApplicationInfoBase]=None
1228+
application_info: Optional[ApplicationInfoBase] = None,
1229+
shard_connection_backoff_policy: Optional[ShardConnectionBackoffPolicy] = None,
12231230
):
12241231
"""
12251232
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
@@ -1325,6 +1332,13 @@ def __init__(self,
13251332
else:
13261333
self._load_balancing_policy = default_lbp_factory() # set internal attribute to avoid committing to legacy config mode
13271334

1335+
if shard_connection_backoff_policy is not None:
1336+
if not isinstance(shard_connection_backoff_policy, ShardConnectionBackoffPolicy):
1337+
raise TypeError("shard_connection_backoff_policy should be an instance of class derived from ShardConnectionBackoffPolicy")
1338+
self._shard_connection_backoff_policy = shard_connection_backoff_policy
1339+
else:
1340+
self._shard_connection_backoff_policy = NoDelayShardConnectionBackoffPolicy()
1341+
13281342
if reconnection_policy is not None:
13291343
if isinstance(reconnection_policy, type):
13301344
raise TypeError("reconnection_policy should not be a class, it should be an instance of that class")
@@ -2716,6 +2730,7 @@ def default_serial_consistency_level(self, cl):
27162730
_metrics = None
27172731
_request_init_callbacks = None
27182732
_graph_paging_available = False
2733+
shard_connection_backoff_scheduler: ShardConnectionScheduler
27192734

27202735
def __init__(self, cluster, hosts, keyspace=None):
27212736
self.cluster = cluster
@@ -2730,6 +2745,7 @@ def __init__(self, cluster, hosts, keyspace=None):
27302745
self._protocol_version = self.cluster.protocol_version
27312746

27322747
self.encoder = Encoder()
2748+
self.shard_connection_backoff_scheduler = cluster.shard_connection_backoff_policy.new_connection_scheduler(self.cluster.scheduler)
27332749

27342750
# create connection pools in parallel
27352751
self._initial_connect_futures = set()
@@ -3340,6 +3356,7 @@ def shutdown(self):
33403356
else:
33413357
self.is_shutdown = True
33423358

3359+
self.shard_connection_backoff_scheduler.shutdown()
33433360
# PYTHON-673. If shutdown was called shortly after session init, avoid
33443361
# a race by cancelling any initial connection attempts haven't started,
33453362
# then blocking on any that have.

cassandra/pool.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
Connection pooling and host management.
1717
"""
1818
from concurrent.futures import Future
19-
from functools import total_ordering
19+
from functools import total_ordering, partial
2020
import logging
2121
import socket
2222
import time
@@ -402,7 +402,6 @@ def __init__(self, host, host_distance, session):
402402
# this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool.
403403
self._stream_available_condition = Condition(Lock())
404404
self._is_replacing = False
405-
self._connecting = set()
406405
self._connections = {}
407406
self._pending_connections = []
408407
# A pool of additional connections which are not used but affect how Scylla
@@ -418,7 +417,6 @@ def __init__(self, host, host_distance, session):
418417
# and are waiting until all requests time out or complete
419418
# so that we can dispose of them.
420419
self._trash = set()
421-
self._shard_connections_futures = []
422420
self.advanced_shardaware_block_until = 0
423421

424422
if host_distance == HostDistance.IGNORED:
@@ -483,25 +481,25 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table
483481
self.host,
484482
routing_key
485483
)
486-
if conn.orphaned_threshold_reached and shard_id not in self._connecting:
484+
if conn.orphaned_threshold_reached:
487485
# The connection has met its orphaned stream ID limit
488486
# and needs to be replaced. Start opening a connection
489487
# to the same shard and replace when it is opened.
490-
self._connecting.add(shard_id)
491-
self._session.submit(self._open_connection_to_missing_shard, shard_id)
488+
self._session.shard_connection_backoff_scheduler.schedule(
489+
self.host.host_id, shard_id, partial(self._open_connection_to_missing_shard, shard_id))
492490
log.debug(
493-
"Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)",
491+
"Scheduling Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)",
494492
shard_id,
495493
self.host,
496494
len(self._connections.keys()),
497495
self.host.sharding_info.shards_count
498496
)
499-
elif shard_id not in self._connecting:
497+
else:
500498
# rate controlled optimistic attempt to connect to a missing shard
501-
self._connecting.add(shard_id)
502-
self._session.submit(self._open_connection_to_missing_shard, shard_id)
499+
self._session.shard_connection_backoff_scheduler.schedule(
500+
self.host.host_id, shard_id, partial(self._open_connection_to_missing_shard, shard_id))
503501
log.debug(
504-
"Trying to connect to missing shard_id=%i on host %s (%s/%i)",
502+
"Scheduling connection to missing shard_id=%i on host %s (%s/%i)",
505503
shard_id,
506504
self.host,
507505
len(self._connections.keys()),
@@ -609,8 +607,8 @@ def _replace(self, connection):
609607
if connection.features.shard_id in self._connections.keys():
610608
del self._connections[connection.features.shard_id]
611609
if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable:
612-
self._connecting.add(connection.features.shard_id)
613-
self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id)
610+
self._session.shard_connection_backoff_scheduler.schedule(
611+
self.host.host_id, connection.features.shard_id, partial(self._open_connection_to_missing_shard, connection.features.shard_id))
614612
else:
615613
connection = self._session.cluster.connection_factory(self.host.endpoint,
616614
on_orphaned_stream_released=self.on_orphaned_stream_released)
@@ -635,9 +633,6 @@ def shutdown(self):
635633
with self._stream_available_condition:
636634
self._stream_available_condition.notify_all()
637635

638-
for future in self._shard_connections_futures:
639-
future.cancel()
640-
641636
connections_to_close = self._connections.copy()
642637
pending_connections_to_close = self._pending_connections.copy()
643638
self._connections.clear()
@@ -843,7 +838,6 @@ def _open_connection_to_missing_shard(self, shard_id):
843838
self._excess_connections.add(conn)
844839
if close_connection:
845840
conn.close()
846-
self._connecting.discard(shard_id)
847841

848842
def _open_connections_for_all_shards(self, skip_shard_id=None):
849843
"""
@@ -856,10 +850,8 @@ def _open_connections_for_all_shards(self, skip_shard_id=None):
856850
for shard_id in range(self.host.sharding_info.shards_count):
857851
if skip_shard_id is not None and skip_shard_id == shard_id:
858852
continue
859-
future = self._session.submit(self._open_connection_to_missing_shard, shard_id)
860-
if isinstance(future, Future):
861-
self._connecting.add(shard_id)
862-
self._shard_connections_futures.append(future)
853+
self._session.shard_connection_backoff_scheduler.schedule(
854+
self.host.host_id, shard_id, partial(self._open_connection_to_missing_shard, shard_id))
863855

864856
trash_conns = None
865857
with self._lock:

tests/unit/test_host_connection_pool.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,23 @@
2222
from threading import Thread, Event, Lock
2323
from unittest.mock import Mock, NonCallableMagicMock, MagicMock
2424

25-
from cassandra.cluster import Session, ShardAwareOptions
25+
from cassandra.cluster import Session, ShardAwareOptions, _Scheduler
2626
from cassandra.connection import Connection
2727
from cassandra.pool import HostConnection, HostConnectionPool
2828
from cassandra.pool import Host, NoConnectionsAvailable
29-
from cassandra.policies import HostDistance, SimpleConvictionPolicy
29+
from cassandra.policies import HostDistance, SimpleConvictionPolicy, _NoDelayShardConnectionBackoffScheduler
3030

3131
LOGGER = logging.getLogger(__name__)
3232

3333

34+
class FakeScheduler(_Scheduler):
35+
def __init__(self):
36+
super(FakeScheduler, self).__init__(ThreadPoolExecutor())
37+
38+
def schedule(self, delay, fn, *args, **kwargs):
39+
super().schedule(0, fn, *args, **kwargs)
40+
41+
3442
class _PoolTests(unittest.TestCase):
3543
__test__ = False
3644
PoolImpl = None
@@ -41,6 +49,9 @@ def make_session(self):
4149
session.cluster.get_core_connections_per_host.return_value = 1
4250
session.cluster.get_max_requests_per_connection.return_value = 1
4351
session.cluster.get_max_connections_per_host.return_value = 1
52+
session.shard_connection_backoff_scheduler = _NoDelayShardConnectionBackoffScheduler(FakeScheduler())
53+
session.shard_connection_backoff_scheduler.schedule = Mock(wraps=session.shard_connection_backoff_scheduler.schedule)
54+
session.is_shutdown = False
4455
return session
4556

4657
def test_borrow_and_return(self):
@@ -174,9 +185,9 @@ def test_return_defunct_connection_on_down_host(self):
174185
if self.PoolImpl is HostConnection:
175186
# on shard aware implementation we use submit function regardless
176187
self.assertTrue(host.signal_connection_failure.call_args)
177-
self.assertTrue(session.submit.called)
188+
self.assertTrue(session.shard_connection_backoff_scheduler.schedule.called)
178189
else:
179-
self.assertFalse(session.submit.called)
190+
self.assertFalse(session.shard_connection_backoff_scheduler.schedule.called)
180191
self.assertTrue(session.cluster.signal_connection_failure.call_args)
181192
self.assertTrue(pool.is_shutdown)
182193

tests/unit/test_shard_aware.py

Lines changed: 72 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import uuid
15+
from unittest.mock import Mock
16+
17+
from cassandra.policies import NoDelayShardConnectionBackoffPolicy, _NoDelayShardConnectionBackoffScheduler
1418

1519
try:
1620
import unittest2 as unittest
@@ -21,7 +25,7 @@
2125
from mock import MagicMock
2226
from concurrent.futures import ThreadPoolExecutor
2327

24-
from cassandra.cluster import ShardAwareOptions
28+
from cassandra.cluster import ShardAwareOptions, _Scheduler
2529
from cassandra.pool import HostConnection, HostDistance
2630
from cassandra.connection import ShardingInfo, DefaultEndPoint
2731
from cassandra.metadata import Murmur3Token
@@ -53,11 +57,18 @@ class OptionsHolder(object):
5357
self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"e").value), 4)
5458
self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"100000").value), 2)
5559

56-
def test_advanced_shard_aware_port(self):
60+
def test_shard_aware_reconnection_policy_no_delay(self):
61+
# with NoDelayReconnectionPolicy all the connections should be created right away
62+
self._test_shard_aware_reconnection_policy(4, NoDelayShardConnectionBackoffPolicy(), 4)
63+
64+
def _test_shard_aware_reconnection_policy(self, shard_count, shard_connection_backoff_policy, expected_connections):
5765
"""
5866
Test that on given a `shard_aware_port` on the OPTIONS message (ShardInfo class)
59-
the next connections would be open using this port
67+
It checks that:
68+
1. Next connections are opened using this port
69+
2. Connection creation pase matches `shard_connection_backoff_policy`
6070
"""
71+
6172
class MockSession(MagicMock):
6273
is_shutdown = False
6374
keyspace = "ks1"
@@ -71,45 +82,82 @@ def __init__(self, is_ssl=False, *args, **kwargs):
7182
self.cluster.ssl_options = None
7283
self.cluster.shard_aware_options = ShardAwareOptions()
7384
self.cluster.executor = ThreadPoolExecutor(max_workers=2)
85+
self._executor_submit_original = self.cluster.executor.submit
86+
self.cluster.executor.submit = self._executor_submit
87+
self.cluster.scheduler = _Scheduler(self.cluster.executor)
88+
89+
# Collect scheduled calls and execute them right away
90+
self.scheduler_calls = []
91+
original_schedule = self.cluster.scheduler.schedule
92+
93+
def new_schedule(delay, fn, *args, **kwargs):
94+
self.scheduler_calls.append((delay, fn, args, kwargs))
95+
return original_schedule(0, fn, *args, **kwargs)
96+
97+
self.cluster.scheduler.schedule = Mock(side_effect=new_schedule)
7498
self.cluster.signal_connection_failure = lambda *args, **kwargs: False
7599
self.cluster.connection_factory = self.mock_connection_factory
76100
self.connection_counter = 0
101+
self.shard_connection_backoff_scheduler = shard_connection_backoff_policy.new_connection_scheduler(
102+
self.cluster.scheduler)
77103
self.futures = []
78104

79105
def submit(self, fn, *args, **kwargs):
106+
if self.is_shutdown:
107+
return None
108+
return self.cluster.executor.submit(fn, *args, **kwargs)
109+
110+
def _executor_submit(self, fn, *args, **kwargs):
80111
logging.info("Scheduling %s with args: %s, kwargs: %s", fn, args, kwargs)
81-
if not self.is_shutdown:
82-
f = self.cluster.executor.submit(fn, *args, **kwargs)
83-
self.futures += [f]
84-
return f
112+
f = self._executor_submit_original(fn, *args, **kwargs)
113+
self.futures += [f]
114+
return f
85115

86116
def mock_connection_factory(self, *args, **kwargs):
87117
connection = MagicMock()
88118
connection.is_shutdown = False
89119
connection.is_defunct = False
90120
connection.is_closed = False
91121
connection.orphaned_threshold_reached = False
92-
connection.endpoint = args[0]
93-
sharding_info = ShardingInfo(shard_id=1, shards_count=4, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045)
94-
connection.features = ProtocolFeatures(shard_id=kwargs.get('shard_id', self.connection_counter), sharding_info=sharding_info)
122+
connection.endpoint = args[0]
123+
sharding_info = None
124+
if shard_count:
125+
sharding_info = ShardingInfo(shard_id=1, shards_count=shard_count, partitioner="",
126+
sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042,
127+
shard_aware_port_ssl=19045)
128+
connection.features = ProtocolFeatures(
129+
shard_id=kwargs.get('shard_id', self.connection_counter),
130+
sharding_info=sharding_info)
95131
self.connection_counter += 1
96132

97133
return connection
98134

99135
host = MagicMock()
136+
host.host_id = uuid.uuid4()
100137
host.endpoint = DefaultEndPoint("1.2.3.4")
138+
session = None
139+
try:
140+
for port, is_ssl in [(19042, False), (19045, True)]:
141+
session = MockSession(is_ssl=is_ssl)
142+
pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session)
143+
for f in session.futures:
144+
f.result()
145+
assert len(pool._connections) == expected_connections
146+
for shard_id, connection in pool._connections.items():
147+
assert connection.features.shard_id == shard_id
148+
if shard_id == 0:
149+
assert connection.endpoint == DefaultEndPoint("1.2.3.4")
150+
else:
151+
assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port)
101152

102-
for port, is_ssl in [(19042, False), (19045, True)]:
103-
session = MockSession(is_ssl=is_ssl)
104-
pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session)
105-
for f in session.futures:
106-
f.result()
107-
assert len(pool._connections) == 4
108-
for shard_id, connection in pool._connections.items():
109-
assert connection.features.shard_id == shard_id
110-
if shard_id == 0:
111-
assert connection.endpoint == DefaultEndPoint("1.2.3.4")
112-
else:
113-
assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port)
114-
115-
session.cluster.executor.shutdown(wait=True)
153+
sleep_time = 0
154+
found_related_calls = 0
155+
for delay, fn, args, kwargs in session.scheduler_calls:
156+
if fn.__self__.__class__ is _NoDelayShardConnectionBackoffScheduler:
157+
found_related_calls += 1
158+
self.assertEqual(delay, sleep_time)
159+
self.assertLessEqual(shard_count - 1, found_related_calls)
160+
finally:
161+
if session:
162+
session.cluster.scheduler.shutdown()
163+
session.cluster.executor.shutdown(wait=True)

0 commit comments

Comments
 (0)