@@ -41,51 +41,36 @@ def test_new_attribute(class_to_test):
4141# Model
4242# -----------------------------------------------------------------------------
4343
44- class MockParam :
45- """
46- Mock parameter class.
47- """
48- def __init__ (self ):
49- """
50- Initialise with specific run periods and arrival parameters.
51- """
52- self .warm_up_period = 10
53- self .data_collection_period = 20
54-
55- self .asu_arrivals = namedtuple (
56- "ASUArrivals" , ["stroke" , "tia" , "neuro" , "other" ])(
57- stroke = 5 , tia = 7 , neuro = 10 , other = 15
58- )
59- self .rehab_arrivals = namedtuple (
60- "RehabArrivals" , ["stroke" , "tia" , "other" ])(
61- stroke = 8 , tia = 12 , other = 20
62- )
63-
64-
6544def test_create_distributions ():
6645 """
6746 Test that distributions are created correctly for all units and patient
68- types specified in MockParam .
47+ types specified.
6948 """
70- param = MockParam ()
49+ param = Param (
50+ asu_arrivals = namedtuple (
51+ "ASUArrivals" , ["stroke" , "tia" , "neuro" , "other" ])(
52+ stroke = 5 , tia = 7 , neuro = 10 , other = 15 ),
53+ rehab_arrivals = namedtuple (
54+ "RehabArrivals" , ["stroke" , "tia" , "other" ])(
55+ stroke = 8 , tia = 12 , other = 20 ))
7156 model = Model (param , run_number = 42 )
7257
73- # Check ASU distributions
74- assert len (model .distributions ["asu" ]) == 4
75- assert "stroke" in model .distributions ["asu" ]
76- assert "tia" in model .distributions ["asu" ]
77- assert "neuro" in model .distributions ["asu" ]
78- assert "other" in model .distributions ["asu" ]
79-
80- # Check Rehab distributions
81- assert len (model .distributions ["rehab" ]) == 3
82- assert "stroke" in model .distributions ["rehab" ]
83- assert "tia" in model .distributions ["rehab" ]
84- assert "other" in model .distributions ["rehab" ]
85- assert "neuro" not in model .distributions ["rehab" ]
86-
87- # Check that all distributions are Exponential
88- for _ , unit_dict in model .distributions .items ():
58+ # Check ASU arrival distributions
59+ assert len (model .arrival_dist ["asu" ]) == 4
60+ assert "stroke" in model .arrival_dist ["asu" ]
61+ assert "tia" in model .arrival_dist ["asu" ]
62+ assert "neuro" in model .arrival_dist ["asu" ]
63+ assert "other" in model .arrival_dist ["asu" ]
64+
65+ # Check Rehab arrival distributions
66+ assert len (model .arrival_dist ["rehab" ]) == 3
67+ assert "stroke" in model .arrival_dist ["rehab" ]
68+ assert "tia" in model .arrival_dist ["rehab" ]
69+ assert "other" in model .arrival_dist ["rehab" ]
70+ assert "neuro" not in model .arrival_dist ["rehab" ]
71+
72+ # Check that all arrival distributions are Exponential
73+ for _ , unit_dict in model .arrival_dist .items ():
8974 for patient_type in unit_dict :
9075 assert isinstance (unit_dict [patient_type ], Exponential )
9176
@@ -95,16 +80,16 @@ def test_sampling_seed_reproducibility():
9580 Test that using the same seed produces the same results when sampling
9681 from one of the arrival distributions.
9782 """
98- param = MockParam ()
83+ param = Param ()
9984
10085 # Create two models with the same seed
10186 model1 = Model (param , run_number = 123 )
10287 model2 = Model (param , run_number = 123 )
10388
10489 # Sample from a distribution in both models
105- samples1 = [model1 .distributions ["asu" ]["stroke" ].sample ()
90+ samples1 = [model1 .arrival_dist ["asu" ]["stroke" ].sample ()
10691 for _ in range (10 )]
107- samples2 = [model2 .distributions ["asu" ]["stroke" ].sample ()
92+ samples2 = [model2 .arrival_dist ["asu" ]["stroke" ].sample ()
10893 for _ in range (10 )]
10994
11095 # Check that the samples are the same
@@ -116,7 +101,7 @@ def test_run_time():
116101 Check that the run length is correct with varying warm-up and data
117102 collection periods.
118103 """
119- param = MockParam ( )
104+ param = Param ( warm_up_period = 10 , data_collection_period = 20 )
120105
121106 # Test with zero warm-up period
122107 param .warm_up_period = 0
0 commit comments