Skip to content

Commit c4011c5

Browse files
committed
Added the test
1 parent 4c5e36f commit c4011c5

1 file changed

Lines changed: 122 additions & 0 deletions

File tree

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""
2+
Regression tests for issue #196.
3+
4+
When a SSCHA minimization runs in parallel under an MPI launcher
5+
(``mpirun``/``srun``), the standard output is frequently a pipe opened in
6+
*non-blocking* mode. A large write (for instance the table of imaginary
7+
frequencies printed by ``SchaMinimizer.check_imaginary_frequencies``) fills
8+
the pipe buffer and a write on a non-blocking descriptor raises
9+
``BlockingIOError`` ([Errno 11]) instead of waiting for the buffer to drain.
10+
This used to abort the whole calculation.
11+
12+
The fix lives in ``sscha.Parallel.pprint`` (the function aliased as ``print``
13+
across the package): it restores the blocking mode of stdout and, as a last
14+
resort, never lets a log line crash the run.
15+
"""
16+
import os
17+
import sys
18+
import threading
19+
20+
import pytest
21+
22+
import sscha.Parallel
23+
24+
# Non-blocking pipe semantics and os.set_blocking are POSIX-only.
25+
pytestmark = pytest.mark.skipif(
26+
sys.platform.startswith("win") or not hasattr(os, "set_blocking"),
27+
reason="requires POSIX non-blocking file descriptors (os.set_blocking)",
28+
)
29+
30+
# Much larger than any pipe buffer (~64 KiB) or stdio buffer, so that the
31+
# write cannot complete in a single non-blocking shot.
32+
BIG_MESSAGE = "x" * (4 * 1024 * 1024)
33+
34+
35+
def _drain(read_fd, sink=None):
36+
"""Consume a pipe until EOF, optionally collecting the bytes."""
37+
while True:
38+
chunk = os.read(read_fd, 1 << 16)
39+
if not chunk:
40+
break
41+
if sink is not None:
42+
sink.extend(chunk)
43+
44+
45+
def test_builtin_print_raises_on_nonblocking_stdout():
46+
"""Reproduce the original failure: a plain ``print`` on a non-blocking
47+
stdout raises ``BlockingIOError`` once the buffer fills up. This is exactly
48+
what ``pprint`` used to do before the fix."""
49+
read_fd, write_fd = os.pipe()
50+
os.set_blocking(write_fd, False)
51+
stdout = os.fdopen(write_fd, "w")
52+
saved = sys.stdout
53+
sys.stdout = stdout
54+
try:
55+
with pytest.raises(BlockingIOError):
56+
print(BIG_MESSAGE) # builtin print, no reader draining the pipe
57+
stdout.flush()
58+
finally:
59+
sys.stdout = saved
60+
# Drain the leftover buffered bytes so closing does not block/raise.
61+
os.set_blocking(write_fd, True)
62+
drainer = threading.Thread(target=_drain, args=(read_fd,))
63+
drainer.start()
64+
try:
65+
stdout.close()
66+
except OSError:
67+
pass
68+
drainer.join()
69+
os.close(read_fd)
70+
71+
72+
def test_pprint_survives_nonblocking_stdout():
73+
"""With the fix, ``pprint`` restores blocking mode and the large write
74+
completes successfully instead of raising."""
75+
read_fd, write_fd = os.pipe()
76+
os.set_blocking(write_fd, False)
77+
78+
received = bytearray()
79+
reader = threading.Thread(target=_drain, args=(read_fd, received))
80+
reader.start()
81+
82+
stdout = os.fdopen(write_fd, "w")
83+
saved = sys.stdout
84+
sys.stdout = stdout
85+
try:
86+
# Must not raise BlockingIOError.
87+
sscha.Parallel.pprint(BIG_MESSAGE)
88+
stdout.flush()
89+
finally:
90+
sys.stdout = saved
91+
stdout.close()
92+
reader.join()
93+
os.close(read_fd)
94+
95+
assert BIG_MESSAGE.encode() in bytes(received)
96+
97+
98+
def test_pprint_never_raises_when_blocking_cannot_be_set(monkeypatch):
99+
"""Safety net: even if blocking mode cannot be enforced, ``pprint`` must
100+
swallow the error rather than abort the calculation."""
101+
monkeypatch.setattr(sscha.Parallel, "_force_stdout_blocking", lambda: None)
102+
103+
read_fd, write_fd = os.pipe()
104+
os.set_blocking(write_fd, False)
105+
stdout = os.fdopen(write_fd, "w")
106+
saved = sys.stdout
107+
sys.stdout = stdout
108+
try:
109+
# stdout stays non-blocking and nobody reads: the internal write
110+
# raises BlockingIOError, which pprint must catch.
111+
sscha.Parallel.pprint(BIG_MESSAGE)
112+
finally:
113+
sys.stdout = saved
114+
os.set_blocking(write_fd, True)
115+
drainer = threading.Thread(target=_drain, args=(read_fd,))
116+
drainer.start()
117+
try:
118+
stdout.close()
119+
except OSError:
120+
pass
121+
drainer.join()
122+
os.close(read_fd)

0 commit comments

Comments
 (0)