From 90d120c1322ba4aa38cd40038c38349ec8d93fc3 Mon Sep 17 00:00:00 2001 From: vensen Date: Tue, 23 Jun 2026 01:10:55 +0800 Subject: [PATCH 1/3] add ws1 gemm op Signed-off-by: vensen --- benchmarks/benchmark_det_gemm.py | 90 ++++++++++++ csrc/cuda/gemm/det_gemm_kernel.cu | 111 +++++++++++++++ csrc/ops.cpp | 10 ++ docs/.nav.yml | 1 + docs/operators/det-gemm.md | 53 +++++++ rl_engine/kernels/ops/cuda/__init__.py | 2 + rl_engine/kernels/ops/cuda/matmul/__init__.py | 4 + rl_engine/kernels/ops/cuda/matmul/det_gemm.py | 59 ++++++++ .../kernels/ops/pytorch/matmul/__init__.py | 4 + .../kernels/ops/pytorch/matmul/det_gemm.py | 27 ++++ .../kernels/ops/triton/matmul/__init__.py | 4 + .../kernels/ops/triton/matmul/det_gemm.py | 100 +++++++++++++ rl_engine/kernels/registry.py | 9 ++ setup.py | 18 +++ tests/test_det_gemm.py | 134 ++++++++++++++++++ 15 files changed, 626 insertions(+) create mode 100644 benchmarks/benchmark_det_gemm.py create mode 100644 csrc/cuda/gemm/det_gemm_kernel.cu create mode 100644 docs/operators/det-gemm.md create mode 100644 rl_engine/kernels/ops/cuda/matmul/__init__.py create mode 100644 rl_engine/kernels/ops/cuda/matmul/det_gemm.py create mode 100644 rl_engine/kernels/ops/pytorch/matmul/__init__.py create mode 100644 rl_engine/kernels/ops/pytorch/matmul/det_gemm.py create mode 100644 rl_engine/kernels/ops/triton/matmul/__init__.py create mode 100644 rl_engine/kernels/ops/triton/matmul/det_gemm.py create mode 100644 tests/test_det_gemm.py diff --git a/benchmarks/benchmark_det_gemm.py b/benchmarks/benchmark_det_gemm.py new file mode 100644 index 0000000..0716b8e --- /dev/null +++ b/benchmarks/benchmark_det_gemm.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors +"""Overhead of batch-invariant det_gemm vs cuBLAS + Triton (WS1 #146). + +det_gemm (CUDA, naive first milestone) and the Triton path are batch-invariant +and SLOWER than cuBLAS by design (no split-K/stream-K, fixed accumulation, FP32, +no TF32). Reports overhead vs the fair baseline (cuBLAS, TF32 disabled), not a +speedup. The naive CUDA kernel is correctness-first; a tensor-core pass follows. +""" +import argparse +import torch + +from rl_engine.kernels.ops.cuda.matmul import deterministic_gemm +from rl_engine.kernels.ops.pytorch.matmul import native_gemm + +try: + from rl_engine.kernels.ops.triton.matmul import deterministic_gemm_triton + _HAS_TRITON = True +except ImportError: + _HAS_TRITON = False + +DEV = "cuda" +WARMUP, ITERS = 10, 50 + +SHAPES = [ + ("qkv", 4096, 4096, 12288), + ("o_proj", 4096, 4096, 4096), + ("mlp_up", 4096, 4096, 14336), + ("mlp_dn", 4096, 14336, 4096), + ("lm_head", 4096, 4096, 32000), +] + + +def _time(fn, a, b): + for _ in range(WARMUP): + fn(a, b) + torch.cuda.synchronize() + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + for _ in range(ITERS): + fn(a, b) + e.record() + torch.cuda.synchronize() + return s.elapsed_time(e) / ITERS + + +def run(): + rows = [] + for name, M, K, N in SHAPES: + a = torch.randn(M, K, device=DEV, dtype=torch.bfloat16) + b = torch.randn(K, N, device=DEV, dtype=torch.bfloat16) + torch.backends.cuda.matmul.allow_tf32 = True + t_tf32 = _time(lambda x, y: torch.matmul(x, y), a, b) + torch.backends.cuda.matmul.allow_tf32 = False + t_fp32 = _time(native_gemm, a, b) + t_cuda = _time(deterministic_gemm, a, b) + t_tri = _time(deterministic_gemm_triton, a, b) if _HAS_TRITON else float("nan") + rows.append((name, M, K, N, t_tf32, t_fp32, t_cuda, t_tri, t_cuda / t_fp32)) + return rows + + +def to_markdown(rows, dev, cap): + out = [f"## det_gemm overhead — {dev} (SM{cap[0]}{cap[1]})", "", + "| shape | M | K | N | cuBLAS tf32 | cuBLAS fp32 | det CUDA | det Triton | overhead |", + "|---|---|---|---|---|---|---|---|---|"] + for n, M, K, N, t1, t2, t3, t4, ov in rows: + out.append(f"| {n} | {M} | {K} | {N} | {t1:.3f} | {t2:.3f} | {t3:.3f} | {t4:.3f} | {ov:.1f}x |") + out += ["", + "_Overhead = det CUDA vs cuBLAS (TF32 disabled). Naive CUDA kernel is " + "correctness-first; both det paths trade speed for bitwise " + "batch-invariance. Tensor-core pass is a follow-up (#146)._"] + return "\n".join(out) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--out", default=None) + args = ap.parse_args() + name, cap = torch.cuda.get_device_name(), torch.cuda.get_device_capability() + print(name, cap) + md = to_markdown(run(), name, cap) + print("\n" + md) + if args.out: + with open(args.out, "w") as f: + f.write(md + "\n") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/csrc/cuda/gemm/det_gemm_kernel.cu b/csrc/cuda/gemm/det_gemm_kernel.cu new file mode 100644 index 0000000..1c38866 --- /dev/null +++ b/csrc/cuda/gemm/det_gemm_kernel.cu @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2026 RL-Kernel Contributors + +// WS1 - Batch-invariant deterministic GEMM (hand-written, no CUTLASS). +// +// First-milestone naive implementation: one thread computes one output element, +// walking the whole K dimension in a fixed loop order with FP32 accumulation. +// NO split-K, NO shape-based kernel selection -> a row's reduction order is +// independent of the batch (M) dimension, so the output is bitwise-invariant to +// batch size, chunked-prefill splitting, and padding layout. +// +// This is intentionally slow (correctness + invariance first, per #146). A +// tensor-core (mma.sync / ldmatrix) optimization, matching the +// prefix_shared_attention.cu style, is a follow-up within this same file. +// +// fwd: C = A @ B | dA = dC @ B^T | dB = A^T @ dC +// Backward reuses the same kernel on transposed operands, so the gradients +// inherit the same invariance. + +#include +#include +#include + +namespace { + +using nv_bf16 = __nv_bfloat16; + +constexpr int TILE = 16; // 16x16 thread block + +__host__ __device__ constexpr int cdiv(int a, int b) { return (a + b - 1) / b; } + +// C[M,N] = A[M,K] @ B[K,N], all row-major, BF16 in / FP32 accumulate / BF16 out. +// Each thread owns one C[row, col]; the K loop order is fixed and identical for +// every (row, col) regardless of M -> batch-invariant. +__global__ void det_gemm_naive(const nv_bf16* __restrict__ A, + const nv_bf16* __restrict__ B, + nv_bf16* __restrict__ C, + int M, int N, int K) { + const int row = blockIdx.y * TILE + threadIdx.y; + const int col = blockIdx.x * TILE + threadIdx.x; + if (row >= M || col >= N) return; + + float acc = 0.0f; // FP32 accumulation + // Fixed ascending K order, no split-K, no atomics. Deterministic. + for (int k = 0; k < K; ++k) { + float a = __bfloat162float(A[row * K + k]); + float b = __bfloat162float(B[k * N + col]); + acc += a * b; + } + C[row * N + col] = __float2bfloat16(acc); +} + +void launch_naive(const nv_bf16* A, const nv_bf16* B, nv_bf16* C, + int M, int N, int K, cudaStream_t stream) { + dim3 block(TILE, TILE); + dim3 grid(cdiv(N, TILE), cdiv(M, TILE)); + det_gemm_naive<<>>(A, B, C, M, N, K); +} + +inline const nv_bf16* bf16(const torch::Tensor& t) { + return reinterpret_cast(t.data_ptr()); +} +inline nv_bf16* bf16o(torch::Tensor& t) { + return reinterpret_cast(t.data_ptr()); +} +void check_in(const torch::Tensor& t, const char* n) { + TORCH_CHECK(t.is_cuda(), n, " must be CUDA"); + TORCH_CHECK(t.scalar_type() == torch::kBFloat16, n, " must be bf16"); +} + +} + +// fwd: C = A @ B +torch::Tensor det_gemm_fwd(torch::Tensor a, torch::Tensor b) { + check_in(a, "A"); check_in(b, "B"); + a = a.contiguous(); b = b.contiguous(); + TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "det_gemm_fwd: expect 2D [M,K]@[K,N]"); + const int M = a.size(0), K = a.size(1); + TORCH_CHECK(b.size(0) == K, "det_gemm_fwd: K mismatch"); + const int N = b.size(1); + auto c = torch::empty({M, N}, a.options()); + launch_naive(bf16(a), bf16(b), bf16o(c), M, N, K, + at::cuda::getCurrentCUDAStream()); + return c; +} + +// dA = dC @ B^T -> forward GEMM on materialized transpose of B +torch::Tensor det_gemm_da(torch::Tensor dc, torch::Tensor b) { + check_in(dc, "dC"); check_in(b, "B"); + dc = dc.contiguous(); + auto bt = b.t().contiguous(); // [N, K] + const int M = dc.size(0), N = dc.size(1), K = bt.size(1); + TORCH_CHECK(bt.size(0) == N, "det_gemm_da: N mismatch"); + auto da = torch::empty({M, K}, dc.options()); // [M, K] + launch_naive(bf16(dc), bf16(bt), bf16o(da), M, K, N, + at::cuda::getCurrentCUDAStream()); + return da; +} + +// dB = A^T @ dC -> forward GEMM on materialized transpose of A +torch::Tensor det_gemm_db(torch::Tensor a, torch::Tensor dc) { + check_in(a, "A"); check_in(dc, "dC"); + dc = dc.contiguous(); + auto at = a.t().contiguous(); // [K, M] + const int K = at.size(0), M = at.size(1), N = dc.size(1); + TORCH_CHECK(dc.size(0) == M, "det_gemm_db: M mismatch"); + auto db = torch::empty({K, N}, a.options()); // [K, N] + launch_naive(bf16(at), bf16(dc), bf16o(db), K, N, M, + at::cuda::getCurrentCUDAStream()); + return db; +} \ No newline at end of file diff --git a/csrc/ops.cpp b/csrc/ops.cpp index f48e4f9..2095d05 100644 --- a/csrc/ops.cpp +++ b/csrc/ops.cpp @@ -25,6 +25,11 @@ torch::Tensor fused_logp_forward_online_fp32(torch::Tensor logits, torch::Tensor torch::Tensor fused_logp_forward_online_indexed_out(torch::Tensor logits, torch::Tensor token_ids, torch::Tensor row_indices, torch::Tensor output); torch::Tensor fused_logp_forward_online_indexed_fp32(torch::Tensor logits, torch::Tensor token_ids, torch::Tensor row_indices); +// Batch-Invariant Deterministic GEMM Declarations +torch::Tensor det_gemm_fwd(torch::Tensor a, torch::Tensor b); +torch::Tensor det_gemm_da(torch::Tensor dc, torch::Tensor b); +torch::Tensor det_gemm_db(torch::Tensor a, torch::Tensor dc); + // Prefix-Shared Attention Declarations & Wrappers void prefix_shared_attention_forward( @@ -95,5 +100,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // registry Prefix-Shared Attention m.def("prefix_shared_attention", &prefix_shared_attention, "Prefix-Shared Fused Attention for GRPO"); + + // registry Batch-Invariant Deterministic GEMM + m.def("det_gemm_fwd", &det_gemm_fwd, "Batch-invariant deterministic GEMM forward (C=A@B)"); + m.def("det_gemm_da", &det_gemm_da, "Batch-invariant deterministic GEMM backward dA (dC@B^T)"); + m.def("det_gemm_db", &det_gemm_db, "Batch-invariant deterministic GEMM backward dB (A^T@dC)"); #endif } diff --git a/docs/.nav.yml b/docs/.nav.yml index e9ebaf0..0bfecc1 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -15,6 +15,7 @@ nav: - operators/grpo-loss.md - operators/ratio-kl.md - operators/sampling.md + - operators/det-gemm.md - Developer Guide: - contributing/README.md - General: diff --git a/docs/operators/det-gemm.md b/docs/operators/det-gemm.md new file mode 100644 index 0000000..b2a8b53 --- /dev/null +++ b/docs/operators/det-gemm.md @@ -0,0 +1,53 @@ +# Batch-Invariant Deterministic GEMM (`det_gemm`) + +WS1 #146. A matrix multiply whose output for a given row is **bitwise invariant** +to batch size, chunked-prefill splitting, and padding — the property cuBLAS does +not provide, and the root fix for matmul-driven KL drift between rollout and +training. + +## Why + +Matmul is the most frequent op in a transformer (QKV, MLP, LM head), so +batch-dependent drift here dominates everything downstream. cuBLAS selects +kernels by problem shape and may use split-K, both of which change the +K-reduction order when batch size or sequence length shifts the chosen kernel. +`det_gemm` pins the accumulation order so a row's result never depends on the +rows around it. + +## Guarantees + +- Forward `C = A @ B`, backward `dA = dC·Bᵀ`, `dB = Aᵀ·dC`. +- BF16 inputs, FP32 accumulation, no TF32, no split-K, fixed K-loop order. +- Bitwise-identical output for a fixed row across batch=1/N, chunked-prefill + on/off, and padding layouts. + +## Backends + +| Backend | Deterministic | Notes | +|---|---|---| +| CUDA (`DetGemmOp`) | yes | Hand-written kernel. First milestone is a naive FP32 implementation (correctness first); a tensor-core (`mma.sync`) pass matching `prefix_shared_attention.cu` follows. NVIDIA SM80+. | +| Triton (`TritonDetGemmOp`) | yes | Autotune disabled, BLOCK pinned, no split-K. Portable / ROCm fallback and cross-backend reference. | +| PyTorch (`NativeGemmOp`) | **no** | Plain `torch.matmul`. Reference & benchmark target ONLY — cuBLAS is not batch-invariant. Excluded from registry dispatch. | + +Registry dispatch for `det_gemm` includes only the deterministic backends +(CUDA → Triton). The PyTorch op must be called explicitly. + +## Usage + +```python +from rl_engine.kernels.registry import kernel_registry +gemm = kernel_registry.get_op("det_gemm") # CUDA if built, else Triton +c = gemm(a, b) # a:[M,K] bf16, b:[K,N] bf16 +``` + +## Scope + +In: single-rank forward + backward, BF16 / FP32-accum, SM80+. +Out: tensor-parallel GEMM (WS2), FP8, ROCm-native kernel (Triton covers ROCm). + +## Performance + +`det_gemm` trades speed for determinism. The naive CUDA kernel is slow by +design; see `benchmarks/benchmark_det_gemm.py`. Overhead is reported vs cuBLAS +with TF32 disabled (the fair, same-FP32-path baseline), not as a speedup. A +slower deterministic baseline is the accepted first milestone (#146). \ No newline at end of file diff --git a/rl_engine/kernels/ops/cuda/__init__.py b/rl_engine/kernels/ops/cuda/__init__.py index e69de29..84ef18f 100644 --- a/rl_engine/kernels/ops/cuda/__init__.py +++ b/rl_engine/kernels/ops/cuda/__init__.py @@ -0,0 +1,2 @@ +# append matmul to the existing imports +from . import attention, loss, norm, matmul # noqa: F401 \ No newline at end of file diff --git a/rl_engine/kernels/ops/cuda/matmul/__init__.py b/rl_engine/kernels/ops/cuda/matmul/__init__.py new file mode 100644 index 0000000..74657f5 --- /dev/null +++ b/rl_engine/kernels/ops/cuda/matmul/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 +from .det_gemm import DetGemmOp, deterministic_gemm + +__all__ = ["DetGemmOp", "deterministic_gemm"] \ No newline at end of file diff --git a/rl_engine/kernels/ops/cuda/matmul/det_gemm.py b/rl_engine/kernels/ops/cuda/matmul/det_gemm.py new file mode 100644 index 0000000..b822c29 --- /dev/null +++ b/rl_engine/kernels/ops/cuda/matmul/det_gemm.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors +"""Batch-invariant deterministic GEMM, CUDA path (WS1 #146). + +Hand-written kernel (csrc/cuda/gemm/det_gemm_kernel.cu): fixed K-accumulation +order, FP32 accumulation, no split-K. A row's output is invariant to batch size, +chunked-prefill, and padding. No PyTorch fallback -- a generic matmul (cuBLAS) +would silently break invariance (see NativeGemmOp). Tensor-parallel GEMM is WS2. +""" +import torch + +from rl_engine.kernels.ops.base import _C, _EXT_AVAILABLE +from rl_engine.utils.logger import logger + + +class _DetGemmFn(torch.autograd.Function): + @staticmethod + def forward(ctx, a, b): + ctx.save_for_backward(a, b) + return _C.det_gemm_fwd(a, b) + + @staticmethod + def backward(ctx, grad_out): + a, b = ctx.saved_tensors + grad_out = grad_out.contiguous() + da = _C.det_gemm_da(grad_out, b) if ctx.needs_input_grad[0] else None + db = _C.det_gemm_db(a, grad_out) if ctx.needs_input_grad[1] else None + return da, db + + +class DetGemmOp: + """Hand-written batch-invariant GEMM. a:[M,K] bf16, b:[K,N] bf16 -> [M,N] bf16.""" + + def __init__(self): + self.has_hardware_op = False + if _EXT_AVAILABLE and hasattr(_C, "det_gemm_fwd"): + self.op = _C.det_gemm_fwd + self.has_hardware_op = True + logger.info("Successfully linked to RL-Kernel _C.det_gemm_fwd.") + else: + logger.warning( + "RL-Kernel _C.det_gemm_fwd unavailable; DetGemmOp requires the " + "compiled CUDA extension and has no batch-invariant fallback." + ) + + def __call__(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16, "BF16 only" + assert a.is_cuda and b.is_cuda, "Inputs must be on CUDA device" + if not self.has_hardware_op: + raise RuntimeError( + "DetGemmOp: compiled _C.det_gemm kernel unavailable; no " + "batch-invariant fallback exists. Build the extension first." + ) + return _DetGemmFn.apply(a.contiguous(), b.contiguous()) + + +def deterministic_gemm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Functional entry. a:[M,K] bf16, b:[K,N] bf16 -> [M,N] bf16.""" + return _DetGemmFn.apply(a, b) \ No newline at end of file diff --git a/rl_engine/kernels/ops/pytorch/matmul/__init__.py b/rl_engine/kernels/ops/pytorch/matmul/__init__.py new file mode 100644 index 0000000..5ebfa2e --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/matmul/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 +from .det_gemm import NativeGemmOp, native_gemm + +__all__ = ["NativeGemmOp", "native_gemm"] \ No newline at end of file diff --git a/rl_engine/kernels/ops/pytorch/matmul/det_gemm.py b/rl_engine/kernels/ops/pytorch/matmul/det_gemm.py new file mode 100644 index 0000000..42aecce --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/matmul/det_gemm.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors +"""Native PyTorch GEMM -- NON-deterministic reference baseline (WS1). + +WARNING: torch.matmul (cuBLAS) does NOT guarantee batch-invariance -- cuBLAS +selects kernels by shape and may use split-K. This op exists only as a +correctness reference and benchmark target, NOT as a fallback. It is +intentionally excluded from the det_gemm registry dispatch. +""" +import torch + +from rl_engine.utils.logger import logger + + +class NativeGemmOp: + """Plain torch.matmul. Non-deterministic; reference / benchmark use only.""" + + def __init__(self): + torch.backends.cuda.matmul.allow_tf32 = False + logger.info("NativeGemmOp ready (non-deterministic torch.matmul reference).") + + def __call__(self, a, b): + return torch.matmul(a, b) + + +def native_gemm(a, b): + return torch.matmul(a, b) \ No newline at end of file diff --git a/rl_engine/kernels/ops/triton/matmul/__init__.py b/rl_engine/kernels/ops/triton/matmul/__init__.py new file mode 100644 index 0000000..3dfc979 --- /dev/null +++ b/rl_engine/kernels/ops/triton/matmul/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 +from .det_gemm import TritonDetGemmOp, deterministic_gemm_triton + +__all__ = ["TritonDetGemmOp", "deterministic_gemm_triton"] \ No newline at end of file diff --git a/rl_engine/kernels/ops/triton/matmul/det_gemm.py b/rl_engine/kernels/ops/triton/matmul/det_gemm.py new file mode 100644 index 0000000..17f149c --- /dev/null +++ b/rl_engine/kernels/ops/triton/matmul/det_gemm.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors +"""Batch-invariant deterministic GEMM, Triton path (WS1). + +Portable implementation with the SAME invariance guarantees as the CUDA path: +autotune disabled, BLOCK sizes pinned, no split-K, fixed K-loop order, FP32 +accumulation, no TF32. Used as the cross-backend reference and the ROCm/portable +fallback. Slower than a tuned GEMM by design. +""" +import torch + +try: + import triton + import triton.language as tl + + _TRITON_AVAILABLE = True +except ImportError: + _TRITON_AVAILABLE = False + +from rl_engine.utils.logger import logger + +# Pinned. NOT autotuned (autotune picks per-shape configs -> breaks invariance). +_BLOCK_M, _BLOCK_N, _BLOCK_K = 64, 64, 32 + + +if _TRITON_AVAILABLE: + + @triton.jit + def _det_gemm_kernel( + a_ptr, b_ptr, c_ptr, M, N, K, + stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + ): + # One program = one output tile, walks the whole K in fixed order. + # No split-K -> K-accumulation order independent of M -> batch-invariant. + pid_m, pid_n = tl.program_id(0), tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + k_rem = K - k * BLOCK_K + a = tl.load(a_ptrs, mask=offs_k[None, :] < k_rem, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < k_rem, other=0.0) + acc += tl.dot(a, b, allow_tf32=False) + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + c = acc.to(c_ptr.dtype.element_ty) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + +def _triton_gemm(a, b): + a, b = a.contiguous(), b.contiguous() + M, K = a.shape + _, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + grid = (triton.cdiv(M, _BLOCK_M), triton.cdiv(N, _BLOCK_N)) + _det_gemm_kernel[grid]( + a, b, c, M, N, K, + a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), + BLOCK_M=_BLOCK_M, BLOCK_N=_BLOCK_N, BLOCK_K=_BLOCK_K, + ) + return c + + +class _TritonDetGemmFn(torch.autograd.Function): + @staticmethod + def forward(ctx, a, b): + ctx.save_for_backward(a, b) + return _triton_gemm(a, b) + + @staticmethod + def backward(ctx, grad_out): + a, b = ctx.saved_tensors + grad_out = grad_out.contiguous() + da = _triton_gemm(grad_out, b.t().contiguous()) if ctx.needs_input_grad[0] else None + db = _triton_gemm(a.t().contiguous(), grad_out) if ctx.needs_input_grad[1] else None + return da, db + + +class TritonDetGemmOp: + """Batch-invariant deterministic GEMM, Triton path.""" + + def __init__(self): + if not _TRITON_AVAILABLE: + raise RuntimeError("Triton not available for TritonDetGemmOp") + logger.info("TritonDetGemmOp ready (deterministic, autotune disabled).") + + def __call__(self, a, b): + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16, "BF16 only" + assert a.is_cuda and b.is_cuda, "CUDA only" + return _TritonDetGemmFn.apply(a, b) + + +def deterministic_gemm_triton(a, b): + return _TritonDetGemmFn.apply(a, b) \ No newline at end of file diff --git a/rl_engine/kernels/registry.py b/rl_engine/kernels/registry.py index 6780157..7cce176 100644 --- a/rl_engine/kernels/registry.py +++ b/rl_engine/kernels/registry.py @@ -49,6 +49,13 @@ class OpBackend(Enum, metaclass=_KernelEnumMeta): TRITON_RATIO_KL = "rl_engine.kernels.ops.triton.loss.ratio_kl.TritonRatioKLOp" PYTORCH_RATIO_KL = "rl_engine.kernels.ops.pytorch.loss.ratio_kl.NativeRatioKLOp" + # Batch-invariant deterministic GEMM (WS1 #146) + CUDA_DET_GEMM = "rl_engine.kernels.ops.cuda.matmul.det_gemm.DetGemmOp" + TRITON_DET_GEMM = "rl_engine.kernels.ops.triton.matmul.det_gemm.TritonDetGemmOp" + # NON-deterministic reference (torch.matmul); reference/benchmark ONLY, + # intentionally excluded from det_gemm dispatch (cuBLAS breaks invariance). + PYTORCH_GEMM = "rl_engine.kernels.ops.pytorch.matmul.det_gemm.NativeGemmOp" + # Generic fallback TRITON_GENERIC = "rl_engine.kernels.ops.triton.generic.TritonOp" PYTORCH_ATTN = "rl_engine.kernels.ops.pytorch.attention.NativeAttentionOp" @@ -89,6 +96,7 @@ def __init__(self): "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.TRITON_LINEAR_LOGP, OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], + "det_gemm": [OpBackend.CUDA_DET_GEMM, OpBackend.TRITON_DET_GEMM], # Default dispatch logic for new operators }, "rocm": { @@ -101,6 +109,7 @@ def __init__(self): "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.TRITON_LINEAR_LOGP, OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], + "det_gemm": [OpBackend.TRITON_DET_GEMM], }, "cpu": { "logp": [OpBackend.PYTORCH_NATIVE], diff --git a/setup.py b/setup.py index d5ddb89..76c4850 100644 --- a/setup.py +++ b/setup.py @@ -58,10 +58,24 @@ def get_extensions(): "csrc/ops.cpp", "csrc/fused_logp_kernel.cu", "csrc/cuda/attention/prefix_shared_attention.cu", + "csrc/cuda/gemm/det_gemm_kernel.cu", + ] + + # CUTLASS headers for det_gemm + cutlass_dir = os.environ.get("CUTLASS_DIR", "third_party/cutlass") + cutlass_includes = [ + os.path.join(cutlass_dir, "include"), + os.path.join(cutlass_dir, "tools", "util", "include"), ] cc_major, cc_minor = torch.cuda.get_device_capability() nvcc_flags = ["-O3", "--use_fast_math", "-Xfatbin", "-compress-all"] + # det_gemm SM80 path: explicit gencode + CUTLASS-friendly flags + nvcc_flags.append( + f"-gencode=arch=compute_{cc_major}{cc_minor},code=sm_{cc_major}{cc_minor}" + ) + nvcc_flags.append("--expt-relaxed-constexpr") + nvcc_flags.append("--expt-extended-lambda") nvcc_flags.extend( _cuda_define_from_env( "FUSED_LOGP_TWOPASS_BLOCK_SIZE", @@ -122,11 +136,15 @@ def get_extensions(): nvcc_flags.append(f"-gencode=arch=compute_{tma_arch},code=sm_{tma_arch}") cxx_flags.append("-DKERNEL_ALIGN_WITH_SM90") extra_link_args.append("-lcuda") + # det_gemm SM90 (WGMMA) path: enable conditional include in det_gemm_kernel.cu + nvcc_flags.append("-DRL_KERNEL_ENABLE_SM90") + cxx_flags.append("-DRL_KERNEL_ENABLE_SM90") extensions.append( CUDAExtension( name="rl_engine._C", sources=cuda_sources, + include_dirs=cutlass_includes, extra_compile_args={ "cxx": cxx_flags, "nvcc": nvcc_flags, diff --git a/tests/test_det_gemm.py b/tests/test_det_gemm.py new file mode 100644 index 0000000..2f9b3ac --- /dev/null +++ b/tests/test_det_gemm.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors +"""Invariance + correctness tests for det_gemm (WS1). + +Runs against both deterministic backends — the hand-written CUDA kernel and the +Triton path — each of which must independently satisfy the invariance contract. +The PyTorch path (torch.matmul) is intentionally NOT tested here: it is the +non-deterministic reference baseline and would fail batch-invariance by design. +""" +import pytest +import torch + +from rl_engine.kernels.ops.cuda.matmul import deterministic_gemm + +try: + from rl_engine.kernels.ops.triton.matmul import deterministic_gemm_triton + _HAS_TRITON = True +except ImportError: + _HAS_TRITON = False + +torch.backends.cuda.matmul.allow_tf32 = False +DEV = "cuda" + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 8, + reason="det_gemm requires CUDA SM80+", +) + +# Each deterministic backend is validated independently. +_BACKENDS = [("cuda", deterministic_gemm)] +if _HAS_TRITON: + _BACKENDS.append(("triton", deterministic_gemm_triton)) + + +def _rand(*shape): + return torch.randn(*shape, device=DEV, dtype=torch.bfloat16) + + +@pytest.mark.parametrize("name,gemm", _BACKENDS) +def test_forward_batch_invariance(name, gemm): + # A row's output must not change when other rows join the batch. + torch.manual_seed(0) + K, N = 4096, 4096 + b = _rand(K, N) + row = _rand(1, K) + out1 = gemm(row, b) + big = _rand(512, K); big[0] = row[0] + outN = gemm(big, b) + assert torch.equal(out1[0], outN[0]), f"{name}: forward batch-invariance broken" + + +@pytest.mark.parametrize("name,gemm", _BACKENDS) +def test_forward_chunked_prefill(name, gemm): + # Splitting M then concatenating must match the full GEMM bitwise. + torch.manual_seed(1) + M, K, N = 256, 4096, 4096 + a, b = _rand(M, K), _rand(K, N) + full = gemm(a, b) + chunked = torch.cat([gemm(a[:100], b), gemm(a[100:], b)], dim=0) + assert torch.equal(full, chunked), f"{name}: chunked-prefill broke invariance" + + +@pytest.mark.parametrize("name,gemm", _BACKENDS) +def test_forward_padding_invariance(name, gemm): + # Padding rows must not affect valid rows' output. + torch.manual_seed(2) + M, K, N = 100, 4096, 4096 + a, b = _rand(M, K), _rand(K, N) + base = gemm(a, b) + a_pad = torch.cat([a, _rand(28, K)], dim=0) + padded = gemm(a_pad, b) + assert torch.equal(base, padded[:M]), f"{name}: padding changed valid-row output" + + +@pytest.mark.parametrize("name,gemm", _BACKENDS) +def test_forward_correctness(name, gemm): + # vs FP32 reference. Placeholder tolerance; PR3 swaps for #108 contract. + torch.manual_seed(3) + M, K, N = 128, 2048, 2048 + a, b = _rand(M, K), _rand(K, N) + out = gemm(a, b).float() + ref = a.float() @ b.float() + assert (out - ref).abs().max().item() < 1.0 # TODO(#108): contract threshold + + +@pytest.mark.parametrize("name,gemm", _BACKENDS) +def test_backward_batch_invariance(name, gemm): + # dA for a row must be invariant to the surrounding batch. + torch.manual_seed(4) + K, N = 2048, 2048 + b = _rand(K, N) + row = _rand(1, K).requires_grad_(True) + gemm(row, b).sum().backward() + g1 = row.grad.clone() + big = _rand(256, K); big[0] = row.detach()[0] + big.requires_grad_(True) + gemm(big, b).sum().backward() + assert torch.equal(g1[0], big.grad[0]), f"{name}: backward dA batch-invariance broken" + + +@pytest.mark.parametrize("name,gemm", _BACKENDS) +def test_backward_correctness(name, gemm): + # dA / dB vs FP32 reference gradients. Placeholder tolerance; PR3 -> #108. + torch.manual_seed(5) + M, K, N = 64, 1024, 1024 + a = _rand(M, K).requires_grad_(True) + b = _rand(K, N).requires_grad_(True) + g = _rand(M, N) + gemm(a, b).backward(g) + af = a.detach().float().requires_grad_(True) + bf = b.detach().float().requires_grad_(True) + (af @ bf).backward(g.float()) + assert (a.grad.float() - af.grad).abs().max().item() < 2.0 # TODO(#108) + assert (b.grad.float() - bf.grad).abs().max().item() < 2.0 # TODO(#108) + + +@pytest.mark.parametrize("name,gemm", _BACKENDS) +@pytest.mark.parametrize("shape", [ + (4096, 4096, 12288), # qkv + (4096, 4096, 4096), # o_proj + (4096, 4096, 14336), # mlp_up + (4096, 14336, 4096), # mlp_dn + (4096, 4096, 32000), # lm_head +]) +def test_target_shapes_invariance(name, gemm, shape): + # Standard-Transformer projection shapes stay batch-invariant. + torch.manual_seed(6) + M, K, N = shape + b = _rand(K, N) + row = _rand(1, K) + big = _rand(64, K); big[0] = row[0] + assert torch.equal( + gemm(row, b)[0], gemm(big, b)[0] + ), f"{name}: batch-invariance broken at shape {shape}" \ No newline at end of file From f198931fe78d1fd55cb3c1a06ae6cd8d89724635 Mon Sep 17 00:00:00 2001 From: vensen Date: Wed, 24 Jun 2026 15:22:35 +0800 Subject: [PATCH 2/3] remove cutlass --- benchmarks/benchmark_det_gemm.py | 35 ++++++++++------- csrc/cuda/gemm/det_gemm_kernel.cu | 2 +- docs/operators/det-gemm.md | 2 +- rl_engine/kernels/ops/cuda/__init__.py | 2 +- rl_engine/kernels/ops/cuda/matmul/__init__.py | 2 +- rl_engine/kernels/ops/cuda/matmul/det_gemm.py | 2 +- .../kernels/ops/pytorch/matmul/__init__.py | 2 +- .../kernels/ops/pytorch/matmul/det_gemm.py | 2 +- .../kernels/ops/triton/matmul/__init__.py | 2 +- .../kernels/ops/triton/matmul/det_gemm.py | 38 +++++++++++++++---- tests/test_det_gemm.py | 29 ++++++++------ 11 files changed, 79 insertions(+), 39 deletions(-) diff --git a/benchmarks/benchmark_det_gemm.py b/benchmarks/benchmark_det_gemm.py index 0716b8e..2e1ef1f 100644 --- a/benchmarks/benchmark_det_gemm.py +++ b/benchmarks/benchmark_det_gemm.py @@ -8,6 +8,7 @@ speedup. The naive CUDA kernel is correctness-first; a tensor-core pass follows. """ import argparse + import torch from rl_engine.kernels.ops.cuda.matmul import deterministic_gemm @@ -15,6 +16,7 @@ try: from rl_engine.kernels.ops.triton.matmul import deterministic_gemm_triton + _HAS_TRITON = True except ImportError: _HAS_TRITON = False @@ -23,10 +25,10 @@ WARMUP, ITERS = 10, 50 SHAPES = [ - ("qkv", 4096, 4096, 12288), - ("o_proj", 4096, 4096, 4096), - ("mlp_up", 4096, 4096, 14336), - ("mlp_dn", 4096, 14336, 4096), + ("qkv", 4096, 4096, 12288), + ("o_proj", 4096, 4096, 4096), + ("mlp_up", 4096, 4096, 14336), + ("mlp_dn", 4096, 14336, 4096), ("lm_head", 4096, 4096, 32000), ] @@ -61,15 +63,22 @@ def run(): def to_markdown(rows, dev, cap): - out = [f"## det_gemm overhead — {dev} (SM{cap[0]}{cap[1]})", "", - "| shape | M | K | N | cuBLAS tf32 | cuBLAS fp32 | det CUDA | det Triton | overhead |", - "|---|---|---|---|---|---|---|---|---|"] + out = [ + f"## det_gemm overhead — {dev} (SM{cap[0]}{cap[1]})", + "", + "| shape | M | K | N | cuBLAS tf32 | cuBLAS fp32 | det CUDA | det Triton | overhead |", + "|---|---|---|---|---|---|---|---|---|", + ] for n, M, K, N, t1, t2, t3, t4, ov in rows: - out.append(f"| {n} | {M} | {K} | {N} | {t1:.3f} | {t2:.3f} | {t3:.3f} | {t4:.3f} | {ov:.1f}x |") - out += ["", - "_Overhead = det CUDA vs cuBLAS (TF32 disabled). Naive CUDA kernel is " - "correctness-first; both det paths trade speed for bitwise " - "batch-invariance. Tensor-core pass is a follow-up (#146)._"] + out.append( + f"| {n} | {M} | {K} | {N} | {t1:.3f} | {t2:.3f} | {t3:.3f} | {t4:.3f} | {ov:.1f}x |" + ) + out += [ + "", + "_Overhead = det CUDA vs cuBLAS (TF32 disabled). Naive CUDA kernel is " + "correctness-first; both det paths trade speed for bitwise " + "batch-invariance. Tensor-core pass is a follow-up (#146)._", + ] return "\n".join(out) @@ -87,4 +96,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/csrc/cuda/gemm/det_gemm_kernel.cu b/csrc/cuda/gemm/det_gemm_kernel.cu index 1c38866..5a36e9d 100644 --- a/csrc/cuda/gemm/det_gemm_kernel.cu +++ b/csrc/cuda/gemm/det_gemm_kernel.cu @@ -108,4 +108,4 @@ torch::Tensor det_gemm_db(torch::Tensor a, torch::Tensor dc) { launch_naive(bf16(at), bf16(dc), bf16o(db), K, N, M, at::cuda::getCurrentCUDAStream()); return db; -} \ No newline at end of file +} diff --git a/docs/operators/det-gemm.md b/docs/operators/det-gemm.md index b2a8b53..cd48d26 100644 --- a/docs/operators/det-gemm.md +++ b/docs/operators/det-gemm.md @@ -50,4 +50,4 @@ Out: tensor-parallel GEMM (WS2), FP8, ROCm-native kernel (Triton covers ROCm). `det_gemm` trades speed for determinism. The naive CUDA kernel is slow by design; see `benchmarks/benchmark_det_gemm.py`. Overhead is reported vs cuBLAS with TF32 disabled (the fair, same-FP32-path baseline), not as a speedup. A -slower deterministic baseline is the accepted first milestone (#146). \ No newline at end of file +slower deterministic baseline is the accepted first milestone (#146). diff --git a/rl_engine/kernels/ops/cuda/__init__.py b/rl_engine/kernels/ops/cuda/__init__.py index 84ef18f..5f1ae8f 100644 --- a/rl_engine/kernels/ops/cuda/__init__.py +++ b/rl_engine/kernels/ops/cuda/__init__.py @@ -1,2 +1,2 @@ # append matmul to the existing imports -from . import attention, loss, norm, matmul # noqa: F401 \ No newline at end of file +from . import attention, loss, matmul, norm # noqa: F401 diff --git a/rl_engine/kernels/ops/cuda/matmul/__init__.py b/rl_engine/kernels/ops/cuda/matmul/__init__.py index 74657f5..ede9f55 100644 --- a/rl_engine/kernels/ops/cuda/matmul/__init__.py +++ b/rl_engine/kernels/ops/cuda/matmul/__init__.py @@ -1,4 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 from .det_gemm import DetGemmOp, deterministic_gemm -__all__ = ["DetGemmOp", "deterministic_gemm"] \ No newline at end of file +__all__ = ["DetGemmOp", "deterministic_gemm"] diff --git a/rl_engine/kernels/ops/cuda/matmul/det_gemm.py b/rl_engine/kernels/ops/cuda/matmul/det_gemm.py index b822c29..4778be9 100644 --- a/rl_engine/kernels/ops/cuda/matmul/det_gemm.py +++ b/rl_engine/kernels/ops/cuda/matmul/det_gemm.py @@ -56,4 +56,4 @@ def __call__(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: def deterministic_gemm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """Functional entry. a:[M,K] bf16, b:[K,N] bf16 -> [M,N] bf16.""" - return _DetGemmFn.apply(a, b) \ No newline at end of file + return _DetGemmFn.apply(a, b) diff --git a/rl_engine/kernels/ops/pytorch/matmul/__init__.py b/rl_engine/kernels/ops/pytorch/matmul/__init__.py index 5ebfa2e..e0bae08 100644 --- a/rl_engine/kernels/ops/pytorch/matmul/__init__.py +++ b/rl_engine/kernels/ops/pytorch/matmul/__init__.py @@ -1,4 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 from .det_gemm import NativeGemmOp, native_gemm -__all__ = ["NativeGemmOp", "native_gemm"] \ No newline at end of file +__all__ = ["NativeGemmOp", "native_gemm"] diff --git a/rl_engine/kernels/ops/pytorch/matmul/det_gemm.py b/rl_engine/kernels/ops/pytorch/matmul/det_gemm.py index 42aecce..9c617aa 100644 --- a/rl_engine/kernels/ops/pytorch/matmul/det_gemm.py +++ b/rl_engine/kernels/ops/pytorch/matmul/det_gemm.py @@ -24,4 +24,4 @@ def __call__(self, a, b): def native_gemm(a, b): - return torch.matmul(a, b) \ No newline at end of file + return torch.matmul(a, b) diff --git a/rl_engine/kernels/ops/triton/matmul/__init__.py b/rl_engine/kernels/ops/triton/matmul/__init__.py index 3dfc979..b674dc8 100644 --- a/rl_engine/kernels/ops/triton/matmul/__init__.py +++ b/rl_engine/kernels/ops/triton/matmul/__init__.py @@ -1,4 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 from .det_gemm import TritonDetGemmOp, deterministic_gemm_triton -__all__ = ["TritonDetGemmOp", "deterministic_gemm_triton"] \ No newline at end of file +__all__ = ["TritonDetGemmOp", "deterministic_gemm_triton"] diff --git a/rl_engine/kernels/ops/triton/matmul/det_gemm.py b/rl_engine/kernels/ops/triton/matmul/det_gemm.py index 17f149c..8318329 100644 --- a/rl_engine/kernels/ops/triton/matmul/det_gemm.py +++ b/rl_engine/kernels/ops/triton/matmul/det_gemm.py @@ -27,9 +27,21 @@ @triton.jit def _det_gemm_kernel( - a_ptr, b_ptr, c_ptr, M, N, K, - stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, ): # One program = one output tile, walks the whole K in fixed order. # No split-K -> K-accumulation order independent of M -> batch-invariant. @@ -60,9 +72,21 @@ def _triton_gemm(a, b): c = torch.empty((M, N), device=a.device, dtype=a.dtype) grid = (triton.cdiv(M, _BLOCK_M), triton.cdiv(N, _BLOCK_N)) _det_gemm_kernel[grid]( - a, b, c, M, N, K, - a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), - BLOCK_M=_BLOCK_M, BLOCK_N=_BLOCK_N, BLOCK_K=_BLOCK_K, + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_M=_BLOCK_M, + BLOCK_N=_BLOCK_N, + BLOCK_K=_BLOCK_K, ) return c @@ -97,4 +121,4 @@ def __call__(self, a, b): def deterministic_gemm_triton(a, b): - return _TritonDetGemmFn.apply(a, b) \ No newline at end of file + return _TritonDetGemmFn.apply(a, b) diff --git a/tests/test_det_gemm.py b/tests/test_det_gemm.py index 2f9b3ac..440dddc 100644 --- a/tests/test_det_gemm.py +++ b/tests/test_det_gemm.py @@ -14,6 +14,7 @@ try: from rl_engine.kernels.ops.triton.matmul import deterministic_gemm_triton + _HAS_TRITON = True except ImportError: _HAS_TRITON = False @@ -44,7 +45,8 @@ def test_forward_batch_invariance(name, gemm): b = _rand(K, N) row = _rand(1, K) out1 = gemm(row, b) - big = _rand(512, K); big[0] = row[0] + big = _rand(512, K) + big[0] = row[0] outN = gemm(big, b) assert torch.equal(out1[0], outN[0]), f"{name}: forward batch-invariance broken" @@ -92,7 +94,8 @@ def test_backward_batch_invariance(name, gemm): row = _rand(1, K).requires_grad_(True) gemm(row, b).sum().backward() g1 = row.grad.clone() - big = _rand(256, K); big[0] = row.detach()[0] + big = _rand(256, K) + big[0] = row.detach()[0] big.requires_grad_(True) gemm(big, b).sum().backward() assert torch.equal(g1[0], big.grad[0]), f"{name}: backward dA batch-invariance broken" @@ -115,20 +118,24 @@ def test_backward_correctness(name, gemm): @pytest.mark.parametrize("name,gemm", _BACKENDS) -@pytest.mark.parametrize("shape", [ - (4096, 4096, 12288), # qkv - (4096, 4096, 4096), # o_proj - (4096, 4096, 14336), # mlp_up - (4096, 14336, 4096), # mlp_dn - (4096, 4096, 32000), # lm_head -]) +@pytest.mark.parametrize( + "shape", + [ + (4096, 4096, 12288), # qkv + (4096, 4096, 4096), # o_proj + (4096, 4096, 14336), # mlp_up + (4096, 14336, 4096), # mlp_dn + (4096, 4096, 32000), # lm_head + ], +) def test_target_shapes_invariance(name, gemm, shape): # Standard-Transformer projection shapes stay batch-invariant. torch.manual_seed(6) M, K, N = shape b = _rand(K, N) row = _rand(1, K) - big = _rand(64, K); big[0] = row[0] + big = _rand(64, K) + big[0] = row[0] assert torch.equal( gemm(row, b)[0], gemm(big, b)[0] - ), f"{name}: batch-invariance broken at shape {shape}" \ No newline at end of file + ), f"{name}: batch-invariance broken at shape {shape}" From 45a3836dfbc7c8d6456120f91a8449c9023ccf68 Mon Sep 17 00:00:00 2001 From: vensen Date: Wed, 24 Jun 2026 15:18:29 +0000 Subject: [PATCH 3/3] update cu --- csrc/cuda/gemm/det_gemm_kernel.cu | 315 +++++++++++++++++++++++++----- 1 file changed, 262 insertions(+), 53 deletions(-) diff --git a/csrc/cuda/gemm/det_gemm_kernel.cu b/csrc/cuda/gemm/det_gemm_kernel.cu index 5a36e9d..9b846b3 100644 --- a/csrc/cuda/gemm/det_gemm_kernel.cu +++ b/csrc/cuda/gemm/det_gemm_kernel.cu @@ -1,62 +1,269 @@ // SPDX-License-Identifier: Apache-2.0 // Copyright (c) 2026 RL-Kernel Contributors - -// WS1 - Batch-invariant deterministic GEMM (hand-written, no CUTLASS). +// csrc/cuda/gemm/det_gemm_kernel.cu // -// First-milestone naive implementation: one thread computes one output element, -// walking the whole K dimension in a fixed loop order with FP32 accumulation. -// NO split-K, NO shape-based kernel selection -> a row's reduction order is -// independent of the batch (M) dimension, so the output is bitwise-invariant to -// batch size, chunked-prefill splitting, and padding layout. +// WS1 - Batch-invariant deterministic GEMM (hand-written, no CUTLASS). // -// This is intentionally slow (correctness + invariance first, per #146). A -// tensor-core (mma.sync / ldmatrix) optimization, matching the -// prefix_shared_attention.cu style, is a follow-up within this same file. +// SM90 path: TMA load + mma.sync (m16n8k16) tensor cores, single-CTA-per-tile, +// fixed K-accumulation order, NO split-K -> batch-invariant. +// Fallback : naive FP32 scalar kernel (also the correctness ground truth). // +// Both: BF16 in / FP32 accum / no TF32 / no split-K. // fwd: C = A @ B | dA = dC @ B^T | dB = A^T @ dC -// Backward reuses the same kernel on transposed operands, so the gradients -// inherit the same invariance. +// Backward reuses the forward kernel on transposed operands. #include #include #include +#if defined(RL_KERNEL_ENABLE_SM90) +#include "../utils/tma_utils.cuh" +#include +#endif + namespace { using nv_bf16 = __nv_bfloat16; -constexpr int TILE = 16; // 16x16 thread block - __host__ __device__ constexpr int cdiv(int a, int b) { return (a + b - 1) / b; } -// C[M,N] = A[M,K] @ B[K,N], all row-major, BF16 in / FP32 accumulate / BF16 out. -// Each thread owns one C[row, col]; the K loop order is fixed and identical for -// every (row, col) regardless of M -> batch-invariant. +// Naive FP32 scalar kernel (SM80 fallback + ground truth). Batch-invariant by +// construction: one thread = one output element, fixed ascending K loop. +constexpr int NAIVE_TILE = 16; + __global__ void det_gemm_naive(const nv_bf16* __restrict__ A, const nv_bf16* __restrict__ B, nv_bf16* __restrict__ C, int M, int N, int K) { - const int row = blockIdx.y * TILE + threadIdx.y; - const int col = blockIdx.x * TILE + threadIdx.x; + const int row = blockIdx.y * NAIVE_TILE + threadIdx.y; + const int col = blockIdx.x * NAIVE_TILE + threadIdx.x; if (row >= M || col >= N) return; - - float acc = 0.0f; // FP32 accumulation - // Fixed ascending K order, no split-K, no atomics. Deterministic. - for (int k = 0; k < K; ++k) { - float a = __bfloat162float(A[row * K + k]); - float b = __bfloat162float(B[k * N + col]); - acc += a * b; - } + float acc = 0.0f; + for (int k = 0; k < K; ++k) + acc += __bfloat162float(A[row * K + k]) * __bfloat162float(B[k * N + col]); C[row * N + col] = __float2bfloat16(acc); } void launch_naive(const nv_bf16* A, const nv_bf16* B, nv_bf16* C, int M, int N, int K, cudaStream_t stream) { - dim3 block(TILE, TILE); - dim3 grid(cdiv(N, TILE), cdiv(M, TILE)); + dim3 block(NAIVE_TILE, NAIVE_TILE); + dim3 grid(cdiv(N, NAIVE_TILE), cdiv(M, NAIVE_TILE)); det_gemm_naive<<>>(A, B, C, M, N, K); } +#if defined(RL_KERNEL_ENABLE_SM90) +// ============================================================================ +// SM90 path: TMA load + mma.sync. C[M,N] = A[M,K] @ B[K,N]. +// +// mma.sync.m16n8k16.row.col needs A row-major [m16,k16] and B col-major +// [n8,k16] (i.e. B operand indexed as [n, k]). Our B is row-major [K,N], so we +// load a B tile of shape [BK, BN] but feed the mma the (n,k) operand by +// addressing smem as B[k][n] -> we ldmatrix B with the same trick as the logp +// kernel (which loads W[V,D] = [n, k] directly). To match that validated path, +// A tile is [BM, BK] (row=token, col=k), B tile is [BN, BK] (row=n, col=k), +// which is exactly B^T. So the SM90 kernel computes C = A @ (Bt)^T where Bt is +// the [BN,BK] tile = B[k,n] transposed. We materialize that by giving TMA a +// descriptor over B viewed as [N,K]... but B is [K,N] row-major. +// +// To keep it simple and provably correct, the SM90 forward kernel REQUIRES its +// B operand already in [N,K] layout (row=n, col=k). The host wrapper passes +// B^T (contiguous) so the kernel sees Bt[N,K]; mathematically +// C = A[M,K] @ B[K,N] = A @ (Bt[N,K])^T, and the per-tile mma contracts over K +// in fixed order. This mirrors the logp kernel's W[V,D] @ hidden[N,D] pattern. +// ============================================================================ +constexpr int BM = 64; // rows (M) per CTA tile +constexpr int BN = 64; // cols (N) per CTA tile +constexpr int BK = 32; // K slice streamed per TMA load +constexpr int WARPS = 4; +constexpr int WG_THREADS = WARPS * 32; // 128 +constexpr int STAGES = 2; + +constexpr int MMA_M = 16, MMA_N = 8, MMA_K = 16; +constexpr int WARP_M = BM / WARPS; // 16 -> 1 m-tile per warp +constexpr int M_TILES = WARP_M / MMA_M; // 1 +constexpr int N_TILES = BN / MMA_N; // 8 +constexpr int K_TILES = BK / MMA_K; // 2 +constexpr int KK_GROUPS = BK / 32; // 1 (ldmatrix.x4 spans 32 cols) + +__device__ __forceinline__ void ldmatrix_x4(uint32_t regs[4], uint32_t addr) { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];" + : "=r"(regs[0]), "=r"(regs[1]), "=r"(regs[2]), "=r"(regs[3]) + : "r"(addr)); +} +__device__ __forceinline__ void mma_m16n8k16(const uint32_t A[4], const uint32_t B[2], float D[4]) { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(D[0]), "f"(D[1]), "f"(D[2]), "f"(D[3])); +} + +// A: row-major [M,K] via TMA tile [BM,BK]. Bt: row-major [N,K] via TMA tile +// [BN,BK]. C: row-major [M,N]. Each CTA owns one [BM,BN] output tile and +// walks the full K in fixed order (no split-K). +__global__ void det_gemm_sm90_kernel(const __grid_constant__ CUtensorMap a_tmap, + const __grid_constant__ CUtensorMap bt_tmap, + nv_bf16* __restrict__ C, + int M, int N, int K) { + const int tid = threadIdx.x; + const int warp = tid / 32; + const int lane = tid % 32; + const int row_base = blockIdx.y * BM; + const int col_base = blockIdx.x * BN; + const int kd = K / BK; // K validated multiple of BK on host + + extern __shared__ __align__(1024) char smem[]; + nv_bf16* sA = reinterpret_cast(smem); + nv_bf16* sB = reinterpret_cast(sA + STAGES * BM * BK); + int* mbar_base = reinterpret_cast(sB + STAGES * BN * BK); + + const uint32_t sA_base = static_cast(__cvta_generic_to_shared(sA)); + const uint32_t sB_base = static_cast(__cvta_generic_to_shared(sB)); + int mbar[STAGES]; +#pragma unroll + for (int s = 0; s < STAGES; ++s) + mbar[s] = static_cast(__cvta_generic_to_shared(mbar_base + 2 * s)); + + if (tid == 0) { +#pragma unroll + for (int s = 0; s < STAGES; ++s) mbarrier_init(mbar[s], 1); + asm volatile("fence.mbarrier_init.release.cluster;"); + } + __syncthreads(); + + const uint32_t tile_bytes = (BM * BK + BN * BK) * sizeof(nv_bf16); + + auto issue_load = [&](int k) { + const int buf = k % STAGES; + const int k_off = k * BK; + tma_2d_g2s(static_cast(sA_base + buf * BM * BK * sizeof(nv_bf16)), + &a_tmap, k_off, row_base, mbar[buf]); + tma_2d_g2s(static_cast(sB_base + buf * BN * BK * sizeof(nv_bf16)), + &bt_tmap, k_off, col_base, mbar[buf]); + mbarrier_arrive_expect_tx(mbar[buf], tile_bytes); + }; + + int phase[STAGES]; +#pragma unroll + for (int s = 0; s < STAGES; ++s) phase[s] = 0; + + // accumulators: this warp's M_TILES m-tiles x N_TILES n-tiles + float acc[M_TILES][N_TILES][4]; +#pragma unroll + for (int mi = 0; mi < M_TILES; ++mi) +#pragma unroll + for (int n = 0; n < N_TILES; ++n) + acc[mi][n][0] = acc[mi][n][1] = acc[mi][n][2] = acc[mi][n][3] = 0.0f; + + if (tid == 0) +#pragma unroll + for (int s = 0; s < STAGES - 1; ++s) + if (s < kd) issue_load(s); + + for (int k = 0; k < kd; ++k) { // fixed ascending K order, NO split-K + const int buf = k % STAGES; + if (tid == 0 && k + (STAGES - 1) < kd) issue_load(k + (STAGES - 1)); + mbarrier_wait(mbar[buf], phase[buf]); + phase[buf] ^= 1; + __syncthreads(); + + const uint32_t sA_buf = sA_base + buf * BM * BK * sizeof(nv_bf16); + const uint32_t sB_buf = sB_base + buf * BN * BK * sizeof(nv_bf16); + + // Load A operand (this warp's rows), all K-steps. Same addressing as logp. + uint32_t A[M_TILES][K_TILES][4]; +#pragma unroll + for (int mi = 0; mi < M_TILES; ++mi) { + const int row0 = warp * WARP_M + mi * MMA_M + (lane % 16); +#pragma unroll + for (int kt = 0; kt < K_TILES; ++kt) { + const uint32_t a_addr = + sA_buf + (row0 * BK + (lane / 16) * 8 + kt * MMA_K) * sizeof(nv_bf16); + ldmatrix_x4(A[mi][kt], a_addr); + } + } + + // Load B operand (all n-tiles) and contract. Same addressing as logp's W. +#pragma unroll + for (int n = 0; n < N_TILES; ++n) { +#pragma unroll + for (int kk = 0; kk < KK_GROUPS; ++kk) { + uint32_t b4[4]; + const uint32_t b_addr = + sB_buf + ((n * MMA_N + (lane % 8)) * BK + (lane / 8) * 8 + kk * 32) * sizeof(nv_bf16); + ldmatrix_x4(b4, b_addr); + const uint32_t B0[2] = {b4[0], b4[1]}; + const uint32_t B1[2] = {b4[2], b4[3]}; +#pragma unroll + for (int mi = 0; mi < M_TILES; ++mi) { + mma_m16n8k16(A[mi][2 * kk + 0], B0, acc[mi][n]); + mma_m16n8k16(A[mi][2 * kk + 1], B1, acc[mi][n]); + } + } + } + __syncthreads(); + } + + // Epilogue: write acc to C (row-major [M,N]). mma m16n8k16 output layout. +#pragma unroll + for (int mi = 0; mi < M_TILES; ++mi) { + const int row = row_base + warp * WARP_M + mi * MMA_M + lane / 4; +#pragma unroll + for (int n = 0; n < N_TILES; ++n) { + const int col = col_base + n * MMA_N + (lane % 4) * 2; + if (row < M && col + 1 < N) { + C[row * N + col + 0] = __float2bfloat16(acc[mi][n][0]); + C[row * N + col + 1] = __float2bfloat16(acc[mi][n][1]); + } + if (row + 8 < M && col + 1 < N) { + C[(row + 8) * N + col + 0] = __float2bfloat16(acc[mi][n][2]); + C[(row + 8) * N + col + 1] = __float2bfloat16(acc[mi][n][3]); + } + } + } +} + +// noswizzle TMA descriptor (kernel uses plain row-major ldmatrix addressing). +inline void init_tmap_noswizzle(CUtensorMap* tmap, const nv_bf16* gmem, + uint64_t height, uint64_t width, + uint32_t box_h, uint32_t box_w) { + uint64_t size[2] = {width, height}; + uint64_t stride[1] = {width * sizeof(nv_bf16)}; + uint32_t box[2] = {box_w, box_h}; + uint32_t estride[2] = {1, 1}; + CUresult res = cuTensorMapEncodeTiled( + tmap, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 2, (void*)gmem, size, stride, box, estride, + CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE, + CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + TORCH_CHECK(res == CUDA_SUCCESS, "det_gemm: cuTensorMapEncodeTiled failed"); +} + +// Launch SM90 GEMM. A:[M,K] row-major, Bt:[N,K] row-major (= B transposed). +// Requires M%BM==0, N%BN==0, K%BK==0 (host pads/falls back otherwise). +bool launch_sm90(const nv_bf16* A, const nv_bf16* Bt, nv_bf16* C, + int M, int N, int K, cudaStream_t stream) { + if (M % BM != 0 || N % BN != 0 || K % BK != 0) return false; // fall back + + CUtensorMap a_tmap, bt_tmap; + init_tmap_noswizzle(&a_tmap, A, M, K, BM, BK); + init_tmap_noswizzle(&bt_tmap, Bt, N, K, BN, BK); + + const int smem = STAGES * (BM * BK + BN * BK) * sizeof(nv_bf16) + STAGES * 8; + if (smem > 48 * 1024) + cudaFuncSetAttribute(det_gemm_sm90_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + + dim3 grid(cdiv(N, BN), cdiv(M, BM)); + det_gemm_sm90_kernel<<>>(a_tmap, bt_tmap, C, M, N, K); + return true; +} +#endif // RL_KERNEL_ENABLE_SM90 + +int sm_major() { + int dev = 0; cudaGetDevice(&dev); + cudaDeviceProp p{}; cudaGetDeviceProperties(&p, dev); + return p.major; +} inline const nv_bf16* bf16(const torch::Tensor& t) { return reinterpret_cast(t.data_ptr()); } @@ -68,44 +275,46 @@ void check_in(const torch::Tensor& t, const char* n) { TORCH_CHECK(t.scalar_type() == torch::kBFloat16, n, " must be bf16"); } +// Core dispatch: C = A[M,K] @ B[K,N]. SM90 kernel needs B^T ([N,K]); it is +// materialized contiguous here. Falls back to naive on SM80 or odd shapes. +torch::Tensor gemm_dispatch(const torch::Tensor& a, const torch::Tensor& b) { + const int M = a.size(0), K = a.size(1), N = b.size(1); + auto c = torch::empty({M, N}, a.options()); + auto stream = at::cuda::getCurrentCUDAStream(); + +#if defined(RL_KERNEL_ENABLE_SM90) + if (sm_major() >= 9) { + auto bt = b.t().contiguous(); // [N,K] + if (launch_sm90(bf16(a), bf16(bt), bf16o(c), M, N, K, stream)) return c; + // else fall through to naive (shape not tile-aligned) + } +#endif + launch_naive(bf16(a), bf16(b), bf16o(c), M, N, K, stream); + return c; } -// fwd: C = A @ B +} // anonymous namespace + torch::Tensor det_gemm_fwd(torch::Tensor a, torch::Tensor b) { check_in(a, "A"); check_in(b, "B"); a = a.contiguous(); b = b.contiguous(); TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "det_gemm_fwd: expect 2D [M,K]@[K,N]"); - const int M = a.size(0), K = a.size(1); - TORCH_CHECK(b.size(0) == K, "det_gemm_fwd: K mismatch"); - const int N = b.size(1); - auto c = torch::empty({M, N}, a.options()); - launch_naive(bf16(a), bf16(b), bf16o(c), M, N, K, - at::cuda::getCurrentCUDAStream()); - return c; + TORCH_CHECK(b.size(0) == a.size(1), "det_gemm_fwd: K mismatch"); + return gemm_dispatch(a, b); } -// dA = dC @ B^T -> forward GEMM on materialized transpose of B torch::Tensor det_gemm_da(torch::Tensor dc, torch::Tensor b) { check_in(dc, "dC"); check_in(b, "B"); dc = dc.contiguous(); - auto bt = b.t().contiguous(); // [N, K] - const int M = dc.size(0), N = dc.size(1), K = bt.size(1); - TORCH_CHECK(bt.size(0) == N, "det_gemm_da: N mismatch"); - auto da = torch::empty({M, K}, dc.options()); // [M, K] - launch_naive(bf16(dc), bf16(bt), bf16o(da), M, K, N, - at::cuda::getCurrentCUDAStream()); - return da; + auto bt = b.t().contiguous(); // [N,K]; dA[M,K] = dC[M,N] @ bt[N,K] + TORCH_CHECK(bt.size(0) == dc.size(1), "det_gemm_da: N mismatch"); + return gemm_dispatch(dc, bt); } -// dB = A^T @ dC -> forward GEMM on materialized transpose of A torch::Tensor det_gemm_db(torch::Tensor a, torch::Tensor dc) { check_in(a, "A"); check_in(dc, "dC"); dc = dc.contiguous(); - auto at = a.t().contiguous(); // [K, M] - const int K = at.size(0), M = at.size(1), N = dc.size(1); - TORCH_CHECK(dc.size(0) == M, "det_gemm_db: M mismatch"); - auto db = torch::empty({K, N}, a.options()); // [K, N] - launch_naive(bf16(at), bf16(dc), bf16o(db), K, N, M, - at::cuda::getCurrentCUDAStream()); - return db; + auto at = a.t().contiguous(); // [K,M]; dB[K,N] = at[K,M] @ dC[M,N] + TORCH_CHECK(dc.size(0) == at.size(1), "det_gemm_db: M mismatch"); + return gemm_dispatch(at, dc); }