Skip to content

Commit 788c8cf

Browse files
authored
[Unblock internal] Fix log capture issue on internal tests (#1076)
1 parent 5aad251 commit 788c8cf

File tree

2 files changed

+75
-28
lines changed

2 files changed

+75
-28
lines changed

helion/_testing.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import functools
66
import importlib
77
import inspect
8+
import io
9+
import logging
810
import operator
911
import os
1012
from pathlib import Path
@@ -39,6 +41,39 @@
3941
EXAMPLES_DIR: Path = PROJECT_ROOT / "examples"
4042

4143

44+
class _LogCapture(logging.Handler):
45+
"""Simple logging handler to capture log records."""
46+
47+
def __init__(self) -> None:
48+
super().__init__()
49+
self.records: list[logging.LogRecord] = []
50+
51+
def emit(self, record: logging.LogRecord) -> None:
52+
self.records.append(record)
53+
54+
def clear(self) -> None:
55+
self.records.clear()
56+
57+
58+
class _OutputCapture:
59+
"""Simple output capture class for stdout/stderr."""
60+
61+
def __init__(self) -> None:
62+
self.stdout = io.StringIO()
63+
self.stderr = io.StringIO()
64+
65+
def readouterr(self) -> tuple[str, str]:
66+
"""Read and clear captured output, returning (stdout, stderr) tuple."""
67+
stdout_val = self.stdout.getvalue()
68+
stderr_val = self.stderr.getvalue()
69+
# Clear the buffers
70+
self.stdout.seek(0)
71+
self.stdout.truncate()
72+
self.stderr.seek(0)
73+
self.stderr.truncate()
74+
return (stdout_val, stderr_val)
75+
76+
4277
def is_cuda() -> bool:
4378
"""Return True if running on CUDA (NVIDIA GPU)."""
4479
return (
@@ -940,3 +975,26 @@ def assertExpectedJournal(self, value: str) -> None:
940975
expected,
941976
msg="To accept the new output, re-run test with env EXPECTTEST_ACCEPT=1",
942977
)
978+
979+
@contextlib.contextmanager
980+
def capture_logs(self) -> Generator[_LogCapture, None, None]:
981+
"""Context manager to capture logs."""
982+
handler = _LogCapture()
983+
handler.setLevel(logging.DEBUG)
984+
logger = logging.getLogger()
985+
logger.addHandler(handler)
986+
try:
987+
yield handler
988+
finally:
989+
logger.removeHandler(handler)
990+
991+
@contextlib.contextmanager
992+
def capture_output(self) -> Generator[_OutputCapture, None, None]:
993+
"""Context manager to capture stdout/stderr."""
994+
capture = _OutputCapture()
995+
old_stdout, old_stderr = sys.stdout, sys.stderr
996+
sys.stdout, sys.stderr = capture.stdout, capture.stderr
997+
try:
998+
yield capture
999+
finally:
1000+
sys.stdout, sys.stderr = old_stdout, old_stderr

test/test_debug_utils.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,6 @@ def _with_print_repro_enabled(self):
5050
else:
5151
os.environ["HELION_PRINT_REPRO"] = original
5252

53-
def _clear_captures(self):
54-
"""Clear pytest capture fixtures if available."""
55-
if hasattr(self, "_capfd"):
56-
self._capfd.readouterr()
57-
if hasattr(self, "_caplog"):
58-
self._caplog.clear()
59-
6053
def _create_kernel(self, **kwargs):
6154
"""Create a simple 1D kernel for testing.
6255
@@ -85,21 +78,19 @@ def test_print_repro_env_var(self):
8578
torch.manual_seed(0)
8679
x = torch.randn([128], dtype=torch.float32, device=DEVICE)
8780

88-
self._clear_captures()
89-
90-
result = kernel(x)
91-
torch.testing.assert_close(result, x + 1)
81+
with self.capture_logs() as log_capture:
82+
result = kernel(x)
83+
torch.testing.assert_close(result, x + 1)
9284

93-
# Extract repro script from logs (use records to get the raw message without formatting)
94-
assert hasattr(self, "_caplog"), "caplog fixture not available"
95-
repro_script = None
96-
for record in self._caplog.records:
97-
if "# === HELION KERNEL REPRO ===" in record.message:
98-
repro_script = record.message
99-
break
85+
# Extract repro script from logs (use records to get the raw message without formatting)
86+
repro_script = None
87+
for record in log_capture.records:
88+
if "# === HELION KERNEL REPRO ===" in record.message:
89+
repro_script = record.message
90+
break
10091

101-
if repro_script is None:
102-
self.fail("No repro script found in logs")
92+
if repro_script is None:
93+
self.fail("No repro script found in logs")
10394

10495
# Normalize range_warp_specializes=[None] to [] for comparison
10596
normalized_script = repro_script.replace(
@@ -149,8 +140,6 @@ def test_print_repro_on_autotune_error(self):
149140
torch.manual_seed(0)
150141
x = torch.randn([128], dtype=torch.float32, device=DEVICE)
151142

152-
self._clear_captures()
153-
154143
# Mock do_bench to fail on the second config with PTXASError (warn level)
155144
from torch._inductor.runtime.triton_compat import PTXASError
156145
from triton.testing import do_bench as original_do_bench
@@ -163,13 +152,13 @@ def mock_do_bench(*args, **kwargs):
163152
raise PTXASError("Mocked PTXAS error")
164153
return original_do_bench(*args, **kwargs)
165154

166-
with mock.patch("helion.autotuner.base_search.do_bench", mock_do_bench):
167-
# Autotune will try both configs, second one will fail and print repro
168-
kernel.autotune([x], force=False)
155+
with self.capture_output() as output_capture:
156+
with mock.patch("helion.autotuner.base_search.do_bench", mock_do_bench):
157+
# Autotune will try both configs, second one will fail and print repro
158+
kernel.autotune([x], force=False)
169159

170-
# Extract repro script from stderr
171-
assert hasattr(self, "_capfd"), "capfd fixture not available"
172-
captured = "".join(self._capfd.readouterr())
160+
# Extract repro script from stderr
161+
captured = "".join(output_capture.readouterr())
173162

174163
# Verify that a repro script was printed for the failing config
175164
self.assertIn("# === HELION KERNEL REPRO ===", captured)

0 commit comments

Comments
 (0)