Skip to content

Commit 8ac02bf

Browse files
committed
Fixes to handle retries for WantWriteError and WantReadError in SSL
Added handling for WantWriteError and WantReadError in BufferedWriter and StreamReader to enable retries. This addresses long standing issues discussed in #245. The fix depends on fixes that were added in pyOpenSSL v25.2.0.
1 parent 4f040de commit 8ac02bf

File tree

3 files changed

+171
-3
lines changed

3 files changed

+171
-3
lines changed

cheroot/makefile.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
# prefer slower Python-based io module
44
import _pyio as io
55
import socket
6+
import time
7+
8+
from OpenSSL import SSL
69

710

811
# Write only 16K at a time to sockets
@@ -32,6 +35,14 @@ def _flush_unlocked(self):
3235
n = self.raw.write(bytes(self._write_buf))
3336
except io.BlockingIOError as e:
3437
n = e.characters_written
38+
except (
39+
SSL.WantReadError,
40+
SSL.WantWriteError,
41+
SSL.WantX509LookupError,
42+
):
43+
# these errors require retries with the same data
44+
# if some data has already been written
45+
n = 0
3546
del self._write_buf[:n]
3647

3748

@@ -45,9 +56,22 @@ def __init__(self, sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE):
4556

4657
def read(self, *args, **kwargs):
4758
"""Capture bytes read."""
48-
val = super().read(*args, **kwargs)
49-
self.bytes_read += len(val)
50-
return val
59+
MAX_ATTEMPTS = 10
60+
attempts = 0
61+
while True:
62+
try:
63+
val = super().read(*args, **kwargs)
64+
except (SSL.WantReadError, SSL.WantWriteError):
65+
attempts += 1
66+
if attempts >= MAX_ATTEMPTS:
67+
# Raise an error if max retries reached
68+
raise TimeoutError(
69+
'Max retries exceeded while waiting for data.',
70+
)
71+
time.sleep(0.1)
72+
else:
73+
self.bytes_read += len(val)
74+
return val
5175

5276
def has_data(self):
5377
"""Return true if there is buffered data to read."""

cheroot/test/test_ssl.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import requests
1818
import trustme
1919

20+
from cheroot.makefile import BufferedWriter
21+
2022
from .._compat import (
2123
IS_ABOVE_OPENSSL10,
2224
IS_ABOVE_OPENSSL31,
@@ -625,6 +627,145 @@ def test_ssl_env( # noqa: C901 # FIXME
625627
)
626628

627629

630+
@pytest.fixture
631+
def mock_raw_open(mocker):
632+
"""Return a mocked raw socket prepared for writing (closed=False)."""
633+
# This fixture sets the state on the injected object
634+
mock_raw = mocker.Mock()
635+
mock_raw.closed = False
636+
return mock_raw
637+
638+
639+
@pytest.fixture
640+
def ssl_writer(mock_raw_open):
641+
"""Return a BufferedWriter instance with a mocked raw socket."""
642+
return BufferedWriter(mock_raw_open)
643+
644+
645+
def test_want_write_error_retry(ssl_writer, mock_raw_open):
646+
"""Test that WantWriteError causes retry with same data."""
647+
test_data = b'hello world'
648+
649+
# set up mock socket so that when its write() method is called,
650+
# we get WantWriteError first, then success on the second call
651+
# indicated by returning the number of bytes written
652+
mock_raw_open.write.side_effect = [
653+
OpenSSL.SSL.WantWriteError(),
654+
len(test_data),
655+
]
656+
657+
bytes_written = ssl_writer.write(test_data)
658+
assert bytes_written == len(test_data)
659+
660+
# Assert against the injected mock object
661+
assert mock_raw_open.write.call_count == 2
662+
663+
664+
def test_want_read_error_retry(ssl_writer, mock_raw_open):
665+
"""Test that WantReadError causes retry with same data."""
666+
test_data = b'test data'
667+
668+
# set up mock socket so that when its write() method is called,
669+
# we get WantReadError first, then success on the second call
670+
# indicated by returning the number of bytes written
671+
mock_raw_open.write.side_effect = [
672+
OpenSSL.SSL.WantReadError(),
673+
len(test_data),
674+
]
675+
676+
bytes_written = ssl_writer.write(test_data)
677+
assert bytes_written == len(test_data)
678+
679+
680+
@pytest.fixture(
681+
params=['builtin', 'pyopenssl'],
682+
)
683+
def adapter_type(request):
684+
"""Fixture that yields the name of the SSL adapter."""
685+
return request.param
686+
687+
688+
@pytest.fixture
689+
def ssl_writer_integration(
690+
mocker,
691+
adapter_type,
692+
tls_certificate_chain_pem_path,
693+
tls_certificate_private_key_pem_path,
694+
):
695+
"""
696+
Set up mock SSL writer for integration test.
697+
698+
Mocks the lowest-level write/send method to simulate a
699+
transient WantWriteError.
700+
"""
701+
# Set up SSL adapter
702+
tls_adapter_cls = get_ssl_adapter_class(name=adapter_type)
703+
tls_adapter = tls_adapter_cls(
704+
tls_certificate_chain_pem_path,
705+
tls_certificate_private_key_pem_path,
706+
)
707+
708+
# Ensure context is initialized if needed
709+
if adapter_type == 'pyopenssl':
710+
# --- PYOPENSSL SETUP
711+
tls_adapter.context = tls_adapter.get_context()
712+
mock_raw_socket = mocker.Mock(name='mock_raw_socket')
713+
mock_raw_socket.fileno.return_value = 1 # need to mock a dummy fd
714+
715+
# Create the real OpenSSL.SSL.Connection object
716+
ssl_conn = OpenSSL.SSL.Connection(tls_adapter.context, mock_raw_socket)
717+
ssl_conn.set_connect_state()
718+
ssl_conn.closed = False
719+
720+
# Return the BufferedWriter and the specific mock for assertions
721+
raw_io_object = ssl_conn
722+
raw_io_object.write = mocker.Mock(name='ssl_conn_write_mock')
723+
else:
724+
# adapter_type == 'builtin'
725+
# --- BUILTIN ADAPTER SETUP (Requires different mocking) ---
726+
# Mock the adapter's own low-level write method
727+
raw_io_object = tls_adapter
728+
raw_io_object.write = mocker.Mock(
729+
name='builtin_adapter_write',
730+
autospec=True,
731+
)
732+
raw_io_object.closed = False
733+
raw_io_object.writable = mocker.Mock(return_value=True)
734+
735+
# Return both the writer and the specific mock assertion target
736+
return BufferedWriter(raw_io_object), raw_io_object.write
737+
738+
739+
def test_want_write_error_integration(ssl_writer_integration):
740+
"""Integration test for SSL writer handling of WantWriteError."""
741+
writer, mock_write = ssl_writer_integration
742+
test_data = b'integration test data'
743+
successful_write_length = len(test_data)
744+
745+
# Determine the failure mechanism
746+
if adapter_type == 'pyopenssl':
747+
# Linter is perfectly happy with this explicit assignment
748+
failure_error = OpenSSL.SSL.WantWriteError()
749+
else:
750+
failure_error = 0
751+
752+
# Configure the mock's side effect with the first error
753+
# and then the calculated buffer length for success
754+
mock_write.side_effect = [
755+
failure_error,
756+
successful_write_length,
757+
]
758+
759+
# write data and then flush
760+
# with the way the mock_write is set up this should fail once,
761+
# and then succeed on the retry.
762+
bytes_written = writer.write(test_data)
763+
writer.flush()
764+
765+
assert bytes_written == successful_write_length
766+
assert mock_write.call_count == 2
767+
768+
628769
@pytest.mark.parametrize(
629770
'ip_addr',
630771
(

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ file = "README.rst"
7272
content-type = "text/x-rst"
7373

7474
[project.optional-dependencies]
75+
ssl = [
76+
"pyOpenSSL >= 25.2.0",
77+
]
7578
docs = [
7679
# upstream
7780
"sphinx >= 1.8.2",

0 commit comments

Comments
 (0)