Skip to content

Commit 4038ad8

Browse files
committed
test(unit): add unit tests to check parameter validation, and alter test_create_distributions to just check that type of each distribution
1 parent f533872 commit 4038ad8

File tree

1 file changed

+92
-25
lines changed

1 file changed

+92
-25
lines changed

tests/test_unittest.py

Lines changed: 92 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
work as intended.
77
"""
88

9-
from collections import namedtuple
109
import numpy as np
1110
import pytest
12-
from sim_tools.distributions import Exponential
11+
from sim_tools.distributions import Exponential, Lognormal, Discrete
1312

1413
from simulation.parameters import (
1514
ASUArrivals, RehabArrivals, ASULOS, RehabLOS,
@@ -42,43 +41,111 @@ def test_new_attribute(class_to_test):
4241
setattr(instance, "new_entry", 3)
4342

4443

44+
def test_param_valid():
45+
"""
46+
Check that all default model parameters are valid.
47+
"""
48+
try:
49+
Param().check_param_validity()
50+
except Exception as exc:
51+
pytest.fail(
52+
f"check_param_validity() raised an unexpected exception: {exc}")
53+
54+
55+
@pytest.mark.parametrize("param, value, msg", [
56+
("warm_up_period", -1,
57+
"Parameter 'warm_up_period' must be greater than or equal to 0"),
58+
("data_collection_period", -5,
59+
"Parameter 'data_collection_period' must be greater than or equal to 0"),
60+
("number_of_runs", 0,
61+
"Parameter 'number_of_runs' must be greater than 0"),
62+
("audit_interval", -2,
63+
"Parameter 'audit_interval' must be greater than 0")])
64+
def test_param_errors(param, value, msg):
65+
"""
66+
Check that `check_param_validity()` catches parameter issues.
67+
"""
68+
model_param = Param()
69+
setattr(model_param, param, value)
70+
with pytest.raises(ValueError, match=msg):
71+
model_param.check_param_validity()
72+
73+
74+
def test_arrival_params():
75+
"""
76+
Test validation of arrival parameters.
77+
"""
78+
model_param = Param(asu_arrivals=ASUArrivals(stroke=-5))
79+
with pytest.raises(
80+
ValueError,
81+
match="Parameter 'stroke' from 'asu_arrivals' must be greater than 0"
82+
):
83+
model_param.check_param_validity()
84+
85+
86+
def test_los_params():
87+
"""
88+
Test validation of length of stay parameters.
89+
"""
90+
model_param = Param(asu_los=ASULOS(neuro_mean=-2, neuro_sd=1))
91+
with pytest.raises(
92+
ValueError,
93+
match=("Parameter 'mean' for 'neuro' in 'asu_los' must be greater " +
94+
"than 0")
95+
):
96+
model_param.check_param_validity()
97+
98+
99+
def test_routing_sum():
100+
"""
101+
Test validation of routing probabilities sum.
102+
"""
103+
model_param = Param(asu_routing=ASURouting(
104+
tia_rehab=0.6, tia_esd=0.2, tia_other=0.1))
105+
with pytest.raises(
106+
ValueError,
107+
match=("Routing probabilities for 'tia' in 'asu_routing' should sum " +
108+
"to apx. 1")
109+
):
110+
model_param.check_param_validity()
111+
112+
113+
def test_routing_range():
114+
"""
115+
Test validation of routing probability ranges.
116+
"""
117+
model_param = Param(asu_routing=ASURouting(
118+
neuro_rehab=1.1, neuro_esd=0.1, neuro_other=-0.2))
119+
with pytest.raises(ValueError, match="must be between 0 and 1"):
120+
model_param.check_param_validity()
121+
122+
45123
# -----------------------------------------------------------------------------
46124
# Model
47125
# -----------------------------------------------------------------------------
48126

49127
def test_create_distributions():
50128
"""
51-
Test that distributions are created correctly for all units and patient
52-
types specified.
129+
Check that distributions are all the correct type.
53130
"""
54-
param = Param(
55-
asu_arrivals=namedtuple(
56-
"ASUArrivals", ["stroke", "tia", "neuro", "other"])(
57-
stroke=5, tia=7, neuro=10, other=15),
58-
rehab_arrivals=namedtuple(
59-
"RehabArrivals", ["stroke", "tia", "other"])(
60-
stroke=8, tia=12, other=20))
131+
param = Param()
61132
model = Model(param, run_number=42)
62133

63-
# Check ASU arrival distributions
64-
assert len(model.arrival_dist["asu"]) == 4
65-
assert "stroke" in model.arrival_dist["asu"]
66-
assert "tia" in model.arrival_dist["asu"]
67-
assert "neuro" in model.arrival_dist["asu"]
68-
assert "other" in model.arrival_dist["asu"]
69-
70-
# Check Rehab arrival distributions
71-
assert len(model.arrival_dist["rehab"]) == 3
72-
assert "stroke" in model.arrival_dist["rehab"]
73-
assert "tia" in model.arrival_dist["rehab"]
74-
assert "other" in model.arrival_dist["rehab"]
75-
assert "neuro" not in model.arrival_dist["rehab"]
76-
77134
# Check that all arrival distributions are Exponential
78135
for _, unit_dict in model.arrival_dist.items():
79136
for patient_type in unit_dict:
80137
assert isinstance(unit_dict[patient_type], Exponential)
81138

139+
# Check that all length of stay distributions are Lognormal
140+
for _, unit_dict in model.los_dist.items():
141+
for patient_type in unit_dict:
142+
assert isinstance(unit_dict[patient_type], Lognormal)
143+
144+
# Check that all routing distributions are Discrete
145+
for _, unit_dict in model.routing_dist.items():
146+
for patient_type in unit_dict:
147+
assert isinstance(unit_dict[patient_type], Discrete)
148+
82149

83150
def test_sampling_seed_reproducibility():
84151
"""

0 commit comments

Comments
 (0)