Distributed-scale solve fixes: sharded NaN/Inf reductions#370
Conversation
Sharding a continuous state axis forces every interpolation lookup to all-gather the full V-array per device — the OOM pattern observed at production scale. Sharding belongs only on discrete state grids, where Markov transitions over the sharded axis contract via the cheap all-reduce (output-sized, not V-array-sized). The check fires at grid construction with a clear message; the older batch_size + distributed rejection on continuous grids becomes unreachable and was dropped from the parametrized tests.
Backward induction returns sharded V_arrs; the simulate phase must consume materialised arrays rather than in-flight kernels. A per-shard `.block_until_ready()` after the loop drains the device queue without moving any data to the host. V stays sharded across devices. The new `_drain_v_arr_shards` helper keeps `solve()` under the C901 complexity ceiling.
Sharded V-arrays now travel through the on-disk format without an implicit gather. Each shard is written and restored on the same device mesh via `orbax-checkpoint`; non-array metadata (regimes, ages, pre-computed result metadata, parameter scaffolding) lives in a sibling `metadata.pkl` produced via `cloudpickle`. API: - `result.save(directory)` → writes `<dir>/arrays/` (orbax) and `<dir>/metadata.pkl` (cloudpickle). - `SimulationResult.load(directory)` → reverse, including the reconstruction of every `MappingProxyType` / `PeriodRegimeSimulationData` wrapper around the restored arrays. Replaces the prior `to_pickle` / `from_pickle` pair, which cloudpickled the whole result and triggered an all-gather on every sharded array's `__reduce__`.
`save(*, directory, df_additional_targets="all", df_use_labels=True)` now writes a third sibling artifact alongside `arrays/` and `metadata.pkl`: `simulated_data.arrow`, a feather dump of the `to_dataframe` projection. Downstream consumers can read the flat per-subject view via `pd.read_feather` without re-instantiating a `SimulationResult`, and pytask tasks can point their `Product` at the arrow file as a stable single-file anchor for the saved bundle. Coerces JAX 0-d arrays to Python scalars before writing feather: regimes whose target function returns a constant get broadcast as a 0-d JAX scalar across the per-regime sub-frame, which pyarrow refuses to convert. `_coerce_jax_scalar_for_arrow` is mapped element-wise over the dataframe at the arrow boundary. `load(*, directory)` reads arrays + metadata only; `simulated_data.arrow` is a one-way artifact for downstream consumers. Both `save` and `load` adopt keyword-only argument style consistent with pylcm convention. AGENTS.md and `docs/user_guide/solving_and_simulating.md` show the new API.
- Rename `_drain_v_arr_shards` to `_drain_V_arr_shards`, plus the matching V_arr local-variable and test-function casing in test_distributed, test_float_dtype_invariants, test_analytical_solution. - Replace the manual nested-shard loop with a single `jax.block_until_ready(solution)` call. JAX walks the pytree and blocks per-shard with no host transfer; the materialisation test stays green. - Narrow `SimulationResult.save` / `.load` to `directory: Path`, dropping the str alternative. Every caller already passes a Path (tests use `tmp_path / "result"`, simulate tasks pass `output.parent`). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Four gaps closed so simulate's hot loop and debug outputs match the materialisation strategy already in place on the solve side: 1. NaN validation in simulate is gated on `log_level`. At `"off"` `validate_V` is skipped (no per-(regime, period) host transfer); at `"warning"` / `"progress"` it warns rather than raising; only `"debug"` raises. Same contract as `solve(log_level=...)`. 2. `log_nan_in_V` only fires when diagnostics are enabled, so the per-period `__bool__` host transfer disappears at `"off"`. 3. Post-loop `jax.block_until_ready(simulation_results)` at the end of `simulate()` drains the per-period compute graph before returning. Mirrors solve's `_drain_V_arr_shards` so downstream consumers (`to_dataframe`, `save`) start with concrete arrays. 4. `log_regime_transitions` builds the full `(n_regimes, n_regimes)` transition count matrix in a single fused JAX kernel and host-transfers it once per period, instead of `O(n_regimes^2)` `.item()` calls per period at DEBUG. Adds two tests in `test_nan_diagnostics.py` covering the gated validation paths (`"off"` skips, `"warning"` warns rather than raises) and a placeholder drain test in `test_distributed.py` as a regression guard. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`log_nan_in_V` and `_simulate_regime_in_period` both already receive the logger that carries the validation policy. Threading a separate `diagnostics_enabled` bool through `_simulate_regime_in_period` was redundant with the logger. Have each function call `validation_enabled(logger)` itself instead — same pattern as `log_regime_transitions`, which already self-gates on `logger.isEnabledFor(logging.DEBUG)`. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
- `SimulationResult.save` now defaults `df_additional_targets=None`, so the `simulated_data.arrow` artifact carries only the base columns (states, actions, regime, age, period, V_arr). At ACA production scale that is ~860 MB per simulate; the previous `"all"` default would bake every DAG leaf in and grow each artifact to ~4 GB. - Update the `save` round-trip test to compare against `to_dataframe(use_labels=True)` instead of `to_dataframe(additional_targets="all", use_labels=True)`. - Rename `_period_v_to_array_tree` -> `_period_V_to_array_tree` and `_array_tree_to_period_v` -> `_array_tree_to_period_V`, and the per-period stat locals `v_min/v_max/v_mean` -> `V_min/V_max/V_mean`, and the new `test_simulate_log_level_warning_does_not_raise_on_nan_v_arr` test -> `..._V_arr`. Finishes the V_arr capitalisation pass across the codebase. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The previous pin (39ac270b) pre-dates aca-model's removal of `assets_distributed` from `GridConfig` and so passes `distributed=True` to a `LinSpacedGrid`. pylcm's new `_fail_if_continuous_grid_distributed` guard rejects that at grid construction, breaking the GPU benchmark job. Bump to 7af96820 (current head of aca-model's API-restructure branch), which only sets `distributed=True` on the `pref_type` DiscreteGrid (and the new discrete-state flags) — all valid targets for sharding under the discrete-only rule. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
orbax's StandardCheckpointer.save stages each jax.Array on device before transferring it to host, which doubles peak device usage at save time. On a device whose memory is comparable to the V_arr, the staging OOMs even when solve + simulate themselves had room. Convert single-device leaves to numpy via jax.device_get upfront so orbax serialises numpy directly. Sharded multi-device leaves are left alone — orbax already does per-shard host transfer there. Symmetric `_array_tree_to_jax` lifts numpy leaves back to jax.Array on load so engine code sees the expected types. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`object` was too tight: callers (including the new helper tests) subscript the result, which ty rejects on `object`. The merged array tree is a heterogeneous `dict[str, ...]` of three sub-trees, so `dict[str, Any]` is the smallest accurate type. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The previous one-shot orbax save materialised every V_arr leaf on device before transferring to host. Leaves whose size approaches the device cap then triggered an out-of-memory at save time even when solve+simulate themselves had headroom — the engine's scan over states with `batch_size > 0` keeps only one slice live during compute, but `np.asarray(jax_array)` requires the whole array as one contiguous device buffer. This change matches the save's materialisation profile to the engine's compute-time peak: - `_build_chunk_specs(regimes, flat_params)` derives a per-regime `_ChunkSpec(chunk_axis, chunk_size)` by walking each V_arr's canonical state axis order and picking the outermost state with `batch_size > 0`. - `SimulationResult` carries `chunk_specs`; `simulate()` populates it. - `save()` writes `arrays/` (orbax for `raw_results` and `flat_params` — small) and `V_arr/period_<NNNN>/regime_<R>/`, where each V_arr leaf is sliced along its splay axis with width `chunk_size`, host-transferred per chunk, and dumped as `<I>.npy` files alongside a `meta.json` describing how to reassemble. - Multi-device sharded leaves still go through orbax (preserves the sharding spec); single-device leaves take the chunked path. - `load()` mirrors: orbax for the small trees, walk `V_arr/` and `np.concatenate` along `chunk_axis` per the recorded meta. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Surface the size distribution of `small_array_tree` on stderr so an OOM inside `_serialize_arrays_batches_without_dispatcher` reports which leaf exceeds the device cap. Top-K with path/shape/dtype/GiB, plus aggregate count and bytes. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Reverses the order inside `SimulationResult.save`: chunked V-array
write goes first, the device-resident grid is dropped via
`self._period_to_regime_to_V_arr = MappingProxyType({})` + `gc.collect`
+ `jax.clear_caches`, and only then does orbax serialise the
`raw_results` / `flat_params` tree.
Background: V100 (16 GiB) orbax save OOM'd at exactly 16.30 GiB.
Instrumentation confirmed the small array tree is 2.568 GiB across
2,106 per-subject leaves — no single leaf is large. The OOM was the
platform allocator asking for ~the device's idle-free block on the
first `data.copy_to_host_async`, blocked by ~12-14 GiB of still
resident solve-grid V-arrays. Freeing those buffers before orbax
runs leaves ~13 GiB of headroom, dwarfing the 2.5 GiB transfer
demand.
Side effect: `self.period_to_regime_to_V_arr` is an empty mapping
after `save` returns. Callers that need the values back must
reload via `SimulationResult.load`. Documented in the `save`
docstring.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The compiled `simulate_functions.argmax_and_max_Q_over_a[period]` and `next_state` programs inside each `Regime` keep their XLA workspaces pinned on the device for as long as a Python reference exists. With the V-array already on disk and dropped, the device still cannot satisfy orbax's D2H allocation because those workspaces fragment the BFC pool (or block contiguous allocations under platform allocator). Reorders `save` so the artifacts that depend on `self._regimes` (`metadata.pkl`, the in-memory dataframe for `simulated_data.arrow`) are produced upfront, and `self._regimes` is then replaced with an empty mapping. `gc.collect` + `jax.clear_caches` after both drops ensures the underlying device buffers go back to the allocator before orbax begins its per-leaf transfers. Existing roundtrip tests captured `result.to_dataframe()` after `save` — now infeasible. They capture it upfront and compare the loaded frame against the snapshot instead. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Clearing JAX's compile cache between the regime/V-array drop and the orbax serialization invalidates the topology orbax records on the saved arrays; `restore` then refuses with "Topology mismatch detected." The compiled-program workspaces are freed by `del` + `gc.collect()` alone — `clear_caches` was added for belt-and- suspenders memory hygiene but is incompatible with orbax's expected state. Drop it from the pre-orbax flow. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`to_dataframe` materialises boolean masks via `arr[mask]`, which JAX
implements by gathering the mask to host with `np.asarray(mask)`. On
a 16 GiB device with the V-array still resident, the staging
allocation for that gather (~10 GiB at production grid sizes) cannot
land. Move the chunked V-array write — and the subsequent
`self._period_to_regime_to_V_arr = MappingProxyType({})` drop —
ahead of `to_dataframe`, so the dataframe builds against an
already-trimmed device pool.
Order of operations:
1. write V-array chunks
2. drop V-array refs
3. build dataframe + metadata snapshot (needs `self._regimes`)
4. drop `self._regimes`
5. orbax-save the small array tree
6. write the feather artifact
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`_concatenate_and_filter` pulls each per-period leaf to host via `np.asarray` and concatenates/masks in numpy. This bypasses JAX's `expand_bool_indices`, which materialises a device-side staging buffer of the full concatenated shape per key and OOMs on small devices when V-array workspaces are still resident in the BFC pool. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
save() now writes metadata.pkl ahead of the dataframe and releases the compiled Regime objects before calling to_dataframe. This frees the XLA program workspaces holding ~3-4 GiB of the BFC pool so the per-period D2H gathers inside _concatenate_and_filter have headroom for the largest staging copy. When df_additional_targets is set the targets DAG still needs the compiled programs, so the drop is deferred to after the dataframe. _create_flat_dataframe / _process_regime now take regime_name separately from the optional Regime object; the dataframe code reads only the name when additional_targets is None, so an empty regimes mapping is valid in that path. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`SimulationResult.save` now opens with `jax.block_until_ready` on the raw simulation results so any lazy XLA dispatch backing the per-period arrays runs while the V-array and compiled regimes are still resident. The deferred kernels would otherwise be triggered inside `to_dataframe`'s D2H gathers, where the BFC pool is already carrying the per-key host staging buffers and CUDA's driver-side launch state has the least room. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Refactor _concatenate_and_filter to walk one key at a time across all periods. The boolean mask is materialised first, then for each key the function pulls per-period chunks to host with `np.asarray`, drops each device reference between chunks, concatenates on host, and applies the mask. Peak device residency drops from "all per-period dicts" to "one per-period leaf". Gated by `LCM_DATAFRAME_SHAPE_CENSUS=1`, the same path now prints a shape / dtype / sharding summary for the first period's dict plus a pre-asarray line for every leaf before pulling it to host. The last `[pre-asarray]` line emitted before an OOM identifies the offending leaf and whether it is sharded. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Restructure `_concatenate_and_filter` to iterate periods on the outer loop, building per-key host chunks as it goes, and clear each period's dict the moment its leaves are on host. Peak device residency drops to one per-period dict's leaves — sized in MB rather than the full per-regime tensor stack. A new `_to_host` helper replaces direct `np.asarray` calls. For single-device leaves it collapses to `np.asarray`; for sharded leaves it walks `addressable_shards`, transfers each shard's local data independently, and reassembles into a host-allocated output via the shard's `index` slice. The implicit XLA all-gather a `np.asarray` on a sharded array would trigger — and the contiguous device buffer that gather needs — is skipped; the contiguous reassembly lives in host memory. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…eSimulationData as pytree Two small fixes folded together. `_to_host` was annotated as `object`, which `ty` rejects when the body reads `.shape` and `.dtype`. The actual values pulled to host inside `_concatenate_and_filter` are already typed `FloatND | IntND | BoolND`, so widen the parameter type accordingly. In `_lcm.engine`, register `PeriodRegimeSimulationData` as a JAX pytree via `jax.tree_util.register_dataclass`. Without it, `jax.block_until_ready` traversals over the `dict[regime][period] -> PeriodRegimeSimulationData` shape treat each dataclass instance as an opaque leaf, so the per-subject `V_arr` / `actions` / `states` / `in_regime` arrays inside are never waited on. Pytree registration makes the existing drain at the end of `simulate` actually drain. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Two leftovers from the save-time OOM investigation, both dead weight: - The `LCM_DATAFRAME_SHAPE_CENSUS` shape-census in result_dataframe.py (`_log_shape_census` / `_log_pre_to_host` / `_describe_value`, env-gated, off by default) was diagnostic instrumentation. Remove it and the `os` / `sys` imports it required. - The top-of-save `jax.block_until_ready(self._raw_results)` is a no-op: registering `PeriodRegimeSimulationData` as a pytree (this branch) makes the simulate-end drain materialise every raw-results array, and nothing re-defers them before `save`, so the save-time drain blocks on already-concrete arrays. Behavior-preserving: 1031 passed, ty + prek clean. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Run the forward simulation one subject chunk at a time through the full period loop, concatenating per-chunk results on the subject axis. Subjects are independent across the forward path, so chunking bounds the per-period device workspace without changing results. Per-subject RNG keys are generated for the full population and sliced by global index (generate_simulation_keys gains a subject_slice argument), so simulated paths are invariant to subject_batch_size. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
When subject_batch_size is set, device_put each chunk's raw_results to the CPU as the chunk finishes — block_until_ready forces the D2H so the device frees the chunk before the next one's period loop allocates, bounding device residency to a single chunk. raw_results leaves stay jax.Array (CPU-backed). Unbatched runs keep results on the compute device, avoiding a needless host round-trip for downstream target evaluation. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
to_dataframe's additional_targets evaluation vmaps the target DAG over a regime's in-regime rows in a single full-population pass — the estimation hot loop's binding memory peak. When simulate ran with subject_batch_size, reuse it (stored on the result, persisted across save/load) to chunk that vmap over subjects, pulling each chunk to host before the next runs. Constant targets (no per-subject variable) keep their single 0-d value to match the single-pass dtype. Values are unchanged; device workspace is bounded by the chunk. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Collapse the solution save/load to a single orbax checkpoint. Each V_arr leaf is moved to host first via _to_host_for_save: single-device leaves are copied in bounded chunks along their splay axis (so a near-cap leaf never needs a full contiguous device buffer for the D2H), and sharded leaves pass through orbax directly, preserving their sharding on load. Removes the per-leaf chunked-.npy writer, the meta.json layout dispatch, and the chunked load path. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The forward-sim save path no longer routes V_arr leaves through a CPU device_put before orbax. A materialized device array saved directly via orbax does not double device residency, so the offload bought nothing — and on a GPU box it recorded a `cpu:0` SingleDeviceShardingMetadata that no GPU runner can restore (Topology mismatch). Saving each leaf on its native device makes save and load same-backend on every runner. Removes the now-dead `_ChunkSpec`/`_build_chunk_specs` machinery, `_to_host_for_save`, and their tests. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
A model with `n_subjects` set AOT-compiles its simulate functions once and dispatches them per subject chunk. The compiled program is now sized for the chunk (`subject_batch_size`, clamped to the population), not the full population, so chunked dispatch matches the compiled shape. To keep every chunk exactly the compiled size, a short final chunk slides its window back to end at the population tail. The overlapped subjects were already simulated in the previous chunk and are recomputed bit-identically (per-subject keys depend only on the global index); the leading overlap is trimmed so each subject is represented once, in global order. Without this, `Model(n_subjects=N)` + `simulate(subject_batch_size=B)` with B < N crashed: the program compiled for [N] rejected the [B]-shaped chunk call. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The asv GPU peak-mem subprocess reads `peak_bytes_in_use`, a monotonic high-water mark spanning compile + execute. XLA autotuning allocates large, run-to-run-variable scratch buffers at compile time; for big models that transient dwarfs the execution working set, so the reported peak swung several-fold between identical runs and even inverted log-off vs log-debug. Small models compile trivially and stayed stable. Disable autotuning in the subprocess (`--xla_gpu_autotune_level=0`, appended to any existing `XLA_FLAGS`). The compile footprint becomes deterministic and matches production, which already runs autotuning off. A fresh single cold run is the production footprint anyway, so the measurement stays representative. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ibuted-solve-fixes
`SimulationResult.save()` writes `simulated_data.arrow` via `df.to_feather(...)`, which needs pyarrow at runtime. It was only in the pixi dev dependencies, so any install of pylcm-as-a-package (or a downstream environment whose pandas build doesn't hard-pull pyarrow) could hit a missing-pyarrow ImportError at save time. Move it into `[project.dependencies]` so it ships with the package. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Pure helpers for `subject_batch_size="auto"`: `estimate_peak_bytes` turns one compiled program's `memory_analysis()` into a peak figure (temp + argument + output − aliased), and `pick_batch_size` fits the affine peak-vs-batch line through compile-only probes and returns the largest batch under the memory budget, clamped to [1, population]. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Replace the `subject_batch_size: int | None = None` knob with `int | Literal["auto"] = 0`, matching the grid `batch_size=0` no-batching sentinel: - `0` (default): one pass over the padded population. - `> 0`: explicit chunk size. Subject-chunking is single-device, so it is rejected under multi-device distribution — the value-function array is sharded across the devices and can't be gathered onto one. - `"auto"`: size the chunk to the device. `compile_all_simulate_functions` now also returns the max program peak from `memory_analysis()` (alias-corrected); the resolver compiles the one-pass program first (reused when it fits), and on overflow fits the affine peak-vs-batch line through a half-population probe to pick the largest batch under `bytes_limit × (1 − margin)`. Falls back to one pass when the device exposes no limit (CPU), the grids are distributed across devices, or the model is not AOT-configured. The engine receives the resolved concrete int (one pass = population, chunked = the chunk size), so its None path is gone and host-offload keys off `batch_size < n_subjects`. Results stay invariant to the knob (per-subject RNG keys sliced by global index). GPU-verified: both the fits→one-pass and forced-chunk branches reproduce the single-pass output. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Review fixes for the subject_batch_size="auto" path: - `pick_batch_size` no longer divides by zero when both probes share a batch size (a one-subject population makes the full- and half-pop probes coincide). It guards `b_hi == b_lo` and falls back to a proportional model, clamped to the population. - `_autotune_compile_batch_size` returns one pass when the half-pop probe wouldn't reduce the batch (population below two), before the redundant compile. - `estimate_peak_bytes` is keyword-only, per the project convention. - Tests: a degenerate identical-batch probe case, and `"auto"` under multi-device distribution falling back to one pass (not raising). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ibuted-solve-fixes
subject_batch_size becomes int-only (0 = single pass, >0 = manual chunk). Sizing a chunk for a memory-constrained device is now manual trial-and-error rather than compile-time XLA memory_analysis() probing. Removes _lcm/simulation/autotune.py (estimate_peak_bytes, pick_batch_size), the _simulate_peak_cache / _autotune_compile_batch_size / _device_budget_bytes machinery in model.py, and the peak-bytes return from compile_all_simulate_functions. The manual chunk path and its multi-device-distribution guard are unchanged. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Pad initial_conditions up front to a single multiple — the device count when sharding, the chunk size when chunking — instead of per-device padding plus a per-chunk slide-window overlap. The chunk loop then iterates exact blocks (no overlap, no per-chunk trim) and one trim_pad_from_raw_results drops the duplicate-last-subject pad rows. Renames pad_initial_conditions_for_devices -> pad_initial_conditions_to_multiple (takes the alignment block size, not regimes) and removes _trim_chunk_subjects. Results are unchanged: the 7-subject non-dividing invariance test stays green; per-subject RNG keys are still drawn for the real population and sliced by global index. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ibuted-solve-fixes
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
A positive `batch_size` on a state with a Markov transition chunks the next-period shock-integration map. That chunking de-fuses the expected-continuation-value reduction from the value-function interpolation, so the full joint next-stochastic cross-product is materialized instead of streamed — peak memory grows with the stochastic cardinality instead of shrinking. Reject it at regime construction, pointing the user to shard the current axis (`distributed=True`) or splay a continuous state instead. Deterministic states keep `batch_size` as a memory-reducing splay knob. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This reverts commit 48906ba.
…tion" This reverts commit 1759b94.
With the shock-integration batch_size unwired, a stochastic state's `batch_size` drives only the outer state-loop splay — a pure memory knob. Lock in the contract: solving with `batch_size=1` on a Markov state yields a value function identical to `batch_size=0`. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Document the two orthogonal scaling knobs and the environment flags, since the splay-vs-shard distinction is easy to get backwards: - `batch_size` (splay) is a memory knob and time-neutral — it only re-tiles the per-period max via `lax.scan`; identical FLOPs. Prefer large, uniform continuous axes; use the fewest chunks that fit. - `distributed` (shard) is the only knob that reduces wall-clock, and only for discrete, never-transitioning axes. Covers both construction guards (no continuous sharding; no batch_size+distributed on one axis) and why they hold. - `subject_batch_size` for the forward pass. - Environment flags: the JAX defaults pylcm sets, and the allocator / preallocate / mem-fraction / autotune / command-buffer trade-offs, with a stable multi-GPU recipe. Registered in the user-guide index and the myst.yml TOC. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Head-to-head per-period execution time is unchanged with kernel autotuning on vs off (matched config, deep-age steady state) — backward induction is gather/scatter- and interpolation-bound, not GEMM-bound, so there is little for autotuning to optimize, while it raises compile time and peak memory. Default to off; turn on only if a per-period measurement shows a real speedup. Also note the exclusive-GPU failure mode: with PREALLOCATE=true a contended or leaked device fails the pool preallocation with an immediate device-0 OUT_OF_MEMORY, distinct from the 'allocated N, saw 1' visibility case. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
mj023
left a comment
There was a problem hiding this comment.
Good change, very helpful guide about batching and sharding!
The second half of the PR description reads like there is a second change, but I can't find it in the Code?
| the synchronisation count and inverts the compute/communication ratio. On a sharded | ||
| axis keep `batch_size=0`; if its per-device chunk is still too big, add devices or | ||
| shard a second axis rather than batching it. |
There was a problem hiding this comment.
I cant quite follow what this means. To shard a second axis, one would need to add devices anyways, rigth?
There was a problem hiding this comment.
Thanks, that is misleading indeed. Now that we disallow sharding continuous axes, you may well want to shard 2x2, so that's supported. But in that example, it did not make any sense.
…ent one Address mj023's review on #370: the 'shard a second axis' remedy for a too-big per-device chunk was misleading. Splaying a sharded axis is only tempting when a device holds >1 block, which is exactly when sharding that same axis across more devices works instead; the everyday memory fix is to splay a different, non-sharded axis. Update both the tuning guide and the matching GridInitializationError message to say so. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
That was the rationale more than anything, poor cleaning up of the writing from here. Should be better now. |
The ty update flagged `pd.Categorical` as `Categorical[object]` while pandas-stubs infers `Categorical[str]` for the returned values; bumping the annotation to `pd.Categorical[str]` (commit 53855bf) silenced ty but broke beartype's module-level claw: `pd.Categorical` isn't subscriptable at runtime, so the annotation raised `TypeError` during decoration and the benchmark CI started emitting `BeartypeClawDecorWarning`. Revert the annotations to bare `pd.Categorical` so beartype can resolve them, and silence ty at the three return statements instead. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Track aca-model main for the benchmarks env instead of a pinned commit, matching the dags dependency convention. The previous rev lived on an aca-model feature branch that has since been squash-merged and deleted. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Re-implements the still-relevant distributed-scale solve fixes from the abandoned #364, fresh on top of the discrete-only-sharding work (#368, now merged to
main). Verified on 4 CPU devices (jax_num_cpu_devices=4).What changed
Sharded NaN/Inf reductions (code).
validate_V,contains_nan,log_nan_in_V, and solve's hot-loop diagnostic accumulator route their NaN/Inf checks through module-level@jax.jithelpersv_array_has_nan/v_array_has_inf. GSPMD then partitions each reduction across the V-array's devices (per-deviceany→ all-reduce → replicated scalar). The eagerjnp.any(jnp.isnan(V_arr))alternative can fall back to gathering a sharded V onto the default device before reducing — exhausting device memory at production grid sizes. Regression: two checks that both helpers return a fully-replicated 0-d scalar on aNamedSharding-sharded input.Batching/sharding tuning guide (docs). New
docs/user_guide/tuning.mdsection contrasting the two grid knobs —batch_size(splay: a time-neutral memory knob) anddistributed(shard: the only wall-clock knob, discrete non-transitioning axes only) — when each applies, and the boundaries the constructor enforces (continuous axes can't shard; you can't splay and shard the same axis). The reasoning behind the subtlest part of that guide is written out below — it documents existing behaviour, there is no code change for it.Rationale for the guide —
batch_sizechunks the state sweep, not the shock expectation(Background, not a code change: this is the reasoning the new guide distills.)
Each period's backward step is two nested sweeps:
V = max_a Q. A grid'sbatch_sizechunks this sweep. The chunked axis is a current state — a dimension of the outputV— so each chunk's results are written out and freed, and peak memory shrinks.Q(state, action)cell, average the next period's value over the stochastic transition:E[V'] = Σ_{s'} P(s')·V'(s', …), built as a productmap over the next-period stochastic outcomes followed byjnp.average.A stochastic grid (e.g.
health) names an axis in both: a current-period state in the state sweep, and a summed-over next-period outcome in the shock expectation.batch_sizeacts only on the state sweep —tests/test_stochastic.py::test_stochastic_state_batch_size_is_value_equivalent_to_no_splaypins that solving withbatch_size=1gives a value function identical tobatch_size=0.The shock expectation stays unconditionally fused (
batch_size=0). Its axis is summed away, not kept, so chunking it can't shrink the output — it only de-fuses thejnp.averagefrom theV'interpolation, forcing the full stack of per-outcomeV'arrays to materialize before the sum instead of streaming into the running total. Samejax.lax.map(batch_size=1)primitive, opposite sign depending on whether the chunked axis is kept (an output) or summed away (reduced) — isolatedlax.map+jnp.averageprobe,K=8next-period outcomes over anN=4M-cell working set (128 MB = K·N·4B):batch_size0(→vmap)1(→scan)[K, N]stack materialized before the sumScale
Nto the production state-action working set andKto the joint next-period stochastic cardinality and0 → 128 MBbecomesfits → 271 GiB. To shed memory along a stochastic dimension, shard the current axis (distributed=True) or chunk the state sweep — both keep the shock expectation fused and beat a sequential scan.Dropped from #364
out_shardingson the solve lowering (#364 piece 1) — obsolete. Its motivating case (continuousdistributed=True) is forbidden outright (GridInitializationError), and discrete-distributed V already shards correctly via natural jit output inference (verified:num_devices=4with and withoutout_shardings; all distributed tests pass without it). A docstring note on_fail_if_continuous_grid_distributedrecords the prerequisite (and points at #364's reference implementation) should continuous-axis sharding ever be re-enabled.Device-count padding (#364 piece 6) landed in #368.
Verification
tests/test_stochastic.py(incl. the newbatch_size-equivalence test) — relevant tests pass locallypixi run -e type-checking ty— cleanprek run --all-files— cleanpytest tests -n 4;tests/test_distributed.pyon 4 CPU devices) — deferred to CI🤖 Generated with Claude Code