Skip to content

feat(ws1): NativeRMSNormOp pure-PyTorch ground-truth reference + numerical contract tests#160

Open
maxiaosong1124 wants to merge 4 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-rms_norm-pytorch-op
Open

feat(ws1): NativeRMSNormOp pure-PyTorch ground-truth reference + numerical contract tests#160
maxiaosong1124 wants to merge 4 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-rms_norm-pytorch-op

Conversation

@maxiaosong1124

@maxiaosong1124 maxiaosong1124 commented Jun 20, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds the pure-PyTorch ground-truth reference op for RMSNorm (pre-norm / QK-Norm)
as the first WS1 batch-invariant operator built on top of the numerical contract
defined in #108. Ships the op, its registry wiring, and a 16-case test suite that
pins down both alignment axes (Axis-A bitwise batch invariance, Axis-B per-dtype
tolerance).

Refs #108

Terminology

This PR uses the WS1 alignment vocabulary from #108:

  • Axis-A — batch invariance (reproducibility). A row's output must not depend on
    how many rows share the batch (batch size, slicing, padding). Asserted bitwise
    (torch.equal). This is what keeps train-time (large batch) and sample-time
    (small batch / dynamic padding) numerics identical so the policy ratio doesn't drift.
  • Axis-B — accuracy. The low-precision (bf16 / fp16) forward must stay within a
    documented per-dtype tolerance of the fp32 ground-truth. Asserted with torch.allclose
    • per-dtype thresholds.

Motivation / Context

#108 establishes the ground-truth harness and numerical contract for the WS1
batch-invariant forward chain. RMSNorm is required on two normalized dims of the
target model (Qwen3-8B dense):

  • hidden = 4096 — input / post-attention norm
  • head_dim = 128 — QK-Norm (per-head RMSNorm on Q and K)

This PR provides the deterministic fp32 reference path those downstream kernels
(Triton / CUDA / ROCm RMSNorm) will be validated against.

Changes

  • rl_engine/kernels/ops/pytorch/norm/rms_norm.pyNativeRMSNormOp
    • forward() — accumulate in fp32, cast result back to x.dtype (Axis-B candidate path)
    • forward_fp32() — fp32 accumulation, forced fp32 output (ground-truth / backward golden source)
    • Formula: out = x * rsqrt(mean(x^2, dim=-1) + eps) * weight
    • eps lives inside the sqrt; plain weight scaling (not the 1 + weight variant)
    • Shape guard: weight must be 1-D of size x.shape[-1]
  • rl_engine/kernels/registry.py — register PYTORCH_NATIVE_RMS_NORM
    and add rms_norm dispatch to the cuda / rocm / cpu priority maps
  • tests/test_rms_norm.py — 16 tests (details below)

How this satisfies the #108 contract

#108 requirement How it's met here
Deterministic reference path, fixed reduction order forward_fp32() accumulates in fp32 along dim=-1; tests use fixed-seed torch.Generator so outputs are reproducible
Per-dtype tolerance policy (bitwise vs tight-tolerance) Axis-A asserted bitwise (torch.equal); Axis-B asserted within documented per-dtype thresholds — bf16 atol=2e-2, rtol=1.6e-2, fp16 atol=1e-3, rtol=1e-3
Batch-config sweep / validation helper Batch-invariance checks compute on the full batch, then assert sliced/padded rows are bitwise identical to their full-batch counterparts
Both normalized dims covered Every correctness/invariance test is parametrized over hidden=4096 and head_dim=128

Test Environment

OS Ubuntu 22.04.5 LTS (kernel 5.15.0-122-generic)
Python 3.12.3
PyTorch 2.8.0+cu128
CUDA / cuDNN 12.8 / 9.10.02 (driver 580.65.06)
pytest 9.0.3
GPU NVIDIA H20

Testing

Run from the repo root with python -m pytest (the -m form puts the repo on

python -m pytest tests/test_rms_norm.py

→ 16 passed, covering:

- Correctness vs an independent hand-written fp32 formula (bitwise, both dims)
- Axis-A batch invariance: row output is independent of batch size — slice and
padding variants, asserted bitwise
- dtype paths: forward follows input dtype; forward_fp32 forces fp32
- Axis-B low-precision (bf16 / fp16) within tolerance of the fp32 reference
- eps inside sqrt (zero input → finite zero output)
- plain weight scaling (rules out the 1 + weight variant)
- shape guard fires on wrong-size / non-1-D weight
- purity (inputs not mutated in place)
- gradient flow (fp32 autograd = backward golden source)
- registry dispatch resolves rms_norm → NativeRMSNormOp

Rebased onto latest upstream/main; registry dispatch for the neighboring
ratio_kl / grpo_loss ops verified intact after conflict resolution.

Checklist

- [x] Pure-PyTorch reference, no custom extension required
- [x] Both Qwen3-8B normalized dims (4096, 128) covered
- [x] Axis-A bitwise batch invariance enforced
- [x] Axis-B per-dtype tolerance documented and tested
- [x] Registered in OpBackend + cuda/rocm/cpu priority maps
- [x] All 16 tests pass locally

---

<!-- This is an auto-generated comment: release notes by coderabbit.ai -->
## Summary by CodeRabbit

* **New Features**
  * Added RMSNorm with multi-backend support across CUDA, ROCm, and CPU, including a pure-PyTorch reference path.
  * Supports fp16/bf16 and fp32 execution, with an option to force fp32 outputs.
  * Includes proper handling of `eps`, weight shape requirements, and output dtype casting.

* **Tests**
  * Added extensive pytest coverage validating correctness vs an fp32 reference, dtype routing, edge cases, input non-mutation, and gradients.
  * Verified operator dispatch resolution for `"rms_norm"`.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

- NativeRMSNormOp with forward / forward_fp32 (fp32 ground-truth path)
- covers both normalized dims: hidden=4096 and head_dim=128 (Qwen3 QK-Norm)
- register PYTORCH_NATIVE_RMS_NORM in OpBackend + cpu/cuda/rocm priority map
- tests/test_rms_norm.py: axis-A bitwise batch invariance + dtype tolerance,
  shape guard, purity, gradient flow, registry dispatch (16 tests)

Refs RL-Align#108
@coderabbitai

coderabbitai Bot commented Jun 20, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: aabe65fc-7f15-4a87-a710-2d023b33009f

📥 Commits

Reviewing files that changed from the base of the PR and between 5396d27 and 6c50a87.

📒 Files selected for processing (2)
  • rl_engine/kernels/ops/pytorch/norm/rms_norm.py
  • tests/test_rms_norm.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/test_rms_norm.py
  • rl_engine/kernels/ops/pytorch/norm/rms_norm.py

📝 Walkthrough

Walkthrough

Adds NativeRMSNormOp, a pure-PyTorch RMSNorm reference implementation, to rl_engine/kernels/ops/pytorch/norm/rms_norm.py. Registers it as OpBackend.PYTORCH_NATIVE_RMS_NORM in KernelRegistry with dispatch entries for cuda, rocm, and cpu. A new pytest module validates correctness, dtype routing, batch invariance, shape guards, purity, gradients, and registry dispatch.

Changes

NativeRMSNormOp: Implementation, Registry Wiring, and Tests

Layer / File(s) Summary
NativeRMSNormOp class and core _rms_norm math
rl_engine/kernels/ops/pytorch/norm/rms_norm.py
Defines NativeRMSNormOp with __call__/forward (fp32 accumulation, casts output to x.dtype), forward_fp32 (forces float32 output), and static _rms_norm that validates weight shape, computes rsqrt(mean(x²) + eps) * weight, and casts to output_dtype.
OpBackend enum and KernelRegistry dispatch wiring
rl_engine/kernels/registry.py
Adds PYTORCH_NATIVE_RMS_NORM to OpBackend with the NativeRMSNormOp import path, and extends KernelRegistry._priority_map with rms_norm entries for cuda, rocm, and cpu.
Test suite: correctness, dtype, guards, purity, gradients, registry
tests/test_rms_norm.py
Validates NativeRMSNormOp against a manual fp32 reference for two normalized dimensions, batch/padding invariance (bitwise equality), dtype routing, bf16/fp16 tolerances, eps/zero-input finiteness, linear weight scaling, ValueError on bad weight shapes, input non-mutation, gradient finiteness, and kernel_registry dispatch.

Possibly Related Issues

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🐇 A little kernel woke up one day,
And said "Let's norm these weights the PyTorch way!"
With rsqrt and mean(x²) in sight,
fp32 accumulation shining bright,
On CUDA, ROCm, CPU it'll run — hooray! 🌟

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 26.32% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically summarizes the main changes: introducing NativeRMSNormOp as a pure-PyTorch implementation with associated test coverage.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/kernels/ops/pytorch/norm/rms_norm.py`:
- Around line 5-6: The file rms_norm.py has linting and formatting violations
from black, isort, and trailing-whitespace checks that are blocking CI. Run the
project's pre-commit formatting hooks (typically via a command like 'pre-commit
run --all-files' or 'black' and 'isort' individually) to automatically reformat
the file and fix signature spacing issues around lines 25-31, expression
formatting issues around lines 38-44 and 57-67, and any trailing whitespace
violations. Commit the reformatted file after the hooks complete.

In `@tests/test_rms_norm.py`:
- Around line 10-11: The test_rms_norm.py file has formatting inconsistencies
detected by Black. Run the Black formatter on this file to automatically fix
alignment and spacing issues in inline comments (like those on the _HIDDEN and
_HEAD_DIM constant definitions) and long assertions (around the test assertion
blocks). Apply Black's output and commit the formatted result to resolve the CI
formatting check.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 34ccd0aa-23c0-4f09-8d2a-e7d333cea8a2

📥 Commits

Reviewing files that changed from the base of the PR and between 9dfcdbc and 5396d27.

📒 Files selected for processing (3)
  • rl_engine/kernels/ops/pytorch/norm/rms_norm.py
  • rl_engine/kernels/registry.py
  • tests/test_rms_norm.py

Comment thread rl_engine/kernels/ops/pytorch/norm/rms_norm.py
Comment thread tests/test_rms_norm.py Outdated
Resolve CodeRabbit formatting findings on RL-Align#160: black (line-length=100),
isort (profile=black), trailing-whitespace and EOF fixes. No logic change;
16 tests still pass.
@Flink-ddd Flink-ddd added platform: cuda Specific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations) priority: high Severe congestion issues require the highest priority for resolution. sprint-0615 labels Jun 21, 2026
@Flink-ddd Flink-ddd requested a review from EthanZero2Hero June 21, 2026 13:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-gpu-ci platform: cuda Specific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations) priority: high Severe congestion issues require the highest priority for resolution. sprint-0615

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants