Skip to content

Commit a0ba4d2

Browse files
authored
Make HELION_PRINT_REPRO=1 take effect in more error cases (#1066)
1 parent a94d16d commit a0ba4d2

File tree

6 files changed

+150
-73
lines changed

6 files changed

+150
-73
lines changed

helion/autotuner/base_search.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]:
188188
baseline_config,
189189
prefix=f"Generated Triton code for {decorator}:",
190190
)
191+
self.kernel.maybe_log_repro(self.log.error, new_args, baseline_config)
191192
raise exc.InvalidConfig(
192193
"Default config failed while computing baseline.\n"
193194
f"Default config: {decorator}\n"
@@ -340,6 +341,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
340341
return res
341342
except Exception as e:
342343
if match_unrecoverable_runtime_error(e):
344+
self.kernel.maybe_log_repro(self.log.error, self.args, config)
343345
raise exc.TritonUnrecoverableRuntimeError(
344346
reason=str(e),
345347
decorator=self.kernel.format_kernel_decorator(
@@ -358,6 +360,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
358360
config,
359361
prefix=f"Generated Triton code for {decorator}:",
360362
)
363+
self.kernel.maybe_log_repro(self.log.error, self.args, config)
361364
raise exc.TritonError(
362365
error=f"{type(e).__qualname__}: {e}",
363366
decorator=decorator,
@@ -372,6 +375,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
372375
prefix=f"Generated Triton code for {decorator}:",
373376
)
374377
self.log.warning(format_triton_compile_failure(config, e, self.kernel))
378+
self.kernel.maybe_log_repro(self.log.warning, self.args, config)
375379
else:
376380
decorator = self.kernel.format_kernel_decorator(config, self.settings)
377381
log_generated_triton_code_debug(
@@ -381,6 +385,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
381385
prefix=f"Generated Triton code for {decorator}:",
382386
)
383387
self.log.debug(f"Benchmarking failed: {type(e).__name__}: {e}")
388+
self.kernel.maybe_log_repro(self.log.debug, self.args, config)
384389
return inf
385390

386391
def start_precompile_and_check_for_hangs(
@@ -1198,6 +1203,9 @@ def _consume_result(self, *, raise_on_raise: bool) -> None:
11981203
self.config,
11991204
prefix=f"Generated Triton code for {decorator}:",
12001205
)
1206+
self.search.kernel.maybe_log_repro(
1207+
self.search.log.error, self.search.args, self.config
1208+
)
12011209
raise exc.TritonError(
12021210
error=f"{type(exc_obj).__qualname__}: {exc_obj}",
12031211
decorator=decorator,
@@ -1223,8 +1231,14 @@ def _consume_result(self, *, raise_on_raise: bool) -> None:
12231231
)
12241232
if classification == "warn":
12251233
self.search.log.warning(formatted)
1234+
self.search.kernel.maybe_log_repro(
1235+
self.search.log.warning, self.search.args, self.config
1236+
)
12261237
elif not ignore_errors:
12271238
self.search.log.debug(formatted)
1239+
self.search.kernel.maybe_log_repro(
1240+
self.search.log.debug, self.search.args, self.config
1241+
)
12281242
self._remote_error_handled = True
12291243

12301244

helion/autotuner/logger.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ def __call__(
5858
if level >= self.level:
5959
self._logger.log(level, " ".join(map(_maybe_call, msg)))
6060

61+
def error(self, *msg: str | Callable[[], str]) -> None:
62+
return self(*msg, level=logging.ERROR)
63+
6164
def warning(self, *msg: str | Callable[[], str]) -> None:
6265
return self(*msg, level=logging.WARNING)
6366

helion/runtime/kernel.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -645,14 +645,19 @@ def __call__(self, *args: object) -> _R:
645645
self.format_kernel_decorator(self._config, self.settings)
646646
] = 1
647647

648-
if self.settings.print_repro:
649-
self._print_repro(args)
648+
self.maybe_log_repro(log.warning, args)
650649

651650
return self._run(*args)
652651

653-
def _print_repro(
654-
self, args: tuple[object, ...], config: Config | None = None
652+
def maybe_log_repro(
653+
self,
654+
log_func: Callable[[str], None],
655+
args: Sequence[object],
656+
config: Config | None = None,
655657
) -> None:
658+
if not self.settings.print_repro:
659+
return
660+
656661
effective_config = config or self._config
657662
assert effective_config is not None
658663

@@ -723,9 +728,11 @@ def _render_input_arg_assignment(name: str, value: object) -> list[str]:
723728
# Add return statement
724729
call_args = ", ".join(arg_names)
725730
output_lines.append(f" return {self.kernel.name}({call_args})")
731+
output_lines.extend(["", "helion_repro_caller()"])
726732

727733
output_lines.append("# === END HELION KERNEL REPRO ===")
728-
print("\n".join(output_lines), file=sys.stderr)
734+
repro_text = "\n".join(output_lines)
735+
log_func(repro_text)
729736

730737

731738
class _KernelDecorator(Protocol):

test/test_autotuner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def _make_search(
8282
search.kernel = SimpleNamespace(
8383
format_kernel_decorator=lambda config, s: "decorator",
8484
to_triton_code=lambda config: "code",
85+
maybe_log_repro=lambda log_func, args, config=None: None,
8586
)
8687
search.args = args
8788
search.counters = collections.Counter()

test/test_debug_utils.expected

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,18 @@ import helion.language as hl
88
import torch
99
from torch._dynamo.testing import rand_strided
1010

11-
@helion.kernel(config=helion.Config(block_sizes=[2, 2], flatten_loops=[False], indexing=['pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=[''], loop_orders=[[0, 1]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)
12-
def kernel1(x: torch.Tensor) -> torch.Tensor:
11+
@helion.kernel(config=helion.Config(block_sizes=[32], indexing=['pointer', 'pointer'], load_eviction_policies=[''], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)
12+
def kernel(x: torch.Tensor) -> torch.Tensor:
1313
out = torch.empty_like(x)
14-
m, n = x.shape
15-
for tile_m, tile_n in hl.tile([m, n]):
16-
out[tile_m, tile_n] = x[tile_m, tile_n] + 1
14+
n = x.shape[0]
15+
for tile_n in hl.tile([n]):
16+
out[tile_n] = x[tile_n] + 1
1717
return out
1818

1919
def helion_repro_caller():
2020
torch.manual_seed(0)
21-
x = rand_strided((2, 2), (2, 1), dtype=torch.float32, device=DEVICE)
22-
return kernel1(x)
21+
x = rand_strided((128,), (1,), dtype=torch.float32, device=DEVICE)
22+
return kernel(x)
23+
24+
helion_repro_caller()
2325
# === END HELION KERNEL REPRO ===

test/test_debug_utils.py

Lines changed: 111 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import linecache
45
import os
56
import unittest
7+
from unittest import mock
68

79
import pytest
810
import torch
@@ -24,65 +26,80 @@ def _store_capfd_on_class(request, capfd):
2426
request.cls._capfd = capfd
2527

2628

29+
@pytest.fixture(autouse=True)
30+
def _store_caplog_on_class(request, caplog):
31+
"""
32+
Expose pytest's caplog fixture as `self._caplog` inside the TestDebugUtils class
33+
(works for unittest.TestCase-style tests).
34+
"""
35+
if request.cls is not None:
36+
request.cls._caplog = caplog
37+
38+
2739
class TestDebugUtils(RefEagerTestDisabled, TestCase):
28-
def test_print_repro_env_var(self):
29-
"""Ensure HELION_PRINT_REPRO=1 emits an executable repro script."""
40+
@contextlib.contextmanager
41+
def _with_print_repro_enabled(self):
42+
"""Context manager to temporarily set HELION_PRINT_REPRO=1."""
3043
original = os.environ.get("HELION_PRINT_REPRO")
3144
os.environ["HELION_PRINT_REPRO"] = "1"
3245
try:
46+
yield
47+
finally:
48+
if original is None:
49+
os.environ.pop("HELION_PRINT_REPRO", None)
50+
else:
51+
os.environ["HELION_PRINT_REPRO"] = original
52+
53+
def _clear_captures(self):
54+
"""Clear pytest capture fixtures if available."""
55+
if hasattr(self, "_capfd"):
56+
self._capfd.readouterr()
57+
if hasattr(self, "_caplog"):
58+
self._caplog.clear()
59+
60+
def _create_kernel(self, **kwargs):
61+
"""Create a simple 1D kernel for testing.
62+
63+
Args:
64+
**kwargs: Arguments to pass to @helion.kernel decorator.
65+
"""
3366

34-
@helion.kernel(
35-
config=helion.Config(
36-
block_sizes=[2, 2],
37-
flatten_loops=[False],
38-
indexing=["pointer", "pointer"],
39-
l2_groupings=[1],
40-
load_eviction_policies=[""],
41-
loop_orders=[[0, 1]],
42-
num_stages=1,
43-
num_warps=4,
44-
pid_type="flat",
45-
range_flattens=[None],
46-
range_multi_buffers=[None],
47-
range_num_stages=[0],
48-
range_unroll_factors=[0],
49-
),
67+
@helion.kernel(**kwargs)
68+
def kernel(x: torch.Tensor) -> torch.Tensor:
69+
out = torch.empty_like(x)
70+
n = x.shape[0]
71+
for tile_n in hl.tile([n]):
72+
out[tile_n] = x[tile_n] + 1
73+
return out
74+
75+
return kernel
76+
77+
def test_print_repro_env_var(self):
78+
"""Ensure HELION_PRINT_REPRO=1 emits an executable repro script."""
79+
with self._with_print_repro_enabled():
80+
kernel = self._create_kernel(
81+
config=helion.Config(block_sizes=[32], num_warps=4),
5082
static_shapes=True,
5183
)
52-
def kernel1(x: torch.Tensor) -> torch.Tensor:
53-
out = torch.empty_like(x)
54-
m, n = x.shape
55-
for tile_m, tile_n in hl.tile([m, n]):
56-
out[tile_m, tile_n] = x[tile_m, tile_n] + 1
57-
return out
5884

5985
torch.manual_seed(0)
60-
x = torch.randn([2, 2], dtype=torch.float32, device=DEVICE)
86+
x = torch.randn([128], dtype=torch.float32, device=DEVICE)
6187

62-
if hasattr(self, "_capfd"):
63-
self._capfd.readouterr()
88+
self._clear_captures()
6489

65-
result = kernel1(x)
90+
result = kernel(x)
6691
torch.testing.assert_close(result, x + 1)
6792

68-
if not hasattr(self, "_capfd"):
69-
return # Cannot test without capture
70-
71-
captured = "".join(self._capfd.readouterr())
93+
# Extract repro script from logs (use records to get the raw message without formatting)
94+
assert hasattr(self, "_caplog"), "caplog fixture not available"
95+
repro_script = None
96+
for record in self._caplog.records:
97+
if "# === HELION KERNEL REPRO ===" in record.message:
98+
repro_script = record.message
99+
break
72100

73-
# Extract repro script
74-
lines = captured.splitlines()
75-
start = next(
76-
i
77-
for i, line in enumerate(lines)
78-
if "# === HELION KERNEL REPRO ===" in line
79-
)
80-
end = next(
81-
i
82-
for i, line in enumerate(lines[start:], start)
83-
if "# === END HELION KERNEL REPRO ===" in line
84-
)
85-
repro_script = "\n".join(lines[start : end + 1])
101+
if repro_script is None:
102+
self.fail("No repro script found in logs")
86103

87104
# Normalize range_warp_specializes=[None] to [] for comparison
88105
normalized_script = repro_script.replace(
@@ -92,26 +109,18 @@ def kernel1(x: torch.Tensor) -> torch.Tensor:
92109
# Verify repro script matches expected script
93110
self.assertExpectedJournal(normalized_script)
94111

95-
# Extract the actual code (without the comment markers) for execution
96-
repro_lines = repro_script.splitlines()
97-
code_start = 1 if repro_lines[0].startswith("# === HELION") else 0
98-
code_end = len(repro_lines) - (
99-
1 if repro_lines[-1].startswith("# === END") else 0
100-
)
101-
repro_code = "\n".join(repro_lines[code_start:code_end])
102-
103112
# Setup linecache so inspect.getsource() works on exec'd code
104113
filename = "<helion_repro_test>"
105114
linecache.cache[filename] = (
106-
len(repro_code),
115+
len(repro_script),
107116
None,
108-
[f"{line}\n" for line in repro_code.splitlines()],
117+
[f"{line}\n" for line in repro_script.splitlines()],
109118
filename,
110119
)
111120

112121
# Execute the repro script
113122
namespace = {}
114-
exec(compile(repro_code, filename, "exec"), namespace)
123+
exec(compile(repro_script, filename, "exec"), namespace)
115124

116125
# Call the generated helper and verify it runs successfully
117126
helper = namespace["helion_repro_caller"]
@@ -121,11 +130,52 @@ def kernel1(x: torch.Tensor) -> torch.Tensor:
121130
torch.testing.assert_close(repro_result, x + 1)
122131

123132
linecache.cache.pop(filename, None)
124-
finally:
125-
if original is None:
126-
os.environ.pop("HELION_PRINT_REPRO", None)
127-
else:
128-
os.environ["HELION_PRINT_REPRO"] = original
133+
134+
def test_print_repro_on_autotune_error(self):
135+
"""Ensure HELION_PRINT_REPRO=1 prints repro when configs fail during autotuning.
136+
137+
This test mocks do_bench to fail on the second config, guaranteeing the repro
138+
printing code path is exercised for "warn" level errors.
139+
"""
140+
with self._with_print_repro_enabled():
141+
kernel = self._create_kernel(
142+
configs=[
143+
helion.Config(block_sizes=[32], num_warps=4),
144+
helion.Config(block_sizes=[64], num_warps=8),
145+
],
146+
autotune_precompile=False,
147+
)
148+
149+
torch.manual_seed(0)
150+
x = torch.randn([128], dtype=torch.float32, device=DEVICE)
151+
152+
self._clear_captures()
153+
154+
# Mock do_bench to fail on the second config with PTXASError (warn level)
155+
from torch._inductor.runtime.triton_compat import PTXASError
156+
from triton.testing import do_bench as original_do_bench
157+
158+
call_count = [0]
159+
160+
def mock_do_bench(*args, **kwargs):
161+
call_count[0] += 1
162+
if call_count[0] == 2: # Fail on second config
163+
raise PTXASError("Mocked PTXAS error")
164+
return original_do_bench(*args, **kwargs)
165+
166+
with mock.patch("helion.autotuner.base_search.do_bench", mock_do_bench):
167+
# Autotune will try both configs, second one will fail and print repro
168+
kernel.autotune([x], force=False)
169+
170+
# Extract repro script from stderr
171+
assert hasattr(self, "_capfd"), "capfd fixture not available"
172+
captured = "".join(self._capfd.readouterr())
173+
174+
# Verify that a repro script was printed for the failing config
175+
self.assertIn("# === HELION KERNEL REPRO ===", captured)
176+
self.assertIn("# === END HELION KERNEL REPRO ===", captured)
177+
self.assertIn("kernel", captured)
178+
self.assertIn("helion_repro_caller()", captured)
129179

130180

131181
if __name__ == "__main__":

0 commit comments

Comments
 (0)