diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 67005a0..c9c145c 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -195,7 +195,7 @@ def test_set_none(client, request): @pytest.mark.parametrize("client", ["znsclient", "znsclient_w_redis"]) def test_set_large_message(client, request, caplog): c = request.getfixturevalue(client) - pipeline = c.pipeline(max_message_size=3000) + pipeline = c.pipeline(max_commands_per_call=75) for _ in range(100): pipeline.set("foo", "bar") diff --git a/znsocket/client.py b/znsocket/client.py index d2752cb..a106c68 100644 --- a/znsocket/client.py +++ b/znsocket/client.py @@ -120,9 +120,20 @@ def hmset(self, name, mapping): @dataclasses.dataclass class Pipeline: + """A pipeline of Redis commands to be executed as a batch on the server. + + Arguments + --------- + client : Client + The client to send the pipeline to. + max_commands_per_call : int + The maximum number of commands to send in a single call to the server. + Decrease this number for large commands to avoid hitting the message size limit. + Increase it for small commands to reduce latency. + """ client: Client - max_message_size: t.Optional[int] = 10 * 1024 * 1024 - pipeline: list = dataclasses.field(default_factory=list) + max_commands_per_call: int = 1_000_000 + pipeline: list = dataclasses.field(default_factory=list, init=False) def _add_to_pipeline(self, command, *args, **kwargs): """Generic handler to add Redis commands to the pipeline.""" @@ -161,16 +172,12 @@ def execute(self): results = [] for idx, entry in enumerate(self.pipeline): message.append(entry) - if self.max_message_size is not None: - msg_size = json.dumps(message).__sizeof__() - if msg_size > self.max_message_size: - warnings.warn( - f"Message size '{msg_size}' is greater than" - f" '{self.max_message_size = }'. Sending message" - f" at index {idx} and continuing." - ) - results.extend(self._send_message(message)) - message = [] + if len(message) >= self.max_commands_per_call: + warnings.warn( + f"splitting message at index {idx} due to max_message_chunk", + ) + results.extend(self._send_message(message)) + message = [] if message: results.extend(self._send_message(message))