Skip to content

Commit 34059c9

Browse files
committed
test/docs(model): add some unit tests for Model + update log with thoughts
1 parent 7b4df7c commit 34059c9

File tree

2 files changed

+227
-0
lines changed

2 files changed

+227
-0
lines changed

docs/log.md

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,4 +612,119 @@ class Model:
612612
613613
# Run the simulation
614614
self.env.run(until=run_length)
615+
```
616+
617+
## Tests
618+
619+
We could add some basic tests now, e.g.
620+
621+
* Check that distributions are created for all patients (eg. "stroke" in distributions asu, "tia" in ...)
622+
* Run with no warm-up and check env.now == param.data_collection_period
623+
* Run with no data collection and check env.now == param.warm_up
624+
625+
Although as not actually doing warm-up logic yet, going down that path a bit premature.
626+
627+
> 💡 Start with basic run time, then change to warm up + data collection
628+
629+
> 💡 Maybe don't need to be mentioning tests at this stage yet.
630+
631+
```
632+
class MockParam:
633+
"""
634+
Mock parameter class.
635+
"""
636+
def __init__(self):
637+
"""
638+
Initialise with specific run periods and arrival parameters.
639+
"""
640+
self.warm_up_period = 10
641+
self.data_collection_period = 20
642+
643+
self.asu_arrivals = namedtuple(
644+
"ASUArrivals", ["stroke", "tia", "neuro", "other"])(
645+
stroke=5, tia=7, neuro=10, other=15
646+
)
647+
self.rehab_arrivals = namedtuple(
648+
"RehabArrivals", ["stroke", "tia", "other"])(
649+
stroke=8, tia=12, other=20
650+
)
651+
652+
653+
def test_create_distributions():
654+
"""
655+
Test that distributions are created correctly for all units and patient
656+
types specified in MockParam.
657+
"""
658+
param = MockParam()
659+
model = Model(param, run_number=42)
660+
661+
# Check ASU distributions
662+
assert len(model.distributions["asu"]) == 4
663+
assert "stroke" in model.distributions["asu"]
664+
assert "tia" in model.distributions["asu"]
665+
assert "neuro" in model.distributions["asu"]
666+
assert "other" in model.distributions["asu"]
667+
668+
# Check Rehab distributions
669+
assert len(model.distributions["rehab"]) == 3
670+
assert "stroke" in model.distributions["rehab"]
671+
assert "tia" in model.distributions["rehab"]
672+
assert "other" in model.distributions["rehab"]
673+
assert "neuro" not in model.distributions["rehab"]
674+
675+
# Check that all distributions are Exponential
676+
for _, unit_dict in model.distributions.items():
677+
for patient_type in unit_dict:
678+
assert isinstance(unit_dict[patient_type], Exponential)
679+
680+
681+
def test_sampling_seed_reproducibility():
682+
"""
683+
Test that using the same seed produces the same results when sampling
684+
from one of the arrival distributions.
685+
"""
686+
param = MockParam()
687+
688+
# Create two models with the same seed
689+
model1 = Model(param, run_number=123)
690+
model2 = Model(param, run_number=123)
691+
692+
# Sample from a distribution in both models
693+
samples1 = [model1.distributions["asu"]["stroke"].sample()
694+
for _ in range(10)]
695+
samples2 = [model2.distributions["asu"]["stroke"].sample()
696+
for _ in range(10)]
697+
698+
# Check that the samples are the same
699+
np.testing.assert_array_almost_equal(samples1, samples2)
700+
701+
702+
def test_run_time():
703+
"""
704+
Check that the run length is correct with varying warm-up and data
705+
collection periods.
706+
"""
707+
param = MockParam()
708+
709+
# Test with zero warm-up period
710+
param.warm_up_period = 0
711+
model = Model(param, run_number=42)
712+
model.run()
713+
assert model.env.now == param.data_collection_period
714+
715+
# Test with zero data collection period
716+
param.warm_up_period = 10
717+
param.data_collection_period = 0
718+
model = Model(param, run_number=42)
719+
model.run()
720+
assert model.env.now == 10
721+
# assert len(model.patients) == 0
722+
723+
# Test with warm-up and data collection period
724+
param.warm_up_period = 12
725+
param.data_collection_period = 9
726+
model = Model(param, run_number=42)
727+
model.run()
728+
assert model.env.now == 21
729+
assert len(model.patients) > 0
615730
```

tests/test_unittest.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,21 @@
22
Unit tests
33
"""
44

5+
from collections import namedtuple
6+
import numpy as np
57
import pytest
8+
from sim_tools.distributions import Exponential
69

710
from simulation.parameters import (
811
ASUArrivals, RehabArrivals, ASULOS, RehabLOS,
912
ASURouting, RehabRouting, Param)
13+
from simulation.model import Model
1014

1115

16+
# -----------------------------------------------------------------------------
17+
# Parameters
18+
# -----------------------------------------------------------------------------
19+
1220
@pytest.mark.parametrize("class_to_test", [
1321
ASUArrivals, RehabArrivals, ASULOS, RehabLOS,
1422
ASURouting, RehabRouting, Param])
@@ -27,3 +35,107 @@ def test_new_attribute(class_to_test):
2735
with pytest.raises(AttributeError,
2836
match="only possible to modify existing attributes"):
2937
setattr(instance, "new_entry", 3)
38+
39+
40+
# -----------------------------------------------------------------------------
41+
# Model
42+
# -----------------------------------------------------------------------------
43+
44+
class MockParam:
45+
"""
46+
Mock parameter class.
47+
"""
48+
def __init__(self):
49+
"""
50+
Initialise with specific run periods and arrival parameters.
51+
"""
52+
self.warm_up_period = 10
53+
self.data_collection_period = 20
54+
55+
self.asu_arrivals = namedtuple(
56+
"ASUArrivals", ["stroke", "tia", "neuro", "other"])(
57+
stroke=5, tia=7, neuro=10, other=15
58+
)
59+
self.rehab_arrivals = namedtuple(
60+
"RehabArrivals", ["stroke", "tia", "other"])(
61+
stroke=8, tia=12, other=20
62+
)
63+
64+
65+
def test_create_distributions():
66+
"""
67+
Test that distributions are created correctly for all units and patient
68+
types specified in MockParam.
69+
"""
70+
param = MockParam()
71+
model = Model(param, run_number=42)
72+
73+
# Check ASU distributions
74+
assert len(model.distributions["asu"]) == 4
75+
assert "stroke" in model.distributions["asu"]
76+
assert "tia" in model.distributions["asu"]
77+
assert "neuro" in model.distributions["asu"]
78+
assert "other" in model.distributions["asu"]
79+
80+
# Check Rehab distributions
81+
assert len(model.distributions["rehab"]) == 3
82+
assert "stroke" in model.distributions["rehab"]
83+
assert "tia" in model.distributions["rehab"]
84+
assert "other" in model.distributions["rehab"]
85+
assert "neuro" not in model.distributions["rehab"]
86+
87+
# Check that all distributions are Exponential
88+
for _, unit_dict in model.distributions.items():
89+
for patient_type in unit_dict:
90+
assert isinstance(unit_dict[patient_type], Exponential)
91+
92+
93+
def test_sampling_seed_reproducibility():
94+
"""
95+
Test that using the same seed produces the same results when sampling
96+
from one of the arrival distributions.
97+
"""
98+
param = MockParam()
99+
100+
# Create two models with the same seed
101+
model1 = Model(param, run_number=123)
102+
model2 = Model(param, run_number=123)
103+
104+
# Sample from a distribution in both models
105+
samples1 = [model1.distributions["asu"]["stroke"].sample()
106+
for _ in range(10)]
107+
samples2 = [model2.distributions["asu"]["stroke"].sample()
108+
for _ in range(10)]
109+
110+
# Check that the samples are the same
111+
np.testing.assert_array_almost_equal(samples1, samples2)
112+
113+
114+
def test_run_time():
115+
"""
116+
Check that the run length is correct with varying warm-up and data
117+
collection periods.
118+
"""
119+
param = MockParam()
120+
121+
# Test with zero warm-up period
122+
param.warm_up_period = 0
123+
model = Model(param, run_number=42)
124+
model.run()
125+
assert model.env.now == param.data_collection_period
126+
127+
# Test with zero data collection period
128+
param.warm_up_period = 10
129+
param.data_collection_period = 0
130+
model = Model(param, run_number=42)
131+
model.run()
132+
assert model.env.now == 10
133+
# assert len(model.patients) == 0
134+
135+
# Test with warm-up and data collection period
136+
param.warm_up_period = 12
137+
param.data_collection_period = 9
138+
model = Model(param, run_number=42)
139+
model.run()
140+
assert model.env.now == 21
141+
assert len(model.patients) > 0

0 commit comments

Comments
 (0)