Skip to content

Commit

Permalink
use pipelines for set/get options if possible (#65)
Browse files Browse the repository at this point in the history
* use pipelines for set/get options if possible

* do not use `json dump` but `max_commands_per_call`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add `max_commands_per_call` to List

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use logging instead of warnings

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove comment

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Nov 7, 2024
1 parent 5307bda commit 20ad62b
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 24 deletions.
16 changes: 12 additions & 4 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import pytest
import redis.exceptions

Expand Down Expand Up @@ -195,10 +197,16 @@ def test_set_none(client, request):
@pytest.mark.parametrize("client", ["znsclient", "znsclient_w_redis"])
def test_set_large_message(client, request, caplog):
c = request.getfixturevalue(client)
pipeline = c.pipeline(max_message_size=3000)

logger = logging.getLogger("znsocket")
logger.setLevel(logging.DEBUG)

pipeline = c.pipeline(max_commands_per_call=75)
for _ in range(100):
pipeline.set("foo", "bar")

with pytest.warns(UserWarning):
# assert that the message is too large and is being split
assert pipeline.execute() == [True] * 100
assert pipeline.execute() == [True] * 100

assert any(
"splitting message at index" in record.message for record in caplog.records
), "Expected 'splitting message' debug log not found."
35 changes: 23 additions & 12 deletions znsocket/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
import functools
import json
import logging
import typing as t
import warnings

Expand All @@ -12,6 +13,8 @@
from znsocket.abc import RefreshDataTypeDict
from znsocket.utils import parse_url

log = logging.getLogger(__name__)


def _handle_data(data: dict):
if "type" in data:
Expand Down Expand Up @@ -120,9 +123,21 @@ def hmset(self, name, mapping):

@dataclasses.dataclass
class Pipeline:
"""A pipeline of Redis commands to be executed as a batch on the server.
Arguments
---------
client : Client
The client to send the pipeline to.
max_commands_per_call : int
The maximum number of commands to send in a single call to the server.
Decrease this number for large commands to avoid hitting the message size limit.
Increase it for small commands to reduce latency.
"""

client: Client
max_message_size: t.Optional[int] = 10 * 1024 * 1024
pipeline: list = dataclasses.field(default_factory=list)
max_commands_per_call: int = 1_000_000
pipeline: list = dataclasses.field(default_factory=list, init=False)

def _add_to_pipeline(self, command, *args, **kwargs):
"""Generic handler to add Redis commands to the pipeline."""
Expand Down Expand Up @@ -161,16 +176,12 @@ def execute(self):
results = []
for idx, entry in enumerate(self.pipeline):
message.append(entry)
if self.max_message_size is not None:
msg_size = json.dumps(message).__sizeof__()
if msg_size > self.max_message_size:
warnings.warn(
f"Message size '{msg_size}' is greater than"
f" '{self.max_message_size = }'. Sending message"
f" at index {idx} and continuing."
)
results.extend(self._send_message(message))
message = []
if len(message) > self.max_commands_per_call:
log.debug(
f"splitting message at index {idx} due to max_message_chunk",
)
results.extend(self._send_message(message))
message = []
if message:
results.extend(self._send_message(message))

Expand Down
35 changes: 27 additions & 8 deletions znsocket/objects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
callbacks: ListCallbackTypedDict | None = None,
repr_type: ListRepr = "length",
converter: list[t.Type[znjson.ConverterBase]] | None = None,
max_commands_per_call: int = 1_000_000,
):
"""Synchronized list object.
Expand All @@ -65,6 +66,12 @@ def __init__(
converter: list[znjson.ConverterBase]|None
Optional list of znjson converters
to use for encoding/decoding the data.
max_commands_per_call: int
Maximum number of commands to send in a
single call when using pipelines.
Reduce for large list operations to avoid
hitting the message size limit.
Only applies when using `znsocket.Client`.
"""
self.redis = r
Expand All @@ -74,6 +81,11 @@ def __init__(
self.converter = converter
self._on_refresh = lambda x: None

if isinstance(r, Client):
self._pipeline_kwargs = {"max_commands_per_call": max_commands_per_call}
else:
self._pipeline_kwargs = {}

self._callbacks = {
"setitem": None,
"delitem": None,
Expand All @@ -93,9 +105,13 @@ def __getitem__(self, index: int | list | slice) -> t.Any | list[t.Any]:
if isinstance(index, slice):
index = list(range(*index.indices(len(self))))

items = []
pipeline = self.redis.pipeline(**self._pipeline_kwargs)
for i in index:
value = self.redis.lindex(self.key, i)
pipeline.lindex(self.key, i)
data = pipeline.execute()

items = []
for value in data:
if value is None:
item = None
else:
Expand Down Expand Up @@ -131,6 +147,7 @@ def __setitem__(self, index: int | list | slice, value: t.Any) -> None:
f"attempt to assign sequence of size {len(value)} to extended slice of size {len(index)}"
)

pipeline = self.redis.pipeline(**self._pipeline_kwargs)
for i, v in zip(index, value):
if i >= LENGTH or i < -LENGTH:
raise IndexError("list index out of range")
Expand All @@ -140,7 +157,8 @@ def __setitem__(self, index: int | list | slice, value: t.Any) -> None:
if value.key == self.key:
raise ValueError("Can not set circular reference to self")
v = f"znsocket.List:{v.key}"
self.redis.lset(self.key, i, _encode(self, v))
pipeline.lset(self.key, i, _encode(self, v))
pipeline.execute()

if callback := self._callbacks["setitem"]:
callback(index, value)
Expand All @@ -159,14 +177,15 @@ def __delitem__(self, index: int | list | slice) -> None:
if len(index) == 0:
return # nothing to delete

pipeline = self.redis.pipeline(**self._pipeline_kwargs)
for i in index:
pipeline.lset(self.key, i, "__DELETED__")
pipeline.lrem(self.key, 0, "__DELETED__")
try:
for i in index:
self.redis.lset(self.key, i, "__DELETED__")
pipeline.execute()
except redis.exceptions.ResponseError:
raise IndexError("list index out of range")

self.redis.lrem(self.key, 0, "__DELETED__")

if self._callbacks["delitem"]:
self._callbacks["delitem"](index)

Expand Down Expand Up @@ -242,7 +261,7 @@ def extend(self, values: t.Iterable) -> None:
"""Extend the list with an iterable using redis pipelines."""
if self.socket is not None:
refresh: RefreshTypeDict = {"start": len(self), "stop": None}
pipe = self.redis.pipeline()
pipe = self.redis.pipeline(**self._pipeline_kwargs)
for value in values:
if isinstance(value, Dict):
value = f"znsocket.Dict:{value.key}"
Expand Down

0 comments on commit 20ad62b

Please sign in to comment.