Skip to content

Commit f4ce81d

Browse files
Add excluded locations on client and request levels (Azure#40298)
* Add Excluded Locations Feature * Added multi-region tests * Fix _AddParitionKey to pass options to sub methods * Added initial live tests * Updated live-platform-matrix for multi-region tests * Add cosmosQuery mark to TestQuery * Correct spelling * Fixed live platform matrix syntax * Changed Multi-regions * Added client level ExcludedLocation for async * Update Live test settings * Added Async tests * Add more live tests for all other Python versions * Fix Async test failure * Fix live test failures * Fix live test failures * Fix live test failures * Add test_delete_all_items_by_partition_key * Remove test_delete_all_items_by_partition_key * Added missing doc for excluded_locations in async client * Remove duplicate functions * Fix live tests with multi write locations * Fixed bug with endpoint routing with multi write region partition key API calls * Adding emulator tests for delete_all_items_by_partition_key API * minimized duplicate codes * Added Async emulator tests * Nit: Changed test names * Addressed comments about documents * Address comments about method naming * Updated document to add more details of request level excluded_locations --------- Co-authored-by: Kushagra Thapar <[email protected]>
1 parent f7384ce commit f4ce81d

28 files changed

+1749
-78
lines changed

sdk/cosmos/azure-cosmos/azure/cosmos/_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
'priority': 'priorityLevel',
6464
'no_response': 'responsePayloadOnWriteDisabled',
6565
'max_item_count': 'maxItemCount',
66+
'excluded_locations': 'excludedLocations',
6667
}
6768

6869
# Cosmos resource ID validation regex breakdown:

sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2044,6 +2044,7 @@ def PatchItem(
20442044
documents._OperationType.Patch, options)
20452045
# Patch will use WriteEndpoint since it uses PUT operation
20462046
request_params = RequestObject(resource_type, documents._OperationType.Patch)
2047+
request_params.set_excluded_location_from_options(options)
20472048
request_data = {}
20482049
if options.get("filterPredicate"):
20492050
request_data["condition"] = options.get("filterPredicate")
@@ -2132,6 +2133,7 @@ def _Batch(
21322133
headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs",
21332134
documents._OperationType.Batch, options)
21342135
request_params = RequestObject("docs", documents._OperationType.Batch)
2136+
request_params.set_excluded_location_from_options(options)
21352137
return cast(
21362138
Tuple[List[Dict[str, Any]], CaseInsensitiveDict],
21372139
self.__Post(path, request_params, batch_operations, headers, **kwargs)
@@ -2190,8 +2192,9 @@ def DeleteAllItemsByPartitionKey(
21902192
path = '{}{}/{}'.format(path, "operations", "partitionkeydelete")
21912193
collection_id = base.GetResourceIdOrFullNameFromLink(collection_link)
21922194
headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id,
2193-
"partitionkey", documents._OperationType.Delete, options)
2194-
request_params = RequestObject("partitionkey", documents._OperationType.Delete)
2195+
http_constants.ResourceType.PartitionKey, documents._OperationType.Delete, options)
2196+
request_params = RequestObject(http_constants.ResourceType.PartitionKey, documents._OperationType.Delete)
2197+
request_params.set_excluded_location_from_options(options)
21952198
_, last_response_headers = self.__Post(
21962199
path=path,
21972200
request_params=request_params,
@@ -2647,6 +2650,7 @@ def Create(
26472650
# Create will use WriteEndpoint since it uses POST operation
26482651

26492652
request_params = RequestObject(typ, documents._OperationType.Create)
2653+
request_params.set_excluded_location_from_options(options)
26502654
result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs)
26512655
self.last_response_headers = last_response_headers
26522656

@@ -2693,6 +2697,7 @@ def Upsert(
26932697

26942698
# Upsert will use WriteEndpoint since it uses POST operation
26952699
request_params = RequestObject(typ, documents._OperationType.Upsert)
2700+
request_params.set_excluded_location_from_options(options)
26962701
result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs)
26972702
self.last_response_headers = last_response_headers
26982703
# update session for write request
@@ -2736,6 +2741,7 @@ def Replace(
27362741
options)
27372742
# Replace will use WriteEndpoint since it uses PUT operation
27382743
request_params = RequestObject(typ, documents._OperationType.Replace)
2744+
request_params.set_excluded_location_from_options(options)
27392745
result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs)
27402746
self.last_response_headers = last_response_headers
27412747

@@ -2777,6 +2783,7 @@ def Read(
27772783
headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options)
27782784
# Read will use ReadEndpoint since it uses GET operation
27792785
request_params = RequestObject(typ, documents._OperationType.Read)
2786+
request_params.set_excluded_location_from_options(options)
27802787
result, last_response_headers = self.__Get(path, request_params, headers, **kwargs)
27812788
self.last_response_headers = last_response_headers
27822789
if response_hook:
@@ -2816,6 +2823,7 @@ def DeleteResource(
28162823
options)
28172824
# Delete will use WriteEndpoint since it uses DELETE operation
28182825
request_params = RequestObject(typ, documents._OperationType.Delete)
2826+
request_params.set_excluded_location_from_options(options)
28192827
result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs)
28202828
self.last_response_headers = last_response_headers
28212829

@@ -3052,6 +3060,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]:
30523060
resource_type,
30533061
documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed
30543062
)
3063+
request_params.set_excluded_location_from_options(options)
30553064
headers = base.GetHeaders(
30563065
self,
30573066
initial_headers,
@@ -3090,6 +3099,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]:
30903099

30913100
# Query operations will use ReadEndpoint even though it uses POST(for regular query operations)
30923101
request_params = RequestObject(resource_type, documents._OperationType.SqlQuery)
3102+
request_params.set_excluded_location_from_options(options)
30933103
req_headers = base.GetHeaders(
30943104
self,
30953105
initial_headers,
@@ -3256,7 +3266,7 @@ def _AddPartitionKey(
32563266
options: Mapping[str, Any]
32573267
) -> Dict[str, Any]:
32583268
collection_link = base.TrimBeginningAndEndingSlashes(collection_link)
3259-
partitionKeyDefinition = self._get_partition_key_definition(collection_link)
3269+
partitionKeyDefinition = self._get_partition_key_definition(collection_link, options)
32603270
new_options = dict(options)
32613271
# If the collection doesn't have a partition key definition, skip it as it's a legacy collection
32623272
if partitionKeyDefinition:
@@ -3358,15 +3368,19 @@ def _UpdateSessionIfRequired(
33583368
# update session
33593369
self.session.update_session(response_result, response_headers)
33603370

3361-
def _get_partition_key_definition(self, collection_link: str) -> Optional[Dict[str, Any]]:
3371+
def _get_partition_key_definition(
3372+
self,
3373+
collection_link: str,
3374+
options: Mapping[str, Any]
3375+
) -> Optional[Dict[str, Any]]:
33623376
partition_key_definition: Optional[Dict[str, Any]]
33633377
# If the document collection link is present in the cache, then use the cached partitionkey definition
33643378
if collection_link in self.__container_properties_cache:
33653379
cached_container: Dict[str, Any] = self.__container_properties_cache.get(collection_link, {})
33663380
partition_key_definition = cached_container.get("partitionKey")
33673381
# Else read the collection from backend and add it to the cache
33683382
else:
3369-
container = self.ReadContainer(collection_link)
3383+
container = self.ReadContainer(collection_link, options)
33703384
partition_key_definition = container.get("partitionKey")
33713385
self.__container_properties_cache[collection_link] = _set_properties_cache(container)
33723386
return partition_key_definition

sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,8 @@ def __init__(self, client):
5050
self.DefaultEndpoint = client.url_connection
5151
self.refresh_time_interval_in_ms = self.get_refresh_time_interval_in_ms_stub()
5252
self.location_cache = LocationCache(
53-
self.PreferredLocations,
5453
self.DefaultEndpoint,
55-
self.EnableEndpointDiscovery,
56-
client.connection_policy.UseMultipleWriteLocations
54+
client.connection_policy
5755
)
5856
self.refresh_needed = False
5957
self.refresh_lock = threading.RLock()

sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@
2525
import collections
2626
import logging
2727
import time
28-
from typing import Set
28+
from typing import Set, Mapping, List
2929
from urllib.parse import urlparse
3030

3131
from . import documents
3232
from . import http_constants
3333
from .documents import _OperationType
34+
from ._request_object import RequestObject
3435

3536
# pylint: disable=protected-access
3637

@@ -113,7 +114,10 @@ def get_endpoints_by_location(new_locations,
113114
except Exception as e:
114115
raise e
115116

116-
return endpoints_by_location, parsed_locations
117+
# Also store a hash map of endpoints for each location
118+
locations_by_endpoints = {value.get_primary(): key for key, value in endpoints_by_location.items()}
119+
120+
return endpoints_by_location, locations_by_endpoints, parsed_locations
117121

118122
def add_endpoint_if_preferred(endpoint: str, preferred_endpoints: Set[str], endpoints: Set[str]) -> bool:
119123
if endpoint in preferred_endpoints:
@@ -150,31 +154,44 @@ def _get_health_check_endpoints(
150154

151155
return endpoints
152156

157+
def _get_applicable_regional_routing_contexts(regional_routing_contexts: List[RegionalRoutingContext],
158+
location_name_by_endpoint: Mapping[str, str],
159+
fall_back_regional_routing_context: RegionalRoutingContext,
160+
exclude_location_list: List[str]) -> List[RegionalRoutingContext]:
161+
# filter endpoints by excluded locations
162+
applicable_regional_routing_contexts = []
163+
for regional_routing_context in regional_routing_contexts:
164+
if location_name_by_endpoint.get(regional_routing_context.get_primary()) not in exclude_location_list:
165+
applicable_regional_routing_contexts.append(regional_routing_context)
166+
167+
# if endpoint is empty add fallback endpoint
168+
if not applicable_regional_routing_contexts:
169+
applicable_regional_routing_contexts.append(fall_back_regional_routing_context)
170+
171+
return applicable_regional_routing_contexts
153172

154173
class LocationCache(object): # pylint: disable=too-many-public-methods,too-many-instance-attributes
155174
def current_time_millis(self):
156175
return int(round(time.time() * 1000))
157176

158177
def __init__(
159178
self,
160-
preferred_locations,
161179
default_endpoint,
162-
enable_endpoint_discovery,
163-
use_multiple_write_locations,
180+
connection_policy,
164181
):
165-
self.preferred_locations = preferred_locations
166182
self.default_regional_routing_context = RegionalRoutingContext(default_endpoint, default_endpoint)
167-
self.enable_endpoint_discovery = enable_endpoint_discovery
168-
self.use_multiple_write_locations = use_multiple_write_locations
169183
self.enable_multiple_writable_locations = False
170184
self.write_regional_routing_contexts = [self.default_regional_routing_context]
171185
self.read_regional_routing_contexts = [self.default_regional_routing_context]
172186
self.location_unavailability_info_by_endpoint = {}
173187
self.last_cache_update_time_stamp = 0
174188
self.account_read_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long
175189
self.account_write_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long
190+
self.account_locations_by_read_regional_routing_context = {} # pylint: disable=name-too-long
191+
self.account_locations_by_write_regional_routing_context = {} # pylint: disable=name-too-long
176192
self.account_write_locations = []
177193
self.account_read_locations = []
194+
self.connection_policy = connection_policy
178195

179196
def get_write_regional_routing_contexts(self):
180197
return self.write_regional_routing_contexts
@@ -207,6 +224,44 @@ def get_ordered_write_locations(self):
207224
def get_ordered_read_locations(self):
208225
return self.account_read_locations
209226

227+
def _get_configured_excluded_locations(self, request: RequestObject) -> List[str]:
228+
# If excluded locations were configured on request, use request level excluded locations.
229+
excluded_locations = request.excluded_locations
230+
if excluded_locations is None:
231+
# If excluded locations were only configured on client(connection_policy), use client level
232+
excluded_locations = self.connection_policy.ExcludedLocations
233+
return excluded_locations
234+
235+
def _get_applicable_read_regional_routing_contexts(self, request: RequestObject) -> List[RegionalRoutingContext]:
236+
# Get configured excluded locations
237+
excluded_locations = self._get_configured_excluded_locations(request)
238+
239+
# If excluded locations were configured, return filtered regional endpoints by excluded locations.
240+
if excluded_locations:
241+
return _get_applicable_regional_routing_contexts(
242+
self.get_read_regional_routing_contexts(),
243+
self.account_locations_by_read_regional_routing_context,
244+
self.get_write_regional_routing_contexts()[0],
245+
excluded_locations)
246+
247+
# Else, return all regional endpoints
248+
return self.get_read_regional_routing_contexts()
249+
250+
def _get_applicable_write_regional_routing_contexts(self, request: RequestObject) -> List[RegionalRoutingContext]:
251+
# Get configured excluded locations
252+
excluded_locations = self._get_configured_excluded_locations(request)
253+
254+
# If excluded locations were configured, return filtered regional endpoints by excluded locations.
255+
if excluded_locations:
256+
return _get_applicable_regional_routing_contexts(
257+
self.get_write_regional_routing_contexts(),
258+
self.account_locations_by_write_regional_routing_context,
259+
self.default_regional_routing_context,
260+
excluded_locations)
261+
262+
# Else, return all regional endpoints
263+
return self.get_write_regional_routing_contexts()
264+
210265
def resolve_service_endpoint(self, request):
211266
if request.location_endpoint_to_route:
212267
return request.location_endpoint_to_route
@@ -227,7 +282,7 @@ def resolve_service_endpoint(self, request):
227282
# For non-document resource types in case of client can use multiple write locations
228283
# or when client cannot use multiple write locations, flip-flop between the
229284
# first and the second writable region in DatabaseAccount (for manual failover)
230-
if self.enable_endpoint_discovery and self.account_write_locations:
285+
if self.connection_policy.EnableEndpointDiscovery and self.account_write_locations:
231286
location_index = min(location_index % 2, len(self.account_write_locations) - 1)
232287
write_location = self.account_write_locations[location_index]
233288
if (self.account_write_regional_routing_contexts_by_location
@@ -247,9 +302,9 @@ def resolve_service_endpoint(self, request):
247302
return self.default_regional_routing_context.get_primary()
248303

249304
regional_routing_contexts = (
250-
self.get_write_regional_routing_contexts()
305+
self._get_applicable_write_regional_routing_contexts(request)
251306
if documents._OperationType.IsWriteOperation(request.operation_type)
252-
else self.get_read_regional_routing_contexts()
307+
else self._get_applicable_read_regional_routing_contexts(request)
253308
)
254309
regional_routing_context = regional_routing_contexts[location_index % len(regional_routing_contexts)]
255310
if (
@@ -263,12 +318,14 @@ def resolve_service_endpoint(self, request):
263318
return regional_routing_context.get_primary()
264319

265320
def should_refresh_endpoints(self): # pylint: disable=too-many-return-statements
266-
most_preferred_location = self.preferred_locations[0] if self.preferred_locations else None
321+
most_preferred_location = self.connection_policy.PreferredLocations[0] \
322+
if self.connection_policy.PreferredLocations else None
267323

268324
# we should schedule refresh in background if we are unable to target the user's most preferredLocation.
269-
if self.enable_endpoint_discovery:
325+
if self.connection_policy.EnableEndpointDiscovery:
270326

271-
should_refresh = self.use_multiple_write_locations and not self.enable_multiple_writable_locations
327+
should_refresh = (self.connection_policy.UseMultipleWriteLocations
328+
and not self.enable_multiple_writable_locations)
272329

273330
if (most_preferred_location and most_preferred_location in
274331
self.account_read_regional_routing_contexts_by_location):
@@ -358,25 +415,27 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl
358415
if enable_multiple_writable_locations:
359416
self.enable_multiple_writable_locations = enable_multiple_writable_locations
360417

361-
if self.enable_endpoint_discovery:
418+
if self.connection_policy.EnableEndpointDiscovery:
362419
if read_locations:
363420
(self.account_read_regional_routing_contexts_by_location,
421+
self.account_locations_by_read_regional_routing_context,
364422
self.account_read_locations) = get_endpoints_by_location(
365423
read_locations,
366424
self.account_read_regional_routing_contexts_by_location,
367425
self.default_regional_routing_context,
368426
False,
369-
self.use_multiple_write_locations
427+
self.connection_policy.UseMultipleWriteLocations
370428
)
371429

372430
if write_locations:
373431
(self.account_write_regional_routing_contexts_by_location,
432+
self.account_locations_by_write_regional_routing_context,
374433
self.account_write_locations) = get_endpoints_by_location(
375434
write_locations,
376435
self.account_write_regional_routing_contexts_by_location,
377436
self.default_regional_routing_context,
378437
True,
379-
self.use_multiple_write_locations
438+
self.connection_policy.UseMultipleWriteLocations
380439
)
381440

382441
self.write_regional_routing_contexts = self.get_preferred_regional_routing_contexts(
@@ -399,18 +458,18 @@ def get_preferred_regional_routing_contexts(
399458
regional_endpoints = []
400459
# if enableEndpointDiscovery is false, we always use the defaultEndpoint that
401460
# user passed in during documentClient init
402-
if self.enable_endpoint_discovery and endpoints_by_location: # pylint: disable=too-many-nested-blocks
461+
if self.connection_policy.EnableEndpointDiscovery and endpoints_by_location: # pylint: disable=too-many-nested-blocks
403462
if (
404463
self.can_use_multiple_write_locations()
405464
or expected_available_operation == EndpointOperationType.ReadType
406465
):
407466
unavailable_endpoints = []
408-
if self.preferred_locations:
467+
if self.connection_policy.PreferredLocations:
409468
# When client can not use multiple write locations, preferred locations
410469
# list should only be used determining read endpoints order. If client
411470
# can use multiple write locations, preferred locations list should be
412471
# used for determining both read and write endpoints order.
413-
for location in self.preferred_locations:
472+
for location in self.connection_policy.PreferredLocations:
414473
regional_endpoint = endpoints_by_location[location] if location in endpoints_by_location \
415474
else None
416475
if regional_endpoint:
@@ -436,11 +495,12 @@ def get_preferred_regional_routing_contexts(
436495
return regional_endpoints
437496

438497
def can_use_multiple_write_locations(self):
439-
return self.use_multiple_write_locations and self.enable_multiple_writable_locations
498+
return self.connection_policy.UseMultipleWriteLocations and self.enable_multiple_writable_locations
440499

441500
def can_use_multiple_write_locations_for_request(self, request): # pylint: disable=name-too-long
442501
return self.can_use_multiple_write_locations() and (
443502
request.resource_type == http_constants.ResourceType.Document
503+
or request.resource_type == http_constants.ResourceType.PartitionKey
444504
or (
445505
request.resource_type == http_constants.ResourceType.StoredProcedure
446506
and request.operation_type == documents._OperationType.ExecuteJavaScript

0 commit comments

Comments
 (0)