Skip to content

Commit f511168

Browse files
committed
feat(validation): add validation
1 parent c49c72e commit f511168

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed

simulation/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def __init__(self, param, run_number):
9090
run_number: int
9191
Replication / run number.
9292
"""
93+
# Check parameter validity
94+
param.check_param_validity()
95+
9396
# Set parameters
9497
self.param = param
9598
self.run_number = run_number

simulation/parameters.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,3 +477,147 @@ def __init__(
477477
# Set up logger
478478
self.logger = SimLogger(log_to_console=log_to_console,
479479
log_to_file=log_to_file)
480+
481+
def check_param_validity(self):
482+
"""
483+
Check the validity of the provided parameters.
484+
485+
Validates all simulation parameters to ensure they meet requirements:
486+
- Warm-up period and data collection period must be >= 0
487+
- Number of runs and audit interval must be > 0
488+
- Arrival rates must be >= 0
489+
- Length of stay parameters must be >= 0
490+
- Routing probabilities must sum to 1 and be between 0 and 1
491+
492+
Raises
493+
------
494+
ValueError
495+
If any parameter fails validation with a descriptive error message.
496+
"""
497+
# Validate parameters that must be >= 0
498+
for param in ["warm_up_period", "data_collection_period"]:
499+
self.validate_param(
500+
param, lambda x: x >= 0,
501+
"must be greater than or equal to 0")
502+
503+
# Validate parameters that must be > 0
504+
for param in ["number_of_runs", "audit_interval"]:
505+
self.validate_param(
506+
param, lambda x: x > 0,
507+
"must be greater than 0")
508+
509+
# Validate arrival parameters
510+
for param in ["asu_arrivals", "rehab_arrivals"]:
511+
self.validate_nested_param(
512+
param, lambda x: x >= 0,
513+
"must be greater than 0")
514+
515+
# Validate length of stay parameters
516+
for param in ["asu_los", "rehab_los"]:
517+
self.validate_nested_param(
518+
param, lambda x: x >= 0,
519+
"must be greater than 0", nested=True)
520+
521+
# Validate routing parameters
522+
for param in ["asu_routing", "rehab_routing"]:
523+
self.validate_routing(param)
524+
525+
def validate_param(self, param_name, condition, error_msg):
526+
"""
527+
Validate a single parameter against a condition.
528+
529+
Parameters
530+
----------
531+
param_name: str
532+
Name of the parameter being validated.
533+
condition: callable
534+
A function that returns True if the value is valid.
535+
error_msg: str
536+
Error message to display if validation fails.
537+
538+
Raises
539+
------
540+
ValueError:
541+
If the parameter fails the validation condition.
542+
"""
543+
value = getattr(self, param_name)
544+
if not condition(value):
545+
raise ValueError(
546+
f"Parameter '{param_name}' {error_msg}, but is: {value}")
547+
548+
def validate_nested_param(
549+
self, obj_name, condition, error_msg, nested=False
550+
):
551+
"""
552+
Validate parameters within a nested object structure.
553+
554+
Parameters
555+
----------
556+
obj_name: str
557+
Name of the object containing parameters.
558+
condition: callable
559+
A function that returns True if the value is valid.
560+
error_msg: str
561+
Error message to display if validation fails.
562+
nested: bool, optional
563+
If True, validates parameters in a doubly-nested structure. If
564+
False, validates parameters in a singly-nested structure.
565+
566+
Raises
567+
------
568+
ValueError:
569+
If any nested parameter fails the validation condition.
570+
"""
571+
obj = getattr(self, obj_name)
572+
for key, value in vars(obj).items():
573+
if key == "_initialised":
574+
continue
575+
if nested:
576+
for sub_key, sub_value in value.items():
577+
if not condition(sub_value):
578+
raise ValueError(
579+
f"Parameter '{sub_key}' for '{key}' in " +
580+
f"'{obj_name}' {error_msg}, but is: {sub_value}")
581+
else:
582+
if not condition(value):
583+
raise ValueError(
584+
f"Parameter '{key}' from '{obj_name}' {error_msg}, " +
585+
f"but is: {value}")
586+
587+
def validate_routing(self, obj_name):
588+
"""
589+
Validate routing probability parameters.
590+
591+
Performs two validations:
592+
1. Checks that all probabilities for each routing option sum to 1.
593+
2. Checks that individual probabilities are between 0 and 1 inclusive.
594+
595+
Parameters
596+
----------
597+
obj_name: str
598+
Name of the routing object.
599+
600+
Raises
601+
------
602+
ValueError:
603+
If the probabilities don't sum to 1, or if any probability is
604+
outside [0,1].
605+
"""
606+
obj = getattr(self, obj_name)
607+
for key, value in vars(obj).items():
608+
if key == "_initialised":
609+
continue
610+
611+
# Check that probabilities sum to 1
612+
total_prob = sum(value.values())
613+
if total_prob != 1:
614+
raise ValueError(
615+
f"Routing probabilities for '{key}' in '{obj_name}' " +
616+
f"should sum to 1 but sum to: {total_prob}")
617+
618+
# Check that probabilities are between 0 and 1
619+
for sub_key, sub_value in value.items():
620+
if sub_value < 0 or sub_value > 1:
621+
raise ValueError(
622+
f"Parameter '{sub_key}' for '{key}' in '{obj_name}'" +
623+
f"must be between 0 and 1, but is: {sub_value}")

0 commit comments

Comments
 (0)