Skip to content

Commit

Permalink
perf: Enhance memory usage by fixing SDK clients caching
Browse files Browse the repository at this point in the history
  • Loading branch information
clemlesne committed Jan 16, 2025
1 parent 7a1861a commit 3f72173
Show file tree
Hide file tree
Showing 32 changed files with 136 additions and 91 deletions.
48 changes: 44 additions & 4 deletions app/helpers/cache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import functools
from collections import OrderedDict
from collections.abc import AsyncGenerator, Awaitable
from contextlib import asynccontextmanager
from functools import wraps

from aiojobs import Scheduler

Expand All @@ -20,17 +20,17 @@ async def get_scheduler() -> AsyncGenerator[Scheduler, None]:
yield scheduler


def async_lru_cache(maxsize: int = 128):
def lru_acache(maxsize: int = 128):
"""
Caches a function's return value each time it is called.
Caches an async function's return value each time it is called.
If the maxsize is reached, the least recently used value is removed.
"""

def decorator(func):
cache: OrderedDict[tuple, Awaitable] = OrderedDict()

@functools.wraps(func)
@wraps(func)
async def wrapper(*args, **kwargs) -> Awaitable:
# Create a cache key from event loop, args and kwargs, using frozenset for kwargs to ensure hashability
key = (
Expand All @@ -49,6 +49,46 @@ async def wrapper(*args, **kwargs) -> Awaitable:
cache[key] = value
cache.move_to_end(key)

# Remove the least recently used key if the cache is full
if len(cache) > maxsize:
cache.popitem(last=False)

return value

return wrapper

return decorator


def lru_cache(maxsize: int = 128):
"""
Caches a sync function's return value each time it is called.
If the maxsize is reached, the least recently used value is removed.
"""

def decorator(func):
cache: OrderedDict[tuple, Awaitable] = OrderedDict()

@wraps(func)
def wrapper(*args, **kwargs) -> Awaitable:
# Create a cache key from args and kwargs, using frozenset for kwargs to ensure hashability
key = (
args,
frozenset(kwargs.items()),
)

if key in cache:
# Move the recently accessed key to the end (most recently used)
cache.move_to_end(key)
return cache[key]

# Compute the value since it's not cached
value = func(*args, **kwargs)
cache[key] = value
cache.move_to_end(key)

# Remove the least recently used key if the cache is full
if len(cache) > maxsize:
cache.popitem(last=False)

Expand Down
4 changes: 2 additions & 2 deletions app/helpers/call_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
from app.models.next import NextModel
from app.models.synthesis import SynthesisModel

_sms = CONFIG.sms.instance()
_db = CONFIG.database.instance()
_sms = CONFIG.sms.instance
_db = CONFIG.database.instance


@tracer.start_as_current_span("on_new_call")
Expand Down
2 changes: 1 addition & 1 deletion app/helpers/call_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
extract_message_style,
)

_db = CONFIG.database.instance()
_db = CONFIG.database.instance


# TODO: Refacto, this function is too long
Expand Down
6 changes: 3 additions & 3 deletions app/helpers/call_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
from noisereduce import reduce_noise

from app.helpers.cache import async_lru_cache
from app.helpers.cache import lru_acache
from app.helpers.config import CONFIG
from app.helpers.features import (
recognition_stt_complete_timeout_ms,
Expand Down Expand Up @@ -71,7 +71,7 @@
r"[^\w\sÀ-ÿ'«»“”\"\"‘’''(),.!?;:\-\+_@/&€$%=]" # noqa: RUF001
) # Sanitize text for TTS

_db = CONFIG.database.instance()
_db = CONFIG.database.instance


class CallHangupException(Exception):
Expand Down Expand Up @@ -526,7 +526,7 @@ def _detect_hangup() -> Generator[None, None, None]:
raise e


@async_lru_cache()
@lru_acache()
async def _use_call_client(
client: CallAutomationClient, voice_id: str
) -> CallConnectionClient:
Expand Down
6 changes: 3 additions & 3 deletions app/helpers/config_models/ai_search.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from functools import lru_cache
from functools import cached_property

from pydantic import BaseModel, Field

Expand All @@ -17,14 +17,14 @@ class AiSearchModel(BaseModel, frozen=True):
strictness: float = Field(default=2, ge=0, le=5)
top_n_documents: int = Field(default=5, ge=1)

@lru_cache
@cached_property
def instance(self) -> ISearch:
from app.helpers.config import CONFIG
from app.persistence.ai_search import (
AiSearchSearch,
)

return AiSearchSearch(
cache=CONFIG.cache.instance(),
cache=CONFIG.cache.instance,
config=self,
)
11 changes: 6 additions & 5 deletions app/helpers/config_models/cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from functools import lru_cache
from functools import cached_property

from pydantic import BaseModel, Field, SecretStr, ValidationInfo, field_validator

Expand All @@ -16,7 +16,7 @@ class ModeEnum(str, Enum):
class MemoryModel(BaseModel, frozen=True):
max_size: int = Field(default=128, ge=10)

@lru_cache
@cached_property
def instance(self) -> ICache:
from app.persistence.memory import (
MemoryCache,
Expand All @@ -32,7 +32,7 @@ class RedisModel(BaseModel, frozen=True):
port: int = 6379
ssl: bool = True

@lru_cache
@cached_property
def instance(self) -> ICache:
from app.persistence.redis import (
RedisCache,
Expand Down Expand Up @@ -68,10 +68,11 @@ def _validate_memory(
raise ValueError("Memory config required")
return memory

@cached_property
def instance(self) -> ICache:
if self.mode == ModeEnum.MEMORY:
assert self.memory
return self.memory.instance()
return self.memory.instance

assert self.redis
return self.redis.instance()
return self.redis.instance
9 changes: 5 additions & 4 deletions app/helpers/config_models/database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from functools import lru_cache
from functools import cached_property

from pydantic import BaseModel

Expand All @@ -10,21 +10,22 @@ class CosmosDbModel(BaseModel, frozen=True):
database: str
endpoint: str

@lru_cache
@cached_property
def instance(self) -> IStore:
from app.helpers.config import CONFIG
from app.persistence.cosmos_db import (
CosmosDbStore,
)

return CosmosDbStore(
cache=CONFIG.cache.instance(),
cache=CONFIG.cache.instance,
config=self,
)


class DatabaseModel(BaseModel):
cosmos_db: CosmosDbModel

@cached_property
def instance(self) -> IStore:
return self.cosmos_db.instance()
return self.cosmos_db.instance
6 changes: 3 additions & 3 deletions app/helpers/config_models/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from azure.ai.inference.aio import ChatCompletionsClient
from pydantic import BaseModel

from app.helpers.cache import async_lru_cache
from app.helpers.cache import lru_acache
from app.helpers.http import azure_transport
from app.helpers.identity import credential

Expand All @@ -14,8 +14,8 @@ class DeploymentModel(BaseModel, frozen=True):
seed: int = 42 # Reproducible results
temperature: float = 0.0 # Most focused and deterministic

@async_lru_cache()
async def instance(self) -> tuple[ChatCompletionsClient, "DeploymentModel"]:
@lru_acache()
async def client(self) -> tuple[ChatCompletionsClient, "DeploymentModel"]:
return ChatCompletionsClient(
# Reliability
seed=self.seed,
Expand Down
10 changes: 5 additions & 5 deletions app/helpers/config_models/queue.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from functools import lru_cache
from functools import cached_property

from pydantic import BaseModel

Expand All @@ -10,7 +10,7 @@ class QueueModel(BaseModel, frozen=True):
sms_name: str
training_name: str

@lru_cache
@cached_property
def call(self):
from app.persistence.azure_queue_storage import AzureQueueStorage

Expand All @@ -19,7 +19,7 @@ def call(self):
name=self.call_name,
)

@lru_cache
@cached_property
def post(self):
from app.persistence.azure_queue_storage import AzureQueueStorage

Expand All @@ -28,7 +28,7 @@ def post(self):
name=self.post_name,
)

@lru_cache
@cached_property
def sms(self):
from app.persistence.azure_queue_storage import AzureQueueStorage

Expand All @@ -37,7 +37,7 @@ def sms(self):
name=self.sms_name,
)

@lru_cache
@cached_property
def training(self):
from app.persistence.azure_queue_storage import AzureQueueStorage

Expand Down
11 changes: 6 additions & 5 deletions app/helpers/config_models/sms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from functools import lru_cache
from functools import cached_property

from pydantic import BaseModel, SecretStr, ValidationInfo, field_validator

Expand All @@ -21,7 +21,7 @@ class CommunicationServiceModel(BaseModel, frozen=True):
Model is purely empty to fit to the `ISms` interface and the "mode" enum code organization. As the Communication Services is also used as the only call interface, it is not necessary to duplicate the models.
"""

@lru_cache
@cached_property
def instance(self) -> ISms:
from app.helpers.config import CONFIG
from app.persistence.communication_services import (
Expand All @@ -36,7 +36,7 @@ class TwilioModel(BaseModel, frozen=True):
auth_token: SecretStr
phone_number: PhoneNumber

@lru_cache
@cached_property
def instance(self) -> ISms:
from app.persistence.twilio import (
TwilioSms,
Expand Down Expand Up @@ -77,10 +77,11 @@ def _validate_twilio(
raise ValueError("Twilio config required")
return twilio

@cached_property
def instance(self) -> ISms:
if self.mode == ModeEnum.COMMUNICATION_SERVICES:
assert self.communication_services
return self.communication_services.instance()
return self.communication_services.instance

assert self.twilio
return self.twilio.instance()
return self.twilio.instance
4 changes: 2 additions & 2 deletions app/helpers/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from azure.appconfiguration.aio import AzureAppConfigurationClient
from azure.core.exceptions import ResourceNotFoundError

from app.helpers.cache import async_lru_cache
from app.helpers.cache import lru_acache
from app.helpers.config import CONFIG
from app.helpers.config_models.cache import MemoryModel
from app.helpers.http import azure_transport
Expand Down Expand Up @@ -233,7 +233,7 @@ async def _get(key: str, type_res: type[T]) -> T | None:
)


@async_lru_cache()
@lru_acache()
async def _use_client() -> AzureAppConfigurationClient:
"""
Generate the App Configuration client and close it after use.
Expand Down
10 changes: 5 additions & 5 deletions app/helpers/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from azure.core.pipeline.transport._aiohttp import AioHttpTransport
from twilio.http.async_http_client import AsyncTwilioHttpClient

from app.helpers.cache import async_lru_cache
from app.helpers.cache import lru_acache


@async_lru_cache()
@lru_acache()
async def _aiohttp_cookie_jar() -> DummyCookieJar:
"""
Create a cookie jar mock for AIOHTTP.
Expand All @@ -24,7 +24,7 @@ async def _aiohttp_cookie_jar() -> DummyCookieJar:
return DummyCookieJar()


@async_lru_cache()
@lru_acache()
async def aiohttp_session() -> ClientSession:
"""
Create an AIOHTTP session.
Expand All @@ -48,7 +48,7 @@ async def aiohttp_session() -> ClientSession:
)


@async_lru_cache()
@lru_acache()
async def azure_transport() -> AioHttpTransport:
"""
Create an AIOHTTP transport, for Azure SDK.
Expand All @@ -64,7 +64,7 @@ async def azure_transport() -> AioHttpTransport:
)


@async_lru_cache()
@lru_acache()
async def twilio_http() -> AsyncTwilioHttpClient:
"""
Create a Twilio HTTP client.
Expand Down
6 changes: 3 additions & 3 deletions app/helpers/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@

from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider

from app.helpers.cache import async_lru_cache
from app.helpers.cache import lru_acache
from app.helpers.http import azure_transport


@async_lru_cache()
@lru_acache()
async def credential() -> DefaultAzureCredential:
return DefaultAzureCredential(
# Performance
transport=await azure_transport(),
)


@async_lru_cache()
@lru_acache()
async def token(service: str) -> Callable[[], Awaitable[str]]:
return get_bearer_token_provider(await credential(), service)
Loading

0 comments on commit 3f72173

Please sign in to comment.