From e2bb04e0b896103620d06bac90cd28a4abe49b88 Mon Sep 17 00:00:00 2001 From: sadlerchen Date: Tue, 23 Jun 2026 11:38:30 +0800 Subject: [PATCH 1/5] feat(kernels): add fused masking + variable-length pack-and-pad op (#42) --- .../kernels/ops/pytorch/packing/__init__.py | 2 + rl_engine/kernels/ops/pytorch/packing/pack.py | 110 ++++++++ .../kernels/ops/triton/packing/__init__.py | 2 + rl_engine/kernels/ops/triton/packing/pack.py | 156 +++++++++++ rl_engine/kernels/registry.py | 7 + tests/test_pack.py | 257 ++++++++++++++++++ 6 files changed, 534 insertions(+) create mode 100644 rl_engine/kernels/ops/pytorch/packing/__init__.py create mode 100644 rl_engine/kernels/ops/pytorch/packing/pack.py create mode 100644 rl_engine/kernels/ops/triton/packing/__init__.py create mode 100644 rl_engine/kernels/ops/triton/packing/pack.py create mode 100644 tests/test_pack.py diff --git a/rl_engine/kernels/ops/pytorch/packing/__init__.py b/rl_engine/kernels/ops/pytorch/packing/__init__.py new file mode 100644 index 0000000..86cf4c9 --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/packing/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors diff --git a/rl_engine/kernels/ops/pytorch/packing/pack.py b/rl_engine/kernels/ops/pytorch/packing/pack.py new file mode 100644 index 0000000..5a79064 --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/packing/pack.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +"""PyTorch-native fallback for fused masking + variable-length packing. + +During RL training (PPO/DPO/GRPO) only the generated, non-padding tokens +contribute to the loss. Materializing full ``[B, S, ...]`` logits for the +masked-out positions wastes VRAM. This op compacts the active rows of a dense +``[B, S, ...]`` tensor into a contiguous ``[Total_Active, ...]`` tensor (and +scatters gradients back on the backward pass), so downstream loss kernels only +ever touch active tokens. + +This is the portable reference path that defines the numerical contract for the +Triton / CUDA / ROCm native kernels (issue #42). The packing order is +row-major over the flattened ``[B, S]`` grid, identical to +``x.reshape(-1, *tail)[mask.reshape(-1)]`` and to +``SyntheticRLKernelBatch.compact_completion_values`` used by the tests. +""" + +from __future__ import annotations + +from typing import Tuple + +import torch + + +def _validate(x: torch.Tensor, mask: torch.Tensor) -> None: + if mask.dim() < 1: + raise ValueError("mask must have at least one dimension.") + if mask.shape != x.shape[: mask.dim()]: + raise ValueError( + f"mask shape {tuple(mask.shape)} must match the leading dims of " + f"x.shape {tuple(x.shape)} (expected {tuple(x.shape[: mask.dim()])})." + ) + if mask.dim() < 1: + raise ValueError("mask must have at least one dimension.") + + +class _PackFunction(torch.autograd.Function): + """forward: gather active rows; backward: scatter grads to active rows.""" + + @staticmethod + def forward(ctx, x: torch.Tensor, mask: torch.Tensor): + _validate(x, mask) + lead = mask.dim() + tail_shape = x.shape[lead:] + + flat_mask = mask.reshape(-1).to(torch.bool) + flat_x = x.reshape(-1, *tail_shape) if tail_shape else x.reshape(-1) + + index = flat_mask.nonzero(as_tuple=False).squeeze(-1) + packed = flat_x.index_select(0, index) + + # cu_seqlens: prefix-sum of per-row active counts, for varlen consumers. + per_row_active = mask.reshape(mask.shape[0], -1).to(torch.int64).sum(dim=1) + cu_seqlens = torch.zeros(mask.shape[0] + 1, dtype=torch.int64, device=x.device) + torch.cumsum(per_row_active, dim=0, out=cu_seqlens[1:]) + + ctx.save_for_backward(index) + ctx.flat_rows = flat_x.shape[0] + ctx.tail_shape = tail_shape + ctx.x_shape = tuple(x.shape) + ctx.x_dtype = x.dtype + return packed, cu_seqlens + + @staticmethod + def backward(ctx, grad_packed: torch.Tensor, grad_cu_seqlens): + (index,) = ctx.saved_tensors + tail_shape = ctx.tail_shape + + grad_flat = grad_packed.new_zeros( + (ctx.flat_rows, *tail_shape) if tail_shape else (ctx.flat_rows,) + ) + grad_flat.index_copy_(0, index, grad_packed) + grad_x = grad_flat.reshape(ctx.x_shape) + # grad w.r.t. mask is undefined (boolean selector). + return grad_x, None + + +class NativePackOp: + """PyTorch-native fused masking + variable-length packing (pack-and-pad). + + Forward packs the active rows of ``x`` (selected by ``mask``) into a + contiguous ``[Total_Active, *tail]`` tensor and returns the per-row + ``cu_seqlens`` prefix-sum. Backward scatters the upstream gradient back to + the original ``[B, S, *tail]`` layout, leaving zeros at inactive positions. + """ + + def __call__(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return _PackFunction.apply(x, mask) + + @staticmethod + def unpack( + packed: torch.Tensor, + mask: torch.Tensor, + *, + tail_shape: Tuple[int, ...] | None = None, + ) -> torch.Tensor: + """Scatter a packed ``[Total_Active, *tail]`` tensor back to a dense + ``[*mask.shape, *tail]`` tensor with zeros at inactive positions. + + This is the explicit (non-autograd) inverse used by diagnostics; the + backward pass of :class:`_PackFunction` performs the same scatter. + """ + flat_mask = mask.reshape(-1).to(torch.bool) + tail = tuple(packed.shape[1:]) if tail_shape is None else tuple(tail_shape) + out = packed.new_zeros((flat_mask.numel(), *tail)) + index = flat_mask.nonzero(as_tuple=False).squeeze(-1) + out.index_copy_(0, index, packed) + return out.reshape(*mask.shape, *tail) diff --git a/rl_engine/kernels/ops/triton/packing/__init__.py b/rl_engine/kernels/ops/triton/packing/__init__.py new file mode 100644 index 0000000..86cf4c9 --- /dev/null +++ b/rl_engine/kernels/ops/triton/packing/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors diff --git a/rl_engine/kernels/ops/triton/packing/pack.py b/rl_engine/kernels/ops/triton/packing/pack.py new file mode 100644 index 0000000..2bbb74c --- /dev/null +++ b/rl_engine/kernels/ops/triton/packing/pack.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +"""Triton fused masking + variable-length packing (pack-and-pad) op (issue #42). + +Packs the active rows of a dense ``[B, S, *tail]`` tensor (selected by a +``[B, S]`` mask) into a contiguous ``[Total_Active, *tail]`` tensor, and scatters +gradients back to the dense layout on the backward pass. The packing order is +row-major over the flattened ``[B, S]`` grid, matching ``NativePackOp`` (the +numerical contract for this op). + +The active-token destination indices are computed with a cheap exclusive +prefix-sum over the flattened mask (small ``[B*S]`` tensor); the heavy +``tail``-vector movement runs in a Triton gather kernel, and the backward is a +symmetric scatter kernel. ``cu_seqlens`` (per-row prefix-sum) is returned for +varlen consumers, identical to the native op. +""" + +from __future__ import annotations + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +# Tail-vector tile width. +_BLOCK_T = 1024 + + +@triton.jit +def _pack_gather_kernel( + src_ptr, # [n_rows, T] dense, flattened over (B, S) + dst_ptr, # [n_active, T] packed + dest_ptr, # [n_rows] int64: packed row index for an active row, else -1 + T, + BLOCK_T: tl.constexpr, +): + """One program per (source row, tail-tile). Active rows copy their tail + vector into the packed buffer at ``dest_ptr[row]``; inactive rows are skipped.""" + row = tl.program_id(0) + dest = tl.load(dest_ptr + row) + if dest >= 0: + t0 = tl.program_id(1) * BLOCK_T + cols = t0 + tl.arange(0, BLOCK_T) + cmask = cols < T + src = tl.load(src_ptr + row.to(tl.int64) * T + cols, mask=cmask, other=0.0) + tl.store(dst_ptr + dest.to(tl.int64) * T + cols, src, mask=cmask) + + +@triton.jit +def _pack_scatter_kernel( + grad_packed_ptr, # [n_active, T] + grad_src_ptr, # [n_rows, T], pre-zeroed + dest_ptr, # [n_rows] int64 + T, + BLOCK_T: tl.constexpr, +): + """Backward: scatter the packed gradient back to the active source rows. + Inactive rows stay zero (grad_src is pre-zeroed).""" + row = tl.program_id(0) + dest = tl.load(dest_ptr + row) + if dest >= 0: + t0 = tl.program_id(1) * BLOCK_T + cols = t0 + tl.arange(0, BLOCK_T) + cmask = cols < T + g = tl.load(grad_packed_ptr + dest.to(tl.int64) * T + cols, mask=cmask, other=0.0) + tl.store(grad_src_ptr + row.to(tl.int64) * T + cols, g, mask=cmask) + + +def _dest_index(flat_mask: torch.Tensor) -> Tuple[torch.Tensor, int]: + """Map each flattened row to its packed destination index (active rows get an + exclusive prefix-sum position; inactive rows get -1). Returns (dest, n_active).""" + active = flat_mask.to(torch.bool) + counts = active.to(torch.int64) + # Exclusive prefix sum: position of each active row in the packed buffer. + excl = torch.cumsum(counts, dim=0) - counts + n_active = int(counts.sum().item()) + dest = torch.where(active, excl, torch.full_like(excl, -1)) + return dest.contiguous(), n_active + + +class _PackFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, mask: torch.Tensor): + lead = mask.dim() + tail_shape = x.shape[lead:] + n_rows = 1 + for s in mask.shape: + n_rows *= int(s) + T = 1 + for s in tail_shape: + T *= int(s) + + src = x.reshape(n_rows, T).contiguous() + flat_mask = mask.reshape(-1) + dest, n_active = _dest_index(flat_mask) + + packed = torch.empty(n_active, T, device=x.device, dtype=x.dtype) + if n_active > 0: + grid = (n_rows, triton.cdiv(T, _BLOCK_T)) + _pack_gather_kernel[grid](src, packed, dest, T, BLOCK_T=_BLOCK_T) + + per_row_active = flat_mask.to(torch.bool).reshape(mask.shape[0], -1).to(torch.int64).sum(1) + cu_seqlens = torch.zeros(mask.shape[0] + 1, dtype=torch.int64, device=x.device) + torch.cumsum(per_row_active, dim=0, out=cu_seqlens[1:]) + + ctx.save_for_backward(dest) + ctx.n_rows = n_rows + ctx.T = T + ctx.x_shape = tuple(x.shape) + ctx.x_dtype = x.dtype + out_tail = tuple(tail_shape) + packed_out = packed.reshape(n_active, *out_tail) if out_tail else packed.reshape(n_active) + return packed_out, cu_seqlens + + @staticmethod + def backward(ctx, grad_packed: torch.Tensor, grad_cu_seqlens): + (dest,) = ctx.saved_tensors + n_rows, T = ctx.n_rows, ctx.T + n_active = grad_packed.shape[0] + + gp = grad_packed.reshape(n_active, T).contiguous() + grad_src = torch.zeros(n_rows, T, device=grad_packed.device, dtype=grad_packed.dtype) + if n_active > 0: + grid = (n_rows, triton.cdiv(T, _BLOCK_T)) + _pack_scatter_kernel[grid](gp, grad_src, dest, T, BLOCK_T=_BLOCK_T) + + grad_x = grad_src.reshape(ctx.x_shape) + return grad_x, None + + +class TritonPackOp: + """Triton fused masking + variable-length packing (pack-and-pad). + + Forward packs the active rows of ``x`` (selected by ``mask``) into a + contiguous ``[Total_Active, *tail]`` tensor and returns the per-row + ``cu_seqlens`` prefix-sum. Backward scatters the upstream gradient back to the + original ``[*mask.shape, *tail]`` layout, leaving zeros at inactive positions. + Numerically identical to ``NativePackOp``; CUDA & ROCm via Triton. + """ + + def __call__(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if x.device.type not in ("cuda", "xpu", "hip"): + raise RuntimeError( + "TritonPackOp requires a GPU tensor (CUDA / ROCm / XPU), got " + f"device '{x.device}'." + ) + if mask.dim() < 1: + raise ValueError("mask must have at least one dimension.") + if mask.shape != x.shape[: mask.dim()]: + raise ValueError( + f"mask shape {tuple(mask.shape)} must match the leading dims of " + f"x.shape {tuple(x.shape)} (expected {tuple(x.shape[: mask.dim()])})." + ) + return _PackFunction.apply(x, mask) diff --git a/rl_engine/kernels/registry.py b/rl_engine/kernels/registry.py index 6780157..5c4490f 100644 --- a/rl_engine/kernels/registry.py +++ b/rl_engine/kernels/registry.py @@ -49,6 +49,10 @@ class OpBackend(Enum, metaclass=_KernelEnumMeta): TRITON_RATIO_KL = "rl_engine.kernels.ops.triton.loss.ratio_kl.TritonRatioKLOp" PYTORCH_RATIO_KL = "rl_engine.kernels.ops.pytorch.loss.ratio_kl.NativeRatioKLOp" + # Fused masking + variable-length packing (pack-and-pad), [B,S,...] -> [Total_Active,...] + TRITON_PACK = "rl_engine.kernels.ops.triton.packing.pack.TritonPackOp" + PYTORCH_PACK = "rl_engine.kernels.ops.pytorch.packing.pack.NativePackOp" + # Generic fallback TRITON_GENERIC = "rl_engine.kernels.ops.triton.generic.TritonOp" PYTORCH_ATTN = "rl_engine.kernels.ops.pytorch.attention.NativeAttentionOp" @@ -90,6 +94,7 @@ def __init__(self): "linear_logp": [OpBackend.TRITON_LINEAR_LOGP, OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], # Default dispatch logic for new operators + "pack": [OpBackend.TRITON_PACK, OpBackend.PYTORCH_PACK], }, "rocm": { "logp": [OpBackend.ROCM_AITER, OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_NATIVE], @@ -101,6 +106,7 @@ def __init__(self): "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.TRITON_LINEAR_LOGP, OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], + "pack": [OpBackend.TRITON_PACK, OpBackend.PYTORCH_PACK], }, "cpu": { "logp": [OpBackend.PYTORCH_NATIVE], @@ -108,6 +114,7 @@ def __init__(self): "grpo_loss": [OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.PYTORCH_RATIO_KL], + "pack": [OpBackend.PYTORCH_PACK], }, } logger.info(f"KernelRegistry initialized for {device_ctx.device_type}") diff --git a/tests/test_pack.py b/tests/test_pack.py new file mode 100644 index 0000000..f67aa19 --- /dev/null +++ b/tests/test_pack.py @@ -0,0 +1,257 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +"""Tests for the fused masking + pack-and-pad op (issue #42). + +The PyTorch-native op is the portable reference that defines the numerical +contract for the Triton / CUDA / ROCm native kernels. Native correctness is +checked against ``SyntheticRLKernelBatch.compact_completion_values`` (the +canonical compaction already used elsewhere in the repo) and a plain +index-based reference; the Triton op is validated against the native op. +Native tests run on CPU; Triton tests require a CUDA/ROCm device. +""" + +import pytest +import torch + +from rl_engine.kernels.ops.pytorch.packing.pack import NativePackOp +from rl_engine.kernels.ops.triton.packing.pack import TritonPackOp +from rl_engine.testing import make_synthetic_rl_kernel_batch + +try: + import triton # noqa: F401 + + _HAS_TRITON = True +except ImportError: # pragma: no cover + _HAS_TRITON = False + +requires_triton_cuda = pytest.mark.skipif( + not (_HAS_TRITON and torch.cuda.is_available()), + reason="Triton pack op requires a CUDA device and Triton.", +) + +_NUM_PROMPTS = 3 +_SPP = 4 +_COMP_LEN = 6 +_VOCAB = 64 + + +def _batch(seed=0, *, device="cpu", valid_density=0.8): + return make_synthetic_rl_kernel_batch( + num_prompts=_NUM_PROMPTS, + samples_per_prompt=_SPP, + prompt_len=0, + completion_len=_COMP_LEN, + vocab_size=_VOCAB, + valid_density=valid_density, + device=device, + seed=seed, + ) + + +def _dense(batch, seed, *, vocab=_VOCAB, device="cpu"): + gen = torch.Generator(device=device).manual_seed(seed) + return torch.randn(batch.batch_size, batch.completion_len, vocab, generator=gen, device=device) + + +# forward correctness +def test_pack_matches_batch_compaction(): + """Packed output must equal the repo's canonical compact_completion_values.""" + batch = _batch(seed=0) + x = _dense(batch, seed=100) + op = NativePackOp() + + packed, cu_seqlens = op(x, batch.completion_mask) + expected = batch.compact_completion_values(x) + + assert packed.shape == expected.shape + assert torch.equal(packed, expected) + # Total active tokens equals the mask sum and the last cu_seqlens entry. + assert int(cu_seqlens[-1].item()) == int(batch.completion_mask.sum().item()) + assert cu_seqlens.numel() == batch.batch_size + 1 + + +def test_pack_matches_index_reference(): + batch = _batch(seed=1, valid_density=0.5) + x = _dense(batch, seed=101) + op = NativePackOp() + + packed, _ = op(x, batch.completion_mask) + flat_mask = batch.completion_mask.reshape(-1) + ref = x.reshape(-1, x.shape[-1])[flat_mask] + assert torch.equal(packed, ref) + + +def test_cu_seqlens_is_per_row_prefix_sum(): + batch = _batch(seed=2, valid_density=0.6) + x = _dense(batch, seed=102) + op = NativePackOp() + + _, cu_seqlens = op(x, batch.completion_mask) + per_row = batch.completion_mask.reshape(batch.batch_size, -1).sum(dim=1) + expected = torch.zeros(batch.batch_size + 1, dtype=torch.int64) + torch.cumsum(per_row.to(torch.int64), dim=0, out=expected[1:]) + assert torch.equal(cu_seqlens, expected) + + +def test_pack_all_active_is_identity_flatten(): + batch = _batch(seed=3, valid_density=1.0) + x = _dense(batch, seed=103) + op = NativePackOp() + + packed, _ = op(x, batch.completion_mask) + assert packed.shape[0] == batch.batch_size * batch.completion_len + assert torch.equal(packed, x.reshape(-1, x.shape[-1])) + + +def test_pack_none_active_is_empty(): + batch = _batch(seed=4, valid_density=0.0) + x = _dense(batch, seed=104) + op = NativePackOp() + + packed, cu_seqlens = op(x, batch.completion_mask) + assert packed.shape[0] == 0 + assert int(cu_seqlens[-1].item()) == 0 + + +# unpack / round-trip +def test_unpack_round_trip_zeros_inactive(): + batch = _batch(seed=5, valid_density=0.7) + x = _dense(batch, seed=105) + op = NativePackOp() + + packed, _ = op(x, batch.completion_mask) + restored = op.unpack(packed, batch.completion_mask) + + mask = batch.completion_mask + active = mask.unsqueeze(-1).expand_as(x) + # Active positions are restored exactly; inactive positions are zeroed. + assert torch.equal(restored[active], x[active]) + assert torch.all(restored[~active] == 0.0) + + +# backward (scatter) correctness +def test_backward_scatters_grad_to_active_rows(): + batch = _batch(seed=6, valid_density=0.7) + x = _dense(batch, seed=106).requires_grad_(True) + op = NativePackOp() + + packed, _ = op(x, batch.completion_mask) + g = torch.randn_like(packed) + packed.backward(g) + + # The gradient w.r.t. x is the scatter of g back to the active rows. + expected_grad = op.unpack(g, batch.completion_mask) + assert x.grad is not None + assert torch.equal(x.grad, expected_grad) + # Inactive positions receive zero gradient. + inactive = ~batch.completion_mask.unsqueeze(-1).expand_as(x) + assert torch.all(x.grad[inactive] == 0.0) + + +def test_backward_gradcheck_double(): + """Analytic scatter backward must match numerical gradients (float64).""" + batch = _batch(seed=7, valid_density=0.6) + mask = batch.completion_mask + x = torch.randn(batch.batch_size, batch.completion_len, 3, dtype=torch.float64).requires_grad_( + True + ) + op = NativePackOp() + + # Only the packed tensor is differentiable; cu_seqlens is integer. + assert torch.autograd.gradcheck(lambda t: op(t, mask)[0], (x,), eps=1e-6, atol=1e-6) + + +# multi-dim tail and validation +def test_pack_supports_multidim_tail(): + mask = torch.tensor([[True, False, True], [False, True, True]]) + x = torch.randn(2, 3, 4, 5) + op = NativePackOp() + + packed, cu_seqlens = op(x, mask) + assert packed.shape == (4, 4, 5) + assert torch.equal(packed, x.reshape(-1, 4, 5)[mask.reshape(-1)]) + assert cu_seqlens.tolist() == [0, 2, 4] + + +def test_pack_rejects_mismatched_mask_shape(): + x = torch.randn(2, 3, 4) + bad_mask = torch.ones(2, 5, dtype=torch.bool) + op = NativePackOp() + with pytest.raises(ValueError): + op(x, bad_mask) + + +# registry dispatch +def test_registry_dispatches_pack(): + from rl_engine.kernels.registry import kernel_registry + + op = kernel_registry.get_op("pack") + if _HAS_TRITON and torch.cuda.is_available(): + assert isinstance(op, TritonPackOp) + else: + assert isinstance(op, NativePackOp) + + +# Triton fused op (validated against the native reference) +@requires_triton_cuda +@pytest.mark.parametrize("valid_density", [1.0, 0.7, 0.0]) +def test_triton_forward_matches_native(valid_density): + batch = _batch(seed=10, device="cuda", valid_density=valid_density) + x = _dense(batch, seed=110, device="cuda") + packed_t, cu_t = TritonPackOp()(x, batch.completion_mask) + packed_n, cu_n = NativePackOp()(x, batch.completion_mask) + assert torch.equal(packed_t, packed_n) + assert torch.equal(cu_t, cu_n) + + +@requires_triton_cuda +def test_triton_backward_matches_native(): + batch = _batch(seed=11, device="cuda", valid_density=0.7) + x0 = _dense(batch, seed=111, device="cuda") + g = torch.randn(int(batch.completion_mask.sum()), _VOCAB, device="cuda") + + xt = x0.clone().requires_grad_(True) + pt, _ = TritonPackOp()(xt, batch.completion_mask) + pt.backward(g) + + xn = x0.clone().requires_grad_(True) + pn, _ = NativePackOp()(xn, batch.completion_mask) + pn.backward(g) + + assert xt.grad is not None + assert torch.equal(xt.grad, xn.grad) + + +@requires_triton_cuda +def test_triton_supports_multidim_tail(): + mask = torch.tensor([[True, False, True], [False, True, True]], device="cuda") + x = torch.randn(2, 3, 4, 5, device="cuda") + packed_t, cu_t = TritonPackOp()(x, mask) + packed_n, cu_n = NativePackOp()(x, mask) + assert packed_t.shape == (4, 4, 5) + assert torch.equal(packed_t, packed_n) + assert cu_t.tolist() == [0, 2, 4] + + +@requires_triton_cuda +def test_triton_inactive_rows_do_not_leak(): + """Garbage values at masked rows must not appear in the packed output.""" + batch = _batch(seed=12, device="cuda", valid_density=0.6) + x = _dense(batch, seed=112, device="cuda") + inactive = ~batch.completion_mask.unsqueeze(-1).expand_as(x) + x_pert = x.clone() + x_pert[inactive] = 1e9 + + base, _ = TritonPackOp()(x, batch.completion_mask) + pert, _ = TritonPackOp()(x_pert, batch.completion_mask) + assert torch.equal(base, pert) + + +@requires_triton_cuda +def test_triton_requires_gpu_tensor(): + op = TritonPackOp() + x = torch.randn(2, 3, 4) + mask = torch.ones(2, 3, dtype=torch.bool) + with pytest.raises(RuntimeError): + op(x, mask) From a4e8fd901566bf7c9bb143377a1a873245419896 Mon Sep 17 00:00:00 2001 From: sadlerchen Date: Tue, 23 Jun 2026 12:07:07 +0800 Subject: [PATCH 2/5] perf(benchmarks): add pack-and-pad benchmark (latency + end-to-end VRAM) (#42) Measure TritonPackOp vs a PyTorch boolean-index baseline for pack latency, and the end-to-end peak VRAM of dense logp vs pack->logp to quantify the memory saving on sparse masks (the motivation behind #42). Follows the existing benchmark_ratio_kl.py conventions (CUDA-event median timing, max_memory_allocated, CSV output, --smoke). --- benchmarks/benchmark_pack.py | 370 +++++++++++++++++++++++++++++++++++ 1 file changed, 370 insertions(+) create mode 100644 benchmarks/benchmark_pack.py diff --git a/benchmarks/benchmark_pack.py b/benchmarks/benchmark_pack.py new file mode 100644 index 0000000..3fec56c --- /dev/null +++ b/benchmarks/benchmark_pack.py @@ -0,0 +1,370 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +"""Benchmark for the fused masking + pack-and-pad op (issue #42). + +Two measurements per shape: + +1. Pack latency: TritonPackOp vs a PyTorch boolean-index baseline + (``x.reshape(-1, T)[mask]``), with max-abs drift between the two. +2. End-to-end peak VRAM: the motivation behind #42. Computing selected + log-probs on the *dense* ``[B, S, V]`` tensor materializes full-sequence + logits for masked-out tokens; packing first lets the log-prob run only on + the ``[Total_Active, V]`` active rows. We report the peak GPU memory of + ``dense logp`` vs ``pack -> logp`` to quantify the saving when the mask is + sparse (long prompt, short response, padded batches). +""" + +from __future__ import annotations + +import argparse +import csv +import statistics +import sys +import time +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import torch + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from rl_engine.kernels.ops.pytorch.packing.pack import NativePackOp # noqa: E402 +from rl_engine.testing import make_synthetic_rl_kernel_batch # noqa: E402 + +CSV_COLUMNS = [ + "timestamp", + "case", + "candidate", + "device", + "dtype", + "num_prompts", + "samples_per_prompt", + "completion_len", + "vocab_size", + "mask_density", + "valid_tokens", + "baseline_ms", + "candidate_ms", + "speedup", + "pack_drift", + "dense_logp_mem_gb", + "packed_logp_mem_gb", + "mem_saving_pct", + "status", + "notes", +] + + +@dataclass(frozen=True) +class BenchmarkConfig: + case: str + device: torch.device + dtype: torch.dtype + num_prompts: int + samples_per_prompt: int + completion_len: int + vocab_size: int + mask_density: float + seed: int + warmup: int + repeat: int + + +def _parse_int_list(value: str) -> list[int]: + return [int(item) for item in value.split(",") if item] + + +def _parse_float_list(value: str) -> list[float]: + return [float(item) for item in value.split(",") if item] + + +def _parse_dtype(value: str) -> torch.dtype: + normalized = value.lower() + if normalized in {"fp16", "float16", "half"}: + return torch.float16 + if normalized in {"bf16", "bfloat16"}: + return torch.bfloat16 + if normalized in {"fp32", "float32"}: + return torch.float32 + raise ValueError(f"unsupported dtype: {value}") + + +def _sync(device: torch.device) -> None: + if device.type == "cuda": + torch.cuda.synchronize(device) + + +def _time_ms(fn, device: torch.device, *, warmup: int = 3, repeat: int = 10) -> tuple[Any, float]: + result = None + for _ in range(max(0, warmup)): + result = fn() + _sync(device) + + elapsed: list[float] = [] + for _ in range(max(1, repeat)): + if device.type == "cuda": + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + result = fn() + end.record() + end.synchronize() + elapsed.append(start.elapsed_time(end)) + else: + _sync(device) + start_time = time.perf_counter() + result = fn() + _sync(device) + elapsed.append((time.perf_counter() - start_time) * 1000.0) + + _sync(device) + return result, statistics.median(elapsed) + + +def _peak_memory_gb(device: torch.device) -> float: + if device.type != "cuda": + return 0.0 + return torch.cuda.max_memory_allocated(device) / (1024**3) + + +def _reset_peak(device: torch.device) -> None: + if device.type == "cuda": + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device) + + +def _baseline_pack(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """PyTorch reference: flatten and boolean-index the active rows.""" + return x.reshape(-1, x.shape[-1])[mask.reshape(-1)] + + +def _selected_logp(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor: + """log_softmax(logits)[ids] over the last dim; ids shape == logits.shape[:-1].""" + logp = torch.log_softmax(logits.float(), dim=-1) + return logp.gather(-1, ids.long().unsqueeze(-1)).squeeze(-1) + + +def _pack_row(config: BenchmarkConfig) -> dict[str, Any]: + candidate_name = "TritonPackOp" + + batch = make_synthetic_rl_kernel_batch( + num_prompts=config.num_prompts, + samples_per_prompt=config.samples_per_prompt, + prompt_len=0, + completion_len=config.completion_len, + vocab_size=config.vocab_size, + valid_density=config.mask_density, + dtype=config.dtype, + device=config.device, + seed=config.seed, + ) + + logit_shape = (batch.batch_size, batch.completion_len, config.vocab_size) + logits = torch.randn(logit_shape, device=config.device, dtype=config.dtype) + mask = batch.completion_mask + ids = batch.token_ids + + native_pack = NativePackOp() + + status = "pass" + notes = "" + baseline_ms: float | str = "" + candidate_ms: float | str = "" + speedup: float | str = "" + pack_drift: float | str = "" + dense_logp_mem_gb: float | str = "" + packed_logp_mem_gb: float | str = "" + mem_saving_pct: float | str = "" + + # (1) pack latency: PyTorch boolean-index baseline vs Triton candidate. + _reset_peak(config.device) + base_packed, baseline_ms = _time_ms( + lambda: _baseline_pack(logits, mask), + config.device, + warmup=config.warmup, + repeat=config.repeat, + ) + + if config.device.type != "cuda": + status = "blocked" + notes = "candidate requires CUDA" + else: + try: + from rl_engine.kernels.registry import kernel_registry + + candidate_op = kernel_registry.get_op("pack") + if candidate_op.__class__.__name__ != candidate_name: + raise RuntimeError(f"{candidate_name} backend is unavailable") + + (cand_packed, _), candidate_ms = _time_ms( + lambda: candidate_op(logits, mask), + config.device, + warmup=config.warmup, + repeat=config.repeat, + ) + speedup = baseline_ms / candidate_ms if candidate_ms else float("inf") + pack_drift = (cand_packed.float() - base_packed.float()).abs().max().item() + + # (2) end-to-end peak VRAM: dense logp vs pack->logp. + _reset_peak(config.device) + _ = _selected_logp(logits, ids) + _sync(config.device) + dense_logp_mem_gb = _peak_memory_gb(config.device) + + _reset_peak(config.device) + packed_logits, _ = candidate_op(logits, mask) + packed_ids, _ = candidate_op(ids.unsqueeze(-1), mask) + _ = _selected_logp(packed_logits, packed_ids.squeeze(-1)) + _sync(config.device) + packed_logp_mem_gb = _peak_memory_gb(config.device) + + if dense_logp_mem_gb > 0: + mem_saving_pct = 100.0 * (1.0 - packed_logp_mem_gb / dense_logp_mem_gb) + except Exception as exc: + status = "blocked" + notes = f"candidate unavailable: {str(exc).splitlines()[0]}" + + metadata = batch.benchmark_metadata() + timing_mode = "cuda_event_median_ms" if config.device.type == "cuda" else "wall_median_ms" + timing_notes = f"warmup={config.warmup}; repeat={config.repeat}; {timing_mode}" + notes = f"{notes}; {timing_notes}" if notes else timing_notes + + def _fmt(value: Any, spec: str) -> Any: + return format(value, spec) if isinstance(value, float) else value + + return { + "timestamp": datetime.now(timezone.utc).isoformat(), + "case": config.case, + "candidate": candidate_name, + "device": str(config.device), + "dtype": str(config.dtype), + "num_prompts": config.num_prompts, + "samples_per_prompt": config.samples_per_prompt, + "completion_len": config.completion_len, + "vocab_size": config.vocab_size, + "mask_density": config.mask_density, + "valid_tokens": metadata["valid_tokens"], + "baseline_ms": f"{baseline_ms:.4f}" if isinstance(baseline_ms, float) else baseline_ms, + "candidate_ms": _fmt(candidate_ms, ".4f"), + "speedup": _fmt(speedup, ".2f"), + "pack_drift": _fmt(pack_drift, ".3e"), + "dense_logp_mem_gb": _fmt(dense_logp_mem_gb, ".6f"), + "packed_logp_mem_gb": _fmt(packed_logp_mem_gb, ".6f"), + "mem_saving_pct": _fmt(mem_saving_pct, ".2f"), + "status": status, + "notes": notes, + } + + +def _write_rows(rows: list[dict[str, Any]], output: Path | None) -> None: + if output is None: + writer = csv.DictWriter(sys.stdout, fieldnames=CSV_COLUMNS) + writer.writeheader() + writer.writerows(rows) + return + + output.parent.mkdir(parents=True, exist_ok=True) + exists = output.exists() and output.stat().st_size > 0 + with output.open("a", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=CSV_COLUMNS) + if not exists: + writer.writeheader() + writer.writerows(rows) + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Fused pack-and-pad RL-Kernel benchmark runner") + parser.add_argument("--case", default="pack", choices=["pack"]) + parser.add_argument("--smoke", action="store_true", help="Run a small local-development shape") + parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--dtype", default="bfloat16") + parser.add_argument("--num-prompts", type=int, default=2) + parser.add_argument("--g-sizes", default="8", help="Comma-separated samples-per-prompt values") + parser.add_argument("--completion-lens", default="1024") + parser.add_argument("--vocab-sizes", default="32768,131072") + parser.add_argument( + "--mask-densities", + default="0.1,0.3,1.0", + help="Active-token fraction; sparse masks show the largest VRAM saving", + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--warmup", type=int, default=3) + parser.add_argument("--repeat", type=int, default=10) + parser.add_argument("--output", type=Path, default=None) + return parser + + +def main() -> None: + args = build_arg_parser().parse_args() + device = torch.device(args.device) + dtype = _parse_dtype(args.dtype) + + if args.smoke: + num_prompts = 1 + g_sizes = [2] + completion_lens = [8] + vocab_sizes = [128] + mask_densities = [0.5, 1.0] + else: + num_prompts = args.num_prompts + g_sizes = _parse_int_list(args.g_sizes) + completion_lens = _parse_int_list(args.completion_lens) + vocab_sizes = _parse_int_list(args.vocab_sizes) + mask_densities = _parse_float_list(args.mask_densities) + + rows: list[dict[str, Any]] = [] + for samples_per_prompt in g_sizes: + for completion_len in completion_lens: + for vocab_size in vocab_sizes: + for mask_density in mask_densities: + config = BenchmarkConfig( + case=args.case, + device=device, + dtype=dtype, + num_prompts=num_prompts, + samples_per_prompt=samples_per_prompt, + completion_len=completion_len, + vocab_size=vocab_size, + mask_density=mask_density, + seed=args.seed, + warmup=args.warmup, + repeat=args.repeat, + ) + try: + rows.append(_pack_row(config)) + except torch.cuda.OutOfMemoryError as exc: + rows.append( + { + "timestamp": datetime.now(timezone.utc).isoformat(), + "case": args.case, + "candidate": "TritonPackOp", + "device": str(device), + "dtype": str(dtype), + "num_prompts": num_prompts, + "samples_per_prompt": samples_per_prompt, + "completion_len": completion_len, + "vocab_size": vocab_size, + "mask_density": mask_density, + "valid_tokens": "", + "baseline_ms": "", + "candidate_ms": "", + "speedup": "", + "pack_drift": "", + "dense_logp_mem_gb": "", + "packed_logp_mem_gb": "", + "mem_saving_pct": "", + "status": "oom", + "notes": str(exc), + } + ) + + _write_rows(rows, args.output) + + +if __name__ == "__main__": + main() From b6ed34f8b785381904b9004ab12bd23fd9f962d1 Mon Sep 17 00:00:00 2001 From: sadlerchen Date: Tue, 23 Jun 2026 12:18:20 +0800 Subject: [PATCH 3/5] perf(benchmarks): model real hidden->lm_head->logp chain for pack VRAM Pack hidden states before the vocab projection so the dense [B,S,V] logits are never materialized for masked-out tokens, which is the actual #42 saving. Drop the fp32 upcast in selected-logp (use logits - logsumexp) so the dense path's peak memory reflects the logits, not an fp32 copy. Add --hidden-dim. --- benchmarks/benchmark_pack.py | 56 +++++++++++++++++++++++++++--------- 1 file changed, 42 insertions(+), 14 deletions(-) diff --git a/benchmarks/benchmark_pack.py b/benchmarks/benchmark_pack.py index 3fec56c..2df7f92 100644 --- a/benchmarks/benchmark_pack.py +++ b/benchmarks/benchmark_pack.py @@ -33,7 +33,6 @@ if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) -from rl_engine.kernels.ops.pytorch.packing.pack import NativePackOp # noqa: E402 from rl_engine.testing import make_synthetic_rl_kernel_batch # noqa: E402 CSV_COLUMNS = [ @@ -46,6 +45,7 @@ "samples_per_prompt", "completion_len", "vocab_size", + "hidden_dim", "mask_density", "valid_tokens", "baseline_ms", @@ -69,6 +69,7 @@ class BenchmarkConfig: samples_per_prompt: int completion_len: int vocab_size: int + hidden_dim: int mask_density: float seed: int warmup: int @@ -144,9 +145,16 @@ def _baseline_pack(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: def _selected_logp(logits: torch.Tensor, ids: torch.Tensor) -> torch.Tensor: - """log_softmax(logits)[ids] over the last dim; ids shape == logits.shape[:-1].""" - logp = torch.log_softmax(logits.float(), dim=-1) - return logp.gather(-1, ids.long().unsqueeze(-1)).squeeze(-1) + """selected log-prob = logits[ids] - logsumexp(logits) over the last dim. + + Computed without materializing a full [*, V] log_softmax tensor and without + upcasting the whole logits tensor to fp32, so the dense path's peak memory + is dominated by the input logits themselves -- which is exactly the cost + that packing removes for masked-out tokens (issue #42). + """ + selected = logits.gather(-1, ids.long().unsqueeze(-1)).squeeze(-1) + lse = torch.logsumexp(logits, dim=-1) + return selected - lse def _pack_row(config: BenchmarkConfig) -> dict[str, Any]: @@ -164,13 +172,21 @@ def _pack_row(config: BenchmarkConfig) -> dict[str, Any]: seed=config.seed, ) - logit_shape = (batch.batch_size, batch.completion_len, config.vocab_size) - logits = torch.randn(logit_shape, device=config.device, dtype=config.dtype) + # Real RL chain: hidden -> (lm_head) -> logits -> selected logp. + # Packing hidden *before* the vocab projection means the full [B, S, V] + # logits are never materialized for masked-out tokens (issue #42). + hidden_dim = config.hidden_dim + hidden = torch.randn( + (batch.batch_size, batch.completion_len, hidden_dim), + device=config.device, + dtype=config.dtype, + ) + lm_head = torch.randn( + (hidden_dim, config.vocab_size), device=config.device, dtype=config.dtype + ) mask = batch.completion_mask ids = batch.token_ids - native_pack = NativePackOp() - status = "pass" notes = "" baseline_ms: float | str = "" @@ -181,10 +197,11 @@ def _pack_row(config: BenchmarkConfig) -> dict[str, Any]: packed_logp_mem_gb: float | str = "" mem_saving_pct: float | str = "" - # (1) pack latency: PyTorch boolean-index baseline vs Triton candidate. + # (1) pack latency: PyTorch boolean-index baseline vs Triton candidate + # (packing the hidden states, the [*, D] tensor moved in the real chain). _reset_peak(config.device) base_packed, baseline_ms = _time_ms( - lambda: _baseline_pack(logits, mask), + lambda: _baseline_pack(hidden, mask), config.device, warmup=config.warmup, repeat=config.repeat, @@ -202,7 +219,7 @@ def _pack_row(config: BenchmarkConfig) -> dict[str, Any]: raise RuntimeError(f"{candidate_name} backend is unavailable") (cand_packed, _), candidate_ms = _time_ms( - lambda: candidate_op(logits, mask), + lambda: candidate_op(hidden, mask), config.device, warmup=config.warmup, repeat=config.repeat, @@ -210,16 +227,21 @@ def _pack_row(config: BenchmarkConfig) -> dict[str, Any]: speedup = baseline_ms / candidate_ms if candidate_ms else float("inf") pack_drift = (cand_packed.float() - base_packed.float()).abs().max().item() - # (2) end-to-end peak VRAM: dense logp vs pack->logp. + # (2) end-to-end peak VRAM: dense (full logits) vs pack-then-project. + flat_ids = ids.reshape(-1) _reset_peak(config.device) - _ = _selected_logp(logits, ids) + dense_logits = (hidden.reshape(-1, hidden_dim) @ lm_head) + _ = _selected_logp(dense_logits, flat_ids) + del dense_logits _sync(config.device) dense_logp_mem_gb = _peak_memory_gb(config.device) _reset_peak(config.device) - packed_logits, _ = candidate_op(logits, mask) + packed_hidden, _ = candidate_op(hidden, mask) packed_ids, _ = candidate_op(ids.unsqueeze(-1), mask) + packed_logits = packed_hidden @ lm_head _ = _selected_logp(packed_logits, packed_ids.squeeze(-1)) + del packed_logits, packed_hidden _sync(config.device) packed_logp_mem_gb = _peak_memory_gb(config.device) @@ -247,6 +269,7 @@ def _fmt(value: Any, spec: str) -> Any: "samples_per_prompt": config.samples_per_prompt, "completion_len": config.completion_len, "vocab_size": config.vocab_size, + "hidden_dim": config.hidden_dim, "mask_density": config.mask_density, "valid_tokens": metadata["valid_tokens"], "baseline_ms": f"{baseline_ms:.4f}" if isinstance(baseline_ms, float) else baseline_ms, @@ -287,6 +310,7 @@ def build_arg_parser() -> argparse.ArgumentParser: parser.add_argument("--g-sizes", default="8", help="Comma-separated samples-per-prompt values") parser.add_argument("--completion-lens", default="1024") parser.add_argument("--vocab-sizes", default="32768,131072") + parser.add_argument("--hidden-dim", type=int, default=4096) parser.add_argument( "--mask-densities", default="0.1,0.3,1.0", @@ -310,12 +334,14 @@ def main() -> None: completion_lens = [8] vocab_sizes = [128] mask_densities = [0.5, 1.0] + hidden_dim = 64 else: num_prompts = args.num_prompts g_sizes = _parse_int_list(args.g_sizes) completion_lens = _parse_int_list(args.completion_lens) vocab_sizes = _parse_int_list(args.vocab_sizes) mask_densities = _parse_float_list(args.mask_densities) + hidden_dim = args.hidden_dim rows: list[dict[str, Any]] = [] for samples_per_prompt in g_sizes: @@ -330,6 +356,7 @@ def main() -> None: samples_per_prompt=samples_per_prompt, completion_len=completion_len, vocab_size=vocab_size, + hidden_dim=hidden_dim, mask_density=mask_density, seed=args.seed, warmup=args.warmup, @@ -349,6 +376,7 @@ def main() -> None: "samples_per_prompt": samples_per_prompt, "completion_len": completion_len, "vocab_size": vocab_size, + "hidden_dim": hidden_dim, "mask_density": mask_density, "valid_tokens": "", "baseline_ms": "", From e9e6505284ce82484bb17e0669467253abf228d1 Mon Sep 17 00:00:00 2001 From: sadlerchen Date: Tue, 23 Jun 2026 14:18:25 +0800 Subject: [PATCH 4/5] refactor(kernels): remove duplicate mask.dim() check in pack _validate The second mask.dim() < 1 guard was dead code: the intervening shape check does not alter mask.dim(). Addresses CodeRabbit review on #182. --- rl_engine/kernels/ops/pytorch/packing/pack.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/rl_engine/kernels/ops/pytorch/packing/pack.py b/rl_engine/kernels/ops/pytorch/packing/pack.py index 5a79064..e5c62f3 100644 --- a/rl_engine/kernels/ops/pytorch/packing/pack.py +++ b/rl_engine/kernels/ops/pytorch/packing/pack.py @@ -32,8 +32,6 @@ def _validate(x: torch.Tensor, mask: torch.Tensor) -> None: f"mask shape {tuple(mask.shape)} must match the leading dims of " f"x.shape {tuple(x.shape)} (expected {tuple(x.shape[: mask.dim()])})." ) - if mask.dim() < 1: - raise ValueError("mask must have at least one dimension.") class _PackFunction(torch.autograd.Function): From db15d4f85c50f811dbe4e4c3223d2af24bf47f60 Mon Sep 17 00:00:00 2001 From: sadlerchen Date: Wed, 24 Jun 2026 11:12:33 +0800 Subject: [PATCH 5/5] fix(kernels): count cu_seqlens from bool mask in NativePackOp Packing selects rows via mask.to(bool) (nonzero == active), but cu_seqlens summed the raw mask, so a non-bool mask (e.g. values in {0, 2}) inflated the prefix sum beyond the number of rows actually packed. Count from the same bool mask so cu_seqlens always matches the packed row count. Adds a regression test. Addresses review feedback on #182. --- rl_engine/kernels/ops/pytorch/packing/pack.py | 5 ++++- tests/test_pack.py | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/rl_engine/kernels/ops/pytorch/packing/pack.py b/rl_engine/kernels/ops/pytorch/packing/pack.py index e5c62f3..6a8a0f5 100644 --- a/rl_engine/kernels/ops/pytorch/packing/pack.py +++ b/rl_engine/kernels/ops/pytorch/packing/pack.py @@ -50,7 +50,10 @@ def forward(ctx, x: torch.Tensor, mask: torch.Tensor): packed = flat_x.index_select(0, index) # cu_seqlens: prefix-sum of per-row active counts, for varlen consumers. - per_row_active = mask.reshape(mask.shape[0], -1).to(torch.int64).sum(dim=1) + # Count from the bool mask so a non-bool mask (e.g. {0, 2}) matches the + # number of rows actually packed above (nonzero == active). + bool_mask = flat_mask.reshape(mask.shape[0], -1) + per_row_active = bool_mask.to(torch.int64).sum(dim=1) cu_seqlens = torch.zeros(mask.shape[0] + 1, dtype=torch.int64, device=x.device) torch.cumsum(per_row_active, dim=0, out=cu_seqlens[1:]) diff --git a/tests/test_pack.py b/tests/test_pack.py index f67aa19..049b776 100644 --- a/tests/test_pack.py +++ b/tests/test_pack.py @@ -94,6 +94,23 @@ def test_cu_seqlens_is_per_row_prefix_sum(): assert torch.equal(cu_seqlens, expected) +def test_non_bool_mask_cu_seqlens_matches_packed_rows(): + """A non-bool mask (nonzero == active) must not over-count cu_seqlens. + + Counting active rows from the raw integer mask (e.g. values in {0, 2}) + would inflate cu_seqlens beyond the number of rows actually packed. + """ + op = NativePackOp() + mask = torch.tensor([[0, 2, 0], [2, 2, 0]], dtype=torch.int32) + x = torch.randn(2, 3, 4) + + packed, cu_seqlens = op(x, mask) + # 3 nonzero positions -> 3 packed rows; cu_seqlens must end at 3, not 6. + assert packed.shape[0] == 3 + assert cu_seqlens.tolist() == [0, 1, 3] + assert torch.equal(packed, x.reshape(-1, 4)[mask.reshape(-1).to(torch.bool)]) + + def test_pack_all_active_is_identity_flatten(): batch = _batch(seed=3, valid_density=1.0) x = _dense(batch, seed=103)