Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/operators/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ Every operator page should include:

- [Fused LogP](fused-logp.md)
- [GRPO Loss](grpo-loss.md)
- [RoPE](rope.md)
- [Sampling](sampling.md)
- [Operator Doc Template](../contributing/operator-doc-template.md)
98 changes: 98 additions & 0 deletions docs/operators/rope.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# RoPE

RoPE applies rotary position embeddings to per-head query or key tensors. The
current implementation is a pure PyTorch reference operator for Issue #108
ground-truth validation; it is not a fused CUDA or Triton kernel.

This page documents the PyTorch baseline version.

## Entry Point

```python
from rl_engine.kernels.registry import kernel_registry

rope = kernel_registry.get_op("rope")
output = rope.forward(x, positions, theta=1_000_000.0)
reference = rope.forward_fp32(x, positions, theta=1_000_000.0)
```

The operator can also be imported directly:

```python
from rl_engine.kernels.ops.pytorch.rotary_embedding import NativeRoPEOp

rope = NativeRoPEOp()
```

## Backend

| Backend | Wrapper | Native symbol | Notes |
| --- | --- | --- | --- |
| PyTorch native | `NativeRoPEOp` | None | Reference baseline for Qwen3-style RoPE. |

`kernel_registry.get_op("rope")` dispatches to the PyTorch native backend on CPU,
CUDA, and ROCm. CUDA/Triton fused RoPE kernels should compare against this reference.

## Tensor Contract

| Argument | Shape | Dtype | Requirements |
| --- | --- | --- | --- |
| `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; Qwen3 uses `D=128`. |
| `positions` | `[S]` or `[B, S]` | Integer | Absolute token positions. |
| `theta` | scalar | float | Defaults to `1_000_000.0` for Qwen3. |
| Output | `[B, H, S, D]` | See below | Same shape as `x`. |
Comment on lines +38 to +43

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

State the even-head-dim precondition.

NativeRoPEOp splits D into two halves, so the contract should explicitly say D must be even. Without that, odd-width inputs look supported here but won’t preserve the intended rotate-half semantics.

📌 Suggested doc tweak
-| `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; Qwen3 uses `D=128`. |
+| `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; `D` must be even; Qwen3 uses `D=128`. |
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
| Argument | Shape | Dtype | Requirements |
| --- | --- | --- | --- |
| `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; Qwen3 uses `D=128`. |
| `positions` | `[S]` or `[B, S]` | Integer | Absolute token positions. |
| `theta` | scalar | float | Defaults to `1_000_000.0` for Qwen3. |
| Output | `[B, H, S, D]` | See below | Same shape as `x`. |
| Argument | Shape | Dtype | Requirements |
| --- | --- | --- | --- |
| `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; `D` must be even; Qwen3 uses `D=128`. |
| `positions` | `[S]` or `[B, S]` | Integer | Absolute token positions. |
| `theta` | scalar | float | Defaults to `1_000_000.0` for Qwen3. |
| Output | `[B, H, S, D]` | See below | Same shape as `x`. |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/operators/rope.md` around lines 38 - 43, In the documentation table for
NativeRoPEOp, the `x` argument (Query or key tensor with shape [B, H, S, D]) is
missing an explicit precondition. Since NativeRoPEOp splits D into two halves
internally, add a requirement stating that D must be even to the Requirements
column for the x argument. This ensures users understand that odd-width
dimension inputs are not supported and clarifies the expected behavior of the
rotate-half semantics.


`forward(...)` returns the input dtype. `forward_fp32(...)` computes and returns
`float32` and is the gold-standard reference path.

## Reference Semantics

The implementation uses the Hugging Face rotate-half convention, pairing dimensions
`(i, i + D/2)` rather than adjacent dimensions.

```python
half = x.shape[-1] // 2
inv_freq = 1.0 / (theta ** (torch.arange(0, half, dtype=torch.float32) / half))
freqs = positions.float()[..., None] * inv_freq
cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1)
sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1)

a, b = x.float()[..., :half], x.float()[..., half:]
rotated = torch.cat([-b, a], dim=-1)
out = x.float() * cos + rotated * sin
```

For Qwen3-8B validation, RoPE is applied after QK-Norm and before attention:

```text
RMSNorm(q), RMSNorm(k) -> RoPE(theta=1e6) -> attention
```

## Accuracy

RoPE is categorized as an `elementwise` operator in the numerical contract.
Expected comparison behavior:

| Path | Expected dtype | Purpose |
| --- | --- | --- |
| `forward` | Same as `x.dtype` | Candidate dtype behavior. |
| `forward_fp32` | `torch.float32` | Deterministic reference output. |

Batch invariance is expected to be bitwise: applying RoPE to a full batch and then
slicing a row must match applying RoPE to that row alone.

## Tests

```bash
python -m pytest tests/test_rope.py -q
```

The test covers shape, dtype behavior, HF rotate-half equivalence, `positions`
as `[S]` and `[B, S]`, batch invariance, and Qwen3 query/key head shapes.

## Implementation Files

- `rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py`
- `rl_engine/kernels/ops/pytorch/rotary_embedding/__init__.py`
- `rl_engine/kernels/registry.py`
- `tests/test_rope.py`
6 changes: 6 additions & 0 deletions rl_engine/kernels/ops/pytorch/rotary_embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2026 RL-Kernel Contributors

from rl_engine.kernels.ops.pytorch.rotary_embedding.rope import NativeRoPEOp

__all__ = ["NativeRoPEOp"]
92 changes: 92 additions & 0 deletions rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2026 RL-Kernel Contributors

from __future__ import annotations

import torch
from torch import Tensor


class NativeRoPEOp:
"""Pure PyTorch reference RoPE — GPT-NeoX style (HF rotate-half).

Qwen3-8B defaults: theta=1e6, head_dim=128, full-dimension rotation (half=64).
Dimension pairing: (i, i+half) — NOT adjacent (i, i+1).
cos/sin are computed internally in fp32 from positions and theta — no external
cos/sin cache is accepted or returned.
"""

op_class = "elementwise"

def __init__(self) -> None:
pass

def __call__(self, x: Tensor, positions: Tensor, *, theta: float = 1_000_000.0) -> Tensor:
return self.forward(x, positions, theta=theta)

def forward(self, x: Tensor, positions: Tensor, *, theta: float = 1_000_000.0) -> Tensor:
"""Apply RoPE in input dtype. Cos/sin always computed in fp32."""
cos, sin = self._compute_cos_sin(x, positions, theta=theta)
xf = x
x1, x2 = xf[..., : xf.shape[-1] // 2], xf[..., xf.shape[-1] // 2 :]
rotated = torch.cat([-x2, x1], dim=-1)
out = xf * cos + rotated * sin
return out.to(dtype=x.dtype)

def forward_fp32(self, x: Tensor, positions: Tensor, *, theta: float = 1_000_000.0) -> Tensor:
"""fp32 gold standard: internal computation and output are fp32."""
cos, sin = self._compute_cos_sin(x, positions, theta=theta)
xf = x.float()
x1, x2 = xf[..., : xf.shape[-1] // 2], xf[..., xf.shape[-1] // 2 :]
rotated = torch.cat([-x2, x1], dim=-1)
out = xf * cos + rotated * sin
return out

# ------------------------------------------------------------------ #
# Helpers
# ------------------------------------------------------------------ #

@staticmethod
def _compute_cos_sin(x: Tensor, positions: Tensor, *, theta: float) -> tuple[Tensor, Tensor]:
"""Compute cos/sin tables in fp32 from positions and theta.

Args:
x: [..., D] — only x.shape[-1] (head_dim) is used.
positions: [S] or [B, S] int64 — absolute token positions.
theta: RoPE base frequency (Qwen3 = 1e6).

Returns:
cos, sin: broadcastable to x shape, fp32.
"""
D = x.shape[-1]
half = D // 2

Comment on lines +61 to +63

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Add explicit even head_dim validation before building RoPE tables.

Odd D currently fails later via shape mismatch; fail fast with a clear error at the contract boundary.

Proposed fix
         D = x.shape[-1]
+        if D % 2 != 0:
+            raise ValueError(f"RoPE requires even head_dim, got {D}")
         half = D // 2

Also applies to: 85-86

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py` around lines 69 - 71,
Add explicit validation to ensure the head dimension D is even before proceeding
with RoPE table construction. After extracting D from x.shape[-1] (around line
69), insert a check to verify that D is even using modulo operator, and raise a
clear ValueError if D is odd to fail fast at the contract boundary. Apply the
same validation pattern at the second location mentioned (lines 85-86) where
similar dimension extraction occurs.

# inv_freq[i] = 1 / (theta^(2i/D)) = 1 / (theta^(i/half))
# shape: [half]
inv_freq = 1.0 / (
theta ** (torch.arange(0, half, dtype=torch.float32, device=x.device) / half)
)

# positions: [S] -> [S, 1] or [B, S] -> [B, S, 1]
pos_float = positions.float().unsqueeze(-1)

# freqs: [S, half] or [B, S, half]
freqs = pos_float * inv_freq
Comment on lines +71 to +74

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Normalize positions to x.device (and validate rank) in _compute_cos_sin.

Current code can crash on mixed-device inputs (x on accelerator, positions on CPU). Convert positions to fp32 on x.device before frequency math.

Proposed fix
-        pos_float = positions.float().unsqueeze(-1)
+        if positions.dim() not in (1, 2):
+            raise ValueError(f"positions must have shape [S] or [B,S], got dim={positions.dim()}")
+        pos_float = positions.to(device=x.device, dtype=torch.float32).unsqueeze(-1)

Also applies to: 92-99

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py` around lines 80 - 83,
In the `_compute_cos_sin` function, the `positions` tensor may reside on a
different device than `x` (e.g., `x` on GPU, `positions` on CPU), which causes
device mismatch errors during the frequency computation. Before converting
`positions` to float and computing the frequency values in the line where
`pos_float = positions.float().unsqueeze(-1)` occurs, first ensure `positions`
is moved to the same device as `x` using the appropriate device transfer
operation, then convert to float32. Additionally, validate the rank of the
`positions` tensor to ensure it has the expected shape for the downstream
operations.


# Duplicate to full dim: [S, D] or [B, S, D]
emb = torch.cat([freqs, freqs], dim=-1)

cos = emb.cos()
sin = emb.sin()

# Reshape for broadcasting with x: [B, H, S, D]
if positions.dim() == 1:
# positions [S] -> cos/sin [1, 1, S, D]
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
else:
# positions [B, S] -> cos/sin [B, 1, S, D]
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)

return cos, sin
4 changes: 4 additions & 0 deletions rl_engine/kernels/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class OpBackend(Enum, metaclass=_KernelEnumMeta):
# Generic fallback
TRITON_GENERIC = "rl_engine.kernels.ops.triton.generic.TritonOp"
PYTORCH_NATIVE = "rl_engine.kernels.ops.pytorch.loss.logp.NativeLogpOp"
PYTORCH_NATIVE_ROPE = "rl_engine.kernels.ops.pytorch.rotary_embedding.rope.NativeRoPEOp"


class KernelRegistry:
Expand Down Expand Up @@ -75,16 +76,19 @@ def __init__(self):
"attn": [OpBackend.FLASH_ATTN, OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_NATIVE],
"grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS],
# Default dispatch logic for new operators
"rope": [OpBackend.PYTORCH_NATIVE_ROPE],
},
"rocm": {
"logp": [OpBackend.ROCM_AITER, OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_NATIVE],
"attn": [OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_NATIVE],
"grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS],
"rope": [OpBackend.PYTORCH_NATIVE_ROPE],
},
"cpu": {
"logp": [OpBackend.PYTORCH_NATIVE],
"attn": [OpBackend.PYTORCH_NATIVE],
"grpo_loss": [OpBackend.PYTORCH_GRPO_LOSS],
"rope": [OpBackend.PYTORCH_NATIVE_ROPE],
},
}
logger.info(f"KernelRegistry initialized for {device_ctx.device_type}")
Expand Down
Loading
Loading