From 53cead4b47ef5abfa66e96902706f9f8e63e7869 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 23 Feb 2026 10:00:47 +0000 Subject: [PATCH 1/7] adjust test_jax tols --- python/tests/test_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 838a9f8144..34088a7942 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -204,7 +204,7 @@ def check_fields_jax( "iy_trafos": jnp.array(iy_trafos), "x_preeq": jnp.array([]), "solver": diffrax.Kvaerno5(), - "controller": diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM), + "controller": diffrax.PIDController(atol=1e-8, rtol=1e-8), "root_finder": optimistix.Newton(atol=ATOL_SIM, rtol=RTOL_SIM), "adjoint": diffrax.RecursiveCheckpointAdjoint(), "steady_state_event": diffrax.steady_state_event(), From 0221c4619ba701aa6c4ee87afc1ddef865ffccb2 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 23 Feb 2026 11:04:04 +0000 Subject: [PATCH 2/7] increase max steps for jax benchmarks --- tests/benchmark_models/test_petab_benchmark_jax.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/benchmark_models/test_petab_benchmark_jax.py b/tests/benchmark_models/test_petab_benchmark_jax.py index a3389498dd..f89bcefd5a 100644 --- a/tests/benchmark_models/test_petab_benchmark_jax.py +++ b/tests/benchmark_models/test_petab_benchmark_jax.py @@ -10,7 +10,7 @@ from amici.importers.petab.v1 import ( import_petab_problem, ) -from amici.jax.petab import run_simulations +from amici.jax.petab import run_simulations, DEFAULT_CONTROLLER_SETTINGS from amici.sim.sundials import SensitivityMethod, SensitivityOrder from amici.sim.sundials.petab.v1 import ( LLH, @@ -117,7 +117,7 @@ def test_jax_llh(benchmark_problem): else: atol = 1e-8 rtol = 1e-8 - max_steps = 1024 + max_steps = 2 * 10**5 beartype(run_simulations)(jax_problem) (llh_jax, _), sllh_jax = eqx.filter_value_and_grad( run_simulations, has_aux=True @@ -130,7 +130,10 @@ def test_jax_llh(benchmark_problem): ), ) else: - llh_jax, _ = beartype(run_simulations)(jax_problem) + llh_jax, _ = beartype(run_simulations)( + jax_problem, + max_steps=2 * 10**5, + ) np.testing.assert_allclose( llh_jax, From 5790bbe38e7567856c9cf75cfdb887a485494436 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 23 Feb 2026 11:04:39 +0000 Subject: [PATCH 3/7] pin optax for notebook test --- scripts/installAmiciSource.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/installAmiciSource.sh b/scripts/installAmiciSource.sh index 4c416a5788..436ac86f97 100755 --- a/scripts/installAmiciSource.sh +++ b/scripts/installAmiciSource.sh @@ -39,7 +39,7 @@ python -m pip install --upgrade pip wheel python -m pip install --upgrade pip setuptools cmake_build_extension==0.6.0 numpy petab swig python -m pip install git+https://github.com/pysb/pysb@master # for SPM with compartments python -m pip install git+https://github.com/patrick-kidger/diffrax@main # for events with direction -python -m pip install optax # for jax petab notebook +python -m pip install 'optax<0.2.7' # for jax petab notebook AMICI_BUILD_TEMP="${AMICI_PATH}/python/sdist/build/temp" \ python -m pip install --verbose -e "${AMICI_PATH}/python/sdist[petab,test,vis,jax]" --no-build-isolation deactivate From 26ccdf60d2d9221b58e187c29b4be4c7d670c87a Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 23 Feb 2026 11:09:53 +0000 Subject: [PATCH 4/7] increase max steps for petab tests --- python/sdist/amici/jax/petab.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 29425cf6c5..33e9490b87 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -1550,7 +1550,7 @@ def run_simulations( steady_state_event: Callable[ ..., diffrax._custom_types.BoolScalarLike ] = diffrax.steady_state_event(), - max_steps: int = 2**10, + max_steps: int = 2**11, ret: ReturnValue | str = ReturnValue.llh, ): """ @@ -1653,7 +1653,7 @@ def petab_simulate( steady_state_event: Callable[ ..., diffrax._custom_types.BoolScalarLike ] = diffrax.steady_state_event(), - max_steps: int = 2**10, + max_steps: int = 2**11, ): """ Run simulations for a problem and return the results as a petab simulation dataframe. From 61aa312623e93e03c7167623b38a0711641c8a03 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 23 Feb 2026 11:59:27 +0000 Subject: [PATCH 5/7] increase petab max steps again --- python/sdist/amici/jax/petab.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 33e9490b87..732984cb16 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -1550,7 +1550,7 @@ def run_simulations( steady_state_event: Callable[ ..., diffrax._custom_types.BoolScalarLike ] = diffrax.steady_state_event(), - max_steps: int = 2**11, + max_steps: int = 2**13, ret: ReturnValue | str = ReturnValue.llh, ): """ @@ -1653,7 +1653,7 @@ def petab_simulate( steady_state_event: Callable[ ..., diffrax._custom_types.BoolScalarLike ] = diffrax.steady_state_event(), - max_steps: int = 2**11, + max_steps: int = 2**13, ): """ Run simulations for a problem and return the results as a petab simulation dataframe. From 6a433e90c03002ef4e7f235fcf62e6af246c54e3 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 23 Feb 2026 12:28:17 +0000 Subject: [PATCH 6/7] even higher max steps for Weber benchmark --- tests/benchmark_models/test_petab_benchmark_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/benchmark_models/test_petab_benchmark_jax.py b/tests/benchmark_models/test_petab_benchmark_jax.py index f89bcefd5a..f769b52851 100644 --- a/tests/benchmark_models/test_petab_benchmark_jax.py +++ b/tests/benchmark_models/test_petab_benchmark_jax.py @@ -113,7 +113,7 @@ def test_jax_llh(benchmark_problem): if problem_id == "Weber_BMC2015": atol = cur_settings.atol_sim rtol = cur_settings.rtol_sim - max_steps = 2 * 10**5 + max_steps = 4 * 10**7 else: atol = 1e-8 rtol = 1e-8 From 11b9c906ff5a569caf691a4fcd677c7b10566fdd Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 23 Feb 2026 13:34:34 +0000 Subject: [PATCH 7/7] pin optax in docs build --- doc/rtd_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/rtd_requirements.txt b/doc/rtd_requirements.txt index 2c1e4784a6..4211b95cc9 100644 --- a/doc/rtd_requirements.txt +++ b/doc/rtd_requirements.txt @@ -9,7 +9,7 @@ setuptools>=67.7.2 git+https://github.com/jmuhlich/pysb@22d69a350b472f33d85ba64ffb10b190483c1c98 # For forward type definition in generate_equinox matplotlib>=3.7.1 -optax +optax==0.2.6 nbsphinx nbformat myst-parser