feat(ws1): NativeRMSNormOp pure-PyTorch ground-truth reference + numerical contract tests#160
Conversation
- NativeRMSNormOp with forward / forward_fp32 (fp32 ground-truth path) - covers both normalized dims: hidden=4096 and head_dim=128 (Qwen3 QK-Norm) - register PYTORCH_NATIVE_RMS_NORM in OpBackend + cpu/cuda/rocm priority map - tests/test_rms_norm.py: axis-A bitwise batch invariance + dtype tolerance, shape guard, purity, gradient flow, registry dispatch (16 tests) Refs RL-Align#108
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughAdds ChangesNativeRMSNormOp: Implementation, Registry Wiring, and Tests
Possibly Related Issues
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/kernels/ops/pytorch/norm/rms_norm.py`:
- Around line 5-6: The file rms_norm.py has linting and formatting violations
from black, isort, and trailing-whitespace checks that are blocking CI. Run the
project's pre-commit formatting hooks (typically via a command like 'pre-commit
run --all-files' or 'black' and 'isort' individually) to automatically reformat
the file and fix signature spacing issues around lines 25-31, expression
formatting issues around lines 38-44 and 57-67, and any trailing whitespace
violations. Commit the reformatted file after the hooks complete.
In `@tests/test_rms_norm.py`:
- Around line 10-11: The test_rms_norm.py file has formatting inconsistencies
detected by Black. Run the Black formatter on this file to automatically fix
alignment and spacing issues in inline comments (like those on the _HIDDEN and
_HEAD_DIM constant definitions) and long assertions (around the test assertion
blocks). Apply Black's output and commit the formatted result to resolve the CI
formatting check.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 34ccd0aa-23c0-4f09-8d2a-e7d333cea8a2
📒 Files selected for processing (3)
rl_engine/kernels/ops/pytorch/norm/rms_norm.pyrl_engine/kernels/registry.pytests/test_rms_norm.py
Resolve CodeRabbit formatting findings on RL-Align#160: black (line-length=100), isort (profile=black), trailing-whitespace and EOF fixes. No logic change; 16 tests still pass.
Summary
Adds the pure-PyTorch ground-truth reference op for RMSNorm (pre-norm / QK-Norm)
as the first WS1 batch-invariant operator built on top of the numerical contract
defined in #108. Ships the op, its registry wiring, and a 16-case test suite that
pins down both alignment axes (Axis-A bitwise batch invariance, Axis-B per-dtype
tolerance).
Refs #108
Terminology
This PR uses the WS1 alignment vocabulary from #108:
how many rows share the batch (batch size, slicing, padding). Asserted bitwise
(
torch.equal). This is what keeps train-time (large batch) and sample-time(small batch / dynamic padding) numerics identical so the policy ratio doesn't drift.
documented per-dtype tolerance of the fp32 ground-truth. Asserted with
torch.allcloseMotivation / Context
#108 establishes the ground-truth harness and numerical contract for the WS1
batch-invariant forward chain. RMSNorm is required on two normalized dims of the
target model (Qwen3-8B dense):
hidden = 4096— input / post-attention normhead_dim = 128— QK-Norm (per-head RMSNorm on Q and K)This PR provides the deterministic fp32 reference path those downstream kernels
(Triton / CUDA / ROCm RMSNorm) will be validated against.
Changes
rl_engine/kernels/ops/pytorch/norm/rms_norm.py—NativeRMSNormOpforward()— accumulate in fp32, cast result back tox.dtype(Axis-B candidate path)forward_fp32()— fp32 accumulation, forced fp32 output (ground-truth / backward golden source)out = x * rsqrt(mean(x^2, dim=-1) + eps) * weightepslives inside the sqrt; plain weight scaling (not the1 + weightvariant)weightmust be 1-D of sizex.shape[-1]rl_engine/kernels/registry.py— registerPYTORCH_NATIVE_RMS_NORMand add
rms_normdispatch to the cuda / rocm / cpu priority mapstests/test_rms_norm.py— 16 tests (details below)How this satisfies the #108 contract
forward_fp32()accumulates in fp32 alongdim=-1; tests use fixed-seedtorch.Generatorso outputs are reproducibletorch.equal); Axis-B asserted within documented per-dtype thresholds — bf16atol=2e-2, rtol=1.6e-2, fp16atol=1e-3, rtol=1e-3hidden=4096andhead_dim=128Test Environment
Testing
Run from the repo root with
python -m pytest(the-mform puts the repo on