feat(ws1): NativeLMHeadOp pure-PyTorch ground-truth reference + numerical contract tests#170
Conversation
WS1 ground-truth language-model-head op for issue RL-Align#108 (Qwen3-8B output projection, vocab=151936 x hidden=4096, tie_word_embeddings=false, no bias): - NativeLMHeadOp: out = hidden @ weight.t() (+ bias), a reduction over hidden exposing the forward / forward_fp32 dual-path contract (fp32 ground truth + dtype-behavior path); weight is HF [out, in] and transposed internally (the one difference from the bare matmul op); pure function, no in-place mutation. - register PYTORCH_NATIVE_LM_HEAD in OpBackend and the cuda/rocm/cpu priority maps. - tests/test_lm_head.py: fp32 correctness vs naive matmul (bitwise), bf16/fp16 dtype-path accuracy (relative-to-peak tolerance, bf16 ~0.37% / fp16 ~0.05% of output peak), bias semantics, Axis-A batch invariance (slice + padding, all dtypes) under a pinned single-thread reduction so the CPU GEMM K-split is M-independent, purity, closed-form gradient flow to hidden/weight, registry dispatch, and a GPU-only real-shape smoke test (vocab=151936, hidden=4096). - docs/operators/lm_head.md + nav/index wiring.
📝 WalkthroughWalkthroughAdds a new ChangesLM Head Operator
Sequence DiagramsequenceDiagram
participant Caller
participant KernelRegistry
participant NativeLMHeadOp
participant _lm_head
Caller->>KernelRegistry: get_op("lm_head")
KernelRegistry-->>Caller: NativeLMHeadOp instance
Caller->>NativeLMHeadOp: forward(hidden, weight, bias)
NativeLMHeadOp->>_lm_head: compute_dtype=hidden.dtype, output_dtype=hidden.dtype
_lm_head->>_lm_head: cast hidden/weight to compute_dtype
_lm_head->>_lm_head: out = hidden @ weight.t()
_lm_head->>_lm_head: add bias if present
_lm_head->>_lm_head: cast to output_dtype
_lm_head-->>NativeLMHeadOp: logits
NativeLMHeadOp-->>Caller: logits [batch, seq, vocab]
Caller->>NativeLMHeadOp: forward_fp32(hidden, weight, bias)
NativeLMHeadOp->>_lm_head: compute_dtype=float32, output_dtype=float32, strict_fp32=True
_lm_head-->>NativeLMHeadOp: logits float32
NativeLMHeadOp-->>Caller: logits [batch, seq, vocab] fp32
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
docs/.nav.yml (1)
13-17: 🧹 Nitpick | 🔵 Trivial | ⚡ Quick winMaintain alphabetical ordering of operator entries in navigation.
The new
operators/lm_head.mdentry should be inserted in alphabetical order betweenoperators/grpo-loss.mdandoperators/ratio-kl.md, not appended at the end. This keeps navigation consistent and easier to scan.📖 Proposed ordering fix
- Operators: - operators/README.md - operators/fused-logp.md - operators/grpo-loss.md + - operators/lm_head.md - operators/ratio-kl.md - - operators/sampling.md + - operators/sampling.md🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@docs/.nav.yml` around lines 13 - 17, The operators list in the navigation file has an alphabetically misplaced entry. Move the operators/lm_head.md entry to its correct alphabetical position between operators/grpo-loss.md and operators/ratio-kl.md, as it should come after "grpo-loss" and before "ratio-kl" when entries are sorted alphabetically. Remove it from its current position at the end of the operators list and insert it in the proper alphabetical order to maintain consistency and readability of the navigation structure.docs/operators/README.md (1)
21-26: 🧹 Nitpick | 🔵 Trivial | ⚡ Quick winMaintain alphabetical ordering of operator index entries.
The new [LM Head] entry should be inserted in alphabetical order between [GRPO Loss] and [Policy Ratio + KL Penalty], not after [Sampling]. This keeps the index consistent and easier to navigate.
📖 Proposed ordering fix
- [Fused LogP](fused-logp.md) - [GRPO Loss](grpo-loss.md) +- [LM Head](lm_head.md) - [Policy Ratio + KL Penalty](ratio-kl.md) - [Sampling](sampling.md) -- [LM Head](lm_head.md) - [Operator Doc Template](../contributing/operator-doc-template.md)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@docs/operators/README.md` around lines 21 - 26, Reorder the operator index entries in the README.md file to maintain alphabetical ordering. Move the [LM Head] link from its current position after [Sampling] to its correct alphabetical position between [GRPO Loss] and [Policy Ratio + KL Penalty]. Ensure all entries in the list are ordered alphabetically by their display names to keep the index consistent and easy to navigate.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/test_lm_head.py`:
- Around line 196-200: The _enough_gpu_memory function can fail test collection
when torch.cuda.mem_get_info() raises a RuntimeError in partially configured
CUDA environments. Wrap the torch.cuda.mem_get_info() call in a try-except block
that catches RuntimeError and returns False when the error is caught, allowing
tests to be skipped gracefully instead of failing during collection.
- Around line 86-87: The test functions test_native_lm_head_dtype_path_accuracy,
and the two other similar test functions at lines 130 and 141 unconditionally
parametrize with torch.float16, which can cause failures on CPU hardware where
half-precision matmul is not guaranteed to be supported. Extract the dtype
tuples used in the parametrize decorators into module-level constants, then
replace the direct parametrization with pytest.param calls that include
conditional runtime checks to skip torch.float16 on CPU backends. Apply this
pattern consistently across all three affected test functions to prevent
backend-dependent test failures.
---
Nitpick comments:
In `@docs/.nav.yml`:
- Around line 13-17: The operators list in the navigation file has an
alphabetically misplaced entry. Move the operators/lm_head.md entry to its
correct alphabetical position between operators/grpo-loss.md and
operators/ratio-kl.md, as it should come after "grpo-loss" and before "ratio-kl"
when entries are sorted alphabetically. Remove it from its current position at
the end of the operators list and insert it in the proper alphabetical order to
maintain consistency and readability of the navigation structure.
In `@docs/operators/README.md`:
- Around line 21-26: Reorder the operator index entries in the README.md file to
maintain alphabetical ordering. Move the [LM Head] link from its current
position after [Sampling] to its correct alphabetical position between [GRPO
Loss] and [Policy Ratio + KL Penalty]. Ensure all entries in the list are
ordered alphabetically by their display names to keep the index consistent and
easy to navigate.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 18702bfa-053c-4402-b236-414de2b74d14
📒 Files selected for processing (7)
docs/.nav.ymldocs/operators/README.mddocs/operators/lm_head.mdrl_engine/kernels/ops/pytorch/linear/__init__.pyrl_engine/kernels/ops/pytorch/linear/lm_head.pyrl_engine/kernels/registry.pytests/test_lm_head.py
- Gate CPU float16 matmul parametrizations behind a runtime support probe so unsupported backends skip rather than fail collection. - Harden _enough_gpu_memory against RuntimeError from mem_get_info in partially-configured CUDA environments. - Add docstrings across the op and test suite to meet coverage. - Sort lm_head entries alphabetically in operator nav/README.
Keep consistent with other PyTorch native ops.
| output_dtype: torch.dtype, | ||
| ) -> torch.Tensor: | ||
| """Core matmul: cast to ``compute_dtype``, project, optionally add bias, cast out.""" | ||
| h = hidden.to(compute_dtype) |
There was a problem hiding this comment.
One precision nit: forward_fp32() casts the inputs to fp32, but the matmul can still run under autocast or CUDA TF32 settings, so it may not be a true fp32 golden reference. Since downstream kernels will compare against this path, could we explicitly disable autocast/TF32 around the matmul, or document that required precision context?
There was a problem hiding this comment.
Thanks, @inaniloquentee .
forward_fp32() now disables autocast and CUDA TF32 around the matmul, while saving/restoring the previous TF32 setting so global state does not leak. The regular forward() path is unchanged and still
follows the ambient precision context.
I also added regression coverage for this:
- CPU autocast case: forward_fp32 remains equal to the fp32 reference and restores the TF32 flag.
- CUDA TF32 case: forward_fp32 is checked against a higher-precision reference when CUDA is available.
Docs were updated to note the precision-context behavior.
Wrap the forward_fp32 matmul to disable autocast and CUDA TF32 (saving and restoring the global allow_tf32 flag) so the fp32 golden path is not silently downcast by the caller's ambient precision context. The dtype-behavior forward path is left to follow ambient precision intentionally. Add tests: forward_fp32 stays true fp32 under ambient autocast and restores the TF32 flag (CPU); numerically beats a TF32 matmul (GPU). Pin TF32 off in the fp32-vs-naive bitwise test. Sync docs accordingly.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
docs/operators/lm_head.md (1)
108-115: 🧹 Nitpick | 🔵 Trivial | ⚡ Quick winBreak the test coverage list into separate sentences for readability.
The test coverage section is accurate and comprehensive, but lines 108–115 condense nine distinct test concerns into a single long sentence that impairs readability. Consider splitting into 2–3 sentences by major category (e.g., "Covers: [precision & dtype behavior]...." then "Also covers: [batch invariance & purity]..." then "GPU-only smoke test...").
📝 Suggested restructuring
- Covers: fp32 correctness vs naive matmul (bitwise, with ambient TF32 pinned off), - `forward_fp32` precision-context safety (true fp32 under ambient autocast + restores the - global TF32 flag on CPU; numerically beats a TF32 matmul on GPU), bf16/fp16 dtype-path - accuracy (relative-to-peak tolerance, with `bias`), output shape, bias semantics, Axis-A - batch invariance (slice + padding, single-thread reduction, all dtypes), input purity, - gradient flow to `hidden`/`weight` (closed-form check), registry dispatch, and a GPU-only - smoke test at the real Qwen3-8B dims (`vocab=151936, hidden=4096`) that skips when CUDA or - GPU memory is unavailable. + Covers: fp32 correctness vs naive matmul (bitwise, with ambient TF32 pinned off) and + `forward_fp32` precision-context safety (true fp32 under ambient autocast, restores the + global TF32 flag on CPU, numerically beats a TF32 matmul on GPU). + + Also covers: bf16/fp16 dtype-path accuracy (relative-to-peak tolerance, with `bias`), + output shape, bias semantics, Axis-A batch invariance (slice + padding, single-thread + reduction, all dtypes), input purity, and gradient flow to `hidden`/`weight` (closed-form). + + Registry dispatch and a GPU-only smoke test at the real Qwen3-8B dims (`vocab=151936, + hidden=4096`) round out coverage; the smoke test skips when CUDA or GPU memory is unavailable.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@docs/operators/lm_head.md` around lines 108 - 115, The test coverage description in the lm_head.md documentation is a single long sentence that lists nine distinct test concerns, making it difficult to read and parse. Break this single sentence into 2-3 shorter sentences organized by major category: group precision and dtype-related tests together (fp32 correctness, forward_fp32 safety, bf16/fp16 accuracy, output shape and bias semantics), then create a second sentence for batch invariance and purity tests (Axis-A batch invariance, input purity, gradient flow), and finally add a third sentence for the GPU-only smoke test at Qwen3-8B dimensions. This restructuring will improve readability while maintaining all the technical details.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@docs/operators/lm_head.md`:
- Around line 108-115: The test coverage description in the lm_head.md
documentation is a single long sentence that lists nine distinct test concerns,
making it difficult to read and parse. Break this single sentence into 2-3
shorter sentences organized by major category: group precision and dtype-related
tests together (fp32 correctness, forward_fp32 safety, bf16/fp16 accuracy,
output shape and bias semantics), then create a second sentence for batch
invariance and purity tests (Axis-A batch invariance, input purity, gradient
flow), and finally add a third sentence for the GPU-only smoke test at Qwen3-8B
dimensions. This restructuring will improve readability while maintaining all
the technical details.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 75685450-1855-40ee-b44d-c14e47328e11
📒 Files selected for processing (3)
docs/operators/lm_head.mdrl_engine/kernels/ops/pytorch/linear/lm_head.pytests/test_lm_head.py
🚧 Files skipped from review as they are similar to previous changes (2)
- rl_engine/kernels/ops/pytorch/linear/lm_head.py
- tests/test_lm_head.py
Summary
Adds the pure-PyTorch ground-truth reference op for the language-model head — the
output layer of the WS1 batch-invariant forward chain — built on top of the numerical
contract defined in #108. Ships the
op, its registry wiring, docs, and a 15-case test suite that pins down both alignment
axes (Axis-A bitwise batch invariance, Axis-B per-dtype accuracy), plus a GPU-only smoke
test at the real Qwen3-8B projection dims.
Refs #108
Terminology
This PR uses the WS1 alignment vocabulary from #108:
many rows share the batch (batch size, slicing, padding). Asserted bitwise
(
torch.equal). This is what keeps train-time (large batch) and sample-time (smallbatch / dynamic padding) numerics identical so the policy ratio doesn't drift.
lossless embedding gather, lm_head is a reduction over hidden, so low-precision
accumulation drifts from the fp32 reference and is checked against a tolerance
window (not bitwise).
Motivation / Context
#108 establishes the ground-truth
harness and numerical contract for the WS1 batch-invariant forward chain. The final stage
of the Qwen3-8B stack projects hidden states back to vocabulary logits:
logits = hidden @ weight.t() # weight is HF [out, in]
This PR provides the deterministic fp32 reference path that downstream kernels (Triton /
CUDA / ROCm) will be validated against. For Qwen3-8B the weight is the output projection
[vocab=151936, hidden=4096]in the HFnn.Linear[out, in]convention (transposedinternally), is independent from the embedding table (
tie_word_embeddings=false),and has no bias.
Changes
rl_engine/kernels/ops/pytorch/linear/lm_head.py—NativeLMHeadOpforward()— project in the input dtype, output the input dtype (Axis-B path)forward_fp32()— upcast to fp32, accumulate in fp32, forced fp32 output(ground-truth / backward golden source)
out = hidden @ weight.t() (+ bias)[out, in]and transposed internally — the one difference from the barematmulop; do not use interchangeablyhiddenrl_engine/kernels/registry.py— registerPYTORCH_NATIVE_LM_HEADinOpBackendand add
lm_headdispatch to the cuda / rocm / cpu priority mapstests/test_lm_head.py— 15 tests (details below)docs/operators/lm_head.md+ nav / index wiringHow this satisfies the #108 contract
forward_fp32()accumulates in fp32; tests use fixed-seedtorch.Generatorso outputs are reproducibletorch.equal); Axis-B reduction drift checked against a per-dtype tolerance measured relative to the output peak (bf16 ~0.37%, fp16 ~0.05%)vocab=151936, hidden=4096); CPU tests keep the realhidden=4096reduction length (only vocab is shrunk) so the drift is representative; skips when CUDA / GPU memory is unavailableReduction-specific note (Axis-A reduction order)
A single
torch.matmulis not bitwise batch-invariant by default: multi-threaded CPUGEMM splits the
hidden(K) reduction across threads by theM = batch*seqdimension, so"compute full then slice" ≠ "compute slice" once
hiddenis large. The tests pin a singlethread to fix the reduction order (a local stand-in for the planned
testing/determinism.py::deterministic_context). On GPU cuBLAS likewise splits K byM,so a batch-invariant GEMM is a downstream kernel concern — the GPU smoke test validates the
full-vocab shape and fp32 correctness, not Axis-A bitwise.
Test Environment
Test Results
python -m pytest tests/test_lm_head.py -v

17 passed
The 17 tests cover:
torch.equal); fp32 forwardpath bitwise-equal to the ground truth
fp16 ~0.05%), with error stats printed
hidden.shape[:-1] + (vocab,))Nonedefault == no bias; provided[vocab]bias added)bitwise under a pinned single-thread reduction
hidden,weight, norbiasmutated in place)hidden/weight(fp32 autograd = backward golden source), verifiedagainst the closed-form grads
lm_head→NativeLMHeadOpvocab=151936, hidden=4096)Checklist
Summary by CodeRabbit
lm_headroutes to the PyTorch-native implementation instead of using the generic fallback.