Skip to content

Commit 0b46d59

Browse files
authored
Merge pull request #2081 from minrk/recv_into
implement Socket.recv_into
2 parents 5e0cdbc + 34e707f commit 0b46d59

19 files changed

+454
-289
lines changed

docs/source/conf.py

+5
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,8 @@
272272

273273
# If false, no module index is generated.
274274
# latex_use_modindex = True
275+
276+
linkcheck_ignore = [
277+
r"https://github\.com(.*)#", # javascript based anchors
278+
r"https://github\.com/zeromq/pyzmq/(issues|commits)(.*)", # too many links
279+
]

examples/recv_into/discard.py

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""
2+
use recv_into with an empty buffer to discard unwanted message frames
3+
4+
avoids unnecessary allocations for message frames that won't be used
5+
"""
6+
7+
import logging
8+
import os
9+
import random
10+
import secrets
11+
import time
12+
from pathlib import Path
13+
from tempfile import TemporaryDirectory
14+
from threading import Thread
15+
16+
import zmq
17+
18+
EMPTY = bytearray()
19+
20+
21+
def subscriber(url: str) -> None:
22+
log = logging.getLogger("subscriber")
23+
with zmq.Context() as ctx, ctx.socket(zmq.SUB) as sub:
24+
sub.linger = 0
25+
sub.connect(url)
26+
sub.subscribe(b"")
27+
log.info("Receiving...")
28+
while True:
29+
frame_0 = sub.recv_string()
30+
if frame_0 == "exit":
31+
log.info("Exiting...")
32+
break
33+
elif frame_0 == "large":
34+
discarded_bytes = 0
35+
discarded_frames = 0
36+
while sub.rcvmore:
37+
discarded_bytes += sub.recv_into(EMPTY)
38+
discarded_frames += 1
39+
log.info(
40+
"Discarding large message frames: %s, bytes: %s",
41+
discarded_frames,
42+
discarded_bytes,
43+
)
44+
else:
45+
msg: list = [frame_0]
46+
if sub.rcvmore:
47+
msg.extend(sub.recv_multipart(flags=zmq.DONTWAIT))
48+
log.info("Received %s", msg)
49+
log.info("Done")
50+
51+
52+
def publisher(url) -> None:
53+
log = logging.getLogger("publisher")
54+
choices = ["large", "small"]
55+
with zmq.Context() as ctx, ctx.socket(zmq.PUB) as pub:
56+
pub.linger = 1000
57+
pub.bind(url)
58+
time.sleep(1)
59+
for i in range(10):
60+
kind = random.choice(choices)
61+
frames = [kind.encode()]
62+
if kind == "large":
63+
for _ in range(random.randint(0, 5)):
64+
chunk_size = random.randint(1024, 2048)
65+
chunk = os.urandom(chunk_size)
66+
frames.append(chunk)
67+
else:
68+
for _ in range(random.randint(0, 3)):
69+
chunk_size = random.randint(0, 5)
70+
chunk = secrets.token_hex(chunk_size).encode()
71+
frames.append(chunk)
72+
nbytes = sum(len(chunk) for chunk in frames)
73+
log.info("Sending %s: %s bytes", kind, nbytes)
74+
pub.send_multipart(frames)
75+
time.sleep(0.1)
76+
log.info("Sending exit")
77+
pub.send(b"exit")
78+
log.info("Done")
79+
80+
81+
def main() -> None:
82+
logging.basicConfig(level=logging.INFO)
83+
with TemporaryDirectory() as td:
84+
sock_path = Path(td) / "example.sock"
85+
url = f"ipc://{sock_path}"
86+
s_thread = Thread(
87+
target=subscriber, args=(url,), daemon=True, name="subscriber"
88+
)
89+
s_thread.start()
90+
publisher(url)
91+
s_thread.join(timeout=3)
92+
93+
94+
if __name__ == "__main__":
95+
main()

examples/recv_into/recv_into_array.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""
2+
Use recv_into to receive data directly into a numpy array
3+
"""
4+
5+
import numpy as np
6+
import numpy.testing as nt
7+
8+
import zmq
9+
10+
url = "inproc://test"
11+
12+
13+
def main() -> None:
14+
A = (np.random.random((5, 5)) * 255).astype(dtype=np.int64)
15+
B = np.empty_like(A)
16+
assert not (A == B).all()
17+
18+
with (
19+
zmq.Context() as ctx,
20+
ctx.socket(zmq.PUSH) as push,
21+
ctx.socket(zmq.PULL) as pull,
22+
):
23+
push.bind(url)
24+
pull.connect(url)
25+
print("sending:\n", A)
26+
push.send(A)
27+
bytes_received = pull.recv_into(B)
28+
print(f"received {bytes_received} bytes:\n", B)
29+
assert bytes_received == A.nbytes
30+
nt.assert_allclose(A, B)
31+
32+
33+
if __name__ == "__main__":
34+
main()

tests/test_asyncio.py

+55
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,61 @@ async def test_recv(create_bound_pair):
6464
assert recvd == b"there"
6565

6666

67+
async def test_recv_into(create_bound_pair):
68+
a, b = create_bound_pair()
69+
b.rcvtimeo = 1000
70+
msg = [
71+
b'hello',
72+
b'there world',
73+
b'part 3',
74+
b'rest',
75+
]
76+
a.send_multipart(msg)
77+
78+
# default nbytes: fits in array
79+
buf = bytearray(10)
80+
nbytes = await b.recv_into(buf)
81+
assert nbytes == len(msg[0])
82+
assert buf[:nbytes] == msg[0]
83+
84+
# default nbytes: truncates to sizeof(buf)
85+
buf = bytearray(4)
86+
nbytes = await b.recv_into(buf, flags=zmq.DONTWAIT)
87+
# returned nbytes is the actual received length,
88+
# which indicates truncation
89+
assert nbytes == len(msg[1])
90+
assert buf[:] == msg[1][: len(buf)]
91+
92+
# specify nbytes, truncates
93+
buf = bytearray(10)
94+
nbytes = 4
95+
nbytes_recvd = await b.recv_into(buf, nbytes=nbytes)
96+
assert nbytes_recvd == len(msg[2])
97+
98+
# recv_into empty buffer discards everything
99+
buf = bytearray(10)
100+
view = memoryview(buf)[:0]
101+
assert view.nbytes == 0
102+
nbytes = await b.recv_into(view)
103+
assert nbytes == len(msg[3])
104+
105+
106+
async def test_recv_into_bad(create_bound_pair):
107+
a, b = create_bound_pair()
108+
b.rcvtimeo = 1000
109+
110+
# bad calls
111+
# make sure flags work
112+
with pytest.raises(zmq.Again):
113+
await b.recv_into(bytearray(5), flags=zmq.DONTWAIT)
114+
115+
await a.send(b'msg')
116+
# negative nbytes
117+
buf = bytearray(10)
118+
with pytest.raises(ValueError):
119+
await b.recv_into(buf, nbytes=-1)
120+
121+
67122
@mark.skipif(not hasattr(zmq, "RCVTIMEO"), reason="requires RCVTIMEO")
68123
async def test_recv_timeout(push_pull):
69124
a, b = push_pull

tests/test_mypy.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
Test our typing with mypy
33
"""
44

5-
import os
65
import sys
76
from pathlib import Path
87
from subprocess import PIPE, STDOUT, Popen
98

109
import pytest
1110

1211
pytest.importorskip("mypy")
12+
pytestmark = pytest.mark.skipif(sys.version_info < (3, 10), reason="targets 3.10")
1313

1414
repo_root = Path(__file__).parents[1]
1515

@@ -25,7 +25,9 @@ def run_mypy(*mypy_args):
2525
Captures output and reports it on errors
2626
"""
2727
p = Popen(
28-
[sys.executable, "-m", "mypy"] + list(mypy_args), stdout=PIPE, stderr=STDOUT
28+
[sys.executable, "-m", "mypy", "--python-version=3.10"] + list(mypy_args),
29+
stdout=PIPE,
30+
stderr=STDOUT,
2931
)
3032
o, _ = p.communicate()
3133
out = o.decode("utf8", "replace")

tests/test_socket.py

+76
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import sys
1111
import time
1212
import warnings
13+
from array import array
1314
from unittest import mock
1415

1516
import pytest
@@ -455,6 +456,81 @@ def test_recv_multipart(self):
455456
for i in range(3):
456457
assert self.recv_multipart(b) == [msg]
457458

459+
def test_recv_into(self):
460+
a, b = self.create_bound_pair()
461+
if not self.green:
462+
b.rcvtimeo = 1000
463+
msg = [
464+
b'hello',
465+
b'there world',
466+
b'part 3',
467+
b'rest',
468+
]
469+
a.send_multipart(msg)
470+
471+
# default nbytes: fits in array
472+
# make sure itemsize > 1 is handled right
473+
buf = array('Q', [0])
474+
nbytes = b.recv_into(buf)
475+
assert nbytes == len(msg[0])
476+
assert buf.tobytes()[:nbytes] == msg[0]
477+
478+
# default nbytes: truncates to sizeof(buf)
479+
buf = bytearray(4)
480+
nbytes = b.recv_into(buf)
481+
# returned nbytes is the actual received length,
482+
# which indicates truncation
483+
assert nbytes == len(msg[1])
484+
assert buf[:] == msg[1][: len(buf)]
485+
486+
# specify nbytes, truncates
487+
buf = bytearray(10)
488+
nbytes = 4
489+
nbytes_recvd = b.recv_into(buf, nbytes=nbytes)
490+
assert nbytes_recvd == len(msg[2])
491+
assert buf[:nbytes] == msg[2][:nbytes]
492+
# didn't recv excess bytes
493+
assert buf[nbytes:] == bytearray(10 - nbytes)
494+
495+
# recv_into empty buffer discards everything
496+
buf = bytearray(10)
497+
view = memoryview(buf)[:0]
498+
assert view.nbytes == 0
499+
nbytes = b.recv_into(view)
500+
assert nbytes == len(msg[3])
501+
assert buf == bytearray(10)
502+
503+
def test_recv_into_bad(self):
504+
a, b = self.create_bound_pair()
505+
if not self.green:
506+
b.rcvtimeo = 1000
507+
508+
# bad calls
509+
510+
# negative nbytes
511+
buf = bytearray(10)
512+
with pytest.raises(ValueError):
513+
b.recv_into(buf, nbytes=-1)
514+
# not contiguous
515+
buf = memoryview(bytearray(10))[::2]
516+
with pytest.raises(ValueError):
517+
b.recv_into(buf)
518+
# readonly
519+
buf = memoryview(b"readonly")
520+
with pytest.raises(ValueError):
521+
b.recv_into(buf)
522+
# too big
523+
buf = bytearray(10)
524+
with pytest.raises(ValueError):
525+
b.recv_into(buf, nbytes=11)
526+
# not memory-viewable
527+
with pytest.raises(TypeError):
528+
b.recv_into(pytest)
529+
530+
# make sure flags work
531+
with pytest.raises(zmq.Again):
532+
b.recv_into(bytearray(5), flags=zmq.DONTWAIT)
533+
458534
def test_close_after_destroy(self):
459535
"""s.close() after ctx.destroy() should be fine"""
460536
ctx = self.Context()

tests/zmq_test_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def create_bound_pair(
134134
s2.setsockopt(zmq.LINGER, 0)
135135
s2.connect(f'{interface}:{port}')
136136
self.sockets.extend([s1, s2])
137+
s2.setsockopt(zmq.LINGER, 0)
137138
return s1, s2
138139

139140
def ping_pong(self, s1, s2, msg):

0 commit comments

Comments
 (0)