Skip to content

feat(testing): add TP-invariant reduction references#103

Open
inaniloquentee wants to merge 1 commit into
mainfrom
feat/tp-invariant-reductions
Open

feat(testing): add TP-invariant reduction references#103
inaniloquentee wants to merge 1 commit into
mainfrom
feat/tp-invariant-reductions

Conversation

@inaniloquentee

@inaniloquentee inaniloquentee commented Jun 12, 2026

Copy link
Copy Markdown
Collaborator

Summary

  • add TP-invariant selected-logprob reference helpers for simulated vocab shards and real torch.distributed all-reduce paths
  • add sharded and distributed masked sum/count/mean helpers so loss denominators stay invariant across TP and micro-batch partitions
  • document the TP-invariant reduction contract, dtype policy, diagnostics, and test entry points
  • add parity tests covering TP=1/2/3/4/8, uneven vocab shards, owner-rank target tokens, masks, dtype matrix, gradients, GRPO loss composition, vocab-parallel lm_head, production-scale vocab smoke, CUDA smoke, and Gloo all-reduce smoke

Closes #102

Validation

  • py -3.13 -m pre_commit run --all-files
  • py -3.13 -m mypy --ignore-missing-imports rl_engine/
  • py -3.13 -m pytest rl_engine/tests/test_dispatch.py -v
  • py -3.13 -m pytest tests/test_tp_invariant_reductions.py tests/test_reference_ops.py -q -rs
  • py -3.13 -m mkdocs build --strict -f mkdocs.yaml
  • py -3.13 -m pytest -q

DCO

  • Commit includes Signed-off-by: inaniloquentee <3051000145@qq.com>

Summary by CodeRabbit

  • Documentation

    • Added design note specifying TP-invariant logprob reduction semantics, dtype policy, and expected diagnostics.
  • New Features

    • Exposed TP-aware reference utilities and global vocab-sharded masked reduction helpers.
    • Added drift-summary diagnostics for locating worst-token deviations.
  • Tests

    • Added comprehensive tests for TP invariance (uneven shards, temperature, masked semantics, gradients) plus distributed and CUDA smoke tests.

@coderabbitai

coderabbitai Bot commented Jun 12, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9f772554-3fd1-4f96-8953-da351bc81051

📥 Commits

Reviewing files that changed from the base of the PR and between 8781ab6 and 8334615.

📒 Files selected for processing (4)
  • docs/design/tp-invariant-reductions.md
  • rl_engine/testing/__init__.py
  • rl_engine/testing/reference_ops.py
  • tests/test_tp_invariant_reductions.py
✅ Files skipped from review due to trivial changes (1)
  • docs/design/tp-invariant-reductions.md
🚧 Files skipped from review as they are similar to previous changes (2)
  • rl_engine/testing/init.py
  • rl_engine/testing/reference_ops.py

📝 Walkthrough

Walkthrough

Adds a TP-invariant reduction contract, reference implementations, shard-aware masked reductions, drift diagnostics, public re-exports, and an extensive test suite (unit, distributed Gloo smoke, large-vocab, and CUDA fp16 gated tests) to validate vocab-sharded selected-logprobs and masked statistics.

Changes

TP-Invariant Reduction Contract and Implementation

Layer / File(s) Summary
Design Specification and Reduction Contract
docs/design/tp-invariant-reductions.md
Design note specifying TP-invariant semantics: global max + exp-sum reduction for selected logprobs across vocab shards, dtype/reduction-state policy, global masked-sum and active-token all-reduce semantics, and diagnostic/reporting fields.
TP Vocab Sharding Infrastructure
rl_engine/testing/reference_ops.py
vocab_shard_ranges computes uneven-tail shard boundaries; shard_logits_by_vocab slices logits into per-rank views; owner_ranks_for_token_ids maps token ids to owning TP ranks with validation.
TP-Invariant Selected Logprob References
rl_engine/testing/reference_ops.py
selected_logprobs_tp_reference computes selected logprobs from shard slices using global max/exp-sum across shards and enforces single-shard coverage; selected_logprobs_distributed_tp_reference reproduces semantics with torch.distributed all-reduce and distributed coverage checks.
Masked Reductions and Drift Diagnostics
rl_engine/testing/reference_ops.py
sharded_masked_sum/count/mean and distributed_masked_sum/count/mean implement global masked reductions with explicit reduction dtype control; summarize_tp_logprob_drift reports abs/relative errors, worst-token localization, and owner-rank/vocab-range attribution.
Module Exports and Public API
rl_engine/testing/__init__.py
Re-exports new reference helpers (vocab sharding, selected-logprob refs, sharded/distributed masked reductions, drift summarizer) by extending imports and __all__.
Test Infrastructure and Core Correctness
tests/test_tp_invariant_reductions.py
Test module with deterministic RNG and logit builders; core parity tests validating TP-selected logprobs match full-vocab references across TP sizes, uneven shard layouts, explicit offsets, dtypes, masked ignore-index handling, coverage enforcement, and temperature invariance.
Diagnostic Reporting and Gradient Consistency
tests/test_tp_invariant_reductions.py
Tests assert masked reductions use a global denominator, validate summarize_tp_logprob_drift owner attribution and relative errors, check TP-reference gradients match full-vocab gradients, and verify end-to-end GRPO loss/policy/KL invariance.
Distributed and Large-Scale Validation
tests/test_tp_invariant_reductions.py
Distributed Gloo smoke test using spawned processes; large-vocab shaped-matrix smoke tests across multiple TP sizes; production-scale tail-shard scenario; CUDA-gated fp16 smoke test for hardware tolerance checks.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • Flink-ddd
  • EthanZero2Hero

Poem

🐰 I split the vocab like carrot rows so neat,
Gathered max and sums so every rank can meet,
All-reduce whispers, numerators greet,
No drift remains where choices once did meet,
A rabbit hops—parity tastes sweet! 🥕

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 39.02% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title 'feat(testing): add TP-invariant reduction references' clearly and concisely summarizes the main change: adding reference implementations and test infrastructure for tensor-parallel-invariant reductions.
Linked Issues check ✅ Passed The pull request fully addresses all coding objectives from issue #102: providing TP-invariant reduction implementations for selected logprobs, masked reductions, reference helpers for both simulated and distributed TP, comprehensive test coverage across TP sizes/dtypes, and actionable diagnostics with drift reporting.
Out of Scope Changes check ✅ Passed All changes directly support the TP-invariant reduction contract: documentation, reference implementations, sharded/distributed helpers, and parity tests. No unrelated modifications or scope creep detected.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/tp-invariant-reductions

Warning

Tools execution failed with the following error:

Failed to run tools: Ping-pong health check failed


Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/testing/reference_ops.py`:
- Around line 124-126: owner_ranks_for_token_ids currently overwrites owners for
overlapping shard_ranges which can violate the “exactly one owner” contract used
by summarize_tp_logprob_drift; add validation to detect overlaps and fail fast:
before iterating shard_ranges verify the ranges are sorted and non-overlapping
(each start >= previous end) or, inside the loop, check where owns is true and
owners != -1 and raise an error identifying the overlapping shard and token_ids;
reference the function owner_ranks_for_token_ids and the variables shard_ranges,
token_ids, owns, and owners when adding the check so overlaps are rejected
rather than silently taking the last match.
- Around line 362-367: The maskless branch in sharded_active_token_count trusts
value_shards[0].device but doesn't validate the rest, causing cross-device
tensors; mirror the device validation from sharded_masked_sum by iterating all
entries in value_shards, check that each.values.device equals the chosen device,
and raise a ValueError (or similar) if any shard is on a different device before
computing count and returning the tensor; reference mask_shards, value_shards,
sharded_active_token_count, sharded_masked_sum, sharded_masked_mean, device, and
count when making the change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 3e77c135-7d94-4528-b2a0-3054de221438

📥 Commits

Reviewing files that changed from the base of the PR and between 04c014d and 8781ab6.

📒 Files selected for processing (4)
  • docs/design/tp-invariant-reductions.md
  • rl_engine/testing/__init__.py
  • rl_engine/testing/reference_ops.py
  • tests/test_tp_invariant_reductions.py

Comment on lines +124 to +126
for rank, (start, end) in enumerate(shard_ranges):
owns = active & (token_ids >= int(start)) & (token_ids < int(end))
owners = torch.where(owns, torch.full_like(owners, rank), owners)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject overlapping shard ranges instead of silently taking the last match.

owner_ranks_for_token_ids() overwrites owners on every matching range, so an overlapping shard_ranges input silently reports the last rank as the owner. That breaks the “exactly one owner” contract the TP references rely on and can misattribute diagnostics downstream in summarize_tp_logprob_drift(). Add a one-time validation that shard_ranges are sorted and non-overlapping before the loop, or detect owners != -1 when owns is true and raise.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/testing/reference_ops.py` around lines 124 - 126,
owner_ranks_for_token_ids currently overwrites owners for overlapping
shard_ranges which can violate the “exactly one owner” contract used by
summarize_tp_logprob_drift; add validation to detect overlaps and fail fast:
before iterating shard_ranges verify the ranges are sorted and non-overlapping
(each start >= previous end) or, inside the loop, check where owns is true and
owners != -1 and raise an error identifying the overlapping shard and token_ids;
reference the function owner_ranks_for_token_ids and the variables shard_ranges,
token_ids, owns, and owners when adding the check so overlaps are rejected
rather than silently taking the last match.

Comment on lines +362 to +367
if mask_shards is None:
if not value_shards:
raise ValueError("value_shards must be provided when mask_shards is None")
device = value_shards[0].device
count = sum(int(values.numel()) for values in value_shards)
return torch.tensor(count, device=device, dtype=reduction_dtype)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate value_shards devices on the maskless count path.

When mask_shards is None, sharded_active_token_count() trusts value_shards[0].device and never checks the remaining shards. A mixed-device input then returns a count tensor on an arbitrary device, and sharded_masked_mean() can later fail when dividing by a numerator produced on a different device. Mirror the same all-shards device validation used in sharded_masked_sum() before computing count.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/testing/reference_ops.py` around lines 362 - 367, The maskless
branch in sharded_active_token_count trusts value_shards[0].device but doesn't
validate the rest, causing cross-device tensors; mirror the device validation
from sharded_masked_sum by iterating all entries in value_shards, check that
each.values.device equals the chosen device, and raise a ValueError (or similar)
if any shard is on a different device before computing count and returning the
tensor; reference mask_shards, value_shards, sharded_active_token_count,
sharded_masked_sum, sharded_masked_mean, device, and count when making the
change.

Signed-off-by: inaniloquentee <3051000145@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEAT] TP-invariant reductions for FSDP(TP=1) vs TP>1 rollout/training parity

1 participant