Skip to content

Distributed-scale solve fixes: sharded NaN/Inf reductions#370

Merged
hmgaudecker merged 58 commits into
mainfrom
feat/distributed-solve-fixes
Jun 5, 2026
Merged

Distributed-scale solve fixes: sharded NaN/Inf reductions#370
hmgaudecker merged 58 commits into
mainfrom
feat/distributed-solve-fixes

Conversation

@hmgaudecker
Copy link
Copy Markdown
Member

@hmgaudecker hmgaudecker commented Jun 1, 2026

Re-implements the still-relevant distributed-scale solve fixes from the abandoned #364, fresh on top of the discrete-only-sharding work (#368, now merged to main). Verified on 4 CPU devices (jax_num_cpu_devices=4).

What changed

Sharded NaN/Inf reductions (code). validate_V, contains_nan, log_nan_in_V, and solve's hot-loop diagnostic accumulator route their NaN/Inf checks through module-level @jax.jit helpers v_array_has_nan / v_array_has_inf. GSPMD then partitions each reduction across the V-array's devices (per-device any → all-reduce → replicated scalar). The eager jnp.any(jnp.isnan(V_arr)) alternative can fall back to gathering a sharded V onto the default device before reducing — exhausting device memory at production grid sizes. Regression: two checks that both helpers return a fully-replicated 0-d scalar on a NamedSharding-sharded input.

Batching/sharding tuning guide (docs). New docs/user_guide/tuning.md section contrasting the two grid knobs — batch_size (splay: a time-neutral memory knob) and distributed (shard: the only wall-clock knob, discrete non-transitioning axes only) — when each applies, and the boundaries the constructor enforces (continuous axes can't shard; you can't splay and shard the same axis). The reasoning behind the subtlest part of that guide is written out below — it documents existing behaviour, there is no code change for it.

Rationale for the guide — batch_size chunks the state sweep, not the shock expectation

(Background, not a code change: this is the reasoning the new guide distills.)

Each period's backward step is two nested sweeps:

  1. State sweep. Walk the grid of current states; for each current state compute V = max_a Q. A grid's batch_size chunks this sweep. The chunked axis is a current state — a dimension of the output V — so each chunk's results are written out and freed, and peak memory shrinks.
  2. Shock expectation. Inside one Q(state, action) cell, average the next period's value over the stochastic transition: E[V'] = Σ_{s'} P(s')·V'(s', …), built as a productmap over the next-period stochastic outcomes followed by jnp.average.

A stochastic grid (e.g. health) names an axis in both: a current-period state in the state sweep, and a summed-over next-period outcome in the shock expectation. batch_size acts only on the state sweep — tests/test_stochastic.py::test_stochastic_state_batch_size_is_value_equivalent_to_no_splay pins that solving with batch_size=1 gives a value function identical to batch_size=0.

The shock expectation stays unconditionally fused (batch_size=0). Its axis is summed away, not kept, so chunking it can't shrink the output — it only de-fuses the jnp.average from the V' interpolation, forcing the full stack of per-outcome V' arrays to materialize before the sum instead of streaming into the running total. Same jax.lax.map(batch_size=1) primitive, opposite sign depending on whether the chunked axis is kept (an output) or summed away (reduced) — isolated lax.map + jnp.average probe, K=8 next-period outcomes over an N=4M-cell working set (128 MB = K·N·4B):

batch_size summed-away axis (shock expectation) output axis (state sweep)
0 (→ vmap) 0 MB — fused into the average, streamed 128 MB — all transients live
1 (→ scan) 128 MB — full [K, N] stack materialized before the sum 16 MB — one transient at a time

Scale N to the production state-action working set and K to the joint next-period stochastic cardinality and 0 → 128 MB becomes fits → 271 GiB. To shed memory along a stochastic dimension, shard the current axis (distributed=True) or chunk the state sweep — both keep the shock expectation fused and beat a sequential scan.

Dropped from #364

out_shardings on the solve lowering (#364 piece 1) — obsolete. Its motivating case (continuous distributed=True) is forbidden outright (GridInitializationError), and discrete-distributed V already shards correctly via natural jit output inference (verified: num_devices=4 with and without out_shardings; all distributed tests pass without it). A docstring note on _fail_if_continuous_grid_distributed records the prerequisite (and points at #364's reference implementation) should continuous-axis sharding ever be re-enabled.

Device-count padding (#364 piece 6) landed in #368.

Verification

  • tests/test_stochastic.py (incl. the new batch_size-equivalence test) — relevant tests pass locally
  • pixi run -e type-checking ty — clean
  • prek run --all-files — clean
  • Full suite (pytest tests -n 4; tests/test_distributed.py on 4 CPU devices) — deferred to CI

🤖 Generated with Claude Code

hmgaudecker and others added 30 commits May 26, 2026 12:05
Sharding a continuous state axis forces every interpolation lookup to
all-gather the full V-array per device — the OOM pattern observed at
production scale. Sharding belongs only on discrete state grids, where
Markov transitions over the sharded axis contract via the cheap
all-reduce (output-sized, not V-array-sized).

The check fires at grid construction with a clear message; the older
batch_size + distributed rejection on continuous grids becomes
unreachable and was dropped from the parametrized tests.
Backward induction returns sharded V_arrs; the simulate phase must
consume materialised arrays rather than in-flight kernels. A per-shard
`.block_until_ready()` after the loop drains the device queue without
moving any data to the host. V stays sharded across devices.

The new `_drain_v_arr_shards` helper keeps `solve()` under the C901
complexity ceiling.
Sharded V-arrays now travel through the on-disk format without an
implicit gather. Each shard is written and restored on the same device
mesh via `orbax-checkpoint`; non-array metadata (regimes, ages,
pre-computed result metadata, parameter scaffolding) lives in a
sibling `metadata.pkl` produced via `cloudpickle`.

API:
- `result.save(directory)` → writes `<dir>/arrays/` (orbax) and
  `<dir>/metadata.pkl` (cloudpickle).
- `SimulationResult.load(directory)` → reverse, including the
  reconstruction of every `MappingProxyType` / `PeriodRegimeSimulationData`
  wrapper around the restored arrays.

Replaces the prior `to_pickle` / `from_pickle` pair, which cloudpickled
the whole result and triggered an all-gather on every sharded array's
`__reduce__`.
`save(*, directory, df_additional_targets="all", df_use_labels=True)`
now writes a third sibling artifact alongside `arrays/` and
`metadata.pkl`: `simulated_data.arrow`, a feather dump of the
`to_dataframe` projection. Downstream consumers can read the flat
per-subject view via `pd.read_feather` without re-instantiating a
`SimulationResult`, and pytask tasks can point their `Product` at
the arrow file as a stable single-file anchor for the saved bundle.

Coerces JAX 0-d arrays to Python scalars before writing feather:
regimes whose target function returns a constant get broadcast as a
0-d JAX scalar across the per-regime sub-frame, which pyarrow refuses
to convert. `_coerce_jax_scalar_for_arrow` is mapped element-wise
over the dataframe at the arrow boundary.

`load(*, directory)` reads arrays + metadata only; `simulated_data.arrow`
is a one-way artifact for downstream consumers.

Both `save` and `load` adopt keyword-only argument style consistent
with pylcm convention.

AGENTS.md and `docs/user_guide/solving_and_simulating.md` show the
new API.
- Rename `_drain_v_arr_shards` to `_drain_V_arr_shards`, plus the
  matching V_arr local-variable and test-function casing in
  test_distributed, test_float_dtype_invariants, test_analytical_solution.
- Replace the manual nested-shard loop with a single
  `jax.block_until_ready(solution)` call. JAX walks the pytree and
  blocks per-shard with no host transfer; the materialisation test
  stays green.
- Narrow `SimulationResult.save` / `.load` to `directory: Path`,
  dropping the str alternative. Every caller already passes a Path
  (tests use `tmp_path / "result"`, simulate tasks pass
  `output.parent`).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Four gaps closed so simulate's hot loop and debug outputs match the
materialisation strategy already in place on the solve side:

1. NaN validation in simulate is gated on `log_level`. At `"off"`
   `validate_V` is skipped (no per-(regime, period) host transfer);
   at `"warning"` / `"progress"` it warns rather than raising; only
   `"debug"` raises. Same contract as `solve(log_level=...)`.

2. `log_nan_in_V` only fires when diagnostics are enabled, so the
   per-period `__bool__` host transfer disappears at `"off"`.

3. Post-loop `jax.block_until_ready(simulation_results)` at the end
   of `simulate()` drains the per-period compute graph before
   returning. Mirrors solve's `_drain_V_arr_shards` so downstream
   consumers (`to_dataframe`, `save`) start with concrete arrays.

4. `log_regime_transitions` builds the full `(n_regimes, n_regimes)`
   transition count matrix in a single fused JAX kernel and
   host-transfers it once per period, instead of `O(n_regimes^2)`
   `.item()` calls per period at DEBUG.

Adds two tests in `test_nan_diagnostics.py` covering the gated
validation paths (`"off"` skips, `"warning"` warns rather than
raises) and a placeholder drain test in `test_distributed.py` as a
regression guard.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`log_nan_in_V` and `_simulate_regime_in_period` both already receive
the logger that carries the validation policy. Threading a separate
`diagnostics_enabled` bool through `_simulate_regime_in_period` was
redundant with the logger. Have each function call
`validation_enabled(logger)` itself instead — same pattern as
`log_regime_transitions`, which already self-gates on
`logger.isEnabledFor(logging.DEBUG)`.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
- `SimulationResult.save` now defaults `df_additional_targets=None`,
  so the `simulated_data.arrow` artifact carries only the base
  columns (states, actions, regime, age, period, V_arr). At ACA
  production scale that is ~860 MB per simulate; the previous
  `"all"` default would bake every DAG leaf in and grow each
  artifact to ~4 GB.
- Update the `save` round-trip test to compare against
  `to_dataframe(use_labels=True)` instead of
  `to_dataframe(additional_targets="all", use_labels=True)`.
- Rename `_period_v_to_array_tree` -> `_period_V_to_array_tree`
  and `_array_tree_to_period_v` -> `_array_tree_to_period_V`,
  and the per-period stat locals `v_min/v_max/v_mean` ->
  `V_min/V_max/V_mean`, and the new
  `test_simulate_log_level_warning_does_not_raise_on_nan_v_arr`
  test -> `..._V_arr`. Finishes the V_arr capitalisation pass
  across the codebase.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The previous pin (39ac270b) pre-dates aca-model's removal of
`assets_distributed` from `GridConfig` and so passes `distributed=True`
to a `LinSpacedGrid`. pylcm's new `_fail_if_continuous_grid_distributed`
guard rejects that at grid construction, breaking the GPU benchmark
job. Bump to 7af96820 (current head of aca-model's API-restructure
branch), which only sets `distributed=True` on the `pref_type`
DiscreteGrid (and the new discrete-state flags) — all valid targets
for sharding under the discrete-only rule.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
orbax's StandardCheckpointer.save stages each jax.Array on device
before transferring it to host, which doubles peak device usage at
save time. On a device whose memory is comparable to the V_arr,
the staging OOMs even when solve + simulate themselves had room.

Convert single-device leaves to numpy via jax.device_get upfront so
orbax serialises numpy directly. Sharded multi-device leaves are
left alone — orbax already does per-shard host transfer there.
Symmetric `_array_tree_to_jax` lifts numpy leaves back to jax.Array
on load so engine code sees the expected types.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`object` was too tight: callers (including the new helper tests)
subscript the result, which ty rejects on `object`. The merged
array tree is a heterogeneous `dict[str, ...]` of three sub-trees,
so `dict[str, Any]` is the smallest accurate type.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The previous one-shot orbax save materialised every V_arr leaf on
device before transferring to host. Leaves whose size approaches the
device cap then triggered an out-of-memory at save time even when
solve+simulate themselves had headroom — the engine's scan over
states with `batch_size > 0` keeps only one slice live during
compute, but `np.asarray(jax_array)` requires the whole array as one
contiguous device buffer.

This change matches the save's materialisation profile to the
engine's compute-time peak:

- `_build_chunk_specs(regimes, flat_params)` derives a per-regime
  `_ChunkSpec(chunk_axis, chunk_size)` by walking each V_arr's
  canonical state axis order and picking the outermost state with
  `batch_size > 0`.
- `SimulationResult` carries `chunk_specs`; `simulate()` populates
  it.
- `save()` writes `arrays/` (orbax for `raw_results` and
  `flat_params` — small) and `V_arr/period_<NNNN>/regime_<R>/`,
  where each V_arr leaf is sliced along its splay axis with width
  `chunk_size`, host-transferred per chunk, and dumped as
  `<I>.npy` files alongside a `meta.json` describing how to
  reassemble.
- Multi-device sharded leaves still go through orbax (preserves the
  sharding spec); single-device leaves take the chunked path.
- `load()` mirrors: orbax for the small trees, walk `V_arr/` and
  `np.concatenate` along `chunk_axis` per the recorded meta.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Surface the size distribution of `small_array_tree` on stderr so an OOM
inside `_serialize_arrays_batches_without_dispatcher` reports which leaf
exceeds the device cap. Top-K with path/shape/dtype/GiB, plus aggregate
count and bytes.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Reverses the order inside `SimulationResult.save`: chunked V-array
write goes first, the device-resident grid is dropped via
`self._period_to_regime_to_V_arr = MappingProxyType({})` + `gc.collect`
+ `jax.clear_caches`, and only then does orbax serialise the
`raw_results` / `flat_params` tree.

Background: V100 (16 GiB) orbax save OOM'd at exactly 16.30 GiB.
Instrumentation confirmed the small array tree is 2.568 GiB across
2,106 per-subject leaves — no single leaf is large. The OOM was the
platform allocator asking for ~the device's idle-free block on the
first `data.copy_to_host_async`, blocked by ~12-14 GiB of still
resident solve-grid V-arrays. Freeing those buffers before orbax
runs leaves ~13 GiB of headroom, dwarfing the 2.5 GiB transfer
demand.

Side effect: `self.period_to_regime_to_V_arr` is an empty mapping
after `save` returns. Callers that need the values back must
reload via `SimulationResult.load`. Documented in the `save`
docstring.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The compiled `simulate_functions.argmax_and_max_Q_over_a[period]` and
`next_state` programs inside each `Regime` keep their XLA workspaces
pinned on the device for as long as a Python reference exists. With
the V-array already on disk and dropped, the device still cannot
satisfy orbax's D2H allocation because those workspaces fragment the
BFC pool (or block contiguous allocations under platform allocator).

Reorders `save` so the artifacts that depend on `self._regimes`
(`metadata.pkl`, the in-memory dataframe for `simulated_data.arrow`)
are produced upfront, and `self._regimes` is then replaced with an
empty mapping. `gc.collect` + `jax.clear_caches` after both drops
ensures the underlying device buffers go back to the allocator
before orbax begins its per-leaf transfers.

Existing roundtrip tests captured `result.to_dataframe()` after
`save` — now infeasible. They capture it upfront and compare the
loaded frame against the snapshot instead.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Clearing JAX's compile cache between the regime/V-array drop and the
orbax serialization invalidates the topology orbax records on the
saved arrays; `restore` then refuses with "Topology mismatch
detected." The compiled-program workspaces are freed by `del` +
`gc.collect()` alone — `clear_caches` was added for belt-and-
suspenders memory hygiene but is incompatible with orbax's expected
state. Drop it from the pre-orbax flow.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`to_dataframe` materialises boolean masks via `arr[mask]`, which JAX
implements by gathering the mask to host with `np.asarray(mask)`. On
a 16 GiB device with the V-array still resident, the staging
allocation for that gather (~10 GiB at production grid sizes) cannot
land. Move the chunked V-array write — and the subsequent
`self._period_to_regime_to_V_arr = MappingProxyType({})` drop —
ahead of `to_dataframe`, so the dataframe builds against an
already-trimmed device pool.

Order of operations:
1. write V-array chunks
2. drop V-array refs
3. build dataframe + metadata snapshot (needs `self._regimes`)
4. drop `self._regimes`
5. orbax-save the small array tree
6. write the feather artifact

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`_concatenate_and_filter` pulls each per-period leaf to host via
`np.asarray` and concatenates/masks in numpy. This bypasses JAX's
`expand_bool_indices`, which materialises a device-side staging buffer
of the full concatenated shape per key and OOMs on small devices when
V-array workspaces are still resident in the BFC pool.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
save() now writes metadata.pkl ahead of the dataframe and releases the
compiled Regime objects before calling to_dataframe. This frees the
XLA program workspaces holding ~3-4 GiB of the BFC pool so the
per-period D2H gathers inside _concatenate_and_filter have headroom
for the largest staging copy. When df_additional_targets is set the
targets DAG still needs the compiled programs, so the drop is
deferred to after the dataframe.

_create_flat_dataframe / _process_regime now take regime_name
separately from the optional Regime object; the dataframe code reads
only the name when additional_targets is None, so an empty regimes
mapping is valid in that path.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`SimulationResult.save` now opens with `jax.block_until_ready` on the
raw simulation results so any lazy XLA dispatch backing the per-period
arrays runs while the V-array and compiled regimes are still resident.
The deferred kernels would otherwise be triggered inside
`to_dataframe`'s D2H gathers, where the BFC pool is already carrying
the per-key host staging buffers and CUDA's driver-side launch state
has the least room.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Refactor _concatenate_and_filter to walk one key at a time across all
periods. The boolean mask is materialised first, then for each key the
function pulls per-period chunks to host with `np.asarray`, drops each
device reference between chunks, concatenates on host, and applies the
mask. Peak device residency drops from "all per-period dicts" to "one
per-period leaf".

Gated by `LCM_DATAFRAME_SHAPE_CENSUS=1`, the same path now prints a
shape / dtype / sharding summary for the first period's dict plus a
pre-asarray line for every leaf before pulling it to host. The last
`[pre-asarray]` line emitted before an OOM identifies the offending
leaf and whether it is sharded.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Restructure `_concatenate_and_filter` to iterate periods on the outer
loop, building per-key host chunks as it goes, and clear each
period's dict the moment its leaves are on host. Peak device
residency drops to one per-period dict's leaves — sized in MB rather
than the full per-regime tensor stack.

A new `_to_host` helper replaces direct `np.asarray` calls. For
single-device leaves it collapses to `np.asarray`; for sharded leaves
it walks `addressable_shards`, transfers each shard's local data
independently, and reassembles into a host-allocated output via the
shard's `index` slice. The implicit XLA all-gather a `np.asarray`
on a sharded array would trigger — and the contiguous device buffer
that gather needs — is skipped; the contiguous reassembly lives in
host memory.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…eSimulationData as pytree

Two small fixes folded together. `_to_host` was annotated as `object`,
which `ty` rejects when the body reads `.shape` and `.dtype`. The
actual values pulled to host inside `_concatenate_and_filter` are
already typed `FloatND | IntND | BoolND`, so widen the parameter type
accordingly.

In `_lcm.engine`, register `PeriodRegimeSimulationData` as a JAX
pytree via `jax.tree_util.register_dataclass`. Without it,
`jax.block_until_ready` traversals over the
`dict[regime][period] -> PeriodRegimeSimulationData` shape treat each
dataclass instance as an opaque leaf, so the per-subject
`V_arr` / `actions` / `states` / `in_regime` arrays inside are never
waited on. Pytree registration makes the existing drain at the end
of `simulate` actually drain.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Two leftovers from the save-time OOM investigation, both dead weight:

- The `LCM_DATAFRAME_SHAPE_CENSUS` shape-census in result_dataframe.py
  (`_log_shape_census` / `_log_pre_to_host` / `_describe_value`, env-gated,
  off by default) was diagnostic instrumentation. Remove it and the
  `os` / `sys` imports it required.
- The top-of-save `jax.block_until_ready(self._raw_results)` is a no-op:
  registering `PeriodRegimeSimulationData` as a pytree (this branch) makes
  the simulate-end drain materialise every raw-results array, and nothing
  re-defers them before `save`, so the save-time drain blocks on
  already-concrete arrays.

Behavior-preserving: 1031 passed, ty + prek clean.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Run the forward simulation one subject chunk at a time through the full period
loop, concatenating per-chunk results on the subject axis. Subjects are
independent across the forward path, so chunking bounds the per-period device
workspace without changing results. Per-subject RNG keys are generated for the
full population and sliced by global index (generate_simulation_keys gains a
subject_slice argument), so simulated paths are invariant to subject_batch_size.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
When subject_batch_size is set, device_put each chunk's raw_results to the CPU
as the chunk finishes — block_until_ready forces the D2H so the device frees the
chunk before the next one's period loop allocates, bounding device residency to
a single chunk. raw_results leaves stay jax.Array (CPU-backed). Unbatched runs
keep results on the compute device, avoiding a needless host round-trip for
downstream target evaluation.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
to_dataframe's additional_targets evaluation vmaps the target DAG over a
regime's in-regime rows in a single full-population pass — the estimation hot
loop's binding memory peak. When simulate ran with subject_batch_size, reuse it
(stored on the result, persisted across save/load) to chunk that vmap over
subjects, pulling each chunk to host before the next runs. Constant targets (no
per-subject variable) keep their single 0-d value to match the single-pass
dtype. Values are unchanged; device workspace is bounded by the chunk.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Collapse the solution save/load to a single orbax checkpoint. Each V_arr leaf
is moved to host first via _to_host_for_save: single-device leaves are copied in
bounded chunks along their splay axis (so a near-cap leaf never needs a full
contiguous device buffer for the D2H), and sharded leaves pass through orbax
directly, preserving their sharding on load. Removes the per-leaf chunked-.npy
writer, the meta.json layout dispatch, and the chunked load path.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The forward-sim save path no longer routes V_arr leaves through a CPU
device_put before orbax. A materialized device array saved directly via
orbax does not double device residency, so the offload bought nothing —
and on a GPU box it recorded a `cpu:0` SingleDeviceShardingMetadata that
no GPU runner can restore (Topology mismatch). Saving each leaf on its
native device makes save and load same-backend on every runner.

Removes the now-dead `_ChunkSpec`/`_build_chunk_specs` machinery,
`_to_host_for_save`, and their tests.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
A model with `n_subjects` set AOT-compiles its simulate functions once and
dispatches them per subject chunk. The compiled program is now sized for the
chunk (`subject_batch_size`, clamped to the population), not the full
population, so chunked dispatch matches the compiled shape.

To keep every chunk exactly the compiled size, a short final chunk slides its
window back to end at the population tail. The overlapped subjects were already
simulated in the previous chunk and are recomputed bit-identically (per-subject
keys depend only on the global index); the leading overlap is trimmed so each
subject is represented once, in global order.

Without this, `Model(n_subjects=N)` + `simulate(subject_batch_size=B)` with
B < N crashed: the program compiled for [N] rejected the [B]-shaped chunk call.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
hmgaudecker and others added 12 commits June 1, 2026 12:48
The asv GPU peak-mem subprocess reads `peak_bytes_in_use`, a monotonic
high-water mark spanning compile + execute. XLA autotuning allocates
large, run-to-run-variable scratch buffers at compile time; for big
models that transient dwarfs the execution working set, so the reported
peak swung several-fold between identical runs and even inverted
log-off vs log-debug. Small models compile trivially and stayed stable.

Disable autotuning in the subprocess (`--xla_gpu_autotune_level=0`,
appended to any existing `XLA_FLAGS`). The compile footprint becomes
deterministic and matches production, which already runs autotuning
off. A fresh single cold run is the production footprint anyway, so the
measurement stays representative.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
`SimulationResult.save()` writes `simulated_data.arrow` via
`df.to_feather(...)`, which needs pyarrow at runtime. It was only in the
pixi dev dependencies, so any install of pylcm-as-a-package (or a
downstream environment whose pandas build doesn't hard-pull pyarrow)
could hit a missing-pyarrow ImportError at save time. Move it into
`[project.dependencies]` so it ships with the package.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Pure helpers for `subject_batch_size="auto"`: `estimate_peak_bytes`
turns one compiled program's `memory_analysis()` into a peak figure
(temp + argument + output − aliased), and `pick_batch_size` fits the
affine peak-vs-batch line through compile-only probes and returns the
largest batch under the memory budget, clamped to [1, population].

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Replace the `subject_batch_size: int | None = None` knob with
`int | Literal["auto"] = 0`, matching the grid `batch_size=0`
no-batching sentinel:

- `0` (default): one pass over the padded population.
- `> 0`: explicit chunk size. Subject-chunking is single-device, so it
  is rejected under multi-device distribution — the value-function array
  is sharded across the devices and can't be gathered onto one.
- `"auto"`: size the chunk to the device. `compile_all_simulate_functions`
  now also returns the max program peak from `memory_analysis()`
  (alias-corrected); the resolver compiles the one-pass program first
  (reused when it fits), and on overflow fits the affine peak-vs-batch
  line through a half-population probe to pick the largest batch under
  `bytes_limit × (1 − margin)`. Falls back to one pass when the device
  exposes no limit (CPU), the grids are distributed across devices, or
  the model is not AOT-configured.

The engine receives the resolved concrete int (one pass = population,
chunked = the chunk size), so its None path is gone and host-offload
keys off `batch_size < n_subjects`. Results stay invariant to the knob
(per-subject RNG keys sliced by global index). GPU-verified: both the
fits→one-pass and forced-chunk branches reproduce the single-pass output.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Review fixes for the subject_batch_size="auto" path:

- `pick_batch_size` no longer divides by zero when both probes share a
  batch size (a one-subject population makes the full- and half-pop
  probes coincide). It guards `b_hi == b_lo` and falls back to a
  proportional model, clamped to the population.
- `_autotune_compile_batch_size` returns one pass when the half-pop
  probe wouldn't reduce the batch (population below two), before the
  redundant compile.
- `estimate_peak_bytes` is keyword-only, per the project convention.
- Tests: a degenerate identical-batch probe case, and `"auto"` under
  multi-device distribution falling back to one pass (not raising).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
subject_batch_size becomes int-only (0 = single pass, >0 = manual chunk).
Sizing a chunk for a memory-constrained device is now manual trial-and-error
rather than compile-time XLA memory_analysis() probing.

Removes _lcm/simulation/autotune.py (estimate_peak_bytes, pick_batch_size),
the _simulate_peak_cache / _autotune_compile_batch_size / _device_budget_bytes
machinery in model.py, and the peak-bytes return from
compile_all_simulate_functions. The manual chunk path and its
multi-device-distribution guard are unchanged.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Pad initial_conditions up front to a single multiple — the device count when
sharding, the chunk size when chunking — instead of per-device padding plus a
per-chunk slide-window overlap. The chunk loop then iterates exact blocks (no
overlap, no per-chunk trim) and one trim_pad_from_raw_results drops the
duplicate-last-subject pad rows.

Renames pad_initial_conditions_for_devices -> pad_initial_conditions_to_multiple
(takes the alignment block size, not regimes) and removes _trim_chunk_subjects.
Results are unchanged: the 7-subject non-dividing invariance test stays green;
per-subject RNG keys are still drawn for the real population and sliced by
global index.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
A positive `batch_size` on a state with a Markov transition chunks the
next-period shock-integration map. That chunking de-fuses the
expected-continuation-value reduction from the value-function
interpolation, so the full joint next-stochastic cross-product is
materialized instead of streamed — peak memory grows with the stochastic
cardinality instead of shrinking.

Reject it at regime construction, pointing the user to shard the current
axis (`distributed=True`) or splay a continuous state instead.
Deterministic states keep `batch_size` as a memory-reducing splay knob.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Base automatically changed from feat/discrete-only-sharding-orbax-save to main June 3, 2026 15:13
hmgaudecker and others added 4 commits June 3, 2026 17:14
With the shock-integration batch_size unwired, a stochastic state's
`batch_size` drives only the outer state-loop splay — a pure memory knob.
Lock in the contract: solving with `batch_size=1` on a Markov state
yields a value function identical to `batch_size=0`.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@hmgaudecker hmgaudecker changed the title Distributed-scale solve fixes: sharded NaN/Inf reductions + stochastic-grid batch_size Distributed-scale solve fixes: sharded NaN/Inf reductions Jun 3, 2026
@hmgaudecker hmgaudecker requested a review from mj023 June 3, 2026 15:43
hmgaudecker and others added 2 commits June 4, 2026 12:13
Document the two orthogonal scaling knobs and the environment flags, since the
splay-vs-shard distinction is easy to get backwards:

- `batch_size` (splay) is a memory knob and time-neutral — it only re-tiles the
  per-period max via `lax.scan`; identical FLOPs. Prefer large, uniform continuous
  axes; use the fewest chunks that fit.
- `distributed` (shard) is the only knob that reduces wall-clock, and only for
  discrete, never-transitioning axes. Covers both construction guards (no
  continuous sharding; no batch_size+distributed on one axis) and why they hold.
- `subject_batch_size` for the forward pass.
- Environment flags: the JAX defaults pylcm sets, and the allocator / preallocate /
  mem-fraction / autotune / command-buffer trade-offs, with a stable multi-GPU recipe.

Registered in the user-guide index and the myst.yml TOC.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Head-to-head per-period execution time is unchanged with kernel autotuning on vs off
(matched config, deep-age steady state) — backward induction is gather/scatter- and
interpolation-bound, not GEMM-bound, so there is little for autotuning to optimize, while
it raises compile time and peak memory. Default to off; turn on only if a per-period
measurement shows a real speedup.

Also note the exclusive-GPU failure mode: with PREALLOCATE=true a contended or leaked
device fails the pool preallocation with an immediate device-0 OUT_OF_MEMORY, distinct
from the 'allocated N, saw 1' visibility case.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Copy link
Copy Markdown
Collaborator

@mj023 mj023 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good change, very helpful guide about batching and sharding!

The second half of the PR description reads like there is a second change, but I can't find it in the Code?

Comment thread docs/user_guide/tuning.md Outdated
Comment on lines +101 to +103
the synchronisation count and inverts the compute/communication ratio. On a sharded
axis keep `batch_size=0`; if its per-device chunk is still too big, add devices or
shard a second axis rather than batching it.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cant quite follow what this means. To shard a second axis, one would need to add devices anyways, rigth?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that is misleading indeed. Now that we disallow sharding continuous axes, you may well want to shard 2x2, so that's supported. But in that example, it did not make any sense.

…ent one

Address mj023's review on #370: the 'shard a second axis' remedy for a
too-big per-device chunk was misleading. Splaying a sharded axis is only
tempting when a device holds >1 block, which is exactly when sharding that
same axis across more devices works instead; the everyday memory fix is to
splay a different, non-sharded axis. Update both the tuning guide and the
matching GridInitializationError message to say so.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@hmgaudecker
Copy link
Copy Markdown
Member Author

The second half of the PR description reads like there is a second change, but I can't find it in the Code?

That was the rationale more than anything, poor cleaning up of the writing from here. Should be better now.

hmgaudecker and others added 3 commits June 5, 2026 11:47
The ty update flagged `pd.Categorical` as `Categorical[object]` while
pandas-stubs infers `Categorical[str]` for the returned values; bumping
the annotation to `pd.Categorical[str]` (commit 53855bf) silenced ty but
broke beartype's module-level claw: `pd.Categorical` isn't subscriptable
at runtime, so the annotation raised `TypeError` during decoration and
the benchmark CI started emitting `BeartypeClawDecorWarning`.

Revert the annotations to bare `pd.Categorical` so beartype can resolve
them, and silence ty at the three return statements instead.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Track aca-model main for the benchmarks env instead of a pinned commit,
matching the dags dependency convention. The previous rev lived on an
aca-model feature branch that has since been squash-merged and deleted.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@hmgaudecker hmgaudecker merged commit 64cf042 into main Jun 5, 2026
10 of 11 checks passed
@hmgaudecker hmgaudecker deleted the feat/distributed-solve-fixes branch June 5, 2026 12:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants