Skip to content

feat(ws1): Add PyTorch RoPE reference operator#167

Open
a-kaa wants to merge 2 commits into
RL-Align:mainfrom
a-kaa:dev-kernel
Open

feat(ws1): Add PyTorch RoPE reference operator#167
a-kaa wants to merge 2 commits into
RL-Align:mainfrom
a-kaa:dev-kernel

Conversation

@a-kaa

@a-kaa a-kaa commented Jun 21, 2026

Copy link
Copy Markdown
Collaborator

Summary

Implements NativeRoPEOp, the fp32 ground-truth RoPE reference operator for ISSUE #108.

  • Shapes: Qwen3-8B defaults — theta=1e6 (not the 1e4 default), head_dim=128, full-dimension rotation (half=64).
  • Dtype contract: forward() follows input dtype; forward_fp32() is the fp32 gold standard with internal fp32 accumulation. Both implemented (fp16 and fp32 paths tested).
  • Purity: no in-place ops, no random state, no global state. cos/sin are computed internally in fp32 from positions/theta — no external cache accepted or returned.

Changes

  • rl_engine/kernels/ops/pytorch/pos/rope.pyNativeRoPEOp (op_class = "elementwise")
  • docs/operators/rope.md — operator reference doc (shapes, convention, tolerance contract)
  • tests/ops/test_rope.py — full test suite, all green locally:
    • Axis A (bitwise batch invariance): torch.equal between batch=1 slice and batch=N slice, including padded batches and per-row 2D positions ([B, S])
    • Axis B (accuracy): forward vs forward_fp32 under tolerance contract (fp32: 1e-5, bf16: 2e-2, fp16: 1e-3)
    • Bitwise match against the frozen ISSUE_108_OPS_DEV.md §5 reference formula
    • Shape/dtype correctness, position broadcasting equivalence ([S] vs [B,S]), pure-function checks
    • Qwen3-8B specific shapes (Q heads=32, KV heads=8, head_dim=128)
image

Summary by CodeRabbit

  • New Features

    • Added RoPE (Rotary Position Embeddings) operator with PyTorch native backend support for CPU, CUDA, and ROCm.
    • Registered Matmul operator in backend dispatch system.
  • Documentation

    • Added operator documentation for RoPE and Matmul with usage instructions and reference semantics.
  • Tests

    • Added comprehensive test suite validating RoPE operator functionality, accuracy, and dtype behavior.

@coderabbitai

coderabbitai Bot commented Jun 21, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c415759a-f2ac-4f99-801e-dee58c82057c

📥 Commits

Reviewing files that changed from the base of the PR and between ba5a4cd and 23b1156.

📒 Files selected for processing (2)
  • rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py
  • tests/test_rope.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/test_rope.py
  • rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py

📝 Walkthrough

Walkthrough

Adds NativeRoPEOp, a pure-PyTorch HF rotate-half RoPE reference operator. The class is implemented with forward, forward_fp32, and _compute_cos_sin, packaged under rl_engine.kernels.ops.pytorch.rotary_embedding, registered in OpBackend and KernelRegistry for cpu/cuda/rocm, covered by 241 lines of pytest, and documented in docs/operators/rope.md.

Changes

NativeRoPEOp implementation, registration, tests, and documentation

Layer / File(s) Summary
OpBackend enum and KernelRegistry dispatch for rope
rl_engine/kernels/registry.py, rl_engine/kernels/ops/pytorch/rotary_embedding/__init__.py
Adds PYTORCH_NATIVE_ROPE to OpBackend wired to rope.NativeRoPEOp. Extends KernelRegistry._priority_map to route the rope op on cuda/rocm/cpu to that backend. __init__.py imports and exports NativeRoPEOp.
NativeRoPEOp: forward, forward_fp32, and _compute_cos_sin
rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py
forward applies rotate-half with fp32 cos/sin and casts the result back to input dtype. forward_fp32 keeps the output in fp32. _compute_cos_sin builds inv_freq from theta, computes frequency products from positions, and reshapes cos/sin for 1D [S] or 2D [B,S] broadcasting.
Correctness, batch invariance, accuracy, and shape tests
tests/test_rope.py
Four test classes: bitwise match against HF fp32 reference and 1D/2D position equivalence; bitwise batch invariance with padded rows and bf16 inputs; dtype-parameterized tolerance checks between forward and forward_fp32; Qwen3-8B-like shapes including the n_kv_heads=8 K-path.
Operator documentation and README index update
docs/operators/rope.md, docs/operators/README.md
rope.md documents usage, backend dispatch, tensor contract, rotate-half semantics, and accuracy expectations. README Current Pages section gains links for rope.md and matmul.md.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~15 minutes

Poem

🐇 A RoPE spun from cosines so bright,
Half the vector rotated just right,
fp32 gold reference in hand,
Qwen3 shapes all tested as planned,
The registry now knows where to look—
The kernel hopped straight into the book!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 35.71% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: adding a PyTorch RoPE (Rotary Position Embedding) reference operator implementation, which is the central objective of this PR.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
tests/test_rope.py (1)

53-126: ⚡ Quick win

Add boundary-contract tests for invalid inputs.

Given the operator now has strict shape/device contracts, add negative tests for odd head_dim and unsupported positions rank to prevent future regressions.

Suggested tests
 class TestNativeRoPEOpCorrectness:
@@
     def test_position_zero_is_identity_for_cos(self):
         ...
         assert torch.allclose(out, x.float(), atol=1e-7)
+
+    def test_odd_head_dim_raises(self):
+        op = NativeRoPEOp()
+        x = torch.randn(1, 1, 4, 127)
+        pos = torch.arange(4, dtype=torch.long)
+        with pytest.raises(ValueError, match="even head_dim"):
+            op.forward_fp32(x, pos)
+
+    def test_positions_rank_validation(self):
+        op = NativeRoPEOp()
+        x = torch.randn(1, 1, 4, 128)
+        bad_pos = torch.zeros(1, 4, 1, dtype=torch.long)
+        with pytest.raises(ValueError, match="positions"):
+            op.forward_fp32(x, bad_pos)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_rope.py` around lines 53 - 126, Add two new test methods to the
TestNativeRoPEOpCorrectness class to validate the operator's input contracts:
create a test_odd_head_dim_raises_error method that verifies
NativeRoPEOp.forward_fp32 raises an appropriate error when given tensors with
odd head_dim, and create a test_unsupported_positions_rank_raises_error method
that verifies the operator rejects positions tensors with unsupported rank
(e.g., rank > 2). Use pytest.raises or equivalent assertion to confirm the
expected exceptions are raised for each invalid input scenario.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@docs/operators/rope.md`:
- Around line 38-43: In the documentation table for NativeRoPEOp, the `x`
argument (Query or key tensor with shape [B, H, S, D]) is missing an explicit
precondition. Since NativeRoPEOp splits D into two halves internally, add a
requirement stating that D must be even to the Requirements column for the x
argument. This ensures users understand that odd-width dimension inputs are not
supported and clarifies the expected behavior of the rotate-half semantics.

In `@rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py`:
- Around line 69-71: Add explicit validation to ensure the head dimension D is
even before proceeding with RoPE table construction. After extracting D from
x.shape[-1] (around line 69), insert a check to verify that D is even using
modulo operator, and raise a clear ValueError if D is odd to fail fast at the
contract boundary. Apply the same validation pattern at the second location
mentioned (lines 85-86) where similar dimension extraction occurs.
- Around line 80-83: In the `_compute_cos_sin` function, the `positions` tensor
may reside on a different device than `x` (e.g., `x` on GPU, `positions` on
CPU), which causes device mismatch errors during the frequency computation.
Before converting `positions` to float and computing the frequency values in the
line where `pos_float = positions.float().unsqueeze(-1)` occurs, first ensure
`positions` is moved to the same device as `x` using the appropriate device
transfer operation, then convert to float32. Additionally, validate the rank of
the `positions` tensor to ensure it has the expected shape for the downstream
operations.

---

Nitpick comments:
In `@tests/test_rope.py`:
- Around line 53-126: Add two new test methods to the
TestNativeRoPEOpCorrectness class to validate the operator's input contracts:
create a test_odd_head_dim_raises_error method that verifies
NativeRoPEOp.forward_fp32 raises an appropriate error when given tensors with
odd head_dim, and create a test_unsupported_positions_rank_raises_error method
that verifies the operator rejects positions tensors with unsupported rank
(e.g., rank > 2). Use pytest.raises or equivalent assertion to confirm the
expected exceptions are raised for each invalid input scenario.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a4031d4f-94d1-4365-abe7-6f45f2958521

📥 Commits

Reviewing files that changed from the base of the PR and between a302be4 and 78d01b1.

📒 Files selected for processing (6)
  • docs/operators/README.md
  • docs/operators/rope.md
  • rl_engine/kernels/ops/pytorch/rotary_embedding/__init__.py
  • rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py
  • rl_engine/kernels/registry.py
  • tests/test_rope.py

Comment thread docs/operators/rope.md
Comment on lines +38 to +43
| Argument | Shape | Dtype | Requirements |
| --- | --- | --- | --- |
| `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; Qwen3 uses `D=128`. |
| `positions` | `[S]` or `[B, S]` | Integer | Absolute token positions. |
| `theta` | scalar | float | Defaults to `1_000_000.0` for Qwen3. |
| Output | `[B, H, S, D]` | See below | Same shape as `x`. |

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

State the even-head-dim precondition.

NativeRoPEOp splits D into two halves, so the contract should explicitly say D must be even. Without that, odd-width inputs look supported here but won’t preserve the intended rotate-half semantics.

📌 Suggested doc tweak
-| `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; Qwen3 uses `D=128`. |
+| `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; `D` must be even; Qwen3 uses `D=128`. |
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
| Argument | Shape | Dtype | Requirements |
| --- | --- | --- | --- |
| `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; Qwen3 uses `D=128`. |
| `positions` | `[S]` or `[B, S]` | Integer | Absolute token positions. |
| `theta` | scalar | float | Defaults to `1_000_000.0` for Qwen3. |
| Output | `[B, H, S, D]` | See below | Same shape as `x`. |
| Argument | Shape | Dtype | Requirements |
| --- | --- | --- | --- |
| `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; `D` must be even; Qwen3 uses `D=128`. |
| `positions` | `[S]` or `[B, S]` | Integer | Absolute token positions. |
| `theta` | scalar | float | Defaults to `1_000_000.0` for Qwen3. |
| Output | `[B, H, S, D]` | See below | Same shape as `x`. |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/operators/rope.md` around lines 38 - 43, In the documentation table for
NativeRoPEOp, the `x` argument (Query or key tensor with shape [B, H, S, D]) is
missing an explicit precondition. Since NativeRoPEOp splits D into two halves
internally, add a requirement stating that D must be even to the Requirements
column for the x argument. This ensures users understand that odd-width
dimension inputs are not supported and clarifies the expected behavior of the
rotate-half semantics.

Comment on lines +69 to +71
D = x.shape[-1]
half = D // 2

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Add explicit even head_dim validation before building RoPE tables.

Odd D currently fails later via shape mismatch; fail fast with a clear error at the contract boundary.

Proposed fix
         D = x.shape[-1]
+        if D % 2 != 0:
+            raise ValueError(f"RoPE requires even head_dim, got {D}")
         half = D // 2

Also applies to: 85-86

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py` around lines 69 - 71,
Add explicit validation to ensure the head dimension D is even before proceeding
with RoPE table construction. After extracting D from x.shape[-1] (around line
69), insert a check to verify that D is even using modulo operator, and raise a
clear ValueError if D is odd to fail fast at the contract boundary. Apply the
same validation pattern at the second location mentioned (lines 85-86) where
similar dimension extraction occurs.

Comment on lines +80 to +83
pos_float = positions.float().unsqueeze(-1)

# freqs: [S, half] or [B, S, half]
freqs = pos_float * inv_freq

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Normalize positions to x.device (and validate rank) in _compute_cos_sin.

Current code can crash on mixed-device inputs (x on accelerator, positions on CPU). Convert positions to fp32 on x.device before frequency math.

Proposed fix
-        pos_float = positions.float().unsqueeze(-1)
+        if positions.dim() not in (1, 2):
+            raise ValueError(f"positions must have shape [S] or [B,S], got dim={positions.dim()}")
+        pos_float = positions.to(device=x.device, dtype=torch.float32).unsqueeze(-1)

Also applies to: 92-99

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py` around lines 80 - 83,
In the `_compute_cos_sin` function, the `positions` tensor may reside on a
different device than `x` (e.g., `x` on GPU, `positions` on CPU), which causes
device mismatch errors during the frequency computation. Before converting
`positions` to float and computing the frequency values in the line where
`pos_float = positions.float().unsqueeze(-1)` occurs, first ensure `positions`
is moved to the same device as `x` using the appropriate device transfer
operation, then convert to float32. Additionally, validate the rank of the
`positions` tensor to ensure it has the expected shape for the downstream
operations.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/kernels/ops/pytorch/linear/matmul.py`:
- Around line 25-27: The forward method accepts mixed input dtypes without
validation and silently casts the result to a.dtype, which can hide upstream
mistakes. Add an explicit dtype compatibility check at the beginning of the
forward method that validates tensors a and b have compatible or identical
dtypes, raising a clear error if they don't match. Only proceed with the
forward_fp32 computation and dtype casting after this validation passes.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d71dc47d-9508-47c9-a83c-852d642341a6

📥 Commits

Reviewing files that changed from the base of the PR and between 78d01b1 and ba5a4cd.

📒 Files selected for processing (6)
  • docs/operators/README.md
  • docs/operators/matmul.md
  • rl_engine/kernels/ops/pytorch/linear/__init__.py
  • rl_engine/kernels/ops/pytorch/linear/matmul.py
  • rl_engine/kernels/registry.py
  • tests/test_matmul.py
✅ Files skipped from review due to trivial changes (2)
  • docs/operators/README.md
  • docs/operators/matmul.md
🚧 Files skipped from review as they are similar to previous changes (1)
  • rl_engine/kernels/registry.py

Comment on lines +25 to +27
def forward(self, a: Tensor, b: Tensor) -> Tensor:
"""Compute `a @ b` and return the input dtype."""
return self.forward_fp32(a, b).to(dtype=a.dtype)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate dtype compatibility before casting output to a.dtype.

On Line 27, mixed input dtypes are silently accepted and the result is forced to a.dtype, which can hide upstream mistakes and unintentionally downcast results. Add an explicit dtype check (or documented promotion rule) before the cast.

Suggested patch
     def forward(self, a: Tensor, b: Tensor) -> Tensor:
         """Compute `a @ b` and return the input dtype."""
+        if a.dtype != b.dtype:
+            raise TypeError(f"NativeMatmulOp.forward expects matching dtypes, got {a.dtype} and {b.dtype}")
         return self.forward_fp32(a, b).to(dtype=a.dtype)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/pytorch/linear/matmul.py` around lines 25 - 27, The
forward method accepts mixed input dtypes without validation and silently casts
the result to a.dtype, which can hide upstream mistakes. Add an explicit dtype
compatibility check at the beginning of the forward method that validates
tensors a and b have compatible or identical dtypes, raising a clear error if
they don't match. Only proceed with the forward_fp32 computation and dtype
casting after this validation passes.

@a-kaa a-kaa changed the title Add PyTorch RoPE reference operator feat(ws1): Add PyTorch RoPE reference operator Jun 22, 2026
@Flink-ddd Flink-ddd assigned Flink-ddd and unassigned Flink-ddd Jun 22, 2026
@Flink-ddd

Copy link
Copy Markdown
Collaborator

please resolve conflicts and CI error first, Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants