From 9bcd65b66938cd7c840d6bdd8ff3452e3444563c Mon Sep 17 00:00:00 2001 From: maxiaosong1124 Date: Sun, 21 Jun 2026 21:42:43 +0800 Subject: [PATCH] feat(ws1): add NativeSiLUOp + NativeSwiGLUOp pure-PyTorch references WS1 ground-truth activation ops for issue #108 (Qwen3-8B gated MLP): - NativeSiLUOp: silu(x) = x * sigmoid(x) - NativeSwiGLUOp: silu(gate) * up (gate/up at intermediate dim) Both expose the forward / forward_fp32 dual-path contract (fp32 ground truth + dtype-behavior path), pure functions, fp32 accumulation. - register PYTORCH_NATIVE_SILU / PYTORCH_NATIVE_SWIGLU in OpBackend and the cuda/rocm/cpu priority maps - tests/test_swiglu.py: correctness vs fp32 formula, dtype paths, Axis-A batch invariance (slice + padding), purity, gradient flow, shape guard, registry dispatch - docs/operators/activation.md + nav/index wiring --- docs/.nav.yml | 1 + docs/operators/README.md | 1 + docs/operators/activation.md | 112 +++++++++++++++ .../ops/pytorch/activation/__init__.py | 2 + .../kernels/ops/pytorch/activation/swiglu.py | 91 +++++++++++++ rl_engine/kernels/registry.py | 8 ++ tests/test_swiglu.py | 127 ++++++++++++++++++ 7 files changed, 342 insertions(+) create mode 100644 docs/operators/activation.md create mode 100644 rl_engine/kernels/ops/pytorch/activation/__init__.py create mode 100644 rl_engine/kernels/ops/pytorch/activation/swiglu.py create mode 100644 tests/test_swiglu.py diff --git a/docs/.nav.yml b/docs/.nav.yml index 6ba2e50..2dd87ec 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -10,6 +10,7 @@ nav: - getting_started/faq.md - Operators: - operators/README.md + - operators/activation.md - operators/fused-logp.md - operators/grpo-loss.md - operators/ratio-kl.md diff --git a/docs/operators/README.md b/docs/operators/README.md index 4bb7e9e..7609170 100644 --- a/docs/operators/README.md +++ b/docs/operators/README.md @@ -18,6 +18,7 @@ Every operator page should include: ## Current Pages +- [SiLU / SwiGLU Activation](activation.md) - [Fused LogP](fused-logp.md) - [GRPO Loss](grpo-loss.md) - [Policy Ratio + KL Penalty](ratio-kl.md) diff --git a/docs/operators/activation.md b/docs/operators/activation.md new file mode 100644 index 0000000..77be71e --- /dev/null +++ b/docs/operators/activation.md @@ -0,0 +1,112 @@ +# SiLU / SwiGLU Activation + +The activation operators are the element-wise core of the Qwen3/Llama gated MLP. They are +**WS1 ground-truth references** (issue #108): pure-PyTorch, fp32-accumulating definitions of +the "correct answer" that downstream fused CUDA/Triton MLP kernels are validated against. + +- **SiLU** (`NativeSiLUOp`): `silu(x) = x * sigmoid(x)` — the `hidden_act="silu"` gate. +- **SwiGLU** (`NativeSwiGLUOp`): `swiglu(gate, up) = silu(gate) * up` — the gated MLP middle + stage. `gate` / `up` are the `gate_proj` / `up_proj` outputs (already at the intermediate + width); the following `down_proj` is a plain Matmul and is **not** part of this operator. + +``` +hidden --gate_proj--> gate --\ + swiglu --> down_proj --> hidden +hidden --up_proj----> up ----/ +``` + +## Entry Point +```python +from rl_engine.kernels.registry import kernel_registry + +silu = kernel_registry.get_op("silu") +swiglu = kernel_registry.get_op("swiglu") + +# SiLU: single element-wise activation +y = silu(x) # [..., N] -> [..., N] + +# SwiGLU: gated activation (gate and up must share shape) +h = swiglu(gate, up) # [..., I], [..., I] -> [..., I] +``` + +Both ops expose the WS1 dual-path contract: + +- `forward(...)` — computes in fp32, casts back to the input dtype (Axis-B accuracy + candidate / dtype-behavior path). +- `forward_fp32(...)` — computes and returns fp32 (the ground-truth golden path). + +## Backends + +| Backend | Wrapper | Native symbol | Status | +| --- | --- | --- | --- | +| PyTorch fallback | `NativeSiLUOp` / `NativeSwiGLUOp` | None | fp32 ground-truth reference; CPU and any GPU. | +| CUDA / ROCm / Triton | — | — | Planned: downstream fused MLP kernels validate against this reference. | + +## Tensor Contract + +| Argument | Shape | Dtype | Requirements | +| --- | --- | --- | --- | +| `x` (SiLU) | `[..., N]` | float (fp16/bf16/fp32) | Any shape; last dim arbitrary (Qwen3-8B `I=12288`). | +| `gate` (SwiGLU) | `[..., I]` | float | `gate_proj` output. | +| `up` (SwiGLU) | `[..., I]` | float | `up_proj` output; **must share `gate`'s shape**. | +| output | same as input | `forward`: input dtype · `forward_fp32`: float32 | Same shape as input. | + +Element-wise and shape-agnostic: the Qwen3-8B intermediate dim `I=12288` is just one valid +last-dim size, not a hard requirement. Pure functions — no randomness, no in-place +mutation, device/dtype follow the inputs. + +## Dispatch Behavior + +`kernel_registry.get_op("silu" | "swiglu")` resolves through the `OpBackend` priority map. +On `cuda` / `rocm` / `cpu` the only registered backend today is the PyTorch native op +(`PYTORCH_NATIVE_SILU` / `PYTORCH_NATIVE_SWIGLU`), so every device dispatches to the +fp32 reference. When fused kernels land, they are prepended to the priority list and the +native op becomes the fallback. + +## Accuracy + +Reference semantics (`forward_fp32`, fp32 accumulation): + +```python +# SiLU +out = x.float() * torch.sigmoid(x.float()) + +# SwiGLU +gate_f = gate.float() +out = gate_f * torch.sigmoid(gate_f) * up.float() +``` + +- **Ground truth**: `forward_fp32` always accumulates in and returns fp32. +- **Dtype path**: `forward` runs the same fp32 math, then casts back to the input dtype; + it is bitwise-equal to `forward_fp32(x).to(dtype)`. +- **Axis A — batch invariance**: element-wise and row-independent, so a row's output is + bitwise-identical regardless of batch size or padding (`torch.equal`, `atol=0`). +- **Axis B — tolerance**: as `elementwise` ops, low-precision tolerance follows the + `elementwise` row of the WS1 numerical contract. + +## Performance Notes + +Reference operators — no fused kernel or benchmark yet. Downstream fused MLP kernels carry +their own benchmarks and are measured against this reference for correctness. + +## Tests + +```bash +python -m pytest tests/test_swiglu.py -v +``` + +Covers: correctness vs an independent fp32 formula, dtype paths, Axis-A batch invariance +(slice + padding), input purity, gradient flow, the SwiGLU shape guard, and registry +dispatch. + +## Implementation Files + +- `rl_engine/kernels/ops/pytorch/activation/swiglu.py` +- `rl_engine/kernels/registry.py` +- `tests/test_swiglu.py` + +## Known Limitations + +- PyTorch fallback only; no fused CUDA/Triton backend yet (downstream work). +- SwiGLU requires `gate` and `up` to share shape (raises `ValueError` otherwise); no + broadcasting. diff --git a/rl_engine/kernels/ops/pytorch/activation/__init__.py b/rl_engine/kernels/ops/pytorch/activation/__init__.py new file mode 100644 index 0000000..86cf4c9 --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/activation/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors diff --git a/rl_engine/kernels/ops/pytorch/activation/swiglu.py b/rl_engine/kernels/ops/pytorch/activation/swiglu.py new file mode 100644 index 0000000..928fee3 --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/activation/swiglu.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import torch + + +class NativeSiLUOp: + """ + Pure PyTorch native SiLU reference. + out = x * sigmoid(x) (a.k.a. Swish) + + Element-wise activation used by the Qwen3 SwiGLU MLP + (hidden_act="silu"). No hyper-parameters and shape-agnostic, so the + Qwen3-8B intermediate dim (12288) is just one valid last-dim size. + """ + + def __init__(self) -> None: + pass + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return self.forward(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Canonical entry: compute in fp32, cast the result back to x.dtype. + This is the dtype-behavior path used as the Axis-B accuracy candidate. + """ + return self._silu(x, output_dtype=x.dtype) + + def forward_fp32(self, x: torch.Tensor) -> torch.Tensor: + """Ground-truth: compute in fp32 and force fp32 output.""" + return self._silu(x, output_dtype=torch.float32) + + # ------------------------------------------------------------------ # + # Helpers + # ------------------------------------------------------------------ # + @staticmethod + def _silu(x: torch.Tensor, *, output_dtype: torch.dtype) -> torch.Tensor: + x_f = x.float() + out = x_f * torch.sigmoid(x_f) + return out.to(output_dtype) + + +class NativeSwiGLUOp: + """ + Pure PyTorch native SwiGLU reference. + out = silu(gate) * up = (gate * sigmoid(gate)) * up + + Middle stage of the Qwen3/Llama MLP: ``gate`` and ``up`` are the + gate_proj / up_proj outputs (already at intermediate dim 12288). The + following down_proj is a plain Matmul and lives in a separate op. + Element-wise and shape-agnostic; ``gate`` and ``up`` must share shape. + """ + + def __init__(self) -> None: + pass + + def __call__(self, gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + return self.forward(gate, up) + + def forward(self, gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + """ + Canonical entry: compute in fp32, cast the result back to gate.dtype. + This is the dtype-behavior path used as the Axis-B accuracy candidate. + """ + return self._swiglu(gate, up, output_dtype=gate.dtype) + + def forward_fp32(self, gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + """Ground-truth: compute in fp32 and force fp32 output.""" + return self._swiglu(gate, up, output_dtype=torch.float32) + + # ------------------------------------------------------------------ # + # Helpers + # ------------------------------------------------------------------ # + @staticmethod + def _swiglu( + gate: torch.Tensor, + up: torch.Tensor, + *, + output_dtype: torch.dtype, + ) -> torch.Tensor: + if gate.shape != up.shape: + raise ValueError( + f"gate and up must share shape, got tuple(gate.shape)=" + f"{tuple(gate.shape)} vs tuple(up.shape)={tuple(up.shape)}" + ) + gate_f = gate.float() + out = gate_f * torch.sigmoid(gate_f) * up.float() + return out.to(output_dtype) diff --git a/rl_engine/kernels/registry.py b/rl_engine/kernels/registry.py index 7aae08f..5294946 100644 --- a/rl_engine/kernels/registry.py +++ b/rl_engine/kernels/registry.py @@ -44,6 +44,8 @@ class OpBackend(Enum, metaclass=_KernelEnumMeta): # Generic fallback TRITON_GENERIC = "rl_engine.kernels.ops.triton.generic.TritonOp" PYTORCH_NATIVE = "rl_engine.kernels.ops.pytorch.loss.logp.NativeLogpOp" + PYTORCH_NATIVE_SILU = "rl_engine.kernels.ops.pytorch.activation.swiglu.NativeSiLUOp" + PYTORCH_NATIVE_SWIGLU = "rl_engine.kernels.ops.pytorch.activation.swiglu.NativeSwiGLUOp" class KernelRegistry: @@ -79,6 +81,8 @@ def __init__(self): "attn": [OpBackend.FLASH_ATTN, OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_NATIVE], "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], + "silu": [OpBackend.PYTORCH_NATIVE_SILU], + "swiglu": [OpBackend.PYTORCH_NATIVE_SWIGLU], # Default dispatch logic for new operators }, "rocm": { @@ -86,12 +90,16 @@ def __init__(self): "attn": [OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_NATIVE], "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], + "silu": [OpBackend.PYTORCH_NATIVE_SILU], + "swiglu": [OpBackend.PYTORCH_NATIVE_SWIGLU], }, "cpu": { "logp": [OpBackend.PYTORCH_NATIVE], "attn": [OpBackend.PYTORCH_NATIVE], "grpo_loss": [OpBackend.PYTORCH_GRPO_LOSS], "ratio_kl": [OpBackend.PYTORCH_RATIO_KL], + "silu": [OpBackend.PYTORCH_NATIVE_SILU], + "swiglu": [OpBackend.PYTORCH_NATIVE_SWIGLU], }, } logger.info(f"KernelRegistry initialized for {device_ctx.device_type}") diff --git a/tests/test_swiglu.py b/tests/test_swiglu.py new file mode 100644 index 0000000..f60def6 --- /dev/null +++ b/tests/test_swiglu.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +import pytest +import torch + +from rl_engine.kernels.ops.pytorch.activation.swiglu import NativeSiLUOp, NativeSwiGLUOp +from rl_engine.kernels.registry import kernel_registry + +# Qwen3-8B SwiGLU intermediate dim (gate/up_proj output width). +_INTERMEDIATE = 12288 + + +# Shared helper +def _rand(shape, *, seed, dtype=torch.float32): + gen = torch.Generator().manual_seed(seed) + return torch.randn(*shape, generator=gen, dtype=dtype) + + +@pytest.mark.parametrize("dtype", (torch.float32, torch.bfloat16, torch.float16)) +def test_native_silu_matches_fp32_reference(dtype: torch.dtype): + x = torch.linspace(-6.0, 6.0, 33, dtype=dtype).reshape(3, 11) + + fp32_reference = x.float() * torch.sigmoid(x.float()) + result = NativeSiLUOp().forward(x) + + assert result.dtype == dtype + assert torch.equal(result, fp32_reference.to(dtype)) + assert torch.equal(NativeSiLUOp().forward_fp32(x), fp32_reference) + + +@pytest.mark.parametrize("dtype", (torch.float32, torch.bfloat16, torch.float16)) +def test_native_swiglu_matches_fp32_reference(dtype: torch.dtype): + gate = torch.linspace(-4.0, 4.0, 48, dtype=dtype).reshape(2, 3, 8) + up = torch.linspace(0.5, 2.0, 48, dtype=dtype).reshape(2, 3, 8) + + fp32_reference = gate.float() * torch.sigmoid(gate.float()) * up.float() + result = NativeSwiGLUOp().forward(gate, up) + + assert result.dtype == dtype + assert torch.equal(result, fp32_reference.to(dtype)) + assert torch.equal(NativeSwiGLUOp().forward_fp32(gate, up), fp32_reference) + + +def test_native_swiglu_rejects_mismatched_shape(): + gate = torch.randn(2, 3) + up = torch.randn(2, 4) + + with pytest.raises(ValueError, match="share shape"): + NativeSwiGLUOp().forward(gate, up) + + +# Axis A -- batch invariance, bitwise (the WS1 "aligned" property). +# A row's output must not depend on how many rows share the batch. +def test_silu_batch_invariance_slice(): + op = NativeSiLUOp() + x = _rand((8, 32, _INTERMEDIATE), seed=2) + full = op.forward_fp32(x) # compute on full batch... + assert torch.equal(op.forward_fp32(x[:1]), full[:1]) # ...then slice + assert torch.equal(op.forward_fp32(x[3:5]), full[3:5]) + + +def test_swiglu_batch_invariance_slice(): + op = NativeSwiGLUOp() + gate = _rand((8, 32, _INTERMEDIATE), seed=3) + up = _rand((8, 32, _INTERMEDIATE), seed=4) + full = op.forward_fp32(gate, up) + assert torch.equal(op.forward_fp32(gate[:1], up[:1]), full[:1]) + assert torch.equal(op.forward_fp32(gate[3:5], up[3:5]), full[3:5]) + + +def test_silu_batch_invariance_with_padding(): + """Padding extra rows must not perturb the real rows (bitwise).""" + op = NativeSiLUOp() + x = _rand((4, _INTERMEDIATE), seed=5) + padded = torch.cat([x, _rand((6, _INTERMEDIATE), seed=99)], dim=0) + assert torch.equal(op.forward_fp32(padded)[:4], op.forward_fp32(x)) + + +def test_swiglu_batch_invariance_with_padding(): + op = NativeSwiGLUOp() + gate = _rand((4, _INTERMEDIATE), seed=6) + up = _rand((4, _INTERMEDIATE), seed=7) + pad_gate = torch.cat([gate, _rand((6, _INTERMEDIATE), seed=98)], dim=0) + pad_up = torch.cat([up, _rand((6, _INTERMEDIATE), seed=97)], dim=0) + assert torch.equal(op.forward_fp32(pad_gate, pad_up)[:4], op.forward_fp32(gate, up)) + + +# Purity -- inputs not mutated in-place +def test_silu_inputs_not_mutated(): + op = NativeSiLUOp() + x = _rand((2, _INTERMEDIATE), seed=8) + xc = x.clone() + op.forward(x) + op.forward_fp32(x) + assert torch.equal(x, xc) + + +def test_swiglu_inputs_not_mutated(): + op = NativeSwiGLUOp() + gate = _rand((2, _INTERMEDIATE), seed=9) + up = _rand((2, _INTERMEDIATE), seed=10) + gc, uc = gate.clone(), up.clone() + op.forward(gate, up) + op.forward_fp32(gate, up) + assert torch.equal(gate, gc) and torch.equal(up, uc) + + +# Gradient flows (fp32 autograd = backward golden source) +def test_silu_gradient_flows(): + op = NativeSiLUOp() + x = _rand((2, _INTERMEDIATE), seed=11).requires_grad_(True) + op.forward_fp32(x).sum().backward() + assert torch.isfinite(x.grad).all() + + +def test_swiglu_gradient_flows(): + op = NativeSwiGLUOp() + gate = _rand((2, _INTERMEDIATE), seed=12).requires_grad_(True) + up = _rand((2, _INTERMEDIATE), seed=13).requires_grad_(True) + op.forward_fp32(gate, up).sum().backward() + assert torch.isfinite(gate.grad).all() and torch.isfinite(up.grad).all() + + +def test_registry_dispatches_native_activation_ops(): + assert isinstance(kernel_registry.get_op("silu"), NativeSiLUOp) + assert isinstance(kernel_registry.get_op("swiglu"), NativeSwiGLUOp)