Skip to content

[WS1][kernels] Batch-invariant deterministic GEMM (fwd + bwd)#180

Draft
Flink-ddd wants to merge 1 commit into
mainfrom
feat/add-ws1-gemm
Draft

[WS1][kernels] Batch-invariant deterministic GEMM (fwd + bwd)#180
Flink-ddd wants to merge 1 commit into
mainfrom
feat/add-ws1-gemm

Conversation

@Flink-ddd

Copy link
Copy Markdown
Collaborator

Summary

Draft / WIP for #146. Implements the single-rank batch-invariant deterministic GEMM op (forward + backward) — one op in the WS1 forward chain. A row's output is invariant to batch size, chunked-prefill splitting, and padding layout.
This is PR2 of the planned series (design note → kernel → tests → LM-head wiring → benchmark). Scope here is the kernel + op wiring + invariance tests.

Scope

In scope:

  • Single-rank forward C = A @ B and backward dA = dC @ Bᵀ, dB = Aᵀ @ dC.
  • BF16 inputs, FP32 accumulation, no TF32.
  • Fixed tile, no split-K — a row's reduction order is independent of the batch (M) dimension.
  • SM80 (A100) first; SM90 (H100) path follows.

Out of scope (per #146):

  • Tensor-parallel GEMM → WS2.
  • FP8, ROCm.
  • Performance tuning — correctness and invariance first; a slower deterministic baseline is the accepted first milestone.

Invariance contract

The kernel pins the tile shape and fixes the K-accumulation order so the same (M, N, K) row produces the same output regardless of the surrounding batch, chunked-prefill splitting, or padding layout. No split-K / heuristic kernel selection that varies with batch shape.

Implementation note (WIP)

The kernel is being implemented as a hand-written mma.sync (m16n8k16) GEMM, matching the existing prefix_shared_attention.cu style (ldmatrix / cp.async / fixed-order register accumulation) to keep the repo dependency-free and consistent with the other hand-written kernels. The hand-written kernel is naturally batch-invariant: the K-accumulation order is the loop order, fixed at compile time, with no shape-based kernel selection.
Forward lands first; the backward GEMMs reuse the same fixed-tile / fixed-accumulation path with transposed operand layouts, so invariance is inherited rather than re-proven.

Files

  • csrc/cuda/gemm/det_gemm_kernel.cu — entry points (det_gemm_fwd / det_gemm_da / det_gemm_db), dispatch by compute capability.
  • csrc/ops.cpp — pybind registration (3 ops).
  • rl_engine/kernels/ops/cuda/matmul/det_gemm.py — autograd wrapper + DetGemmOp (registry backend, no PyTorch fallback — a generic matmul would break invariance).
  • rl_engine/kernels/registry.py — CUDA_DET_GEMM backend + det_gemm dispatch entry.
  • tests/test_det_gemm.py — invariance + correctness.
  • benchmarks/benchmark_det_gemm.py — overhead vs cuBLAS baseline.
  • setup.py — det_gemm CUDA source wired into the _C extension.

Tests

tests/test_det_gemm.py (no dependency on the #108 harness):

  • Forward batch-invariance (same row, different batch sizes) — bitwise (torch.equal).
  • Chunked-prefill split == full GEMM — bitwise.
  • Padding rows do not affect valid rows — bitwise.
  • Backward dA batch-invariance — bitwise.
  • Forward / backward correctness vs FP32 reference — placeholder tolerance, to be replaced by the [WS1] Ground-truth harness + numerical contract for batch-invariant ops #108 numerical contract in PR3.

Verification

Built and validated on SM80 (A100); SM90 (H100) to follow. The bitwise-invariance tests are the hard gate; correctness uses placeholder tolerances pending #108.

pip install -e .
pytest tests/test_det_gemm.py -v

Follow-ups

PR3: swap placeholder tolerances for the #108 threshold table; full QKV / MLP / LM-head shape sweep.
PR4: wire one real projection (LM head) through the deterministic path.
PR5: benchmark vs cuBLAS, document overhead + supported shapes.

Backward-pass invariance validation aligns with the WS1 backward-consistency issue.

@coderabbitai

coderabbitai Bot commented Jun 22, 2026

Copy link
Copy Markdown

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5664d08b-26ee-43b0-895d-c8f9dfc72b11

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/add-ws1-gemm

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@Flink-ddd Flink-ddd changed the title WS1][kernels] Batch-invariant deterministic GEMM (fwd + bwd) [WS1][kernels] Batch-invariant deterministic GEMM (fwd + bwd) Jun 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant