feat(ws1): Add PyTorch RoPE reference operator#167
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughAdds ChangesNativeRoPEOp implementation, registration, tests, and documentation
Estimated code review effort🎯 2 (Simple) | ⏱️ ~15 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
tests/test_rope.py (1)
53-126: ⚡ Quick winAdd boundary-contract tests for invalid inputs.
Given the operator now has strict shape/device contracts, add negative tests for odd
head_dimand unsupportedpositionsrank to prevent future regressions.Suggested tests
class TestNativeRoPEOpCorrectness: @@ def test_position_zero_is_identity_for_cos(self): ... assert torch.allclose(out, x.float(), atol=1e-7) + + def test_odd_head_dim_raises(self): + op = NativeRoPEOp() + x = torch.randn(1, 1, 4, 127) + pos = torch.arange(4, dtype=torch.long) + with pytest.raises(ValueError, match="even head_dim"): + op.forward_fp32(x, pos) + + def test_positions_rank_validation(self): + op = NativeRoPEOp() + x = torch.randn(1, 1, 4, 128) + bad_pos = torch.zeros(1, 4, 1, dtype=torch.long) + with pytest.raises(ValueError, match="positions"): + op.forward_fp32(x, bad_pos)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_rope.py` around lines 53 - 126, Add two new test methods to the TestNativeRoPEOpCorrectness class to validate the operator's input contracts: create a test_odd_head_dim_raises_error method that verifies NativeRoPEOp.forward_fp32 raises an appropriate error when given tensors with odd head_dim, and create a test_unsupported_positions_rank_raises_error method that verifies the operator rejects positions tensors with unsupported rank (e.g., rank > 2). Use pytest.raises or equivalent assertion to confirm the expected exceptions are raised for each invalid input scenario.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@docs/operators/rope.md`:
- Around line 38-43: In the documentation table for NativeRoPEOp, the `x`
argument (Query or key tensor with shape [B, H, S, D]) is missing an explicit
precondition. Since NativeRoPEOp splits D into two halves internally, add a
requirement stating that D must be even to the Requirements column for the x
argument. This ensures users understand that odd-width dimension inputs are not
supported and clarifies the expected behavior of the rotate-half semantics.
In `@rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py`:
- Around line 69-71: Add explicit validation to ensure the head dimension D is
even before proceeding with RoPE table construction. After extracting D from
x.shape[-1] (around line 69), insert a check to verify that D is even using
modulo operator, and raise a clear ValueError if D is odd to fail fast at the
contract boundary. Apply the same validation pattern at the second location
mentioned (lines 85-86) where similar dimension extraction occurs.
- Around line 80-83: In the `_compute_cos_sin` function, the `positions` tensor
may reside on a different device than `x` (e.g., `x` on GPU, `positions` on
CPU), which causes device mismatch errors during the frequency computation.
Before converting `positions` to float and computing the frequency values in the
line where `pos_float = positions.float().unsqueeze(-1)` occurs, first ensure
`positions` is moved to the same device as `x` using the appropriate device
transfer operation, then convert to float32. Additionally, validate the rank of
the `positions` tensor to ensure it has the expected shape for the downstream
operations.
---
Nitpick comments:
In `@tests/test_rope.py`:
- Around line 53-126: Add two new test methods to the
TestNativeRoPEOpCorrectness class to validate the operator's input contracts:
create a test_odd_head_dim_raises_error method that verifies
NativeRoPEOp.forward_fp32 raises an appropriate error when given tensors with
odd head_dim, and create a test_unsupported_positions_rank_raises_error method
that verifies the operator rejects positions tensors with unsupported rank
(e.g., rank > 2). Use pytest.raises or equivalent assertion to confirm the
expected exceptions are raised for each invalid input scenario.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a4031d4f-94d1-4365-abe7-6f45f2958521
📒 Files selected for processing (6)
docs/operators/README.mddocs/operators/rope.mdrl_engine/kernels/ops/pytorch/rotary_embedding/__init__.pyrl_engine/kernels/ops/pytorch/rotary_embedding/rope.pyrl_engine/kernels/registry.pytests/test_rope.py
| | Argument | Shape | Dtype | Requirements | | ||
| | --- | --- | --- | --- | | ||
| | `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; Qwen3 uses `D=128`. | | ||
| | `positions` | `[S]` or `[B, S]` | Integer | Absolute token positions. | | ||
| | `theta` | scalar | float | Defaults to `1_000_000.0` for Qwen3. | | ||
| | Output | `[B, H, S, D]` | See below | Same shape as `x`. | |
There was a problem hiding this comment.
State the even-head-dim precondition.
NativeRoPEOp splits D into two halves, so the contract should explicitly say D must be even. Without that, odd-width inputs look supported here but won’t preserve the intended rotate-half semantics.
📌 Suggested doc tweak
-| `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; Qwen3 uses `D=128`. |
+| `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; `D` must be even; Qwen3 uses `D=128`. |📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| | Argument | Shape | Dtype | Requirements | | |
| | --- | --- | --- | --- | | |
| | `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; Qwen3 uses `D=128`. | | |
| | `positions` | `[S]` or `[B, S]` | Integer | Absolute token positions. | | |
| | `theta` | scalar | float | Defaults to `1_000_000.0` for Qwen3. | | |
| | Output | `[B, H, S, D]` | See below | Same shape as `x`. | | |
| | Argument | Shape | Dtype | Requirements | | |
| | --- | --- | --- | --- | | |
| | `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; `D` must be even; Qwen3 uses `D=128`. | | |
| | `positions` | `[S]` or `[B, S]` | Integer | Absolute token positions. | | |
| | `theta` | scalar | float | Defaults to `1_000_000.0` for Qwen3. | | |
| | Output | `[B, H, S, D]` | See below | Same shape as `x`. | |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@docs/operators/rope.md` around lines 38 - 43, In the documentation table for
NativeRoPEOp, the `x` argument (Query or key tensor with shape [B, H, S, D]) is
missing an explicit precondition. Since NativeRoPEOp splits D into two halves
internally, add a requirement stating that D must be even to the Requirements
column for the x argument. This ensures users understand that odd-width
dimension inputs are not supported and clarifies the expected behavior of the
rotate-half semantics.
| D = x.shape[-1] | ||
| half = D // 2 | ||
|
|
There was a problem hiding this comment.
Add explicit even head_dim validation before building RoPE tables.
Odd D currently fails later via shape mismatch; fail fast with a clear error at the contract boundary.
Proposed fix
D = x.shape[-1]
+ if D % 2 != 0:
+ raise ValueError(f"RoPE requires even head_dim, got {D}")
half = D // 2Also applies to: 85-86
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py` around lines 69 - 71,
Add explicit validation to ensure the head dimension D is even before proceeding
with RoPE table construction. After extracting D from x.shape[-1] (around line
69), insert a check to verify that D is even using modulo operator, and raise a
clear ValueError if D is odd to fail fast at the contract boundary. Apply the
same validation pattern at the second location mentioned (lines 85-86) where
similar dimension extraction occurs.
| pos_float = positions.float().unsqueeze(-1) | ||
|
|
||
| # freqs: [S, half] or [B, S, half] | ||
| freqs = pos_float * inv_freq |
There was a problem hiding this comment.
Normalize positions to x.device (and validate rank) in _compute_cos_sin.
Current code can crash on mixed-device inputs (x on accelerator, positions on CPU). Convert positions to fp32 on x.device before frequency math.
Proposed fix
- pos_float = positions.float().unsqueeze(-1)
+ if positions.dim() not in (1, 2):
+ raise ValueError(f"positions must have shape [S] or [B,S], got dim={positions.dim()}")
+ pos_float = positions.to(device=x.device, dtype=torch.float32).unsqueeze(-1)Also applies to: 92-99
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py` around lines 80 - 83,
In the `_compute_cos_sin` function, the `positions` tensor may reside on a
different device than `x` (e.g., `x` on GPU, `positions` on CPU), which causes
device mismatch errors during the frequency computation. Before converting
`positions` to float and computing the frequency values in the line where
`pos_float = positions.float().unsqueeze(-1)` occurs, first ensure `positions`
is moved to the same device as `x` using the appropriate device transfer
operation, then convert to float32. Additionally, validate the rank of the
`positions` tensor to ensure it has the expected shape for the downstream
operations.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/kernels/ops/pytorch/linear/matmul.py`:
- Around line 25-27: The forward method accepts mixed input dtypes without
validation and silently casts the result to a.dtype, which can hide upstream
mistakes. Add an explicit dtype compatibility check at the beginning of the
forward method that validates tensors a and b have compatible or identical
dtypes, raising a clear error if they don't match. Only proceed with the
forward_fp32 computation and dtype casting after this validation passes.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: d71dc47d-9508-47c9-a83c-852d642341a6
📒 Files selected for processing (6)
docs/operators/README.mddocs/operators/matmul.mdrl_engine/kernels/ops/pytorch/linear/__init__.pyrl_engine/kernels/ops/pytorch/linear/matmul.pyrl_engine/kernels/registry.pytests/test_matmul.py
✅ Files skipped from review due to trivial changes (2)
- docs/operators/README.md
- docs/operators/matmul.md
🚧 Files skipped from review as they are similar to previous changes (1)
- rl_engine/kernels/registry.py
| def forward(self, a: Tensor, b: Tensor) -> Tensor: | ||
| """Compute `a @ b` and return the input dtype.""" | ||
| return self.forward_fp32(a, b).to(dtype=a.dtype) |
There was a problem hiding this comment.
Validate dtype compatibility before casting output to a.dtype.
On Line 27, mixed input dtypes are silently accepted and the result is forced to a.dtype, which can hide upstream mistakes and unintentionally downcast results. Add an explicit dtype check (or documented promotion rule) before the cast.
Suggested patch
def forward(self, a: Tensor, b: Tensor) -> Tensor:
"""Compute `a @ b` and return the input dtype."""
+ if a.dtype != b.dtype:
+ raise TypeError(f"NativeMatmulOp.forward expects matching dtypes, got {a.dtype} and {b.dtype}")
return self.forward_fp32(a, b).to(dtype=a.dtype)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/kernels/ops/pytorch/linear/matmul.py` around lines 25 - 27, The
forward method accepts mixed input dtypes without validation and silently casts
the result to a.dtype, which can hide upstream mistakes. Add an explicit dtype
compatibility check at the beginning of the forward method that validates
tensors a and b have compatible or identical dtypes, raising a clear error if
they don't match. Only proceed with the forward_fp32 computation and dtype
casting after this validation passes.
|
please resolve conflicts and CI error first, Thanks. |
Summary
Implements
NativeRoPEOp, the fp32 ground-truth RoPE reference operator for ISSUE #108.theta=1e6(not the 1e4 default),head_dim=128, full-dimension rotation (half=64).forward()follows input dtype;forward_fp32()is the fp32 gold standard with internal fp32 accumulation. Both implemented (fp16 and fp32 paths tested).positions/theta— no external cache accepted or returned.Changes
rl_engine/kernels/ops/pytorch/pos/rope.py—NativeRoPEOp(op_class = "elementwise")docs/operators/rope.md— operator reference doc (shapes, convention, tolerance contract)tests/ops/test_rope.py— full test suite, all green locally:torch.equalbetween batch=1 slice and batch=N slice, including padded batches and per-row 2Dpositions([B, S])forwardvsforward_fp32under tolerance contract (fp32: 1e-5, bf16: 2e-2, fp16: 1e-3)ISSUE_108_OPS_DEV.md§5 reference formula[S]vs[B,S]), pure-function checksSummary by CodeRabbit
New Features
Documentation
Tests