Skip to content

Commit 54e7043

Browse files
committed
test(distributions): add unit tests for the distribution classes (+ set random seed default to None)
1 parent 6dc7e5b commit 54e7043

File tree

2 files changed

+189
-4
lines changed

2 files changed

+189
-4
lines changed

simulation/distributions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Exponential:
1818
1919
This class is from Heather and Monks 2025, who adapted from Monks 2021.
2020
"""
21-
def __init__(self, mean, random_seed):
21+
def __init__(self, mean, random_seed=None):
2222
"""
2323
Initialises a new distribution.
2424
@@ -57,7 +57,7 @@ class LogNormal:
5757
5858
This class is adapted from Monks 2021.
5959
"""
60-
def __init__(self, mean, stdev, random_seed):
60+
def __init__(self, mean, stdev, random_seed=None):
6161
"""
6262
Initialises a new distribution.
6363
@@ -117,7 +117,7 @@ class Discrete:
117117
118118
This class is adapted from Monks 2021.
119119
"""
120-
def __init__(self, values, freq, random_seed):
120+
def __init__(self, values, freq, random_seed=None):
121121
"""
122122
Initialises a new distribution.
123123
@@ -137,7 +137,7 @@ def __init__(self, values, freq, random_seed):
137137
self.probabilities = self.freq / self.freq.sum()
138138
self.rand = np.random.default_rng(random_seed)
139139

140-
def sample(self, size):
140+
def sample(self, size=None):
141141
"""
142142
Generate sample.
143143

tests/test_unittest.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,19 @@
22
Unit tests
33
"""
44

5+
import numpy as np
56
import pytest
67

78
from simulation.parameters import (
89
ASUArrivals, RehabArrivals, ASULOS, RehabLOS,
910
ASURouting, RehabRouting, Param)
11+
from simulation.distributions import Exponential, LogNormal, Discrete
1012

1113

14+
# -----------------------------------------------------------------------------
15+
# Parameter classes
16+
# -----------------------------------------------------------------------------
17+
1218
@pytest.mark.parametrize('class_to_test', [
1319
ASUArrivals, RehabArrivals, ASULOS, RehabLOS,
1420
ASURouting, RehabRouting, Param])
@@ -27,3 +33,182 @@ def test_new_attribute(class_to_test):
2733
with pytest.raises(AttributeError,
2834
match='only possible to modify existing attributes'):
2935
setattr(instance, 'new_entry', 3)
36+
37+
38+
# -----------------------------------------------------------------------------
39+
# Distributions
40+
# -----------------------------------------------------------------------------
41+
42+
@pytest.mark.parametrize('dist_class, params, expected_mean, expected_type', [
43+
(Exponential, {'mean': 5}, 5, float),
44+
(LogNormal, {'mean': 1, 'stdev': 0.5}, 1, float),
45+
(Discrete, {'values': [1, 2, 3], 'freq': [0.2, 0.5, 0.3]}, 2.1, int)
46+
])
47+
def test_samples(dist_class, params, expected_mean, expected_type):
48+
"""
49+
Test that generated samples match the expected mean and requested size,
50+
and that the random seed is working.
51+
52+
Arguments:
53+
dist_class (class):
54+
Distribution class to test.
55+
params (dict):
56+
Parameters for initialising the distribution.
57+
expected_mean (float):
58+
Expected mean of the distribution.
59+
expected_type (type):
60+
Expected type of the sample (float for continuous distributions,
61+
int for discrete).
62+
"""
63+
# Initialise the distribution
64+
dist = dist_class(random_seed=42, **params)
65+
66+
# Check that sample is a float
67+
x = dist.sample()
68+
assert isinstance(x, expected_type), (
69+
f'Expected sample() to return a {expected_type} - instead: {type(x)}'
70+
)
71+
72+
# Check that the mean of generated samples is close to the expected mean
73+
samples = dist.sample(size=10000)
74+
assert np.isclose(np.mean(samples), expected_mean, rtol=0.1)
75+
76+
# Check that the sample size matches the requested size
77+
assert len(samples) == 10000
78+
79+
# Check that the same seed returns the same sample
80+
sample1 = dist_class(random_seed=5, **params).sample(size=5)
81+
sample2 = dist_class(random_seed=5, **params).sample(size=5)
82+
assert np.array_equal(sample1, sample2), (
83+
'Samples with the same random seeds should be equal.'
84+
)
85+
86+
# Check that different seeds return different samples
87+
sample3 = dist_class(random_seed=89, **params).sample(size=5)
88+
assert not np.array_equal(sample1, sample3), (
89+
'Samples with different random seeds should not be equal.'
90+
)
91+
92+
93+
def test_invalid_exponential():
94+
"""
95+
Ensure that Exponential distribution cannot be created with a negative
96+
or zero mean.
97+
"""
98+
# Negative mean
99+
with pytest.raises(ValueError):
100+
Exponential(mean=-5, random_seed=42)
101+
102+
# Zero mean
103+
with pytest.raises(ValueError):
104+
Exponential(mean=0, random_seed=42)
105+
106+
# Check that no negative values are sampled
107+
d = Exponential(mean=10, random_seed=42)
108+
bigsample = d.sample(size=100000)
109+
assert all(x > 0 for x in bigsample), (
110+
'Sample contains non-positive values.'
111+
)
112+
113+
114+
def test_lognormal_moments():
115+
"""
116+
Test the calculation of normal distribution parameters (mu, sigma) from
117+
lognormal parameters.
118+
119+
This test verifies that:
120+
1. The normal_moments_from_lognormal method correctly converts lognormal
121+
parameters (mean, variance) to normal distribution parameters (mu, sigma).
122+
2. The calculated values match the expected mathematical formulas.
123+
"""
124+
# Define lognormal parameters
125+
mean, stdev = 2.0, 0.5
126+
127+
# Initialise distribution and get calculated parameters
128+
dist = LogNormal(mean=mean, stdev=stdev, random_seed=42)
129+
calculated_mu, calculated_sigma = (
130+
dist.normal_moments_from_lognormal(mean, stdev**2))
131+
132+
# Verify calculated parameters match expected mathematical formulas
133+
# Formula for mu: ln(mean²/√(stdev² + mean²))
134+
assert np.isclose(calculated_mu,
135+
np.log(mean**2 / np.sqrt(stdev**2 + mean**2)),
136+
rtol=1e-5)
137+
# Formula for sigma: √ln(1 + stdev²/mean²)
138+
assert np.isclose(calculated_sigma,
139+
np.sqrt(np.log(1 + stdev**2 / mean**2)),
140+
rtol=1e-5)
141+
142+
143+
def test_discrete_probabilities():
144+
"""
145+
Test correct calculation of probabilities for Discrete distribution.
146+
147+
This test verifies that:
148+
1. The Discrete class correctly normalizes frequency values to
149+
probabilities.
150+
2. The sum of probabilities equals 1
151+
3. The relative proportions match the input frequencies
152+
"""
153+
# Define discrete distribution parameters
154+
values = [1, 2, 3]
155+
freq = [10, 20, 30]
156+
157+
# Initialise distribution
158+
dist = Discrete(values=values, freq=freq, random_seed=42)
159+
160+
# Calculate expected probabilities by normalising frequencies
161+
expected_probs = np.array(freq) / np.sum(freq)
162+
163+
# Verify calculated probabilities match expected values
164+
assert np.allclose(dist.probabilities, expected_probs, rtol=1e-5)
165+
166+
# Verify probabilities sum to 1
167+
assert np.isclose(np.sum(dist.probabilities), 1.0, rtol=1e-10)
168+
169+
170+
def test_discrete_value_error():
171+
"""
172+
Test if Discrete raises ValueError for mismatched inputs.
173+
174+
This test verifies that the Discrete class correctly validates that
175+
the values and frequencies arrays have the same length.
176+
"""
177+
# Attempt to initialise with mismatched array lengths
178+
with pytest.raises(ValueError):
179+
Discrete(values=[1, 2], freq=[0.5], random_seed=42)
180+
181+
182+
def test_invalid_input_types():
183+
"""
184+
Test error handling for invalid string input types.
185+
186+
This test verifies that appropriate errors are raised when
187+
string values are provided instead of numeric values for
188+
distribution parameters.
189+
"""
190+
with pytest.raises(TypeError):
191+
Exponential(mean='5')
192+
with pytest.raises(TypeError):
193+
LogNormal(mean='4', stdev=1)
194+
with pytest.raises(TypeError):
195+
Discrete(values=[1, 2, 3], freq=['0.2', '0.5', '0.3'])
196+
197+
198+
def test_discrete_uneven_probabilities():
199+
"""
200+
Test behavior of Discrete distribution with highly uneven probabilities.
201+
202+
This test verifies that when one probability is much larger than another,
203+
the sampling correctly reflects this imbalance by rarely selecting the
204+
low-probability value.
205+
"""
206+
# Create distribution with extremely uneven probabilities
207+
dist = Discrete(values=[1, 2], freq=[1, 1e-10], random_seed=42)
208+
209+
# Generate a large sample
210+
samples = dist.sample(size=10000)
211+
212+
# Verify that the low-probability value (2) appears very rarely
213+
# With p ≈ 1e-10, we expect virtually no occurrences of value 2
214+
assert np.sum(samples == 2) < 5

0 commit comments

Comments
 (0)