Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cosmos] add service retry logic #39394

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches
resource_id: Optional[str],
resource_type: str,
options: Mapping[str, Any],
operation_type: str,
partition_key_range_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Gets HTTP request headers.
Expand Down Expand Up @@ -323,6 +324,11 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches
if resource_type != 'dbs' and options.get("containerRID"):
headers[http_constants.HttpHeaders.IntendedCollectionRID] = options["containerRID"]

if resource_type == "":
resource_type = "databaseaccount"
headers[http_constants.HttpHeaders.ThinClientProxyResourceType] = resource_type
headers[http_constants.HttpHeaders.ThinClientProxyOperationType] = operation_type

return headers


Expand Down
33 changes: 22 additions & 11 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2038,7 +2038,8 @@ def PatchItem(
if options is None:
options = {}

headers = base.GetHeaders(self, self.default_headers, "patch", path, document_id, resource_type, options)
headers = base.GetHeaders(self, self.default_headers, "patch", path, document_id, resource_type,
documents._OperationType.Patch, options)
# Patch will use WriteEndpoint since it uses PUT operation
request_params = RequestObject(resource_type, documents._OperationType.Patch)
request_data = {}
Expand Down Expand Up @@ -2126,7 +2127,8 @@ def _Batch(
) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]:
initial_headers = self.default_headers.copy()
base._populate_batch_headers(initial_headers)
headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", options)
headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs",
documents._OperationType.Batch, options)
request_params = RequestObject("docs", documents._OperationType.Batch)
return cast(
Tuple[List[Dict[str, Any]], CaseInsensitiveDict],
Expand Down Expand Up @@ -2185,7 +2187,8 @@ def DeleteAllItemsByPartitionKey(
# Specified url to perform background operation to delete all items by partition key
path = '{}{}/{}'.format(path, "operations", "partitionkeydelete")
collection_id = base.GetResourceIdOrFullNameFromLink(collection_link)
headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id, "partitionkey", options)
headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id,
"partitionkey", documents._OperationType.Delete, options)
request_params = RequestObject("partitionkey", documents._OperationType.Delete)
_, last_response_headers = self.__Post(
path=path,
Expand Down Expand Up @@ -2353,7 +2356,8 @@ def ExecuteStoredProcedure(

path = base.GetPathFromLink(sproc_link)
sproc_id = base.GetResourceIdOrFullNameFromLink(sproc_link)
headers = base.GetHeaders(self, initial_headers, "post", path, sproc_id, "sprocs", options)
headers = base.GetHeaders(self, initial_headers, "post", path, sproc_id, "sprocs",
documents._OperationType.ExecuteJavaScript, options)

# ExecuteStoredProcedure will use WriteEndpoint since it uses POST operation
request_params = RequestObject("sprocs", documents._OperationType.ExecuteJavaScript)
Expand Down Expand Up @@ -2550,7 +2554,8 @@ def GetDatabaseAccount(
if url_connection is None:
url_connection = self.url_connection

headers = base.GetHeaders(self, self.default_headers, "get", "", "", "", {})
headers = base.GetHeaders(self, self.default_headers, "get", "", "", "",
documents._OperationType.Read,{})
request_params = RequestObject("databaseaccount", documents._OperationType.Read, url_connection)
result, last_response_headers = self.__Get("", request_params, headers, **kwargs)
self.last_response_headers = last_response_headers
Expand Down Expand Up @@ -2615,7 +2620,8 @@ def Create(
options = {}

initial_headers = initial_headers or self.default_headers
headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, options)
headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Create,
options)
# Create will use WriteEndpoint since it uses POST operation

request_params = RequestObject(typ, documents._OperationType.Create)
Expand Down Expand Up @@ -2659,7 +2665,8 @@ def Upsert(
options = {}

initial_headers = initial_headers or self.default_headers
headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, options)
headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Upsert,
options)
headers[http_constants.HttpHeaders.IsUpsert] = True

# Upsert will use WriteEndpoint since it uses POST operation
Expand Down Expand Up @@ -2703,7 +2710,8 @@ def Replace(
options = {}

initial_headers = initial_headers or self.default_headers
headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, options)
headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace,
options)
# Replace will use WriteEndpoint since it uses PUT operation
request_params = RequestObject(typ, documents._OperationType.Replace)
result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs)
Expand Down Expand Up @@ -2744,7 +2752,7 @@ def Read(
options = {}

initial_headers = initial_headers or self.default_headers
headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, options)
headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options)
# Read will use ReadEndpoint since it uses GET operation
request_params = RequestObject(typ, documents._OperationType.Read)
result, last_response_headers = self.__Get(path, request_params, headers, **kwargs)
Expand Down Expand Up @@ -2782,7 +2790,8 @@ def DeleteResource(
options = {}

initial_headers = initial_headers or self.default_headers
headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, options)
headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete,
options)
# Delete will use WriteEndpoint since it uses DELETE operation
request_params = RequestObject(typ, documents._OperationType.Delete)
result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs)
Expand Down Expand Up @@ -3027,6 +3036,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]:
path,
resource_id,
resource_type,
request_params.operation_type,
options,
partition_key_range_id
)
Expand Down Expand Up @@ -3064,6 +3074,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]:
path,
resource_id,
resource_type,
documents._OperationType.SqlQuery,
options,
partition_key_range_id
)
Expand Down Expand Up @@ -3334,4 +3345,4 @@ def _get_partition_key_definition(self, collection_link: str) -> Optional[Dict[s
container = self.ReadContainer(collection_link)
partition_key_definition = container.get("partitionKey")
self.__container_properties_cache[collection_link] = _set_properties_cache(container)
return partition_key_definition
return partition_key_definition
17 changes: 8 additions & 9 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,23 +159,22 @@ def should_refresh_endpoints(self): # pylint: disable=too-many-return-statement

should_refresh = self.use_multiple_write_locations and not self.enable_multiple_writable_locations

if most_preferred_location:
if self.available_read_endpoint_by_locations:
most_preferred_read_endpoint = self.available_read_endpoint_by_locations[most_preferred_location]
if most_preferred_read_endpoint and most_preferred_read_endpoint != self.read_endpoints[0]:
# For reads, we can always refresh in background as we can alternate to
# other available read endpoints
return True
else:
if most_preferred_location and most_preferred_location in self.available_read_endpoint_by_locations:
most_preferred_read_endpoint = self.available_read_endpoint_by_locations[most_preferred_location]
if most_preferred_read_endpoint and most_preferred_read_endpoint != self.read_endpoints[0]:
# For reads, we can always refresh in background as we can alternate to
# other available read endpoints
return True
else:
return True

if not self.can_use_multiple_write_locations():
if self.is_endpoint_unavailable(self.write_endpoints[0], EndpointOperationType.WriteType):
# Since most preferred write endpoint is unavailable, we can only refresh in background if
# we have an alternate write endpoint
return True
return should_refresh
if most_preferred_location:
if most_preferred_location and most_preferred_location in self.available_write_endpoint_by_locations:
most_preferred_write_endpoint = self.available_write_endpoint_by_locations[most_preferred_location]
if most_preferred_write_endpoint:
should_refresh |= most_preferred_write_endpoint != self.write_endpoints[0]
Expand Down
54 changes: 48 additions & 6 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
"""Internal methods for executing functions in the Azure Cosmos database service.
"""
import json
from requests.exceptions import ReadTimeout, ConnectTimeout
import time
from typing import Optional

from azure.core.exceptions import AzureError, ClientAuthenticationError, ServiceRequestError
from azure.core.exceptions import AzureError, ClientAuthenticationError, ServiceRequestError, ServiceResponseError
from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import RetryPolicy
from azure.core.pipeline.transport._base import HttpRequest

from . import exceptions
from . import _endpoint_discovery_retry_policy
Expand All @@ -38,6 +38,7 @@
from . import _gone_retry_policy
from . import _timeout_failover_retry_policy
from . import _container_recreate_retry_policy
from . import _service_response_retry_policy
from .http_constants import HttpHeaders, StatusCodes, SubStatusCodes


Expand Down Expand Up @@ -78,8 +79,11 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs):
timeout_failover_retry_policy = _timeout_failover_retry_policy._TimeoutFailoverRetryPolicy(
client.connection_policy, global_endpoint_manager, *args
)
service_response_retry_policy = _service_response_retry_policy.ServiceResponseRetryPolicy(
client.connection_policy, global_endpoint_manager, *args,
)
# HttpRequest we would need to modify for Container Recreate Retry Policy
request: Optional[HttpRequest] = None
request = None
if args and len(args) > 3:
# Reference HttpRequest instance in args
request = args[3]
Expand Down Expand Up @@ -188,6 +192,14 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs):
if kwargs['timeout'] <= 0:
raise exceptions.CosmosClientTimeoutError()

except ServiceResponseError as e:
if e.exc_type in [ReadTimeout, ConnectTimeout]:
_handle_service_retries(request, client, service_response_retry_policy, args)
else:
raise

except ServiceRequestError:
_handle_service_retries(request, client, service_response_retry_policy, args)

def ExecuteFunction(function, *args, **kwargs):
"""Stub method so that it can be used for mocking purposes as well.
Expand All @@ -198,6 +210,24 @@ def ExecuteFunction(function, *args, **kwargs):
"""
return function(*args, **kwargs)

def _has_retryable_headers(request_headers):
if (request_headers.get(HttpHeaders.ThinClientProxyResourceType) in ["docs"]
and request_headers.get(HttpHeaders.ThinClientProxyOperationType) in ["Read", "Query", "QueryPlan",
"ReadFeed", "SqlQuery"]):
return True
return False

def _handle_service_retries(request, client, response_retry_policy, *args):
if _has_retryable_headers(request.headers):
# we resolve the request endpoint to the next preferred region
# once we are out of preferred regions we stop retrying
retry_policy = response_retry_policy
if not retry_policy.ShouldRetry():
if args and args[0].should_clear_session_token_on_session_read_failure and client.session:
client.session.clear_session_token(client.last_response_headers)
raise
else:
raise

def _configure_timeout(request: PipelineRequest, absolute: Optional[int], per_request: int) -> None:
if absolute is not None:
Expand Down Expand Up @@ -243,7 +273,6 @@ def send(self, request):
start_time = time.time()
try:
_configure_timeout(request, absolute_timeout, per_request_timeout)

response = self.next.send(request)
if self.is_retry(retry_settings, response):
retry_active = self.increment(retry_settings, response=response)
Expand All @@ -261,6 +290,9 @@ def send(self, request):
timeout_error.history = retry_settings['history']
raise
except ServiceRequestError as err:
if _has_retryable_headers(request.http_request.headers):
# raise exception immediately to be dealt with in client retry policies
raise err
# the request ran into a socket timeout or failed to establish a new connection
# since request wasn't sent, we retry up to however many connection retries are configured (default 3)
if retry_settings['connect'] > 0:
Expand All @@ -269,18 +301,28 @@ def send(self, request):
self.sleep(retry_settings, request.context.transport)
continue
raise err
except ServiceResponseError as err:
retry_error = err
if err.exc_type in [ReadTimeout, ConnectTimeout]:
if _has_retryable_headers(request.http_request.headers):
# raise exception immediately to be dealt with in client retry policies
raise err
retry_active = self.increment(retry_settings, response=request, error=err)
if retry_active:
self.sleep(retry_settings, request.context.transport)
continue
raise err
except AzureError as err:
retry_error = err
if self._is_method_retryable(retry_settings, request.http_request):
retry_active = self.increment(retry_settings, response=request, error=err)
if retry_active:
self.sleep(retry_settings, request.context.transport)
continue
raise err
finally:
end_time = time.time()
if absolute_timeout:
absolute_timeout -= (end_time - start_time)

self.update_context(response.context, retry_settings)
return response
return response
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# The MIT License (MIT)
# Copyright (c) Microsoft Corporation. All rights reserved.

"""Internal class for service response read errors implementation in the Azure
Cosmos database service.
"""

class ServiceResponseRetryPolicy(object):

def __init__(self, connection_policy, global_endpoint_manager, *args):
self.args = args
self.global_endpoint_manager = global_endpoint_manager
self.total_retries = len(self.global_endpoint_manager.location_cache.read_endpoints)
self.failover_retry_count = 0
self.connection_policy = connection_policy
self.request = args[0] if args else None
if self.request:
self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request)

def ShouldRetry(self):
"""Returns true if the request should retry based on preferred regions and retries already done.

"""
if not self.connection_policy.EnableEndpointDiscovery:
return False
if self.args[0].operation_type != 'Read' and self.args[0].resource_type != 'docs':
return False

self.failover_retry_count += 1
if self.failover_retry_count > self.total_retries:
return False

if self.request:
# clear previous location-based routing directive
self.request.clear_route_to_location()

# set location-based routing directive based on retry count
# ensuring usePreferredLocations is set to True for retry
self.request.route_to_location_with_preferred_location_flag(self.failover_retry_count, True)

# Resolve the endpoint for the request and pin the resolution to the resolved endpoint
# This enables marking the endpoint unavailability on endpoint failover/unreachability
self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request)
self.request.route_to_location(self.location_endpoint)
return True
Loading
Loading