diff --git a/tests/benchmark_models/test_petab_benchmark_jax.py b/tests/benchmark_models/test_petab_benchmark_jax.py index 52d0d053ab..a3389498dd 100644 --- a/tests/benchmark_models/test_petab_benchmark_jax.py +++ b/tests/benchmark_models/test_petab_benchmark_jax.py @@ -1,6 +1,7 @@ import logging from functools import partial +import diffrax import equinox as eqx import jax import jax.numpy as jnp @@ -23,8 +24,6 @@ settings, ) -import diffrax - jax.config.update("jax_enable_x64", True) @@ -38,7 +37,12 @@ def test_jax_llh(benchmark_problem): benchmark_problem ) - to_skip = ["Smith_BMCSystBiol2013", "Oliveira_NatCommun2021", "SalazarCavazos_MBoC2020"] + to_skip = [ + "Liu_IFACPapersOnLine2025", + "Oliveira_NatCommun2021", + "SalazarCavazos_MBoC2020", + "Smith_BMCSystBiol2013", + ] if problem_id in to_skip: pytest.skip( f"Skipping {problem_id} due to non-supported events in JAX." @@ -118,12 +122,12 @@ def test_jax_llh(benchmark_problem): (llh_jax, _), sllh_jax = eqx.filter_value_and_grad( run_simulations, has_aux=True )( - jax_problem, + jax_problem, max_steps=max_steps, controller=diffrax.PIDController( atol=atol, rtol=rtol, - ) + ), ) else: llh_jax, _ = beartype(run_simulations)(jax_problem)