-
Notifications
You must be signed in to change notification settings - Fork 34
feat(ws1): Add PyTorch RoPE reference operator #167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| # RoPE | ||
|
|
||
| RoPE applies rotary position embeddings to per-head query or key tensors. The | ||
| current implementation is a pure PyTorch reference operator for Issue #108 | ||
| ground-truth validation; it is not a fused CUDA or Triton kernel. | ||
|
|
||
| This page documents the PyTorch baseline version. | ||
|
|
||
| ## Entry Point | ||
|
|
||
| ```python | ||
| from rl_engine.kernels.registry import kernel_registry | ||
|
|
||
| rope = kernel_registry.get_op("rope") | ||
| output = rope.forward(x, positions, theta=1_000_000.0) | ||
| reference = rope.forward_fp32(x, positions, theta=1_000_000.0) | ||
| ``` | ||
|
|
||
| The operator can also be imported directly: | ||
|
|
||
| ```python | ||
| from rl_engine.kernels.ops.pytorch.rotary_embedding import NativeRoPEOp | ||
|
|
||
| rope = NativeRoPEOp() | ||
| ``` | ||
|
|
||
| ## Backend | ||
|
|
||
| | Backend | Wrapper | Native symbol | Notes | | ||
| | --- | --- | --- | --- | | ||
| | PyTorch native | `NativeRoPEOp` | None | Reference baseline for Qwen3-style RoPE. | | ||
|
|
||
| `kernel_registry.get_op("rope")` dispatches to the PyTorch native backend on CPU, | ||
| CUDA, and ROCm. CUDA/Triton fused RoPE kernels should compare against this reference. | ||
|
|
||
| ## Tensor Contract | ||
|
|
||
| | Argument | Shape | Dtype | Requirements | | ||
| | --- | --- | --- | --- | | ||
| | `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; Qwen3 uses `D=128`. | | ||
| | `positions` | `[S]` or `[B, S]` | Integer | Absolute token positions. | | ||
| | `theta` | scalar | float | Defaults to `1_000_000.0` for Qwen3. | | ||
| | Output | `[B, H, S, D]` | See below | Same shape as `x`. | | ||
|
|
||
| `forward(...)` returns the input dtype. `forward_fp32(...)` computes and returns | ||
| `float32` and is the gold-standard reference path. | ||
|
|
||
| ## Reference Semantics | ||
|
|
||
| The implementation uses the Hugging Face rotate-half convention, pairing dimensions | ||
| `(i, i + D/2)` rather than adjacent dimensions. | ||
|
|
||
| ```python | ||
| half = x.shape[-1] // 2 | ||
| inv_freq = 1.0 / (theta ** (torch.arange(0, half, dtype=torch.float32) / half)) | ||
| freqs = positions.float()[..., None] * inv_freq | ||
| cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) | ||
| sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) | ||
|
|
||
| a, b = x.float()[..., :half], x.float()[..., half:] | ||
| rotated = torch.cat([-b, a], dim=-1) | ||
| out = x.float() * cos + rotated * sin | ||
| ``` | ||
|
|
||
| For Qwen3-8B validation, RoPE is applied after QK-Norm and before attention: | ||
|
|
||
| ```text | ||
| RMSNorm(q), RMSNorm(k) -> RoPE(theta=1e6) -> attention | ||
| ``` | ||
|
|
||
| ## Accuracy | ||
|
|
||
| RoPE is categorized as an `elementwise` operator in the numerical contract. | ||
| Expected comparison behavior: | ||
|
|
||
| | Path | Expected dtype | Purpose | | ||
| | --- | --- | --- | | ||
| | `forward` | Same as `x.dtype` | Candidate dtype behavior. | | ||
| | `forward_fp32` | `torch.float32` | Deterministic reference output. | | ||
|
|
||
| Batch invariance is expected to be bitwise: applying RoPE to a full batch and then | ||
| slicing a row must match applying RoPE to that row alone. | ||
|
|
||
| ## Tests | ||
|
|
||
| ```bash | ||
| python -m pytest tests/test_rope.py -q | ||
| ``` | ||
|
|
||
| The test covers shape, dtype behavior, HF rotate-half equivalence, `positions` | ||
| as `[S]` and `[B, S]`, batch invariance, and Qwen3 query/key head shapes. | ||
|
|
||
| ## Implementation Files | ||
|
|
||
| - `rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py` | ||
| - `rl_engine/kernels/ops/pytorch/rotary_embedding/__init__.py` | ||
| - `rl_engine/kernels/registry.py` | ||
| - `tests/test_rope.py` | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright (c) 2026 RL-Kernel Contributors | ||
|
|
||
| from rl_engine.kernels.ops.pytorch.rotary_embedding.rope import NativeRoPEOp | ||
|
|
||
| __all__ = ["NativeRoPEOp"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright (c) 2026 RL-Kernel Contributors | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
|
|
||
|
|
||
| class NativeRoPEOp: | ||
| """Pure PyTorch reference RoPE — GPT-NeoX style (HF rotate-half). | ||
|
|
||
| Qwen3-8B defaults: theta=1e6, head_dim=128, full-dimension rotation (half=64). | ||
| Dimension pairing: (i, i+half) — NOT adjacent (i, i+1). | ||
| cos/sin are computed internally in fp32 from positions and theta — no external | ||
| cos/sin cache is accepted or returned. | ||
| """ | ||
|
|
||
| op_class = "elementwise" | ||
|
|
||
| def __init__(self) -> None: | ||
| pass | ||
|
|
||
| def __call__(self, x: Tensor, positions: Tensor, *, theta: float = 1_000_000.0) -> Tensor: | ||
| return self.forward(x, positions, theta=theta) | ||
|
|
||
| def forward(self, x: Tensor, positions: Tensor, *, theta: float = 1_000_000.0) -> Tensor: | ||
| """Apply RoPE in input dtype. Cos/sin always computed in fp32.""" | ||
| cos, sin = self._compute_cos_sin(x, positions, theta=theta) | ||
| xf = x | ||
| x1, x2 = xf[..., : xf.shape[-1] // 2], xf[..., xf.shape[-1] // 2 :] | ||
| rotated = torch.cat([-x2, x1], dim=-1) | ||
| out = xf * cos + rotated * sin | ||
| return out.to(dtype=x.dtype) | ||
|
|
||
| def forward_fp32(self, x: Tensor, positions: Tensor, *, theta: float = 1_000_000.0) -> Tensor: | ||
| """fp32 gold standard: internal computation and output are fp32.""" | ||
| cos, sin = self._compute_cos_sin(x, positions, theta=theta) | ||
| xf = x.float() | ||
| x1, x2 = xf[..., : xf.shape[-1] // 2], xf[..., xf.shape[-1] // 2 :] | ||
| rotated = torch.cat([-x2, x1], dim=-1) | ||
| out = xf * cos + rotated * sin | ||
| return out | ||
|
|
||
| # ------------------------------------------------------------------ # | ||
| # Helpers | ||
| # ------------------------------------------------------------------ # | ||
|
|
||
| @staticmethod | ||
| def _compute_cos_sin(x: Tensor, positions: Tensor, *, theta: float) -> tuple[Tensor, Tensor]: | ||
| """Compute cos/sin tables in fp32 from positions and theta. | ||
|
|
||
| Args: | ||
| x: [..., D] — only x.shape[-1] (head_dim) is used. | ||
| positions: [S] or [B, S] int64 — absolute token positions. | ||
| theta: RoPE base frequency (Qwen3 = 1e6). | ||
|
|
||
| Returns: | ||
| cos, sin: broadcastable to x shape, fp32. | ||
| """ | ||
| D = x.shape[-1] | ||
| half = D // 2 | ||
|
|
||
|
Comment on lines
+61
to
+63
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add explicit even Odd Proposed fix D = x.shape[-1]
+ if D % 2 != 0:
+ raise ValueError(f"RoPE requires even head_dim, got {D}")
half = D // 2Also applies to: 85-86 🤖 Prompt for AI Agents |
||
| # inv_freq[i] = 1 / (theta^(2i/D)) = 1 / (theta^(i/half)) | ||
| # shape: [half] | ||
| inv_freq = 1.0 / ( | ||
| theta ** (torch.arange(0, half, dtype=torch.float32, device=x.device) / half) | ||
| ) | ||
|
|
||
| # positions: [S] -> [S, 1] or [B, S] -> [B, S, 1] | ||
| pos_float = positions.float().unsqueeze(-1) | ||
|
|
||
| # freqs: [S, half] or [B, S, half] | ||
| freqs = pos_float * inv_freq | ||
|
Comment on lines
+71
to
+74
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Normalize Current code can crash on mixed-device inputs ( Proposed fix- pos_float = positions.float().unsqueeze(-1)
+ if positions.dim() not in (1, 2):
+ raise ValueError(f"positions must have shape [S] or [B,S], got dim={positions.dim()}")
+ pos_float = positions.to(device=x.device, dtype=torch.float32).unsqueeze(-1)Also applies to: 92-99 🤖 Prompt for AI Agents |
||
|
|
||
| # Duplicate to full dim: [S, D] or [B, S, D] | ||
| emb = torch.cat([freqs, freqs], dim=-1) | ||
|
|
||
| cos = emb.cos() | ||
| sin = emb.sin() | ||
|
|
||
| # Reshape for broadcasting with x: [B, H, S, D] | ||
| if positions.dim() == 1: | ||
| # positions [S] -> cos/sin [1, 1, S, D] | ||
| cos = cos.unsqueeze(0).unsqueeze(0) | ||
| sin = sin.unsqueeze(0).unsqueeze(0) | ||
| else: | ||
| # positions [B, S] -> cos/sin [B, 1, S, D] | ||
| cos = cos.unsqueeze(1) | ||
| sin = sin.unsqueeze(1) | ||
|
|
||
| return cos, sin | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
State the even-head-dim precondition.
NativeRoPEOpsplitsDinto two halves, so the contract should explicitly sayDmust be even. Without that, odd-width inputs look supported here but won’t preserve the intended rotate-half semantics.📌 Suggested doc tweak
📝 Committable suggestion
🤖 Prompt for AI Agents