Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions redis/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,21 @@ class EvictionPolicyType(Enum):

@dataclass(frozen=True)
class CacheKey:
"""
Represents a unique key for a cache entry.

Attributes:
command (str): The Redis command being cached.
redis_keys (tuple): The Redis keys involved in the command.
redis_args (tuple): Additional arguments for the Redis command.
This field is included in the cache key to ensure uniqueness
when commands have the same keys but different arguments.
Changing this field will affect cache key uniqueness.
"""

command: str
redis_keys: tuple
redis_args: tuple = () # Additional arguments for the Redis command; affects cache key uniqueness.


class CacheEntry:
Expand Down
6 changes: 4 additions & 2 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,9 @@ def send_command(self, *args, **kwargs):
with self._cache_lock:
# Command is write command or not allowed
# to be cached.
if not self._cache.is_cachable(CacheKey(command=args[0], redis_keys=())):
if not self._cache.is_cachable(
CacheKey(command=args[0], redis_keys=(), redis_args=())
):
self._current_command_cache_key = None
self._conn.send_command(*args, **kwargs)
return
Expand All @@ -1224,7 +1226,7 @@ def send_command(self, *args, **kwargs):

# Creates cache key.
self._current_command_cache_key = CacheKey(
command=args[0], redis_keys=tuple(kwargs.get("keys"))
command=args[0], redis_keys=tuple(kwargs.get("keys")), redis_args=args
)

with self._cache_lock:
Expand Down
53 changes: 43 additions & 10 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,12 +348,17 @@ def test_format_error_message(conn, error, expected_message):


def test_network_connection_failure():
exp_err = f"Error {ECONNREFUSED} connecting to localhost:9999. Connection refused."
# Match only the stable part of the error message across OS
exp_err = rf"Error {ECONNREFUSED} connecting to localhost:9999\."
with pytest.raises(ConnectionError, match=exp_err):
redis = Redis(port=9999)
redis.set("a", "b")


@pytest.mark.skipif(
not hasattr(socket, "AF_UNIX"),
reason="Unix domain sockets not supported on this platform",
)
def test_unix_socket_connection_failure():
exp_err = "Error 2 connecting to unix:///tmp/a.sock. No such file or directory."
with pytest.raises(ConnectionError, match=exp_err):
Expand Down Expand Up @@ -463,25 +468,33 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection):
None,
None,
CacheEntry(
cache_key=CacheKey(command="GET", redis_keys=("foo",)),
cache_key=CacheKey(
command="GET", redis_keys=("foo",), redis_args=("GET", "foo")
),
cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE,
status=CacheEntryStatus.IN_PROGRESS,
connection_ref=mock_connection,
),
CacheEntry(
cache_key=CacheKey(command="GET", redis_keys=("foo",)),
cache_key=CacheKey(
command="GET", redis_keys=("foo",), redis_args=("GET", "foo")
),
cache_value=b"bar",
status=CacheEntryStatus.VALID,
connection_ref=mock_connection,
),
CacheEntry(
cache_key=CacheKey(command="GET", redis_keys=("foo",)),
cache_key=CacheKey(
command="GET", redis_keys=("foo",), redis_args=("GET", "foo")
),
cache_value=b"bar",
status=CacheEntryStatus.VALID,
connection_ref=mock_connection,
),
CacheEntry(
cache_key=CacheKey(command="GET", redis_keys=("foo",)),
cache_key=CacheKey(
command="GET", redis_keys=("foo",), redis_args=("GET", "foo")
),
cache_value=b"bar",
status=CacheEntryStatus.VALID,
connection_ref=mock_connection,
Expand All @@ -503,15 +516,23 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection):
[
call(
CacheEntry(
cache_key=CacheKey(command="GET", redis_keys=("foo",)),
cache_key=CacheKey(
command="GET",
redis_keys=("foo",),
redis_args=("GET", "foo"),
),
cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE,
status=CacheEntryStatus.IN_PROGRESS,
connection_ref=mock_connection,
)
),
call(
CacheEntry(
cache_key=CacheKey(command="GET", redis_keys=("foo",)),
cache_key=CacheKey(
command="GET",
redis_keys=("foo",),
redis_args=("GET", "foo"),
),
cache_value=b"bar",
status=CacheEntryStatus.VALID,
connection_ref=mock_connection,
Expand All @@ -522,9 +543,21 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection):

mock_cache.get.assert_has_calls(
[
call(CacheKey(command="GET", redis_keys=("foo",))),
call(CacheKey(command="GET", redis_keys=("foo",))),
call(CacheKey(command="GET", redis_keys=("foo",))),
call(
CacheKey(
command="GET", redis_keys=("foo",), redis_args=("GET", "foo")
)
),
call(
CacheKey(
command="GET", redis_keys=("foo",), redis_args=("GET", "foo")
)
),
call(
CacheKey(
command="GET", redis_keys=("foo",), redis_args=("GET", "foo")
)
),
]
)

Expand Down
Loading