Skip to content

Commit bfa223a

Browse files
authored
Add HELION_PRINT_REPRO=1 to print Helion kernel repro script to console (#1049)
1 parent 0bafd91 commit bfa223a

File tree

9 files changed

+259
-2
lines changed

9 files changed

+259
-2
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,11 @@ To view the generated Triton code, set the environment variable `HELION_PRINT_OU
301301
helpful for debugging and understanding Helion's compilation process. One can also use
302302
`foo_kernel.bind(args).to_triton_code(config)` to get the Triton code as a string.
303303

304+
To emit a repro script that includes the Helion kernel definition, the config decorator, and a
305+
`helion_repro_caller()` helper that recreates the runtime inputs before invoking the Helion kernel, set
306+
`HELION_PRINT_REPRO=1` or include `print_repro=True` in the `@helion.kernel` decorator. This prints
307+
the repro script to `stderr`, which is helpful for debugging and for sharing minimal repro on GitHub issue tracker.
308+
304309
Within an `hl.tile`/`hl.grid` device loop, if you want to print intermediate results using `print("x", ...)` syntax,
305310
or pause execution using Python's built-in `breakpoint()`, set either `TRITON_INTERPRET=1` (runs Triton's CPU interpreter)
306311
or `HELION_INTERPRET=1` (runs the Helion kernel in eager mode).

docs/api/config.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ The `Config` class represents kernel optimization parameters that control how He
2727
|--------|--------|----------|
2828
| **Purpose** | Control execution performance | Control compilation behavior |
2929
| **Autotuning** | ✅ Automatically optimized | ❌ Never autotuned |
30-
| **Examples** | `block_sizes`, `num_warps`, `indexing` | `print_output_code`, `autotune_effort` |
30+
| **Examples** | `block_sizes`, `num_warps`, `indexing` | `print_output_code`, `print_repro`, `autotune_effort` |
3131
| **When to use** | Performance optimization | Development, debugging, environment setup |
3232

3333

docs/api/kernel.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ Settings control **how the kernel is compiled** and the development environment:
161161
autotune_effort="none", # Skip autotuning for development
162162
autotune_effort="quick", # Smaller autotuning budget when search is enabled
163163
print_output_code=True, # Debug: show generated Triton code
164+
print_repro=True, # Debug: show Helion kernel code, config, and caller code as a standalone repro script
164165
static_shapes=True, # Compilation optimization strategy
165166
autotune_log_level=logging.DEBUG # Verbose autotuning output
166167
)

docs/api/settings.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ import helion.language as hl
6161

6262
@helion.kernel(
6363
autotune_effort="none", # Skip autotuning
64-
print_output_code=True, # Debug output
64+
print_output_code=True, # Debug: show generated Triton code
65+
print_repro=True, # Debug: show Helion kernel code, config, and caller code as a standalone repro script
6566
)
6667
def my_kernel(x: torch.Tensor) -> torch.Tensor:
6768
result = torch.zeros_like(x)
@@ -190,6 +191,10 @@ See :class:`helion.autotuner.LocalAutotuneCache` for details on cache keys and b
190191
191192
Print generated Triton code to stderr. Default is ``False``. Controlled by ``HELION_PRINT_OUTPUT_CODE=1``.
192193
194+
.. autoattribute:: Settings.print_repro
195+
196+
Print Helion kernel code, config, and caller code to stderr as a standalone repro script. Default is ``False``. Controlled by ``HELION_PRINT_REPRO=1``.
197+
193198
.. autoattribute:: Settings.output_origin_lines
194199
195200
Annotate generated Triton code with ``# src[<file>:<line>]`` comments indicating the originating Helion statements.
@@ -259,6 +264,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe
259264
| ``HELION_SKIP_CACHE`` | ``LocalAutotuneCache`` | When set to ``1``, ignore cached autotuning entries and rerun searches. |
260265
| ``HELION_ASSERT_CACHE_HIT`` | ``AutotuneCacheBase`` | When set to ``1``, require a cache hit; raises ``CacheAssertionError`` on cache miss with detailed diagnostics. |
261266
| ``HELION_PRINT_OUTPUT_CODE`` | ``print_output_code`` | Print generated Triton code to stderr for inspection. |
267+
| ``HELION_PRINT_REPRO`` | ``print_repro`` | Print Helion kernel code, config, and caller code to stderr as a standalone repro script. |
262268
| ``HELION_OUTPUT_ORIGIN_LINES`` | ``output_origin_lines`` | Include ``# src[...]`` comments in generated Triton code; set to ``0`` to disable. |
263269
| ``HELION_IGNORE_WARNINGS`` | ``ignore_warnings`` | Comma-separated warning names defined in ``helion.exc`` to suppress. |
264270
| ``HELION_ALLOW_WARP_SPECIALIZE`` | ``allow_warp_specialize`` | Permit warp-specialized code generation for ``tl.range``. |

docs/index.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,11 @@ To view the generated Triton code, set the environment variable `HELION_PRINT_OU
241241
helpful for debugging and understanding Helion's compilation process. One can also use
242242
`foo_kernel.bind(args).to_triton_code(config)` to get the Triton code as a string.
243243

244+
To emit a repro script that includes the Helion kernel definition, the config decorator, and a
245+
`helion_repro_caller()` helper that recreates the runtime inputs before invoking the Helion kernel, set
246+
`HELION_PRINT_REPRO=1` or include `print_repro=True` in the `@helion.kernel` decorator. This prints
247+
the repro script to `stderr`, which is helpful for debugging and for sharing minimal repro on GitHub issue tracker.
248+
244249
To force autotuning, bypassing provided configurations, set `HELION_FORCE_AUTOTUNE=1` or invoke `foo_kernel.autotune(args,
245250
force=True)`.
246251

helion/runtime/kernel.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import operator
1010
import re
1111
import sys
12+
import textwrap
1213
import types
1314
from typing import TYPE_CHECKING
1415
from typing import Callable
@@ -641,8 +642,88 @@ def __call__(self, *args: object) -> _R:
641642
self.format_kernel_decorator(self._config, self.settings)
642643
] = 1
643644

645+
if self.settings.print_repro:
646+
self._print_repro(args)
647+
644648
return self._run(*args)
645649

650+
def _print_repro(
651+
self, args: tuple[object, ...], config: Config | None = None
652+
) -> None:
653+
effective_config = config or self._config
654+
assert effective_config is not None
655+
656+
# Get kernel source
657+
try:
658+
raw_source = inspect.getsource(self.kernel.fn)
659+
source_lines = textwrap.dedent(raw_source).splitlines()
660+
# Skip decorator lines (including multi-line decorators)
661+
start_idx = 0
662+
while start_idx < len(source_lines) and not source_lines[
663+
start_idx
664+
].lstrip().startswith("def "):
665+
start_idx += 1
666+
kernel_body = "\n".join(source_lines[start_idx:])
667+
except (OSError, TypeError):
668+
kernel_body = f"# Source unavailable for {self.kernel.fn.__module__}.{self.kernel.fn.__qualname__}"
669+
670+
# Format decorator
671+
decorator = self.format_kernel_decorator(effective_config, self.settings)
672+
673+
# Build output
674+
output_lines = [
675+
"# === HELION KERNEL REPRO ===",
676+
"import helion",
677+
"import helion.language as hl",
678+
"import torch",
679+
"from torch._dynamo.testing import rand_strided",
680+
"",
681+
decorator,
682+
kernel_body,
683+
]
684+
685+
# Generate caller function
686+
if args:
687+
688+
def _render_input_arg_assignment(name: str, value: object) -> list[str]:
689+
if isinstance(value, torch.Tensor):
690+
shape = tuple(int(d) for d in value.shape)
691+
stride = tuple(int(s) for s in value.stride())
692+
device = str(value.device)
693+
dtype = str(value.dtype)
694+
695+
lines = [
696+
f"{name} = rand_strided({shape!r}, {stride!r}, dtype={dtype}, device={device!r})"
697+
]
698+
699+
if value.requires_grad:
700+
lines.append(f"{name}.requires_grad_(True)")
701+
return lines
702+
703+
return [f"{name} = {value!r}"]
704+
705+
sig_param_names = list(self.kernel.signature.parameters.keys())
706+
assert len(args) == len(sig_param_names)
707+
708+
output_lines.extend(["", "def helion_repro_caller():"])
709+
output_lines.append(" torch.manual_seed(0)")
710+
arg_names = []
711+
712+
for i, value in enumerate(args):
713+
var_name = sig_param_names[i]
714+
arg_names.append(var_name)
715+
716+
# Add assignment lines with indentation
717+
for line in _render_input_arg_assignment(var_name, value):
718+
output_lines.append(f" {line}")
719+
720+
# Add return statement
721+
call_args = ", ".join(arg_names)
722+
output_lines.append(f" return {self.kernel.name}({call_args})")
723+
724+
output_lines.append("# === END HELION KERNEL REPRO ===")
725+
print("\n".join(output_lines), file=sys.stderr)
726+
646727

647728
class _KernelDecorator(Protocol):
648729
def __call__(

helion/runtime/settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,9 @@ class _Settings:
315315
_env_get_bool, "HELION_PRINT_OUTPUT_CODE", False
316316
)
317317
)
318+
print_repro: bool = dataclasses.field(
319+
default_factory=functools.partial(_env_get_bool, "HELION_PRINT_REPRO", False)
320+
)
318321
output_origin_lines: bool = dataclasses.field(
319322
default_factory=functools.partial(
320323
_env_get_bool, "HELION_OUTPUT_ORIGIN_LINES", True
@@ -386,6 +389,7 @@ class Settings(_Settings):
386389
"Set HELION_AUTOTUNE_IGNORE_ERRORS=1 to enable globally."
387390
),
388391
"print_output_code": "If True, print the output code of the kernel to stderr.",
392+
"print_repro": "If True, print Helion kernel code, config, and caller code to stderr as a standalone repro script.",
389393
"output_origin_lines": (
390394
"If True, annotate generated Triton code with source-origin comments. "
391395
"Set HELION_OUTPUT_ORIGIN_LINES=0 to disable."

test/test_debug_utils.expected

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
This file is automatically generated by assertExpectedJournal calls in test_debug_utils.py.
2+
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.
3+
4+
--- assertExpectedJournal(TestDebugUtils.test_print_repro_env_var)
5+
# === HELION KERNEL REPRO ===
6+
import helion
7+
import helion.language as hl
8+
import torch
9+
from torch._dynamo.testing import rand_strided
10+
11+
@helion.kernel(config=helion.Config(block_sizes=[2, 2], flatten_loops=[False], indexing=['pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=[''], loop_orders=[[0, 1]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)
12+
def kernel1(x: torch.Tensor) -> torch.Tensor:
13+
out = torch.empty_like(x)
14+
m, n = x.shape
15+
for tile_m, tile_n in hl.tile([m, n]):
16+
out[tile_m, tile_n] = x[tile_m, tile_n] + 1
17+
return out
18+
19+
def helion_repro_caller():
20+
torch.manual_seed(0)
21+
x = rand_strided((2, 2), (2, 1), dtype=torch.float32, device=DEVICE)
22+
return kernel1(x)
23+
# === END HELION KERNEL REPRO ===

test/test_debug_utils.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from __future__ import annotations
2+
3+
import linecache
4+
import os
5+
import unittest
6+
7+
import pytest
8+
import torch
9+
10+
import helion
11+
from helion._testing import DEVICE
12+
from helion._testing import RefEagerTestDisabled
13+
from helion._testing import TestCase
14+
import helion.language as hl
15+
16+
17+
@pytest.fixture(autouse=True)
18+
def _store_capfd_on_class(request, capfd):
19+
"""
20+
Expose pytest's capfd fixture as `self._capfd` inside the TestDebugUtils class
21+
(works for unittest.TestCase-style tests).
22+
"""
23+
if request.cls is not None:
24+
request.cls._capfd = capfd
25+
26+
27+
class TestDebugUtils(RefEagerTestDisabled, TestCase):
28+
def test_print_repro_env_var(self):
29+
"""Ensure HELION_PRINT_REPRO=1 emits an executable repro script."""
30+
original = os.environ.get("HELION_PRINT_REPRO")
31+
os.environ["HELION_PRINT_REPRO"] = "1"
32+
try:
33+
34+
@helion.kernel(
35+
config=helion.Config(
36+
block_sizes=[2, 2],
37+
flatten_loops=[False],
38+
indexing=["pointer", "pointer"],
39+
l2_groupings=[1],
40+
load_eviction_policies=[""],
41+
loop_orders=[[0, 1]],
42+
num_stages=1,
43+
num_warps=4,
44+
pid_type="flat",
45+
range_flattens=[None],
46+
range_multi_buffers=[None],
47+
range_num_stages=[0],
48+
range_unroll_factors=[0],
49+
),
50+
static_shapes=True,
51+
)
52+
def kernel1(x: torch.Tensor) -> torch.Tensor:
53+
out = torch.empty_like(x)
54+
m, n = x.shape
55+
for tile_m, tile_n in hl.tile([m, n]):
56+
out[tile_m, tile_n] = x[tile_m, tile_n] + 1
57+
return out
58+
59+
torch.manual_seed(0)
60+
x = torch.randn([2, 2], dtype=torch.float32, device=DEVICE)
61+
62+
if hasattr(self, "_capfd"):
63+
self._capfd.readouterr()
64+
65+
result = kernel1(x)
66+
torch.testing.assert_close(result, x + 1)
67+
68+
if not hasattr(self, "_capfd"):
69+
return # Cannot test without capture
70+
71+
captured = "".join(self._capfd.readouterr())
72+
73+
# Extract repro script
74+
lines = captured.splitlines()
75+
start = next(
76+
i
77+
for i, line in enumerate(lines)
78+
if "# === HELION KERNEL REPRO ===" in line
79+
)
80+
end = next(
81+
i
82+
for i, line in enumerate(lines[start:], start)
83+
if "# === END HELION KERNEL REPRO ===" in line
84+
)
85+
repro_script = "\n".join(lines[start : end + 1])
86+
87+
# Normalize range_warp_specializes=[None] to [] for comparison
88+
normalized_script = repro_script.replace(
89+
"range_warp_specializes=[None]", "range_warp_specializes=[]"
90+
)
91+
92+
# Verify repro script matches expected script
93+
self.assertExpectedJournal(normalized_script)
94+
95+
# Extract the actual code (without the comment markers) for execution
96+
repro_lines = repro_script.splitlines()
97+
code_start = 1 if repro_lines[0].startswith("# === HELION") else 0
98+
code_end = len(repro_lines) - (
99+
1 if repro_lines[-1].startswith("# === END") else 0
100+
)
101+
repro_code = "\n".join(repro_lines[code_start:code_end])
102+
103+
# Setup linecache so inspect.getsource() works on exec'd code
104+
filename = "<helion_repro_test>"
105+
linecache.cache[filename] = (
106+
len(repro_code),
107+
None,
108+
[f"{line}\n" for line in repro_code.splitlines()],
109+
filename,
110+
)
111+
112+
# Execute the repro script
113+
namespace = {}
114+
exec(compile(repro_code, filename, "exec"), namespace)
115+
116+
# Call the generated helper and verify it runs successfully
117+
helper = namespace["helion_repro_caller"]
118+
repro_result = helper()
119+
120+
# Verify the output
121+
torch.testing.assert_close(repro_result, x + 1)
122+
123+
linecache.cache.pop(filename, None)
124+
finally:
125+
if original is None:
126+
os.environ.pop("HELION_PRINT_REPRO", None)
127+
else:
128+
os.environ["HELION_PRINT_REPRO"] = original
129+
130+
131+
if __name__ == "__main__":
132+
unittest.main()

0 commit comments

Comments
 (0)