Skip to content

Commit 1b15e34

Browse files
kushagraThaparsimorenohtvaron3jeet1995
authored
Port cross region retries functionality shipped as hotfix to main branch (#39417)
* implementation * Update _retry_utility_async.py * changelog, versions, fixes * fixes * remove fake logic, count fix * Update _service_request_retry_policy.py * Update _retry_utility_async.py * retry utilities fixing * Update _retry_utility.py * additional enhancements * Update setup.py * Update _retry_utility_async.py * add tests, remove previous retry logic for ServiceRequestExceptions * clean up with finally * tests * retry utilities * disable tests * add logging to policies * GetDatabaseAccount Fix * Update _base.py * retry utilities fixes * Update _retry_utility.py * retry utulities part 34 * Update _service_request_retry_policy.py * remove extra logs * policy updates * Update _service_response_retry_policy.py * Update _service_response_retry_policy.py * policies updates and update operation types * trying out fixes * Update sdk/cosmos/azure-cosmos/CHANGELOG.md Co-authored-by: Abhijeet Mohanty <[email protected]> * Update sdk/cosmos/azure-cosmos/CHANGELOG.md Co-authored-by: Abhijeet Mohanty <[email protected]> * Skipped proxy test for debugging * annotation fix * Fixed some tests cases * test fixes * Update test_service_retry_policies_async.py * Fixed some mocking behavior * fixed pylint issues * Added aiohttp minimum dependency * Updated changelog and setup.py * Updated changelog --------- Co-authored-by: Simon Moreno <[email protected]> Co-authored-by: tvaron3 <[email protected]> Co-authored-by: Abhijeet Mohanty <[email protected]>
1 parent 8ac2529 commit 1b15e34

18 files changed

+850
-66
lines changed

sdk/cosmos/azure-cosmos/CHANGELOG.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
## Release History
22

3-
### 4.9.1b2 (Unreleased)
3+
### 4.9.1b2 (2025-01-24)
44

55
#### Features Added
6-
7-
#### Breaking Changes
6+
* Added new cross-regional retry logic for `ServiceRequestError` and `ServiceResponseError` exceptions. See [PR 39396](https://github.com/Azure/azure-sdk-for-python/pull/39396)
87

98
#### Bugs Fixed
9+
* Fixed `KeyError` being returned by location cache when most preferred location is not present in cached regions. See [PR 39396](https://github.com/Azure/azure-sdk-for-python/pull/39396).
10+
* Fixed cross-region retries on `CosmosClient` initialization. See [PR 39396](https://github.com/Azure/azure-sdk-for-python/pull/39396)
1011

1112
#### Other Changes
13+
* This release requires aiohttp version 3.10.11 and above. See [PR 39396](https://github.com/Azure/azure-sdk-for-python/pull/39396)
1214

1315
### 4.9.1b1 (2024-12-13)
1416

sdk/cosmos/azure-cosmos/azure/cosmos/_base.py

+7
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches
116116
path: str,
117117
resource_id: Optional[str],
118118
resource_type: str,
119+
operation_type: str,
119120
options: Mapping[str, Any],
120121
partition_key_range_id: Optional[str] = None,
121122
) -> Dict[str, Any]:
@@ -127,6 +128,7 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches
127128
:param str path:
128129
:param str resource_id:
129130
:param str resource_type:
131+
:param str operation_type:
130132
:param dict options:
131133
:param str partition_key_range_id:
132134
:return: The HTTP request headers.
@@ -323,6 +325,11 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches
323325
if resource_type != 'dbs' and options.get("containerRID"):
324326
headers[http_constants.HttpHeaders.IntendedCollectionRID] = options["containerRID"]
325327

328+
if resource_type == "":
329+
resource_type = http_constants.ResourceType.DatabaseAccount
330+
headers[http_constants.HttpHeaders.ThinClientProxyResourceType] = resource_type
331+
headers[http_constants.HttpHeaders.ThinClientProxyOperationType] = operation_type
332+
326333
return headers
327334

328335

sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -2038,7 +2038,8 @@ def PatchItem(
20382038
if options is None:
20392039
options = {}
20402040

2041-
headers = base.GetHeaders(self, self.default_headers, "patch", path, document_id, resource_type, options)
2041+
headers = base.GetHeaders(self, self.default_headers, "patch", path, document_id, resource_type,
2042+
documents._OperationType.Patch, options)
20422043
# Patch will use WriteEndpoint since it uses PUT operation
20432044
request_params = RequestObject(resource_type, documents._OperationType.Patch)
20442045
request_data = {}
@@ -2126,7 +2127,8 @@ def _Batch(
21262127
) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]:
21272128
initial_headers = self.default_headers.copy()
21282129
base._populate_batch_headers(initial_headers)
2129-
headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", options)
2130+
headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs",
2131+
documents._OperationType.Batch, options)
21302132
request_params = RequestObject("docs", documents._OperationType.Batch)
21312133
return cast(
21322134
Tuple[List[Dict[str, Any]], CaseInsensitiveDict],
@@ -2185,7 +2187,8 @@ def DeleteAllItemsByPartitionKey(
21852187
# Specified url to perform background operation to delete all items by partition key
21862188
path = '{}{}/{}'.format(path, "operations", "partitionkeydelete")
21872189
collection_id = base.GetResourceIdOrFullNameFromLink(collection_link)
2188-
headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id, "partitionkey", options)
2190+
headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id,
2191+
"partitionkey", documents._OperationType.Delete, options)
21892192
request_params = RequestObject("partitionkey", documents._OperationType.Delete)
21902193
_, last_response_headers = self.__Post(
21912194
path=path,
@@ -2353,7 +2356,8 @@ def ExecuteStoredProcedure(
23532356

23542357
path = base.GetPathFromLink(sproc_link)
23552358
sproc_id = base.GetResourceIdOrFullNameFromLink(sproc_link)
2356-
headers = base.GetHeaders(self, initial_headers, "post", path, sproc_id, "sprocs", options)
2359+
headers = base.GetHeaders(self, initial_headers, "post", path, sproc_id, "sprocs",
2360+
documents._OperationType.ExecuteJavaScript, options)
23572361

23582362
# ExecuteStoredProcedure will use WriteEndpoint since it uses POST operation
23592363
request_params = RequestObject("sprocs", documents._OperationType.ExecuteJavaScript)
@@ -2550,7 +2554,8 @@ def GetDatabaseAccount(
25502554
if url_connection is None:
25512555
url_connection = self.url_connection
25522556

2553-
headers = base.GetHeaders(self, self.default_headers, "get", "", "", "", {})
2557+
headers = base.GetHeaders(self, self.default_headers, "get", "", "", "",
2558+
documents._OperationType.Read,{})
25542559
request_params = RequestObject("databaseaccount", documents._OperationType.Read, url_connection)
25552560
result, last_response_headers = self.__Get("", request_params, headers, **kwargs)
25562561
self.last_response_headers = last_response_headers
@@ -2615,7 +2620,8 @@ def Create(
26152620
options = {}
26162621

26172622
initial_headers = initial_headers or self.default_headers
2618-
headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, options)
2623+
headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Create,
2624+
options)
26192625
# Create will use WriteEndpoint since it uses POST operation
26202626

26212627
request_params = RequestObject(typ, documents._OperationType.Create)
@@ -2659,7 +2665,8 @@ def Upsert(
26592665
options = {}
26602666

26612667
initial_headers = initial_headers or self.default_headers
2662-
headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, options)
2668+
headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Upsert,
2669+
options)
26632670
headers[http_constants.HttpHeaders.IsUpsert] = True
26642671

26652672
# Upsert will use WriteEndpoint since it uses POST operation
@@ -2703,7 +2710,8 @@ def Replace(
27032710
options = {}
27042711

27052712
initial_headers = initial_headers or self.default_headers
2706-
headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, options)
2713+
headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace,
2714+
options)
27072715
# Replace will use WriteEndpoint since it uses PUT operation
27082716
request_params = RequestObject(typ, documents._OperationType.Replace)
27092717
result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs)
@@ -2744,7 +2752,7 @@ def Read(
27442752
options = {}
27452753

27462754
initial_headers = initial_headers or self.default_headers
2747-
headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, options)
2755+
headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options)
27482756
# Read will use ReadEndpoint since it uses GET operation
27492757
request_params = RequestObject(typ, documents._OperationType.Read)
27502758
result, last_response_headers = self.__Get(path, request_params, headers, **kwargs)
@@ -2782,7 +2790,8 @@ def DeleteResource(
27822790
options = {}
27832791

27842792
initial_headers = initial_headers or self.default_headers
2785-
headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, options)
2793+
headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete,
2794+
options)
27862795
# Delete will use WriteEndpoint since it uses DELETE operation
27872796
request_params = RequestObject(typ, documents._OperationType.Delete)
27882797
result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs)
@@ -3027,6 +3036,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]:
30273036
path,
30283037
resource_id,
30293038
resource_type,
3039+
request_params.operation_type,
30303040
options,
30313041
partition_key_range_id
30323042
)
@@ -3064,6 +3074,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]:
30643074
path,
30653075
resource_id,
30663076
resource_type,
3077+
documents._OperationType.SqlQuery,
30673078
options,
30683079
partition_key_range_id
30693080
)

sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
"""
2525

2626
import threading
27-
2827
from urllib.parse import urlparse
2928

29+
from azure.core.exceptions import AzureError
30+
3031
from . import _constants as constants
3132
from . import exceptions
3233
from ._location_cache import LocationCache
@@ -134,14 +135,14 @@ def _GetDatabaseAccount(self, **kwargs):
134135
# specified (by creating a locational endpoint) and keeping eating the exception
135136
# until we get the database account and return None at the end, if we are not able
136137
# to get that info from any endpoints
137-
except exceptions.CosmosHttpResponseError:
138+
except (exceptions.CosmosHttpResponseError, AzureError):
138139
for location_name in self.PreferredLocations:
139140
locational_endpoint = _GlobalEndpointManager.GetLocationalEndpoint(self.DefaultEndpoint, location_name)
140141
try:
141142
database_account = self._GetDatabaseAccountStub(locational_endpoint, **kwargs)
142143
self._database_account_cache = database_account
143144
return database_account
144-
except exceptions.CosmosHttpResponseError:
145+
except (exceptions.CosmosHttpResponseError, AzureError):
145146
pass
146147
raise
147148

sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -159,23 +159,22 @@ def should_refresh_endpoints(self): # pylint: disable=too-many-return-statement
159159

160160
should_refresh = self.use_multiple_write_locations and not self.enable_multiple_writable_locations
161161

162-
if most_preferred_location:
163-
if self.available_read_endpoint_by_locations:
164-
most_preferred_read_endpoint = self.available_read_endpoint_by_locations[most_preferred_location]
165-
if most_preferred_read_endpoint and most_preferred_read_endpoint != self.read_endpoints[0]:
166-
# For reads, we can always refresh in background as we can alternate to
167-
# other available read endpoints
168-
return True
169-
else:
162+
if most_preferred_location and most_preferred_location in self.available_read_endpoint_by_locations:
163+
most_preferred_read_endpoint = self.available_read_endpoint_by_locations[most_preferred_location]
164+
if most_preferred_read_endpoint and most_preferred_read_endpoint != self.read_endpoints[0]:
165+
# For reads, we can always refresh in background as we can alternate to
166+
# other available read endpoints
170167
return True
168+
else:
169+
return True
171170

172171
if not self.can_use_multiple_write_locations():
173172
if self.is_endpoint_unavailable(self.write_endpoints[0], EndpointOperationType.WriteType):
174173
# Since most preferred write endpoint is unavailable, we can only refresh in background if
175174
# we have an alternate write endpoint
176175
return True
177176
return should_refresh
178-
if most_preferred_location:
177+
if most_preferred_location and most_preferred_location in self.available_write_endpoint_by_locations:
179178
most_preferred_write_endpoint = self.available_write_endpoint_by_locations[most_preferred_location]
180179
if most_preferred_write_endpoint:
181180
should_refresh |= most_preferred_write_endpoint != self.write_endpoints[0]

sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py

+57-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
import time
2626
from typing import Optional
2727

28-
from azure.core.exceptions import AzureError, ClientAuthenticationError, ServiceRequestError
28+
from requests.exceptions import ( # pylint: disable=networking-import-outside-azure-core-transport
29+
ReadTimeout, ConnectTimeout) # pylint: disable=networking-import-outside-azure-core-transport
30+
from azure.core.exceptions import AzureError, ClientAuthenticationError, ServiceRequestError, ServiceResponseError
2931
from azure.core.pipeline import PipelineRequest
3032
from azure.core.pipeline.policies import RetryPolicy
3133

@@ -37,6 +39,8 @@
3739
from . import _gone_retry_policy
3840
from . import _timeout_failover_retry_policy
3941
from . import _container_recreate_retry_policy
42+
from . import _service_request_retry_policy, _service_response_retry_policy
43+
from .documents import _OperationType
4044
from .http_constants import HttpHeaders, StatusCodes, SubStatusCodes
4145

4246

@@ -77,6 +81,12 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs):
7781
timeout_failover_retry_policy = _timeout_failover_retry_policy._TimeoutFailoverRetryPolicy(
7882
client.connection_policy, global_endpoint_manager, *args
7983
)
84+
service_response_retry_policy = _service_response_retry_policy.ServiceResponseRetryPolicy(
85+
client.connection_policy, global_endpoint_manager, *args,
86+
)
87+
service_request_retry_policy = _service_request_retry_policy.ServiceRequestRetryPolicy(
88+
client.connection_policy, global_endpoint_manager, *args,
89+
)
8090
# HttpRequest we would need to modify for Container Recreate Retry Policy
8191
request = None
8292
if args and len(args) > 3:
@@ -187,6 +197,16 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs):
187197
if kwargs['timeout'] <= 0:
188198
raise exceptions.CosmosClientTimeoutError()
189199

200+
except ServiceRequestError as e:
201+
_handle_service_request_retries(client, service_request_retry_policy, e, *args)
202+
203+
except ServiceResponseError as e:
204+
if e.exc_type == ReadTimeout:
205+
_handle_service_response_retries(request, client, service_response_retry_policy, e, *args)
206+
elif e.exc_type == ConnectTimeout:
207+
_handle_service_request_retries(client, service_request_retry_policy, e, *args)
208+
else:
209+
raise
190210

191211
def ExecuteFunction(function, *args, **kwargs):
192212
"""Stub method so that it can be used for mocking purposes as well.
@@ -197,6 +217,31 @@ def ExecuteFunction(function, *args, **kwargs):
197217
"""
198218
return function(*args, **kwargs)
199219

220+
def _has_read_retryable_headers(request_headers):
221+
if _OperationType.IsReadOnlyOperation(request_headers.get(HttpHeaders.ThinClientProxyOperationType)):
222+
return True
223+
return False
224+
225+
def _handle_service_request_retries(client, request_retry_policy, exception, *args):
226+
# we resolve the request endpoint to the next preferred region
227+
# once we are out of preferred regions we stop retrying
228+
retry_policy = request_retry_policy
229+
if not retry_policy.ShouldRetry():
230+
if args and args[0].should_clear_session_token_on_session_read_failure and client.session:
231+
client.session.clear_session_token(client.last_response_headers)
232+
raise exception
233+
234+
def _handle_service_response_retries(request, client, response_retry_policy, exception, *args):
235+
if _has_read_retryable_headers(request.headers):
236+
# we resolve the request endpoint to the next preferred region
237+
# once we are out of preferred regions we stop retrying
238+
retry_policy = response_retry_policy
239+
if not retry_policy.ShouldRetry():
240+
if args and args[0].should_clear_session_token_on_session_read_failure and client.session:
241+
client.session.clear_session_token(client.last_response_headers)
242+
raise exception
243+
else:
244+
raise exception
200245

201246
def _configure_timeout(request: PipelineRequest, absolute: Optional[int], per_request: int) -> None:
202247
if absolute is not None:
@@ -242,7 +287,6 @@ def send(self, request):
242287
start_time = time.time()
243288
try:
244289
_configure_timeout(request, absolute_timeout, per_request_timeout)
245-
246290
response = self.next.send(request)
247291
if self.is_retry(retry_settings, response):
248292
retry_active = self.increment(retry_settings, response=response)
@@ -261,8 +305,17 @@ def send(self, request):
261305
raise
262306
except ServiceRequestError as err:
263307
# the request ran into a socket timeout or failed to establish a new connection
264-
# since request wasn't sent, we retry up to however many connection retries are configured (default 3)
265-
if retry_settings['connect'] > 0:
308+
# since request wasn't sent, raise exception immediately to be dealt with in client retry policies
309+
raise err
310+
except ServiceResponseError as err:
311+
retry_error = err
312+
if err.exc_type == ReadTimeout:
313+
if _has_read_retryable_headers(request.http_request.headers):
314+
# raise exception immediately to be dealt with in client retry policies
315+
raise err
316+
elif err.exc_type == ConnectTimeout:
317+
raise err
318+
if self._is_method_retryable(retry_settings, request.http_request):
266319
retry_active = self.increment(retry_settings, response=request, error=err)
267320
if retry_active:
268321
self.sleep(retry_settings, request.context.transport)

0 commit comments

Comments
 (0)