Skip to content

Commit 5c4225d

Browse files
committed
Fixes to handle retries for WantWriteError and WantReadError in SSL
1 parent 4f040de commit 5c4225d

File tree

2 files changed

+166
-3
lines changed

2 files changed

+166
-3
lines changed

cheroot/makefile.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
# prefer slower Python-based io module
44
import _pyio as io
55
import socket
6+
import time
7+
8+
from OpenSSL import SSL
69

710

811
# Write only 16K at a time to sockets
@@ -32,6 +35,14 @@ def _flush_unlocked(self):
3235
n = self.raw.write(bytes(self._write_buf))
3336
except io.BlockingIOError as e:
3437
n = e.characters_written
38+
except (
39+
SSL.WantReadError,
40+
SSL.WantWriteError,
41+
SSL.WantX509LookupError,
42+
):
43+
# these errors require retries with the same data
44+
# if some data has already been written
45+
n = 0
3546
del self._write_buf[:n]
3647

3748

@@ -45,9 +56,22 @@ def __init__(self, sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE):
4556

4657
def read(self, *args, **kwargs):
4758
"""Capture bytes read."""
48-
val = super().read(*args, **kwargs)
49-
self.bytes_read += len(val)
50-
return val
59+
MAX_ATTEMPTS = 10
60+
attempts = 0
61+
while True:
62+
try:
63+
val = super().read(*args, **kwargs)
64+
except (SSL.WantReadError, SSL.WantWriteError):
65+
attempts += 1
66+
if attempts >= MAX_ATTEMPTS:
67+
# Raise an error if max retries reached
68+
raise TimeoutError(
69+
'Max retries exceeded while waiting for data.',
70+
)
71+
time.sleep(0.1)
72+
else:
73+
self.bytes_read += len(val)
74+
return val
5175

5276
def has_data(self):
5377
"""Return true if there is buffered data to read."""

cheroot/test/test_ssl.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import requests
1818
import trustme
1919

20+
from cheroot.makefile import BufferedWriter
21+
2022
from .._compat import (
2123
IS_ABOVE_OPENSSL10,
2224
IS_ABOVE_OPENSSL31,
@@ -625,6 +627,143 @@ def test_ssl_env( # noqa: C901 # FIXME
625627
)
626628

627629

630+
@pytest.fixture
631+
def mock_raw_open(mocker):
632+
"""Return a mocked raw socket prepared for writing (closed=False)."""
633+
# This fixture sets the state on the injected object
634+
mock_raw = mocker.Mock()
635+
mock_raw.closed = False
636+
return mock_raw
637+
638+
639+
@pytest.fixture
640+
def ssl_writer(mock_raw_open):
641+
"""Return a BufferedWriter instance with a mocked raw socket."""
642+
return BufferedWriter(mock_raw_open)
643+
644+
645+
def test_want_write_error_retry(ssl_writer, mock_raw_open):
646+
"""Test that WantWriteError causes retry with same data."""
647+
test_data = b'hello world'
648+
649+
# set up mock socket so that when its write() method is called,
650+
# we get WantWriteError first, then success on the second call
651+
# indicated by returning the number of bytes written
652+
mock_raw_open.write.side_effect = [
653+
OpenSSL.SSL.WantWriteError(),
654+
len(test_data),
655+
]
656+
657+
bytes_written = ssl_writer.write(test_data)
658+
assert bytes_written == len(test_data)
659+
660+
# Assert against the injected mock object
661+
assert mock_raw_open.write.call_count == 2
662+
663+
664+
def test_want_read_error_retry(ssl_writer, mock_raw_open):
665+
"""Test that WantReadError causes retry with same data."""
666+
test_data = b'test data'
667+
668+
# set up mock socket so that when its write() method is called,
669+
# we get WantReadError first, then success on the second call
670+
# indicated by returning the number of bytes written
671+
mock_raw_open.write.side_effect = [
672+
OpenSSL.SSL.WantReadError(),
673+
len(test_data),
674+
]
675+
676+
bytes_written = ssl_writer.write(test_data)
677+
assert bytes_written == len(test_data)
678+
679+
680+
@pytest.fixture(
681+
params=['builtin', 'pyopenssl'],
682+
)
683+
def adapter_type(request):
684+
"""Fixture that yields the name of the SSL adapter."""
685+
return request.param
686+
687+
688+
@pytest.fixture
689+
def ssl_writer_integration(
690+
mocker,
691+
adapter_type,
692+
tls_certificate_chain_pem_path,
693+
tls_certificate_private_key_pem_path,
694+
):
695+
"""
696+
Set up mock SSL writer for integration test.
697+
698+
Mocks the lowest-level write/send method to simulate a
699+
transient WantWriteError.
700+
"""
701+
# Set up SSL adapter
702+
tls_adapter_cls = get_ssl_adapter_class(name=adapter_type)
703+
tls_adapter = tls_adapter_cls(
704+
tls_certificate_chain_pem_path,
705+
tls_certificate_private_key_pem_path,
706+
)
707+
708+
# Ensure context is initialized if needed
709+
if adapter_type == 'pyopenssl':
710+
# --- PYOPENSSL SETUP
711+
tls_adapter.context = tls_adapter.get_context()
712+
mock_raw_socket = mocker.Mock(name='mock_raw_socket')
713+
mock_raw_socket.fileno.return_value = 1 # need to mock a dummy fd
714+
715+
# Create the real OpenSSL.SSL.Connection object
716+
ssl_conn = OpenSSL.SSL.Connection(tls_adapter.context, mock_raw_socket)
717+
ssl_conn.set_connect_state()
718+
ssl_conn.closed = False
719+
720+
# Return the BufferedWriter and the specific mock for assertions
721+
raw_io_object = ssl_conn
722+
raw_io_object.write = mocker.Mock(name='ssl_conn_write_mock')
723+
else:
724+
# adapter_type == 'builtin'
725+
# --- BUILTIN ADAPTER SETUP (Requires different mocking) ---
726+
# Mock the adapter's own low-level write method
727+
raw_io_object = tls_adapter
728+
raw_io_object.write = mocker.Mock(
729+
name='builtin_adapter_write',
730+
autospec=True,
731+
)
732+
raw_io_object.closed = False
733+
raw_io_object.writable = mocker.Mock(return_value=True)
734+
735+
# Return both the writer and the specific mock assertion target
736+
return BufferedWriter(raw_io_object), raw_io_object.write
737+
738+
739+
def test_want_write_error_integration(ssl_writer_integration):
740+
"""Integration test for SSL writer handling of WantWriteError."""
741+
writer, mock_write = ssl_writer_integration
742+
test_data = b'integration test data'
743+
successful_write_length = len(test_data)
744+
745+
# Determine the failure mechanism
746+
failure_error = (
747+
OpenSSL.SSL.WantWriteError() if adapter_type == 'pyopenssl' else 0
748+
)
749+
750+
# Configure the mock's side effect with the first error
751+
# and then the calculated buffer length for success
752+
mock_write.side_effect = [
753+
failure_error,
754+
successful_write_length,
755+
]
756+
757+
# write data and then flush
758+
# with the way the mock_write is set up this should fail once,
759+
# and then succeed on the retry.
760+
bytes_written = writer.write(test_data)
761+
writer.flush()
762+
763+
assert bytes_written == successful_write_length
764+
assert mock_write.call_count == 2
765+
766+
628767
@pytest.mark.parametrize(
629768
'ip_addr',
630769
(

0 commit comments

Comments
 (0)