Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions src/_lcm/solution/solve_brute.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,20 @@ def solve(
}
)

# AOT-compile all unique max_Q_over_a functions in parallel.
# AOT-compile all unique max_Q_over_a functions in parallel. The regime's
# declared V-array sharding rides along as `out_shardings` so the solved V
# is emitted on the layout its next-period consumers were lowered against.
compiled_functions = _compile_all_functions(
regimes=regimes,
flat_params=flat_params,
ages=ages,
next_regime_to_V_arr=next_regime_to_V_arr,
regime_V_shardings=MappingProxyType(
{
regime_name: topology.sharding
for regime_name, topology in regime_V_topology.items()
}
),
enable_jit=enable_jit,
max_compilation_workers=max_compilation_workers,
logger=logger,
Expand Down Expand Up @@ -279,6 +287,7 @@ def _compile_all_functions(
flat_params: FlatParams,
ages: AgeGrid,
next_regime_to_V_arr: MappingProxyType[RegimeName, FloatND],
regime_V_shardings: MappingProxyType[RegimeName, jax.NamedSharding | None],
enable_jit: bool,
max_compilation_workers: int | None,
logger: logging.Logger,
Expand All @@ -299,6 +308,14 @@ def _compile_all_functions(
ages: Age grid for the model.
next_regime_to_V_arr: Template with consistent keys and V array shapes
for constructing lowering arguments.
regime_V_shardings: Immutable mapping of each regime's V-array device
sharding, or `None` where the regime distributes no state. Passed as
the `out_shardings` of the compiled `max_Q_over_a` so the solved
array carries the declared sharding its next-period consumers were
lowered against, without a post-hoc `device_put` reshard. Natural jit
output inference reproduces the declared layout for a sharded
discrete axis but not for a sharded continuous one, where the
interpolation read mixes the axis into the output.
enable_jit: Whether to JIT-compile the functions of the internal regimes.
max_compilation_workers: Maximum threads for parallel compilation.
Defaults to `os.cpu_count()`.
Expand Down Expand Up @@ -360,7 +377,9 @@ def _compile_all_functions(
logger.info("%d/%d %s", i, n_unique, label)
logger.info(" lowering ...")
start = time.monotonic()
lowered[func_id] = jax.jit(func).lower(**lower_args)
lowered[func_id] = jax.jit(
func, out_shardings=regime_V_shardings[regime_name]
).lower(**lower_args)
elapsed = time.monotonic() - start
logger.info(" lowered in %s", format_duration(seconds=elapsed))

Expand Down
Loading