|
10 | 10 | import sys
|
11 | 11 | import time
|
12 | 12 | import warnings
|
| 13 | +from array import array |
13 | 14 | from unittest import mock
|
14 | 15 |
|
15 | 16 | import pytest
|
@@ -455,6 +456,81 @@ def test_recv_multipart(self):
|
455 | 456 | for i in range(3):
|
456 | 457 | assert self.recv_multipart(b) == [msg]
|
457 | 458 |
|
| 459 | + def test_recv_into(self): |
| 460 | + a, b = self.create_bound_pair() |
| 461 | + if not self.green: |
| 462 | + b.rcvtimeo = 1000 |
| 463 | + msg = [ |
| 464 | + b'hello', |
| 465 | + b'there world', |
| 466 | + b'part 3', |
| 467 | + b'rest', |
| 468 | + ] |
| 469 | + a.send_multipart(msg) |
| 470 | + |
| 471 | + # default nbytes: fits in array |
| 472 | + # make sure itemsize > 1 is handled right |
| 473 | + buf = array('Q', [0]) |
| 474 | + nbytes = b.recv_into(buf) |
| 475 | + assert nbytes == len(msg[0]) |
| 476 | + assert buf.tobytes()[:nbytes] == msg[0] |
| 477 | + |
| 478 | + # default nbytes: truncates to sizeof(buf) |
| 479 | + buf = bytearray(4) |
| 480 | + nbytes = b.recv_into(buf) |
| 481 | + # returned nbytes is the actual received length, |
| 482 | + # which indicates truncation |
| 483 | + assert nbytes == len(msg[1]) |
| 484 | + assert buf[:] == msg[1][: len(buf)] |
| 485 | + |
| 486 | + # specify nbytes, truncates |
| 487 | + buf = bytearray(10) |
| 488 | + nbytes = 4 |
| 489 | + nbytes_recvd = b.recv_into(buf, nbytes=nbytes) |
| 490 | + assert nbytes_recvd == len(msg[2]) |
| 491 | + assert buf[:nbytes] == msg[2][:nbytes] |
| 492 | + # didn't recv excess bytes |
| 493 | + assert buf[nbytes:] == bytearray(10 - nbytes) |
| 494 | + |
| 495 | + # recv_into empty buffer discards everything |
| 496 | + buf = bytearray(10) |
| 497 | + view = memoryview(buf)[:0] |
| 498 | + assert view.nbytes == 0 |
| 499 | + nbytes = b.recv_into(view) |
| 500 | + assert nbytes == len(msg[3]) |
| 501 | + assert buf == bytearray(10) |
| 502 | + |
| 503 | + def test_recv_into_bad(self): |
| 504 | + a, b = self.create_bound_pair() |
| 505 | + if not self.green: |
| 506 | + b.rcvtimeo = 1000 |
| 507 | + |
| 508 | + # bad calls |
| 509 | + |
| 510 | + # negative nbytes |
| 511 | + buf = bytearray(10) |
| 512 | + with pytest.raises(ValueError): |
| 513 | + b.recv_into(buf, nbytes=-1) |
| 514 | + # not contiguous |
| 515 | + buf = memoryview(bytearray(10))[::2] |
| 516 | + with pytest.raises(ValueError): |
| 517 | + b.recv_into(buf) |
| 518 | + # readonly |
| 519 | + buf = memoryview(b"readonly") |
| 520 | + with pytest.raises(ValueError): |
| 521 | + b.recv_into(buf) |
| 522 | + # too big |
| 523 | + buf = bytearray(10) |
| 524 | + with pytest.raises(ValueError): |
| 525 | + b.recv_into(buf, nbytes=11) |
| 526 | + # not memory-viewable |
| 527 | + with pytest.raises(TypeError): |
| 528 | + b.recv_into(pytest) |
| 529 | + |
| 530 | + # make sure flags work |
| 531 | + with pytest.raises(zmq.Again): |
| 532 | + b.recv_into(bytearray(5), flags=zmq.DONTWAIT) |
| 533 | + |
458 | 534 | def test_close_after_destroy(self):
|
459 | 535 | """s.close() after ctx.destroy() should be fine"""
|
460 | 536 | ctx = self.Context()
|
|
0 commit comments