Skip to content

Commit c619e5b

Browse files
committed
Add unit/integration tests for shard aware
1 parent ae5b0a7 commit c619e5b

File tree

2 files changed

+161
-0
lines changed

2 files changed

+161
-0
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright DataStax, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
try:
16+
import unittest2 as unittest
17+
except ImportError:
18+
import unittest # noqa
19+
20+
from cassandra.cluster import Cluster
21+
from cassandra.policies import TokenAwarePolicy, RoundRobinPolicy
22+
23+
from tests.integration import use_singledc
24+
25+
26+
def setup_module():
27+
use_singledc()
28+
29+
30+
class TestShardAwareIntegration(unittest.TestCase):
31+
@classmethod
32+
def setup_class(cls):
33+
cls.cluster = Cluster(protocol_version=4, load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()))
34+
cls.session = cls.cluster.connect()
35+
36+
@classmethod
37+
def teardown_class(cls):
38+
cls.cluster.shutdown()
39+
40+
def verify_same_shard_in_tracing(self, results, shard_name):
41+
trace_id = results.response_future.get_query_trace_ids()[0]
42+
traces = self.session.execute("SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id, ))
43+
events = [event for event in traces]
44+
for event in events:
45+
print(event.thread, event.activity)
46+
for event in traces:
47+
self.assertEqual(event.thread, shard_name)
48+
self.assertIn('querying locally', "\n".join([ event.activity for event in events ]))
49+
50+
def test_all_tracing_coming_one_shard(self):
51+
"""
52+
Testing that shard aware driver is sending the requests to the correct shards
53+
54+
using the traces to validate that all the action been executed on the the same shard.
55+
this test is using prepared SELECT statements for this validation
56+
"""
57+
58+
self.session.execute(
59+
"""
60+
DROP KEYSPACE IF EXISTS preparedtests
61+
"""
62+
)
63+
self.session.execute(
64+
"""
65+
CREATE KEYSPACE preparedtests
66+
WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}
67+
""")
68+
69+
self.session.execute("USE preparedtests")
70+
self.session.execute(
71+
"""
72+
CREATE TABLE cf0 (
73+
a text,
74+
b text,
75+
c text,
76+
PRIMARY KEY (a, b)
77+
)
78+
""")
79+
80+
prepared = self.session.prepare(
81+
"""
82+
INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?)
83+
""")
84+
85+
bound = prepared.bind(('a', 'b', 'c'))
86+
87+
self.session.execute(bound)
88+
89+
bound = prepared.bind(('e', 'f', 'g'))
90+
91+
self.session.execute(bound)
92+
93+
bound = prepared.bind(('100000', 'f', 'g'))
94+
95+
self.session.execute(bound)
96+
97+
prepared = self.session.prepare(
98+
"""
99+
SELECT * FROM cf0 WHERE a=? AND b=?
100+
""")
101+
102+
bound = prepared.bind(('a', 'b'))
103+
results = self.session.execute(bound, trace=True)
104+
self.assertEqual(results, [('a', 'b', 'c')])
105+
106+
self.verify_same_shard_in_tracing(results, "shard 4")
107+
108+
bound = prepared.bind(('100000', 'f'))
109+
results = self.session.execute(bound, trace=True)
110+
self.assertEqual(results, [('100000', 'f', 'g')])
111+
112+
self.verify_same_shard_in_tracing(results, "shard 2")
113+
114+
bound = prepared.bind(('e', 'f'))
115+
results = self.session.execute(bound, trace=True)
116+
117+
self.verify_same_shard_in_tracing(results, "shard 4")

tests/unit/test_shard_aware.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright DataStax, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
try:
16+
import unittest2 as unittest
17+
except ImportError:
18+
import unittest # noqa
19+
20+
from cassandra.connection import ShardingInfo
21+
from cassandra.metadata import Murmur3Token
22+
23+
class TestShardAware(unittest.TestCase):
24+
def test_parsing_and_calculating_shard_id(self):
25+
'''
26+
Testing the parsing of the options command
27+
and the calculation getting a shard id from a Murmur3 token
28+
'''
29+
class OptionsHolder():
30+
options = {
31+
'SCYLLA_SHARD': ['1'],
32+
'SCYLLA_NR_SHARDS': ['12'],
33+
'SCYLLA_PARTITIONER': ['org.apache.cassandra.dht.Murmur3Partitioner'],
34+
'SCYLLA_SHARDING_ALGORITHM': ['biased-token-round-robin'],
35+
'SCYLLA_SHARDING_IGNORE_MSB': ['12']
36+
}
37+
shard_id, shard_info = ShardingInfo.parse_sharding_info(OptionsHolder())
38+
39+
self.assertEqual(shard_id, 1)
40+
self.assertEqual(shard_info.shard_id(Murmur3Token.from_key(b"a")), 4)
41+
self.assertEqual(shard_info.shard_id(Murmur3Token.from_key(b"b")), 6)
42+
self.assertEqual(shard_info.shard_id(Murmur3Token.from_key(b"c")), 6)
43+
self.assertEqual(shard_info.shard_id(Murmur3Token.from_key(b"e")), 4)
44+
self.assertEqual(shard_info.shard_id(Murmur3Token.from_key(b"100000")), 2)

0 commit comments

Comments
 (0)