25
25
import collections
26
26
import logging
27
27
import time
28
- from typing import Set
28
+ from typing import Set , Mapping , List
29
29
from urllib .parse import urlparse
30
30
31
31
from . import documents
32
32
from . import http_constants
33
33
from .documents import _OperationType
34
+ from ._request_object import RequestObject
34
35
35
36
# pylint: disable=protected-access
36
37
@@ -113,7 +114,10 @@ def get_endpoints_by_location(new_locations,
113
114
except Exception as e :
114
115
raise e
115
116
116
- return endpoints_by_location , parsed_locations
117
+ # Also store a hash map of endpoints for each location
118
+ locations_by_endpoints = {value .get_primary (): key for key , value in endpoints_by_location .items ()}
119
+
120
+ return endpoints_by_location , locations_by_endpoints , parsed_locations
117
121
118
122
def add_endpoint_if_preferred (endpoint : str , preferred_endpoints : Set [str ], endpoints : Set [str ]) -> bool :
119
123
if endpoint in preferred_endpoints :
@@ -150,31 +154,44 @@ def _get_health_check_endpoints(
150
154
151
155
return endpoints
152
156
157
+ def _get_applicable_regional_routing_contexts (regional_routing_contexts : List [RegionalRoutingContext ],
158
+ location_name_by_endpoint : Mapping [str , str ],
159
+ fall_back_regional_routing_context : RegionalRoutingContext ,
160
+ exclude_location_list : List [str ]) -> List [RegionalRoutingContext ]:
161
+ # filter endpoints by excluded locations
162
+ applicable_regional_routing_contexts = []
163
+ for regional_routing_context in regional_routing_contexts :
164
+ if location_name_by_endpoint .get (regional_routing_context .get_primary ()) not in exclude_location_list :
165
+ applicable_regional_routing_contexts .append (regional_routing_context )
166
+
167
+ # if endpoint is empty add fallback endpoint
168
+ if not applicable_regional_routing_contexts :
169
+ applicable_regional_routing_contexts .append (fall_back_regional_routing_context )
170
+
171
+ return applicable_regional_routing_contexts
153
172
154
173
class LocationCache (object ): # pylint: disable=too-many-public-methods,too-many-instance-attributes
155
174
def current_time_millis (self ):
156
175
return int (round (time .time () * 1000 ))
157
176
158
177
def __init__ (
159
178
self ,
160
- preferred_locations ,
161
179
default_endpoint ,
162
- enable_endpoint_discovery ,
163
- use_multiple_write_locations ,
180
+ connection_policy ,
164
181
):
165
- self .preferred_locations = preferred_locations
166
182
self .default_regional_routing_context = RegionalRoutingContext (default_endpoint , default_endpoint )
167
- self .enable_endpoint_discovery = enable_endpoint_discovery
168
- self .use_multiple_write_locations = use_multiple_write_locations
169
183
self .enable_multiple_writable_locations = False
170
184
self .write_regional_routing_contexts = [self .default_regional_routing_context ]
171
185
self .read_regional_routing_contexts = [self .default_regional_routing_context ]
172
186
self .location_unavailability_info_by_endpoint = {}
173
187
self .last_cache_update_time_stamp = 0
174
188
self .account_read_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long
175
189
self .account_write_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long
190
+ self .account_locations_by_read_regional_routing_context = {} # pylint: disable=name-too-long
191
+ self .account_locations_by_write_regional_routing_context = {} # pylint: disable=name-too-long
176
192
self .account_write_locations = []
177
193
self .account_read_locations = []
194
+ self .connection_policy = connection_policy
178
195
179
196
def get_write_regional_routing_contexts (self ):
180
197
return self .write_regional_routing_contexts
@@ -207,6 +224,44 @@ def get_ordered_write_locations(self):
207
224
def get_ordered_read_locations (self ):
208
225
return self .account_read_locations
209
226
227
+ def _get_configured_excluded_locations (self , request : RequestObject ) -> List [str ]:
228
+ # If excluded locations were configured on request, use request level excluded locations.
229
+ excluded_locations = request .excluded_locations
230
+ if excluded_locations is None :
231
+ # If excluded locations were only configured on client(connection_policy), use client level
232
+ excluded_locations = self .connection_policy .ExcludedLocations
233
+ return excluded_locations
234
+
235
+ def _get_applicable_read_regional_routing_contexts (self , request : RequestObject ) -> List [RegionalRoutingContext ]:
236
+ # Get configured excluded locations
237
+ excluded_locations = self ._get_configured_excluded_locations (request )
238
+
239
+ # If excluded locations were configured, return filtered regional endpoints by excluded locations.
240
+ if excluded_locations :
241
+ return _get_applicable_regional_routing_contexts (
242
+ self .get_read_regional_routing_contexts (),
243
+ self .account_locations_by_read_regional_routing_context ,
244
+ self .get_write_regional_routing_contexts ()[0 ],
245
+ excluded_locations )
246
+
247
+ # Else, return all regional endpoints
248
+ return self .get_read_regional_routing_contexts ()
249
+
250
+ def _get_applicable_write_regional_routing_contexts (self , request : RequestObject ) -> List [RegionalRoutingContext ]:
251
+ # Get configured excluded locations
252
+ excluded_locations = self ._get_configured_excluded_locations (request )
253
+
254
+ # If excluded locations were configured, return filtered regional endpoints by excluded locations.
255
+ if excluded_locations :
256
+ return _get_applicable_regional_routing_contexts (
257
+ self .get_write_regional_routing_contexts (),
258
+ self .account_locations_by_write_regional_routing_context ,
259
+ self .default_regional_routing_context ,
260
+ excluded_locations )
261
+
262
+ # Else, return all regional endpoints
263
+ return self .get_write_regional_routing_contexts ()
264
+
210
265
def resolve_service_endpoint (self , request ):
211
266
if request .location_endpoint_to_route :
212
267
return request .location_endpoint_to_route
@@ -227,7 +282,7 @@ def resolve_service_endpoint(self, request):
227
282
# For non-document resource types in case of client can use multiple write locations
228
283
# or when client cannot use multiple write locations, flip-flop between the
229
284
# first and the second writable region in DatabaseAccount (for manual failover)
230
- if self .enable_endpoint_discovery and self .account_write_locations :
285
+ if self .connection_policy . EnableEndpointDiscovery and self .account_write_locations :
231
286
location_index = min (location_index % 2 , len (self .account_write_locations ) - 1 )
232
287
write_location = self .account_write_locations [location_index ]
233
288
if (self .account_write_regional_routing_contexts_by_location
@@ -247,9 +302,9 @@ def resolve_service_endpoint(self, request):
247
302
return self .default_regional_routing_context .get_primary ()
248
303
249
304
regional_routing_contexts = (
250
- self .get_write_regional_routing_contexts ( )
305
+ self ._get_applicable_write_regional_routing_contexts ( request )
251
306
if documents ._OperationType .IsWriteOperation (request .operation_type )
252
- else self .get_read_regional_routing_contexts ( )
307
+ else self ._get_applicable_read_regional_routing_contexts ( request )
253
308
)
254
309
regional_routing_context = regional_routing_contexts [location_index % len (regional_routing_contexts )]
255
310
if (
@@ -263,12 +318,14 @@ def resolve_service_endpoint(self, request):
263
318
return regional_routing_context .get_primary ()
264
319
265
320
def should_refresh_endpoints (self ): # pylint: disable=too-many-return-statements
266
- most_preferred_location = self .preferred_locations [0 ] if self .preferred_locations else None
321
+ most_preferred_location = self .connection_policy .PreferredLocations [0 ] \
322
+ if self .connection_policy .PreferredLocations else None
267
323
268
324
# we should schedule refresh in background if we are unable to target the user's most preferredLocation.
269
- if self .enable_endpoint_discovery :
325
+ if self .connection_policy . EnableEndpointDiscovery :
270
326
271
- should_refresh = self .use_multiple_write_locations and not self .enable_multiple_writable_locations
327
+ should_refresh = (self .connection_policy .UseMultipleWriteLocations
328
+ and not self .enable_multiple_writable_locations )
272
329
273
330
if (most_preferred_location and most_preferred_location in
274
331
self .account_read_regional_routing_contexts_by_location ):
@@ -358,25 +415,27 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl
358
415
if enable_multiple_writable_locations :
359
416
self .enable_multiple_writable_locations = enable_multiple_writable_locations
360
417
361
- if self .enable_endpoint_discovery :
418
+ if self .connection_policy . EnableEndpointDiscovery :
362
419
if read_locations :
363
420
(self .account_read_regional_routing_contexts_by_location ,
421
+ self .account_locations_by_read_regional_routing_context ,
364
422
self .account_read_locations ) = get_endpoints_by_location (
365
423
read_locations ,
366
424
self .account_read_regional_routing_contexts_by_location ,
367
425
self .default_regional_routing_context ,
368
426
False ,
369
- self .use_multiple_write_locations
427
+ self .connection_policy . UseMultipleWriteLocations
370
428
)
371
429
372
430
if write_locations :
373
431
(self .account_write_regional_routing_contexts_by_location ,
432
+ self .account_locations_by_write_regional_routing_context ,
374
433
self .account_write_locations ) = get_endpoints_by_location (
375
434
write_locations ,
376
435
self .account_write_regional_routing_contexts_by_location ,
377
436
self .default_regional_routing_context ,
378
437
True ,
379
- self .use_multiple_write_locations
438
+ self .connection_policy . UseMultipleWriteLocations
380
439
)
381
440
382
441
self .write_regional_routing_contexts = self .get_preferred_regional_routing_contexts (
@@ -399,18 +458,18 @@ def get_preferred_regional_routing_contexts(
399
458
regional_endpoints = []
400
459
# if enableEndpointDiscovery is false, we always use the defaultEndpoint that
401
460
# user passed in during documentClient init
402
- if self .enable_endpoint_discovery and endpoints_by_location : # pylint: disable=too-many-nested-blocks
461
+ if self .connection_policy . EnableEndpointDiscovery and endpoints_by_location : # pylint: disable=too-many-nested-blocks
403
462
if (
404
463
self .can_use_multiple_write_locations ()
405
464
or expected_available_operation == EndpointOperationType .ReadType
406
465
):
407
466
unavailable_endpoints = []
408
- if self .preferred_locations :
467
+ if self .connection_policy . PreferredLocations :
409
468
# When client can not use multiple write locations, preferred locations
410
469
# list should only be used determining read endpoints order. If client
411
470
# can use multiple write locations, preferred locations list should be
412
471
# used for determining both read and write endpoints order.
413
- for location in self .preferred_locations :
472
+ for location in self .connection_policy . PreferredLocations :
414
473
regional_endpoint = endpoints_by_location [location ] if location in endpoints_by_location \
415
474
else None
416
475
if regional_endpoint :
@@ -436,11 +495,12 @@ def get_preferred_regional_routing_contexts(
436
495
return regional_endpoints
437
496
438
497
def can_use_multiple_write_locations (self ):
439
- return self .use_multiple_write_locations and self .enable_multiple_writable_locations
498
+ return self .connection_policy . UseMultipleWriteLocations and self .enable_multiple_writable_locations
440
499
441
500
def can_use_multiple_write_locations_for_request (self , request ): # pylint: disable=name-too-long
442
501
return self .can_use_multiple_write_locations () and (
443
502
request .resource_type == http_constants .ResourceType .Document
503
+ or request .resource_type == http_constants .ResourceType .PartitionKey
444
504
or (
445
505
request .resource_type == http_constants .ResourceType .StoredProcedure
446
506
and request .operation_type == documents ._OperationType .ExecuteJavaScript
0 commit comments