Skip to content

Commit eaacfb9

Browse files
committed
Make blocking set keyspace query to fail by timeout
1 parent 7e0b02d commit eaacfb9

File tree

4 files changed

+45
-8
lines changed

4 files changed

+45
-8
lines changed

cassandra/cluster.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2388,7 +2388,7 @@ def _prepare_all_queries(self, host):
23882388
else:
23892389
for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace):
23902390
if keyspace is not None:
2391-
connection.set_keyspace_blocking(keyspace)
2391+
connection.set_keyspace_blocking(keyspace, self.control_connection_timeout)
23922392

23932393
# prepare 10 statements at a time
23942394
ks_statements = list(ks_statements)

cassandra/connection.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1498,14 +1498,14 @@ def _handle_auth_response(self, auth_response):
14981498
log.error(msg, self.endpoint, auth_response)
14991499
raise ProtocolError(msg % (self.endpoint, auth_response))
15001500

1501-
def set_keyspace_blocking(self, keyspace):
1501+
def set_keyspace_blocking(self, keyspace, timeout=None):
15021502
if not keyspace or keyspace == self.keyspace:
15031503
return
15041504

15051505
query = QueryMessage(query='USE "%s"' % (keyspace,),
15061506
consistency_level=ConsistencyLevel.ONE)
15071507
try:
1508-
result = self.wait_for_response(query)
1508+
result = self.wait_for_response(query, timeout=timeout)
15091509
except InvalidRequestException as ire:
15101510
# the keyspace probably doesn't exist
15111511
raise ire.to_exception()

cassandra/pool.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def __init__(self, host, host_distance, session):
435435
self._keyspace = session.keyspace
436436

437437
if self._keyspace:
438-
first_connection.set_keyspace_blocking(self._keyspace)
438+
first_connection.set_keyspace_blocking(self._keyspace, session.cluster.control_connection_timeout)
439439
if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable:
440440
self.host.sharding_info = first_connection.features.sharding_info
441441
self._open_connections_for_all_shards(first_connection.features.shard_id)
@@ -615,7 +615,7 @@ def _replace(self, connection):
615615
connection = self._session.cluster.connection_factory(self.host.endpoint,
616616
on_orphaned_stream_released=self.on_orphaned_stream_released)
617617
if self._keyspace:
618-
connection.set_keyspace_blocking(self._keyspace)
618+
connection.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout)
619619
self._connections[connection.features.shard_id] = connection
620620
except Exception:
621621
log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,))
@@ -766,7 +766,7 @@ def _open_connection_to_missing_shard(self, shard_id):
766766
self.host
767767
)
768768
if self._keyspace:
769-
conn.set_keyspace_blocking(self._keyspace)
769+
conn.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout)
770770

771771
self._connections[conn.features.shard_id] = conn
772772
if old_conn is not None:
@@ -953,7 +953,7 @@ def __init__(self, host, host_distance, session):
953953
self._keyspace = session.keyspace
954954
if self._keyspace:
955955
for conn in self._connections:
956-
conn.set_keyspace_blocking(self._keyspace)
956+
conn.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout)
957957

958958
self._trash = set()
959959
self._next_trash_allowed_at = time.time()
@@ -1053,7 +1053,7 @@ def _add_conn_if_under_max(self):
10531053
try:
10541054
conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
10551055
if self._keyspace:
1056-
conn.set_keyspace_blocking(self._session.keyspace)
1056+
conn.set_keyspace_blocking(self._session.keyspace, self._session.cluster.control_connection_timeout)
10571057
self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL
10581058
with self._lock:
10591059
new_connections = self._connections[:] + [conn]

tests/integration/standard/test_cluster.py

+37
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
RetryPolicy, SimpleConvictionPolicy, HostDistance,
3333
AddressTranslator, TokenAwarePolicy, HostFilterPolicy)
3434
from cassandra import ConsistencyLevel
35+
from cassandra.protocol import ProtocolHandler, QueryMessage
3536

3637
from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory
3738
from cassandra.auth import PlainTextAuthProvider, SaslAuthProvider
@@ -484,6 +485,42 @@ def test_refresh_schema_table(self):
484485
self.assertEqual(original_system_schema_meta.as_cql_query(), current_system_schema_meta.as_cql_query())
485486
cluster.shutdown()
486487

488+
def test_use_keyspace_blocking(self):
489+
ks = "test_refresh_schema_type"
490+
491+
cluster = TestCluster()
492+
send_msg_orig = cluster.connection_class.send_msg
493+
494+
def send_msg_patched(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message,
495+
decoder=ProtocolHandler.decode_message, result_metadata=None):
496+
if isinstance(msg, QueryMessage) and f'USE "{ks}"' in msg.query:
497+
orig_decoder = decoder
498+
def decode_patched(protocol_version, protocol_features, user_type_map, stream_id, flags, opcode, body,
499+
decompressor, result_metadata):
500+
time.sleep(cluster.control_connection_timeout + 0.1)
501+
return orig_decoder(protocol_version, protocol_features, user_type_map, stream_id, flags,
502+
opcode, body, decompressor, result_metadata)
503+
504+
decoder = decode_patched
505+
506+
return send_msg_orig(self, msg, request_id, cb, encoder, decoder, result_metadata)
507+
508+
cluster.connection_class.send_msg = send_msg_patched
509+
510+
cluster.connect().execute("""
511+
CREATE KEYSPACE IF NOT EXISTS %s
512+
WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' }
513+
""" % ks)
514+
515+
try:
516+
cluster.connect(ks)
517+
except NoHostAvailable:
518+
pass
519+
except Exception as e:
520+
self.fail(f"got unexpected exception {e}")
521+
else:
522+
self.fail("connection should fail, but was not")
523+
487524
def test_refresh_schema_type(self):
488525
if get_server_versions()[0] < (2, 1, 0):
489526
raise unittest.SkipTest('UDTs were introduced in Cassandra 2.1')

0 commit comments

Comments
 (0)