Skip to content

Commit a7750e9

Browse files
julianz-webknjaz
authored andcommitted
Fixes to handle retries for WantWriteError and WantReadError in SSL
As discussed in #245
1 parent 4f040de commit a7750e9

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

cheroot/makefile.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# prefer slower Python-based io module
44
import _pyio as io
55
import socket
6+
import time
7+
from OpenSSL import SSL
68

79

810
# Write only 16K at a time to sockets
@@ -31,7 +33,11 @@ def _flush_unlocked(self):
3133
# so perhaps we should conditionally wrap this for perf?
3234
n = self.raw.write(bytes(self._write_buf))
3335
except io.BlockingIOError as e:
34-
n = e.characters_written
36+
n = e.characters_writteni
37+
except (SSL.WantReadError,SSL.WantWriteError, SSL.WantX509LookupError) as e:
38+
# these errors require retries with the same data
39+
# if some data has already been written
40+
n = 0
3541
del self._write_buf[:n]
3642

3743

@@ -45,9 +51,15 @@ def __init__(self, sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE):
4551

4652
def read(self, *args, **kwargs):
4753
"""Capture bytes read."""
48-
val = super().read(*args, **kwargs)
49-
self.bytes_read += len(val)
50-
return val
54+
while True:
55+
try:
56+
val = super().read(*args, **kwargs)
57+
self.bytes_read += len(val)
58+
return val
59+
except SSL.WantReadError:
60+
time.sleep(0.1) # allow some retry delay
61+
except SSL.WantWriteError:
62+
time.sleep(0.1)
5163

5264
def has_data(self):
5365
"""Return true if there is buffered data to read."""

cheroot/test/test_ssl.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import requests
1818
import trustme
1919

20+
from cheroot.makefile import BufferedWriter
21+
2022
from .._compat import (
2123
IS_ABOVE_OPENSSL10,
2224
IS_ABOVE_OPENSSL31,
@@ -38,6 +40,7 @@
3840
_get_conn_data,
3941
_probe_ipv6_sock,
4042
)
43+
from unittest import mock
4144
from ..wsgi import Gateway_10
4245

4346

@@ -624,6 +627,37 @@ def test_ssl_env( # noqa: C901 # FIXME
624627
),
625628
)
626629

630+
class TestBufferedWriterSSLWantErrors:
631+
632+
def setup_method(self):
633+
self.mock_raw = mock.Mock()
634+
self.mock_raw.closed = False
635+
self.writer = BufferedWriter(self.mock_raw)
636+
637+
def test_want_write_error_retry(self):
638+
"""Test that WantWriteError causes retry with same data."""
639+
test_data = b"hello world"
640+
641+
self.mock_raw.write.side_effect = [
642+
OpenSSL.SSL.WantWriteError(),
643+
len(test_data)
644+
]
645+
646+
result = self.writer.write(test_data)
647+
assert result == len(test_data)
648+
assert self.mock_raw.write.call_count == 2
649+
650+
def test_want_read_error_retry(self):
651+
"""Test that WantReadError causes retry with same data."""
652+
test_data = b"test data"
653+
654+
self.mock_raw.write.side_effect = [
655+
OpenSSL.SSL.WantReadError(),
656+
len(test_data)
657+
]
658+
659+
result = self.writer.write(test_data)
660+
assert result == len(test_data)
627661

628662
@pytest.mark.parametrize(
629663
'ip_addr',

0 commit comments

Comments
 (0)