diff --git a/kernels/chunk_gated_delta_h.py b/kernels/chunk_gated_delta_h.py new file mode 100644 index 00000000..740535fb --- /dev/null +++ b/kernels/chunk_gated_delta_h.py @@ -0,0 +1,773 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +""" +Gated Delta Net K5 hidden-state recurrence kernel using the @flyc.kernel API. + +Mirrors the Triton `chunk_gated_delta_rule_fwd_kernel_h_opt3` from ATOM/FLA, +rewritten in FlyDSL for AMD GPUs (gfx942/gfx950). + +For each chunk t (serial over NT chunks): + 1. Store h snapshot for downstream K6 + 2. v_new = u - w @ h (delta correction via MFMA) + 3. Gated decay + state update: + v_new *= exp(g_last - g_cumsum) + h = h * exp(g_last) + k^T @ v_new +""" + +import functools +import math + +import torch +import triton + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.expr.typing import T +from flydsl.expr import range_constexpr, arith, vector, gpu, rocdl, buffer_ops +from flydsl._mlir import ir +from flydsl._mlir.dialects import scf, math as math_dialect, llvm as _llvm +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.compiler.kernel_function import CompilationContext +from flydsl.compiler.protocol import fly_values +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr + +from kernels.tensor_shim import GTensor, STensor, _to_raw + +_LOG2E = math.log2(math.e) # 1.4426950408889634 +_LLVM_GEP_DYNAMIC = -2147483648 + + +def _llvm_lds_ptr_ty(): + return ir.Type.parse("!llvm.ptr<3>") + + +def _llvm_exp2_f32(x): + """Emit llvm.exp2.f32 intrinsic directly (maps to single v_exp_f32 on AMD).""" + x_raw = _to_raw(x) + return _llvm.call_intrinsic( + ir.F32Type.get(), "llvm.exp2.f32", [x_raw], [], [] + ) + + +def _fast_exp(x): + """exp(x) via exp2(x * log2(e)) using the LLVM intrinsic.""" + log2e = arith.constant(_LOG2E, type=T.f32) + return _llvm_exp2_f32(arith.mulf(x, log2e)) + + +def _mfma_bf16_16x16x32(a_bf16x8, b_bf16x8, acc_f32x4): + """Single mfma_f32_16x16x32_bf16 instruction.""" + return rocdl.mfma_f32_16x16x32_bf16( + T.f32x4, a_bf16x8, b_bf16x8, acc_f32x4, 0, 0, 0 + ).res + + +# ── Utility helpers ────────────────────────────────────────────────────── + +def _prepare_lens(cu_seqlens): + return cu_seqlens[1:] - cu_seqlens[:-1] + + +@functools.lru_cache(maxsize=8) +def _prepare_chunk_offsets(cu_seqlens_id, chunk_size, device): + cu_seqlens = torch._dynamo.utils.get_fake_value(cu_seqlens_id) if hasattr(torch._dynamo, 'utils') else None + return None + + +def prepare_chunk_offsets(cu_seqlens, chunk_size): + lens = _prepare_lens(cu_seqlens) + return torch.cat([ + cu_seqlens.new_tensor([0]), + triton.cdiv(lens, chunk_size), + ]).cumsum(-1) + + +# ── Compile the kernel ─────────────────────────────────────────────────── + +def compile_chunk_gated_delta_h( + *, + K: int, + V: int, + BT: int = 64, + BV: int = 32, + H: int, + Hg: int, + USE_G: bool = True, + USE_INITIAL_STATE: bool = True, + STORE_FINAL_STATE: bool = True, + SAVE_NEW_VALUE: bool = True, + IS_VARLEN: bool = True, + WU_CONTIGUOUS: bool = True, +): + """Compile the GDN K5 kernel. + + Returns a @flyc.jit function: + launch_fn(k, v, w, v_new, g, h, h0, ht, + cu_seqlens, chunk_offsets, + T_val, T_flat, N_val, stream) + """ + assert K <= 256 + assert K % 64 == 0 + assert BV % 16 == 0 + NUM_K_BLOCKS = K // 64 + + WARP_SIZE = 64 + NUM_WARPS = 4 + BLOCK_THREADS = NUM_WARPS * WARP_SIZE + + WMMA_M = 16 + WMMA_N = 16 + WMMA_K = 32 + WMMA_C_FRAG = 4 + + M_REPEAT = BT // WMMA_M + N_REPEAT = BV // WMMA_N + + NUM_H_ACCS = NUM_K_BLOCKS * N_REPEAT + + # ── LDS layout: w and k store all K-blocks to reduce barriers ── + LDS_W_STRIDE = K + LDS_W_ELEMS = BT * LDS_W_STRIDE + LDS_W_BYTES = LDS_W_ELEMS * 2 + + LDS_K_STRIDE = K + LDS_K_ELEMS = BT * LDS_K_STRIDE + LDS_K_BYTES = LDS_K_ELEMS * 2 + + LDS_VN_STRIDE = BV + LDS_VN_ELEMS = BT * LDS_VN_STRIDE + LDS_VN_BYTES = LDS_VN_ELEMS * 2 + + LDS_H_STRIDE = BV + LDS_H_ELEMS = K * LDS_H_STRIDE + LDS_H_BYTES = LDS_H_ELEMS * 2 + + allocator = SmemAllocator(None, arch="gfx942", global_sym_name="gdn_h_smem") + lds_w_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = lds_w_offset + LDS_W_BYTES + lds_k_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = lds_k_offset + LDS_K_BYTES + lds_vn_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = lds_vn_offset + LDS_VN_BYTES + lds_h_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = lds_h_offset + LDS_H_BYTES + + # Cooperative load parameters + LOAD_VEC_WIDTH = 8 # 8 bf16 = 16 bytes = buffer_load_dwordx4 + THREADS_PER_ROW_64 = 64 // LOAD_VEC_WIDTH # 8 + ROWS_PER_BATCH_64 = BLOCK_THREADS // THREADS_PER_ROW_64 # 32 + NUM_LOAD_BATCHES_64 = BT // ROWS_PER_BATCH_64 # 2 + + @flyc.kernel(name="chunk_gdn_fwd_h_opt3") + def gdn_h_kernel( + k_tensor: fx.Tensor, + v_tensor: fx.Tensor, + w_tensor: fx.Tensor, + v_new_tensor: fx.Tensor, + g_tensor: fx.Tensor, + h_tensor: fx.Tensor, + h0_tensor: fx.Tensor, + ht_tensor: fx.Tensor, + cu_seqlens_tensor: fx.Tensor, + chunk_offsets_tensor: fx.Tensor, + T_val: fx.Int32, + T_flat: fx.Int32, + N_val: fx.Int32, + ): + i_v = arith.index_cast(T.i32, gpu.block_id("x")) + i_nh = arith.index_cast(T.i32, gpu.block_id("y")) + i_n = i_nh // fx.Int32(H) + i_h = i_nh % fx.Int32(H) + + tid = arith.index_cast(T.i32, gpu.thread_id("x")) + wid = tid // fx.Int32(WARP_SIZE) + lane = tid % fx.Int32(WARP_SIZE) + + k_ = GTensor(k_tensor, dtype=T.bf16, shape=(-1,)) + v_ = GTensor(v_tensor, dtype=T.bf16, shape=(-1,)) + w_ = GTensor(w_tensor, dtype=T.bf16, shape=(-1,)) + h_ = GTensor(h_tensor, dtype=T.bf16, shape=(-1,)) + g_ = GTensor(g_tensor, dtype=T.f32, shape=(-1,)) + + vn_ = GTensor(v_new_tensor, dtype=T.bf16, shape=(-1,)) + if USE_INITIAL_STATE: + h0_ = GTensor(h0_tensor, dtype=T.f32, shape=(-1,)) + if STORE_FINAL_STATE: + ht_ = GTensor(ht_tensor, dtype=T.f32, shape=(-1,)) + + if IS_VARLEN: + cu_ = GTensor(cu_seqlens_tensor, dtype=T.i32, shape=(-1,)) + co_ = GTensor(chunk_offsets_tensor, dtype=T.i32, shape=(-1,)) + + # ── LDS views ── + lds_base_ptr = allocator.get_base() + + # w tile (bf16) — separate from k + lds_w_ptr = SmemPtr(lds_base_ptr, lds_w_offset, T.bf16, shape=(LDS_W_ELEMS,)) + lds_w = STensor(lds_w_ptr, dtype=T.bf16, shape=(LDS_W_ELEMS,)) + + # k tile (bf16) — separate from w + lds_k_ptr = SmemPtr(lds_base_ptr, lds_k_offset, T.bf16, shape=(LDS_K_ELEMS,)) + lds_k = STensor(lds_k_ptr, dtype=T.bf16, shape=(LDS_K_ELEMS,)) + + # gated v_new (bf16) + lds_vn_ptr = SmemPtr(lds_base_ptr, lds_vn_offset, T.bf16, shape=(LDS_VN_ELEMS,)) + lds_vn = STensor(lds_vn_ptr, dtype=T.bf16, shape=(LDS_VN_ELEMS,)) + + # h snapshot (bf16) + lds_h_ptr = SmemPtr(lds_base_ptr, lds_h_offset, T.bf16, shape=(LDS_H_ELEMS,)) + lds_h = STensor(lds_h_ptr, dtype=T.bf16, shape=(LDS_H_ELEMS,)) + + # ── Cooperative load decomposition ── + load_row_in_batch = tid // fx.Int32(THREADS_PER_ROW_64) + load_col_base = (tid % fx.Int32(THREADS_PER_ROW_64)) * fx.Int32(LOAD_VEC_WIDTH) + + # ── XOR swizzle: col ^ ((row & 7) << 3) at 8-element granularity for bf16 ── + def _xor_swizzle(row, col): + return col ^ ((row & fx.Int32(0x7)) << fx.Int32(3)) + + def _xor_swizzle_idx(row, col): + return col ^ ((row & arith.index(0x7)) << arith.index(3)) + + # ── LDS vector read helpers (generates ds_read_b128 for 8xbf16) ── + v8bf16_type = T.vec(8, T.bf16) + lds_w_memref = lds_w_ptr.get() + lds_k_memref = lds_k_ptr.get() + + def _lds_vec_read_w_bf16x8(elem_idx): + return vector.load_op(v8bf16_type, lds_w_memref, [elem_idx]) + + def _lds_vec_read_k_bf16x8(elem_idx): + return vector.load_op(v8bf16_type, lds_k_memref, [elem_idx]) + + # ── ds_read_b64_tr_b16 helper (gfx950) ── + v4bf16_type = T.vec(4, T.bf16) + + def _ds_read_tr_bf16x4(lds_byte_offset): + byte_idx = arith.index_cast(T.index, lds_byte_offset) + byte_i64 = arith.index_cast(T.i64, byte_idx) + ptr = _llvm.IntToPtrOp(_llvm_lds_ptr_ty(), byte_i64).result + return rocdl.ds_read_tr16_b64(v4bf16_type, ptr).result + + # ds_read_b64_tr_b16 lane decomposition + tr_k_group = (lane % fx.Int32(16)) // fx.Int32(4) + tr_col_sub = lane % fx.Int32(4) + tr_col_half = (lane % fx.Int32(32)) // fx.Int32(16) + lane_div_32 = lane // fx.Int32(32) + + # ── Prologue: compute bos, T_local, NT, boh ── + if IS_VARLEN: + bos = cu_[fx.Index(i_n)] + eos = cu_[fx.Index(i_n) + fx.Index(1)] + T_local = eos - bos + NT = (T_local + fx.Int32(BT - 1)) // fx.Int32(BT) + boh = co_[fx.Index(i_n)] + else: + bos = i_n * T_val + T_local = T_val + NT = (T_local + fx.Int32(BT - 1)) // fx.Int32(BT) + boh = i_n * NT + + # ── Base pointer offsets (element counts) ── + # h: [B, NT, H, K, V] — base = (boh*H + i_h) * K * V + h_base = (boh * fx.Int32(H) + i_h) * fx.Int32(K * V) + stride_h = fx.Int32(H * K * V) + + # k: [B, T, Hg, K] — base = (bos*Hg + i_h//(H//Hg)) * K + gqa_ratio = H // Hg + k_base = (bos * fx.Int32(Hg) + i_h // fx.Int32(gqa_ratio)) * fx.Int32(K) + stride_k = fx.Int32(Hg * K) + + if WU_CONTIGUOUS: + if IS_VARLEN: + v_base = (i_h * T_flat + bos) * fx.Int32(V) + w_base = (i_h * T_flat + bos) * fx.Int32(K) + else: + v_base = ((i_n * fx.Int32(H) + i_h) * T_flat) * fx.Int32(V) + w_base = ((i_n * fx.Int32(H) + i_h) * T_flat) * fx.Int32(K) + stride_v = fx.Int32(V) + stride_w = fx.Int32(K) + else: + v_base = (bos * fx.Int32(H) + i_h) * fx.Int32(V) + w_base = (bos * fx.Int32(H) + i_h) * fx.Int32(K) + stride_v = fx.Int32(H * V) + stride_w = fx.Int32(H * K) + + if IS_VARLEN: + vn_base = (i_h * T_flat + bos) * fx.Int32(V) + else: + vn_base = ((i_n * fx.Int32(H) + i_h) * T_flat) * fx.Int32(V) + + if USE_INITIAL_STATE: + h0_base = (i_nh * fx.Int32(K * V)) + if STORE_FINAL_STATE: + ht_base = (i_nh * fx.Int32(K * V)) + + # ── MFMA lane mapping for 16x16 tiles ── + lane_n = lane % fx.Int32(16) + lane_m_base = lane // fx.Int32(16) + + # index-typed versions for LDS addressing + wid_idx = arith.index_cast(T.index, wid) + lane_n_idx = arith.index_cast(T.index, lane_n) + lane_m_base_idx = arith.index_cast(T.index, lane_m_base) + + # ── Initialize h accumulators ── + acc_zero = arith.constant_vector(0.0, T.f32x4) + + # h_accs[kb][nr] = f32x4 accumulator for k-block kb, v-repeat nr + h_accs = [] + for _kb in range_constexpr(NUM_K_BLOCKS): + for _nr in range_constexpr(N_REPEAT): + h_accs.append(acc_zero) + + # ── Load initial state if provided ── + if USE_INITIAL_STATE: + for kb in range_constexpr(NUM_K_BLOCKS): + for nr in range_constexpr(N_REPEAT): + h0_col = i_v * fx.Int32(BV) + fx.Int32(nr * 16) + lane_n + h0_elems = [] + for elem_i in range_constexpr(4): + h0_row = fx.Int32(kb * 64) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) + h0_off = h0_base + h0_row * fx.Int32(V) + h0_col + h0_elems.append(h0_[fx.Index(h0_off)]) + loaded_vec = vector.from_elements(T.f32x4, h0_elems) + acc_idx = kb * N_REPEAT + nr + h_accs[acc_idx] = arith.addf(h_accs[acc_idx], loaded_vec) + + # ── Main chunk loop ── + init_state = [_to_raw(v) for v in h_accs] + c_zero = arith.index(0) + c_one = arith.index(1) + nt_idx = arith.index_cast(T.index, NT) + + for i_t, state in range(c_zero, nt_idx, c_one, init=init_state): + h_accs_in = list(state) + i_t_i32 = arith.index_cast(T.i32, i_t) + + # ── 1. Prefetch all w K-blocks from global (overlap with h snapshot store) ── + w_prefetch_all = [] + w_prefetch_lds_all = [] + for kb in range_constexpr(NUM_K_BLOCKS): + for batch in range_constexpr(NUM_LOAD_BATCHES_64): + row = fx.Int32(batch * ROWS_PER_BATCH_64) + load_row_in_batch + abs_row = i_t_i32 * fx.Int32(BT) + row + in_bounds = arith.cmpi(arith.CmpIPredicate.slt, abs_row, T_local) + safe_row = arith.select(in_bounds, abs_row, fx.Int32(0)) + g_off = w_base + safe_row * stride_w + fx.Int32(kb * 64) + load_col_base + w_prefetch_all.append(w_.vec_load((fx.Index(g_off),), LOAD_VEC_WIDTH)) + w_prefetch_lds_all.append(row * fx.Int32(LDS_W_STRIDE) + fx.Int32(kb * 64) + load_col_base) + + # ── Store h snapshot to global + LDS (w[0] loads in flight) ── + for kb in range_constexpr(NUM_K_BLOCKS): + for nr in range_constexpr(N_REPEAT): + acc_idx = kb * N_REPEAT + nr + acc_val = h_accs_in[acc_idx] + h_col = i_v * fx.Int32(BV) + fx.Int32(nr * 16) + lane_n + lds_h_col = fx.Int32(nr * 16) + lane_n + + for elem_i in range_constexpr(4): + f32_val = vector.extract(acc_val, static_position=[elem_i], dynamic_position=[]) + bf16_val = arith.trunc_f(T.bf16, f32_val) + + h_row = fx.Int32(kb * 64) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) + h_off = h_base + i_t_i32 * stride_h + h_row * fx.Int32(V) + h_col + h_[fx.Index(h_off)] = bf16_val + + lds_h_row = fx.Int32(kb * 64) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) + lds_h_idx = lds_h_row * fx.Int32(BV) + lds_h_col + lds_h[fx.Index(lds_h_idx)] = bf16_val + + # ── Store all w K-blocks to LDS in one batch ── + for i_wp in range_constexpr(NUM_K_BLOCKS * NUM_LOAD_BATCHES_64): + lds_w.vec_store((fx.Index(w_prefetch_lds_all[i_wp]),), w_prefetch_all[i_wp], LOAD_VEC_WIDTH) + + gpu.barrier() + + # ── 2. Delta correction: b_v = w @ h, then v_new = u - b_v ── + # Prefetch k[0] and u values during MFMA (overlap global loads with compute) + k_prefetch = [] + k_prefetch_lds = [] + for batch in range_constexpr(NUM_LOAD_BATCHES_64): + row = fx.Int32(batch * ROWS_PER_BATCH_64) + load_row_in_batch + abs_row = i_t_i32 * fx.Int32(BT) + row + in_bounds = arith.cmpi(arith.CmpIPredicate.slt, abs_row, T_local) + safe_row = arith.select(in_bounds, abs_row, fx.Int32(0)) + g_off = k_base + safe_row * stride_k + fx.Int32(0 * 64) + load_col_base + k_prefetch.append(k_.vec_load((fx.Index(g_off),), LOAD_VEC_WIDTH)) + k_prefetch_lds.append(row * fx.Int32(LDS_K_STRIDE) + load_col_base) + + # Prefetch g values (overlap with MFMA below) + if USE_G: + next_chunk_end = (i_t_i32 + fx.Int32(1)) * fx.Int32(BT) + last_idx_raw = arith.select( + arith.cmpi(arith.CmpIPredicate.slt, next_chunk_end, T_local), + next_chunk_end, + T_local, + ) - fx.Int32(1) + g_last_off = (bos + last_idx_raw) * fx.Int32(H) + i_h + g_last_prefetch = g_[fx.Index(g_last_off)] + + g_row_prefetch = [] + for elem_i in range_constexpr(4): + abs_row = i_t_i32 * fx.Int32(BT) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) + in_bounds = arith.cmpi(arith.CmpIPredicate.slt, abs_row, T_local) + safe_row = arith.select(in_bounds, abs_row, fx.Int32(0)) + g_row_off = (bos + safe_row) * fx.Int32(H) + i_h + g_row_prefetch.append((g_[fx.Index(g_row_off)], in_bounds)) + + # Prefetch u values (overlap with MFMA below) + u_prefetch = [] + for nr in range_constexpr(N_REPEAT): + u_col = i_v * fx.Int32(BV) + fx.Int32(nr * 16) + lane_n + for elem_i in range_constexpr(4): + u_bt_row_raw = i_t_i32 * fx.Int32(BT) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) + u_row_in_bounds = arith.cmpi(arith.CmpIPredicate.slt, u_bt_row_raw, T_local) + safe_u_row = arith.select(u_row_in_bounds, u_bt_row_raw, fx.Int32(0)) + u_off = v_base + safe_u_row * stride_v + u_col + u_prefetch.append(v_[fx.Index(u_off)]) + + bv_accs = [] + for _nr in range_constexpr(N_REPEAT): + bv_accs.append(arith.constant_vector(0.0, T.f32x4)) + + K_STEPS_PER_BLOCK = 64 // WMMA_K + + for kb in range_constexpr(NUM_K_BLOCKS): + for ks in range_constexpr(K_STEPS_PER_BLOCK): + w_lds_row_idx = wid_idx * arith.index(16) + lane_n_idx + w_lds_col_idx = arith.index(kb * 64 + ks * WMMA_K) + lane_m_base_idx * arith.index(8) + w_lds_idx = w_lds_row_idx * arith.index(LDS_W_STRIDE) + w_lds_col_idx + a_frag = _lds_vec_read_w_bf16x8(w_lds_idx) + + global_ks = kb * K_STEPS_PER_BLOCK + ks + + for nr in range_constexpr(N_REPEAT): + h_k_row = fx.Int32(global_ks * WMMA_K) + lane_m_base * fx.Int32(8) + tr_k_group + h_v_col = fx.Int32(nr * 16) + tr_col_sub * fx.Int32(4) + h_lds_elem = h_k_row * fx.Int32(BV) + h_v_col + h_lds_byte = h_lds_elem * fx.Int32(2) + fx.Int32(lds_h_offset) + + h_lo = _ds_read_tr_bf16x4(h_lds_byte) + h_hi = _ds_read_tr_bf16x4(h_lds_byte + fx.Int32(4 * BV * 2)) + b_frag = vector.shuffle(h_lo, h_hi, [0, 1, 2, 3, 4, 5, 6, 7]) + + bv_accs[nr] = _mfma_bf16_16x16x32(a_frag, b_frag, bv_accs[nr]) + + # v_new = u - b_v (u values already prefetched) + vn_frags = [] + for nr in range_constexpr(N_REPEAT): + bv_val = bv_accs[nr] + u_f32_elems = [] + for elem_i in range_constexpr(4): + u_bf16 = u_prefetch[nr * 4 + elem_i] + u_f32_elems.append(arith.extf(T.f32, u_bf16)) + u_f32 = vector.from_elements(T.f32x4, u_f32_elems) + + vn_frags.append(arith.subf(u_f32, bv_val)) + + # ── 2b. Store v_new (pre-gating) for output ── + if SAVE_NEW_VALUE: + for nr in range_constexpr(N_REPEAT): + vn_val = vn_frags[nr] + vn_col = i_v * fx.Int32(BV) + fx.Int32(nr * 16) + lane_n + for elem_i in range_constexpr(4): + vn_bt_row = i_t_i32 * fx.Int32(BT) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) + vn_in_bounds = arith.cmpi(arith.CmpIPredicate.slt, vn_bt_row, T_local) + _if_vn = scf.IfOp(vn_in_bounds) + with ir.InsertionPoint(_if_vn.then_block): + f32_v = vector.extract(vn_val, static_position=[elem_i], dynamic_position=[]) + bf16_v = arith.trunc_f(T.bf16, f32_v) + vn_off = vn_base + vn_bt_row * fx.Int32(V) + vn_col + vn_[fx.Index(vn_off)] = bf16_v + scf.YieldOp([]) + + # ── 3. Gating — g values prefetched before MFMA ── + if USE_G: + g_last = g_last_prefetch + exp_g_last = _fast_exp(g_last) + + gate_vec = arith.constant_vector(0.0, T.f32x4) + for elem_i in range_constexpr(4): + g_row, in_bounds = g_row_prefetch[elem_i] + gate = _fast_exp(arith.subf(g_last, g_row)) + gate_masked = arith.select(in_bounds, gate, arith.constant(0.0, type=T.f32)) + gate_vec = vector.insert(gate_masked, gate_vec, static_position=[elem_i], dynamic_position=[]) + + for nr in range_constexpr(N_REPEAT): + vn_frags[nr] = arith.mulf(vn_frags[nr], gate_vec) + + exp_g_last_vec = arith.constant_vector(0.0, T.f32x4) + for ei in range_constexpr(4): + exp_g_last_vec = vector.insert(exp_g_last, exp_g_last_vec, static_position=[ei], dynamic_position=[]) + + for kb in range_constexpr(NUM_K_BLOCKS): + for nr in range_constexpr(N_REPEAT): + acc_idx = kb * N_REPEAT + nr + h_accs_in[acc_idx] = arith.mulf(h_accs_in[acc_idx], exp_g_last_vec) + + # ── 4. State update: h += k^T @ v_new_gated ── + BT_STEPS = BT // WMMA_K + + # Prefetch remaining k K-blocks (k[0] already prefetched during delta correction) + for kb_extra in range_constexpr(1, NUM_K_BLOCKS): + for batch in range_constexpr(NUM_LOAD_BATCHES_64): + row = fx.Int32(batch * ROWS_PER_BATCH_64) + load_row_in_batch + abs_row = i_t_i32 * fx.Int32(BT) + row + in_bounds = arith.cmpi(arith.CmpIPredicate.slt, abs_row, T_local) + safe_row = arith.select(in_bounds, abs_row, fx.Int32(0)) + g_off = k_base + safe_row * stride_k + fx.Int32(kb_extra * 64) + load_col_base + k_prefetch.append(k_.vec_load((fx.Index(g_off),), LOAD_VEC_WIDTH)) + k_prefetch_lds.append(row * fx.Int32(LDS_K_STRIDE) + fx.Int32(kb_extra * 64) + load_col_base) + + # Store gated v_new + all k K-blocks to LDS in one batch, single barrier + for nr in range_constexpr(N_REPEAT): + vn_val = vn_frags[nr] + lds_col = fx.Int32(nr * 16) + lane_n + for elem_i in range_constexpr(4): + f32_v = vector.extract(vn_val, static_position=[elem_i], dynamic_position=[]) + bf16_v = arith.trunc_f(T.bf16, f32_v) + lds_row = wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) + lds_idx = lds_row * fx.Int32(LDS_VN_STRIDE) + lds_col + lds_vn[fx.Index(lds_idx)] = bf16_v + + for i_kp in range_constexpr(NUM_K_BLOCKS * NUM_LOAD_BATCHES_64): + lds_k.vec_store((fx.Index(k_prefetch_lds[i_kp]),), k_prefetch[i_kp], LOAD_VEC_WIDTH) + + gpu.barrier() + + for kb in range_constexpr(NUM_K_BLOCKS): + for bt_s in range_constexpr(BT_STEPS): + k_col_tr = wid * fx.Int32(16) + tr_col_sub * fx.Int32(4) + bt_row_tr = fx.Int32(bt_s * WMMA_K) + lane_m_base * fx.Int32(8) + tr_k_group + k_lds_elem = bt_row_tr * fx.Int32(LDS_K_STRIDE) + fx.Int32(kb * 64) + k_col_tr + k_lds_byte = k_lds_elem * fx.Int32(2) + fx.Int32(lds_k_offset) + + k_lo = _ds_read_tr_bf16x4(k_lds_byte) + k_hi = _ds_read_tr_bf16x4(k_lds_byte + fx.Int32(4 * LDS_K_STRIDE * 2)) + k_a_frag = vector.shuffle(k_lo, k_hi, [0, 1, 2, 3, 4, 5, 6, 7]) + + for nr in range_constexpr(N_REPEAT): + vn_bt_row = fx.Int32(bt_s * WMMA_K) + lane_m_base * fx.Int32(8) + tr_k_group + vn_v_col = fx.Int32(nr * 16) + tr_col_sub * fx.Int32(4) + vn_lds_elem = vn_bt_row * fx.Int32(LDS_VN_STRIDE) + vn_v_col + vn_lds_byte = vn_lds_elem * fx.Int32(2) + fx.Int32(lds_vn_offset) + + vn_lo = _ds_read_tr_bf16x4(vn_lds_byte) + vn_hi = _ds_read_tr_bf16x4(vn_lds_byte + fx.Int32(4 * BV * 2)) + vn_b_frag = vector.shuffle(vn_lo, vn_hi, [0, 1, 2, 3, 4, 5, 6, 7]) + + acc_idx = kb * N_REPEAT + nr + h_accs_in[acc_idx] = _mfma_bf16_16x16x32(k_a_frag, vn_b_frag, h_accs_in[acc_idx]) + + results = yield [_to_raw(v) for v in h_accs_in] + + h_accs_final = list(results) + + # ── Epilogue: store final state ── + if STORE_FINAL_STATE: + for kb in range_constexpr(NUM_K_BLOCKS): + for nr in range_constexpr(N_REPEAT): + acc_idx = kb * N_REPEAT + nr + acc_val = h_accs_final[acc_idx] + + ht_col = i_v * fx.Int32(BV) + fx.Int32(nr * 16) + lane_n + for elem_i in range_constexpr(4): + f32_val = vector.extract(acc_val, static_position=[elem_i], dynamic_position=[]) + ht_row = fx.Int32(kb * 64) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) + ht_off = ht_base + ht_row * fx.Int32(V) + ht_col + ht_[fx.Index(ht_off)] = f32_val + + # ── Host launcher ────────────────────────────────────────────────────── + @flyc.jit + def launch_gdn_h( + k_tensor: fx.Tensor, + v_tensor: fx.Tensor, + w_tensor: fx.Tensor, + v_new_tensor: fx.Tensor, + g_tensor: fx.Tensor, + h_tensor: fx.Tensor, + h0_tensor: fx.Tensor, + ht_tensor: fx.Tensor, + cu_seqlens_tensor: fx.Tensor, + chunk_offsets_tensor: fx.Tensor, + T_val: fx.Int32, + T_flat: fx.Int32, + N_val: fx.Int32, + grid_v: fx.Int32, + grid_nh: fx.Int32, + stream: fx.Stream, + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + launcher = gdn_h_kernel( + k_tensor, v_tensor, w_tensor, v_new_tensor, g_tensor, + h_tensor, h0_tensor, ht_tensor, + cu_seqlens_tensor, chunk_offsets_tensor, + T_val, T_flat, N_val, + ) + launcher.launch( + grid=(grid_v, grid_nh, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_gdn_h + + +# ── Python wrapper (matches Triton interface) ──────────────────────────── + +_compiled_kernels = {} +_autotune_cache = {} # (shape_key) -> best BV +_BV_CANDIDATES = [16, 32, 64] +_AUTOTUNE_WARMUP = 5 +_AUTOTUNE_ITERS = 25 + + +def _get_or_compile(K, V, BT, BV, H, Hg, use_g, use_h0, store_fs, save_vn, is_varlen, wu_contig): + cache_key = (K, V, BT, BV, H, Hg, use_g, use_h0, store_fs, save_vn, is_varlen, wu_contig) + if cache_key not in _compiled_kernels: + _compiled_kernels[cache_key] = compile_chunk_gated_delta_h( + K=K, V=V, BT=BT, BV=BV, H=H, Hg=Hg, + USE_G=use_g, USE_INITIAL_STATE=use_h0, + STORE_FINAL_STATE=store_fs, SAVE_NEW_VALUE=save_vn, + IS_VARLEN=is_varlen, WU_CONTIGUOUS=wu_contig, + ) + return _compiled_kernels[cache_key] + + +def _launch_kernel(launch_fn, BV, V, N, H, + k, u, w, vn_arg, g_arg, h, h0_arg, ht_arg, + cu_arg, co_arg, T, T_flat, stream): + grid_v = triton.cdiv(V, BV) + grid_nh = N * H + launch_fn( + k, u, w, vn_arg, g_arg, + h, h0_arg, ht_arg, + cu_arg, co_arg, + T, T_flat, N, + grid_v, grid_nh, + stream, + ) + + +def chunk_gated_delta_rule_fwd_h_flydsl( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor | None = None, + gk: torch.Tensor | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, + chunk_size: int = 64, + save_new_value: bool = True, + cu_seqlens: torch.LongTensor | None = None, + wu_contiguous: bool = True, + BV: int = 0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + """FlyDSL K5 wrapper with wrapper-level autotune over BV.""" + B, T, Hg, K = k.shape + BT = chunk_size + + if wu_contiguous: + H = w.shape[1] + V = u.shape[-1] + T_flat = w.shape[2] + else: + H = u.shape[-2] + V = u.shape[-1] + T_flat = w.shape[1] + + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N = len(cu_seqlens) - 1 + lens = cu_seqlens[1:] - cu_seqlens[:-1] + NT = sum(triton.cdiv(int(l), BT) for l in lens.tolist()) + chunk_offsets = torch.cat([ + cu_seqlens.new_tensor([0]), + triton.cdiv(lens, BT), + ]).cumsum(-1).to(torch.int32) + + assert K <= 256 + + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + v_new_buf = k.new_empty(B, H, T_flat, V, dtype=u.dtype) + v_new = v_new_buf if save_new_value else None + + dummy = torch.empty(1, device=k.device, dtype=torch.float32) + g_arg = g if g is not None else dummy + h0_arg = initial_state if initial_state is not None else dummy + ht_arg = final_state if final_state is not None else dummy + vn_arg = v_new_buf + cu_arg = cu_seqlens.to(torch.int32) if cu_seqlens is not None else dummy.to(torch.int32) + co_arg = chunk_offsets if chunk_offsets is not None else dummy.to(torch.int32) + stream = torch.cuda.current_stream() + + use_g = g is not None + use_h0 = initial_state is not None + is_varlen = cu_seqlens is not None + + # Resolve BV: explicit > autotune cache > benchmark + if BV <= 0: + shape_key = (K, V, BT, H, Hg, T_flat, N, + use_g, use_h0, output_final_state, + save_new_value, is_varlen, wu_contiguous) + + if shape_key in _autotune_cache: + BV = _autotune_cache[shape_key] + else: + candidates = [bv for bv in _BV_CANDIDATES if bv <= V and V % bv == 0] + if len(candidates) <= 1: + BV = candidates[0] if candidates else 16 + else: + print(f"[K5 autotune] benchmarking BV in {candidates} ...") + best_bv, best_us = candidates[0], float('inf') + for bv in candidates: + fn = _get_or_compile(K, V, BT, bv, H, Hg, + use_g, use_h0, output_final_state, + save_new_value, is_varlen, wu_contiguous) + # warmup + for _ in range(_AUTOTUNE_WARMUP): + _launch_kernel(fn, bv, V, N, H, + k, u, w, vn_arg, g_arg, h, h0_arg, ht_arg, + cu_arg, co_arg, T, T_flat, stream) + torch.cuda.synchronize() + # benchmark + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + for _ in range(_AUTOTUNE_ITERS): + _launch_kernel(fn, bv, V, N, H, + k, u, w, vn_arg, g_arg, h, h0_arg, ht_arg, + cu_arg, co_arg, T, T_flat, stream) + e.record() + torch.cuda.synchronize() + us = s.elapsed_time(e) / _AUTOTUNE_ITERS * 1000 + print(f" BV={bv:3d}: {us:.2f} us") + if us < best_us: + best_us = us + best_bv = bv + BV = best_bv + print(f"[K5 autotune] best BV={BV} ({best_us:.2f} us)") + _autotune_cache[shape_key] = BV + + launch_fn = _get_or_compile(K, V, BT, BV, H, Hg, + use_g, use_h0, output_final_state, + save_new_value, is_varlen, wu_contiguous) + _launch_kernel(launch_fn, BV, V, N, H, + k, u, w, vn_arg, g_arg, h, h0_arg, ht_arg, + cu_arg, co_arg, T, T_flat, stream) + + return h, v_new, final_state + + +__all__ = [ + "compile_chunk_gated_delta_h", + "chunk_gated_delta_rule_fwd_h_flydsl", +] diff --git a/tests/kernels/test_chunk_gated_delta_h.py b/tests/kernels/test_chunk_gated_delta_h.py new file mode 100755 index 00000000..7560702e --- /dev/null +++ b/tests/kernels/test_chunk_gated_delta_h.py @@ -0,0 +1,745 @@ +""" +Tests for FlyDSL K5: chunk_gated_delta_rule_fwd_h (GDN hidden-state recurrence) + +Correctness: compare FlyDSL kernel against a pure-PyTorch reference. +Performance: compare FlyDSL kernel against Triton opt3 kernel. +Rocprof: profile with rocprofv3 for accurate GPU kernel timing. + +Runtime parameters derived from Qwen3.5-397B-A17B TP=8 serving config: + K=128, V=128, Hk=16->Hg=2, Hv=64->H=8, BT=64 + max_num_batched_tokens=8192, full_prompt_len=8000 + +Usage: + cd /workspace/FlyDSL + python3 -m pytest tests/kernels/test_chunk_gated_delta_h.py -v -s + python3 -m pytest tests/kernels/test_chunk_gated_delta_h.py -v -s -k "Correct" + python3 -m pytest tests/kernels/test_chunk_gated_delta_h.py -v -s -k "Perf" + python3 -m pytest tests/kernels/test_chunk_gated_delta_h.py -v -s -k "Rocprof" + + # Direct rocprofv3 profiling (without pytest): + python3 tests/kernels/test_chunk_gated_delta_h.py --mode rocprof + python3 tests/kernels/test_chunk_gated_delta_h.py --mode rocprof --full-prompt-len 1000 +""" + +import argparse +import csv +import ctypes +import subprocess +import sys +import os +from ctypes.util import find_library +from pathlib import Path + +import pytest +import torch +import triton + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from kernels.chunk_gated_delta_h import chunk_gated_delta_rule_fwd_h_flydsl + +# ── Triton opt3 kernel (inlined, no external dependency) ──────────────── + +import functools +import triton.language as tl + +def _check_platform(): + try: + backend = triton.runtime.driver.active.get_current_target().backend + except (RuntimeError, AttributeError): + backend = "cpu" + return {"cuda": "nvidia", "hip": "amd", "xpu": "intel"}.get(backend, backend) + +_use_cuda_graph = _check_platform() == "nvidia" and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" + +def _tensor_cache(fn): + cache_entries = [] + @functools.wraps(fn) + def wrapper(*args, **kwargs): + nonlocal cache_entries + for i, (la, lk, lr) in enumerate(cache_entries): + if len(args) == len(la) and all(a is b for a, b in zip(args, la)) \ + and len(kwargs) == len(lk) and all(k in lk and v is lk[k] for k, v in kwargs.items()): + cache_entries = cache_entries[:i] + cache_entries[i+1:] + [(la, lk, lr)] + return lr + result = fn(*args, **kwargs) + if len(cache_entries) >= 8: + cache_entries.pop(0) + cache_entries.append((args, kwargs, result)) + return result + return wrapper + +@_tensor_cache +def _prepare_lens(cu_seqlens): + return cu_seqlens[1:] - cu_seqlens[:-1] + +@_tensor_cache +def _prepare_chunk_indices(cu_seqlens, chunk_size): + indices = torch.cat([torch.arange(n) for n in triton.cdiv(_prepare_lens(cu_seqlens), chunk_size).tolist()]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) + +@_tensor_cache +def _prepare_chunk_offsets(cu_seqlens, chunk_size): + return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(_prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1) + +@triton.heuristics({ + "USE_G": lambda args: args["g"] is not None, + "USE_GK": lambda args: args["gk"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, +}) +@triton.autotune( + configs=[triton.Config({"BV": BV}, num_warps=nw, num_stages=ns) + for nw in [2, 4] for ns in [1, 2, 3, 4] for BV in [16, 32, 64]], + key=["H", "K", "V", "BT", "IS_VARLEN"], + use_cuda_graph=_use_cuda_graph, +) +@triton.jit(do_not_specialize=["T"]) +def _triton_fwd_kernel_h_opt3( + k, v, w, v_new, g, gk, h, h0, ht, + cu_seqlens, chunk_offsets, T, T_flat, + H: tl.constexpr, Hg: tl.constexpr, K: tl.constexpr, V: tl.constexpr, + BT: tl.constexpr, BV: tl.constexpr, + USE_G: tl.constexpr, USE_GK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, IS_VARLEN: tl.constexpr, + WU_CONTIGUOUS: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos = tl.load(cu_seqlens + i_n).to(tl.int32) + eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: b_h4 = tl.zeros([64, BV], dtype=tl.float32) + h += ((boh * H + i_h) * K * V).to(tl.int64) + k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64) + if WU_CONTIGUOUS: + if IS_VARLEN: + v += ((i_h * T_flat + bos) * V).to(tl.int64) + w += ((i_h * T_flat + bos) * K).to(tl.int64) + else: + v += (((i_n * H + i_h) * T_flat) * V).to(tl.int64) + w += (((i_n * H + i_h) * T_flat) * K).to(tl.int64) + stride_v, stride_w = V, K + else: + v += ((bos * H + i_h) * V).to(tl.int64) + w += ((bos * H + i_h) * K).to(tl.int64) + stride_v, stride_w = H * V, H * K + if SAVE_NEW_VALUE: + if IS_VARLEN: v_new += ((i_h * T_flat + bos) * V).to(tl.int64) + else: v_new += (((i_n * H + i_h) * T_flat) * V).to(tl.int64) + stride_h, stride_k = H * K * V, Hg * K + if USE_INITIAL_STATE: h0 = h0 + i_nh * K * V + if STORE_FINAL_STATE: ht = ht + i_nh * K * V + if USE_INITIAL_STATE: + b_h1 += tl.load(tl.make_block_ptr(h0, (K,V),(V,1),(0,i_v*BV),(64,BV),(1,0)), boundary_check=(0,1)).to(tl.float32) + if K > 64: b_h2 += tl.load(tl.make_block_ptr(h0,(K,V),(V,1),(64,i_v*BV),(64,BV),(1,0)), boundary_check=(0,1)).to(tl.float32) + if K > 128: b_h3 += tl.load(tl.make_block_ptr(h0,(K,V),(V,1),(128,i_v*BV),(64,BV),(1,0)), boundary_check=(0,1)).to(tl.float32) + if K > 192: b_h4 += tl.load(tl.make_block_ptr(h0,(K,V),(V,1),(192,i_v*BV),(64,BV),(1,0)), boundary_check=(0,1)).to(tl.float32) + for i_t in range(NT): + tl.store(tl.make_block_ptr(h+i_t*stride_h,(K,V),(V,1),(0,i_v*BV),(64,BV),(1,0)), b_h1.to(tl.bfloat16), boundary_check=(0,1)) + if K > 64: tl.store(tl.make_block_ptr(h+i_t*stride_h,(K,V),(V,1),(64,i_v*BV),(64,BV),(1,0)), b_h2.to(tl.bfloat16), boundary_check=(0,1)) + if K > 128: tl.store(tl.make_block_ptr(h+i_t*stride_h,(K,V),(V,1),(128,i_v*BV),(64,BV),(1,0)), b_h3.to(tl.bfloat16), boundary_check=(0,1)) + if K > 192: tl.store(tl.make_block_ptr(h+i_t*stride_h,(K,V),(V,1),(192,i_v*BV),(64,BV),(1,0)), b_h4.to(tl.bfloat16), boundary_check=(0,1)) + p_w = tl.make_block_ptr(w,(T,K),(stride_w,1),(i_t*BT,0),(BT,64),(1,0)) + b_v = tl.dot(tl.load(p_w, boundary_check=(0,1)), b_h1.to(tl.bfloat16)) + if K > 64: b_v += tl.dot(tl.load(tl.make_block_ptr(w,(T,K),(stride_w,1),(i_t*BT,64),(BT,64),(1,0)), boundary_check=(0,1)), b_h2.to(tl.bfloat16)) + if K > 128: b_v += tl.dot(tl.load(tl.make_block_ptr(w,(T,K),(stride_w,1),(i_t*BT,128),(BT,64),(1,0)), boundary_check=(0,1)), b_h3.to(tl.bfloat16)) + if K > 192: b_v += tl.dot(tl.load(tl.make_block_ptr(w,(T,K),(stride_w,1),(i_t*BT,192),(BT,64),(1,0)), boundary_check=(0,1)), b_h4.to(tl.bfloat16)) + b_v = tl.load(tl.make_block_ptr(v,(T,V),(stride_v,1),(i_t*BT,i_v*BV),(BT,BV),(1,0)), boundary_check=(0,1)) - b_v + if SAVE_NEW_VALUE: + tl.store(tl.make_block_ptr(v_new,(T,V),(V,1),(i_t*BT,i_v*BV),(BT,BV),(1,0)), b_v.to(tl.bfloat16), boundary_check=(0,1)) + last_idx = min((i_t+1)*BT, T) - 1 + if USE_G: + m_t = (i_t*BT + tl.arange(0, BT)) < T + b_g_last = tl.load(g + bos*H + last_idx*H + i_h) + b_g = tl.load(tl.make_block_ptr(g+bos*H+i_h,(T,),(H,),(i_t*BT,),(BT,),(0,)), boundary_check=(0,)) + b_v = b_v * tl.where(m_t, tl.exp(b_g_last - b_g), 0)[:, None] + b_g_last = tl.exp(b_g_last) + b_h1 *= b_g_last + if K > 64: b_h2 *= b_g_last + if K > 128: b_h3 *= b_g_last + if K > 192: b_h4 *= b_g_last + if USE_GK: + o_k1 = tl.arange(0, 64) + b_h1 *= tl.exp(tl.load(gk+(bos+last_idx)*H*K+i_h*K+o_k1, mask=(o_k1 64: b_h2 *= tl.exp(tl.load(gk+(bos+last_idx)*H*K+i_h*K+64+o_k1, mask=(64+o_k1 128: b_h3 *= tl.exp(tl.load(gk+(bos+last_idx)*H*K+i_h*K+128+o_k1, mask=(128+o_k1 192: b_h4 *= tl.exp(tl.load(gk+(bos+last_idx)*H*K+i_h*K+192+o_k1, mask=(192+o_k1 64: b_h2 += tl.dot(tl.load(tl.make_block_ptr(k,(K,T),(1,stride_k),(64,i_t*BT),(64,BT),(0,1)), boundary_check=(0,1)), b_v) + if K > 128: b_h3 += tl.dot(tl.load(tl.make_block_ptr(k,(K,T),(1,stride_k),(128,i_t*BT),(64,BT),(0,1)), boundary_check=(0,1)), b_v) + if K > 192: b_h4 += tl.dot(tl.load(tl.make_block_ptr(k,(K,T),(1,stride_k),(192,i_t*BT),(64,BT),(0,1)), boundary_check=(0,1)), b_v) + if STORE_FINAL_STATE: + tl.store(tl.make_block_ptr(ht,(K,V),(V,1),(0,i_v*BV),(64,BV),(1,0)), b_h1.to(tl.float32), boundary_check=(0,1)) + if K > 64: tl.store(tl.make_block_ptr(ht,(K,V),(V,1),(64,i_v*BV),(64,BV),(1,0)), b_h2.to(tl.float32), boundary_check=(0,1)) + if K > 128: tl.store(tl.make_block_ptr(ht,(K,V),(V,1),(128,i_v*BV),(64,BV),(1,0)), b_h3.to(tl.float32), boundary_check=(0,1)) + if K > 192: tl.store(tl.make_block_ptr(ht,(K,V),(V,1),(192,i_v*BV),(64,BV),(1,0)), b_h4.to(tl.float32), boundary_check=(0,1)) + +def fwd_h_triton_opt3( + k, w, u, g=None, gk=None, initial_state=None, + output_final_state=False, chunk_size=64, save_new_value=True, + cu_seqlens=None, wu_contiguous=False, +): + B, T, Hg, K = k.shape + BT = chunk_size + if wu_contiguous: + H, V, T_flat = w.shape[1], u.shape[-1], w.shape[2] + else: + H, V, T_flat = u.shape[-2], u.shape[-1], w.shape[1] + chunk_indices = _prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N = len(cu_seqlens) - 1 + NT = len(chunk_indices) + chunk_offsets = _prepare_chunk_offsets(cu_seqlens, BT) + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + v_new = k.new_empty(B, H, T_flat, V, dtype=u.dtype) if save_new_value else None + _triton_fwd_kernel_h_opt3[(lambda meta: (triton.cdiv(V, meta["BV"]), N * H))]( + k=k, v=u, w=w, v_new=v_new, g=g, gk=gk, + h=h, h0=initial_state, ht=final_state, + cu_seqlens=cu_seqlens, chunk_offsets=chunk_offsets, + T=T, T_flat=T_flat, H=H, Hg=Hg, K=K, V=V, BT=BT, + WU_CONTIGUOUS=wu_contiguous, + ) + return h, v_new, final_state + + + +# ── Global test configuration ────────────────────────────────────────── + +K = 128 +V = 128 +Hg = 2 +H = 8 +BT = 64 + +MAX_NUM_BATCHED_TOKENS = 8192 +FULL_PROMPT_LENS = [8000] + +NUM_WARMUP = 10 +NUM_ITERS = 200 + + +def _build_context_lens(full_prompt_len, max_tokens=MAX_NUM_BATCHED_TOKENS): + context_lens = [] + remaining = max_tokens + while remaining > 0: + cur = min(full_prompt_len, remaining) + context_lens.append(cur) + remaining -= cur + return context_lens + + +def _build_cu_seqlens(context_lens, device="cuda"): + scheduled_q_lens = context_lens + cu_seqlens = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(scheduled_q_lens), 0).tolist()), + dtype=torch.int32, + device=device, + ) + return scheduled_q_lens, cu_seqlens + + +def _make_inputs(context_lens, dtype=torch.bfloat16, device="cuda", + with_initial_state=True): + scheduled_q_lens, cu_seqlens = _build_cu_seqlens(context_lens, device=device) + T_total = int(cu_seqlens[-1].item()) + N = len(scheduled_q_lens) + B = 1 + + k = torch.randn(B, T_total, Hg, K, dtype=dtype, device=device) * 0.1 + w_orig = torch.randn(B, T_total, H, K, dtype=dtype, device=device) * 0.1 + u_orig = torch.randn(B, T_total, H, V, dtype=dtype, device=device) * 0.1 + g = torch.randn(T_total, H, dtype=torch.float32, device=device).abs() * -0.5 + g = g.cumsum(dim=0) + + w_c = w_orig.permute(0, 2, 1, 3).contiguous() + u_c = u_orig.permute(0, 2, 1, 3).contiguous() + + initial_state = None + if with_initial_state: + initial_state = torch.randn(N, H, K, V, dtype=torch.float32, device=device) * 0.01 + + return k, w_orig, u_orig, w_c, u_c, g, initial_state, cu_seqlens, scheduled_q_lens + + +# ── Pure-PyTorch reference ────────────────────────────────────────────── + +def ref_chunk_gated_delta_rule_fwd_h( + k, w, u, g, + initial_state=None, + output_final_state=False, + chunk_size=64, + cu_seqlens=None, +): + """Reference in FP32 for correctness checking.""" + B, T, Hg_dim, K_dim = k.shape + H_dim, V_dim = u.shape[-2], u.shape[-1] + BT_dim = chunk_size + if cu_seqlens is None: + NT = triton.cdiv(T, BT_dim) + else: + seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + NT = sum(triton.cdiv(int(seq_len), BT_dim) for seq_len in seq_lens) + gqa_ratio = H_dim // Hg_dim + + h_out = k.new_zeros(B, NT, H_dim, K_dim, V_dim, dtype=torch.float32) + v_new_out = torch.zeros_like(u, dtype=torch.float32) + + N = len(cu_seqlens) - 1 if cu_seqlens is not None else B + final_state = torch.zeros(N, H_dim, K_dim, V_dim, dtype=torch.float32, + device=k.device) if output_final_state else None + + for b_idx in range(B): + if cu_seqlens is not None: + seqs = [(s, cu_seqlens[s].item(), cu_seqlens[s + 1].item()) + for s in range(N)] + else: + seqs = [(b_idx, 0, T)] + + chunk_offset = 0 + for seq_idx, bos, eos in seqs: + seq_len = eos - bos + seq_nt = triton.cdiv(seq_len, BT_dim) + + for i_h in range(H_dim): + i_hg = i_h // gqa_ratio + h_state = torch.zeros(K_dim, V_dim, dtype=torch.float32, + device=k.device) + if initial_state is not None: + h_state = initial_state[seq_idx, i_h].float().clone() + + for i_t in range(seq_nt): + t_start = i_t * BT_dim + t_end = min(t_start + BT_dim, seq_len) + actual_bt = t_end - t_start + + h_out[b_idx, chunk_offset + i_t, i_h] = h_state.clone() + + w_chunk = w[b_idx, bos + t_start:bos + t_end, i_h].float() + u_chunk = u[b_idx, bos + t_start:bos + t_end, i_h].float() + b_v = u_chunk - w_chunk @ h_state + v_new_out[b_idx, bos + t_start:bos + t_end, i_h] = b_v + + last_idx = bos + t_end - 1 + g_last = g[last_idx, i_h].float() + g_chunk = g[bos + t_start:bos + t_end, i_h].float() + + mask = torch.zeros(BT_dim, device=k.device) + mask[:actual_bt] = 1.0 + gate = torch.where( + mask[:actual_bt].bool(), + torch.exp(g_last - g_chunk), + torch.zeros_like(g_chunk), + ) + b_v_gated = b_v * gate.unsqueeze(-1) + + h_state = h_state * torch.exp(g_last) + k_chunk = k[b_idx, bos + t_start:bos + t_end, i_hg].float() + b_v_gated_cast = b_v_gated.to(k.dtype).float() + h_state = h_state + k_chunk.T @ b_v_gated_cast + + if output_final_state: + final_state[seq_idx, i_h] = h_state + + chunk_offset += seq_nt + + return h_out, v_new_out.to(u.dtype), final_state + + +def _normalize_opt_v_new(vn_opt): + """Convert opt v_new layout [B, H, T, V] back to [B, T, H, V].""" + return vn_opt.permute(0, 2, 1, 3).contiguous() + + +# ── Correctness tests ─────────────────────────────────────────────────── + +class TestCorrectness: + """Correctness against PyTorch reference.""" + + @pytest.mark.parametrize("full_prompt_len", FULL_PROMPT_LENS) + def test_correctness_flydsl(self, full_prompt_len): + context_lens = _build_context_lens(full_prompt_len) + k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens) + + h_fly, vn_fly, fs_fly = chunk_gated_delta_rule_fwd_h_flydsl( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + h_ref, vn_ref, fs_ref = ref_chunk_gated_delta_rule_fwd_h( + k, w_orig, u_orig, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, + ) + + torch.testing.assert_close( + h_fly.float(), h_ref.float(), atol=1e-1, rtol=1e-1) + torch.testing.assert_close( + _normalize_opt_v_new(vn_fly).float(), vn_ref.float(), + atol=1e-1, rtol=1e-1) + torch.testing.assert_close( + fs_fly.float(), fs_ref.float(), atol=1e-1, rtol=1e-1) + + @pytest.mark.parametrize("full_prompt_len", FULL_PROMPT_LENS) + def test_correctness_triton_opt3(self, full_prompt_len): + context_lens = _build_context_lens(full_prompt_len) + k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens) + + h_tri, vn_tri, fs_tri = fwd_h_triton_opt3( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + h_ref, vn_ref, fs_ref = ref_chunk_gated_delta_rule_fwd_h( + k, w_orig, u_orig, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, + ) + + torch.testing.assert_close( + h_tri.float(), h_ref.float(), atol=1e-1, rtol=1e-1) + torch.testing.assert_close( + _normalize_opt_v_new(vn_tri).float(), vn_ref.float(), + atol=1e-1, rtol=1e-1) + torch.testing.assert_close( + fs_tri.float(), fs_ref.float(), atol=1e-1, rtol=1e-1) + + @pytest.mark.parametrize("full_prompt_len", FULL_PROMPT_LENS) + def test_correctness_flydsl_vs_triton(self, full_prompt_len): + """Direct comparison between FlyDSL and Triton opt3 kernels.""" + context_lens = _build_context_lens(full_prompt_len) + k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens) + + h_fly, vn_fly, fs_fly = chunk_gated_delta_rule_fwd_h_flydsl( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + h_tri, vn_tri, fs_tri = fwd_h_triton_opt3( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + + h_fly_f, h_tri_f = h_fly.float(), h_tri.float() + vn_fly_f, vn_tri_f = vn_fly.float(), vn_tri.float() + fs_fly_f, fs_tri_f = fs_fly.float(), fs_tri.float() + + def _report(name, a, b): + diff = (a - b).abs() + diff_flat = diff.flatten() + sorted_diff, _ = diff_flat.sort() + n = sorted_diff.numel() + median_val = sorted_diff[n // 2].item() + p99_val = sorted_diff[min(int(n * 0.99), n - 1)].item() + print(f" {name}:") + print(f" FlyDSL range: [{a.min().item():.6f}, {a.max().item():.6f}]") + print(f" Triton range: [{b.min().item():.6f}, {b.max().item():.6f}]") + print(f" abs_err max={diff.max().item():.6f} " + f"mean={diff.mean().item():.6f} " + f"median={median_val:.6f} " + f"p99={p99_val:.6f}") + + print(f"\n[FlyDSL vs Triton opt3 full_prompt_len={full_prompt_len}]") + _report("h", h_fly_f, h_tri_f) + _report("v_new", vn_fly_f, vn_tri_f) + _report("final_state", fs_fly_f, fs_tri_f) + + torch.testing.assert_close(h_fly_f, h_tri_f, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(vn_fly_f, vn_tri_f, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(fs_fly_f, fs_tri_f, atol=1e-1, rtol=1e-1) + + +# ── Performance tests ─────────────────────────────────────────────────── + +def _bench_fn(fn, *args, **kwargs): + """Warmup + measure, return average us.""" + fn(*args, **kwargs) + torch.cuda.synchronize() + for _ in range(NUM_WARMUP): + fn(*args, **kwargs) + torch.cuda.synchronize() + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + s.record() + for _ in range(NUM_ITERS): + fn(*args, **kwargs) + e.record() + torch.cuda.synchronize() + return s.elapsed_time(e) / NUM_ITERS * 1000 + + +PERF_SHAPES = [ + pytest.param(fpl, id=f"full{fpl}") + for fpl in FULL_PROMPT_LENS +] + + +class TestPerformance: + """Performance comparison: FlyDSL vs Triton opt3.""" + + @pytest.mark.parametrize("full_prompt_len", PERF_SHAPES) + def test_perf_comparison(self, full_prompt_len): + context_lens = _build_context_lens(full_prompt_len) + k, w_orig, u_orig, w_c, u_c, g, h0, cu, scheduled_q_lens = _make_inputs( + context_lens) + total_tokens = int(cu[-1].item()) + + # FlyDSL kernel + us_fly = _bench_fn( + chunk_gated_delta_rule_fwd_h_flydsl, + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + + print(f"\n[K5 FlyDSL T={total_tokens}] {us_fly:.2f} us") + + # Triton opt3 kernel for comparison + us_triton = _bench_fn( + fwd_h_triton_opt3, + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + speedup = us_triton / us_fly if us_fly > 0 else float('inf') + print(f"[K5 Triton opt3 T={total_tokens}] {us_triton:.2f} us") + print(f"[Speedup FlyDSL/Triton] {speedup:.3f}x") + + +# ── rocprofv3 profiling infrastructure ────────────────────────────────── + +TARGET_KERNEL_FLYDSL = "chunk_gdn_fwd_h_opt3" +TARGET_KERNEL_TRITON = "_triton_fwd_kernel_h_opt3" + + +def _load_roctx_library(): + """Load the roctx shared library for profiler pause/resume control.""" + for candidate in ("rocprofiler-sdk-roctx", "roctx64"): + libname = find_library(candidate) + if libname is None: + continue + lib = ctypes.CDLL(libname) + lib.roctxGetThreadId.argtypes = [ctypes.POINTER(ctypes.c_uint64)] + lib.roctxGetThreadId.restype = None + lib.roctxProfilerPause.argtypes = [ctypes.c_uint64] + lib.roctxProfilerPause.restype = None + lib.roctxProfilerResume.argtypes = [ctypes.c_uint64] + lib.roctxProfilerResume.restype = None + lib.roctxRangePushA.argtypes = [ctypes.c_char_p] + lib.roctxRangePushA.restype = ctypes.c_int + lib.roctxRangePop.argtypes = [] + lib.roctxRangePop.restype = ctypes.c_int + return lib + return None + + +def _roctx_thread_id(lib): + tid = ctypes.c_uint64() + lib.roctxGetThreadId(ctypes.byref(tid)) + return int(tid.value) + + +def _rocprof_worker(full_prompt_len): + """Inner worker: runs under rocprofv3 --selected-regions. + + Profiling starts paused. We warmup both kernels, then + Resume -> measured iterations -> Pause for each kernel sequentially. + """ + roctx = _load_roctx_library() + if roctx is None: + raise RuntimeError("roctx library not found; cannot run as profiling worker") + + tid = _roctx_thread_id(roctx) + + context_lens = _build_context_lens(full_prompt_len) + k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens) + total_tokens = int(cu[-1].item()) + + run_fly = lambda: chunk_gated_delta_rule_fwd_h_flydsl( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + + # Warmup FlyDSL (paused) + print(f"[rocprof-worker] Warmup FlyDSL (T={total_tokens}) ...", flush=True) + for _ in range(NUM_WARMUP): + run_fly() + torch.cuda.synchronize() + + # Measure FlyDSL + roctx.roctxProfilerResume(tid) + roctx.roctxRangePushA(b"flydsl_k5_bench") + for _ in range(NUM_ITERS): + run_fly() + torch.cuda.synchronize() + roctx.roctxRangePop() + roctx.roctxProfilerPause(tid) + print(f"[rocprof-worker] FlyDSL: {NUM_ITERS} iterations done", flush=True) + + # Triton opt3 + run_tri = lambda: fwd_h_triton_opt3( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + + print(f"[rocprof-worker] Warmup Triton opt3 ...", flush=True) + for _ in range(NUM_WARMUP): + run_tri() + torch.cuda.synchronize() + + roctx.roctxProfilerResume(tid) + roctx.roctxRangePushA(b"triton_k5_bench") + for _ in range(NUM_ITERS): + run_tri() + torch.cuda.synchronize() + roctx.roctxRangePop() + roctx.roctxProfilerPause(tid) + print(f"[rocprof-worker] Triton: {NUM_ITERS} iterations done", flush=True) + + +def _parse_kernel_stats(stats_path: Path) -> dict[str, dict]: + """Parse kernel_stats CSV -> {name: {AverageNs, TotalDurationNs, Calls, ...}}.""" + result = {} + with stats_path.open(newline="") as f: + for row in csv.DictReader(f): + result[row["Name"]] = row + return result + + +def _print_rocprof_summary(stats_path: Path, total_tokens: int): + """Print a formatted summary from rocprofv3 kernel_stats CSV.""" + stats = _parse_kernel_stats(stats_path) + + targets = [ + ("FlyDSL", TARGET_KERNEL_FLYDSL), + ("Triton opt3", TARGET_KERNEL_TRITON), + ] + + results = {} + for label, kname in targets: + entry = stats.get(kname) + if entry is None: + for name in stats: + if kname in name: + entry = stats[name] + break + if entry is None: + print(f" {label}: kernel '{kname}' not found in stats") + continue + + avg_ns = float(entry["AverageNs"]) + min_ns = float(entry["MinNs"]) + max_ns = float(entry["MaxNs"]) + calls = int(entry["Calls"]) + total_ns = float(entry["TotalDurationNs"]) + results[label] = avg_ns + + print(f" {label} ({kname}):") + print(f" Calls: {calls}") + print(f" Average: {avg_ns / 1000:.2f} us ({avg_ns:.0f} ns)") + print(f" Min: {min_ns / 1000:.2f} us") + print(f" Max: {max_ns / 1000:.2f} us") + print(f" Total: {total_ns / 1e6:.2f} ms") + + if "FlyDSL" in results and "Triton opt3" in results: + speedup = results["Triton opt3"] / results["FlyDSL"] + print(f"\n Speedup (FlyDSL vs Triton): {speedup:.3f}x") + + if not stats: + print(" WARNING: no kernels found in stats file") + elif not results: + print(" Available kernels:") + for name in sorted(stats.keys()): + print(f" {name}") + + +def _do_rocprof(full_prompt_len): + """Outer driver: launches rocprofv3 wrapping this script in --_rocprof-worker mode.""" + repo_root = Path(__file__).resolve().parent.parent.parent + output_dir = repo_root / "rocprof_output" + output_dir.mkdir(parents=True, exist_ok=True) + output_stem = f"gdn_k5_fpl{full_prompt_len}" + + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + + inner_cmd = [ + "python3", "-u", str(Path(__file__).resolve()), + "--_rocprof-worker", + "--full-prompt-len", str(full_prompt_len), + ] + rocprof_cmd = [ + "rocprofv3", + "--kernel-trace", + "--marker-trace", + "--output-format", "csv", + "-d", str(output_dir), + "-o", output_stem, + "--stats", + "--selected-regions", + "--", *inner_cmd, + ] + + context_lens = _build_context_lens(full_prompt_len) + total_tokens = sum(context_lens) + + print(f"\n[rocprof] full_prompt_len={full_prompt_len}, T={total_tokens}") + print(f"[rocprof] cmd: {' '.join(rocprof_cmd)}", flush=True) + result = subprocess.run(rocprof_cmd, cwd=repo_root, env=env) + + stats_path = output_dir / f"{output_stem}_kernel_stats.csv" + if stats_path.exists(): + print(f"\n[rocprof] Results (full_prompt_len={full_prompt_len}, T={total_tokens}):") + _print_rocprof_summary(stats_path, total_tokens) + else: + print(f"[rocprof] kernel stats not found: {stats_path}", flush=True) + trace_path = output_dir / f"{output_stem}_kernel_trace.csv" + if trace_path.exists(): + print(f"[rocprof] trace file exists: {trace_path}") + + if result.returncode != 0: + print(f"[rocprof] rocprofv3 exited with code {result.returncode}", flush=True) + + return result.returncode + + +# ── rocprofv3 pytest tests ───────────────────────────────────────────── + +class TestRocprof: + """Profile FlyDSL and Triton kernels with rocprofv3.""" + + @pytest.mark.parametrize("full_prompt_len", PERF_SHAPES) + def test_rocprof(self, full_prompt_len): + rc = _do_rocprof(full_prompt_len) + assert rc == 0, f"rocprofv3 exited with code {rc}" + + +# ── Main ──────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="GDN K5 test / profile") + parser.add_argument("--mode", choices=["test", "rocprof"], default="test", + help="test=pytest (default), rocprof=rocprofv3 profiling") + parser.add_argument("--full-prompt-len", type=int, default=8000) + parser.add_argument("--_rocprof-worker", action="store_true", + help=argparse.SUPPRESS) + args = parser.parse_args() + + if args._rocprof_worker: + _rocprof_worker(args.full_prompt_len) + elif args.mode == "rocprof": + _do_rocprof(args.full_prompt_len) + else: + pytest.main([__file__, "-v", "-s"])