Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions benchmarks/_gpu_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class MahlerYumGpuPeakMem(GpuPeakMem):
# Project root: the directory containing the benchmarks/ package.
_PROJECT_ROOT = Path(__file__).resolve().parent.parent

# Marks the peak-memory line on the subprocess's stdout. The subprocess imports
# lcm, whose beartype claw can emit diagnostics to stdout, so the parent locates
# this line instead of parsing stdout wholesale.
_PEAK_MARKER = "__PEAK_BYTES_IN_USE__"


def measure_gpu_peak(bench_module: str, bench_class: str) -> int:
"""Run a benchmark in a subprocess and return peak GPU bytes.
Expand Down Expand Up @@ -58,7 +63,15 @@ def measure_gpu_peak(bench_module: str, bench_class: str) -> int:
f"stderr: {result.stderr!r}"
)
raise RuntimeError(msg)
return int(result.stdout.strip())
for line in result.stdout.splitlines():
if line.startswith(_PEAK_MARKER):
return int(line.removeprefix(_PEAK_MARKER).strip())
msg = (
"GPU memory subprocess produced no peak-bytes line.\n"
f"stdout: {result.stdout!r}\n"
f"stderr: {result.stderr!r}"
)
raise RuntimeError(msg)


def _track_gpu_peak_mem(self):
Expand Down Expand Up @@ -104,4 +117,4 @@ def setup(self):
import jax

stats = jax.local_devices()[0].memory_stats()
print(stats["peak_bytes_in_use"])
print(f"{_PEAK_MARKER} {stats['peak_bytes_in_use']}")
47 changes: 23 additions & 24 deletions pixi.lock

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ ty = "ty check"
jax = ">=0.9"
pdbp = "*"
pylcm = { path = ".", editable = true }
# Pin dags to the feat/no-type-check-flag branch (PR
# OpenSourceEconomics/dags#82): its wrappers advertise the `*args,
# **kwargs` forwarder shape on `__annotations__`, so beartype's import
# claw treats them as permissive forwarders. Replace with `dags>=0.6`
# once that PR is released.
dags = { git = "https://github.com/OpenSourceEconomics/dags.git", rev = "cf59c04" }
[tool.pixi.tasks]
asv-compare = "asv compare"
asv-preview = "asv preview"
Expand Down
7 changes: 6 additions & 1 deletion src/lcm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@

import jax

# Patch jaxtyping's `"..."` sentinel to survive pickling before any
# `jaxtyping`-subscripted type is created (see the module docstring).
from lcm import _jaxtyping_patch # noqa: F401

with contextlib.suppress(ImportError):
import pdbp # noqa: F401

Expand All @@ -38,11 +42,12 @@
# exception most natural to that subpackage (see `lcm._beartype_conf`).
from beartype.claw import beartype_package

from lcm._beartype_conf import GRID_CONF, PARAMS_CONF
from lcm._beartype_conf import GRID_CONF, PARAMS_CONF, REGIME_BUILDING_CONF

beartype_package("lcm.grids", conf=GRID_CONF)
beartype_package("lcm.shocks", conf=GRID_CONF)
beartype_package("lcm.params", conf=PARAMS_CONF)
beartype_package("lcm.regime_building", conf=REGIME_BUILDING_CONF)

from lcm import shocks # noqa: E402
from lcm._version import __version__ # noqa: E402
Expand Down
4 changes: 4 additions & 0 deletions src/lcm/_beartype_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,7 @@ def _conf(exc: type[Exception]) -> BeartypeConf:

# Used on `Model.solve` and `Model.simulate`.
PARAMS_CONF = _conf(InvalidParamsError)

# Used by the claw on `lcm.regime_building` (regime compilation pipeline,
# part of model construction).
REGIME_BUILDING_CONF = _conf(ModelInitializationError)
36 changes: 36 additions & 0 deletions src/lcm/_jaxtyping_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Make jaxtyping's anonymous-variadic-dim sentinel survive pickling.

jaxtyping marks a `"..."` axis with a module-level `object()` sentinel
(`_anonymous_variadic_dim`). A plain `object()` does not keep its identity
across a pickle round-trip, so cloudpickling a value whose type annotations
reference a `Foo[Array, "..."]` type — which the beartype claw makes
pervasive — yields a type whose variadic-dim marker no longer matches the
live module global. jaxtyping's shape check then trips
`assert type(variadic_dim) is _NamedVariadicDim`.

Replacing the sentinel with a `__reduce__`-backed singleton makes it
round-trip to the same object, so unpickled annotation types stay valid.
This module must be imported before any `jaxtyping`-subscripted type is
created — `lcm/__init__.py` imports it before every other `lcm` submodule.
"""

from typing import Self

from jaxtyping import _array_types


class _AnonymousVariadicDim:
"""Picklable singleton for jaxtyping's `"..."` axis marker."""

_instance: Self | None = None

def __new__(cls) -> Self:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __reduce__(self) -> tuple[type[_AnonymousVariadicDim], tuple[()]]:
return (_AnonymousVariadicDim, ())


_array_types._anonymous_variadic_dim = _AnonymousVariadicDim() # noqa: SLF001
3 changes: 3 additions & 0 deletions src/lcm/regime_building/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ def _wrap_with_reduction(

"""

# `kwargs` carries the wrapped function's full input map: the
# `next_regime_to_V_arr` mapping alongside the Float/Int/Bool-valued
# state/action inputs.
def reduced(
**kwargs: MappingProxyType[RegimeName, FloatND] | FloatND | IntND | BoolND,
) -> dict[str, Any]:
Expand Down
6 changes: 5 additions & 1 deletion src/lcm/regime_building/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,11 @@ def _get_weights_func_for_shock(*, name: str, grid: _ShockGrid) -> UserFunction:

@with_signature(args=args, return_annotation="FloatND", enforce=False)
def weights_func_runtime(*a: FloatND, **kwargs: FloatND) -> Float1D: # noqa: ARG001
shock_kw: dict[str, float] = { # ty: ignore[invalid-assignment]
# `float` here covers Python floats from fixed_params; under
# JIT tracing, the runtime values forwarded through `kwargs`
# arrive as JAX tracers (`FloatND`), which are accepted by the
# shock grid's `compute_gridpoints` / `compute_transition_probs`.
shock_kw: dict[str, float | FloatND] = {
**fixed_params,
**{raw: kwargs[qn] for qn, raw in runtime_param_names.items()},
}
Expand Down
2 changes: 1 addition & 1 deletion src/lcm/regime_building/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _find_function_output_grid_indexing(


def collect_state_transitions(
states: Mapping[StateName, Grid],
states: Mapping[StateName, Grid | None],
state_transitions: Mapping[
StateName,
UserFunction | Callable | None | Mapping[RegimeName, UserFunction | Callable],
Expand Down
10 changes: 5 additions & 5 deletions src/lcm/solution/solve_brute.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from lcm.ages import AgeGrid
from lcm.interfaces import InternalRegime, _build_regime_sharding
from lcm.typing import FloatND, InternalParams, RegimeName, StateName
from lcm.typing import BoolND, FloatND, InternalParams, RegimeName, StateName
from lcm.utils.error_handling import validate_V
from lcm.utils.logging import (
format_duration,
Expand Down Expand Up @@ -107,8 +107,8 @@ def solve(
diagnostic_min: list[FloatND] = []
diagnostic_max: list[FloatND] = []
diagnostic_mean: list[FloatND] = []
running_any_nan: FloatND = jnp.zeros((), dtype=bool)
running_any_inf: FloatND = jnp.zeros((), dtype=bool)
running_any_nan: BoolND = jnp.zeros((), dtype=bool)
running_any_inf: BoolND = jnp.zeros((), dtype=bool)

logger.info("Starting solution")
total_start = time.monotonic()
Expand Down Expand Up @@ -471,8 +471,8 @@ def _emit_post_loop_diagnostics(
solution: MappingProxyType[int, MappingProxyType[RegimeName, FloatND]],
internal_regimes: MappingProxyType[RegimeName, InternalRegime],
internal_params: InternalParams,
running_any_nan: FloatND,
running_any_inf: FloatND,
running_any_nan: BoolND,
running_any_inf: BoolND,
diagnostic_min: list[FloatND] | None,
diagnostic_max: list[FloatND] | None,
diagnostic_mean: list[FloatND] | None,
Expand Down
13 changes: 8 additions & 5 deletions tests/test_ndimage_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@


def test_map_coordinates_wrong_input_dimensions():
values = jnp.arange(2) # ndim = 1
coordinates = [jnp.array([0]), jnp.array([1])] # len = 2
values = jnp.arange(2, dtype=jnp.int32) # ndim = 1
coordinates = [
jnp.array([0], dtype=jnp.int32),
jnp.array([1], dtype=jnp.int32),
] # len = 2
with pytest.raises(ValueError, match="coordinates must be a sequence of length"):
map_coordinates(values, coordinates)

Expand All @@ -29,7 +32,7 @@ def test_map_coordinates_extrapolation():


def test_nonempty_sum():
a = jnp.arange(3)
a = jnp.arange(3, dtype=jnp.int32)

expected = a + a + a
got = _sum_all([a, a, a])
Expand All @@ -38,7 +41,7 @@ def test_nonempty_sum():


def test_nonempty_prod():
a = jnp.arange(3)
a = jnp.arange(3, dtype=jnp.int32)

expected = a * a * a
got = _multiply_all([a, a, a])
Expand Down Expand Up @@ -75,7 +78,7 @@ def test_linear_indices_and_weights_inside_domain():


def test_linear_indices_and_weights_outside_domain():
coordinates = jnp.array([-1, 2])
coordinates = jnp.array([-1.0, 2.0])

(idx_low, weight_low), (idx_high, weight_high) = _compute_indices_and_weights(
coordinates, input_size=2
Expand Down
2 changes: 1 addition & 1 deletion tests/test_next_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class MockCategory:


def test_create_stochastic_next_func():
labels = jnp.arange(2)
labels = jnp.arange(2, dtype=jnp.int32)
got_func = _create_discrete_stochastic_next_func(
target="t", next_state_name="next_a", labels=labels
)
Expand Down
Loading