Skip to content

Commit a0a2c8d

Browse files
authored
Merge pull request #44 from aws-neuron/feat/deltanet
example: add linear attention (deltanet) demo
2 parents 2f709b5 + 2aba9bc commit a0a2c8d

1 file changed

Lines changed: 238 additions & 0 deletions

File tree

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
#!/usr/bin/env python3
2+
"""
3+
DeltaNet Linear Attention Example using NKIPy
4+
5+
DeltaNet applies the "delta rule" to linear attention, replacing the softmax
6+
with a recurrent state update that achieves O(N*D^2) complexity instead of
7+
O(N^2*D). For each timestep t:
8+
9+
S_t = S_{t-1} + beta_t * (v_t - S_{t-1} @ k_t) outer k_t # state update
10+
o_t = S_t @ q_t # output
11+
12+
This example provides:
13+
1. A PyTorch reference implementation for correctness validation
14+
2. An NKIPy kernel using pure NumPy ops (the timestep loop is unrolled at trace time)
15+
3. Optional on-device compilation and benchmarking
16+
"""
17+
18+
import time
19+
20+
import numpy as np
21+
22+
try:
23+
import torch
24+
25+
TORCH_AVAILABLE = True
26+
except ImportError:
27+
TORCH_AVAILABLE = False
28+
29+
from nkipy.core import tensor_apis
30+
from nkipy.runtime import DeviceKernel, DeviceTensor, is_neuron_compatible
31+
32+
33+
def deltanet_pytorch(q, k, v, beta):
34+
"""
35+
PyTorch reference for DeltaNet recurrent linear attention.
36+
37+
Args:
38+
q: queries [B, H, L, D]
39+
k: keys [B, H, L, D] (will be L2-normalized)
40+
v: values [B, H, L, D]
41+
beta: gates [B, H, L] in (0, 1), post-sigmoid
42+
43+
Returns:
44+
output [B, H, L, D]
45+
"""
46+
B, H, L, D = q.shape
47+
48+
# L2-normalize keys
49+
k = k / torch.clamp(torch.norm(k, dim=-1, keepdim=True), min=1e-6)
50+
51+
S = torch.zeros(B, H, D, D, dtype=q.dtype, device=q.device)
52+
outputs = []
53+
54+
for t in range(L):
55+
q_t = q[:, :, t, :] # [B, H, D]
56+
k_t = k[:, :, t, :] # [B, H, D]
57+
v_t = v[:, :, t, :] # [B, H, D]
58+
beta_t = beta[:, :, t] # [B, H]
59+
60+
# delta = beta_t * (v_t - S @ k_t)
61+
Sk = torch.einsum("bhde,bhe->bhd", S, k_t) # [B, H, D]
62+
delta = beta_t.unsqueeze(-1) * (v_t - Sk) # [B, H, D]
63+
64+
# S += delta outer k_t
65+
S = S + torch.einsum("bhd,bhe->bhde", delta, k_t) # [B, H, D, D]
66+
67+
# o_t = S @ q_t
68+
o_t = torch.einsum("bhde,bhe->bhd", S, q_t) # [B, H, D]
69+
outputs.append(o_t.unsqueeze(2))
70+
71+
return torch.cat(outputs, dim=2) # [B, H, L, D]
72+
73+
74+
def deltanet_nkipy(q, k, v, beta_logits):
75+
"""
76+
NKIPy kernel for DeltaNet recurrent linear attention.
77+
78+
Args:
79+
q: queries [B, H, L, D] float32
80+
k: keys [B, H, L, D] float32 (will be L2-normalized)
81+
v: values [B, H, L, D] float32
82+
beta_logits: gate logits [B, H, L] float32 (pre-sigmoid)
83+
84+
Returns:
85+
output [B, H, L, D] float32
86+
"""
87+
B, H, L, D = q.shape
88+
89+
# Sigmoid activation: beta = 1 / (1 + exp(-x))
90+
beta = 1.0 / (1.0 + np.exp(-beta_logits))
91+
92+
# L2-normalize keys
93+
k_norm = np.linalg.norm(k, axis=-1, keepdims=True)
94+
k = k / np.maximum(k_norm, 1e-6)
95+
96+
# Initialize state [B, H, D, D]
97+
# Use tensor_apis.zeros so this works during both CPU and HLO tracing
98+
S = tensor_apis.zeros((B, H, D, D), dtype=q.dtype)
99+
100+
outputs = []
101+
for t in range(L):
102+
q_t = q[:, :, t, :] # [B, H, D]
103+
k_t = k[:, :, t, :] # [B, H, D]
104+
v_t = v[:, :, t, :] # [B, H, D]
105+
beta_t = beta[:, :, t] # [B, H]
106+
107+
# S @ k_t: matmul state with key vector
108+
# [B, H, D, D] @ [B, H, D, 1] -> [B, H, D, 1] -> [B, H, D]
109+
k_col = np.expand_dims(k_t, axis=-1) # [B, H, D, 1]
110+
Sk = np.matmul(S, k_col)[:, :, :, 0] # [B, H, D]
111+
112+
# delta = beta_t * (v_t - Sk)
113+
beta_2d = np.expand_dims(beta_t, axis=-1) # [B, H, 1]
114+
delta = beta_2d * (v_t - Sk) # [B, H, D]
115+
116+
# Batched outer product: delta outer k_t -> [B, H, D, D]
117+
outer = np.expand_dims(delta, axis=-1) * np.expand_dims(k_t, axis=-2)
118+
119+
# State update
120+
S = S + outer
121+
122+
# Output: S @ q_t -> [B, H, D]
123+
q_col = np.expand_dims(q_t, axis=-1) # [B, H, D, 1]
124+
o_t = np.matmul(S, q_col)[:, :, :, 0] # [B, H, D]
125+
126+
outputs.append(np.expand_dims(o_t, axis=2)) # [B, H, 1, D]
127+
128+
return np.concatenate(outputs, axis=2) # [B, H, L, D]
129+
130+
131+
def main():
132+
print("=" * 80)
133+
print("DeltaNet Linear Attention Example")
134+
print("=" * 80)
135+
136+
# Configuration
137+
B, H, L, D = 1, 4, 64, 32
138+
dtype = np.float32
139+
140+
print(f"\nConfiguration: B={B}, H={H}, L={L}, D={D}, dtype={dtype.__name__}")
141+
142+
# Create random inputs
143+
print("\n[1/5] Creating test data...")
144+
np.random.seed(42)
145+
q = np.random.randn(B, H, L, D).astype(dtype) * 0.1
146+
k = np.random.randn(B, H, L, D).astype(dtype) * 0.1
147+
v = np.random.randn(B, H, L, D).astype(dtype) * 0.1
148+
beta_logits = np.random.randn(B, H, L).astype(dtype) # pre-sigmoid
149+
150+
# PyTorch reference
151+
if TORCH_AVAILABLE:
152+
print("\n[2/5] Running PyTorch reference...")
153+
q_pt = torch.from_numpy(q)
154+
k_pt = torch.from_numpy(k)
155+
v_pt = torch.from_numpy(v)
156+
beta_pt = torch.sigmoid(torch.from_numpy(beta_logits))
157+
ref_output = deltanet_pytorch(q_pt, k_pt, v_pt, beta_pt).numpy()
158+
print(f" PyTorch output shape: {ref_output.shape}")
159+
else:
160+
print("\n[2/5] PyTorch not available, skipping reference...")
161+
ref_output = None
162+
163+
# NKIPy CPU execution (pure numpy)
164+
print("\n[3/5] Running NKIPy kernel (CPU mode)...")
165+
cpu_output = deltanet_nkipy(q, k, v, beta_logits)
166+
print(f" NKIPy CPU output shape: {cpu_output.shape}")
167+
168+
# Compare CPU vs PyTorch
169+
if ref_output is not None:
170+
print("\n[4/5] Validating CPU correctness against PyTorch...")
171+
try:
172+
np.testing.assert_allclose(cpu_output, ref_output, rtol=1e-4, atol=1e-4)
173+
max_err = np.max(np.abs(cpu_output - ref_output))
174+
print(f" PASSED - max absolute error: {max_err:.2e}")
175+
except AssertionError as e:
176+
print(f" FAILED: {e}")
177+
return
178+
else:
179+
print("\n[4/5] Skipping validation (no PyTorch reference)...")
180+
181+
# On-device execution
182+
if is_neuron_compatible():
183+
print("\n[5/5] Compiling and running on Neuron hardware...")
184+
compile_start = time.time()
185+
kernel = DeviceKernel.compile_and_load(
186+
deltanet_nkipy,
187+
q,
188+
k,
189+
v,
190+
beta_logits,
191+
name="deltanet_kernel",
192+
use_cached_if_exists=False,
193+
)
194+
compile_time = time.time() - compile_start
195+
print(f" Compiled in {compile_time:.2f}s")
196+
197+
# Create device tensors
198+
d_q = DeviceTensor.from_numpy(q)
199+
d_k = DeviceTensor.from_numpy(k)
200+
d_v = DeviceTensor.from_numpy(v)
201+
d_beta = DeviceTensor.from_numpy(beta_logits)
202+
d_out = DeviceTensor.from_numpy(np.zeros_like(cpu_output))
203+
204+
kernel(
205+
inputs={"q": d_q, "k": d_k, "v": d_v, "beta_logits": d_beta},
206+
outputs={"output0": d_out},
207+
)
208+
device_output = d_out.numpy()
209+
210+
try:
211+
np.testing.assert_allclose(device_output, cpu_output, rtol=1e-2, atol=1e-2)
212+
max_err = np.max(np.abs(device_output - cpu_output))
213+
print(f" Device output matches CPU - max error: {max_err:.2e}")
214+
except AssertionError as e:
215+
print(f" Device validation failed: {e}")
216+
return
217+
218+
# Benchmark
219+
stats = kernel.benchmark(
220+
inputs={"q": d_q, "k": d_k, "v": d_v, "beta_logits": d_beta},
221+
outputs={"output0": d_out},
222+
warmup_iter=5,
223+
benchmark_iter=10,
224+
)
225+
print(
226+
f"\n Performance: mean={stats.mean_ms:.3f}ms, "
227+
f"min={stats.min_ms:.3f}ms, max={stats.max_ms:.3f}ms"
228+
)
229+
else:
230+
print("\n[5/5] No Neuron hardware detected, skipping on-device execution.")
231+
232+
print(f"\n{'=' * 80}")
233+
print("Example completed successfully!")
234+
print("=" * 80)
235+
236+
237+
if __name__ == "__main__":
238+
main()

0 commit comments

Comments
 (0)