|
1 | 1 | """Tests for TLS support.""" |
2 | 2 |
|
| 3 | +import errno |
3 | 4 | import functools |
4 | 5 | import http.client |
5 | 6 | import json |
|
17 | 18 | import requests |
18 | 19 | import trustme |
19 | 20 |
|
| 21 | +from cheroot.makefile import BufferedWriter |
| 22 | + |
20 | 23 | from .._compat import ( |
21 | 24 | IS_ABOVE_OPENSSL10, |
22 | 25 | IS_ABOVE_OPENSSL31, |
@@ -625,6 +628,186 @@ def test_ssl_env( # noqa: C901 # FIXME |
625 | 628 | ) |
626 | 629 |
|
627 | 630 |
|
| 631 | +@pytest.fixture |
| 632 | +def mock_raw_open_socket(mocker): |
| 633 | + """Return a mocked raw socket prepared for writing (closed=False).""" |
| 634 | + # This fixture sets the state on the injected object |
| 635 | + mock_raw = mocker.Mock(name='mock_raw_socket') |
| 636 | + mock_raw.closed = False |
| 637 | + return mock_raw |
| 638 | + |
| 639 | + |
| 640 | +@pytest.fixture |
| 641 | +def ssl_writer(mock_raw_open_socket): |
| 642 | + """Return a BufferedWriter instance with a mocked raw socket.""" |
| 643 | + return BufferedWriter(mock_raw_open_socket) |
| 644 | + |
| 645 | + |
| 646 | +def test_want_write_error_retry(ssl_writer, mock_raw_open_socket): |
| 647 | + """Test that WantWriteError causes retry with same data.""" |
| 648 | + test_data = b'hello world' |
| 649 | + |
| 650 | + # set up mock socket so that when its write() method is called, |
| 651 | + # we get WantWriteError first, then success on the second call |
| 652 | + # indicated by returning the number of bytes written |
| 653 | + mock_raw_open_socket.write.side_effect = [ |
| 654 | + OpenSSL.SSL.WantWriteError(), |
| 655 | + len(test_data), |
| 656 | + ] |
| 657 | + |
| 658 | + bytes_written = ssl_writer.write(test_data) |
| 659 | + assert bytes_written == len(test_data) |
| 660 | + |
| 661 | + # Assert against the injected mock object |
| 662 | + assert mock_raw_open_socket.write.call_count == 2 |
| 663 | + |
| 664 | + |
| 665 | +def test_want_read_error_retry(ssl_writer, mock_raw_open_socket): |
| 666 | + """Test that WantReadError causes retry with same data.""" |
| 667 | + test_data = b'test data' |
| 668 | + |
| 669 | + # set up mock socket so that when its write() method is called, |
| 670 | + # we get WantReadError first, then success on the second call |
| 671 | + # indicated by returning the number of bytes written |
| 672 | + mock_raw_open_socket.write.side_effect = [ |
| 673 | + OpenSSL.SSL.WantReadError(), |
| 674 | + len(test_data), |
| 675 | + ] |
| 676 | + |
| 677 | + bytes_written = ssl_writer.write(test_data) |
| 678 | + assert bytes_written == len(test_data) |
| 679 | + |
| 680 | + |
| 681 | +@pytest.fixture( |
| 682 | + params=('builtin', 'pyopenssl'), |
| 683 | +) |
| 684 | +def adapter_type(request): |
| 685 | + """Fixture that yields the name of the SSL adapter.""" |
| 686 | + return request.param |
| 687 | + |
| 688 | + |
| 689 | +@pytest.fixture |
| 690 | +def create_side_effects_factory(adapter_type): |
| 691 | + """ |
| 692 | + Fixture that returns a factory function to create the side effect list. |
| 693 | +
|
| 694 | + The factory function returns a list of two items: |
| 695 | + 1. An error to be raised on the first call |
| 696 | + 2. The length of data written on the second call |
| 697 | +
|
| 698 | + It returns a function that takes one argument, |
| 699 | + allowing the data length to be injected from the test function. |
| 700 | + """ |
| 701 | + if adapter_type == 'pyopenssl': |
| 702 | + failure_error = OpenSSL.SSL.WantWriteError() |
| 703 | + else: # adapter_type == 'builtin' |
| 704 | + failure_error = BlockingIOError( |
| 705 | + errno.EWOULDBLOCK, |
| 706 | + 'Resource temporarily unavailable', |
| 707 | + ) |
| 708 | + failure_error.characters_written = 0 |
| 709 | + |
| 710 | + def generate_side_effects(data_length): # noqa: WPS430 |
| 711 | + """Return the list: [failure_error, data_length].""" |
| 712 | + return [ |
| 713 | + failure_error, |
| 714 | + data_length, # This uses the length provided by the test |
| 715 | + ] |
| 716 | + |
| 717 | + # Return the inner function |
| 718 | + return generate_side_effects |
| 719 | + |
| 720 | + |
| 721 | +@pytest.fixture |
| 722 | +def ssl_writer_integration( |
| 723 | + mocker, |
| 724 | + mock_raw_open_socket, |
| 725 | + adapter_type, |
| 726 | + tls_certificate_chain_pem_path, |
| 727 | + tls_certificate_private_key_pem_path, |
| 728 | +): |
| 729 | + """ |
| 730 | + Set up mock SSL writer for integration test. |
| 731 | +
|
| 732 | + Mocks the lowest-level write/send method to simulate a |
| 733 | + WantWriteError for the PYOPENSSL adapter, and a |
| 734 | + BlockingIOError for the BUILTIN adapter. |
| 735 | + """ |
| 736 | + # Set up SSL adapter |
| 737 | + tls_adapter_cls = get_ssl_adapter_class(name=adapter_type) |
| 738 | + tls_adapter = tls_adapter_cls( |
| 739 | + tls_certificate_chain_pem_path, |
| 740 | + tls_certificate_private_key_pem_path, |
| 741 | + ) |
| 742 | + |
| 743 | + if adapter_type == 'pyopenssl': |
| 744 | + # --- PYOPENSSL SETUP |
| 745 | + # Ensure context is initialized, as required by an OpenSSL Connection |
| 746 | + tls_adapter.context = tls_adapter.get_context() |
| 747 | + # need to mock a dummy fd on the mocked raw socket |
| 748 | + mock_raw_open_socket.fileno.return_value = 1 |
| 749 | + |
| 750 | + # Create an OpenSSL.SSL.Connection object using the mocked raw socket |
| 751 | + ssl_conn = OpenSSL.SSL.Connection( |
| 752 | + tls_adapter.context, |
| 753 | + mock_raw_open_socket, |
| 754 | + ) |
| 755 | + ssl_conn.set_connect_state() |
| 756 | + ssl_conn.closed = False |
| 757 | + |
| 758 | + # we need to mock a write method on the mocked raw socket |
| 759 | + raw_io_object = ssl_conn |
| 760 | + raw_io_object.write = mocker.Mock(name='ssl_conn_write_mock') |
| 761 | + else: |
| 762 | + # adapter_type == 'builtin' |
| 763 | + # --- BUILTIN ADAPTER SETUP (Requires different mocking) --- |
| 764 | + # we need to mock the adapter's own write and writable methods |
| 765 | + raw_io_object = tls_adapter |
| 766 | + raw_io_object.writable = mocker.Mock(return_value=True) |
| 767 | + raw_io_object.write = mocker.Mock( |
| 768 | + name='builtin_adapter_write', |
| 769 | + ) |
| 770 | + raw_io_object.closed = False |
| 771 | + |
| 772 | + # return mock assertion target |
| 773 | + return raw_io_object |
| 774 | + |
| 775 | + |
| 776 | +def test_want_write_error_integration( |
| 777 | + ssl_writer_integration, |
| 778 | + create_side_effects_factory, |
| 779 | +): |
| 780 | + """Integration test for SSL writer handling of WantWriteError. |
| 781 | +
|
| 782 | + This test gets called twice, once for each adapter type. |
| 783 | + The fixture ssl_writer_integration sets up the mock write method. |
| 784 | + The fixture create_side_effects_factory creates the side effect list |
| 785 | + with the data length injected from this test function. |
| 786 | + """ |
| 787 | + raw_io_object = ssl_writer_integration |
| 788 | + test_data = b'integration test data' |
| 789 | + successful_write_length = len(test_data) |
| 790 | + |
| 791 | + # Call side effects factory function to create |
| 792 | + # a two step list for the mock write method. |
| 793 | + # First call raises error, second call returns length. |
| 794 | + # We have to inject the length because the factory |
| 795 | + # is created in a fixture that doesn't know the test data. |
| 796 | + side_effects = create_side_effects_factory(successful_write_length) |
| 797 | + raw_io_object.write.side_effect = side_effects |
| 798 | + |
| 799 | + writer = BufferedWriter(raw_io_object) |
| 800 | + |
| 801 | + # write data and then flush |
| 802 | + # with the way the mock_write is set up this should fail once, |
| 803 | + # and then succeed on the retry. |
| 804 | + bytes_written = writer.write(test_data) |
| 805 | + writer.flush() |
| 806 | + |
| 807 | + assert bytes_written == successful_write_length |
| 808 | + assert raw_io_object.write.call_count == 2 |
| 809 | + |
| 810 | + |
628 | 811 | @pytest.mark.parametrize( |
629 | 812 | 'ip_addr', |
630 | 813 | ( |
|
0 commit comments