Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions fastapi_cache/backends/redis.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Union, Optional, AnyStr

import aioredis
Expand All @@ -6,8 +7,9 @@
from .base import BaseCacheBackend

DEFAULT_ENCODING = 'utf-8'
DEFAULT_POOL_MIN_SIZE = 5
CACHE_KEY = 'REDIS'
# a singleton sentinel value for parameter defaults
_sentinel = object()

# expected to be of bytearray, bytes, float, int, or str type

Expand All @@ -19,11 +21,16 @@ class RedisCacheBackend(BaseCacheBackend[RedisKey, RedisValue]):
def __init__(
self,
address: str,
pool_minsize: Optional[int] = DEFAULT_POOL_MIN_SIZE,
pool_minsize: Optional[int] = _sentinel,
encoding: Optional[str] = DEFAULT_ENCODING,
) -> None:
self._redis_address = address
self._redis_pool_minsize = pool_minsize
if pool_minsize is not _sentinel:
warnings.warn(
"Parameter 'pool_minsize' has been obsolete since aioredis 2.0.0.",
DeprecationWarning,
)

self._encoding = encoding

self._pool: Optional[Redis] = None
Expand All @@ -36,10 +43,7 @@ async def _client(self) -> Redis:
return self._pool

async def _create_connection(self) -> Redis:
return await aioredis.create_redis_pool(
self._redis_address,
minsize=self._redis_pool_minsize,
)
return aioredis.from_url(self._redis_address)

async def add(
self,
Expand Down Expand Up @@ -75,10 +79,12 @@ async def get(
default: RedisValue = None,
**kwargs,
) -> AnyStr:
kwargs.setdefault('encoding', self._encoding)
encoding = kwargs.pop("encoding", self._encoding)

client = await self._client
cached_value = await client.get(key, **kwargs)
if encoding is not None and isinstance(cached_value, bytes):
cached_value = cached_value.decode(encoding)

return cached_value if cached_value is not None else default

Expand Down Expand Up @@ -118,5 +124,6 @@ async def expire(

async def close(self) -> None:
client = await self._client
client.close()
await client.wait_closed()
# Redis.close() only close currrent connection, but not the pool
await client.connection_pool.disconnect()
await client.close()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@
'redis', 'aioredis', 'asyncio', 'fastapi', 'starlette', 'cache'
],
install_requires=[
'aioredis==1.3.1',
'aioredis==2.0.0',
],
)
5 changes: 3 additions & 2 deletions tests/redis_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ async def test_should_add_n_get_data_no_encoding(
) -> None:
NO_ENCODING_KEY = 'bytes'
NO_ENCODING_VALUE = b'test'
await f_backend.expire(NO_ENCODING_KEY, 0)
is_added = await f_backend.add(NO_ENCODING_KEY, NO_ENCODING_VALUE)

assert is_added is True
Expand Down Expand Up @@ -165,8 +166,8 @@ async def test_close_should_close_connection(
f_backend: RedisCacheBackend
) -> None:
await f_backend.close()
with pytest.raises(aioredis.errors.PoolClosedError):
await f_backend.add(TEST_KEY, TEST_VALUE)
assert len(f_backend._pool.connection_pool._available_connections) == 0
assert len(f_backend._pool.connection_pool._in_use_connections) == 0


@pytest.mark.asyncio
Expand Down