feat(testing): add TP-invariant reduction references#103
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (4)
✅ Files skipped from review due to trivial changes (1)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughAdds a TP-invariant reduction contract, reference implementations, shard-aware masked reductions, drift diagnostics, public re-exports, and an extensive test suite (unit, distributed Gloo smoke, large-vocab, and CUDA fp16 gated tests) to validate vocab-sharded selected-logprobs and masked statistics. ChangesTP-Invariant Reduction Contract and Implementation
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Warning Tools execution failed with the following error: Failed to run tools: Ping-pong health check failed Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/testing/reference_ops.py`:
- Around line 124-126: owner_ranks_for_token_ids currently overwrites owners for
overlapping shard_ranges which can violate the “exactly one owner” contract used
by summarize_tp_logprob_drift; add validation to detect overlaps and fail fast:
before iterating shard_ranges verify the ranges are sorted and non-overlapping
(each start >= previous end) or, inside the loop, check where owns is true and
owners != -1 and raise an error identifying the overlapping shard and token_ids;
reference the function owner_ranks_for_token_ids and the variables shard_ranges,
token_ids, owns, and owners when adding the check so overlaps are rejected
rather than silently taking the last match.
- Around line 362-367: The maskless branch in sharded_active_token_count trusts
value_shards[0].device but doesn't validate the rest, causing cross-device
tensors; mirror the device validation from sharded_masked_sum by iterating all
entries in value_shards, check that each.values.device equals the chosen device,
and raise a ValueError (or similar) if any shard is on a different device before
computing count and returning the tensor; reference mask_shards, value_shards,
sharded_active_token_count, sharded_masked_sum, sharded_masked_mean, device, and
count when making the change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 3e77c135-7d94-4528-b2a0-3054de221438
📒 Files selected for processing (4)
docs/design/tp-invariant-reductions.mdrl_engine/testing/__init__.pyrl_engine/testing/reference_ops.pytests/test_tp_invariant_reductions.py
| for rank, (start, end) in enumerate(shard_ranges): | ||
| owns = active & (token_ids >= int(start)) & (token_ids < int(end)) | ||
| owners = torch.where(owns, torch.full_like(owners, rank), owners) |
There was a problem hiding this comment.
Reject overlapping shard ranges instead of silently taking the last match.
owner_ranks_for_token_ids() overwrites owners on every matching range, so an overlapping shard_ranges input silently reports the last rank as the owner. That breaks the “exactly one owner” contract the TP references rely on and can misattribute diagnostics downstream in summarize_tp_logprob_drift(). Add a one-time validation that shard_ranges are sorted and non-overlapping before the loop, or detect owners != -1 when owns is true and raise.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/testing/reference_ops.py` around lines 124 - 126,
owner_ranks_for_token_ids currently overwrites owners for overlapping
shard_ranges which can violate the “exactly one owner” contract used by
summarize_tp_logprob_drift; add validation to detect overlaps and fail fast:
before iterating shard_ranges verify the ranges are sorted and non-overlapping
(each start >= previous end) or, inside the loop, check where owns is true and
owners != -1 and raise an error identifying the overlapping shard and token_ids;
reference the function owner_ranks_for_token_ids and the variables shard_ranges,
token_ids, owns, and owners when adding the check so overlaps are rejected
rather than silently taking the last match.
| if mask_shards is None: | ||
| if not value_shards: | ||
| raise ValueError("value_shards must be provided when mask_shards is None") | ||
| device = value_shards[0].device | ||
| count = sum(int(values.numel()) for values in value_shards) | ||
| return torch.tensor(count, device=device, dtype=reduction_dtype) |
There was a problem hiding this comment.
Validate value_shards devices on the maskless count path.
When mask_shards is None, sharded_active_token_count() trusts value_shards[0].device and never checks the remaining shards. A mixed-device input then returns a count tensor on an arbitrary device, and sharded_masked_mean() can later fail when dividing by a numerator produced on a different device. Mirror the same all-shards device validation used in sharded_masked_sum() before computing count.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/testing/reference_ops.py` around lines 362 - 367, The maskless
branch in sharded_active_token_count trusts value_shards[0].device but doesn't
validate the rest, causing cross-device tensors; mirror the device validation
from sharded_masked_sum by iterating all entries in value_shards, check that
each.values.device equals the chosen device, and raise a ValueError (or similar)
if any shard is on a different device before computing count and returning the
tensor; reference mask_shards, value_shards, sharded_active_token_count,
sharded_masked_sum, sharded_masked_mean, device, and count when making the
change.
Signed-off-by: inaniloquentee <3051000145@qq.com>
8781ab6 to
8334615
Compare
Summary
Closes #102
Validation
py -3.13 -m pre_commit run --all-filespy -3.13 -m mypy --ignore-missing-imports rl_engine/py -3.13 -m pytest rl_engine/tests/test_dispatch.py -vpy -3.13 -m pytest tests/test_tp_invariant_reductions.py tests/test_reference_ops.py -q -rspy -3.13 -m mkdocs build --strict -f mkdocs.yamlpy -3.13 -m pytest -qDCO
Signed-off-by: inaniloquentee <3051000145@qq.com>Summary by CodeRabbit
Documentation
New Features
Tests