|
17 | 17 | import requests |
18 | 18 | import trustme |
19 | 19 |
|
| 20 | +from cheroot.makefile import BufferedWriter |
| 21 | + |
20 | 22 | from .._compat import ( |
21 | 23 | IS_ABOVE_OPENSSL10, |
22 | 24 | IS_ABOVE_OPENSSL31, |
@@ -625,6 +627,143 @@ def test_ssl_env( # noqa: C901 # FIXME |
625 | 627 | ) |
626 | 628 |
|
627 | 629 |
|
| 630 | +@pytest.fixture |
| 631 | +def mock_raw_open(mocker): |
| 632 | + """Return a mocked raw socket prepared for writing (closed=False).""" |
| 633 | + # This fixture sets the state on the injected object |
| 634 | + mock_raw = mocker.Mock() |
| 635 | + mock_raw.closed = False |
| 636 | + return mock_raw |
| 637 | + |
| 638 | + |
| 639 | +@pytest.fixture |
| 640 | +def ssl_writer(mock_raw_open): |
| 641 | + """Return a BufferedWriter instance with a mocked raw socket.""" |
| 642 | + return BufferedWriter(mock_raw_open) |
| 643 | + |
| 644 | + |
| 645 | +def test_want_write_error_retry(ssl_writer, mock_raw_open): |
| 646 | + """Test that WantWriteError causes retry with same data.""" |
| 647 | + test_data = b'hello world' |
| 648 | + |
| 649 | + # set up mock socket so that when its write() method is called, |
| 650 | + # we get WantWriteError first, then success on the second call |
| 651 | + # indicated by returning the number of bytes written |
| 652 | + mock_raw_open.write.side_effect = [ |
| 653 | + OpenSSL.SSL.WantWriteError(), |
| 654 | + len(test_data), |
| 655 | + ] |
| 656 | + |
| 657 | + bytes_written = ssl_writer.write(test_data) |
| 658 | + assert bytes_written == len(test_data) |
| 659 | + |
| 660 | + # Assert against the injected mock object |
| 661 | + assert mock_raw_open.write.call_count == 2 |
| 662 | + |
| 663 | + |
| 664 | +def test_want_read_error_retry(ssl_writer, mock_raw_open): |
| 665 | + """Test that WantReadError causes retry with same data.""" |
| 666 | + test_data = b'test data' |
| 667 | + |
| 668 | + # set up mock socket so that when its write() method is called, |
| 669 | + # we get WantReadError first, then success on the second call |
| 670 | + # indicated by returning the number of bytes written |
| 671 | + mock_raw_open.write.side_effect = [ |
| 672 | + OpenSSL.SSL.WantReadError(), |
| 673 | + len(test_data), |
| 674 | + ] |
| 675 | + |
| 676 | + bytes_written = ssl_writer.write(test_data) |
| 677 | + assert bytes_written == len(test_data) |
| 678 | + |
| 679 | + |
| 680 | +@pytest.fixture( |
| 681 | + params=['builtin', 'pyopenssl'], |
| 682 | +) |
| 683 | +def adapter_type(request): |
| 684 | + """Fixture that yields the name of the SSL adapter.""" |
| 685 | + return request.param |
| 686 | + |
| 687 | + |
| 688 | +@pytest.fixture |
| 689 | +def ssl_writer_integration( |
| 690 | + mocker, |
| 691 | + adapter_type, |
| 692 | + tls_certificate_chain_pem_path, |
| 693 | + tls_certificate_private_key_pem_path, |
| 694 | +): |
| 695 | + """ |
| 696 | + Set up mock SSL writer for integration test. |
| 697 | +
|
| 698 | + Mocks the lowest-level write/send method to simulate a |
| 699 | + transient WantWriteError. |
| 700 | + """ |
| 701 | + # Set up SSL adapter |
| 702 | + tls_adapter_cls = get_ssl_adapter_class(name=adapter_type) |
| 703 | + tls_adapter = tls_adapter_cls( |
| 704 | + tls_certificate_chain_pem_path, |
| 705 | + tls_certificate_private_key_pem_path, |
| 706 | + ) |
| 707 | + |
| 708 | + # Ensure context is initialized if needed |
| 709 | + if adapter_type == 'pyopenssl': |
| 710 | + # --- PYOPENSSL SETUP |
| 711 | + tls_adapter.context = tls_adapter.get_context() |
| 712 | + mock_raw_socket = mocker.Mock(name='mock_raw_socket') |
| 713 | + mock_raw_socket.fileno.return_value = 1 # need to mock a dummy fd |
| 714 | + |
| 715 | + # Create the real OpenSSL.SSL.Connection object |
| 716 | + ssl_conn = OpenSSL.SSL.Connection(tls_adapter.context, mock_raw_socket) |
| 717 | + ssl_conn.set_connect_state() |
| 718 | + ssl_conn.closed = False |
| 719 | + |
| 720 | + # Return the BufferedWriter and the specific mock for assertions |
| 721 | + raw_io_object = ssl_conn |
| 722 | + raw_io_object.write = mocker.Mock(name='ssl_conn_write_mock') |
| 723 | + else: |
| 724 | + # adapter_type == 'builtin' |
| 725 | + # --- BUILTIN ADAPTER SETUP (Requires different mocking) --- |
| 726 | + # Mock the adapter's own low-level write method |
| 727 | + raw_io_object = tls_adapter |
| 728 | + raw_io_object.write = mocker.Mock( |
| 729 | + name='builtin_adapter_write', |
| 730 | + autospec=True, |
| 731 | + ) |
| 732 | + raw_io_object.closed = False |
| 733 | + raw_io_object.writable = mocker.Mock(return_value=True) |
| 734 | + |
| 735 | + # Return both the writer and the specific mock assertion target |
| 736 | + return BufferedWriter(raw_io_object), raw_io_object.write |
| 737 | + |
| 738 | + |
| 739 | +def test_want_write_error_integration(ssl_writer_integration): |
| 740 | + """Integration test for SSL writer handling of WantWriteError.""" |
| 741 | + writer, mock_write = ssl_writer_integration |
| 742 | + test_data = b'integration test data' |
| 743 | + successful_write_length = len(test_data) |
| 744 | + |
| 745 | + # Determine the failure mechanism |
| 746 | + failure_error = ( |
| 747 | + OpenSSL.SSL.WantWriteError() if adapter_type == 'pyopenssl' else 0 |
| 748 | + ) |
| 749 | + |
| 750 | + # Configure the mock's side effect with the first error |
| 751 | + # and then the calculated buffer length for success |
| 752 | + mock_write.side_effect = [ |
| 753 | + failure_error, |
| 754 | + successful_write_length, |
| 755 | + ] |
| 756 | + |
| 757 | + # write data and then flush |
| 758 | + # with the way the mock_write is set up this should fail once, |
| 759 | + # and then succeed on the retry. |
| 760 | + bytes_written = writer.write(test_data) |
| 761 | + writer.flush() |
| 762 | + |
| 763 | + assert bytes_written == successful_write_length |
| 764 | + assert mock_write.call_count == 2 |
| 765 | + |
| 766 | + |
628 | 767 | @pytest.mark.parametrize( |
629 | 768 | 'ip_addr', |
630 | 769 | ( |
|
0 commit comments