diff --git a/benchmarks/_gpu_mem.py b/benchmarks/_gpu_mem.py index 45e328216..591b6ba18 100644 --- a/benchmarks/_gpu_mem.py +++ b/benchmarks/_gpu_mem.py @@ -26,6 +26,11 @@ class MahlerYumGpuPeakMem(GpuPeakMem): # Project root: the directory containing the benchmarks/ package. _PROJECT_ROOT = Path(__file__).resolve().parent.parent +# Marks the peak-memory line on the subprocess's stdout. The subprocess imports +# lcm, whose beartype claw can emit diagnostics to stdout, so the parent locates +# this line instead of parsing stdout wholesale. +_PEAK_MARKER = "__PEAK_BYTES_IN_USE__" + def measure_gpu_peak(bench_module: str, bench_class: str) -> int: """Run a benchmark in a subprocess and return peak GPU bytes. @@ -58,7 +63,15 @@ def measure_gpu_peak(bench_module: str, bench_class: str) -> int: f"stderr: {result.stderr!r}" ) raise RuntimeError(msg) - return int(result.stdout.strip()) + for line in result.stdout.splitlines(): + if line.startswith(_PEAK_MARKER): + return int(line.removeprefix(_PEAK_MARKER).strip()) + msg = ( + "GPU memory subprocess produced no peak-bytes line.\n" + f"stdout: {result.stdout!r}\n" + f"stderr: {result.stderr!r}" + ) + raise RuntimeError(msg) def _track_gpu_peak_mem(self): @@ -104,4 +117,4 @@ def setup(self): import jax stats = jax.local_devices()[0].memory_stats() - print(stats["peak_bytes_in_use"]) + print(f"{_PEAK_MARKER} {stats['peak_bytes_in_use']}") diff --git a/pixi.lock b/pixi.lock index 05c13a7af..f46905259 100644 --- a/pixi.lock +++ b/pixi.lock @@ -275,7 +275,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/8d/2d/f61c918d9edc2127068f0d5ad4604fedd9bfd393f464219090f3279c73f7/estimagic-0.5.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d5/1f/5f4a3cd9e4440e9d9bc78ad0a91a1c8d46b4d429d5239ebe6793c9fe5c41/fsspec-2026.3.0-py3-none-any.whl @@ -588,7 +588,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5f/78/a3d9ceda0793f4fb43daa292af7b801932611a1aed442636ddfc93d58c7a/jax_cuda12_pjrt-0.10.0-py3-none-manylinux_2_27_x86_64.whl @@ -888,7 +888,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/21/98/77f15d81fd0637da454e453c8456d4a2b5c8b2e66823b4237ee8689152cf/jax_cuda13_pjrt-0.10.0-py3-none-manylinux_2_27_x86_64.whl @@ -1158,7 +1158,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a1/8e/b2a08ffc51c93842de71f7f988865cebfa7f43d6721957812dc8cc8b9d40/jaxlib-0.10.0-cp314-cp314-manylinux_2_27_x86_64.whl @@ -1399,7 +1399,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-1.3.2-h8088a28_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a7/25/e1e52a21786b321fb6a2edf9ef9971aa70f06bb2738aef9afd6d8f46a441/jaxlib-0.10.0-cp314-cp314-macosx_11_0_arm64.whl @@ -1637,7 +1637,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/zlib-1.3.2-hfd05255_2.conda - conda: https://conda.anaconda.org/conda-forge/win-64/zstd-1.5.7-h534d264_6.conda - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/24/08/26e6a3ecf0a95f1ec0dcd7a668d5c9a72e581c40fe4ae51e102ca63174c5/jaxlib-0.10.0-cp314-cp314-win_amd64.whl @@ -1895,7 +1895,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a1/8e/b2a08ffc51c93842de71f7f988865cebfa7f43d6721957812dc8cc8b9d40/jaxlib-0.10.0-cp314-cp314-manylinux_2_27_x86_64.whl @@ -2146,7 +2146,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-1.3.2-h8088a28_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a7/25/e1e52a21786b321fb6a2edf9ef9971aa70f06bb2738aef9afd6d8f46a441/jaxlib-0.10.0-cp314-cp314-macosx_11_0_arm64.whl @@ -2393,7 +2393,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/zlib-1.3.2-hfd05255_2.conda - conda: https://conda.anaconda.org/conda-forge/win-64/zstd-1.5.7-h534d264_6.conda - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/24/08/26e6a3ecf0a95f1ec0dcd7a668d5c9a72e581c40fe4ae51e102ca63174c5/jaxlib-0.10.0-cp314-cp314-win_amd64.whl @@ -2648,7 +2648,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-1.3.2-h8088a28_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_metal-0.1.1-py3-none-macosx_13_0_arm64.whl @@ -2911,7 +2911,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a1/8e/b2a08ffc51c93842de71f7f988865cebfa7f43d6721957812dc8cc8b9d40/jaxlib-0.10.0-cp314-cp314-manylinux_2_27_x86_64.whl @@ -3166,7 +3166,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-1.3.2-h8088a28_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a7/25/e1e52a21786b321fb6a2edf9ef9971aa70f06bb2738aef9afd6d8f46a441/jaxlib-0.10.0-cp314-cp314-macosx_11_0_arm64.whl @@ -3417,7 +3417,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/zlib-1.3.2-hfd05255_2.conda - conda: https://conda.anaconda.org/conda-forge/win-64/zstd-1.5.7-h534d264_6.conda - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/24/08/26e6a3ecf0a95f1ec0dcd7a668d5c9a72e581c40fe4ae51e102ca63174c5/jaxlib-0.10.0-cp314-cp314-win_amd64.whl @@ -3716,7 +3716,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5f/78/a3d9ceda0793f4fb43daa292af7b801932611a1aed442636ddfc93d58c7a/jax_cuda12_pjrt-0.10.0-py3-none-manylinux_2_27_x86_64.whl @@ -4029,7 +4029,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.2-h25fd6f3_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/21/98/77f15d81fd0637da454e453c8456d4a2b5c8b2e66823b4237ee8689152cf/jax_cuda13_pjrt-0.10.0-py3-none-manylinux_2_27_x86_64.whl @@ -4308,7 +4308,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-1.3.2-h8088a28_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_metal-0.1.1-py3-none-macosx_13_0_arm64.whl @@ -4673,7 +4673,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-ng-2.3.3-hceb46e0_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a1/8e/b2a08ffc51c93842de71f7f988865cebfa7f43d6721957812dc8cc8b9d40/jaxlib-0.10.0-cp314-cp314-manylinux_2_27_x86_64.whl @@ -4964,7 +4964,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-ng-2.3.3-hed4e4f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a7/25/e1e52a21786b321fb6a2edf9ef9971aa70f06bb2738aef9afd6d8f46a441/jaxlib-0.10.0-cp314-cp314-macosx_11_0_arm64.whl @@ -5274,7 +5274,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/zlib-ng-2.3.3-h0261ad2_1.conda - conda: https://conda.anaconda.org/conda-forge/win-64/zstd-1.5.7-h534d264_6.conda - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl + - pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 - pypi: https://files.pythonhosted.org/packages/43/f5/ee39c6e92acc742c052f137b47c210cd0a1b72dcd3f98495528bb4d27761/flatten_dict-0.4.2-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/70/aa/dfac6d72cc35bc07e7587115b6946e333ef4ccb2e6cd26ecf639438c5d26/jax-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/24/08/26e6a3ecf0a95f1ec0dcd7a668d5c9a72e581c40fe4ae51e102ca63174c5/jaxlib-0.10.0-cp314-cp314-win_amd64.whl @@ -7542,10 +7542,9 @@ packages: purls: [] size: 210103 timestamp: 1771943128249 -- pypi: https://files.pythonhosted.org/packages/2c/c1/a662f0a8f6e024fca239d493f278d9adf5de1c8408af46a53a76beb13534/dags-0.5.1-py3-none-any.whl +- pypi: git+https://github.com/OpenSourceEconomics/dags.git?rev=cf59c04#cf59c04c6ba07b7c54ca763dc155deea3341a480 name: dags - version: 0.5.1 - sha256: e9fd9fbe0536784fe8b8ce58ea194801b1de39d7364941d4a1f2d8240c14123d + version: 0.5.2.dev6+gcf59c04c6 requires_dist: - flatten-dict - networkx>=3.6 @@ -14086,8 +14085,8 @@ packages: timestamp: 1774796815820 - pypi: ./ name: pylcm - version: 0.0.2.dev136+ga5d932a48.d20260513 - sha256: fc3e3cff622b9db6e1f5a01c6cfad41879d5d1efdb982743f18b8ca3edf585a7 + version: 0.0.2.dev182+g1a92dffec.d20260514 + sha256: 2ac3c0e6987658df12a93886cc7ea80b12456826f61aa1720a0c56e8fd828128 requires_dist: - beartype>=0.21 - cloudpickle>=3.1.2 diff --git a/pyproject.toml b/pyproject.toml index 3651a96a5..6e8f346d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,6 +149,12 @@ ty = "ty check" jax = ">=0.9" pdbp = "*" pylcm = { path = ".", editable = true } +# Pin dags to the feat/no-type-check-flag branch (PR +# OpenSourceEconomics/dags#82): its wrappers advertise the `*args, +# **kwargs` forwarder shape on `__annotations__`, so beartype's import +# claw treats them as permissive forwarders. Replace with `dags>=0.6` +# once that PR is released. +dags = { git = "https://github.com/OpenSourceEconomics/dags.git", rev = "cf59c04" } [tool.pixi.tasks] asv-compare = "asv compare" asv-preview = "asv preview" diff --git a/src/lcm/__init__.py b/src/lcm/__init__.py index bb00b6127..e18fdc350 100644 --- a/src/lcm/__init__.py +++ b/src/lcm/__init__.py @@ -25,6 +25,10 @@ import jax +# Patch jaxtyping's `"..."` sentinel to survive pickling before any +# `jaxtyping`-subscripted type is created (see the module docstring). +from lcm import _jaxtyping_patch # noqa: F401 + with contextlib.suppress(ImportError): import pdbp # noqa: F401 @@ -38,11 +42,12 @@ # exception most natural to that subpackage (see `lcm._beartype_conf`). from beartype.claw import beartype_package -from lcm._beartype_conf import GRID_CONF, PARAMS_CONF +from lcm._beartype_conf import GRID_CONF, PARAMS_CONF, REGIME_BUILDING_CONF beartype_package("lcm.grids", conf=GRID_CONF) beartype_package("lcm.shocks", conf=GRID_CONF) beartype_package("lcm.params", conf=PARAMS_CONF) +beartype_package("lcm.regime_building", conf=REGIME_BUILDING_CONF) from lcm import shocks # noqa: E402 from lcm._version import __version__ # noqa: E402 diff --git a/src/lcm/_beartype_conf.py b/src/lcm/_beartype_conf.py index b68ff2034..6d3f1175f 100644 --- a/src/lcm/_beartype_conf.py +++ b/src/lcm/_beartype_conf.py @@ -46,3 +46,7 @@ def _conf(exc: type[Exception]) -> BeartypeConf: # Used on `Model.solve` and `Model.simulate`. PARAMS_CONF = _conf(InvalidParamsError) + +# Used by the claw on `lcm.regime_building` (regime compilation pipeline, +# part of model construction). +REGIME_BUILDING_CONF = _conf(ModelInitializationError) diff --git a/src/lcm/_jaxtyping_patch.py b/src/lcm/_jaxtyping_patch.py new file mode 100644 index 000000000..e9dced877 --- /dev/null +++ b/src/lcm/_jaxtyping_patch.py @@ -0,0 +1,36 @@ +"""Make jaxtyping's anonymous-variadic-dim sentinel survive pickling. + +jaxtyping marks a `"..."` axis with a module-level `object()` sentinel +(`_anonymous_variadic_dim`). A plain `object()` does not keep its identity +across a pickle round-trip, so cloudpickling a value whose type annotations +reference a `Foo[Array, "..."]` type — which the beartype claw makes +pervasive — yields a type whose variadic-dim marker no longer matches the +live module global. jaxtyping's shape check then trips +`assert type(variadic_dim) is _NamedVariadicDim`. + +Replacing the sentinel with a `__reduce__`-backed singleton makes it +round-trip to the same object, so unpickled annotation types stay valid. +This module must be imported before any `jaxtyping`-subscripted type is +created — `lcm/__init__.py` imports it before every other `lcm` submodule. +""" + +from typing import Self + +from jaxtyping import _array_types + + +class _AnonymousVariadicDim: + """Picklable singleton for jaxtyping's `"..."` axis marker.""" + + _instance: Self | None = None + + def __new__(cls) -> Self: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __reduce__(self) -> tuple[type[_AnonymousVariadicDim], tuple[()]]: + return (_AnonymousVariadicDim, ()) + + +_array_types._anonymous_variadic_dim = _AnonymousVariadicDim() # noqa: SLF001 diff --git a/src/lcm/regime_building/diagnostics.py b/src/lcm/regime_building/diagnostics.py index 9d57868dc..e5321cdbd 100644 --- a/src/lcm/regime_building/diagnostics.py +++ b/src/lcm/regime_building/diagnostics.py @@ -166,6 +166,9 @@ def _wrap_with_reduction( """ + # `kwargs` carries the wrapped function's full input map: the + # `next_regime_to_V_arr` mapping alongside the Float/Int/Bool-valued + # state/action inputs. def reduced( **kwargs: MappingProxyType[RegimeName, FloatND] | FloatND | IntND | BoolND, ) -> dict[str, Any]: diff --git a/src/lcm/regime_building/processing.py b/src/lcm/regime_building/processing.py index 7dd1d610c..d5285d96c 100644 --- a/src/lcm/regime_building/processing.py +++ b/src/lcm/regime_building/processing.py @@ -890,7 +890,11 @@ def _get_weights_func_for_shock(*, name: str, grid: _ShockGrid) -> UserFunction: @with_signature(args=args, return_annotation="FloatND", enforce=False) def weights_func_runtime(*a: FloatND, **kwargs: FloatND) -> Float1D: # noqa: ARG001 - shock_kw: dict[str, float] = { # ty: ignore[invalid-assignment] + # `float` here covers Python floats from fixed_params; under + # JIT tracing, the runtime values forwarded through `kwargs` + # arrive as JAX tracers (`FloatND`), which are accepted by the + # shock grid's `compute_gridpoints` / `compute_transition_probs`. + shock_kw: dict[str, float | FloatND] = { **fixed_params, **{raw: kwargs[qn] for qn, raw in runtime_param_names.items()}, } diff --git a/src/lcm/regime_building/validation.py b/src/lcm/regime_building/validation.py index 5457e03f5..471d16548 100644 --- a/src/lcm/regime_building/validation.py +++ b/src/lcm/regime_building/validation.py @@ -241,7 +241,7 @@ def _find_function_output_grid_indexing( def collect_state_transitions( - states: Mapping[StateName, Grid], + states: Mapping[StateName, Grid | None], state_transitions: Mapping[ StateName, UserFunction | Callable | None | Mapping[RegimeName, UserFunction | Callable], diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index 9e531ebb5..441023d52 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -12,7 +12,7 @@ from lcm.ages import AgeGrid from lcm.interfaces import InternalRegime, _build_regime_sharding -from lcm.typing import FloatND, InternalParams, RegimeName, StateName +from lcm.typing import BoolND, FloatND, InternalParams, RegimeName, StateName from lcm.utils.error_handling import validate_V from lcm.utils.logging import ( format_duration, @@ -107,8 +107,8 @@ def solve( diagnostic_min: list[FloatND] = [] diagnostic_max: list[FloatND] = [] diagnostic_mean: list[FloatND] = [] - running_any_nan: FloatND = jnp.zeros((), dtype=bool) - running_any_inf: FloatND = jnp.zeros((), dtype=bool) + running_any_nan: BoolND = jnp.zeros((), dtype=bool) + running_any_inf: BoolND = jnp.zeros((), dtype=bool) logger.info("Starting solution") total_start = time.monotonic() @@ -471,8 +471,8 @@ def _emit_post_loop_diagnostics( solution: MappingProxyType[int, MappingProxyType[RegimeName, FloatND]], internal_regimes: MappingProxyType[RegimeName, InternalRegime], internal_params: InternalParams, - running_any_nan: FloatND, - running_any_inf: FloatND, + running_any_nan: BoolND, + running_any_inf: BoolND, diagnostic_min: list[FloatND] | None, diagnostic_max: list[FloatND] | None, diagnostic_mean: list[FloatND] | None, diff --git a/tests/test_ndimage_unit.py b/tests/test_ndimage_unit.py index 516a5aec2..f70903acc 100644 --- a/tests/test_ndimage_unit.py +++ b/tests/test_ndimage_unit.py @@ -12,8 +12,11 @@ def test_map_coordinates_wrong_input_dimensions(): - values = jnp.arange(2) # ndim = 1 - coordinates = [jnp.array([0]), jnp.array([1])] # len = 2 + values = jnp.arange(2, dtype=jnp.int32) # ndim = 1 + coordinates = [ + jnp.array([0], dtype=jnp.int32), + jnp.array([1], dtype=jnp.int32), + ] # len = 2 with pytest.raises(ValueError, match="coordinates must be a sequence of length"): map_coordinates(values, coordinates) @@ -29,7 +32,7 @@ def test_map_coordinates_extrapolation(): def test_nonempty_sum(): - a = jnp.arange(3) + a = jnp.arange(3, dtype=jnp.int32) expected = a + a + a got = _sum_all([a, a, a]) @@ -38,7 +41,7 @@ def test_nonempty_sum(): def test_nonempty_prod(): - a = jnp.arange(3) + a = jnp.arange(3, dtype=jnp.int32) expected = a * a * a got = _multiply_all([a, a, a]) @@ -75,7 +78,7 @@ def test_linear_indices_and_weights_inside_domain(): def test_linear_indices_and_weights_outside_domain(): - coordinates = jnp.array([-1, 2]) + coordinates = jnp.array([-1.0, 2.0]) (idx_low, weight_low), (idx_high, weight_high) = _compute_indices_and_weights( coordinates, input_size=2 diff --git a/tests/test_next_state.py b/tests/test_next_state.py index f89bad620..08f51649b 100644 --- a/tests/test_next_state.py +++ b/tests/test_next_state.py @@ -102,7 +102,7 @@ class MockCategory: def test_create_stochastic_next_func(): - labels = jnp.arange(2) + labels = jnp.arange(2, dtype=jnp.int32) got_func = _create_discrete_stochastic_next_func( target="t", next_state_name="next_a", labels=labels )