11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import uuid
15
+ from unittest .mock import Mock
16
+
17
+ from cassandra .policies import NoDelayShardConnectionBackoffPolicy , _NoDelayShardConnectionBackoffScheduler
14
18
15
19
try :
16
20
import unittest2 as unittest
21
25
from mock import MagicMock
22
26
from concurrent .futures import ThreadPoolExecutor
23
27
24
- from cassandra .cluster import ShardAwareOptions
28
+ from cassandra .cluster import ShardAwareOptions , _Scheduler
25
29
from cassandra .pool import HostConnection , HostDistance
26
30
from cassandra .connection import ShardingInfo , DefaultEndPoint
27
31
from cassandra .metadata import Murmur3Token
@@ -53,11 +57,18 @@ class OptionsHolder(object):
53
57
self .assertEqual (shard_info .shard_id_from_token (Murmur3Token .from_key (b"e" ).value ), 4 )
54
58
self .assertEqual (shard_info .shard_id_from_token (Murmur3Token .from_key (b"100000" ).value ), 2 )
55
59
56
- def test_advanced_shard_aware_port (self ):
60
+ def test_shard_aware_reconnection_policy_no_delay (self ):
61
+ # with NoDelayReconnectionPolicy all the connections should be created right away
62
+ self ._test_shard_aware_reconnection_policy (4 , NoDelayShardConnectionBackoffPolicy (), 4 )
63
+
64
+ def _test_shard_aware_reconnection_policy (self , shard_count , shard_connection_backoff_policy , expected_connections ):
57
65
"""
58
66
Test that on given a `shard_aware_port` on the OPTIONS message (ShardInfo class)
59
- the next connections would be open using this port
67
+ It checks that:
68
+ 1. Next connections are opened using this port
69
+ 2. Connection creation pase matches `shard_connection_backoff_policy`
60
70
"""
71
+
61
72
class MockSession (MagicMock ):
62
73
is_shutdown = False
63
74
keyspace = "ks1"
@@ -71,45 +82,82 @@ def __init__(self, is_ssl=False, *args, **kwargs):
71
82
self .cluster .ssl_options = None
72
83
self .cluster .shard_aware_options = ShardAwareOptions ()
73
84
self .cluster .executor = ThreadPoolExecutor (max_workers = 2 )
85
+ self ._executor_submit_original = self .cluster .executor .submit
86
+ self .cluster .executor .submit = self ._executor_submit
87
+ self .cluster .scheduler = _Scheduler (self .cluster .executor )
88
+
89
+ # Collect scheduled calls and execute them right away
90
+ self .scheduler_calls = []
91
+ original_schedule = self .cluster .scheduler .schedule
92
+
93
+ def new_schedule (delay , fn , * args , ** kwargs ):
94
+ self .scheduler_calls .append ((delay , fn , args , kwargs ))
95
+ return original_schedule (0 , fn , * args , ** kwargs )
96
+
97
+ self .cluster .scheduler .schedule = Mock (side_effect = new_schedule )
74
98
self .cluster .signal_connection_failure = lambda * args , ** kwargs : False
75
99
self .cluster .connection_factory = self .mock_connection_factory
76
100
self .connection_counter = 0
101
+ self .shard_connection_backoff_scheduler = shard_connection_backoff_policy .new_connection_scheduler (
102
+ self .cluster .scheduler )
77
103
self .futures = []
78
104
79
105
def submit (self , fn , * args , ** kwargs ):
106
+ if self .is_shutdown :
107
+ return None
108
+ return self .cluster .executor .submit (fn , * args , ** kwargs )
109
+
110
+ def _executor_submit (self , fn , * args , ** kwargs ):
80
111
logging .info ("Scheduling %s with args: %s, kwargs: %s" , fn , args , kwargs )
81
- if not self .is_shutdown :
82
- f = self .cluster .executor .submit (fn , * args , ** kwargs )
83
- self .futures += [f ]
84
- return f
112
+ f = self ._executor_submit_original (fn , * args , ** kwargs )
113
+ self .futures += [f ]
114
+ return f
85
115
86
116
def mock_connection_factory (self , * args , ** kwargs ):
87
117
connection = MagicMock ()
88
118
connection .is_shutdown = False
89
119
connection .is_defunct = False
90
120
connection .is_closed = False
91
121
connection .orphaned_threshold_reached = False
92
- connection .endpoint = args [0 ]
93
- sharding_info = ShardingInfo (shard_id = 1 , shards_count = 4 , partitioner = "" , sharding_algorithm = "" , sharding_ignore_msb = 0 , shard_aware_port = 19042 , shard_aware_port_ssl = 19045 )
94
- connection .features = ProtocolFeatures (shard_id = kwargs .get ('shard_id' , self .connection_counter ), sharding_info = sharding_info )
122
+ connection .endpoint = args [0 ]
123
+ sharding_info = None
124
+ if shard_count :
125
+ sharding_info = ShardingInfo (shard_id = 1 , shards_count = shard_count , partitioner = "" ,
126
+ sharding_algorithm = "" , sharding_ignore_msb = 0 , shard_aware_port = 19042 ,
127
+ shard_aware_port_ssl = 19045 )
128
+ connection .features = ProtocolFeatures (
129
+ shard_id = kwargs .get ('shard_id' , self .connection_counter ),
130
+ sharding_info = sharding_info )
95
131
self .connection_counter += 1
96
132
97
133
return connection
98
134
99
135
host = MagicMock ()
136
+ host .host_id = uuid .uuid4 ()
100
137
host .endpoint = DefaultEndPoint ("1.2.3.4" )
138
+ session = None
139
+ try :
140
+ for port , is_ssl in [(19042 , False ), (19045 , True )]:
141
+ session = MockSession (is_ssl = is_ssl )
142
+ pool = HostConnection (host = host , host_distance = HostDistance .REMOTE , session = session )
143
+ for f in session .futures :
144
+ f .result ()
145
+ assert len (pool ._connections ) == expected_connections
146
+ for shard_id , connection in pool ._connections .items ():
147
+ assert connection .features .shard_id == shard_id
148
+ if shard_id == 0 :
149
+ assert connection .endpoint == DefaultEndPoint ("1.2.3.4" )
150
+ else :
151
+ assert connection .endpoint == DefaultEndPoint ("1.2.3.4" , port = port )
101
152
102
- for port , is_ssl in [(19042 , False ), (19045 , True )]:
103
- session = MockSession (is_ssl = is_ssl )
104
- pool = HostConnection (host = host , host_distance = HostDistance .REMOTE , session = session )
105
- for f in session .futures :
106
- f .result ()
107
- assert len (pool ._connections ) == 4
108
- for shard_id , connection in pool ._connections .items ():
109
- assert connection .features .shard_id == shard_id
110
- if shard_id == 0 :
111
- assert connection .endpoint == DefaultEndPoint ("1.2.3.4" )
112
- else :
113
- assert connection .endpoint == DefaultEndPoint ("1.2.3.4" , port = port )
114
-
115
- session .cluster .executor .shutdown (wait = True )
153
+ sleep_time = 0
154
+ found_related_calls = 0
155
+ for delay , fn , args , kwargs in session .scheduler_calls :
156
+ if fn .__self__ .__class__ is _NoDelayShardConnectionBackoffScheduler :
157
+ found_related_calls += 1
158
+ self .assertEqual (delay , sleep_time )
159
+ self .assertLessEqual (shard_count - 1 , found_related_calls )
160
+ finally :
161
+ if session :
162
+ session .cluster .scheduler .shutdown ()
163
+ session .cluster .executor .shutdown (wait = True )
0 commit comments