[WS1][kernels] Batch-invariant deterministic GEMM (fwd + bwd)#180
Draft
Flink-ddd wants to merge 1 commit into
Draft
[WS1][kernels] Batch-invariant deterministic GEMM (fwd + bwd)#180Flink-ddd wants to merge 1 commit into
Flink-ddd wants to merge 1 commit into
Conversation
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Draft / WIP for #146. Implements the single-rank batch-invariant deterministic GEMM op (forward + backward) — one op in the WS1 forward chain. A row's output is invariant to batch size, chunked-prefill splitting, and padding layout.
This is PR2 of the planned series (design note → kernel → tests → LM-head wiring → benchmark). Scope here is the kernel + op wiring + invariance tests.
Scope
In scope:
Out of scope (per #146):
Invariance contract
The kernel pins the tile shape and fixes the K-accumulation order so the same (M, N, K) row produces the same output regardless of the surrounding batch, chunked-prefill splitting, or padding layout. No split-K / heuristic kernel selection that varies with batch shape.
Implementation note (WIP)
The kernel is being implemented as a hand-written mma.sync (m16n8k16) GEMM, matching the existing prefix_shared_attention.cu style (ldmatrix / cp.async / fixed-order register accumulation) to keep the repo dependency-free and consistent with the other hand-written kernels. The hand-written kernel is naturally batch-invariant: the K-accumulation order is the loop order, fixed at compile time, with no shape-based kernel selection.
Forward lands first; the backward GEMMs reuse the same fixed-tile / fixed-accumulation path with transposed operand layouts, so invariance is inherited rather than re-proven.
Files
Tests
tests/test_det_gemm.py (no dependency on the #108 harness):
Verification
Built and validated on SM80 (A100); SM90 (H100) to follow. The bitwise-invariance tests are the hard gate; correctness uses placeholder tolerances pending #108.
Follow-ups
PR3: swap placeholder tolerances for the #108 threshold table; full QKV / MLP / LM-head shape sweep.
PR4: wire one real projection (LM head) through the deterministic path.
PR5: benchmark vs cuBLAS, document overhead + supported shapes.
Backward-pass invariance validation aligns with the WS1 backward-consistency issue.