22Unit tests
33"""
44
5+ from collections import namedtuple
6+ import numpy as np
57import pytest
8+ from sim_tools .distributions import Exponential
69
710from simulation .parameters import (
811 ASUArrivals , RehabArrivals , ASULOS , RehabLOS ,
912 ASURouting , RehabRouting , Param )
13+ from simulation .model import Model
1014
1115
16+ # -----------------------------------------------------------------------------
17+ # Parameters
18+ # -----------------------------------------------------------------------------
19+
1220@pytest .mark .parametrize ("class_to_test" , [
1321 ASUArrivals , RehabArrivals , ASULOS , RehabLOS ,
1422 ASURouting , RehabRouting , Param ])
@@ -27,3 +35,107 @@ def test_new_attribute(class_to_test):
2735 with pytest .raises (AttributeError ,
2836 match = "only possible to modify existing attributes" ):
2937 setattr (instance , "new_entry" , 3 )
38+
39+
40+ # -----------------------------------------------------------------------------
41+ # Model
42+ # -----------------------------------------------------------------------------
43+
44+ class MockParam :
45+ """
46+ Mock parameter class.
47+ """
48+ def __init__ (self ):
49+ """
50+ Initialise with specific run periods and arrival parameters.
51+ """
52+ self .warm_up_period = 10
53+ self .data_collection_period = 20
54+
55+ self .asu_arrivals = namedtuple (
56+ "ASUArrivals" , ["stroke" , "tia" , "neuro" , "other" ])(
57+ stroke = 5 , tia = 7 , neuro = 10 , other = 15
58+ )
59+ self .rehab_arrivals = namedtuple (
60+ "RehabArrivals" , ["stroke" , "tia" , "other" ])(
61+ stroke = 8 , tia = 12 , other = 20
62+ )
63+
64+
65+ def test_create_distributions ():
66+ """
67+ Test that distributions are created correctly for all units and patient
68+ types specified in MockParam.
69+ """
70+ param = MockParam ()
71+ model = Model (param , run_number = 42 )
72+
73+ # Check ASU distributions
74+ assert len (model .distributions ["asu" ]) == 4
75+ assert "stroke" in model .distributions ["asu" ]
76+ assert "tia" in model .distributions ["asu" ]
77+ assert "neuro" in model .distributions ["asu" ]
78+ assert "other" in model .distributions ["asu" ]
79+
80+ # Check Rehab distributions
81+ assert len (model .distributions ["rehab" ]) == 3
82+ assert "stroke" in model .distributions ["rehab" ]
83+ assert "tia" in model .distributions ["rehab" ]
84+ assert "other" in model .distributions ["rehab" ]
85+ assert "neuro" not in model .distributions ["rehab" ]
86+
87+ # Check that all distributions are Exponential
88+ for _ , unit_dict in model .distributions .items ():
89+ for patient_type in unit_dict :
90+ assert isinstance (unit_dict [patient_type ], Exponential )
91+
92+
93+ def test_sampling_seed_reproducibility ():
94+ """
95+ Test that using the same seed produces the same results when sampling
96+ from one of the arrival distributions.
97+ """
98+ param = MockParam ()
99+
100+ # Create two models with the same seed
101+ model1 = Model (param , run_number = 123 )
102+ model2 = Model (param , run_number = 123 )
103+
104+ # Sample from a distribution in both models
105+ samples1 = [model1 .distributions ["asu" ]["stroke" ].sample ()
106+ for _ in range (10 )]
107+ samples2 = [model2 .distributions ["asu" ]["stroke" ].sample ()
108+ for _ in range (10 )]
109+
110+ # Check that the samples are the same
111+ np .testing .assert_array_almost_equal (samples1 , samples2 )
112+
113+
114+ def test_run_time ():
115+ """
116+ Check that the run length is correct with varying warm-up and data
117+ collection periods.
118+ """
119+ param = MockParam ()
120+
121+ # Test with zero warm-up period
122+ param .warm_up_period = 0
123+ model = Model (param , run_number = 42 )
124+ model .run ()
125+ assert model .env .now == param .data_collection_period
126+
127+ # Test with zero data collection period
128+ param .warm_up_period = 10
129+ param .data_collection_period = 0
130+ model = Model (param , run_number = 42 )
131+ model .run ()
132+ assert model .env .now == 10
133+ # assert len(model.patients) == 0
134+
135+ # Test with warm-up and data collection period
136+ param .warm_up_period = 12
137+ param .data_collection_period = 9
138+ model = Model (param , run_number = 42 )
139+ model .run ()
140+ assert model .env .now == 21
141+ assert len (model .patients ) > 0
0 commit comments