Emit solved V-arrays sharded via out_shardings (continuous-axis prerequisite)#372
Closed
hmgaudecker wants to merge 1 commit into
Closed
Emit solved V-arrays sharded via out_shardings (continuous-axis prerequisite)#372hmgaudecker wants to merge 1 commit into
hmgaudecker wants to merge 1 commit into
Conversation
`solve` lowers each unique `max_Q_over_a` with the regime's declared V-array sharding as `out_shardings`, so the compiled XLA program produces V already partitioned across the right devices — the layout its next-period `next_regime_to_V_arr` consumers were lowered against, with no post-hoc `device_put` reshard. Natural jit output inference reproduces the declared layout for a sharded discrete axis (the existing `test_solution_running_on_multiple_cpus` covers that), but not for a sharded continuous axis, where the interpolation read folds the axis into the output — so this is the prerequisite for continuous-axis (e.g. `aime`) sharding once the `_fail_if_continuous_grid_distributed` guard is lifted. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Member
Author
|
Closing — continuous-axis sharding (the only thing this enables) doesn't pan out. The aime-4× experiment on Marvin confirmed the pathology |
Benchmark comparison (main → HEAD)Comparing
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked on #370.
Re-adds the
out_shardingspath dropped from #364:solvelowers eachmax_Q_over_awith the regime's declared V-array sharding asout_shardings, so the solved V is emitted on the layout its next-periodnext_regime_to_V_arrconsumers were lowered against — no post-hocdevice_putreshard.For a sharded discrete axis, natural jit output inference already reproduces this layout (the existing
test_solution_running_on_multiple_cpuscovers it), so this is a no-op there. Its purpose is to enable a sharded continuous axis (e.g.aime): there the next-period interpolation read folds the axis into the output, so without explicitout_shardingsthe solved V returns on a layout its consumers were not lowered against. This is the prerequisite for lifting_fail_if_continuous_grid_distributedand shardingaimeto scale past the 3-devicepref_typeceiling.Ported from #364's
3009542(thebatch_size>0 + distributed=Truecombo its original test used is now rejected at grid construction, so the existing distributed suite is the regression guard).🤖 Generated with Claude Code