Skip to content

feat(kernels): add fused masking + variable-length pack-and-pad op #182

Open
Chen-BUPT wants to merge 6 commits into
RL-Align:mainfrom
Chen-BUPT:feat/pack-and-pad
Open

feat(kernels): add fused masking + variable-length pack-and-pad op #182
Chen-BUPT wants to merge 6 commits into
RL-Align:mainfrom
Chen-BUPT:feat/pack-and-pad

Conversation

@Chen-BUPT

@Chen-BUPT Chen-BUPT commented Jun 23, 2026

Copy link
Copy Markdown

Fused masking + variable-length pack-and-pad op (#42)

What this adds

A pack operator that compacts the active rows of a dense [B, S, *tail]
tensor (selected by a [B, S] mask) into a contiguous [Total_Active, *tail]
tensor, returns per-row cu_seqlens, and scatters gradients back to the dense
layout on the backward pass.

Backend Class Notes
PyTorch (CPU/GPU fallback) NativePackOp Portable reference; defines the numerical contract. Packing order matches SyntheticRLKernelBatch.compact_completion_values.
Triton (CUDA & ROCm) TritonPackOp Host-side prefix-sum for destination indices; gather/scatter kernels move the tail vectors. Numerically identical to the native op.

Registry dispatch for "pack": Triton on GPU, PyTorch fallback on CPU.

Correctness

  • 18/18 tests in tests/test_pack.py pass on an NVIDIA H20 (SM90, CUDA 13.0).
  • Native op validated against the repo's canonical compaction + gradcheck.
  • Triton op validated against the native op (forward + backward); max-abs
    drift = 0.000e+00
    across all benchmarked shapes.

Why it matters: end-to-end VRAM

In RL training only the response / non-padding tokens contribute to the loss.
Packing the hidden states before the vocab projection means the full
[B, S, V] logits are never materialized for masked-out tokens — exactly the
saving #42 targets ("saved memory can be used for larger batches or longer
CoT").

Benchmark: hidden=4096 -> lm_head -> logits -> selected logp, bf16,
B=32, S=1024, comparing dense (full logits) vs pack-then-project.

vocab = 131072

mask_density valid_tokens dense logp (GB) packed logp (GB) VRAM saving speedup pack_drift
0.05 1638 17.31 2.12 87.8 % 0.23 0
0.10 3277 17.33 2.96 82.9 % 0.29 0
0.30 9830 17.43 6.31 63.8 % 0.45 0
0.50 16384 17.53 9.66 44.9 % 0.57 0
1.00 32768 17.78 18.03 -1.4 % 0.74 0

vocab = 32768

mask_density valid_tokens dense logp (GB) packed logp (GB) VRAM saving speedup pack_drift
0.05 1638 4.56 0.77 83.1 % 0.25 0
0.10 3277 4.58 1.01 78.0 % 0.32 0
0.30 9830 4.68 1.96 58.2 % 0.46 0
0.50 16384 4.78 2.91 39.2 % 0.54 0
1.00 32768 5.03 5.28 -5.0 % 0.78 0

(Full data: pack_h20.csv.)

How to read the numbers

  • VRAM saving is the headline. With a sparse loss mask (long prompt, short
    response — the common RL case), packing before the projection cuts peak logp
    memory by up to ~88 % (density 0.05). The 17 GB -> 2 GB drop lets the
    same GPU fit a much larger batch or longer CoT.
  • density = 1.0 (all active) is the control group: saving ≈ 0 % (a small
    pack overhead), as expected — there is nothing to compact, so nothing is
    saved. This confirms the measurement is honest.
  • Latency (speedup < 1) is reported as-is. The pack op itself is
    memory-bound and a touch slower than PyTorch's boolean indexing
    (index_select is already highly tuned). Its absolute cost is 0.06–0.35 ms,
    negligible against the multi-GB memory it saves. This PR targets the VRAM
    win ([FEAT][kernels]: implement Fused Masking and Variable-Length Sequence Packing (Pack-and-Pad) #42's stated motivation);

Reproduce

PYTHONPATH=. python -m pytest tests/test_pack.py -v        # correctness (needs CUDA for Triton cases)

PYTHONPATH=. python benchmarks/benchmark_pack.py \
  --num-prompts 4 --g-sizes 8 --hidden-dim 4096 \
  --mask-densities 0.05,0.1,0.3,0.5,1.0 \
  --completion-lens 1024 --vocab-sizes 32768,131072 \
  --output benchmarks/results/pack_h20.csv

Notes for reviewers

  • Hardware fallback: no GPU / no Triton → automatically dispatches to
    NativePackOp (CPU path); Triton tests skip cleanly.
  • ROCm: the Triton kernels are wavefront-agnostic (plain gather/scatter), so
    they run on ROCm via Triton without CUDA-specific intrinsics.
  • Please add the needs-gpu-ci label so the GPU CI exercises the Triton
    path.

Summary by CodeRabbit

Summary by CodeRabbit

  • New Features

    • Added a fused “pack-and-pad” operation that compacts masked tokens into a packed tensor plus cu_seqlens, with both PyTorch and Triton GPU backends.
    • Updated operation routing so the "pack" op dispatches to the best available backend per platform.
  • Tests

    • Added a full test suite covering forward correctness, cu_seqlens, unpack round-trips, gradient behavior, edge cases, and backend dispatch.
  • Chores

    • Added a benchmark runner to measure latency, peak GPU memory, numerical drift, and export results to CSV.

sadlerchen added 3 commits June 23, 2026 11:46
…AM) (RL-Align#42)

Measure TritonPackOp vs a PyTorch boolean-index baseline for pack latency,
and the end-to-end peak VRAM of dense logp vs pack->logp to quantify the
memory saving on sparse masks (the motivation behind RL-Align#42). Follows the
existing benchmark_ratio_kl.py conventions (CUDA-event median timing,
max_memory_allocated, CSV output, --smoke).
Pack hidden states before the vocab projection so the dense [B,S,V] logits
are never materialized for masked-out tokens, which is the actual RL-Align#42 saving.
Drop the fp32 upcast in selected-logp (use logits - logsumexp) so the dense
path's peak memory reflects the logits, not an fp32 copy. Add --hidden-dim.
@coderabbitai

coderabbitai Bot commented Jun 23, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 325a0dd2-df96-4bab-907c-f3194f32e5ad

📥 Commits

Reviewing files that changed from the base of the PR and between e9e6505 and db15d4f.

📒 Files selected for processing (2)
  • rl_engine/kernels/ops/pytorch/packing/pack.py
  • tests/test_pack.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • rl_engine/kernels/ops/pytorch/packing/pack.py
  • tests/test_pack.py

📝 Walkthrough

Walkthrough

Adds a fused masking and variable-length pack-and-pad operation implementing both a PyTorch-native autograd reference (NativePackOp) and a Triton GPU-accelerated version (TritonPackOp). Both are registered in KernelRegistry for CUDA, ROCm, and CPU dispatch. A test module validates correctness and a benchmark script measures latency and peak VRAM.

Changes

Pack-and-pad fused op

Layer / File(s) Summary
NativePackOp: validation, autograd forward/backward, and unpack
rl_engine/kernels/ops/pytorch/packing/__init__.py, rl_engine/kernels/ops/pytorch/packing/pack.py
Input validation enforces mask dimensionality and shape prefix match. _PackFunction gathers active rows into [Total_Active, *tail], computes cu_seqlens via prefix-sum of per-row active counts, and scatters gradients back using index_copy_. NativePackOp.unpack inverts packing with zeros at inactive positions.
TritonPackOp: JIT gather/scatter kernels and autograd
rl_engine/kernels/ops/triton/packing/__init__.py, rl_engine/kernels/ops/triton/packing/pack.py
Two Triton JIT kernels handle forward gather (active row tails → packed output) and backward scatter (grad_packed → zero-init dense grad). _dest_index computes exclusive prefix-sum dest indices with -1 for inactive rows. TritonPackOp validates GPU device type before dispatching.
KernelRegistry dispatch wiring
rl_engine/kernels/registry.py
OpBackend gains TRITON_PACK and PYTORCH_PACK members. The CUDA and ROCm priority maps route "pack" to [TRITON_PACK, PYTORCH_PACK]; the CPU map routes to [PYTORCH_PACK].
Test suite: native and Triton forward/backward/registry
tests/test_pack.py
Validates forward correctness against canonical compaction and index-based reference, cu_seqlens semantics, all-active and none-active edge cases, unpack round-trip, backward scatter and gradcheck, multi-dim tail, mask-shape rejection, and registry dispatch selection. CUDA-gated tests compare Triton vs native across densities and verify CPU-input rejection.
Benchmark script: pack latency and VRAM sweep
benchmarks/benchmark_pack.py
BenchmarkConfig and CSV_COLUMNS define the sweep schema. _pack_row measures baseline (boolean-index) vs Triton pack latency plus dense vs pack-first peak VRAM per shape. main iterates configs with OOM-per-shape handling and writes CSV/stdout via _write_rows.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • bitborne
  • EthanZero2Hero
  • maxiaosong1124
  • inaniloquentee
  • Flink-ddd

Poem

🐇 Hop, hop, tokens scatter wide,
A mask decides who gets to ride.
Triton gathers, PyTorch packs,
cu_seqlens cover all the tracks.
Dense and fused — the VRAM shrinks,
The benchmark proves it faster, thinks the bunny. 🎉

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 22.22% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat(kernels): add fused masking + variable-length pack-and-pad op' clearly and specifically describes the main change—adding a new fused pack operator with masking and variable-length support for the kernel registry.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
rl_engine/kernels/ops/pytorch/packing/pack.py (1)

27-37: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value

Remove the duplicate dimension check.

The mask.dim() < 1 guard at Lines 35-36 is dead code — it repeats the identical check at Lines 28-29, and the shape comparison in between never alters mask.dim().

♻️ Proposed cleanup
 def _validate(x: torch.Tensor, mask: torch.Tensor) -> None:
     if mask.dim() < 1:
         raise ValueError("mask must have at least one dimension.")
     if mask.shape != x.shape[: mask.dim()]:
         raise ValueError(
             f"mask shape {tuple(mask.shape)} must match the leading dims of "
             f"x.shape {tuple(x.shape)} (expected {tuple(x.shape[: mask.dim()])})."
         )
-    if mask.dim() < 1:
-        raise ValueError("mask must have at least one dimension.")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/pytorch/packing/pack.py` around lines 27 - 37, The
_validate function contains a duplicate dimension check for mask.dim() < 1 that
appears after the shape validation logic. Remove the second occurrence of the
identical check (the one appearing after the shape comparison) since the
intermediate validation logic does not modify mask.dim() and therefore makes the
repeated check dead code. Keep only the initial dimension check and remove the
redundant check that follows the shape validation.
rl_engine/kernels/ops/triton/packing/pack.py (1)

99-102: 🚀 Performance & Scalability | 🔵 Trivial | ⚖️ Poor tradeoff

Gather grid launches a program for every source row, including inactive ones.

The grid is sized over n_rows = B*S, and each program loads dest and early-exits when the row is inactive (dest < 0). For the low-density packing this op targets (e.g. 0.05), the large majority of launched programs do no work. Launching over the n_active rows via a packed→source inverse index would avoid the wasted launches on the hot path. This is the design tradeoff already acknowledged in the PR, so feel free to defer.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/triton/packing/pack.py` around lines 99 - 102, The grid
in the `_pack_gather_kernel` launch is currently sized over `n_rows` which
includes all source rows, causing programs to be launched for inactive rows that
do no work (they just early-exit when dest < 0). To optimize this for sparse
packing scenarios, resize the grid to be sized over `n_active` instead of
`n_rows`, and introduce a packed-to-source inverse index mapping that the kernel
can use to look up which source rows are actually active. This eliminates wasted
kernel launches on inactive rows while keeping the same kernel logic.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@benchmarks/benchmark_pack.py`:
- Around line 214-252: The broad except Exception block that catches all
exceptions and sets status to "blocked" masks real execution errors
(kernel/runtime/math failures) as if the candidate backend is simply
unavailable. Replace the generic except Exception handler with specific
exception handling that only catches exceptions that genuinely indicate
unavailability (such as ImportError for missing kernel_registry or RuntimeError
for unavailable backends), and allow other exceptions like runtime/math/kernel
errors to propagate or be handled separately so they surface as actual failures
rather than being silently masked as blocked candidates.

In `@tests/test_pack.py`:
- Around line 18-31: The TritonPackOp import on line 18 happens unconditionally
before the Triton availability check, causing test collection to fail in
environments without Triton. Move the TritonPackOp import into the try block
alongside the triton import, and add a fallback assignment in the except block
(such as TritonPackOp = None) so the name can be safely referenced elsewhere in
the code. This ensures tests are properly skipped by the requires_triton_cuda
marker instead of failing during collection.

---

Nitpick comments:
In `@rl_engine/kernels/ops/pytorch/packing/pack.py`:
- Around line 27-37: The _validate function contains a duplicate dimension check
for mask.dim() < 1 that appears after the shape validation logic. Remove the
second occurrence of the identical check (the one appearing after the shape
comparison) since the intermediate validation logic does not modify mask.dim()
and therefore makes the repeated check dead code. Keep only the initial
dimension check and remove the redundant check that follows the shape
validation.

In `@rl_engine/kernels/ops/triton/packing/pack.py`:
- Around line 99-102: The grid in the `_pack_gather_kernel` launch is currently
sized over `n_rows` which includes all source rows, causing programs to be
launched for inactive rows that do no work (they just early-exit when dest < 0).
To optimize this for sparse packing scenarios, resize the grid to be sized over
`n_active` instead of `n_rows`, and introduce a packed-to-source inverse index
mapping that the kernel can use to look up which source rows are actually
active. This eliminates wasted kernel launches on inactive rows while keeping
the same kernel logic.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b134abab-ef3b-4adc-8ca3-c50dd21ef7ae

📥 Commits

Reviewing files that changed from the base of the PR and between 51b8b21 and b6ed34f.

📒 Files selected for processing (7)
  • benchmarks/benchmark_pack.py
  • rl_engine/kernels/ops/pytorch/packing/__init__.py
  • rl_engine/kernels/ops/pytorch/packing/pack.py
  • rl_engine/kernels/ops/triton/packing/__init__.py
  • rl_engine/kernels/ops/triton/packing/pack.py
  • rl_engine/kernels/registry.py
  • tests/test_pack.py

Comment on lines +214 to +252
try:
from rl_engine.kernels.registry import kernel_registry

candidate_op = kernel_registry.get_op("pack")
if candidate_op.__class__.__name__ != candidate_name:
raise RuntimeError(f"{candidate_name} backend is unavailable")

(cand_packed, _), candidate_ms = _time_ms(
lambda: candidate_op(hidden, mask),
config.device,
warmup=config.warmup,
repeat=config.repeat,
)
speedup = baseline_ms / candidate_ms if candidate_ms else float("inf")
pack_drift = (cand_packed.float() - base_packed.float()).abs().max().item()

# (2) end-to-end peak VRAM: dense (full logits) vs pack-then-project.
flat_ids = ids.reshape(-1)
_reset_peak(config.device)
dense_logits = (hidden.reshape(-1, hidden_dim) @ lm_head)
_ = _selected_logp(dense_logits, flat_ids)
del dense_logits
_sync(config.device)
dense_logp_mem_gb = _peak_memory_gb(config.device)

_reset_peak(config.device)
packed_hidden, _ = candidate_op(hidden, mask)
packed_ids, _ = candidate_op(ids.unsqueeze(-1), mask)
packed_logits = packed_hidden @ lm_head
_ = _selected_logp(packed_logits, packed_ids.squeeze(-1))
del packed_logits, packed_hidden
_sync(config.device)
packed_logp_mem_gb = _peak_memory_gb(config.device)

if dense_logp_mem_gb > 0:
mem_saving_pct = 100.0 * (1.0 - packed_logp_mem_gb / dense_logp_mem_gb)
except Exception as exc:
status = "blocked"
notes = f"candidate unavailable: {str(exc).splitlines()[0]}"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

Don’t swallow all benchmark failures as “candidate unavailable”.

The broad except Exception on Line 250 masks real execution regressions (kernel/runtime/math errors) as "blocked", which can silently produce misleading benchmark results.

Suggested fix
-    else:
-        try:
-            from rl_engine.kernels.registry import kernel_registry
-
-            candidate_op = kernel_registry.get_op("pack")
-            if candidate_op.__class__.__name__ != candidate_name:
-                raise RuntimeError(f"{candidate_name} backend is unavailable")
+    else:
+        try:
+            from rl_engine.kernels.registry import kernel_registry
+            candidate_op = kernel_registry.get_op("pack")
+        except (ImportError, RuntimeError) as exc:
+            status = "blocked"
+            notes = f"candidate unavailable: {str(exc).splitlines()[0]}"
+            candidate_op = None
 
-            (cand_packed, _), candidate_ms = _time_ms(
-                lambda: candidate_op(hidden, mask),
-                config.device,
-                warmup=config.warmup,
-                repeat=config.repeat,
-            )
-            speedup = baseline_ms / candidate_ms if candidate_ms else float("inf")
-            pack_drift = (cand_packed.float() - base_packed.float()).abs().max().item()
+        if candidate_op is not None:
+            (cand_packed, _), candidate_ms = _time_ms(
+                lambda: candidate_op(hidden, mask),
+                config.device,
+                warmup=config.warmup,
+                repeat=config.repeat,
+            )
+            speedup = baseline_ms / candidate_ms if candidate_ms else float("inf")
+            pack_drift = (cand_packed.float() - base_packed.float()).abs().max().item()
 
             # (2) end-to-end peak VRAM: dense (full logits) vs pack-then-project.
             flat_ids = ids.reshape(-1)
             _reset_peak(config.device)
             dense_logits = (hidden.reshape(-1, hidden_dim) @ lm_head)
@@
-            if dense_logp_mem_gb > 0:
-                mem_saving_pct = 100.0 * (1.0 - packed_logp_mem_gb / dense_logp_mem_gb)
-        except Exception as exc:
-            status = "blocked"
-            notes = f"candidate unavailable: {str(exc).splitlines()[0]}"
+            if dense_logp_mem_gb > 0:
+                mem_saving_pct = 100.0 * (1.0 - packed_logp_mem_gb / dense_logp_mem_gb)
🧰 Tools
🪛 Ruff (0.15.18)

[warning] 250-250: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@benchmarks/benchmark_pack.py` around lines 214 - 252, The broad except
Exception block that catches all exceptions and sets status to "blocked" masks
real execution errors (kernel/runtime/math failures) as if the candidate backend
is simply unavailable. Replace the generic except Exception handler with
specific exception handling that only catches exceptions that genuinely indicate
unavailability (such as ImportError for missing kernel_registry or RuntimeError
for unavailable backends), and allow other exceptions like runtime/math/kernel
errors to propagate or be handled separately so they surface as actual failures
rather than being silently masked as blocked candidates.

Source: Linters/SAST tools

Comment thread tests/test_pack.py
Comment on lines +18 to +31
from rl_engine.kernels.ops.triton.packing.pack import TritonPackOp
from rl_engine.testing import make_synthetic_rl_kernel_batch

try:
import triton # noqa: F401

_HAS_TRITON = True
except ImportError: # pragma: no cover
_HAS_TRITON = False

requires_triton_cuda = pytest.mark.skipif(
not (_HAS_TRITON and torch.cuda.is_available()),
reason="Triton pack op requires a CUDA device and Triton.",
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify unconditional Triton operator import occurs before optional availability guard.
rg -n -C2 'from rl_engine\.kernels\.ops\.triton\.packing\.pack import TritonPackOp|try:|import triton|except ImportError' tests/test_pack.py

Repository: RL-Align/RL-Kernel

Length of output: 516


🏁 Script executed:

cat -n tests/test_pack.py

Repository: RL-Align/RL-Kernel

Length of output: 10503


Move Triton operator import into try block to prevent test collection failures without Triton.

Line 18 imports TritonPackOp unconditionally, so test collection will fail in environments without Triton before the @skipif marker is evaluated. Move the operator import into the Triton availability try block and add a fallback assignment so it can be safely referenced in conditional code.

Proposed fix
 import pytest
 import torch
 
 from rl_engine.kernels.ops.pytorch.packing.pack import NativePackOp
-from rl_engine.kernels.ops.triton.packing.pack import TritonPackOp
 from rl_engine.testing import make_synthetic_rl_kernel_batch
 
 try:
     import triton  # noqa: F401
+    from rl_engine.kernels.ops.triton.packing.pack import TritonPackOp
 
     _HAS_TRITON = True
 except ImportError:  # pragma: no cover
     _HAS_TRITON = False
+    TritonPackOp = None  # type: ignore[assignment]
@@
 def test_registry_dispatches_pack():
     from rl_engine.kernels.registry import kernel_registry
 
     op = kernel_registry.get_op("pack")
     if _HAS_TRITON and torch.cuda.is_available():
+        assert TritonPackOp is not None
         assert isinstance(op, TritonPackOp)
     else:
         assert isinstance(op, NativePackOp)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from rl_engine.kernels.ops.triton.packing.pack import TritonPackOp
from rl_engine.testing import make_synthetic_rl_kernel_batch
try:
import triton # noqa: F401
_HAS_TRITON = True
except ImportError: # pragma: no cover
_HAS_TRITON = False
requires_triton_cuda = pytest.mark.skipif(
not (_HAS_TRITON and torch.cuda.is_available()),
reason="Triton pack op requires a CUDA device and Triton.",
)
from rl_engine.kernels.ops.pytorch.packing.pack import NativePackOp
from rl_engine.testing import make_synthetic_rl_kernel_batch
try:
import triton # noqa: F401
from rl_engine.kernels.ops.triton.packing.pack import TritonPackOp
_HAS_TRITON = True
except ImportError: # pragma: no cover
_HAS_TRITON = False
TritonPackOp = None # type: ignore[assignment]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_pack.py` around lines 18 - 31, The TritonPackOp import on line 18
happens unconditionally before the Triton availability check, causing test
collection to fail in environments without Triton. Move the TritonPackOp import
into the try block alongside the triton import, and add a fallback assignment in
the except block (such as TritonPackOp = None) so the name can be safely
referenced elsewhere in the code. This ensures tests are properly skipped by the
requires_triton_cuda marker instead of failing during collection.

The second mask.dim() < 1 guard was dead code: the intervening shape check
does not alter mask.dim(). Addresses CodeRabbit review on RL-Align#182.
packed = flat_x.index_select(0, index)

# cu_seqlens: prefix-sum of per-row active counts, for varlen consumers.
per_row_active = mask.reshape(mask.shape[0], -1).to(torch.int64).sum(dim=1)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

packed uses mask.to(bool), but cu_seqlens sums the raw mask. Non-bool masks can produce wrong prefix sums. Please either require bool masks or compute counts from the bool mask.

Packing selects rows via mask.to(bool) (nonzero == active), but cu_seqlens
summed the raw mask, so a non-bool mask (e.g. values in {0, 2}) inflated the
prefix sum beyond the number of rows actually packed. Count from the same
bool mask so cu_seqlens always matches the packed row count. Adds a
regression test. Addresses review feedback on RL-Align#182.
@Chen-BUPT Chen-BUPT requested a review from inaniloquentee June 24, 2026 03:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants