diff --git a/benchmarks/benchmark_det_gemm.py b/benchmarks/benchmark_det_gemm.py new file mode 100644 index 0000000..2e1ef1f --- /dev/null +++ b/benchmarks/benchmark_det_gemm.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors +"""Overhead of batch-invariant det_gemm vs cuBLAS + Triton (WS1 #146). + +det_gemm (CUDA, naive first milestone) and the Triton path are batch-invariant +and SLOWER than cuBLAS by design (no split-K/stream-K, fixed accumulation, FP32, +no TF32). Reports overhead vs the fair baseline (cuBLAS, TF32 disabled), not a +speedup. The naive CUDA kernel is correctness-first; a tensor-core pass follows. +""" +import argparse + +import torch + +from rl_engine.kernels.ops.cuda.matmul import deterministic_gemm +from rl_engine.kernels.ops.pytorch.matmul import native_gemm + +try: + from rl_engine.kernels.ops.triton.matmul import deterministic_gemm_triton + + _HAS_TRITON = True +except ImportError: + _HAS_TRITON = False + +DEV = "cuda" +WARMUP, ITERS = 10, 50 + +SHAPES = [ + ("qkv", 4096, 4096, 12288), + ("o_proj", 4096, 4096, 4096), + ("mlp_up", 4096, 4096, 14336), + ("mlp_dn", 4096, 14336, 4096), + ("lm_head", 4096, 4096, 32000), +] + + +def _time(fn, a, b): + for _ in range(WARMUP): + fn(a, b) + torch.cuda.synchronize() + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + for _ in range(ITERS): + fn(a, b) + e.record() + torch.cuda.synchronize() + return s.elapsed_time(e) / ITERS + + +def run(): + rows = [] + for name, M, K, N in SHAPES: + a = torch.randn(M, K, device=DEV, dtype=torch.bfloat16) + b = torch.randn(K, N, device=DEV, dtype=torch.bfloat16) + torch.backends.cuda.matmul.allow_tf32 = True + t_tf32 = _time(lambda x, y: torch.matmul(x, y), a, b) + torch.backends.cuda.matmul.allow_tf32 = False + t_fp32 = _time(native_gemm, a, b) + t_cuda = _time(deterministic_gemm, a, b) + t_tri = _time(deterministic_gemm_triton, a, b) if _HAS_TRITON else float("nan") + rows.append((name, M, K, N, t_tf32, t_fp32, t_cuda, t_tri, t_cuda / t_fp32)) + return rows + + +def to_markdown(rows, dev, cap): + out = [ + f"## det_gemm overhead — {dev} (SM{cap[0]}{cap[1]})", + "", + "| shape | M | K | N | cuBLAS tf32 | cuBLAS fp32 | det CUDA | det Triton | overhead |", + "|---|---|---|---|---|---|---|---|---|", + ] + for n, M, K, N, t1, t2, t3, t4, ov in rows: + out.append( + f"| {n} | {M} | {K} | {N} | {t1:.3f} | {t2:.3f} | {t3:.3f} | {t4:.3f} | {ov:.1f}x |" + ) + out += [ + "", + "_Overhead = det CUDA vs cuBLAS (TF32 disabled). Naive CUDA kernel is " + "correctness-first; both det paths trade speed for bitwise " + "batch-invariance. Tensor-core pass is a follow-up (#146)._", + ] + return "\n".join(out) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--out", default=None) + args = ap.parse_args() + name, cap = torch.cuda.get_device_name(), torch.cuda.get_device_capability() + print(name, cap) + md = to_markdown(run(), name, cap) + print("\n" + md) + if args.out: + with open(args.out, "w") as f: + f.write(md + "\n") + + +if __name__ == "__main__": + main() diff --git a/csrc/cuda/gemm/det_gemm_kernel.cu b/csrc/cuda/gemm/det_gemm_kernel.cu new file mode 100644 index 0000000..9b846b3 --- /dev/null +++ b/csrc/cuda/gemm/det_gemm_kernel.cu @@ -0,0 +1,320 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2026 RL-Kernel Contributors +// csrc/cuda/gemm/det_gemm_kernel.cu +// +// WS1 - Batch-invariant deterministic GEMM (hand-written, no CUTLASS). +// +// SM90 path: TMA load + mma.sync (m16n8k16) tensor cores, single-CTA-per-tile, +// fixed K-accumulation order, NO split-K -> batch-invariant. +// Fallback : naive FP32 scalar kernel (also the correctness ground truth). +// +// Both: BF16 in / FP32 accum / no TF32 / no split-K. +// fwd: C = A @ B | dA = dC @ B^T | dB = A^T @ dC +// Backward reuses the forward kernel on transposed operands. + +#include +#include +#include + +#if defined(RL_KERNEL_ENABLE_SM90) +#include "../utils/tma_utils.cuh" +#include +#endif + +namespace { + +using nv_bf16 = __nv_bfloat16; + +__host__ __device__ constexpr int cdiv(int a, int b) { return (a + b - 1) / b; } + +// Naive FP32 scalar kernel (SM80 fallback + ground truth). Batch-invariant by +// construction: one thread = one output element, fixed ascending K loop. +constexpr int NAIVE_TILE = 16; + +__global__ void det_gemm_naive(const nv_bf16* __restrict__ A, + const nv_bf16* __restrict__ B, + nv_bf16* __restrict__ C, + int M, int N, int K) { + const int row = blockIdx.y * NAIVE_TILE + threadIdx.y; + const int col = blockIdx.x * NAIVE_TILE + threadIdx.x; + if (row >= M || col >= N) return; + float acc = 0.0f; + for (int k = 0; k < K; ++k) + acc += __bfloat162float(A[row * K + k]) * __bfloat162float(B[k * N + col]); + C[row * N + col] = __float2bfloat16(acc); +} + +void launch_naive(const nv_bf16* A, const nv_bf16* B, nv_bf16* C, + int M, int N, int K, cudaStream_t stream) { + dim3 block(NAIVE_TILE, NAIVE_TILE); + dim3 grid(cdiv(N, NAIVE_TILE), cdiv(M, NAIVE_TILE)); + det_gemm_naive<<>>(A, B, C, M, N, K); +} + +#if defined(RL_KERNEL_ENABLE_SM90) +// ============================================================================ +// SM90 path: TMA load + mma.sync. C[M,N] = A[M,K] @ B[K,N]. +// +// mma.sync.m16n8k16.row.col needs A row-major [m16,k16] and B col-major +// [n8,k16] (i.e. B operand indexed as [n, k]). Our B is row-major [K,N], so we +// load a B tile of shape [BK, BN] but feed the mma the (n,k) operand by +// addressing smem as B[k][n] -> we ldmatrix B with the same trick as the logp +// kernel (which loads W[V,D] = [n, k] directly). To match that validated path, +// A tile is [BM, BK] (row=token, col=k), B tile is [BN, BK] (row=n, col=k), +// which is exactly B^T. So the SM90 kernel computes C = A @ (Bt)^T where Bt is +// the [BN,BK] tile = B[k,n] transposed. We materialize that by giving TMA a +// descriptor over B viewed as [N,K]... but B is [K,N] row-major. +// +// To keep it simple and provably correct, the SM90 forward kernel REQUIRES its +// B operand already in [N,K] layout (row=n, col=k). The host wrapper passes +// B^T (contiguous) so the kernel sees Bt[N,K]; mathematically +// C = A[M,K] @ B[K,N] = A @ (Bt[N,K])^T, and the per-tile mma contracts over K +// in fixed order. This mirrors the logp kernel's W[V,D] @ hidden[N,D] pattern. +// ============================================================================ +constexpr int BM = 64; // rows (M) per CTA tile +constexpr int BN = 64; // cols (N) per CTA tile +constexpr int BK = 32; // K slice streamed per TMA load +constexpr int WARPS = 4; +constexpr int WG_THREADS = WARPS * 32; // 128 +constexpr int STAGES = 2; + +constexpr int MMA_M = 16, MMA_N = 8, MMA_K = 16; +constexpr int WARP_M = BM / WARPS; // 16 -> 1 m-tile per warp +constexpr int M_TILES = WARP_M / MMA_M; // 1 +constexpr int N_TILES = BN / MMA_N; // 8 +constexpr int K_TILES = BK / MMA_K; // 2 +constexpr int KK_GROUPS = BK / 32; // 1 (ldmatrix.x4 spans 32 cols) + +__device__ __forceinline__ void ldmatrix_x4(uint32_t regs[4], uint32_t addr) { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];" + : "=r"(regs[0]), "=r"(regs[1]), "=r"(regs[2]), "=r"(regs[3]) + : "r"(addr)); +} +__device__ __forceinline__ void mma_m16n8k16(const uint32_t A[4], const uint32_t B[2], float D[4]) { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(D[0]), "f"(D[1]), "f"(D[2]), "f"(D[3])); +} + +// A: row-major [M,K] via TMA tile [BM,BK]. Bt: row-major [N,K] via TMA tile +// [BN,BK]. C: row-major [M,N]. Each CTA owns one [BM,BN] output tile and +// walks the full K in fixed order (no split-K). +__global__ void det_gemm_sm90_kernel(const __grid_constant__ CUtensorMap a_tmap, + const __grid_constant__ CUtensorMap bt_tmap, + nv_bf16* __restrict__ C, + int M, int N, int K) { + const int tid = threadIdx.x; + const int warp = tid / 32; + const int lane = tid % 32; + const int row_base = blockIdx.y * BM; + const int col_base = blockIdx.x * BN; + const int kd = K / BK; // K validated multiple of BK on host + + extern __shared__ __align__(1024) char smem[]; + nv_bf16* sA = reinterpret_cast(smem); + nv_bf16* sB = reinterpret_cast(sA + STAGES * BM * BK); + int* mbar_base = reinterpret_cast(sB + STAGES * BN * BK); + + const uint32_t sA_base = static_cast(__cvta_generic_to_shared(sA)); + const uint32_t sB_base = static_cast(__cvta_generic_to_shared(sB)); + int mbar[STAGES]; +#pragma unroll + for (int s = 0; s < STAGES; ++s) + mbar[s] = static_cast(__cvta_generic_to_shared(mbar_base + 2 * s)); + + if (tid == 0) { +#pragma unroll + for (int s = 0; s < STAGES; ++s) mbarrier_init(mbar[s], 1); + asm volatile("fence.mbarrier_init.release.cluster;"); + } + __syncthreads(); + + const uint32_t tile_bytes = (BM * BK + BN * BK) * sizeof(nv_bf16); + + auto issue_load = [&](int k) { + const int buf = k % STAGES; + const int k_off = k * BK; + tma_2d_g2s(static_cast(sA_base + buf * BM * BK * sizeof(nv_bf16)), + &a_tmap, k_off, row_base, mbar[buf]); + tma_2d_g2s(static_cast(sB_base + buf * BN * BK * sizeof(nv_bf16)), + &bt_tmap, k_off, col_base, mbar[buf]); + mbarrier_arrive_expect_tx(mbar[buf], tile_bytes); + }; + + int phase[STAGES]; +#pragma unroll + for (int s = 0; s < STAGES; ++s) phase[s] = 0; + + // accumulators: this warp's M_TILES m-tiles x N_TILES n-tiles + float acc[M_TILES][N_TILES][4]; +#pragma unroll + for (int mi = 0; mi < M_TILES; ++mi) +#pragma unroll + for (int n = 0; n < N_TILES; ++n) + acc[mi][n][0] = acc[mi][n][1] = acc[mi][n][2] = acc[mi][n][3] = 0.0f; + + if (tid == 0) +#pragma unroll + for (int s = 0; s < STAGES - 1; ++s) + if (s < kd) issue_load(s); + + for (int k = 0; k < kd; ++k) { // fixed ascending K order, NO split-K + const int buf = k % STAGES; + if (tid == 0 && k + (STAGES - 1) < kd) issue_load(k + (STAGES - 1)); + mbarrier_wait(mbar[buf], phase[buf]); + phase[buf] ^= 1; + __syncthreads(); + + const uint32_t sA_buf = sA_base + buf * BM * BK * sizeof(nv_bf16); + const uint32_t sB_buf = sB_base + buf * BN * BK * sizeof(nv_bf16); + + // Load A operand (this warp's rows), all K-steps. Same addressing as logp. + uint32_t A[M_TILES][K_TILES][4]; +#pragma unroll + for (int mi = 0; mi < M_TILES; ++mi) { + const int row0 = warp * WARP_M + mi * MMA_M + (lane % 16); +#pragma unroll + for (int kt = 0; kt < K_TILES; ++kt) { + const uint32_t a_addr = + sA_buf + (row0 * BK + (lane / 16) * 8 + kt * MMA_K) * sizeof(nv_bf16); + ldmatrix_x4(A[mi][kt], a_addr); + } + } + + // Load B operand (all n-tiles) and contract. Same addressing as logp's W. +#pragma unroll + for (int n = 0; n < N_TILES; ++n) { +#pragma unroll + for (int kk = 0; kk < KK_GROUPS; ++kk) { + uint32_t b4[4]; + const uint32_t b_addr = + sB_buf + ((n * MMA_N + (lane % 8)) * BK + (lane / 8) * 8 + kk * 32) * sizeof(nv_bf16); + ldmatrix_x4(b4, b_addr); + const uint32_t B0[2] = {b4[0], b4[1]}; + const uint32_t B1[2] = {b4[2], b4[3]}; +#pragma unroll + for (int mi = 0; mi < M_TILES; ++mi) { + mma_m16n8k16(A[mi][2 * kk + 0], B0, acc[mi][n]); + mma_m16n8k16(A[mi][2 * kk + 1], B1, acc[mi][n]); + } + } + } + __syncthreads(); + } + + // Epilogue: write acc to C (row-major [M,N]). mma m16n8k16 output layout. +#pragma unroll + for (int mi = 0; mi < M_TILES; ++mi) { + const int row = row_base + warp * WARP_M + mi * MMA_M + lane / 4; +#pragma unroll + for (int n = 0; n < N_TILES; ++n) { + const int col = col_base + n * MMA_N + (lane % 4) * 2; + if (row < M && col + 1 < N) { + C[row * N + col + 0] = __float2bfloat16(acc[mi][n][0]); + C[row * N + col + 1] = __float2bfloat16(acc[mi][n][1]); + } + if (row + 8 < M && col + 1 < N) { + C[(row + 8) * N + col + 0] = __float2bfloat16(acc[mi][n][2]); + C[(row + 8) * N + col + 1] = __float2bfloat16(acc[mi][n][3]); + } + } + } +} + +// noswizzle TMA descriptor (kernel uses plain row-major ldmatrix addressing). +inline void init_tmap_noswizzle(CUtensorMap* tmap, const nv_bf16* gmem, + uint64_t height, uint64_t width, + uint32_t box_h, uint32_t box_w) { + uint64_t size[2] = {width, height}; + uint64_t stride[1] = {width * sizeof(nv_bf16)}; + uint32_t box[2] = {box_w, box_h}; + uint32_t estride[2] = {1, 1}; + CUresult res = cuTensorMapEncodeTiled( + tmap, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 2, (void*)gmem, size, stride, box, estride, + CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE, + CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + TORCH_CHECK(res == CUDA_SUCCESS, "det_gemm: cuTensorMapEncodeTiled failed"); +} + +// Launch SM90 GEMM. A:[M,K] row-major, Bt:[N,K] row-major (= B transposed). +// Requires M%BM==0, N%BN==0, K%BK==0 (host pads/falls back otherwise). +bool launch_sm90(const nv_bf16* A, const nv_bf16* Bt, nv_bf16* C, + int M, int N, int K, cudaStream_t stream) { + if (M % BM != 0 || N % BN != 0 || K % BK != 0) return false; // fall back + + CUtensorMap a_tmap, bt_tmap; + init_tmap_noswizzle(&a_tmap, A, M, K, BM, BK); + init_tmap_noswizzle(&bt_tmap, Bt, N, K, BN, BK); + + const int smem = STAGES * (BM * BK + BN * BK) * sizeof(nv_bf16) + STAGES * 8; + if (smem > 48 * 1024) + cudaFuncSetAttribute(det_gemm_sm90_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + + dim3 grid(cdiv(N, BN), cdiv(M, BM)); + det_gemm_sm90_kernel<<>>(a_tmap, bt_tmap, C, M, N, K); + return true; +} +#endif // RL_KERNEL_ENABLE_SM90 + +int sm_major() { + int dev = 0; cudaGetDevice(&dev); + cudaDeviceProp p{}; cudaGetDeviceProperties(&p, dev); + return p.major; +} +inline const nv_bf16* bf16(const torch::Tensor& t) { + return reinterpret_cast(t.data_ptr()); +} +inline nv_bf16* bf16o(torch::Tensor& t) { + return reinterpret_cast(t.data_ptr()); +} +void check_in(const torch::Tensor& t, const char* n) { + TORCH_CHECK(t.is_cuda(), n, " must be CUDA"); + TORCH_CHECK(t.scalar_type() == torch::kBFloat16, n, " must be bf16"); +} + +// Core dispatch: C = A[M,K] @ B[K,N]. SM90 kernel needs B^T ([N,K]); it is +// materialized contiguous here. Falls back to naive on SM80 or odd shapes. +torch::Tensor gemm_dispatch(const torch::Tensor& a, const torch::Tensor& b) { + const int M = a.size(0), K = a.size(1), N = b.size(1); + auto c = torch::empty({M, N}, a.options()); + auto stream = at::cuda::getCurrentCUDAStream(); + +#if defined(RL_KERNEL_ENABLE_SM90) + if (sm_major() >= 9) { + auto bt = b.t().contiguous(); // [N,K] + if (launch_sm90(bf16(a), bf16(bt), bf16o(c), M, N, K, stream)) return c; + // else fall through to naive (shape not tile-aligned) + } +#endif + launch_naive(bf16(a), bf16(b), bf16o(c), M, N, K, stream); + return c; +} + +} // anonymous namespace + +torch::Tensor det_gemm_fwd(torch::Tensor a, torch::Tensor b) { + check_in(a, "A"); check_in(b, "B"); + a = a.contiguous(); b = b.contiguous(); + TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "det_gemm_fwd: expect 2D [M,K]@[K,N]"); + TORCH_CHECK(b.size(0) == a.size(1), "det_gemm_fwd: K mismatch"); + return gemm_dispatch(a, b); +} + +torch::Tensor det_gemm_da(torch::Tensor dc, torch::Tensor b) { + check_in(dc, "dC"); check_in(b, "B"); + dc = dc.contiguous(); + auto bt = b.t().contiguous(); // [N,K]; dA[M,K] = dC[M,N] @ bt[N,K] + TORCH_CHECK(bt.size(0) == dc.size(1), "det_gemm_da: N mismatch"); + return gemm_dispatch(dc, bt); +} + +torch::Tensor det_gemm_db(torch::Tensor a, torch::Tensor dc) { + check_in(a, "A"); check_in(dc, "dC"); + dc = dc.contiguous(); + auto at = a.t().contiguous(); // [K,M]; dB[K,N] = at[K,M] @ dC[M,N] + TORCH_CHECK(dc.size(0) == at.size(1), "det_gemm_db: M mismatch"); + return gemm_dispatch(at, dc); +} diff --git a/csrc/ops.cpp b/csrc/ops.cpp index f48e4f9..2095d05 100644 --- a/csrc/ops.cpp +++ b/csrc/ops.cpp @@ -25,6 +25,11 @@ torch::Tensor fused_logp_forward_online_fp32(torch::Tensor logits, torch::Tensor torch::Tensor fused_logp_forward_online_indexed_out(torch::Tensor logits, torch::Tensor token_ids, torch::Tensor row_indices, torch::Tensor output); torch::Tensor fused_logp_forward_online_indexed_fp32(torch::Tensor logits, torch::Tensor token_ids, torch::Tensor row_indices); +// Batch-Invariant Deterministic GEMM Declarations +torch::Tensor det_gemm_fwd(torch::Tensor a, torch::Tensor b); +torch::Tensor det_gemm_da(torch::Tensor dc, torch::Tensor b); +torch::Tensor det_gemm_db(torch::Tensor a, torch::Tensor dc); + // Prefix-Shared Attention Declarations & Wrappers void prefix_shared_attention_forward( @@ -95,5 +100,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // registry Prefix-Shared Attention m.def("prefix_shared_attention", &prefix_shared_attention, "Prefix-Shared Fused Attention for GRPO"); + + // registry Batch-Invariant Deterministic GEMM + m.def("det_gemm_fwd", &det_gemm_fwd, "Batch-invariant deterministic GEMM forward (C=A@B)"); + m.def("det_gemm_da", &det_gemm_da, "Batch-invariant deterministic GEMM backward dA (dC@B^T)"); + m.def("det_gemm_db", &det_gemm_db, "Batch-invariant deterministic GEMM backward dB (A^T@dC)"); #endif } diff --git a/docs/.nav.yml b/docs/.nav.yml index e9ebaf0..0bfecc1 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -15,6 +15,7 @@ nav: - operators/grpo-loss.md - operators/ratio-kl.md - operators/sampling.md + - operators/det-gemm.md - Developer Guide: - contributing/README.md - General: diff --git a/docs/operators/det-gemm.md b/docs/operators/det-gemm.md new file mode 100644 index 0000000..cd48d26 --- /dev/null +++ b/docs/operators/det-gemm.md @@ -0,0 +1,53 @@ +# Batch-Invariant Deterministic GEMM (`det_gemm`) + +WS1 #146. A matrix multiply whose output for a given row is **bitwise invariant** +to batch size, chunked-prefill splitting, and padding — the property cuBLAS does +not provide, and the root fix for matmul-driven KL drift between rollout and +training. + +## Why + +Matmul is the most frequent op in a transformer (QKV, MLP, LM head), so +batch-dependent drift here dominates everything downstream. cuBLAS selects +kernels by problem shape and may use split-K, both of which change the +K-reduction order when batch size or sequence length shifts the chosen kernel. +`det_gemm` pins the accumulation order so a row's result never depends on the +rows around it. + +## Guarantees + +- Forward `C = A @ B`, backward `dA = dC·Bᵀ`, `dB = Aᵀ·dC`. +- BF16 inputs, FP32 accumulation, no TF32, no split-K, fixed K-loop order. +- Bitwise-identical output for a fixed row across batch=1/N, chunked-prefill + on/off, and padding layouts. + +## Backends + +| Backend | Deterministic | Notes | +|---|---|---| +| CUDA (`DetGemmOp`) | yes | Hand-written kernel. First milestone is a naive FP32 implementation (correctness first); a tensor-core (`mma.sync`) pass matching `prefix_shared_attention.cu` follows. NVIDIA SM80+. | +| Triton (`TritonDetGemmOp`) | yes | Autotune disabled, BLOCK pinned, no split-K. Portable / ROCm fallback and cross-backend reference. | +| PyTorch (`NativeGemmOp`) | **no** | Plain `torch.matmul`. Reference & benchmark target ONLY — cuBLAS is not batch-invariant. Excluded from registry dispatch. | + +Registry dispatch for `det_gemm` includes only the deterministic backends +(CUDA → Triton). The PyTorch op must be called explicitly. + +## Usage + +```python +from rl_engine.kernels.registry import kernel_registry +gemm = kernel_registry.get_op("det_gemm") # CUDA if built, else Triton +c = gemm(a, b) # a:[M,K] bf16, b:[K,N] bf16 +``` + +## Scope + +In: single-rank forward + backward, BF16 / FP32-accum, SM80+. +Out: tensor-parallel GEMM (WS2), FP8, ROCm-native kernel (Triton covers ROCm). + +## Performance + +`det_gemm` trades speed for determinism. The naive CUDA kernel is slow by +design; see `benchmarks/benchmark_det_gemm.py`. Overhead is reported vs cuBLAS +with TF32 disabled (the fair, same-FP32-path baseline), not as a speedup. A +slower deterministic baseline is the accepted first milestone (#146). diff --git a/rl_engine/kernels/ops/cuda/__init__.py b/rl_engine/kernels/ops/cuda/__init__.py index e69de29..5f1ae8f 100644 --- a/rl_engine/kernels/ops/cuda/__init__.py +++ b/rl_engine/kernels/ops/cuda/__init__.py @@ -0,0 +1,2 @@ +# append matmul to the existing imports +from . import attention, loss, matmul, norm # noqa: F401 diff --git a/rl_engine/kernels/ops/cuda/matmul/__init__.py b/rl_engine/kernels/ops/cuda/matmul/__init__.py new file mode 100644 index 0000000..ede9f55 --- /dev/null +++ b/rl_engine/kernels/ops/cuda/matmul/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 +from .det_gemm import DetGemmOp, deterministic_gemm + +__all__ = ["DetGemmOp", "deterministic_gemm"] diff --git a/rl_engine/kernels/ops/cuda/matmul/det_gemm.py b/rl_engine/kernels/ops/cuda/matmul/det_gemm.py new file mode 100644 index 0000000..4778be9 --- /dev/null +++ b/rl_engine/kernels/ops/cuda/matmul/det_gemm.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors +"""Batch-invariant deterministic GEMM, CUDA path (WS1 #146). + +Hand-written kernel (csrc/cuda/gemm/det_gemm_kernel.cu): fixed K-accumulation +order, FP32 accumulation, no split-K. A row's output is invariant to batch size, +chunked-prefill, and padding. No PyTorch fallback -- a generic matmul (cuBLAS) +would silently break invariance (see NativeGemmOp). Tensor-parallel GEMM is WS2. +""" +import torch + +from rl_engine.kernels.ops.base import _C, _EXT_AVAILABLE +from rl_engine.utils.logger import logger + + +class _DetGemmFn(torch.autograd.Function): + @staticmethod + def forward(ctx, a, b): + ctx.save_for_backward(a, b) + return _C.det_gemm_fwd(a, b) + + @staticmethod + def backward(ctx, grad_out): + a, b = ctx.saved_tensors + grad_out = grad_out.contiguous() + da = _C.det_gemm_da(grad_out, b) if ctx.needs_input_grad[0] else None + db = _C.det_gemm_db(a, grad_out) if ctx.needs_input_grad[1] else None + return da, db + + +class DetGemmOp: + """Hand-written batch-invariant GEMM. a:[M,K] bf16, b:[K,N] bf16 -> [M,N] bf16.""" + + def __init__(self): + self.has_hardware_op = False + if _EXT_AVAILABLE and hasattr(_C, "det_gemm_fwd"): + self.op = _C.det_gemm_fwd + self.has_hardware_op = True + logger.info("Successfully linked to RL-Kernel _C.det_gemm_fwd.") + else: + logger.warning( + "RL-Kernel _C.det_gemm_fwd unavailable; DetGemmOp requires the " + "compiled CUDA extension and has no batch-invariant fallback." + ) + + def __call__(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16, "BF16 only" + assert a.is_cuda and b.is_cuda, "Inputs must be on CUDA device" + if not self.has_hardware_op: + raise RuntimeError( + "DetGemmOp: compiled _C.det_gemm kernel unavailable; no " + "batch-invariant fallback exists. Build the extension first." + ) + return _DetGemmFn.apply(a.contiguous(), b.contiguous()) + + +def deterministic_gemm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Functional entry. a:[M,K] bf16, b:[K,N] bf16 -> [M,N] bf16.""" + return _DetGemmFn.apply(a, b) diff --git a/rl_engine/kernels/ops/pytorch/matmul/__init__.py b/rl_engine/kernels/ops/pytorch/matmul/__init__.py new file mode 100644 index 0000000..e0bae08 --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/matmul/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 +from .det_gemm import NativeGemmOp, native_gemm + +__all__ = ["NativeGemmOp", "native_gemm"] diff --git a/rl_engine/kernels/ops/pytorch/matmul/det_gemm.py b/rl_engine/kernels/ops/pytorch/matmul/det_gemm.py new file mode 100644 index 0000000..9c617aa --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/matmul/det_gemm.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors +"""Native PyTorch GEMM -- NON-deterministic reference baseline (WS1). + +WARNING: torch.matmul (cuBLAS) does NOT guarantee batch-invariance -- cuBLAS +selects kernels by shape and may use split-K. This op exists only as a +correctness reference and benchmark target, NOT as a fallback. It is +intentionally excluded from the det_gemm registry dispatch. +""" +import torch + +from rl_engine.utils.logger import logger + + +class NativeGemmOp: + """Plain torch.matmul. Non-deterministic; reference / benchmark use only.""" + + def __init__(self): + torch.backends.cuda.matmul.allow_tf32 = False + logger.info("NativeGemmOp ready (non-deterministic torch.matmul reference).") + + def __call__(self, a, b): + return torch.matmul(a, b) + + +def native_gemm(a, b): + return torch.matmul(a, b) diff --git a/rl_engine/kernels/ops/triton/matmul/__init__.py b/rl_engine/kernels/ops/triton/matmul/__init__.py new file mode 100644 index 0000000..b674dc8 --- /dev/null +++ b/rl_engine/kernels/ops/triton/matmul/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 +from .det_gemm import TritonDetGemmOp, deterministic_gemm_triton + +__all__ = ["TritonDetGemmOp", "deterministic_gemm_triton"] diff --git a/rl_engine/kernels/ops/triton/matmul/det_gemm.py b/rl_engine/kernels/ops/triton/matmul/det_gemm.py new file mode 100644 index 0000000..8318329 --- /dev/null +++ b/rl_engine/kernels/ops/triton/matmul/det_gemm.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors +"""Batch-invariant deterministic GEMM, Triton path (WS1). + +Portable implementation with the SAME invariance guarantees as the CUDA path: +autotune disabled, BLOCK sizes pinned, no split-K, fixed K-loop order, FP32 +accumulation, no TF32. Used as the cross-backend reference and the ROCm/portable +fallback. Slower than a tuned GEMM by design. +""" +import torch + +try: + import triton + import triton.language as tl + + _TRITON_AVAILABLE = True +except ImportError: + _TRITON_AVAILABLE = False + +from rl_engine.utils.logger import logger + +# Pinned. NOT autotuned (autotune picks per-shape configs -> breaks invariance). +_BLOCK_M, _BLOCK_N, _BLOCK_K = 64, 64, 32 + + +if _TRITON_AVAILABLE: + + @triton.jit + def _det_gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ): + # One program = one output tile, walks the whole K in fixed order. + # No split-K -> K-accumulation order independent of M -> batch-invariant. + pid_m, pid_n = tl.program_id(0), tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + k_rem = K - k * BLOCK_K + a = tl.load(a_ptrs, mask=offs_k[None, :] < k_rem, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < k_rem, other=0.0) + acc += tl.dot(a, b, allow_tf32=False) + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + c = acc.to(c_ptr.dtype.element_ty) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + +def _triton_gemm(a, b): + a, b = a.contiguous(), b.contiguous() + M, K = a.shape + _, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + grid = (triton.cdiv(M, _BLOCK_M), triton.cdiv(N, _BLOCK_N)) + _det_gemm_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_M=_BLOCK_M, + BLOCK_N=_BLOCK_N, + BLOCK_K=_BLOCK_K, + ) + return c + + +class _TritonDetGemmFn(torch.autograd.Function): + @staticmethod + def forward(ctx, a, b): + ctx.save_for_backward(a, b) + return _triton_gemm(a, b) + + @staticmethod + def backward(ctx, grad_out): + a, b = ctx.saved_tensors + grad_out = grad_out.contiguous() + da = _triton_gemm(grad_out, b.t().contiguous()) if ctx.needs_input_grad[0] else None + db = _triton_gemm(a.t().contiguous(), grad_out) if ctx.needs_input_grad[1] else None + return da, db + + +class TritonDetGemmOp: + """Batch-invariant deterministic GEMM, Triton path.""" + + def __init__(self): + if not _TRITON_AVAILABLE: + raise RuntimeError("Triton not available for TritonDetGemmOp") + logger.info("TritonDetGemmOp ready (deterministic, autotune disabled).") + + def __call__(self, a, b): + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16, "BF16 only" + assert a.is_cuda and b.is_cuda, "CUDA only" + return _TritonDetGemmFn.apply(a, b) + + +def deterministic_gemm_triton(a, b): + return _TritonDetGemmFn.apply(a, b) diff --git a/rl_engine/kernels/registry.py b/rl_engine/kernels/registry.py index 6780157..7cce176 100644 --- a/rl_engine/kernels/registry.py +++ b/rl_engine/kernels/registry.py @@ -49,6 +49,13 @@ class OpBackend(Enum, metaclass=_KernelEnumMeta): TRITON_RATIO_KL = "rl_engine.kernels.ops.triton.loss.ratio_kl.TritonRatioKLOp" PYTORCH_RATIO_KL = "rl_engine.kernels.ops.pytorch.loss.ratio_kl.NativeRatioKLOp" + # Batch-invariant deterministic GEMM (WS1 #146) + CUDA_DET_GEMM = "rl_engine.kernels.ops.cuda.matmul.det_gemm.DetGemmOp" + TRITON_DET_GEMM = "rl_engine.kernels.ops.triton.matmul.det_gemm.TritonDetGemmOp" + # NON-deterministic reference (torch.matmul); reference/benchmark ONLY, + # intentionally excluded from det_gemm dispatch (cuBLAS breaks invariance). + PYTORCH_GEMM = "rl_engine.kernels.ops.pytorch.matmul.det_gemm.NativeGemmOp" + # Generic fallback TRITON_GENERIC = "rl_engine.kernels.ops.triton.generic.TritonOp" PYTORCH_ATTN = "rl_engine.kernels.ops.pytorch.attention.NativeAttentionOp" @@ -89,6 +96,7 @@ def __init__(self): "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.TRITON_LINEAR_LOGP, OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], + "det_gemm": [OpBackend.CUDA_DET_GEMM, OpBackend.TRITON_DET_GEMM], # Default dispatch logic for new operators }, "rocm": { @@ -101,6 +109,7 @@ def __init__(self): "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.TRITON_LINEAR_LOGP, OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], + "det_gemm": [OpBackend.TRITON_DET_GEMM], }, "cpu": { "logp": [OpBackend.PYTORCH_NATIVE], diff --git a/setup.py b/setup.py index d5ddb89..76c4850 100644 --- a/setup.py +++ b/setup.py @@ -58,10 +58,24 @@ def get_extensions(): "csrc/ops.cpp", "csrc/fused_logp_kernel.cu", "csrc/cuda/attention/prefix_shared_attention.cu", + "csrc/cuda/gemm/det_gemm_kernel.cu", + ] + + # CUTLASS headers for det_gemm + cutlass_dir = os.environ.get("CUTLASS_DIR", "third_party/cutlass") + cutlass_includes = [ + os.path.join(cutlass_dir, "include"), + os.path.join(cutlass_dir, "tools", "util", "include"), ] cc_major, cc_minor = torch.cuda.get_device_capability() nvcc_flags = ["-O3", "--use_fast_math", "-Xfatbin", "-compress-all"] + # det_gemm SM80 path: explicit gencode + CUTLASS-friendly flags + nvcc_flags.append( + f"-gencode=arch=compute_{cc_major}{cc_minor},code=sm_{cc_major}{cc_minor}" + ) + nvcc_flags.append("--expt-relaxed-constexpr") + nvcc_flags.append("--expt-extended-lambda") nvcc_flags.extend( _cuda_define_from_env( "FUSED_LOGP_TWOPASS_BLOCK_SIZE", @@ -122,11 +136,15 @@ def get_extensions(): nvcc_flags.append(f"-gencode=arch=compute_{tma_arch},code=sm_{tma_arch}") cxx_flags.append("-DKERNEL_ALIGN_WITH_SM90") extra_link_args.append("-lcuda") + # det_gemm SM90 (WGMMA) path: enable conditional include in det_gemm_kernel.cu + nvcc_flags.append("-DRL_KERNEL_ENABLE_SM90") + cxx_flags.append("-DRL_KERNEL_ENABLE_SM90") extensions.append( CUDAExtension( name="rl_engine._C", sources=cuda_sources, + include_dirs=cutlass_includes, extra_compile_args={ "cxx": cxx_flags, "nvcc": nvcc_flags, diff --git a/tests/test_det_gemm.py b/tests/test_det_gemm.py new file mode 100644 index 0000000..440dddc --- /dev/null +++ b/tests/test_det_gemm.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors +"""Invariance + correctness tests for det_gemm (WS1). + +Runs against both deterministic backends — the hand-written CUDA kernel and the +Triton path — each of which must independently satisfy the invariance contract. +The PyTorch path (torch.matmul) is intentionally NOT tested here: it is the +non-deterministic reference baseline and would fail batch-invariance by design. +""" +import pytest +import torch + +from rl_engine.kernels.ops.cuda.matmul import deterministic_gemm + +try: + from rl_engine.kernels.ops.triton.matmul import deterministic_gemm_triton + + _HAS_TRITON = True +except ImportError: + _HAS_TRITON = False + +torch.backends.cuda.matmul.allow_tf32 = False +DEV = "cuda" + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 8, + reason="det_gemm requires CUDA SM80+", +) + +# Each deterministic backend is validated independently. +_BACKENDS = [("cuda", deterministic_gemm)] +if _HAS_TRITON: + _BACKENDS.append(("triton", deterministic_gemm_triton)) + + +def _rand(*shape): + return torch.randn(*shape, device=DEV, dtype=torch.bfloat16) + + +@pytest.mark.parametrize("name,gemm", _BACKENDS) +def test_forward_batch_invariance(name, gemm): + # A row's output must not change when other rows join the batch. + torch.manual_seed(0) + K, N = 4096, 4096 + b = _rand(K, N) + row = _rand(1, K) + out1 = gemm(row, b) + big = _rand(512, K) + big[0] = row[0] + outN = gemm(big, b) + assert torch.equal(out1[0], outN[0]), f"{name}: forward batch-invariance broken" + + +@pytest.mark.parametrize("name,gemm", _BACKENDS) +def test_forward_chunked_prefill(name, gemm): + # Splitting M then concatenating must match the full GEMM bitwise. + torch.manual_seed(1) + M, K, N = 256, 4096, 4096 + a, b = _rand(M, K), _rand(K, N) + full = gemm(a, b) + chunked = torch.cat([gemm(a[:100], b), gemm(a[100:], b)], dim=0) + assert torch.equal(full, chunked), f"{name}: chunked-prefill broke invariance" + + +@pytest.mark.parametrize("name,gemm", _BACKENDS) +def test_forward_padding_invariance(name, gemm): + # Padding rows must not affect valid rows' output. + torch.manual_seed(2) + M, K, N = 100, 4096, 4096 + a, b = _rand(M, K), _rand(K, N) + base = gemm(a, b) + a_pad = torch.cat([a, _rand(28, K)], dim=0) + padded = gemm(a_pad, b) + assert torch.equal(base, padded[:M]), f"{name}: padding changed valid-row output" + + +@pytest.mark.parametrize("name,gemm", _BACKENDS) +def test_forward_correctness(name, gemm): + # vs FP32 reference. Placeholder tolerance; PR3 swaps for #108 contract. + torch.manual_seed(3) + M, K, N = 128, 2048, 2048 + a, b = _rand(M, K), _rand(K, N) + out = gemm(a, b).float() + ref = a.float() @ b.float() + assert (out - ref).abs().max().item() < 1.0 # TODO(#108): contract threshold + + +@pytest.mark.parametrize("name,gemm", _BACKENDS) +def test_backward_batch_invariance(name, gemm): + # dA for a row must be invariant to the surrounding batch. + torch.manual_seed(4) + K, N = 2048, 2048 + b = _rand(K, N) + row = _rand(1, K).requires_grad_(True) + gemm(row, b).sum().backward() + g1 = row.grad.clone() + big = _rand(256, K) + big[0] = row.detach()[0] + big.requires_grad_(True) + gemm(big, b).sum().backward() + assert torch.equal(g1[0], big.grad[0]), f"{name}: backward dA batch-invariance broken" + + +@pytest.mark.parametrize("name,gemm", _BACKENDS) +def test_backward_correctness(name, gemm): + # dA / dB vs FP32 reference gradients. Placeholder tolerance; PR3 -> #108. + torch.manual_seed(5) + M, K, N = 64, 1024, 1024 + a = _rand(M, K).requires_grad_(True) + b = _rand(K, N).requires_grad_(True) + g = _rand(M, N) + gemm(a, b).backward(g) + af = a.detach().float().requires_grad_(True) + bf = b.detach().float().requires_grad_(True) + (af @ bf).backward(g.float()) + assert (a.grad.float() - af.grad).abs().max().item() < 2.0 # TODO(#108) + assert (b.grad.float() - bf.grad).abs().max().item() < 2.0 # TODO(#108) + + +@pytest.mark.parametrize("name,gemm", _BACKENDS) +@pytest.mark.parametrize( + "shape", + [ + (4096, 4096, 12288), # qkv + (4096, 4096, 4096), # o_proj + (4096, 4096, 14336), # mlp_up + (4096, 14336, 4096), # mlp_dn + (4096, 4096, 32000), # lm_head + ], +) +def test_target_shapes_invariance(name, gemm, shape): + # Standard-Transformer projection shapes stay batch-invariant. + torch.manual_seed(6) + M, K, N = shape + b = _rand(K, N) + row = _rand(1, K) + big = _rand(64, K) + big[0] = row[0] + assert torch.equal( + gemm(row, b)[0], gemm(big, b)[0] + ), f"{name}: batch-invariance broken at shape {shape}"