Skip to content

Commit 5aad251

Browse files
authored
add geglu backward (#1069)
1 parent a0ba4d2 commit 5aad251

File tree

3 files changed

+210
-6
lines changed

3 files changed

+210
-6
lines changed

examples/geglu.py

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
if TYPE_CHECKING:
3838
from collections.abc import Callable
39+
from typing import Any
3940

4041

4142
# %%
@@ -104,6 +105,71 @@ def geglu(a: Tensor, b: Tensor) -> Tensor:
104105
return out
105106

106107

108+
@helion.kernel()
109+
def geglu_bwd(grad_out: Tensor, a: Tensor, b: Tensor) -> tuple[Tensor, Tensor]:
110+
grad_a = torch.empty_like(a)
111+
grad_b = torch.empty_like(b)
112+
113+
grad_out_flat = grad_out.view(-1)
114+
a_flat = a.view(-1)
115+
b_flat = b.view(-1)
116+
grad_a_flat = grad_a.view(-1)
117+
grad_b_flat = grad_b.view(-1)
118+
119+
for tile_idx in hl.tile(a.numel()):
120+
a_vals = a_flat[tile_idx].to(torch.float32)
121+
b_vals = b_flat[tile_idx].to(torch.float32)
122+
grad_out_vals = grad_out_flat[tile_idx].to(torch.float32)
123+
124+
sqrt_2_over_pi = 0.7978845608028654
125+
126+
a_cubed = a_vals * a_vals * a_vals
127+
tanh_arg = sqrt_2_over_pi * (a_vals + 0.044715 * a_cubed)
128+
tanh_result = torch.tanh(tanh_arg)
129+
gelu_a = 0.5 * a_vals * (1.0 + tanh_result)
130+
131+
grad_b_vals = grad_out_vals * gelu_a
132+
grad_b_flat[tile_idx] = grad_b_vals.to(b.dtype)
133+
134+
dz_da = sqrt_2_over_pi * (1.0 + 0.134145 * a_vals * a_vals)
135+
sech_sq = 1.0 - tanh_result * tanh_result
136+
137+
dgelu_da = 0.5 * (1.0 + tanh_result) + 0.5 * a_vals * sech_sq * dz_da
138+
139+
grad_a_vals = grad_out_vals * b_vals * dgelu_da
140+
grad_a_flat[tile_idx] = grad_a_vals.to(a.dtype)
141+
142+
return grad_a, grad_b
143+
144+
145+
class GEGLUFunction(torch.autograd.Function):
146+
@staticmethod
147+
def forward(
148+
ctx: Any, # noqa: ANN401
149+
a: Tensor,
150+
b: Tensor,
151+
) -> Tensor:
152+
"""Forward pass for GEGLU."""
153+
out = geglu(a, b)
154+
ctx.save_for_backward(a, b)
155+
return out
156+
157+
@staticmethod
158+
def backward( # type: ignore[override]
159+
ctx: Any, # noqa: ANN401
160+
grad_out: Tensor,
161+
) -> tuple[Tensor, Tensor]:
162+
"""Backward pass for GEGLU."""
163+
a, b = ctx.saved_tensors
164+
grad_a, grad_b = geglu_bwd(grad_out, a, b)
165+
return grad_a, grad_b
166+
167+
168+
def geglu_autograd(a: Tensor, b: Tensor) -> Tensor:
169+
"""GEGLU with forward + backward support."""
170+
return GEGLUFunction.apply(a, b) # type: ignore[no-any-return]
171+
172+
107173
# %%
108174
# GEGLU MLP Module (matches liger_kernel structure)
109175
# -------------------------------------------------
@@ -167,9 +233,6 @@ def check_geglu_kernel(shape: tuple[int, ...]) -> None:
167233
Args:
168234
shape: Shape of the input tensors to test.
169235
"""
170-
# Create test tensors
171-
a = torch.randn(shape, device=DEVICE, dtype=torch.float16)
172-
b = torch.randn(shape, device=DEVICE, dtype=torch.float16)
173236

174237
def baseline_geglu(a: Tensor, b: Tensor) -> Tensor:
175238
"""
@@ -178,8 +241,26 @@ def baseline_geglu(a: Tensor, b: Tensor) -> Tensor:
178241
"""
179242
return nn.functional.gelu(a, approximate="tanh").to(b.dtype) * b
180243

244+
print("\n=== Forward Pass Test ===")
245+
a = torch.randn(shape, device=DEVICE, dtype=torch.float16)
246+
b = torch.randn(shape, device=DEVICE, dtype=torch.float16)
181247
run_example(geglu, baseline_geglu, (a, b))
182248

249+
# Test forward + backward pass
250+
print("\n\n=== Forward + Backward Pass Test ===")
251+
a_grad = torch.randn(shape, device=DEVICE, dtype=torch.float16, requires_grad=True)
252+
b_grad = torch.randn(shape, device=DEVICE, dtype=torch.float16, requires_grad=True)
253+
run_example(
254+
geglu_autograd,
255+
baseline_geglu,
256+
(a_grad, b_grad),
257+
kernel_name="helion_autograd",
258+
baseline_name="torch",
259+
rtol=1e-2,
260+
atol=1e-1,
261+
bwd=True,
262+
)
263+
183264

184265
class BaselineMLP(nn.Module):
185266
def __init__(self, config: Config) -> None:
@@ -303,11 +384,11 @@ def main() -> None:
303384
kernel_test_shapes = [(8, 2048, 4096), (8, 4096, 8192)]
304385

305386
for shape in kernel_test_shapes:
306-
print(f"Testing GEGLU kernel shape: {shape}")
387+
print(f"\nTesting GEGLU kernel shape: {shape}")
307388
check_geglu_kernel(shape)
308389
print(f"✓ GEGLU kernel shape {shape} passed")
309390

310-
print("\nTesting GEGLU MLP...")
391+
print("\n\nTesting GEGLU MLP...")
311392

312393
# Test GEGLU MLP with transformer-typical sizes
313394
mlp_test_configs = [
@@ -317,7 +398,7 @@ def main() -> None:
317398

318399
for batch_size, seq_len, hidden_size, intermediate_size in mlp_test_configs:
319400
print(
320-
f"Testing GEGLU MLP: B={batch_size}, T={seq_len}, H={hidden_size}, I={intermediate_size}"
401+
f"\nTesting GEGLU MLP: B={batch_size}, T={seq_len}, H={hidden_size}, I={intermediate_size}"
321402
)
322403
check_geglu_mlp(batch_size, seq_len, hidden_size, intermediate_size)
323404
print("✓ GEGLU MLP config passed")

test/test_examples.expected

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,6 +1756,105 @@ def geglu(a: Tensor, b: Tensor, *, _launcher=_default_launcher):
17561756
# src[geglu.py:N]: return out
17571757
return out
17581758

1759+
--- assertExpectedJournal(TestExamples.test_geglu_bwd)
1760+
from __future__ import annotations
1761+
1762+
import torch
1763+
import triton
1764+
import triton.language as tl
1765+
from helion.runtime import default_launcher as _default_launcher
1766+
1767+
@triton.jit
1768+
def _helion_geglu_bwd(a_flat, b_flat, grad_out_flat, grad_b_flat, grad_a_flat, _BLOCK_SIZE_0: tl.constexpr):
1769+
# src[geglu.py:N]: for tile_idx in hl.tile(a.numel()):
1770+
pid_0 = tl.program_id(0)
1771+
offset_0 = pid_0 * _BLOCK_SIZE_0
1772+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1773+
# src[geglu.py:N]: a_vals = a_flat[tile_idx].to(torch.float32)
1774+
load = tl.load(a_flat + indices_0 * 1, None)
1775+
v_0 = tl.cast(load, tl.float32)
1776+
# src[geglu.py:N]: b_vals = b_flat[tile_idx].to(torch.float32)
1777+
load_1 = tl.load(b_flat + indices_0 * 1, None)
1778+
v_1 = tl.cast(load_1, tl.float32)
1779+
# src[geglu.py:N]: grad_out_vals = grad_out_flat[tile_idx].to(torch.float32)
1780+
load_2 = tl.load(grad_out_flat + indices_0 * 1, None)
1781+
v_2 = tl.cast(load_2, tl.float32)
1782+
# src[geglu.py:N]: a_cubed = a_vals * a_vals * a_vals
1783+
v_3 = v_0 * v_0
1784+
v_4 = v_3 * v_0
1785+
# src[geglu.py:N]: tanh_arg = sqrt_2_over_pi * (a_vals + 0.044715 * a_cubed)
1786+
v_5 = 0.044715
1787+
v_6 = v_4 * v_5
1788+
v_7 = v_0 + v_6
1789+
v_8 = 0.7978845608028654
1790+
v_9 = v_7 * v_8
1791+
# src[geglu.py:N]: tanh_result = torch.tanh(tanh_arg)
1792+
v_10 = libdevice.tanh(v_9)
1793+
# src[geglu.py:N]: gelu_a = 0.5 * a_vals * (1.0 + tanh_result)
1794+
v_11 = 0.5
1795+
v_12 = v_0 * v_11
1796+
v_13 = 1.0
1797+
v_14 = v_10 + v_13
1798+
v_15 = v_12 * v_14
1799+
# src[geglu.py:N]: grad_b_vals = grad_out_vals * gelu_a
1800+
v_16 = v_2 * v_15
1801+
# src[geglu.py:N]: grad_b_flat[tile_idx] = grad_b_vals.to(b.dtype)
1802+
v_17 = tl.cast(v_16, tl.bfloat16)
1803+
tl.store(grad_b_flat + indices_0 * 1, v_17, None)
1804+
# src[geglu.py:N]: dz_da = sqrt_2_over_pi * (1.0 + 0.134145 * a_vals * a_vals)
1805+
v_18 = 0.134145
1806+
v_19 = v_0 * v_18
1807+
v_20 = v_19 * v_0
1808+
v_21 = 1.0
1809+
v_22 = v_20 + v_21
1810+
v_23 = 0.7978845608028654
1811+
v_24 = v_22 * v_23
1812+
# src[geglu.py:N]: sech_sq = 1.0 - tanh_result * tanh_result
1813+
v_25 = v_10 * v_10
1814+
v_26 = 1.0
1815+
v_27 = v_26 - v_25
1816+
# src[geglu.py:N]: dgelu_da = 0.5 * (1.0 + tanh_result) + 0.5 * a_vals * sech_sq * dz_da
1817+
v_28 = 1.0
1818+
v_29 = v_10 + v_28
1819+
v_30 = 0.5
1820+
v_31 = v_29 * v_30
1821+
v_32 = 0.5
1822+
v_33 = v_0 * v_32
1823+
v_34 = v_33 * v_27
1824+
v_35 = v_34 * v_24
1825+
v_36 = v_31 + v_35
1826+
# src[geglu.py:N]: grad_a_vals = grad_out_vals * b_vals * dgelu_da
1827+
v_37 = v_2 * v_1
1828+
v_38 = v_37 * v_36
1829+
# src[geglu.py:N]: grad_a_flat[tile_idx] = grad_a_vals.to(a.dtype)
1830+
v_39 = tl.cast(v_38, tl.bfloat16)
1831+
tl.store(grad_a_flat + indices_0 * 1, v_39, None)
1832+
1833+
def geglu_bwd(grad_out: Tensor, a: Tensor, b: Tensor, *, _launcher=_default_launcher):
1834+
# src[geglu.py:N]: grad_a = torch.empty_like(a)
1835+
grad_a = torch.empty_like(a)
1836+
# src[geglu.py:N]: grad_b = torch.empty_like(b)
1837+
grad_b = torch.empty_like(b)
1838+
# src[geglu.py:N]: grad_out_flat = grad_out.view(-1)
1839+
grad_out_flat = grad_out.view(-1)
1840+
# src[geglu.py:N]: a_flat = a.view(-1)
1841+
a_flat = a.view(-1)
1842+
# src[geglu.py:N]: b_flat = b.view(-1)
1843+
b_flat = b.view(-1)
1844+
# src[geglu.py:N]: grad_a_flat = grad_a.view(-1)
1845+
grad_a_flat = grad_a.view(-1)
1846+
# src[geglu.py:N]: grad_b_flat = grad_b.view(-1)
1847+
grad_b_flat = grad_b.view(-1)
1848+
# src[geglu.py:N]: for tile_idx in hl.tile(a.numel()):
1849+
_BLOCK_SIZE_0 = 16
1850+
# src[geglu.py:N]: for tile_idx in hl.tile(a.numel()):
1851+
# src[geglu.py:N]: a_vals = a_flat[tile_idx].to(torch.float32)
1852+
# src[geglu.py:N]: b_vals = b_flat[tile_idx].to(torch.float32)
1853+
# src[geglu.py:N-N]: ...
1854+
_launcher(_helion_geglu_bwd, (triton.cdiv(1024, _BLOCK_SIZE_0),), a_flat, b_flat, grad_out_flat, grad_b_flat, grad_a_flat, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
1855+
# src[geglu.py:N]: return grad_a, grad_b
1856+
return (grad_a, grad_b)
1857+
17591858
--- assertExpectedJournal(TestExamples.test_grouped_gemm_jagged)
17601859
from __future__ import annotations
17611860

test/test_examples.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,30 @@ def test_geglu(self):
12271227
)
12281228
)
12291229

1230+
def test_geglu_bwd(self):
1231+
x1, x2 = [
1232+
torch.randn(1024, device=DEVICE, dtype=torch.bfloat16, requires_grad=True)
1233+
for _ in range(2)
1234+
]
1235+
1236+
out = torch.nn.functional.gelu(x1, approximate="tanh") * x2
1237+
grad_out = torch.randn_like(out)
1238+
out.backward(grad_out)
1239+
1240+
args = (grad_out, x1, x2)
1241+
1242+
self.assertExpectedJournal(
1243+
check_example(
1244+
"geglu",
1245+
args,
1246+
(x1.grad, x2.grad),
1247+
fn_name="geglu_bwd",
1248+
block_sizes=[16],
1249+
num_warps=4,
1250+
num_stages=3,
1251+
)
1252+
)
1253+
12301254
def test_swiglu(self):
12311255
args = (
12321256
torch.randn([1024, 1024], device=DEVICE, dtype=torch.float16),

0 commit comments

Comments
 (0)