Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
ed49742
Reject distributed=True on continuous grids (A)
hmgaudecker May 26, 2026
3ad860a
Drain every V_arr shard at end of solve (B)
hmgaudecker May 26, 2026
d062a75
SimulationResult.save/load via orbax + cloudpickle (C)
hmgaudecker May 26, 2026
b05d3c9
SimulationResult.save also writes simulated_data.arrow
hmgaudecker May 26, 2026
6cada5d
V_arr capitalisation, simpler drain, Path-only save/load
hmgaudecker May 26, 2026
88facaf
Mirror solve's materialisation discipline in simulate
hmgaudecker May 26, 2026
3c9d445
Self-gate validation; drop redundant diagnostics_enabled threading
hmgaudecker May 26, 2026
1a157a1
Default save's df_additional_targets to None; finish V_arr casing sweep
hmgaudecker May 26, 2026
6a8db1d
Bump pinned aca-model in benchmarks to 7af96820
hmgaudecker May 26, 2026
eedc35c
Pull single-device V_arr to host before orbax serialization
hmgaudecker May 27, 2026
6ba9c15
Type host/jax array-tree helpers as dict[str, Any]
hmgaudecker May 27, 2026
43f0a08
Save V_arr per (period, regime), chunked along the engine's splay axis
hmgaudecker May 27, 2026
3486f60
Print biggest jax.Array leaves before orbax save
hmgaudecker May 27, 2026
02e5853
Save V-array first, drop refs, then orbax the small tree
hmgaudecker May 28, 2026
64baa9b
Drop self._regimes before orbax to release pinned XLA workspaces
hmgaudecker May 28, 2026
f999109
Drop jax.clear_caches before orbax to keep topology metadata intact
hmgaudecker May 28, 2026
0130bf2
Save V-array first, then build dataframe with a clearer pool
hmgaudecker May 28, 2026
b84a0a2
Filter regime data on host to avoid JAX bool-index OOM
hmgaudecker May 28, 2026
f96a595
Drop self._regimes before to_dataframe when no targets
hmgaudecker May 28, 2026
541f191
Drain raw_results before save's pool reshuffling
hmgaudecker May 29, 2026
577a383
Stream _concatenate_and_filter column-by-column with shape census
hmgaudecker May 29, 2026
df8e060
Walk period_dicts period-by-period; transfer via shard iteration
hmgaudecker May 29, 2026
907df02
Type _to_host input as FloatND | IntND | BoolND; register PeriodRegim…
hmgaudecker May 29, 2026
709a992
Drop OOM-hunt scaffolding: dataframe shape-census + redundant save drain
hmgaudecker May 29, 2026
8d062aa
Add subject_batch_size: chunk the forward simulation over subjects
hmgaudecker May 31, 2026
9b5bc3c
Offload each subject chunk to host when batching
hmgaudecker May 31, 2026
95ae8df
Chunk to_dataframe target evaluation by subject_batch_size
hmgaudecker May 31, 2026
c701944
Save the solution via orbax only, dropping the chunked-.npy path
hmgaudecker May 31, 2026
f2e7179
Save period_to_regime_to_V_arr directly via orbax; drop chunk_specs
hmgaudecker May 31, 2026
6df7f71
Compile simulate functions for the chunk shape; overlap the final chunk
hmgaudecker May 31, 2026
960d3bf
Overwrite existing checkpoints on save (orbax force=True)
hmgaudecker Jun 1, 2026
1b85508
Allow non-device-multiple subject counts under distributed grids
hmgaudecker Jun 1, 2026
5d4a430
Cloudpickle flat_params in metadata instead of the orbax array tree
hmgaudecker Jun 1, 2026
1759b94
Honor stochastic-grid batch_size in the Q-and-F shock integration
hmgaudecker Jun 1, 2026
4cdee1f
Keep V-array NaN/Inf reductions sharded via jit-wrapped helpers
hmgaudecker Jun 1, 2026
225d7a2
Document the out_shardings prerequisite for continuous-axis sharding
hmgaudecker Jun 1, 2026
a9eecf5
Measure GPU peak mem with autotuning off for a deterministic compile
hmgaudecker Jun 1, 2026
4d665cc
Merge branch 'feat/discrete-only-sharding-orbax-save' into feat/distr…
hmgaudecker Jun 1, 2026
cf0de59
Declare pyarrow as a runtime dependency
hmgaudecker Jun 2, 2026
7bb47e9
Add subject-batch autotuning core: peak estimate + size picker
hmgaudecker Jun 2, 2026
091718a
Wire subject_batch_size="auto"/0/>0 with a distributed guard
hmgaudecker Jun 2, 2026
d48b947
Harden autotune: degenerate probes, kw-only estimate, tiny-pop guard
hmgaudecker Jun 2, 2026
4c30dd6
Merge branch 'feat/discrete-only-sharding-orbax-save' into feat/distr…
hmgaudecker Jun 2, 2026
dd73843
Drop subject_batch_size="auto" autotuning
hmgaudecker Jun 3, 2026
e31b77e
Unify subject-axis padding via repeat-last
hmgaudecker Jun 3, 2026
d605c4b
Merge branch 'feat/discrete-only-sharding-orbax-save' into feat/distr…
hmgaudecker Jun 3, 2026
9142c8f
[pre-commit.ci] pre-commit autoupdate (#371)
pre-commit-ci[bot] Jun 3, 2026
48906ba
Guard against batch_size>0 on stochastic states
hmgaudecker Jun 3, 2026
8cf896b
Merge main into feat/distributed-solve-fixes
hmgaudecker Jun 3, 2026
89551eb
Revert "Guard against batch_size>0 on stochastic states"
hmgaudecker Jun 3, 2026
d26db9b
Revert "Honor stochastic-grid batch_size in the Q-and-F shock integra…
hmgaudecker Jun 3, 2026
ced2905
Test stochastic-state batch_size as a value-preserving splay
hmgaudecker Jun 3, 2026
fc204d1
docs(user_guide): add Performance and Memory Tuning page
hmgaudecker Jun 4, 2026
0391808
docs(tuning): record that autotune does not speed gather-bound solves
hmgaudecker Jun 4, 2026
7091ebe
Clarify splay-vs-shard guidance: shard the same axis / splay a differ…
hmgaudecker Jun 5, 2026
53855bf
Update ty, CI, and pre-commit hooks.
hmgaudecker Jun 5, 2026
3b1cf4a
Quiet beartype claw warnings on Categorical return annotations
hmgaudecker Jun 5, 2026
591da11
Point benchmark aca-model dependency at main
hmgaudecker Jun 5, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
- uses: actions/checkout@v6
- uses: prefix-dev/setup-pixi@v0.9.6
with:
pixi-version: v0.69.0
pixi-version: v0.70.1
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: tests-cpu
Expand All @@ -49,7 +49,7 @@ jobs:
if: runner.os == 'Linux' && matrix.python-version == '3.14'
- name: Upload coverage report
if: runner.os == 'Linux' && matrix.python-version == '3.14'
uses: codecov/codecov-action@v6
uses: codecov/codecov-action@v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
run-ty:
Expand All @@ -64,7 +64,7 @@ jobs:
- uses: actions/checkout@v6
- uses: prefix-dev/setup-pixi@v0.9.6
with:
pixi-version: v0.69.0
pixi-version: v0.70.1
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: type-checking
Expand All @@ -87,7 +87,7 @@ jobs:
- uses: actions/checkout@v6
- uses: prefix-dev/setup-pixi@v0.9.6
with:
pixi-version: v0.69.0
pixi-version: v0.70.1
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: tests-cuda12
Expand All @@ -106,7 +106,7 @@ jobs:
- uses: actions/checkout@v6
- uses: prefix-dev/setup-pixi@v0.9.6
with:
pixi-version: v0.69.0
pixi-version: v0.70.1
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: tests-cuda12
Expand All @@ -121,7 +121,7 @@ jobs:
- uses: actions/checkout@v6
- uses: prefix-dev/setup-pixi@v0.9.6
with:
pixi-version: v0.69.0
pixi-version: v0.70.1
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: docs
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ repos:
- id: check-hooks-apply
- id: check-useless-excludes
- repo: https://github.com/tox-dev/pyproject-fmt
rev: v2.21.2
rev: v2.23.0
hooks:
- id: pyproject-fmt
- repo: https://github.com/lyz-code/yamlfix
Expand Down Expand Up @@ -54,7 +54,7 @@ repos:
hooks:
- id: check-github-workflows
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.14
rev: v0.15.16
hooks:
- id: ruff-check
args:
Expand Down
1 change: 1 addition & 0 deletions docs/myst.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ project:
- file: user_guide/parameters.md
- file: user_guide/pandas_interop.md
- file: user_guide/solving_and_simulating.md
- file: user_guide/tuning.md
- file: user_guide/benchmarking.md
- file: user_guide/debugging.md
- file: explanations/index.md
Expand Down
1 change: 1 addition & 0 deletions docs/user_guide/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ philosophy. The remaining pages cover individual topics in depth.
- [Parameters](parameters.md)
- [Working with DataFrames and Series](pandas_interop.md)
- [Solving and Simulating](solving_and_simulating.md)
- [Performance and Memory Tuning](tuning.md)
- [Debugging](debugging.md)
215 changes: 215 additions & 0 deletions docs/user_guide/tuning.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
---
title: Performance and Memory Tuning
---

# Performance and Memory Tuning

Two questions decide how a model runs on accelerators: *does it fit in memory*, and *are
the devices used well*. pylcm keeps them separate as two independent knobs on every grid
— `batch_size` (splay) and `distributed` (shard) — plus a forward-simulation chunk size
and a handful of XLA environment flags. This page explains what each does, when it
helps, and the trade-offs that are easy to get backwards.

The one-line model:

- **`batch_size` (splay) is a memory knob. It is time-neutral.**
- **`distributed` (shard) is a speed knob. It applies only to discrete,
non-transitioning axes.**

Keeping these straight is the whole game: splaying never speeds anything up, and
sharding is the only knob that does.

## The two grid knobs

Every grid — `DiscreteGrid` and every continuous grid (`LinSpacedGrid`, `LogSpacedGrid`,
`IrregSpacedGrid`, the piecewise variants) — takes both:

```python
from lcm.grids import DiscreteGrid, LinSpacedGrid

# A permanent (never-transitioning) discrete state, sharded one block per device (speed):
pref_type = DiscreteGrid(PrefType, distributed=True)

# A continuous state, scan-chunked into pieces to save memory (time-neutral):
assets = LinSpacedGrid(start=0.0, stop=1_000.0, n_points=200, batch_size=50)
```

| knob | what it does | what it buys | applies to |
| -------------------------- | ------------------------------------------------------------------------ | ----------------- | ------------------------------------- |
| `batch_size=k` (splay) | `lax.scan` the per-period work over chunks of `k` points along that axis | lower peak memory | any axis |
| `distributed=True` (shard) | place that axis's blocks on separate devices | parallel speedup | discrete, non-transitioning axes only |

`batch_size=0` (the default) means "no splay" — one kernel per period over the full
axis. `distributed=False` (the default) means "not sharded".

## `batch_size`: splay for memory, time-neutral

At each period, backward induction builds the value array over every (state, action)
combination and maximises over actions. `batch_size=k` only changes how that work is
*tiled*: instead of one big `vmap`, it runs a `lax.scan` over chunks of `k` points along
the chosen axis. **The total FLOPs are identical** — every combination is still
evaluated exactly once — so the wall-clock barely moves. What drops is peak memory,
because only one chunk's intermediate is live at a time.

Splay stays time-neutral as long as each chunk still has enough parallel work to
saturate the device — and in a real model it does, because the other grid dimensions
(assets × savings × shocks × …) provide ample parallelism inside every chunk.

It stops being free only at the extremes:

- **Over-chunking** (very small `batch_size` → many tiny chunks): per-launch overhead
piles up, and a chunk can get too small to saturate the device. This bites hardest
when CUDA graphs are off (see [Environment flags](#environment-flags)), because every
chunk is then launched individually.
- **Under-chunking** (`batch_size=0`, batch the whole axis): the full intermediate must
fit at once. If that forces the allocator to spill or to shrink fusion tiles, batching
can be *slower* than splaying — which is the whole reason the knob exists.

**Which axis to splay.** Prefer a large, *uniform* axis:

- Continuous axes (savings, assets, accumulated earnings) are ideal: they have many
points (fine control over the chunk count) and are full-size in every regime, so the
relief is uniform.
- A discrete axis that *collapses* in some regimes — for example a lagged choice that is
fixed when the agent is forced out of the labour market — gives lumpy relief: splaying
it does nothing in the regimes where it is already a singleton.

**Rule: use the fewest chunks that fit.** Halving memory needs only two chunks
(`batch_size = n_points / 2`), not `batch_size = 1`.

## `distributed`: shard for speed (discrete, non-transitioning axes)

`distributed=True` places the blocks of an axis on separate devices and solves them in
parallel. It is the only knob that reduces wall-clock — but it is legal only for a
narrow class of axes, and pylcm enforces the boundaries at construction time.

**It runs communication-free only for axes the agent never transitions along.** If an
agent's position on the axis is fixed for life (a permanent type, a fixed group), each
block's value function is independent of the others, so the blocks sit on different
devices with *zero* cross-device traffic. An axis the agent *moves along* (health,
wealth, a lagged choice) couples the blocks: every period would need an all-to-all
exchange, and the communication swamps the compute.

Two guards make this concrete — both raise `GridInitializationError` at construction:

- **Continuous axes cannot be sharded.** `distributed=True` on any continuous grid is
rejected. (Continuous-axis sharding would require the solved value array to carry an
explicit output sharding; that path is not enabled.)
- **You cannot splay and shard the same axis.** `batch_size > 0` together with
`distributed=True` is rejected: each batch is its own dispatch, and on a sharded axis
every dispatch carries a per-period cross-device collective, so batching multiplies
the synchronisation count (`×ceil(n_per_device / batch_size)`) and inverts the
compute/communication ratio. Keep `batch_size=0` on the sharded axis. When a device's
chunk is too big, shed memory by splaying a *different*, non-sharded axis — usually
the practical fix, since it needs no extra devices. If you do have spare devices,
shard the same axis across more of them: that helps precisely when a device holds more
than one block (`n_points / n_devices > 1`), the only case where splaying the sharded
axis would have helped anyway, and it shrinks the per-device chunk *and* adds
parallelism with no extra collectives.

```{note}
Sharding divides the state space across devices, so it also *reduces* per-device memory — a
sharded model often needs no splay at all. Reach for splay only if a single device still
cannot hold its share.
```

## Forward simulation: `subject_batch_size`

Solving is one memory profile; simulating a large panel forward is another.
`Model.simulate(..., subject_batch_size=k)` chunks the simulated subjects so only one
chunk is resident at a time:

- `subject_batch_size=0` (the default) simulates all subjects in a single pass.
- `subject_batch_size=k` walks the panel in chunks of `k`.

Like grid `batch_size`, this is a time-neutral memory knob — raise the chunk count if
the simulated panel does not fit, and otherwise leave it at a single pass.

## Worked example

Measured on 80 GB A100s, one six-regime lifecycle model:

- **One GPU, every axis batched** — full solve + simulate ≈ **1 h 37 m**.
- **Three GPUs, the permanent-type axis sharded one block per device** — a *heavier*
policy-overlay variant of the same model ≈ **59 m**. The shard more than offsets the
extra per-regime work: three devices beat one even on a bigger problem.
- **Two single-GPU runs that differ only in which axis is chunked for memory** finished
within about a minute of each other (≈ 1 h 37 m vs ≈ 1 h 38 m) — direct confirmation
that the choice of splay axis is time-neutral; only the device count moved the wall.

The takeaway is the one-line model: the multiplicative speedup comes from *sharding*
across devices, not from any choice of `batch_size`.

## Environment flags

pylcm sets two JAX defaults at import and leaves the rest to the environment.

**Set by pylcm (override before importing `lcm`):**

- `XLA_PYTHON_CLIENT_PREALLOCATE=false` — allocate GPU memory on demand instead of
grabbing a fixed fraction up front. This plays nicely with other processes and makes
`nvidia-smi` and memory benchmarks reflect real usage.
- `JAX_COMPILATION_CACHE_DIR=~/.cache/jax` — persist the JIT cache so repeated runs of a
large (many-regime) model skip the multi-minute compile.

**Knobs you set yourself**, with the trade-off each carries:

- `XLA_PYTHON_CLIENT_PREALLOCATE=true` — preallocate a single pool. At production scale
a stable pool avoids fragmentation and reduces allocator churn across the solve; pair
it with `XLA_PYTHON_CLIENT_MEM_FRACTION`.
- `XLA_PYTHON_CLIENT_MEM_FRACTION=0.90` — the fraction of device memory the preallocated
pool claims. The remainder stays as non-pool headroom that the driver, collectives,
and CUDA graphs draw on; leave enough for them on a multi-GPU run.
- `XLA_PYTHON_CLIENT_ALLOCATOR=default` — keep JAX's pooled BFC allocator. The
`platform` setting (per-op `cudaMalloc`/`cudaFree`) is dramatically slower; avoid it.
- `XLA_FLAGS=--xla_gpu_autotune_level=0` — disable kernel autotuning. Off gives a
deterministic, lower-memory compile; on searches for faster GEMM/conv kernels but
reserves the largest candidate's scratch at compile time, which can re-trigger an OOM
on a model that barely fits. **Default to off.** Backward induction is dominated by
gather/scatter and interpolation over the state-action grid, not dense GEMMs, so
autotuning has little to optimize: head-to-head, the per-period execution time is
unchanged on/off (matched to logging precision), while compile time and peak memory
both rise. Turn it on only if a measurement on your model shows an actual per-period
speedup.
- `XLA_FLAGS=--xla_gpu_enable_command_buffer=` (empty, i.e. disabled) — turn off CUDA
graphs. Command buffers batch kernel launches but consume non-pool driver memory;
disabling them frees that headroom at the cost of per-launch overhead. That overhead
lands hardest on splay-heavy configs (many small kernels), so a heavily-splayed model
pays more for disabling them.

```{warning}
Sharding only helps if the devices are actually visible *and* exclusively yours. If your
launcher grants N GPUs but `CUDA_VISIBLE_DEVICES` exposes only one, a model declared
`distributed=True` silently runs on a single device — the classic "allocated 3, saw 1";
assert `jax.device_count()` matches what you sharded for at startup, before the solve.
And with `PREALLOCATE=true`, a GPU that another job or a leaked process is already using
fails the pool preallocation outright — an `OUT_OF_MEMORY` on device 0 within seconds of
startup (not mid-solve) — so request GPUs exclusively.
```

**A stable multi-GPU configuration.** One environment that holds up at production scale,
trading compile-time kernel search and launch batching for memory headroom:

```bash
export XLA_PYTHON_CLIENT_PREALLOCATE=true
export XLA_PYTHON_CLIENT_ALLOCATOR=default # pooled BFC
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.90
export XLA_FLAGS='--xla_gpu_autotune_level=0 --xla_gpu_enable_command_buffer='
```

Command buffers are the one knob to revisit once a model fits comfortably: re-enabling
them amortizes launch overhead, at the cost of the non-pool driver memory they consume.
Autotuning, by contrast, has not been observed to speed these gather-bound solves, so
leaving it off costs nothing and keeps the memory headroom.

## Checklist

- Shard a never-transitioning discrete axis across devices for speed
(`distributed=True`).
- Keep `batch_size=0` on a sharded axis — never batch and shard the same axis.
- If a single device still can't hold its share, splay a large continuous axis, using
the fewest chunks that fit.
- Never splay a sharded axis, and never expect splay to speed anything up — it only buys
memory.
- Chunk the forward pass with `subject_batch_size` if the simulated panel doesn't fit.
- Verify `jax.device_count()` matches your sharding before the solve.
Loading
Loading