Skip to content

Commit 3d97ecd

Browse files
committed
Implementa integration tests for shard connection backof policies
Tests cover: 1. LimitedConcurrencyShardConnectionBackoffPolicy 2. NoDelayShardConnectionBackoffPolicy For both Scylla and Cassandra backend.
1 parent 22c82c2 commit 3d97ecd

File tree

1 file changed

+98
-2
lines changed

1 file changed

+98
-2
lines changed

tests/integration/long/test_policies.py

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,25 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import inspect
15+
import os
16+
import time
1517
import unittest
18+
from typing import Optional
19+
from unittest.mock import Mock
1620

1721
from cassandra import ConsistencyLevel, Unavailable
18-
from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT
22+
from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT, Session
23+
from cassandra.policies import LimitedConcurrencyShardConnectionBackoffPolicy, ConstantReconnectionPolicy, \
24+
ShardConnectionBackoffPolicy, NoDelayShardConnectionBackoffPolicy, _ScopeBucket, \
25+
_NoDelayShardConnectionBackoffScheduler
26+
from cassandra.shard_info import _ShardingInfo
1927

2028
from tests.integration import use_cluster, get_cluster, get_node, TestCluster
2129

2230

2331
def setup_module():
32+
os.environ['SCYLLA_EXT_OPTS'] = "--smp 4"
2433
use_cluster('test_cluster', [4])
2534

2635

@@ -65,3 +74,90 @@ def test_should_rethrow_on_unvailable_with_default_policy_if_cas(self):
6574
self.assertEqual(exception.consistency, ConsistencyLevel.SERIAL)
6675
self.assertEqual(exception.required_replicas, 2)
6776
self.assertEqual(exception.alive_replicas, 1)
77+
78+
79+
class ShardBackoffPolicyTests(unittest.TestCase):
80+
@classmethod
81+
def tearDownClass(cls):
82+
cluster = get_cluster()
83+
cluster.start(wait_for_binary_proto=True, wait_other_notice=True) # make sure other nodes are restarted
84+
85+
def test_limited_concurrency_1_connection_per_host(self):
86+
self._test_backoff(
87+
LimitedConcurrencyShardConnectionBackoffPolicy(
88+
backoff_policy=ConstantReconnectionPolicy(0.1),
89+
max_concurrent=1,
90+
)
91+
)
92+
93+
def test_limited_concurrency_2_connection_per_host(self):
94+
self._test_backoff(
95+
LimitedConcurrencyShardConnectionBackoffPolicy(
96+
backoff_policy=ConstantReconnectionPolicy(0.1),
97+
max_concurrent=1,
98+
)
99+
)
100+
101+
def test_no_delay(self):
102+
self._test_backoff(NoDelayShardConnectionBackoffPolicy())
103+
104+
def _test_backoff(self, shard_connection_backoff_policy: ShardConnectionBackoffPolicy):
105+
backoff_policy = None
106+
if isinstance(shard_connection_backoff_policy, LimitedConcurrencyShardConnectionBackoffPolicy):
107+
backoff_policy = shard_connection_backoff_policy.backoff_policy
108+
109+
cluster = TestCluster(
110+
shard_connection_backoff_policy=shard_connection_backoff_policy,
111+
reconnection_policy=ConstantReconnectionPolicy(0),
112+
)
113+
114+
# Collect scheduled calls and execute them right away
115+
scheduler_calls = []
116+
original_schedule = cluster.scheduler.schedule
117+
118+
def new_schedule(delay, fn, *args, **kwargs):
119+
scheduler_calls.append((delay, fn, args, kwargs))
120+
return original_schedule(0, fn, *args, **kwargs)
121+
122+
cluster.scheduler.schedule = Mock(side_effect=new_schedule)
123+
124+
session = cluster.connect()
125+
sharding_info = get_sharding_info(session)
126+
127+
# Since scheduled calls executed in a separate thread we need to give them some time to complete
128+
time.sleep(0.2)
129+
130+
if not sharding_info:
131+
# If it is not scylla `ShardConnectionBackoffScheduler` should not be involved
132+
for delay, fn, args, kwargs in scheduler_calls:
133+
if fn.__self__.__class__ is _ScopeBucket or fn.__self__.__class__ is _NoDelayShardConnectionBackoffScheduler:
134+
self.fail(
135+
"in non-shard-aware case connection should be created directly, not involving ShardConnectionBackoffScheduler")
136+
return
137+
138+
sleep_time = 0
139+
if backoff_policy:
140+
schedule = backoff_policy.new_schedule()
141+
sleep_time = next(iter(schedule))
142+
143+
# Make sure that all scheduled calls have delay according to policy
144+
found_related_calls = 0
145+
for delay, fn, args, kwargs in scheduler_calls:
146+
if fn.__self__.__class__ is _ScopeBucket or fn.__self__.__class__ is _NoDelayShardConnectionBackoffScheduler:
147+
found_related_calls += 1
148+
self.assertEqual(delay, sleep_time)
149+
self.assertLessEqual(len(session.hosts) * (sharding_info.shards_count - 1), found_related_calls)
150+
151+
152+
def get_connections_per_host(session: Session) -> dict[str, int]:
153+
host_connections = {}
154+
for host, pool in session._pools.items():
155+
host_connections[host.host_id] = len(pool._connections)
156+
return host_connections
157+
158+
159+
def get_sharding_info(session: Session) -> Optional[_ShardingInfo]:
160+
for host in session.hosts:
161+
if host.sharding_info:
162+
return host.sharding_info
163+
return None

0 commit comments

Comments
 (0)