22Unit tests
33"""
44
5+ import numpy as np
56import pytest
67
78from simulation .parameters import (
89 ASUArrivals , RehabArrivals , ASULOS , RehabLOS ,
910 ASURouting , RehabRouting , Param )
11+ from simulation .distributions import Exponential , LogNormal , Discrete
1012
1113
14+ # -----------------------------------------------------------------------------
15+ # Parameter classes
16+ # -----------------------------------------------------------------------------
17+
1218@pytest .mark .parametrize ('class_to_test' , [
1319 ASUArrivals , RehabArrivals , ASULOS , RehabLOS ,
1420 ASURouting , RehabRouting , Param ])
@@ -27,3 +33,182 @@ def test_new_attribute(class_to_test):
2733 with pytest .raises (AttributeError ,
2834 match = 'only possible to modify existing attributes' ):
2935 setattr (instance , 'new_entry' , 3 )
36+
37+
38+ # -----------------------------------------------------------------------------
39+ # Distributions
40+ # -----------------------------------------------------------------------------
41+
42+ @pytest .mark .parametrize ('dist_class, params, expected_mean, expected_type' , [
43+ (Exponential , {'mean' : 5 }, 5 , float ),
44+ (LogNormal , {'mean' : 1 , 'stdev' : 0.5 }, 1 , float ),
45+ (Discrete , {'values' : [1 , 2 , 3 ], 'freq' : [0.2 , 0.5 , 0.3 ]}, 2.1 , int )
46+ ])
47+ def test_samples (dist_class , params , expected_mean , expected_type ):
48+ """
49+ Test that generated samples match the expected mean and requested size,
50+ and that the random seed is working.
51+
52+ Arguments:
53+ dist_class (class):
54+ Distribution class to test.
55+ params (dict):
56+ Parameters for initialising the distribution.
57+ expected_mean (float):
58+ Expected mean of the distribution.
59+ expected_type (type):
60+ Expected type of the sample (float for continuous distributions,
61+ int for discrete).
62+ """
63+ # Initialise the distribution
64+ dist = dist_class (random_seed = 42 , ** params )
65+
66+ # Check that sample is a float
67+ x = dist .sample ()
68+ assert isinstance (x , expected_type ), (
69+ f'Expected sample() to return a { expected_type } - instead: { type (x )} '
70+ )
71+
72+ # Check that the mean of generated samples is close to the expected mean
73+ samples = dist .sample (size = 10000 )
74+ assert np .isclose (np .mean (samples ), expected_mean , rtol = 0.1 )
75+
76+ # Check that the sample size matches the requested size
77+ assert len (samples ) == 10000
78+
79+ # Check that the same seed returns the same sample
80+ sample1 = dist_class (random_seed = 5 , ** params ).sample (size = 5 )
81+ sample2 = dist_class (random_seed = 5 , ** params ).sample (size = 5 )
82+ assert np .array_equal (sample1 , sample2 ), (
83+ 'Samples with the same random seeds should be equal.'
84+ )
85+
86+ # Check that different seeds return different samples
87+ sample3 = dist_class (random_seed = 89 , ** params ).sample (size = 5 )
88+ assert not np .array_equal (sample1 , sample3 ), (
89+ 'Samples with different random seeds should not be equal.'
90+ )
91+
92+
93+ def test_invalid_exponential ():
94+ """
95+ Ensure that Exponential distribution cannot be created with a negative
96+ or zero mean.
97+ """
98+ # Negative mean
99+ with pytest .raises (ValueError ):
100+ Exponential (mean = - 5 , random_seed = 42 )
101+
102+ # Zero mean
103+ with pytest .raises (ValueError ):
104+ Exponential (mean = 0 , random_seed = 42 )
105+
106+ # Check that no negative values are sampled
107+ d = Exponential (mean = 10 , random_seed = 42 )
108+ bigsample = d .sample (size = 100000 )
109+ assert all (x > 0 for x in bigsample ), (
110+ 'Sample contains non-positive values.'
111+ )
112+
113+
114+ def test_lognormal_moments ():
115+ """
116+ Test the calculation of normal distribution parameters (mu, sigma) from
117+ lognormal parameters.
118+
119+ This test verifies that:
120+ 1. The normal_moments_from_lognormal method correctly converts lognormal
121+ parameters (mean, variance) to normal distribution parameters (mu, sigma).
122+ 2. The calculated values match the expected mathematical formulas.
123+ """
124+ # Define lognormal parameters
125+ mean , stdev = 2.0 , 0.5
126+
127+ # Initialise distribution and get calculated parameters
128+ dist = LogNormal (mean = mean , stdev = stdev , random_seed = 42 )
129+ calculated_mu , calculated_sigma = (
130+ dist .normal_moments_from_lognormal (mean , stdev ** 2 ))
131+
132+ # Verify calculated parameters match expected mathematical formulas
133+ # Formula for mu: ln(mean²/√(stdev² + mean²))
134+ assert np .isclose (calculated_mu ,
135+ np .log (mean ** 2 / np .sqrt (stdev ** 2 + mean ** 2 )),
136+ rtol = 1e-5 )
137+ # Formula for sigma: √ln(1 + stdev²/mean²)
138+ assert np .isclose (calculated_sigma ,
139+ np .sqrt (np .log (1 + stdev ** 2 / mean ** 2 )),
140+ rtol = 1e-5 )
141+
142+
143+ def test_discrete_probabilities ():
144+ """
145+ Test correct calculation of probabilities for Discrete distribution.
146+
147+ This test verifies that:
148+ 1. The Discrete class correctly normalizes frequency values to
149+ probabilities.
150+ 2. The sum of probabilities equals 1
151+ 3. The relative proportions match the input frequencies
152+ """
153+ # Define discrete distribution parameters
154+ values = [1 , 2 , 3 ]
155+ freq = [10 , 20 , 30 ]
156+
157+ # Initialise distribution
158+ dist = Discrete (values = values , freq = freq , random_seed = 42 )
159+
160+ # Calculate expected probabilities by normalising frequencies
161+ expected_probs = np .array (freq ) / np .sum (freq )
162+
163+ # Verify calculated probabilities match expected values
164+ assert np .allclose (dist .probabilities , expected_probs , rtol = 1e-5 )
165+
166+ # Verify probabilities sum to 1
167+ assert np .isclose (np .sum (dist .probabilities ), 1.0 , rtol = 1e-10 )
168+
169+
170+ def test_discrete_value_error ():
171+ """
172+ Test if Discrete raises ValueError for mismatched inputs.
173+
174+ This test verifies that the Discrete class correctly validates that
175+ the values and frequencies arrays have the same length.
176+ """
177+ # Attempt to initialise with mismatched array lengths
178+ with pytest .raises (ValueError ):
179+ Discrete (values = [1 , 2 ], freq = [0.5 ], random_seed = 42 )
180+
181+
182+ def test_invalid_input_types ():
183+ """
184+ Test error handling for invalid string input types.
185+
186+ This test verifies that appropriate errors are raised when
187+ string values are provided instead of numeric values for
188+ distribution parameters.
189+ """
190+ with pytest .raises (TypeError ):
191+ Exponential (mean = '5' )
192+ with pytest .raises (TypeError ):
193+ LogNormal (mean = '4' , stdev = 1 )
194+ with pytest .raises (TypeError ):
195+ Discrete (values = [1 , 2 , 3 ], freq = ['0.2' , '0.5' , '0.3' ])
196+
197+
198+ def test_discrete_uneven_probabilities ():
199+ """
200+ Test behavior of Discrete distribution with highly uneven probabilities.
201+
202+ This test verifies that when one probability is much larger than another,
203+ the sampling correctly reflects this imbalance by rarely selecting the
204+ low-probability value.
205+ """
206+ # Create distribution with extremely uneven probabilities
207+ dist = Discrete (values = [1 , 2 ], freq = [1 , 1e-10 ], random_seed = 42 )
208+
209+ # Generate a large sample
210+ samples = dist .sample (size = 10000 )
211+
212+ # Verify that the low-probability value (2) appears very rarely
213+ # With p ≈ 1e-10, we expect virtually no occurrences of value 2
214+ assert np .sum (samples == 2 ) < 5
0 commit comments