diff --git a/docs/operators/README.md b/docs/operators/README.md index 49ab34b..18ee19e 100644 --- a/docs/operators/README.md +++ b/docs/operators/README.md @@ -20,5 +20,6 @@ Every operator page should include: - [Fused LogP](fused-logp.md) - [GRPO Loss](grpo-loss.md) +- [RoPE](rope.md) - [Sampling](sampling.md) - [Operator Doc Template](../contributing/operator-doc-template.md) diff --git a/docs/operators/rope.md b/docs/operators/rope.md new file mode 100644 index 0000000..cfc171b --- /dev/null +++ b/docs/operators/rope.md @@ -0,0 +1,98 @@ +# RoPE + +RoPE applies rotary position embeddings to per-head query or key tensors. The +current implementation is a pure PyTorch reference operator for Issue #108 +ground-truth validation; it is not a fused CUDA or Triton kernel. + +This page documents the PyTorch baseline version. + +## Entry Point + +```python +from rl_engine.kernels.registry import kernel_registry + +rope = kernel_registry.get_op("rope") +output = rope.forward(x, positions, theta=1_000_000.0) +reference = rope.forward_fp32(x, positions, theta=1_000_000.0) +``` + +The operator can also be imported directly: + +```python +from rl_engine.kernels.ops.pytorch.rotary_embedding import NativeRoPEOp + +rope = NativeRoPEOp() +``` + +## Backend + +| Backend | Wrapper | Native symbol | Notes | +| --- | --- | --- | --- | +| PyTorch native | `NativeRoPEOp` | None | Reference baseline for Qwen3-style RoPE. | + +`kernel_registry.get_op("rope")` dispatches to the PyTorch native backend on CPU, +CUDA, and ROCm. CUDA/Triton fused RoPE kernels should compare against this reference. + +## Tensor Contract + +| Argument | Shape | Dtype | Requirements | +| --- | --- | --- | --- | +| `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; Qwen3 uses `D=128`. | +| `positions` | `[S]` or `[B, S]` | Integer | Absolute token positions. | +| `theta` | scalar | float | Defaults to `1_000_000.0` for Qwen3. | +| Output | `[B, H, S, D]` | See below | Same shape as `x`. | + +`forward(...)` returns the input dtype. `forward_fp32(...)` computes and returns +`float32` and is the gold-standard reference path. + +## Reference Semantics + +The implementation uses the Hugging Face rotate-half convention, pairing dimensions +`(i, i + D/2)` rather than adjacent dimensions. + +```python +half = x.shape[-1] // 2 +inv_freq = 1.0 / (theta ** (torch.arange(0, half, dtype=torch.float32) / half)) +freqs = positions.float()[..., None] * inv_freq +cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) +sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) + +a, b = x.float()[..., :half], x.float()[..., half:] +rotated = torch.cat([-b, a], dim=-1) +out = x.float() * cos + rotated * sin +``` + +For Qwen3-8B validation, RoPE is applied after QK-Norm and before attention: + +```text +RMSNorm(q), RMSNorm(k) -> RoPE(theta=1e6) -> attention +``` + +## Accuracy + +RoPE is categorized as an `elementwise` operator in the numerical contract. +Expected comparison behavior: + +| Path | Expected dtype | Purpose | +| --- | --- | --- | +| `forward` | Same as `x.dtype` | Candidate dtype behavior. | +| `forward_fp32` | `torch.float32` | Deterministic reference output. | + +Batch invariance is expected to be bitwise: applying RoPE to a full batch and then +slicing a row must match applying RoPE to that row alone. + +## Tests + +```bash +python -m pytest tests/test_rope.py -q +``` + +The test covers shape, dtype behavior, HF rotate-half equivalence, `positions` +as `[S]` and `[B, S]`, batch invariance, and Qwen3 query/key head shapes. + +## Implementation Files + +- `rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py` +- `rl_engine/kernels/ops/pytorch/rotary_embedding/__init__.py` +- `rl_engine/kernels/registry.py` +- `tests/test_rope.py` diff --git a/rl_engine/kernels/ops/pytorch/rotary_embedding/__init__.py b/rl_engine/kernels/ops/pytorch/rotary_embedding/__init__.py new file mode 100644 index 0000000..6054d4c --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/rotary_embedding/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from rl_engine.kernels.ops.pytorch.rotary_embedding.rope import NativeRoPEOp + +__all__ = ["NativeRoPEOp"] diff --git a/rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py b/rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py new file mode 100644 index 0000000..72eb93c --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import torch +from torch import Tensor + + +class NativeRoPEOp: + """Pure PyTorch reference RoPE — GPT-NeoX style (HF rotate-half). + + Qwen3-8B defaults: theta=1e6, head_dim=128, full-dimension rotation (half=64). + Dimension pairing: (i, i+half) — NOT adjacent (i, i+1). + cos/sin are computed internally in fp32 from positions and theta — no external + cos/sin cache is accepted or returned. + """ + + op_class = "elementwise" + + def __init__(self) -> None: + pass + + def __call__(self, x: Tensor, positions: Tensor, *, theta: float = 1_000_000.0) -> Tensor: + return self.forward(x, positions, theta=theta) + + def forward(self, x: Tensor, positions: Tensor, *, theta: float = 1_000_000.0) -> Tensor: + """Apply RoPE in input dtype. Cos/sin always computed in fp32.""" + cos, sin = self._compute_cos_sin(x, positions, theta=theta) + xf = x + x1, x2 = xf[..., : xf.shape[-1] // 2], xf[..., xf.shape[-1] // 2 :] + rotated = torch.cat([-x2, x1], dim=-1) + out = xf * cos + rotated * sin + return out.to(dtype=x.dtype) + + def forward_fp32(self, x: Tensor, positions: Tensor, *, theta: float = 1_000_000.0) -> Tensor: + """fp32 gold standard: internal computation and output are fp32.""" + cos, sin = self._compute_cos_sin(x, positions, theta=theta) + xf = x.float() + x1, x2 = xf[..., : xf.shape[-1] // 2], xf[..., xf.shape[-1] // 2 :] + rotated = torch.cat([-x2, x1], dim=-1) + out = xf * cos + rotated * sin + return out + + # ------------------------------------------------------------------ # + # Helpers + # ------------------------------------------------------------------ # + + @staticmethod + def _compute_cos_sin(x: Tensor, positions: Tensor, *, theta: float) -> tuple[Tensor, Tensor]: + """Compute cos/sin tables in fp32 from positions and theta. + + Args: + x: [..., D] — only x.shape[-1] (head_dim) is used. + positions: [S] or [B, S] int64 — absolute token positions. + theta: RoPE base frequency (Qwen3 = 1e6). + + Returns: + cos, sin: broadcastable to x shape, fp32. + """ + D = x.shape[-1] + half = D // 2 + + # inv_freq[i] = 1 / (theta^(2i/D)) = 1 / (theta^(i/half)) + # shape: [half] + inv_freq = 1.0 / ( + theta ** (torch.arange(0, half, dtype=torch.float32, device=x.device) / half) + ) + + # positions: [S] -> [S, 1] or [B, S] -> [B, S, 1] + pos_float = positions.float().unsqueeze(-1) + + # freqs: [S, half] or [B, S, half] + freqs = pos_float * inv_freq + + # Duplicate to full dim: [S, D] or [B, S, D] + emb = torch.cat([freqs, freqs], dim=-1) + + cos = emb.cos() + sin = emb.sin() + + # Reshape for broadcasting with x: [B, H, S, D] + if positions.dim() == 1: + # positions [S] -> cos/sin [1, 1, S, D] + cos = cos.unsqueeze(0).unsqueeze(0) + sin = sin.unsqueeze(0).unsqueeze(0) + else: + # positions [B, S] -> cos/sin [B, 1, S, D] + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + + return cos, sin diff --git a/rl_engine/kernels/registry.py b/rl_engine/kernels/registry.py index 7d85d26..20efc5c 100644 --- a/rl_engine/kernels/registry.py +++ b/rl_engine/kernels/registry.py @@ -40,6 +40,7 @@ class OpBackend(Enum, metaclass=_KernelEnumMeta): # Generic fallback TRITON_GENERIC = "rl_engine.kernels.ops.triton.generic.TritonOp" PYTORCH_NATIVE = "rl_engine.kernels.ops.pytorch.loss.logp.NativeLogpOp" + PYTORCH_NATIVE_ROPE = "rl_engine.kernels.ops.pytorch.rotary_embedding.rope.NativeRoPEOp" class KernelRegistry: @@ -75,16 +76,19 @@ def __init__(self): "attn": [OpBackend.FLASH_ATTN, OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_NATIVE], "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], # Default dispatch logic for new operators + "rope": [OpBackend.PYTORCH_NATIVE_ROPE], }, "rocm": { "logp": [OpBackend.ROCM_AITER, OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_NATIVE], "attn": [OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_NATIVE], "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], + "rope": [OpBackend.PYTORCH_NATIVE_ROPE], }, "cpu": { "logp": [OpBackend.PYTORCH_NATIVE], "attn": [OpBackend.PYTORCH_NATIVE], "grpo_loss": [OpBackend.PYTORCH_GRPO_LOSS], + "rope": [OpBackend.PYTORCH_NATIVE_ROPE], }, } logger.info(f"KernelRegistry initialized for {device_ctx.device_type}") diff --git a/tests/test_rope.py b/tests/test_rope.py new file mode 100644 index 0000000..21b8479 --- /dev/null +++ b/tests/test_rope.py @@ -0,0 +1,241 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +"""Tests for NativeRoPEOp — fp32 gold standard RoPE (HF rotate-half convention). + +Validates: +- Axis A: batch invariance (bitwise torch.equal between batch=1 slice and batch=N slice) +- Axis B: accuracy (forward vs forward_fp32 under tolerance_contract thresholds) +- Functional correctness: pure function, dtype, shape, positions [S] vs [B,S] +""" + +from __future__ import annotations + +import pytest +import torch + +from rl_engine.kernels.ops.pytorch.rotary_embedding.rope import NativeRoPEOp + +# --------------------------------------------------------------------------- +# Fixtures & helpers +# --------------------------------------------------------------------------- + +QWEN3_HEAD_DIM = 128 +QWEN3_THETA = 1_000_000.0 + + +def _make_inputs(batch: int, n_heads: int, seq: int, head_dim: int, seed: int = 42): + """Deterministic RoPE inputs.""" + gen = torch.Generator().manual_seed(seed) + x = torch.randn(batch, n_heads, seq, head_dim, generator=gen) + positions = torch.arange(seq, dtype=torch.long) + return x, positions + + +def _hf_rotate_half_reference(x, positions, theta=1e6): + """Independent HF rotate-half reference (from ISSUE_108_OPS_DEV §5).""" + D = x.shape[-1] + half = D // 2 + inv_freq = 1.0 / (theta ** (torch.arange(0, half, dtype=torch.float32) / half)) + freqs = positions.float()[:, None] * inv_freq[None, :] + cos = torch.cat([freqs.cos()] * 2, dim=-1) + sin = torch.cat([freqs.sin()] * 2, dim=-1) + a, b = x.float()[..., :half], x.float()[..., half:] + rotated = torch.cat([-b, a], dim=-1) + return x.float() * cos + rotated * sin + + +# --------------------------------------------------------------------------- +# Correctness +# --------------------------------------------------------------------------- + + +class TestNativeRoPEOpCorrectness: + """Basic correctness: shape, dtype, purity, HF reference match.""" + + def test_output_shape_matches_input(self): + op = NativeRoPEOp() + x, pos = _make_inputs(2, 32, 16, QWEN3_HEAD_DIM) + out = op.forward_fp32(x, pos, theta=QWEN3_THETA) + assert out.shape == x.shape + + def test_forward_fp32_returns_fp32(self): + op = NativeRoPEOp() + x, pos = _make_inputs(2, 32, 16, QWEN3_HEAD_DIM) + out = op.forward_fp32(x, pos) + assert out.dtype == torch.float32 + + def test_forward_fp32_returns_fp32_even_with_bf16_input(self): + op = NativeRoPEOp() + x, pos = _make_inputs(2, 32, 16, QWEN3_HEAD_DIM) + out = op.forward_fp32(x.bfloat16(), pos) + assert out.dtype == torch.float32 + + def test_call_equals_forward(self): + op = NativeRoPEOp() + x, pos = _make_inputs(2, 32, 16, QWEN3_HEAD_DIM) + assert torch.equal(op(x, pos, theta=QWEN3_THETA), op.forward(x, pos, theta=QWEN3_THETA)) + + def test_pure_function_no_inplace(self): + op = NativeRoPEOp() + x, pos = _make_inputs(2, 32, 16, QWEN3_HEAD_DIM) + x_orig = x.clone() + _ = op.forward_fp32(x, pos) + assert torch.equal(x, x_orig), "forward_fp32 modified input in-place" + + def test_matches_hf_reference_bitwise(self): + """NativeRoPEOp must be bitwise identical to the ISSUE_108_OPS_DEV §5 reference.""" + op = NativeRoPEOp() + x, pos = _make_inputs(2, 32, 16, QWEN3_HEAD_DIM) + our = op.forward_fp32(x, pos, theta=QWEN3_THETA) + ref = _hf_rotate_half_reference(x, pos, theta=QWEN3_THETA) + assert torch.equal(our, ref), ( + f"Not bitwise match with HF reference, max diff: " + f"{(our - ref).abs().max().item():.2e}" + ) + + def test_positions_1d_and_2d_equivalent(self): + """positions [S] and [B, S] (with identical values) must produce same output.""" + op = NativeRoPEOp() + B, H, S, D = 3, 32, 16, QWEN3_HEAD_DIM + x, pos_1d = _make_inputs(B, H, S, D) + pos_2d = pos_1d.unsqueeze(0).expand(B, -1) + out_1d = op.forward_fp32(x, pos_1d) + out_2d = op.forward_fp32(x, pos_2d) + assert torch.equal(out_1d, out_2d) + + def test_op_class_is_elementwise(self): + assert NativeRoPEOp.op_class == "elementwise" + + def test_theta_affects_output(self): + """Different theta must produce different results.""" + op = NativeRoPEOp() + x, pos = _make_inputs(2, 32, 16, QWEN3_HEAD_DIM) + out_1e6 = op.forward_fp32(x, pos, theta=1_000_000.0) + out_1e4 = op.forward_fp32(x, pos, theta=10_000.0) + assert not torch.equal(out_1e6, out_1e4) + + def test_position_zero_is_identity_for_cos(self): + """At position 0, cos=1 and sin=0, so output should equal input.""" + op = NativeRoPEOp() + x = torch.randn(1, 1, 1, QWEN3_HEAD_DIM) + pos = torch.zeros(1, dtype=torch.long) + out = op.forward_fp32(x, pos) + # cos(0)=1, sin(0)=0 → out = x*1 + rotate_half(x)*0 = x + assert torch.allclose(out, x.float(), atol=1e-7) + + +# --------------------------------------------------------------------------- +# Axis A — Batch invariance (bitwise) +# --------------------------------------------------------------------------- + + +class TestNativeRoPEOpBatchInvariance: + """Axis A: forward_fp32 must be bitwise batch-invariant. + + Golden rule from ISSUE_108: compute on full input first, then slice — + compare against computing on the single-batch slice alone. + """ + + def test_batch1_vs_batchN_bitwise(self): + op = NativeRoPEOp() + x, pos = _make_inputs(4, 32, 16, QWEN3_HEAD_DIM, seed=99) + full_out = op.forward_fp32(x, pos) + for i in range(x.shape[0]): + single_out = op.forward_fp32(x[i : i + 1], pos) + assert torch.equal(full_out[i], single_out[0]), f"Batch invariance broken at row {i}" + + def test_batch_invariance_with_padding(self): + """Padded batch (extra rows) must not affect valid rows.""" + op = NativeRoPEOp() + x_valid, pos = _make_inputs(2, 32, 16, QWEN3_HEAD_DIM, seed=77) + # Pad with garbage + x_padded = torch.cat([x_valid, torch.randn(3, 32, 16, QWEN3_HEAD_DIM)], dim=0) + out_valid = op.forward_fp32(x_valid, pos) + out_padded = op.forward_fp32(x_padded, pos) + assert torch.equal(out_valid[0], out_padded[0]) + assert torch.equal(out_valid[1], out_padded[1]) + + def test_batch_invariance_bf16(self): + """Axis A must hold for bf16 inputs too (forward_fp32 path).""" + op = NativeRoPEOp() + x, pos = _make_inputs(4, 32, 16, QWEN3_HEAD_DIM, seed=55) + x_bf16 = x.bfloat16() + full_out = op.forward_fp32(x_bf16, pos) + single_out = op.forward_fp32(x_bf16[0:1], pos) + assert torch.equal(full_out[0], single_out[0]) + + def test_batch_invariance_positions_2d(self): + """Axis A with per-batch positions [B, S].""" + op = NativeRoPEOp() + B, H, S, D = 3, 32, 16, QWEN3_HEAD_DIM + x, _ = _make_inputs(B, H, S, D, seed=33) + # Different position offsets per batch item + pos_2d = torch.stack([torch.arange(S) + i * 100 for i in range(B)]) + full_out = op.forward_fp32(x, pos_2d) + for i in range(B): + single_out = op.forward_fp32(x[i : i + 1], pos_2d[i : i + 1]) + assert torch.equal( + full_out[i], single_out[0] + ), f"Batch invariance broken at row {i} with 2D positions" + + +# --------------------------------------------------------------------------- +# Axis B — Accuracy (forward vs forward_fp32) +# --------------------------------------------------------------------------- + + +class TestNativeRoPEOpAccuracy: + """Axis B: forward(input_dtype) vs forward_fp32 under tolerance thresholds. + + RoPE is elementwise → expected tolerance from tolerance_contract.yaml: + float32: atol=1e-5 + bfloat16: atol=2e-2 + float16: atol=1e-3 + """ + + @pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-5, 1e-5), + (torch.bfloat16, 2e-2, 1.6e-2), + (torch.float16, 1e-3, 1e-3), + ], + ) + def test_forward_vs_fp32_within_tolerance(self, dtype, atol, rtol): + op = NativeRoPEOp() + x, pos = _make_inputs(2, 32, 16, QWEN3_HEAD_DIM) + x_typed = x.to(dtype) + out_typed = op.forward(x_typed, pos).float() + out_fp32 = op.forward_fp32(x_typed, pos) + diff = (out_typed - out_fp32).abs().max().item() + assert torch.allclose(out_typed, out_fp32, atol=atol, rtol=rtol), ( + f"dtype={dtype}, max_abs_error={diff:.3e} exceeds " f"atol={atol}, rtol={rtol}" + ) + + +# --------------------------------------------------------------------------- +# Qwen3-8B specific shapes +# --------------------------------------------------------------------------- + + +class TestNativeRoPEOpQwen3Shapes: + """Verify with Qwen3-8B actual dimensions.""" + + @pytest.mark.parametrize( + "batch, seq, label", + [(2, 16, "SMALL"), (4, 512, "MEDIUM")], + ) + def test_qwen3_shape(self, batch, seq, label): + op = NativeRoPEOp() + x, pos = _make_inputs(batch, 32, seq, 128, seed=12) + out = op.forward_fp32(x, pos, theta=1_000_000.0) + assert out.shape == (batch, 32, seq, 128) + assert out.dtype == torch.float32 + + def test_qwen3_kv_heads_shape(self): + """RoPE is also applied to K with n_kv_heads=8.""" + op = NativeRoPEOp() + x_k, pos = _make_inputs(2, 8, 16, 128, seed=13) + out = op.forward_fp32(x_k, pos, theta=1_000_000.0) + assert out.shape == (2, 8, 16, 128)