3636
3737if TYPE_CHECKING :
3838 from collections .abc import Callable
39+ from typing import Any
3940
4041
4142# %%
@@ -104,6 +105,71 @@ def geglu(a: Tensor, b: Tensor) -> Tensor:
104105 return out
105106
106107
108+ @helion .kernel ()
109+ def geglu_bwd (grad_out : Tensor , a : Tensor , b : Tensor ) -> tuple [Tensor , Tensor ]:
110+ grad_a = torch .empty_like (a )
111+ grad_b = torch .empty_like (b )
112+
113+ grad_out_flat = grad_out .view (- 1 )
114+ a_flat = a .view (- 1 )
115+ b_flat = b .view (- 1 )
116+ grad_a_flat = grad_a .view (- 1 )
117+ grad_b_flat = grad_b .view (- 1 )
118+
119+ for tile_idx in hl .tile (a .numel ()):
120+ a_vals = a_flat [tile_idx ].to (torch .float32 )
121+ b_vals = b_flat [tile_idx ].to (torch .float32 )
122+ grad_out_vals = grad_out_flat [tile_idx ].to (torch .float32 )
123+
124+ sqrt_2_over_pi = 0.7978845608028654
125+
126+ a_cubed = a_vals * a_vals * a_vals
127+ tanh_arg = sqrt_2_over_pi * (a_vals + 0.044715 * a_cubed )
128+ tanh_result = torch .tanh (tanh_arg )
129+ gelu_a = 0.5 * a_vals * (1.0 + tanh_result )
130+
131+ grad_b_vals = grad_out_vals * gelu_a
132+ grad_b_flat [tile_idx ] = grad_b_vals .to (b .dtype )
133+
134+ dz_da = sqrt_2_over_pi * (1.0 + 0.134145 * a_vals * a_vals )
135+ sech_sq = 1.0 - tanh_result * tanh_result
136+
137+ dgelu_da = 0.5 * (1.0 + tanh_result ) + 0.5 * a_vals * sech_sq * dz_da
138+
139+ grad_a_vals = grad_out_vals * b_vals * dgelu_da
140+ grad_a_flat [tile_idx ] = grad_a_vals .to (a .dtype )
141+
142+ return grad_a , grad_b
143+
144+
145+ class GEGLUFunction (torch .autograd .Function ):
146+ @staticmethod
147+ def forward (
148+ ctx : Any , # noqa: ANN401
149+ a : Tensor ,
150+ b : Tensor ,
151+ ) -> Tensor :
152+ """Forward pass for GEGLU."""
153+ out = geglu (a , b )
154+ ctx .save_for_backward (a , b )
155+ return out
156+
157+ @staticmethod
158+ def backward ( # type: ignore[override]
159+ ctx : Any , # noqa: ANN401
160+ grad_out : Tensor ,
161+ ) -> tuple [Tensor , Tensor ]:
162+ """Backward pass for GEGLU."""
163+ a , b = ctx .saved_tensors
164+ grad_a , grad_b = geglu_bwd (grad_out , a , b )
165+ return grad_a , grad_b
166+
167+
168+ def geglu_autograd (a : Tensor , b : Tensor ) -> Tensor :
169+ """GEGLU with forward + backward support."""
170+ return GEGLUFunction .apply (a , b ) # type: ignore[no-any-return]
171+
172+
107173# %%
108174# GEGLU MLP Module (matches liger_kernel structure)
109175# -------------------------------------------------
@@ -167,9 +233,6 @@ def check_geglu_kernel(shape: tuple[int, ...]) -> None:
167233 Args:
168234 shape: Shape of the input tensors to test.
169235 """
170- # Create test tensors
171- a = torch .randn (shape , device = DEVICE , dtype = torch .float16 )
172- b = torch .randn (shape , device = DEVICE , dtype = torch .float16 )
173236
174237 def baseline_geglu (a : Tensor , b : Tensor ) -> Tensor :
175238 """
@@ -178,8 +241,26 @@ def baseline_geglu(a: Tensor, b: Tensor) -> Tensor:
178241 """
179242 return nn .functional .gelu (a , approximate = "tanh" ).to (b .dtype ) * b
180243
244+ print ("\n === Forward Pass Test ===" )
245+ a = torch .randn (shape , device = DEVICE , dtype = torch .float16 )
246+ b = torch .randn (shape , device = DEVICE , dtype = torch .float16 )
181247 run_example (geglu , baseline_geglu , (a , b ))
182248
249+ # Test forward + backward pass
250+ print ("\n \n === Forward + Backward Pass Test ===" )
251+ a_grad = torch .randn (shape , device = DEVICE , dtype = torch .float16 , requires_grad = True )
252+ b_grad = torch .randn (shape , device = DEVICE , dtype = torch .float16 , requires_grad = True )
253+ run_example (
254+ geglu_autograd ,
255+ baseline_geglu ,
256+ (a_grad , b_grad ),
257+ kernel_name = "helion_autograd" ,
258+ baseline_name = "torch" ,
259+ rtol = 1e-2 ,
260+ atol = 1e-1 ,
261+ bwd = True ,
262+ )
263+
183264
184265class BaselineMLP (nn .Module ):
185266 def __init__ (self , config : Config ) -> None :
@@ -303,11 +384,11 @@ def main() -> None:
303384 kernel_test_shapes = [(8 , 2048 , 4096 ), (8 , 4096 , 8192 )]
304385
305386 for shape in kernel_test_shapes :
306- print (f"Testing GEGLU kernel shape: { shape } " )
387+ print (f"\n Testing GEGLU kernel shape: { shape } " )
307388 check_geglu_kernel (shape )
308389 print (f"✓ GEGLU kernel shape { shape } passed" )
309390
310- print ("\n Testing GEGLU MLP..." )
391+ print ("\n \ n Testing GEGLU MLP..." )
311392
312393 # Test GEGLU MLP with transformer-typical sizes
313394 mlp_test_configs = [
@@ -317,7 +398,7 @@ def main() -> None:
317398
318399 for batch_size , seq_len , hidden_size , intermediate_size in mlp_test_configs :
319400 print (
320- f"Testing GEGLU MLP: B={ batch_size } , T={ seq_len } , H={ hidden_size } , I={ intermediate_size } "
401+ f"\n Testing GEGLU MLP: B={ batch_size } , T={ seq_len } , H={ hidden_size } , I={ intermediate_size } "
321402 )
322403 check_geglu_mlp (batch_size , seq_len , hidden_size , intermediate_size )
323404 print ("✓ GEGLU MLP config passed" )
0 commit comments