diff --git a/privacy_guard/analysis/tests/base_test_analysis_node.py b/privacy_guard/analysis/tests/base_test_analysis_node.py index 1b40ccc..876d934 100644 --- a/privacy_guard/analysis/tests/base_test_analysis_node.py +++ b/privacy_guard/analysis/tests/base_test_analysis_node.py @@ -12,7 +12,7 @@ # pyre-strict import unittest -from typing import Tuple +from typing import Any, Dict, List, Tuple import numpy as np import pandas as pd @@ -23,6 +23,35 @@ class BaseTestAnalysisNode(unittest.TestCase): Util test class which sets up common dataframes for use in testing. """ + def assertIsListOfFloats(self, value: Any, msg: str = "") -> None: + """Assert that value is a list containing only float or np.floating elements.""" + self.assertIsInstance(value, list, msg or "Expected a list") + self.assertTrue( + all(isinstance(x, (float, np.floating)) for x in value), + msg + or f"Expected all elements to be float, got types: {[type(x).__name__ for x in value]}", + ) + + def assertIsListOfFloatsWithLength( + self, value: Any, expected_length: int, msg: str = "" + ) -> None: + """Assert that value is a list of floats with a specific length.""" + self.assertIsListOfFloats(value, msg) + self.assertEqual( + len(value), + expected_length, + msg or f"Expected list of length {expected_length}, got {len(value)}", + ) + + def assertAllKeysPresent( + self, d: Dict[str, Any], keys: List[str], msg: str = "" + ) -> None: + """Assert that all specified keys are present in dictionary.""" + self.assertTrue( + set(keys).issubset(d.keys()), + msg or f"Missing keys: {set(keys) - set(d.keys())}", + ) + def sample_normal_distribution( self, mean: float = 0.0, std_dev: float = 1.0, num_samples: int = 20000 ) -> pd.DataFrame: diff --git a/privacy_guard/analysis/tests/test_analysis_node.py b/privacy_guard/analysis/tests/test_analysis_node.py index c99fbf2..e35b2df 100644 --- a/privacy_guard/analysis/tests/test_analysis_node.py +++ b/privacy_guard/analysis/tests/test_analysis_node.py @@ -109,7 +109,7 @@ def test_timer_disabled(self) -> None: def test_turn_cap_eps_on(self) -> None: """ - Tests capping of computed epsilons. Under cap_eps=True and a seprable setting with two users, the max eps should be log(2) = 0.693. + Tests capping of computed epsilons. Under cap_eps=True and a separable setting with two users, the max eps should be log(2) = 0.693. """ analysis_node = AnalysisNode( self.separable_base_analysis_input, @@ -121,12 +121,12 @@ def test_turn_cap_eps_on(self) -> None: eps_tpr_ub = max( outputs["eps_tpr_ub"] ) # max eps over all TPR thresholds, should be log(2) ~ 0.693 - assert abs(eps_tpr_ub - np.log(2)) < 1e-6 + self.assertAlmostEqual(eps_tpr_ub, np.log(2), places=6) eps_fpr_ub = max( outputs["eps_fpr_ub"] ) # max eps over all FPR thresholds, should be log(2) ~ 0.693 - assert abs(eps_fpr_ub - np.log(2)) < 1e-6 + self.assertAlmostEqual(eps_fpr_ub, np.log(2), places=6) def test_turn_cap_eps_off(self) -> None: """ @@ -142,13 +142,12 @@ def test_turn_cap_eps_off(self) -> None: eps_tpr_ub = max( outputs["eps_tpr_ub"] ) # max eps over all TPR thresholds, should be inf - assert eps_tpr_ub == float("inf") + self.assertEqual(eps_tpr_ub, float("inf")) eps_fpr_ub = max( outputs["eps_fpr_ub"] ) # max eps over all FPR thresholds, should be inf - print(outputs["eps_tpr_ub"], outputs["eps_fpr_ub"]) - assert eps_fpr_ub == float("inf") + self.assertEqual(eps_fpr_ub, float("inf")) def test_num_bootstrap_resampling(self) -> None: """ @@ -237,80 +236,34 @@ def test_compute_output_types(self) -> None: self.assertIsInstance(analysis_outputs, AnalysisNodeOutput) analysis_outputs_dict = self.analysis_node.compute_outputs() self.assertIsInstance(analysis_outputs_dict, dict) + + # Scalar float fields self.assertIsInstance(analysis_outputs_dict["eps"], (float, np.floating)) self.assertIsInstance(analysis_outputs_dict["eps_lb"], (float, np.floating)) self.assertIsInstance( analysis_outputs_dict["eps_fpr_max_ub"], (float, np.floating) ) - self.assertIsInstance(analysis_outputs_dict["eps_fpr_lb"], list) - self.assertTrue( - all( - isinstance(x, (float, np.floating)) - for x in analysis_outputs_dict["eps_fpr_lb"] - ) - ) - self.assertIsInstance(analysis_outputs_dict["eps_fpr_ub"], list) - self.assertTrue( - all( - isinstance(x, (float, np.floating)) - for x in analysis_outputs_dict["eps_fpr_ub"] - ) - ) - self.assertIsInstance(analysis_outputs_dict["eps_tpr_lb"], list) - self.assertTrue( - all( - isinstance(x, (float, np.floating)) - for x in analysis_outputs_dict["eps_tpr_lb"] - ) - ) - self.assertIsInstance(analysis_outputs_dict["eps_tpr_ub"], list) - self.assertTrue( - all( - isinstance(x, (float, np.floating)) - for x in analysis_outputs_dict["eps_tpr_ub"] - ) - ) - self.assertIsInstance(analysis_outputs_dict["eps_max_lb"], list) - self.assertTrue( - all( - isinstance(x, (float, np.floating)) - for x in analysis_outputs_dict["eps_max_lb"] - ) - ) - self.assertIsInstance(analysis_outputs_dict["eps_max_ub"], list) - self.assertTrue( - all( - isinstance(x, (float, np.floating)) - for x in analysis_outputs_dict["eps_max_ub"] - ) - ) self.assertIsInstance(analysis_outputs_dict["eps_cp"], (float, np.floating)) - self.assertIsInstance(analysis_outputs_dict["accuracy"], (float, np.floating)) - self.assertIsInstance(analysis_outputs_dict["accuracy_ci"], list) - self.assertEqual(len(analysis_outputs_dict["accuracy_ci"]), 2) - self.assertTrue( - all( - isinstance(x, (float, np.floating)) - for x in analysis_outputs_dict["accuracy_ci"] - ) - ) - self.assertIsInstance(analysis_outputs_dict["auc"], (float, np.floating)) - self.assertIsInstance(analysis_outputs_dict["auc_ci"], list) - self.assertEqual(len(analysis_outputs_dict["auc_ci"]), 2) - self.assertTrue( - all( - isinstance(x, (float, np.floating)) - for x in analysis_outputs_dict["auc_ci"] - ) - ) + # List of floats fields + self.assertIsListOfFloats(analysis_outputs_dict["eps_fpr_lb"]) + self.assertIsListOfFloats(analysis_outputs_dict["eps_fpr_ub"]) + self.assertIsListOfFloats(analysis_outputs_dict["eps_tpr_lb"]) + self.assertIsListOfFloats(analysis_outputs_dict["eps_tpr_ub"]) + self.assertIsListOfFloats(analysis_outputs_dict["eps_max_lb"]) + self.assertIsListOfFloats(analysis_outputs_dict["eps_max_ub"]) + + # Confidence intervals (list of 2 floats) + self.assertIsListOfFloatsWithLength(analysis_outputs_dict["accuracy_ci"], 2) + self.assertIsListOfFloatsWithLength(analysis_outputs_dict["auc_ci"], 2) + + # Data size dictionary self.assertIsInstance(analysis_outputs_dict["data_size"], dict) - self.assertTrue( - {"train_size", "test_size", "bootstrap_size"}.issubset( - analysis_outputs_dict["data_size"] - ) + self.assertAllKeysPresent( + analysis_outputs_dict["data_size"], + ["train_size", "test_size", "bootstrap_size"], ) self.assertTrue( all(isinstance(x, int) for x in analysis_outputs_dict["data_size"].values()) diff --git a/privacy_guard/analysis/tests/test_parallel_analysis_node.py b/privacy_guard/analysis/tests/test_parallel_analysis_node.py index 68aff6d..8e810ef 100644 --- a/privacy_guard/analysis/tests/test_parallel_analysis_node.py +++ b/privacy_guard/analysis/tests/test_parallel_analysis_node.py @@ -129,66 +129,32 @@ def test_compute_output_types(self) -> None: self.assertIsInstance(analysis_outputs, AnalysisNodeOutput) analysis_outputs_dict = self.parallel_analysis_node.compute_outputs() self.assertIsInstance(analysis_outputs_dict, dict) + + # Scalar float fields self.assertIsInstance(analysis_outputs_dict["eps"], (float, np.floating)) self.assertIsInstance(analysis_outputs_dict["eps_lb"], (float, np.floating)) self.assertIsInstance( analysis_outputs_dict["eps_fpr_max_ub"], (float, np.floating) ) - self.assertIsInstance(analysis_outputs_dict["eps_fpr_lb"], list) - self.assertTrue( - all( - isinstance(x, (float, np.floating)) - for x in analysis_outputs_dict["eps_fpr_lb"] - ) - ) - self.assertIsInstance(analysis_outputs_dict["eps_fpr_ub"], list) - self.assertTrue( - all( - isinstance(x, (float, np.floating)) - for x in analysis_outputs_dict["eps_fpr_ub"] - ) - ) - self.assertIsInstance(analysis_outputs_dict["eps_tpr_lb"], list) - self.assertTrue( - all( - isinstance(x, (float, np.floating)) - for x in analysis_outputs_dict["eps_tpr_lb"] - ) - ) - self.assertIsInstance(analysis_outputs_dict["eps_tpr_ub"], list) - self.assertTrue( - all( - isinstance(x, (float, np.floating)) - for x in analysis_outputs_dict["eps_tpr_ub"] - ) - ) self.assertIsInstance(analysis_outputs_dict["eps_cp"], (float, np.floating)) - self.assertIsInstance(analysis_outputs_dict["accuracy"], (float, np.floating)) - self.assertIsInstance(analysis_outputs_dict["accuracy_ci"], list) - self.assertEqual(len(analysis_outputs_dict["accuracy_ci"]), 2) - self.assertTrue( - all( - isinstance(x, (float, np.floating)) - for x in analysis_outputs_dict["accuracy_ci"] - ) - ) - self.assertIsInstance(analysis_outputs_dict["auc"], (float, np.floating)) - self.assertIsInstance(analysis_outputs_dict["auc_ci"], list) - self.assertEqual(len(analysis_outputs_dict["auc_ci"]), 2) - self.assertTrue( - all( - isinstance(x, (float, np.floating)) - for x in analysis_outputs_dict["auc_ci"] - ) - ) + # List of floats fields + self.assertIsListOfFloats(analysis_outputs_dict["eps_fpr_lb"]) + self.assertIsListOfFloats(analysis_outputs_dict["eps_fpr_ub"]) + self.assertIsListOfFloats(analysis_outputs_dict["eps_tpr_lb"]) + self.assertIsListOfFloats(analysis_outputs_dict["eps_tpr_ub"]) + + # Confidence intervals (list of 2 floats) + self.assertIsListOfFloatsWithLength(analysis_outputs_dict["accuracy_ci"], 2) + self.assertIsListOfFloatsWithLength(analysis_outputs_dict["auc_ci"], 2) + + # Data size dictionary self.assertIsInstance(analysis_outputs_dict["data_size"], dict) - self.assertTrue( - {"train_size", "test_size", "bootstrap_size"}.issubset( - analysis_outputs_dict["data_size"] - ) + self.assertAllKeysPresent( + analysis_outputs_dict["data_size"], + ["train_size", "test_size", "bootstrap_size"], ) self.assertTrue( all(isinstance(x, int) for x in analysis_outputs_dict["data_size"].values())