11from __future__ import annotations
22
3+ import contextlib
34import linecache
45import os
56import unittest
7+ from unittest import mock
68
79import pytest
810import torch
@@ -24,65 +26,80 @@ def _store_capfd_on_class(request, capfd):
2426 request .cls ._capfd = capfd
2527
2628
29+ @pytest .fixture (autouse = True )
30+ def _store_caplog_on_class (request , caplog ):
31+ """
32+ Expose pytest's caplog fixture as `self._caplog` inside the TestDebugUtils class
33+ (works for unittest.TestCase-style tests).
34+ """
35+ if request .cls is not None :
36+ request .cls ._caplog = caplog
37+
38+
2739class TestDebugUtils (RefEagerTestDisabled , TestCase ):
28- def test_print_repro_env_var (self ):
29- """Ensure HELION_PRINT_REPRO=1 emits an executable repro script."""
40+ @contextlib .contextmanager
41+ def _with_print_repro_enabled (self ):
42+ """Context manager to temporarily set HELION_PRINT_REPRO=1."""
3043 original = os .environ .get ("HELION_PRINT_REPRO" )
3144 os .environ ["HELION_PRINT_REPRO" ] = "1"
3245 try :
46+ yield
47+ finally :
48+ if original is None :
49+ os .environ .pop ("HELION_PRINT_REPRO" , None )
50+ else :
51+ os .environ ["HELION_PRINT_REPRO" ] = original
52+
53+ def _clear_captures (self ):
54+ """Clear pytest capture fixtures if available."""
55+ if hasattr (self , "_capfd" ):
56+ self ._capfd .readouterr ()
57+ if hasattr (self , "_caplog" ):
58+ self ._caplog .clear ()
59+
60+ def _create_kernel (self , ** kwargs ):
61+ """Create a simple 1D kernel for testing.
62+
63+ Args:
64+ **kwargs: Arguments to pass to @helion.kernel decorator.
65+ """
3366
34- @helion .kernel (
35- config = helion .Config (
36- block_sizes = [2 , 2 ],
37- flatten_loops = [False ],
38- indexing = ["pointer" , "pointer" ],
39- l2_groupings = [1 ],
40- load_eviction_policies = ["" ],
41- loop_orders = [[0 , 1 ]],
42- num_stages = 1 ,
43- num_warps = 4 ,
44- pid_type = "flat" ,
45- range_flattens = [None ],
46- range_multi_buffers = [None ],
47- range_num_stages = [0 ],
48- range_unroll_factors = [0 ],
49- ),
67+ @helion .kernel (** kwargs )
68+ def kernel (x : torch .Tensor ) -> torch .Tensor :
69+ out = torch .empty_like (x )
70+ n = x .shape [0 ]
71+ for tile_n in hl .tile ([n ]):
72+ out [tile_n ] = x [tile_n ] + 1
73+ return out
74+
75+ return kernel
76+
77+ def test_print_repro_env_var (self ):
78+ """Ensure HELION_PRINT_REPRO=1 emits an executable repro script."""
79+ with self ._with_print_repro_enabled ():
80+ kernel = self ._create_kernel (
81+ config = helion .Config (block_sizes = [32 ], num_warps = 4 ),
5082 static_shapes = True ,
5183 )
52- def kernel1 (x : torch .Tensor ) -> torch .Tensor :
53- out = torch .empty_like (x )
54- m , n = x .shape
55- for tile_m , tile_n in hl .tile ([m , n ]):
56- out [tile_m , tile_n ] = x [tile_m , tile_n ] + 1
57- return out
5884
5985 torch .manual_seed (0 )
60- x = torch .randn ([2 , 2 ], dtype = torch .float32 , device = DEVICE )
86+ x = torch .randn ([128 ], dtype = torch .float32 , device = DEVICE )
6187
62- if hasattr (self , "_capfd" ):
63- self ._capfd .readouterr ()
88+ self ._clear_captures ()
6489
65- result = kernel1 (x )
90+ result = kernel (x )
6691 torch .testing .assert_close (result , x + 1 )
6792
68- if not hasattr (self , "_capfd" ):
69- return # Cannot test without capture
70-
71- captured = "" .join (self ._capfd .readouterr ())
93+ # Extract repro script from logs (use records to get the raw message without formatting)
94+ assert hasattr (self , "_caplog" ), "caplog fixture not available"
95+ repro_script = None
96+ for record in self ._caplog .records :
97+ if "# === HELION KERNEL REPRO ===" in record .message :
98+ repro_script = record .message
99+ break
72100
73- # Extract repro script
74- lines = captured .splitlines ()
75- start = next (
76- i
77- for i , line in enumerate (lines )
78- if "# === HELION KERNEL REPRO ===" in line
79- )
80- end = next (
81- i
82- for i , line in enumerate (lines [start :], start )
83- if "# === END HELION KERNEL REPRO ===" in line
84- )
85- repro_script = "\n " .join (lines [start : end + 1 ])
101+ if repro_script is None :
102+ self .fail ("No repro script found in logs" )
86103
87104 # Normalize range_warp_specializes=[None] to [] for comparison
88105 normalized_script = repro_script .replace (
@@ -92,26 +109,18 @@ def kernel1(x: torch.Tensor) -> torch.Tensor:
92109 # Verify repro script matches expected script
93110 self .assertExpectedJournal (normalized_script )
94111
95- # Extract the actual code (without the comment markers) for execution
96- repro_lines = repro_script .splitlines ()
97- code_start = 1 if repro_lines [0 ].startswith ("# === HELION" ) else 0
98- code_end = len (repro_lines ) - (
99- 1 if repro_lines [- 1 ].startswith ("# === END" ) else 0
100- )
101- repro_code = "\n " .join (repro_lines [code_start :code_end ])
102-
103112 # Setup linecache so inspect.getsource() works on exec'd code
104113 filename = "<helion_repro_test>"
105114 linecache .cache [filename ] = (
106- len (repro_code ),
115+ len (repro_script ),
107116 None ,
108- [f"{ line } \n " for line in repro_code .splitlines ()],
117+ [f"{ line } \n " for line in repro_script .splitlines ()],
109118 filename ,
110119 )
111120
112121 # Execute the repro script
113122 namespace = {}
114- exec (compile (repro_code , filename , "exec" ), namespace )
123+ exec (compile (repro_script , filename , "exec" ), namespace )
115124
116125 # Call the generated helper and verify it runs successfully
117126 helper = namespace ["helion_repro_caller" ]
@@ -121,11 +130,52 @@ def kernel1(x: torch.Tensor) -> torch.Tensor:
121130 torch .testing .assert_close (repro_result , x + 1 )
122131
123132 linecache .cache .pop (filename , None )
124- finally :
125- if original is None :
126- os .environ .pop ("HELION_PRINT_REPRO" , None )
127- else :
128- os .environ ["HELION_PRINT_REPRO" ] = original
133+
134+ def test_print_repro_on_autotune_error (self ):
135+ """Ensure HELION_PRINT_REPRO=1 prints repro when configs fail during autotuning.
136+
137+ This test mocks do_bench to fail on the second config, guaranteeing the repro
138+ printing code path is exercised for "warn" level errors.
139+ """
140+ with self ._with_print_repro_enabled ():
141+ kernel = self ._create_kernel (
142+ configs = [
143+ helion .Config (block_sizes = [32 ], num_warps = 4 ),
144+ helion .Config (block_sizes = [64 ], num_warps = 8 ),
145+ ],
146+ autotune_precompile = False ,
147+ )
148+
149+ torch .manual_seed (0 )
150+ x = torch .randn ([128 ], dtype = torch .float32 , device = DEVICE )
151+
152+ self ._clear_captures ()
153+
154+ # Mock do_bench to fail on the second config with PTXASError (warn level)
155+ from torch ._inductor .runtime .triton_compat import PTXASError
156+ from triton .testing import do_bench as original_do_bench
157+
158+ call_count = [0 ]
159+
160+ def mock_do_bench (* args , ** kwargs ):
161+ call_count [0 ] += 1
162+ if call_count [0 ] == 2 : # Fail on second config
163+ raise PTXASError ("Mocked PTXAS error" )
164+ return original_do_bench (* args , ** kwargs )
165+
166+ with mock .patch ("helion.autotuner.base_search.do_bench" , mock_do_bench ):
167+ # Autotune will try both configs, second one will fail and print repro
168+ kernel .autotune ([x ], force = False )
169+
170+ # Extract repro script from stderr
171+ assert hasattr (self , "_capfd" ), "capfd fixture not available"
172+ captured = "" .join (self ._capfd .readouterr ())
173+
174+ # Verify that a repro script was printed for the failing config
175+ self .assertIn ("# === HELION KERNEL REPRO ===" , captured )
176+ self .assertIn ("# === END HELION KERNEL REPRO ===" , captured )
177+ self .assertIn ("kernel" , captured )
178+ self .assertIn ("helion_repro_caller()" , captured )
129179
130180
131181if __name__ == "__main__" :
0 commit comments