Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3624,11 +3624,11 @@ def refresh_routing_map_provider(
)
else:
# Full refresh - create a new provider instance. This clears all cached routing maps.
self._routing_map_provider = routing_map_provider.SmartRoutingMapProvider(self)
self._routing_map_provider.clear_cache()
return

# Fallback to full refresh when targeted refresh fails transiently.
self._routing_map_provider = routing_map_provider.SmartRoutingMapProvider(self)
self._routing_map_provider.clear_cache()

def _refresh_container_properties_cache(self, container_link: str):
# If container properties cache is stale, refresh it by reading the container.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ class _PartitionHealthInfo(object):
"""
This internal class keeps the health and statistics for a partition.
"""
# __slots__ reduces per-instance memory by using a fixed-size C array
# instead of a per-instance __dict__. Significant when tracking many partitions.
__slots__ = (
'write_failure_count',
'read_failure_count',
'write_success_count',
'read_success_count',
'read_consecutive_failure_count',
'write_consecutive_failure_count',
'unavailability_info',
)

def __init__(self) -> None:
self.write_failure_count: int = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .collection_routing_map import CollectionRoutingMap, _build_routing_map_from_ranges
from . import routing_range
from .routing_range import (
PKRange,
PartitionKeyRange,
_is_sorted_and_non_overlapping,
_subtract_range,
Expand Down Expand Up @@ -186,7 +187,7 @@ def process_fetched_ranges(
# Incremental update -- merge deltas into the existing map.
# Resolve parent chains transitively within this single delta so cascading
# splits (A->B+C and B->D+E in one payload) can be merged incrementally.
range_tuples: List[Tuple[Dict[str, Any], Any]] = []
range_tuples: List[Tuple[Any, Any]] = []
known_range_info_by_id = {
pkr_id: pkr_tuple[1]
for pkr_id, pkr_tuple in previous_routing_map._rangeById.items() # pylint: disable=protected-access
Expand All @@ -209,7 +210,11 @@ def process_fetched_ranges(
next_unresolved.append(r)
continue

range_tuples.append((r, range_info))
range_tuples.append((PKRange(
id=r[PartitionKeyRange.Id],
minInclusive=r[PartitionKeyRange.MinInclusive],
maxExclusive=r[PartitionKeyRange.MaxExclusive],
parents=r.get(PartitionKeyRange.Parents)), range_info))
known_range_info_by_id[r[PartitionKeyRange.Id]] = range_info
progress_made = True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"""
import asyncio # pylint: disable=do-not-import-asyncio
import logging
import threading
from typing import Dict, Any, Optional, List, TYPE_CHECKING
from azure.core.utils import CaseInsensitiveDict
from ... import _base, http_constants
Expand All @@ -41,6 +42,11 @@

if TYPE_CHECKING:
from ...aio._cosmos_client_connection_async import CosmosClientConnection

# Shared routing map cache across all clients targeting the same endpoint.
_shared_routing_map_cache: dict = {}
_shared_cache_lock = threading.Lock()

# pylint: disable=protected-access

logger = logging.getLogger(__name__)
Expand All @@ -64,14 +70,27 @@ def __init__(self, client: Any):
"""

self._document_client = client
self._endpoint = getattr(client, 'url_connection', '')

# keeps the cached collection routing map by collection id
self._collection_routing_map_by_item: Dict[str, CollectionRoutingMap] = {}
# Share routing map cache across clients with the same endpoint
with _shared_cache_lock:
if self._endpoint not in _shared_routing_map_cache:
_shared_routing_map_cache[self._endpoint] = {}
self._collection_routing_map_by_item = _shared_routing_map_cache[self._endpoint]
# A lock to control access to the locks dictionary itself
self._locks_lock = asyncio.Lock()
# A dictionary to hold a lock for each collection ID
self._collection_locks: Dict[str, asyncio.Lock] = {}

def clear_cache(self):
"""Clear the shared routing map cache for this endpoint."""
with _shared_cache_lock:
if self._endpoint in _shared_routing_map_cache:
_shared_routing_map_cache[self._endpoint] = {}
self._collection_routing_map_by_item = _shared_routing_map_cache.get(self._endpoint, {})

self._collection_locks = {}

async def _get_lock_for_collection(self, collection_id: str) -> asyncio.Lock:
"""Safely gets or creates a lock for a given collection ID.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from typing import Optional, Union

from azure.cosmos._routing import routing_range
from azure.cosmos._routing.routing_range import PartitionKeyRange
from azure.cosmos._routing.routing_range import PartitionKeyRange, PKRange

# pylint: disable=line-too-long
class CollectionRoutingMap(object):
Expand Down Expand Up @@ -288,7 +288,13 @@ def _build_routing_map_from_ranges(
if PartitionKeyRange.Parents in r and r[PartitionKeyRange.Parents]:
gone_range_ids.update(r[PartitionKeyRange.Parents])

filtered_ranges = [r for r in ranges if r[PartitionKeyRange.Id] not in gone_range_ids]
filtered_ranges = [
PKRange(id=r[PartitionKeyRange.Id],
minInclusive=r[PartitionKeyRange.MinInclusive],
maxExclusive=r[PartitionKeyRange.MaxExclusive],
parents=r.get(PartitionKeyRange.Parents))
for r in ranges if r[PartitionKeyRange.Id] not in gone_range_ids
]
range_tuples = [(r, True) for r in filtered_ranges]

routing_map = CollectionRoutingMap.CompleteRoutingMap(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@

if TYPE_CHECKING:
from .._cosmos_client_connection import CosmosClientConnection

# Shared routing map cache across all clients targeting the same endpoint.
_shared_routing_map_cache: dict = {}
_shared_cache_lock = threading.Lock()

# pylint: disable=protected-access, line-too-long


Expand All @@ -63,14 +68,29 @@ def __init__(self, client: Any):
"""

self._document_client = client
self._endpoint = getattr(client, 'url_connection', '')

# Share routing map cache across clients with the same endpoint
with _shared_cache_lock:
if self._endpoint not in _shared_routing_map_cache:
_shared_routing_map_cache[self._endpoint] = {}
self._collection_routing_map_by_item = _shared_routing_map_cache[self._endpoint]

# keeps the cached collection routing map by collection id
self._collection_routing_map_by_item: Dict[str, CollectionRoutingMap] = {}
# A lock to control access to the locks dictionary itself
self._locks_lock = threading.Lock()
# A dictionary to hold a lock for each collection ID
self._collection_locks: Dict[str, threading.Lock] = {}

def clear_cache(self):
"""Clear the shared routing map cache for this endpoint."""
with _shared_cache_lock:
if self._endpoint in _shared_routing_map_cache:
_shared_routing_map_cache[self._endpoint] = {}
self._collection_routing_map_by_item = _shared_routing_map_cache.get(self._endpoint, {})

self._locks_lock = threading.Lock()
self._collection_locks = {}

def _get_lock_for_collection(self, collection_id: str) -> threading.Lock:

"""Safely gets or creates a lock for a given collection ID.
Expand Down
34 changes: 31 additions & 3 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,31 @@
import json


from collections import namedtuple

_PKRangeBase = namedtuple('_PKRangeBase', ['id', 'minInclusive', 'maxExclusive', 'parents'])


class PKRange(_PKRangeBase):
"""Compact partition key range with dict-compatible access."""
__slots__ = ()

def __getitem__(self, key):
try:
return getattr(self, key)
except AttributeError as exc:
raise KeyError(key) from exc

def get(self, key, default=None):
return getattr(self, key, default)

def __contains__(self, key):
return key in self._fields

def items(self):
return zip(self._fields, self)


class PartitionKeyRange(object):
"""Partition Key Range Constants"""

Expand All @@ -37,7 +62,10 @@ class PartitionKeyRange(object):


class Range(object):
"""description of class"""
"""Range of a partition key."""
# __slots__ reduces per-instance memory from ~250 bytes to ~64 bytes.
# Significant when 100K+ partition ranges are cached per client.
__slots__ = ('min', 'max', 'isMinInclusive', 'isMaxInclusive')

MinPath = "min"
MaxPath = "max"
Expand All @@ -50,8 +78,8 @@ def __init__(self, range_min, range_max, isMinInclusive, isMaxInclusive):
if range_max is None:
raise ValueError("max is missing")

self.min = range_min.upper()
self.max = range_max.upper()
self.min = range_min if range_min == range_min.upper() else range_min.upper()
self.max = range_max if range_max == range_max.upper() else range_max.upper()
self.isMinInclusive = isMinInclusive
self.isMaxInclusive = isMaxInclusive

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3495,7 +3495,7 @@ async def refresh_routing_map_provider(
return

# Fallback to full refresh when targeted refresh fails transiently.
self._routing_map_provider = SmartRoutingMapProvider(self)
self._routing_map_provider.clear_cache()

async def _refresh_container_properties_cache(self, container_link: str):
# If container properties cache is stale, refresh it by reading the container.
Expand Down
10 changes: 10 additions & 0 deletions sdk/cosmos/azure-cosmos/cspell.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"words": [
"hdrh",
"hdrhistogram",
"perfdb",
"perfresults",
"pkrange",
"ppcb"
]
}
117 changes: 117 additions & 0 deletions sdk/cosmos/azure-cosmos/tests/routing/test_shared_pk_range_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# The MIT License (MIT)
# Copyright (c) Microsoft Corporation. All rights reserved.

import sys
import unittest

import pytest

from azure.cosmos._routing.routing_range import Range, PKRange
from azure.cosmos._routing.collection_routing_map import CollectionRoutingMap
from azure.cosmos._routing.routing_map_provider import (
PartitionKeyRangeCache,
_shared_routing_map_cache,
_shared_cache_lock,
)


class MockClient:
def __init__(self, url_connection):
self.url_connection = url_connection


@pytest.mark.cosmosEmulator
class TestSharedPartitionKeyRangeCache(unittest.TestCase):

def tearDown(self):
with _shared_cache_lock:
_shared_routing_map_cache.clear()

def test_same_endpoint_shares_cache(self):
c1 = MockClient("https://account1.documents.azure.com:443/")
c2 = MockClient("https://account1.documents.azure.com:443/")
cache1 = PartitionKeyRangeCache(c1)
cache2 = PartitionKeyRangeCache(c2)
self.assertIs(cache1._collection_routing_map_by_item,
cache2._collection_routing_map_by_item)

def test_different_endpoints_isolated(self):
c1 = MockClient("https://account1.documents.azure.com:443/")
c2 = MockClient("https://account2.documents.azure.com:443/")
cache1 = PartitionKeyRangeCache(c1)
cache2 = PartitionKeyRangeCache(c2)
self.assertIsNot(cache1._collection_routing_map_by_item,
cache2._collection_routing_map_by_item)

def test_shared_cache_populated_by_first_client(self):
c1 = MockClient("https://account1.documents.azure.com:443/")
c2 = MockClient("https://account1.documents.azure.com:443/")
cache1 = PartitionKeyRangeCache(c1)
cache2 = PartitionKeyRangeCache(c2)
pk_ranges = [{"id": "0", "minInclusive": "", "maxExclusive": "FF"}]
crm = CollectionRoutingMap.CompleteRoutingMap(
[(r, True) for r in pk_ranges], "test-collection"
)
cache1._collection_routing_map_by_item["test-collection"] = crm
self.assertIn("test-collection", cache2._collection_routing_map_by_item)
self.assertIs(cache2._collection_routing_map_by_item["test-collection"], crm)

def test_clear_cache_resets_for_endpoint(self):
c1 = MockClient("https://account1.documents.azure.com:443/")
cache1 = PartitionKeyRangeCache(c1)
cache1._collection_routing_map_by_item["coll1"] = "dummy"
cache1.clear_cache()
self.assertNotIn("coll1", cache1._collection_routing_map_by_item)

def test_clear_cache_does_not_affect_other_endpoints(self):
c1 = MockClient("https://account1.documents.azure.com:443/")
c2 = MockClient("https://account2.documents.azure.com:443/")
cache1 = PartitionKeyRangeCache(c1)
cache2 = PartitionKeyRangeCache(c2)
cache1._collection_routing_map_by_item["coll1"] = "data1"
cache2._collection_routing_map_by_item["coll2"] = "data2"
cache1.clear_cache()
self.assertNotIn("coll1", cache1._collection_routing_map_by_item)
self.assertIn("coll2", cache2._collection_routing_map_by_item)


def test_pkrange_dict_access(self):
"""PKRange supports dict-style [key] access."""
pkr = PKRange(id="1", minInclusive="00", maxExclusive="FF", parents=["0"])
self.assertEqual(pkr["id"], "1")
self.assertEqual(pkr["minInclusive"], "00")
self.assertEqual(pkr.get("parents"), ["0"])
self.assertEqual(pkr.get("_rid", "default"), "default")
self.assertIn("id", pkr)
self.assertNotIn("_rid", pkr)

def test_pkrange_in_collection_routing_map(self):
"""CollectionRoutingMap works with PKRange namedtuples."""
pk_ranges = [
PKRange(id="0", minInclusive="", maxExclusive="80", parents=None),
PKRange(id="1", minInclusive="80", maxExclusive="FF", parents=None),
]
crm = CollectionRoutingMap.CompleteRoutingMap(
[(r, True) for r in pk_ranges], "test"
)
self.assertIsNotNone(crm)
overlapping = crm.get_overlapping_ranges(Range("", "FF", True, False))
self.assertEqual(len(overlapping), 2)

def test_range_has_slots(self):
r = Range("00", "FF", True, False)
self.assertFalse(hasattr(r, "__dict__"))
self.assertLess(sys.getsizeof(r), 100)

def test_range_skips_upper_when_already_uppercase(self):
original = "05C1C9CD673398"
r = Range(original, original, True, False)
self.assertIs(r.min, original)

def test_range_applies_upper_when_lowercase(self):
r = Range("05c1c9cd", "05c1d9cd", True, False)
self.assertEqual(r.min, "05C1C9CD")


if __name__ == "__main__":
unittest.main()