Skip to content

Commit 216a6df

Browse files
committed
fix(test)/style(lint): fixed the tests (removing dependence on MockParam) and lint
1 parent dcb9e8b commit 216a6df

File tree

3 files changed

+62
-66
lines changed

3 files changed

+62
-66
lines changed

notebooks/analysis.ipynb

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
{
22
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Analysis"
8+
]
9+
},
310
{
411
"cell_type": "code",
512
"execution_count": 1,
@@ -11,13 +18,11 @@
1118
"%autoreload 1\n",
1219
"%aimport simulation\n",
1320
"\n",
14-
"from collections import Counter\n",
21+
"# pylint: disable=wrong-import-position\n",
1522
"import os\n",
16-
"import pandas as pd\n",
1723
"import plotly.express as px\n",
1824
"\n",
1925
"from simulation.parameters import Param\n",
20-
"from simulation.model import Model\n",
2126
"from simulation.runner import Runner"
2227
]
2328
},
@@ -195,6 +200,9 @@
195200
" unit_lab = \"acute\"\n",
196201
" elif unit == \"rehab\":\n",
197202
" unit_lab = \"rehabilitation\"\n",
203+
" else:\n",
204+
" raise ValueError(\"unit must be either 'acute' or 'rehab'\")\n",
205+
"\n",
198206
" fig.update_layout(\n",
199207
" xaxis_title=f\"No. patients in {unit_lab} unit\",\n",
200208
" yaxis_title=\"% observations\",\n",
@@ -2116,6 +2124,9 @@
21162124
" unit_lab = \"acute\"\n",
21172125
" elif unit == \"rehab\":\n",
21182126
" unit_lab = \"rehabilitation\"\n",
2127+
" else:\n",
2128+
" raise ValueError(\"unit must be either 'acute' or 'rehab'\")\n",
2129+
"\n",
21192130
" fig.update_layout(\n",
21202131
" xaxis_title=f\"No. of {unit_lab} beds available\",\n",
21212132
" yaxis_title=\"Probability of delay\",\n",

simulation/model.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -129,26 +129,6 @@ def __init__(self, param, run_number):
129129
rehab_param=self.param.rehab_los,
130130
distribution_type="lognormal")
131131

132-
# Add model initialisation details to the log
133-
self.param.logger.log(sim_time=self.env.now, msg="Initialise model:\n")
134-
self.param.logger.log(vars(self))
135-
self.param.logger.log(msg="Parameters:\n ")
136-
self.param.logger.log(vars(self.param))
137-
self.param.logger.log(msg="Logger:\n ")
138-
self.param.logger.log(vars(self.param.logger))
139-
self.param.logger.log(msg="ASU arrivals:\n ")
140-
self.param.logger.log(vars(self.param.asu_arrivals))
141-
self.param.logger.log(msg="ASU LOS:\n ")
142-
self.param.logger.log(vars(self.param.asu_los))
143-
self.param.logger.log(msg="ASU routing:\n ")
144-
self.param.logger.log(vars(self.param.asu_routing))
145-
self.param.logger.log(msg="Rehab arrivals:\n ")
146-
self.param.logger.log(vars(self.param.rehab_arrivals))
147-
self.param.logger.log(msg="Rehab LOS:\n ")
148-
self.param.logger.log(vars(self.param.rehab_los))
149-
self.param.logger.log(msg="Rehab routing:\n ")
150-
self.param.logger.log(vars(self.param.rehab_routing))
151-
152132
def create_distributions(self, asu_param, rehab_param, distribution_type):
153133
"""
154134
Create a nested dictionary with two items: "asu" and "rehab". Each
@@ -377,6 +357,26 @@ def run(self):
377357
"""
378358
Run the simulation.
379359
"""
360+
# Add model initialisation details to the log
361+
self.param.logger.log(sim_time=self.env.now, msg="Initialise model:\n")
362+
self.param.logger.log(vars(self))
363+
self.param.logger.log(msg="Parameters:\n ")
364+
self.param.logger.log(vars(self.param))
365+
self.param.logger.log(msg="Logger:\n ")
366+
self.param.logger.log(vars(self.param.logger))
367+
self.param.logger.log(msg="ASU arrivals:\n ")
368+
self.param.logger.log(vars(self.param.asu_arrivals))
369+
self.param.logger.log(msg="ASU LOS:\n ")
370+
self.param.logger.log(vars(self.param.asu_los))
371+
self.param.logger.log(msg="ASU routing:\n ")
372+
self.param.logger.log(vars(self.param.asu_routing))
373+
self.param.logger.log(msg="Rehab arrivals:\n ")
374+
self.param.logger.log(vars(self.param.rehab_arrivals))
375+
self.param.logger.log(msg="Rehab LOS:\n ")
376+
self.param.logger.log(vars(self.param.rehab_los))
377+
self.param.logger.log(msg="Rehab routing:\n ")
378+
self.param.logger.log(vars(self.param.rehab_routing))
379+
380380
# Calculate the total run length
381381
run_length = (self.param.warm_up_period +
382382
self.param.data_collection_period)

tests/test_unittest.py

Lines changed: 28 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -41,51 +41,36 @@ def test_new_attribute(class_to_test):
4141
# Model
4242
# -----------------------------------------------------------------------------
4343

44-
class MockParam:
45-
"""
46-
Mock parameter class.
47-
"""
48-
def __init__(self):
49-
"""
50-
Initialise with specific run periods and arrival parameters.
51-
"""
52-
self.warm_up_period = 10
53-
self.data_collection_period = 20
54-
55-
self.asu_arrivals = namedtuple(
56-
"ASUArrivals", ["stroke", "tia", "neuro", "other"])(
57-
stroke=5, tia=7, neuro=10, other=15
58-
)
59-
self.rehab_arrivals = namedtuple(
60-
"RehabArrivals", ["stroke", "tia", "other"])(
61-
stroke=8, tia=12, other=20
62-
)
63-
64-
6544
def test_create_distributions():
6645
"""
6746
Test that distributions are created correctly for all units and patient
68-
types specified in MockParam.
47+
types specified.
6948
"""
70-
param = MockParam()
49+
param = Param(
50+
asu_arrivals=namedtuple(
51+
"ASUArrivals", ["stroke", "tia", "neuro", "other"])(
52+
stroke=5, tia=7, neuro=10, other=15),
53+
rehab_arrivals=namedtuple(
54+
"RehabArrivals", ["stroke", "tia", "other"])(
55+
stroke=8, tia=12, other=20))
7156
model = Model(param, run_number=42)
7257

73-
# Check ASU distributions
74-
assert len(model.distributions["asu"]) == 4
75-
assert "stroke" in model.distributions["asu"]
76-
assert "tia" in model.distributions["asu"]
77-
assert "neuro" in model.distributions["asu"]
78-
assert "other" in model.distributions["asu"]
79-
80-
# Check Rehab distributions
81-
assert len(model.distributions["rehab"]) == 3
82-
assert "stroke" in model.distributions["rehab"]
83-
assert "tia" in model.distributions["rehab"]
84-
assert "other" in model.distributions["rehab"]
85-
assert "neuro" not in model.distributions["rehab"]
86-
87-
# Check that all distributions are Exponential
88-
for _, unit_dict in model.distributions.items():
58+
# Check ASU arrival distributions
59+
assert len(model.arrival_dist["asu"]) == 4
60+
assert "stroke" in model.arrival_dist["asu"]
61+
assert "tia" in model.arrival_dist["asu"]
62+
assert "neuro" in model.arrival_dist["asu"]
63+
assert "other" in model.arrival_dist["asu"]
64+
65+
# Check Rehab arrival distributions
66+
assert len(model.arrival_dist["rehab"]) == 3
67+
assert "stroke" in model.arrival_dist["rehab"]
68+
assert "tia" in model.arrival_dist["rehab"]
69+
assert "other" in model.arrival_dist["rehab"]
70+
assert "neuro" not in model.arrival_dist["rehab"]
71+
72+
# Check that all arrival distributions are Exponential
73+
for _, unit_dict in model.arrival_dist.items():
8974
for patient_type in unit_dict:
9075
assert isinstance(unit_dict[patient_type], Exponential)
9176

@@ -95,16 +80,16 @@ def test_sampling_seed_reproducibility():
9580
Test that using the same seed produces the same results when sampling
9681
from one of the arrival distributions.
9782
"""
98-
param = MockParam()
83+
param = Param()
9984

10085
# Create two models with the same seed
10186
model1 = Model(param, run_number=123)
10287
model2 = Model(param, run_number=123)
10388

10489
# Sample from a distribution in both models
105-
samples1 = [model1.distributions["asu"]["stroke"].sample()
90+
samples1 = [model1.arrival_dist["asu"]["stroke"].sample()
10691
for _ in range(10)]
107-
samples2 = [model2.distributions["asu"]["stroke"].sample()
92+
samples2 = [model2.arrival_dist["asu"]["stroke"].sample()
10893
for _ in range(10)]
10994

11095
# Check that the samples are the same
@@ -116,7 +101,7 @@ def test_run_time():
116101
Check that the run length is correct with varying warm-up and data
117102
collection periods.
118103
"""
119-
param = MockParam()
104+
param = Param(warm_up_period=10, data_collection_period=20)
120105

121106
# Test with zero warm-up period
122107
param.warm_up_period = 0

0 commit comments

Comments
 (0)