Skip to content

Commit 67137c6

Browse files
committed
Fixes to handle retries for WantWriteError and WantReadError in SSL
Added handling for WantWriteError and WantReadError in BufferedWriter and StreamReader to enable retries. This addresses long standing issues discussed in #245. The fix depends on fixes that were added in pyOpenSSL v25.2.0.
1 parent 4f040de commit 67137c6

File tree

2 files changed

+219
-3
lines changed

2 files changed

+219
-3
lines changed

cheroot/makefile.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
# prefer slower Python-based io module
44
import _pyio as io
55
import socket
6+
import time
7+
8+
from OpenSSL import SSL
69

710

811
# Write only 16K at a time to sockets
@@ -31,7 +34,24 @@ def _flush_unlocked(self):
3134
# so perhaps we should conditionally wrap this for perf?
3235
n = self.raw.write(bytes(self._write_buf))
3336
except io.BlockingIOError as e:
37+
# some data may have been written
38+
# we need to remove that from the buffer before retryings
3439
n = e.characters_written
40+
except (
41+
SSL.WantReadError,
42+
SSL.WantWriteError,
43+
SSL.WantX509LookupError,
44+
):
45+
# these errors require retries with the same data
46+
# regardless of whether data has already been written
47+
continue
48+
except OSError:
49+
# This catches errors like EBADF (Bad File Descriptor)
50+
# or EPIPE (Broken pipe), which indicate the underlying
51+
# socket is already closed or invalid.
52+
# Since this happens in __del__, we silently stop flushing.
53+
self._write_buf.clear()
54+
return # Exit the function
3555
del self._write_buf[:n]
3656

3757

@@ -45,9 +65,22 @@ def __init__(self, sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE):
4565

4666
def read(self, *args, **kwargs):
4767
"""Capture bytes read."""
48-
val = super().read(*args, **kwargs)
49-
self.bytes_read += len(val)
50-
return val
68+
MAX_ATTEMPTS = 10
69+
last_error = None
70+
for _ in range(MAX_ATTEMPTS):
71+
try:
72+
val = super().read(*args, **kwargs)
73+
except (SSL.WantReadError, SSL.WantWriteError) as ssl_want_error:
74+
last_error = ssl_want_error
75+
time.sleep(0.1)
76+
else:
77+
self.bytes_read += len(val)
78+
return val
79+
80+
# If we get here, all attempts failed
81+
raise TimeoutError(
82+
'Max retries exceeded while waiting for data.',
83+
) from last_error
5184

5285
def has_data(self):
5386
"""Return true if there is buffered data to read."""

cheroot/test/test_ssl.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for TLS support."""
22

3+
import errno
34
import functools
45
import http.client
56
import json
@@ -17,6 +18,8 @@
1718
import requests
1819
import trustme
1920

21+
from cheroot.makefile import BufferedWriter
22+
2023
from .._compat import (
2124
IS_ABOVE_OPENSSL10,
2225
IS_ABOVE_OPENSSL31,
@@ -625,6 +628,186 @@ def test_ssl_env( # noqa: C901 # FIXME
625628
)
626629

627630

631+
@pytest.fixture
632+
def mock_raw_open_socket(mocker):
633+
"""Return a mocked raw socket prepared for writing (closed=False)."""
634+
# This fixture sets the state on the injected object
635+
mock_raw = mocker.Mock(name='mock_raw_socket')
636+
mock_raw.closed = False
637+
return mock_raw
638+
639+
640+
@pytest.fixture
641+
def ssl_writer(mock_raw_open_socket):
642+
"""Return a BufferedWriter instance with a mocked raw socket."""
643+
return BufferedWriter(mock_raw_open_socket)
644+
645+
646+
def test_want_write_error_retry(ssl_writer, mock_raw_open_socket):
647+
"""Test that WantWriteError causes retry with same data."""
648+
test_data = b'hello world'
649+
650+
# set up mock socket so that when its write() method is called,
651+
# we get WantWriteError first, then success on the second call
652+
# indicated by returning the number of bytes written
653+
mock_raw_open_socket.write.side_effect = [
654+
OpenSSL.SSL.WantWriteError(),
655+
len(test_data),
656+
]
657+
658+
bytes_written = ssl_writer.write(test_data)
659+
assert bytes_written == len(test_data)
660+
661+
# Assert against the injected mock object
662+
assert mock_raw_open_socket.write.call_count == 2
663+
664+
665+
def test_want_read_error_retry(ssl_writer, mock_raw_open_socket):
666+
"""Test that WantReadError causes retry with same data."""
667+
test_data = b'test data'
668+
669+
# set up mock socket so that when its write() method is called,
670+
# we get WantReadError first, then success on the second call
671+
# indicated by returning the number of bytes written
672+
mock_raw_open_socket.write.side_effect = [
673+
OpenSSL.SSL.WantReadError(),
674+
len(test_data),
675+
]
676+
677+
bytes_written = ssl_writer.write(test_data)
678+
assert bytes_written == len(test_data)
679+
680+
681+
@pytest.fixture(
682+
params=('builtin', 'pyopenssl'),
683+
)
684+
def adapter_type(request):
685+
"""Fixture that yields the name of the SSL adapter."""
686+
return request.param
687+
688+
689+
@pytest.fixture
690+
def create_side_effects_factory(adapter_type):
691+
"""
692+
Fixture that returns a factory function to create the side effect list.
693+
694+
The factory function returns a list of two items:
695+
1. An error to be raised on the first call
696+
2. The length of data written on the second call
697+
698+
It returns a function that takes one argument,
699+
allowing the data length to be injected from the test function.
700+
"""
701+
if adapter_type == 'pyopenssl':
702+
failure_error = OpenSSL.SSL.WantWriteError()
703+
else: # adapter_type == 'builtin'
704+
failure_error = BlockingIOError(
705+
errno.EWOULDBLOCK,
706+
'Resource temporarily unavailable',
707+
)
708+
failure_error.characters_written = 0
709+
710+
def generate_side_effects(data_length): # noqa: WPS430
711+
"""Return the list: [failure_error, data_length]."""
712+
return [
713+
failure_error,
714+
data_length, # This uses the length provided by the test
715+
]
716+
717+
# Return the inner function
718+
return generate_side_effects
719+
720+
721+
@pytest.fixture
722+
def ssl_writer_integration(
723+
mocker,
724+
mock_raw_open_socket,
725+
adapter_type,
726+
tls_certificate_chain_pem_path,
727+
tls_certificate_private_key_pem_path,
728+
):
729+
"""
730+
Set up mock SSL writer for integration test.
731+
732+
Mocks the lowest-level write/send method to simulate a
733+
WantWriteError for the PYOPENSSL adapter, and a
734+
BlockingIOError for the BUILTIN adapter.
735+
"""
736+
# Set up SSL adapter
737+
tls_adapter_cls = get_ssl_adapter_class(name=adapter_type)
738+
tls_adapter = tls_adapter_cls(
739+
tls_certificate_chain_pem_path,
740+
tls_certificate_private_key_pem_path,
741+
)
742+
743+
if adapter_type == 'pyopenssl':
744+
# --- PYOPENSSL SETUP
745+
# Ensure context is initialized, as required by an OpenSSL Connection
746+
tls_adapter.context = tls_adapter.get_context()
747+
# need to mock a dummy fd on the mocked raw socket
748+
mock_raw_open_socket.fileno.return_value = 1
749+
750+
# Create an OpenSSL.SSL.Connection object using the mocked raw socket
751+
ssl_conn = OpenSSL.SSL.Connection(
752+
tls_adapter.context,
753+
mock_raw_open_socket,
754+
)
755+
ssl_conn.set_connect_state()
756+
ssl_conn.closed = False
757+
758+
# we need to mock a write method on the mocked raw socket
759+
raw_io_object = ssl_conn
760+
raw_io_object.write = mocker.Mock(name='ssl_conn_write_mock')
761+
else:
762+
# adapter_type == 'builtin'
763+
# --- BUILTIN ADAPTER SETUP (Requires different mocking) ---
764+
# we need to mock the adapter's own write and writable methods
765+
raw_io_object = tls_adapter
766+
raw_io_object.writable = mocker.Mock(return_value=True)
767+
raw_io_object.write = mocker.Mock(
768+
name='builtin_adapter_write',
769+
)
770+
raw_io_object.closed = False
771+
772+
# return mock assertion target
773+
return raw_io_object
774+
775+
776+
def test_want_write_error_integration(
777+
ssl_writer_integration,
778+
create_side_effects_factory,
779+
):
780+
"""Integration test for SSL writer handling of WantWriteError.
781+
782+
This test gets called twice, once for each adapter type.
783+
The fixture ssl_writer_integration sets up the mock write method.
784+
The fixture create_side_effects_factory creates the side effect list
785+
with the data length injected from this test function.
786+
"""
787+
raw_io_object = ssl_writer_integration
788+
test_data = b'integration test data'
789+
successful_write_length = len(test_data)
790+
791+
# Call side effects factory function to create
792+
# a two step list for the mock write method.
793+
# First call raises error, second call returns length.
794+
# We have to inject the length because the factory
795+
# is created in a fixture that doesn't know the test data.
796+
side_effects = create_side_effects_factory(successful_write_length)
797+
raw_io_object.write.side_effect = side_effects
798+
799+
writer = BufferedWriter(raw_io_object)
800+
801+
# write data and then flush
802+
# with the way the mock_write is set up this should fail once,
803+
# and then succeed on the retry.
804+
bytes_written = writer.write(test_data)
805+
writer.flush()
806+
807+
assert bytes_written == successful_write_length
808+
assert raw_io_object.write.call_count == 2
809+
810+
628811
@pytest.mark.parametrize(
629812
'ip_addr',
630813
(

0 commit comments

Comments
 (0)