From 3beaa231081d6f4d8bbc6e74a5cb0e69984436c7 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Mon, 9 Feb 2026 08:57:06 +0100 Subject: [PATCH] CI: Update libpetab, fix failing petab v2 tests Enable previously skipped tests. Skip gradient check for nx=0 models. --- .github/workflows/test_petab_test_suite.yml | 4 +- tests/petab_test_suite/test_petab_v2_suite.py | 38 +++++++++---------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/.github/workflows/test_petab_test_suite.yml b/.github/workflows/test_petab_test_suite.yml index 86d5fe30b0..2daee87a71 100644 --- a/.github/workflows/test_petab_test_suite.yml +++ b/.github/workflows/test_petab_test_suite.yml @@ -87,7 +87,7 @@ jobs: run: | source ./venv/bin/activate \ && python3 -m pip uninstall -y petab \ - && python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@8dc6c1c4b801fba5acc35fcd25308a659d01050e \ + && python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@44c8062ce1b87a74a0ba1bd2551de0cdc2a13ff1 \ && python3 -m pip install git+https://github.com/pysb/pysb@master \ && python3 -m pip install sympy>=1.12.1 @@ -186,7 +186,7 @@ jobs: run: | source ./venv/bin/activate \ && python3 -m pip uninstall -y petab \ - && python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@d57d9fed8d8d5f8592e76d0b15676e05397c3b4b \ + && python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@44c8062ce1b87a74a0ba1bd2551de0cdc2a13ff1 \ && python3 -m pip install git+https://github.com/pysb/pysb@master \ && python3 -m pip install sympy>=1.12.1 diff --git a/tests/petab_test_suite/test_petab_v2_suite.py b/tests/petab_test_suite/test_petab_v2_suite.py index d43cec15c2..6895da6149 100755 --- a/tests/petab_test_suite/test_petab_v2_suite.py +++ b/tests/petab_test_suite/test_petab_v2_suite.py @@ -5,6 +5,7 @@ import sys import diffrax +import jax import pandas as pd import petabtests import pytest @@ -21,7 +22,6 @@ ) from amici.sim.sundials.petab import PetabSimulator from petab import v2 -import jax logger = get_logger(__name__, logging.DEBUG) set_log_level(get_logger("amici.petab_import"), logging.DEBUG) @@ -70,7 +70,6 @@ def _test_case(case, model_type, version, jax): f"petab_{model_type}_test_case_{case}_{version.replace('.', '_')}" ) - if jax: from amici.jax import petab_simulate, run_simulations from amici.jax.petab import DEFAULT_CONTROLLER_SETTINGS @@ -90,28 +89,25 @@ def _test_case(case, model_type, version, jax): if case.startswith("0016"): controller = diffrax.PIDController( - **DEFAULT_CONTROLLER_SETTINGS, - dtmax=0.5 + **DEFAULT_CONTROLLER_SETTINGS, dtmax=0.5 ) else: - controller = diffrax.PIDController( - **DEFAULT_CONTROLLER_SETTINGS - ) + controller = diffrax.PIDController(**DEFAULT_CONTROLLER_SETTINGS) llh, _ = run_simulations( - jax_problem, - steady_state_event=steady_state_event, + jax_problem, + steady_state_event=steady_state_event, controller=controller, ) chi2, _ = run_simulations( - jax_problem, - ret="chi2", - steady_state_event=steady_state_event, + jax_problem, + ret="chi2", + steady_state_event=steady_state_event, controller=controller, ) simulation_df = petab_simulate( - jax_problem, - steady_state_event=steady_state_event, + jax_problem, + steady_state_event=steady_state_event, controller=controller, ) else: @@ -137,7 +133,9 @@ def _test_case(case, model_type, version, jax): ) chi2 = sum(rdata.chi2 for rdata in rdatas) llh = res.llh - simulation_df = rdatas_to_simulation_df(rdatas, ps.model, pi.petab_problem) + simulation_df = rdatas_to_simulation_df( + rdatas, ps.model, pi.petab_problem + ) solution = petabtests.load_solution(case, model_type, version=version) gt_chi2 = solution[petabtests.CHI2] @@ -198,13 +196,13 @@ def _test_case(case, model_type, version, jax): else: if (case, model_type, version) in ( ("0016", "sbml", "v2.0.0"), - ("0024", "sbml", "v2.0.0"), - ("0025", "sbml", "v2.0.0"), ("0013", "pysb", "v2.0.0"), ): # FIXME: issue with events and sensitivities ... - else: + elif ps.model.nx_solver > 0: + # sensitivity calculation is currently only supported for models + # with state variables check_derivatives(ps, problem_parameters) if not all([llhs_match, simulations_match]) or not chi2s_match: @@ -247,12 +245,12 @@ def run(): n_total = 0 version = "v2.0.0" - for jax in (False, True): + for jax_ in (False, True): cases = list(petabtests.get_cases("sbml", version=version)) n_total += len(cases) for case in cases: try: - test_case(case, "sbml", version=version, jax=jax) + test_case(case, "sbml", version=version, jax=jax_) n_success += 1 except Skipped: n_skipped += 1