Skip to content

Commit e7736e3

Browse files
authored
Add client id to all request headers (#42104)
* Add client id to all request headers * Update CHANGELOG.md * Small update * remove passing in client id in requests update client connection to not pass in client id and instead have it set to headers in base.getheaders
1 parent 984d7e9 commit e7736e3

File tree

5 files changed

+50
-2
lines changed

5 files changed

+50
-2
lines changed

sdk/cosmos/azure-cosmos/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
* Fixed bug where container cache was not being properly updated resulting in unnecessary extra requests. See [PR 42143](https://github.com/Azure/azure-sdk-for-python/pull/42143).
1212

1313
#### Other Changes
14+
* Changed to include client id in headers for all requests. See [PR 42104](https://github.com/Azure/azure-sdk-for-python/pull/42104).
1415

1516
### 4.14.0b1 (2025-07-14)
1617

sdk/cosmos/azure-cosmos/azure/cosmos/_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,8 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches
291291

292292
if client_id is not None:
293293
headers[http_constants.HttpHeaders.ClientId] = client_id
294+
elif cosmos_client_connection and cosmos_client_connection.client_id:
295+
headers[http_constants.HttpHeaders.ClientId] = cosmos_client_connection.client_id
294296

295297
if options.get("enableScriptLogging"):
296298
headers[http_constants.HttpHeaders.EnableScriptLogging] = options["enableScriptLogging"]

sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2823,7 +2823,8 @@ def Read(
28232823
options = {}
28242824

28252825
initial_headers = initial_headers or self.default_headers
2826-
headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options)
2826+
headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read,
2827+
options)
28272828
# Read will use ReadEndpoint since it uses GET operation
28282829
request_params = RequestObject(typ, documents._OperationType.Read, headers)
28292830
request_params.set_excluded_location_from_options(options)

sdk/cosmos/azure-cosmos/tests/test_headers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ def side_effect_correlated_activity_id(self, *args, **kwargs):
5656
assert args[3]["x-ms-cosmos-correlated-activityid"] # cspell:disable-line
5757
raise StopIteration
5858

59+
def side_effect_client_id(self, *args, **kwargs):
60+
# Extract request headers from args
61+
assert args[2][http_constants.HttpHeaders.ClientId]
62+
raise StopIteration
63+
5964
def test_correlated_activity_id(self):
6065
query = 'SELECT * from c ORDER BY c._ts'
6166

@@ -98,6 +103,20 @@ def test_negative_max_integrated_cache_staleness(self):
98103
except Exception as exception:
99104
assert isinstance(exception, ValueError)
100105

106+
def test_client_id(self):
107+
# Client ID should be sent on every request, Verify it is sent on a read_item request
108+
cosmos_client_connection = self.container.client_connection
109+
original_connection_get = cosmos_client_connection._CosmosClientConnection__Get
110+
cosmos_client_connection._CosmosClientConnection__Get = MagicMock(
111+
side_effect=self.side_effect_client_id)
112+
try:
113+
self.container.read_item(item="id-1", partition_key="pk-1")
114+
except StopIteration:
115+
pass
116+
finally:
117+
cosmos_client_connection._CosmosClientConnection__Get = original_connection_get
118+
119+
101120
def test_client_level_throughput_bucket(self):
102121
cosmos_client.CosmosClient(self.host, self.masterKey,
103122
throughput_bucket=client_throughput_bucket_number,

sdk/cosmos/azure-cosmos/tests/test_headers_async.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33

44
import unittest
5+
from unittest.mock import MagicMock
56

67
import pytest
78
import uuid
89

910

1011
import test_config
1112
from azure.cosmos import http_constants
12-
import azure.cosmos.exceptions as exceptions
1313
from azure.cosmos.aio import CosmosClient, _retry_utility_async, DatabaseProxy
1414
from azure.cosmos.partition_key import PartitionKey
1515

@@ -23,6 +23,12 @@ async def request_raw_response_hook(response):
2323
assert (response.http_request.headers[http_constants.HttpHeaders.ThroughputBucket]
2424
== str(request_throughput_bucket_number))
2525

26+
27+
class ClientIDVerificationError(Exception):
28+
"""Custom exception for client ID verification errors."""
29+
pass
30+
31+
2632
@pytest.mark.cosmosEmulator
2733
class TestHeadersAsync(unittest.IsolatedAsyncioTestCase):
2834
client: CosmosClient = None
@@ -206,5 +212,24 @@ async def test_container_read_item_negative_throughput_bucket_async(self):
206212
assert "specified for the header 'x-ms-cosmos-throughput-bucket' is invalid." in e.http_error_message
207213
"""
208214

215+
async def side_effect_client_id(self, *args, **kwargs):
216+
# This is a side effect to verify that the client ID is sent in the request headers
217+
assert args[2].get(http_constants.HttpHeaders.ClientId) is not None
218+
raise ClientIDVerificationError("Client ID verification complete")
219+
220+
async def test_client_id(self):
221+
# Client ID should be sent on every request, Verify it is sent on a read_item request
222+
cosmos_client_connection = self.container.client_connection
223+
original_connection_get = cosmos_client_connection._CosmosClientConnection__Get
224+
cosmos_client_connection._CosmosClientConnection__Get = MagicMock(
225+
side_effect=self.side_effect_client_id)
226+
try:
227+
await self.container.read_item(item="id-1", partition_key="pk-1")
228+
except ClientIDVerificationError:
229+
pass
230+
finally:
231+
cosmos_client_connection._CosmosClientConnection__Get = original_connection_get
232+
233+
209234
if __name__ == "__main__":
210235
unittest.main()

0 commit comments

Comments
 (0)