Skip to content

Emit solved V-arrays sharded via out_shardings (continuous-axis prerequisite)#372

Closed
hmgaudecker wants to merge 1 commit into
feat/distributed-solve-fixesfrom
feat/restore-out-shardings
Closed

Emit solved V-arrays sharded via out_shardings (continuous-axis prerequisite)#372
hmgaudecker wants to merge 1 commit into
feat/distributed-solve-fixesfrom
feat/restore-out-shardings

Conversation

@hmgaudecker
Copy link
Copy Markdown
Member

Stacked on #370.

Re-adds the out_shardings path dropped from #364: solve lowers each max_Q_over_a with the regime's declared V-array sharding as out_shardings, so the solved V is emitted on the layout its next-period next_regime_to_V_arr consumers were lowered against — no post-hoc device_put reshard.

For a sharded discrete axis, natural jit output inference already reproduces this layout (the existing test_solution_running_on_multiple_cpus covers it), so this is a no-op there. Its purpose is to enable a sharded continuous axis (e.g. aime): there the next-period interpolation read folds the axis into the output, so without explicit out_shardings the solved V returns on a layout its consumers were not lowered against. This is the prerequisite for lifting _fail_if_continuous_grid_distributed and sharding aime to scale past the 3-device pref_type ceiling.

Ported from #364's 3009542 (the batch_size>0 + distributed=True combo its original test used is now rejected at grid construction, so the existing distributed suite is the regression guard).

🤖 Generated with Claude Code

`solve` lowers each unique `max_Q_over_a` with the regime's declared
V-array sharding as `out_shardings`, so the compiled XLA program produces
V already partitioned across the right devices — the layout its
next-period `next_regime_to_V_arr` consumers were lowered against, with no
post-hoc `device_put` reshard.

Natural jit output inference reproduces the declared layout for a sharded
discrete axis (the existing `test_solution_running_on_multiple_cpus`
covers that), but not for a sharded continuous axis, where the
interpolation read folds the axis into the output — so this is the
prerequisite for continuous-axis (e.g. `aime`) sharding once the
`_fail_if_continuous_grid_distributed` guard is lifted.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@read-the-docs-community
Copy link
Copy Markdown

@hmgaudecker
Copy link
Copy Markdown
Member Author

Closing — continuous-axis sharding (the only thing this enables) doesn't pan out. The aime-4× experiment on Marvin confirmed the pathology _fail_if_continuous_grid_distributed documents: with out_shardings the solved V does carry the declared sharding (it runs on 4 GPUs), but every next-period interpolation still all-gathers the full V-array per device — XLA flags 1–1.9 TiB input/output arguments. Net: ~28.5 s per 6-regime age vs ~10 s for discrete pref_type-3× (all-reduce), i.e. ~2.8× slower. The guard stays; this prerequisite isn't worth landing.

@hmgaudecker hmgaudecker closed this Jun 3, 2026
@hmgaudecker hmgaudecker deleted the feat/restore-out-shardings branch June 3, 2026 20:08
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Jun 3, 2026

Benchmark comparison (main → HEAD)

Comparing 820a2475 (main) → e772e9de (HEAD)

Benchmark Statistic before after Ratio Alert
aca-baseline execution time 14.887 s 14.806 s 0.99
peak GPU mem 581 MB 581 MB 1.00
compilation time 274.26 s 270.76 s 0.99
peak CPU mem 6.88 GB 7.40 GB 1.08
aca-baseline-debug execution time 76.769 s 75.855 s 0.99
peak GPU mem 581 MB 581 MB 1.00
compilation time 363.98 s 364.96 s 1.00
peak CPU mem 7.57 GB 7.50 GB 0.99
Mahler-Yum execution time 4.793 s 4.896 s 1.02
peak GPU mem 520 MB 520 MB 1.00
compilation time 13.23 s 13.30 s 1.01
peak CPU mem 1.55 GB 1.55 GB 1.00
Precautionary Savings - Solve execution time 29.0 ms 25.7 ms 0.89
peak GPU mem 8 MB 8 MB 1.00
compilation time 2.12 s 2.09 s 0.99
peak CPU mem 1.14 GB 1.15 GB 1.01
Precautionary Savings - Simulate execution time 88.7 ms 90.1 ms 1.02
peak GPU mem 162 MB 162 MB 1.00
compilation time 4.40 s 4.47 s 1.02
peak CPU mem 1.34 GB 1.34 GB 1.00
Precautionary Savings - Solve & Simulate execution time 119.9 ms 119.8 ms 1.00
peak GPU mem 566 MB 566 MB 1.00
compilation time 5.79 s 5.86 s 1.01
peak CPU mem 1.30 GB 1.30 GB 1.00
Precautionary Savings - Solve & Simulate (irreg) execution time 231.3 ms 227.9 ms 0.99
peak GPU mem 2.18 GB 2.18 GB 1.00
compilation time 6.27 s 6.20 s 0.99
peak CPU mem 1.36 GB 1.35 GB 0.99

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant