Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test_petab_test_suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ jobs:
run: |
source ./venv/bin/activate \
&& python3 -m pip uninstall -y petab \
&& python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@8dc6c1c4b801fba5acc35fcd25308a659d01050e \
&& python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@44c8062ce1b87a74a0ba1bd2551de0cdc2a13ff1 \
&& python3 -m pip install git+https://github.com/pysb/pysb@master \
&& python3 -m pip install sympy>=1.12.1

Expand Down Expand Up @@ -186,7 +186,7 @@ jobs:
run: |
source ./venv/bin/activate \
&& python3 -m pip uninstall -y petab \
&& python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@d57d9fed8d8d5f8592e76d0b15676e05397c3b4b \
&& python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@44c8062ce1b87a74a0ba1bd2551de0cdc2a13ff1 \
&& python3 -m pip install git+https://github.com/pysb/pysb@master \
&& python3 -m pip install sympy>=1.12.1

Expand Down
38 changes: 18 additions & 20 deletions tests/petab_test_suite/test_petab_v2_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys

import diffrax
import jax
import pandas as pd
import petabtests
import pytest
Expand All @@ -21,7 +22,6 @@
)
from amici.sim.sundials.petab import PetabSimulator
from petab import v2
import jax

logger = get_logger(__name__, logging.DEBUG)
set_log_level(get_logger("amici.petab_import"), logging.DEBUG)
Expand Down Expand Up @@ -70,7 +70,6 @@ def _test_case(case, model_type, version, jax):
f"petab_{model_type}_test_case_{case}_{version.replace('.', '_')}"
)


if jax:
from amici.jax import petab_simulate, run_simulations
from amici.jax.petab import DEFAULT_CONTROLLER_SETTINGS
Expand All @@ -90,28 +89,25 @@ def _test_case(case, model_type, version, jax):

if case.startswith("0016"):
controller = diffrax.PIDController(
**DEFAULT_CONTROLLER_SETTINGS,
dtmax=0.5
**DEFAULT_CONTROLLER_SETTINGS, dtmax=0.5
)
else:
controller = diffrax.PIDController(
**DEFAULT_CONTROLLER_SETTINGS
)
controller = diffrax.PIDController(**DEFAULT_CONTROLLER_SETTINGS)

llh, _ = run_simulations(
jax_problem,
steady_state_event=steady_state_event,
jax_problem,
steady_state_event=steady_state_event,
controller=controller,
)
chi2, _ = run_simulations(
jax_problem,
ret="chi2",
steady_state_event=steady_state_event,
jax_problem,
ret="chi2",
steady_state_event=steady_state_event,
controller=controller,
)
simulation_df = petab_simulate(
jax_problem,
steady_state_event=steady_state_event,
jax_problem,
steady_state_event=steady_state_event,
controller=controller,
)
else:
Expand All @@ -137,7 +133,9 @@ def _test_case(case, model_type, version, jax):
)
chi2 = sum(rdata.chi2 for rdata in rdatas)
llh = res.llh
simulation_df = rdatas_to_simulation_df(rdatas, ps.model, pi.petab_problem)
simulation_df = rdatas_to_simulation_df(
rdatas, ps.model, pi.petab_problem
)

solution = petabtests.load_solution(case, model_type, version=version)
gt_chi2 = solution[petabtests.CHI2]
Expand Down Expand Up @@ -198,13 +196,13 @@ def _test_case(case, model_type, version, jax):
else:
if (case, model_type, version) in (
("0016", "sbml", "v2.0.0"),
("0024", "sbml", "v2.0.0"),
("0025", "sbml", "v2.0.0"),
("0013", "pysb", "v2.0.0"),
):
# FIXME: issue with events and sensitivities
...
else:
elif ps.model.nx_solver > 0:
# sensitivity calculation is currently only supported for models
# with state variables
check_derivatives(ps, problem_parameters)

if not all([llhs_match, simulations_match]) or not chi2s_match:
Expand Down Expand Up @@ -247,12 +245,12 @@ def run():
n_total = 0
version = "v2.0.0"

for jax in (False, True):
for jax_ in (False, True):
cases = list(petabtests.get_cases("sbml", version=version))
n_total += len(cases)
for case in cases:
try:
test_case(case, "sbml", version=version, jax=jax)
test_case(case, "sbml", version=version, jax=jax_)
n_success += 1
except Skipped:
n_skipped += 1
Expand Down
Loading