Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement RESP3 support #360

Draft
wants to merge 37 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions docs/supported-commands/Redis/CONNECTION.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,33 @@
# Redis `connection` commands (4/24 implemented)
# Redis `connection` commands (9/24 implemented)

## [AUTH](https://redis.io/commands/auth/)

Authenticates the connection.

## [CLIENT GETNAME](https://redis.io/commands/client-getname/)

Returns the name of the connection.

## [CLIENT INFO](https://redis.io/commands/client-info/)

Returns information about the connection.

## [CLIENT SETINFO](https://redis.io/commands/client-setinfo/)

Sets information specific to the client or connection.

## [CLIENT SETNAME](https://redis.io/commands/client-setname/)

Sets the connection name.

## [ECHO](https://redis.io/commands/echo/)

Returns the given string.

## [HELLO](https://redis.io/commands/hello/)

Handshakes with the Redis server.

## [PING](https://redis.io/commands/ping/)

Returns the server's liveliness response.
Expand All @@ -20,10 +40,6 @@ Changes the selected database.
## Unsupported connection commands
> To implement support for a command, see [here](/guides/implement-command/)

#### [AUTH](https://redis.io/commands/auth/) <small>(not implemented)</small>

Authenticates the connection.

#### [CLIENT](https://redis.io/commands/client/) <small>(not implemented)</small>

A container for client connection commands.
Expand All @@ -32,10 +48,6 @@ A container for client connection commands.

Instructs the server whether to track the keys in the next request.

#### [CLIENT GETNAME](https://redis.io/commands/client-getname/) <small>(not implemented)</small>

Returns the name of the connection.

#### [CLIENT GETREDIR](https://redis.io/commands/client-getredir/) <small>(not implemented)</small>

Returns the client ID to which the connection's tracking notifications are redirected.
Expand All @@ -44,10 +56,6 @@ Returns the client ID to which the connection's tracking notifications are redir

Returns the unique client ID of the connection.

#### [CLIENT INFO](https://redis.io/commands/client-info/) <small>(not implemented)</small>

Returns information about the connection.

#### [CLIENT KILL](https://redis.io/commands/client-kill/) <small>(not implemented)</small>

Terminates open connections.
Expand All @@ -72,10 +80,6 @@ Suspends commands processing.

Instructs the server whether to reply to commands.

#### [CLIENT SETNAME](https://redis.io/commands/client-setname/) <small>(not implemented)</small>

Sets the connection name.

#### [CLIENT TRACKING](https://redis.io/commands/client-tracking/) <small>(not implemented)</small>

Controls server-assisted client-side caching for the connection.
Expand All @@ -92,10 +96,6 @@ Unblocks a client blocked by a blocking command from a different connection.

Resumes processing commands from paused clients.

#### [HELLO](https://redis.io/commands/hello/) <small>(not implemented)</small>

Handshakes with the Redis server.

#### [RESET](https://redis.io/commands/reset/) <small>(not implemented)</small>

Resets the connection.
Expand Down
18 changes: 17 additions & 1 deletion fakeredis/_basefakesocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@
)


def _convert_to_resp2(val: Any) -> Any:
if isinstance(val, str):
return val.encode()
if isinstance(val, float):
return Float.encode(val, humanfriendly=False)
if isinstance(val, dict):
result = list(itertools.chain(*val.items()))
return [_convert_to_resp2(item) for item in result]
if isinstance(val, (list, tuple)):
return [_convert_to_resp2(item) for item in val]
return val


def _extract_command(fields: List[bytes]) -> Tuple[Any, List[Any]]:
"""Extracts the command and command arguments from a list of `bytes` fields.

Expand Down Expand Up @@ -253,7 +266,10 @@ def _run_command(
else:
args, command_items = ret
result = func(*args) # type: ignore
assert valid_response_type(result)
if self.protocol_version == 2 and msgs.FLAG_SKIP_CONVERT_TO_RESP2 not in sig.flags:
result = _convert_to_resp2(result)
if msgs.FLAG_SKIP_CONVERT_TO_RESP2 not in sig.flags:
assert valid_response_type(result, self.protocol_version)
except SimpleError as exc:
result = exc
for command_item in command_items:
Expand Down
1 change: 1 addition & 0 deletions fakeredis/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(
"client_name",
"connected",
"server",
"protocol",
}
connection_kwargs = {
"connection_class": FakeConnection,
Expand Down
14 changes: 11 additions & 3 deletions fakeredis/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def __init__(self, value: bytes) -> None:
def decode(cls, value: bytes) -> bytes:
return value

def __repr__(self):
return f"{self.__class__.__name__}({self.value!r})"


class SimpleError(Exception):
"""Exception that will be turned into a frontend-specific exception."""
Expand Down Expand Up @@ -203,13 +206,18 @@ def __eq__(self, other: object) -> bool:
return super(object, self) == other


def valid_response_type(value: Any, nested: bool = False) -> bool:
_VALID_RESPONSE_TYPES_RESP2 = (bytes, SimpleString, SimpleError, float, int, list)
_VALID_RESPONSE_TYPES_RESP3 = (bytes, SimpleString, SimpleError, float, int, list, dict, str)


def valid_response_type(value: Any, protocol_version: int, nested: bool = False) -> bool:
if isinstance(value, NoResponse) and not nested:
return True
if value is not None and not isinstance(value, (bytes, SimpleString, SimpleError, float, int, list)):
allowed_types = _VALID_RESPONSE_TYPES_RESP2 if protocol_version == 2 else _VALID_RESPONSE_TYPES_RESP3
if value is not None and not isinstance(value, allowed_types):
return False
if isinstance(value, list):
if any(not valid_response_type(item, True) for item in value):
if any(not valid_response_type(item, protocol_version, True) for item in value):
return False
return True

Expand Down
1 change: 1 addition & 0 deletions fakeredis/_msgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,4 @@
FLAG_LEAVE_EMPTY_VAL = "v"
FLAG_TRANSACTION = "t"
FLAG_DO_NOT_CREATE = "i"
FLAG_SKIP_CONVERT_TO_RESP2 = "2"
2 changes: 1 addition & 1 deletion fakeredis/commands.json

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions fakeredis/commands_mixins/hash_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
import math
import random
from collections.abc import Mapping
from typing import Callable, List, Tuple, Any, Optional

from fakeredis import _msgs as msgs
Expand Down Expand Up @@ -50,8 +50,8 @@ def hget(self, key: CommandItem, field: bytes) -> Any:
return key.value.get(field)

@command((Key(Hash),))
def hgetall(self, key: CommandItem) -> List[bytes]:
return list(itertools.chain(*key.value.items()))
def hgetall(self, key: CommandItem) -> Mapping[str, str]:
return key.value.getall()

@command(fixed=(Key(Hash), bytes, bytes))
def hincrby(self, key: CommandItem, field: bytes, amount_bytes: bytes) -> int:
Expand Down
27 changes: 13 additions & 14 deletions fakeredis/commands_mixins/sortedset_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def zcount(self, key, _min, _max):
return key.value.zcount(_min.lower_bound, _max.upper_bound)

@command((Key(ZSet), Float, bytes))
def zincrby(self, key, increment, member) -> float:
def zincrby(self, key: CommandItem, increment: float, member: bytes) -> float:
# Can't just default the old score to 0.0, because in IEEE754, adding
# 0.0 to something isn't a nop (e.g., 0.0 + -0.0 == 0.0).
try:
Expand All @@ -200,10 +200,7 @@ def zincrby(self, key, increment, member) -> float:
raise SimpleError(msgs.SCORE_NAN_MSG)
key.value[member] = score
key.updated()
# For some reason, here it does not ignore the version
# https://github.com/cunla/fakeredis-py/actions/runs/3377186364/jobs/5605815202
return Float.encode(score, False)
# return self._encodefloat(score, False)
return score

@command((Key(ZSet), StringTest, StringTest))
def zlexcount(self, key, _min, _max):
Expand All @@ -218,7 +215,9 @@ def _zrangebyscore(self, key, _min, _max, reverse, withscores, offset, count) ->
items = self._apply_withscores(items, withscores)
return items

def _zrange(self, key, start, stop, reverse, withscores, byscore) -> List[bytes]:
def _zrange(
self, key: CommandItem, start: ScoreTest, stop: ScoreTest, reverse: bool, withscores: bool, byscore: bool
) -> List[bytes]:
zset = key.value
if byscore:
items = zset.irange_score(start.lower_bound, stop.upper_bound, reverse=reverse)
Expand All @@ -244,7 +243,7 @@ def _zrangebylex(self, key, _min, _max, reverse, offset, count) -> List[bytes]:
items = self._limit_items(items, offset, count)
return items

def _zrange_args(self, key, start, stop, *args):
def _zrange_args(self, key: CommandItem, start: bytes, stop: bytes, *args: bytes):
(bylex, byscore, rev, (offset, count), withscores), _ = extract_args(
args, ("bylex", "byscore", "rev", "++limit", "withscores")
)
Expand Down Expand Up @@ -300,7 +299,7 @@ def zrangestore(self, dest: CommandItem, src, start, stop, *args):
return len(res)

@command((Key(ZSet), ScoreTest, ScoreTest), (bytes,))
def zrevrange(self, key, start, stop, *args):
def zrevrange(self, key: CommandItem, start: ScoreTest, stop: ScoreTest, *args):
(withscores, byscore), _ = extract_args(args, ("withscores", "byscore"))
return self._zrange(key, start, stop, True, withscores, byscore)

Expand Down Expand Up @@ -356,7 +355,7 @@ def zrevrank(self, key: CommandItem, member: bytes, *args: bytes) -> Union[None,
return None

@command((Key(ZSet), bytes), (bytes,))
def zrem(self, key, *members):
def zrem(self, key: CommandItem, *members: bytes) -> int:
old_size = len(key.value)
for member in members:
key.value.discard(member)
Expand All @@ -371,15 +370,15 @@ def zremrangebylex(self, key, _min, _max):
return self.zrem(key, *items)

@command((Key(ZSet), ScoreTest, ScoreTest))
def zremrangebyscore(self, key, _min, _max):
items = key.value.irange_score(_min.lower_bound, _max.upper_bound)
def zremrangebyscore(self, key: CommandItem, _min: ScoreTest, _max: ScoreTest):
items = key.value.irange_score(_min.lower_bound, _max.upper_bound, reverse=False)
return self.zrem(key, *[item[1] for item in items])

@command((Key(ZSet), Int, Int))
def zremrangebyrank(self, key, start: int, stop: int):
zset = key.value
start, stop = fix_range(start, stop, len(zset))
items = zset.islice_score(start, stop)
items = zset.islice_score(start, stop, reverse=False)
return self.zrem(key, *[item[1] for item in items])

@command((Key(ZSet), Int), (bytes, bytes))
Expand All @@ -392,9 +391,9 @@ def zscan(self, key: CommandItem, cursor: int, *args: bytes) -> List[Union[int,
return [new_cursor, flat]

@command((Key(ZSet), bytes))
def zscore(self, key: CommandItem, member: bytes) -> Optional[bytes]:
def zscore(self, key: CommandItem, member: bytes) -> Union[None, bytes]:
try:
return self._encodefloat(key.value[member], False)
return key.value[member]
except KeyError:
return None

Expand Down
32 changes: 22 additions & 10 deletions fakeredis/commands_mixins/streams_mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import List, Union, Tuple, Callable, Optional, Any
from typing import List, Union, Tuple, Callable, Optional, Any, Dict

import fakeredis._msgs as msgs
from fakeredis._command_args_parsing import extract_args
Expand Down Expand Up @@ -58,18 +58,28 @@ def xtrim(self, key: CommandItem, *args: bytes) -> int:
def xlen(self, key: CommandItem) -> int:
return len(key.value)

@command(name="XRANGE", fixed=(Key(XStream), StreamRangeTest, StreamRangeTest), repeat=(bytes,))
@command(
name="XRANGE",
fixed=(Key(XStream), StreamRangeTest, StreamRangeTest),
repeat=(bytes,),
flags=msgs.FLAG_SKIP_CONVERT_TO_RESP2,
)
def xrange(self, key: CommandItem, _min: StreamRangeTest, _max: StreamRangeTest, *args: bytes) -> List[bytes]:
(count,), _ = extract_args(args, ("+count",))
return self._xrange(key.value, _min, _max, False, count)

@command(name="XREVRANGE", fixed=(Key(XStream), StreamRangeTest, StreamRangeTest), repeat=(bytes,))
@command(
name="XREVRANGE",
fixed=(Key(XStream), StreamRangeTest, StreamRangeTest),
repeat=(bytes,),
flags=msgs.FLAG_SKIP_CONVERT_TO_RESP2,
)
def xrevrange(self, key: CommandItem, _min: StreamRangeTest, _max: StreamRangeTest, *args: bytes) -> List[bytes]:
(count,), _ = extract_args(args, ("+count",))
return self._xrange(key.value, _max, _min, True, count)

@command(name="XREAD", fixed=(bytes,), repeat=(bytes,))
def xread(self, *args: bytes) -> Optional[List[List[Union[bytes, List[Tuple[bytes, List[bytes]]]]]]]:
@command(name="XREAD", fixed=(bytes,), repeat=(bytes,), flags=msgs.FLAG_SKIP_CONVERT_TO_RESP2)
def xread(self, *args: bytes) -> Optional[Dict[str, Any]]:
(
count,
timeout,
Expand Down Expand Up @@ -241,10 +251,10 @@ def xgroup_delconsumer(self, key: CommandItem, group_name: bytes, consumer_name:
return group.del_consumer(consumer_name)

@command(name="XINFO GROUPS", fixed=(Key(XStream),), repeat=())
def xinfo_groups(self, key: CommandItem) -> List[List[bytes]]:
def xinfo_groups(self, key: CommandItem) -> Dict[bytes, Any]:
if key.value is None:
raise SimpleError(msgs.NO_KEY_MSG)
res: List[List[bytes]] = key.value.groups_info()
res: Dict[bytes, Any] = key.value.groups_info()
return res

@command(name="XINFO STREAM", fixed=(Key(XStream),), repeat=(bytes,), flags=msgs.FLAG_DO_NOT_CREATE)
Expand Down Expand Up @@ -349,18 +359,20 @@ def _xreadgroup(

def _xread(
self, stream_start_id_list: List[Tuple[bytes, StreamRangeTest]], count: int, blocking: bool, first_pass: bool
) -> Optional[List[List[Union[bytes, List[Tuple[bytes, List[bytes]]]]]]]:
) -> Union[None, Dict[bytes, Any], List[List[Union[bytes, List[Tuple[bytes, List[bytes]]]]]]]:
max_inf = StreamRangeTest.decode(b"+")
res: List[Any] = list()
res: Dict[bytes, Any] = dict()
for stream_name, start_id in stream_start_id_list:
item = CommandItem(stream_name, self._db, item=self._db.get(stream_name), default=None)
stream_results = self._xrange(item.value, start_id, max_inf, False, count)
if len(stream_results) > 0:
res.append([item.key, stream_results])
res[item.key] = stream_results

# On blocking read, when count is not None, and there are no results, return None (instead of an empty list)
if blocking and count and len(res) == 0:
return None
if self.protocol_version == 2:
return [[k, v] for k, v in res.items()]
return res

@staticmethod
Expand Down
5 changes: 3 additions & 2 deletions fakeredis/model/_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def update(self, values: Dict[bytes, Any]) -> None:
self._expire_keys()
self._values.update(values)

def getall(self) -> Dict[bytes, Any]:
def getall(self) -> Dict[str, str]:
self._expire_keys()
return self._values.copy()
res = self._values.copy()
return {k.decode("utf-8"): v.decode("utf-8") for k, v in res.items()}
8 changes: 4 additions & 4 deletions fakeredis/model/_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def del_consumer(self, consumer_name: bytes) -> int:
def consumers_info(self) -> List[List[Union[bytes, int]]]:
return [self.consumers[k].info(current_time()) for k in self.consumers]

def group_info(self) -> List[bytes]:
def group_info(self) -> Dict[bytes, Any]:
start_index, _ = self.stream.find_index(self.start_key)
last_delivered_index, _ = self.stream.find_index(self.last_delivered_key)
last_ack_index, _ = self.stream.find_index(self.last_ack_key)
Expand All @@ -139,7 +139,7 @@ def group_info(self) -> List[bytes]:
b"entries-read": self.entries_read,
b"lag": lag,
}
return list(itertools.chain(*res.items())) # type: ignore
return res

def group_read(
self, consumer_name: bytes, start_id: bytes, count: int, noack: bool
Expand Down Expand Up @@ -486,7 +486,7 @@ def trim(
del self._values_dict[k]
return res

def irange(self, start: StreamRangeTest, stop: StreamRangeTest, reverse: bool = False) -> List[Any]:
def irange(self, start: StreamRangeTest, stop: StreamRangeTest, reverse: bool = False) -> Tuple[Any]:
"""Returns a range of the stream values from start to stop.

:param start: Start key
Expand All @@ -510,7 +510,7 @@ def _find_index(elem: StreamRangeTest, from_left: bool = True) -> int:
matches = map(lambda x: self.format_record(self._ids[x]), range(start_ind, stop_ind))
if reverse:
return list(reversed(tuple(matches)))
return list(matches)
return tuple(list(matches))

def last_item_key(self) -> bytes:
return self._ids[-1].encode() if len(self._ids) > 0 else "0-0".encode()
Expand Down
Loading
Loading