Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion privacy_guard/analysis/tests/base_test_analysis_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# pyre-strict

import unittest
from typing import Tuple
from typing import Any, Dict, List, Tuple

import numpy as np
import pandas as pd
Expand All @@ -23,6 +23,35 @@ class BaseTestAnalysisNode(unittest.TestCase):
Util test class which sets up common dataframes for use in testing.
"""

def assertIsListOfFloats(self, value: Any, msg: str = "") -> None:
"""Assert that value is a list containing only float or np.floating elements."""
self.assertIsInstance(value, list, msg or "Expected a list")
self.assertTrue(
all(isinstance(x, (float, np.floating)) for x in value),
msg
or f"Expected all elements to be float, got types: {[type(x).__name__ for x in value]}",
)

def assertIsListOfFloatsWithLength(
self, value: Any, expected_length: int, msg: str = ""
) -> None:
"""Assert that value is a list of floats with a specific length."""
self.assertIsListOfFloats(value, msg)
self.assertEqual(
len(value),
expected_length,
msg or f"Expected list of length {expected_length}, got {len(value)}",
)

def assertAllKeysPresent(
self, d: Dict[str, Any], keys: List[str], msg: str = ""
) -> None:
"""Assert that all specified keys are present in dictionary."""
self.assertTrue(
set(keys).issubset(d.keys()),
msg or f"Missing keys: {set(keys) - set(d.keys())}",
)

def sample_normal_distribution(
self, mean: float = 0.0, std_dev: float = 1.0, num_samples: int = 20000
) -> pd.DataFrame:
Expand Down
93 changes: 23 additions & 70 deletions privacy_guard/analysis/tests/test_analysis_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_timer_disabled(self) -> None:

def test_turn_cap_eps_on(self) -> None:
"""
Tests capping of computed epsilons. Under cap_eps=True and a seprable setting with two users, the max eps should be log(2) = 0.693.
Tests capping of computed epsilons. Under cap_eps=True and a separable setting with two users, the max eps should be log(2) = 0.693.
"""
analysis_node = AnalysisNode(
self.separable_base_analysis_input,
Expand All @@ -121,12 +121,12 @@ def test_turn_cap_eps_on(self) -> None:
eps_tpr_ub = max(
outputs["eps_tpr_ub"]
) # max eps over all TPR thresholds, should be log(2) ~ 0.693
assert abs(eps_tpr_ub - np.log(2)) < 1e-6
self.assertAlmostEqual(eps_tpr_ub, np.log(2), places=6)

eps_fpr_ub = max(
outputs["eps_fpr_ub"]
) # max eps over all FPR thresholds, should be log(2) ~ 0.693
assert abs(eps_fpr_ub - np.log(2)) < 1e-6
self.assertAlmostEqual(eps_fpr_ub, np.log(2), places=6)

def test_turn_cap_eps_off(self) -> None:
"""
Expand All @@ -142,13 +142,12 @@ def test_turn_cap_eps_off(self) -> None:
eps_tpr_ub = max(
outputs["eps_tpr_ub"]
) # max eps over all TPR thresholds, should be inf
assert eps_tpr_ub == float("inf")
self.assertEqual(eps_tpr_ub, float("inf"))

eps_fpr_ub = max(
outputs["eps_fpr_ub"]
) # max eps over all FPR thresholds, should be inf
print(outputs["eps_tpr_ub"], outputs["eps_fpr_ub"])
assert eps_fpr_ub == float("inf")
self.assertEqual(eps_fpr_ub, float("inf"))

def test_num_bootstrap_resampling(self) -> None:
"""
Expand Down Expand Up @@ -237,80 +236,34 @@ def test_compute_output_types(self) -> None:
self.assertIsInstance(analysis_outputs, AnalysisNodeOutput)
analysis_outputs_dict = self.analysis_node.compute_outputs()
self.assertIsInstance(analysis_outputs_dict, dict)

# Scalar float fields
self.assertIsInstance(analysis_outputs_dict["eps"], (float, np.floating))
self.assertIsInstance(analysis_outputs_dict["eps_lb"], (float, np.floating))
self.assertIsInstance(
analysis_outputs_dict["eps_fpr_max_ub"], (float, np.floating)
)
self.assertIsInstance(analysis_outputs_dict["eps_fpr_lb"], list)
self.assertTrue(
all(
isinstance(x, (float, np.floating))
for x in analysis_outputs_dict["eps_fpr_lb"]
)
)
self.assertIsInstance(analysis_outputs_dict["eps_fpr_ub"], list)
self.assertTrue(
all(
isinstance(x, (float, np.floating))
for x in analysis_outputs_dict["eps_fpr_ub"]
)
)
self.assertIsInstance(analysis_outputs_dict["eps_tpr_lb"], list)
self.assertTrue(
all(
isinstance(x, (float, np.floating))
for x in analysis_outputs_dict["eps_tpr_lb"]
)
)
self.assertIsInstance(analysis_outputs_dict["eps_tpr_ub"], list)
self.assertTrue(
all(
isinstance(x, (float, np.floating))
for x in analysis_outputs_dict["eps_tpr_ub"]
)
)
self.assertIsInstance(analysis_outputs_dict["eps_max_lb"], list)
self.assertTrue(
all(
isinstance(x, (float, np.floating))
for x in analysis_outputs_dict["eps_max_lb"]
)
)
self.assertIsInstance(analysis_outputs_dict["eps_max_ub"], list)
self.assertTrue(
all(
isinstance(x, (float, np.floating))
for x in analysis_outputs_dict["eps_max_ub"]
)
)
self.assertIsInstance(analysis_outputs_dict["eps_cp"], (float, np.floating))

self.assertIsInstance(analysis_outputs_dict["accuracy"], (float, np.floating))
self.assertIsInstance(analysis_outputs_dict["accuracy_ci"], list)
self.assertEqual(len(analysis_outputs_dict["accuracy_ci"]), 2)
self.assertTrue(
all(
isinstance(x, (float, np.floating))
for x in analysis_outputs_dict["accuracy_ci"]
)
)

self.assertIsInstance(analysis_outputs_dict["auc"], (float, np.floating))
self.assertIsInstance(analysis_outputs_dict["auc_ci"], list)
self.assertEqual(len(analysis_outputs_dict["auc_ci"]), 2)
self.assertTrue(
all(
isinstance(x, (float, np.floating))
for x in analysis_outputs_dict["auc_ci"]
)
)

# List of floats fields
self.assertIsListOfFloats(analysis_outputs_dict["eps_fpr_lb"])
self.assertIsListOfFloats(analysis_outputs_dict["eps_fpr_ub"])
self.assertIsListOfFloats(analysis_outputs_dict["eps_tpr_lb"])
self.assertIsListOfFloats(analysis_outputs_dict["eps_tpr_ub"])
self.assertIsListOfFloats(analysis_outputs_dict["eps_max_lb"])
self.assertIsListOfFloats(analysis_outputs_dict["eps_max_ub"])

# Confidence intervals (list of 2 floats)
self.assertIsListOfFloatsWithLength(analysis_outputs_dict["accuracy_ci"], 2)
self.assertIsListOfFloatsWithLength(analysis_outputs_dict["auc_ci"], 2)

# Data size dictionary
self.assertIsInstance(analysis_outputs_dict["data_size"], dict)
self.assertTrue(
{"train_size", "test_size", "bootstrap_size"}.issubset(
analysis_outputs_dict["data_size"]
)
self.assertAllKeysPresent(
analysis_outputs_dict["data_size"],
["train_size", "test_size", "bootstrap_size"],
)
self.assertTrue(
all(isinstance(x, int) for x in analysis_outputs_dict["data_size"].values())
Expand Down
66 changes: 16 additions & 50 deletions privacy_guard/analysis/tests/test_parallel_analysis_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,66 +129,32 @@ def test_compute_output_types(self) -> None:
self.assertIsInstance(analysis_outputs, AnalysisNodeOutput)
analysis_outputs_dict = self.parallel_analysis_node.compute_outputs()
self.assertIsInstance(analysis_outputs_dict, dict)

# Scalar float fields
self.assertIsInstance(analysis_outputs_dict["eps"], (float, np.floating))
self.assertIsInstance(analysis_outputs_dict["eps_lb"], (float, np.floating))
self.assertIsInstance(
analysis_outputs_dict["eps_fpr_max_ub"], (float, np.floating)
)
self.assertIsInstance(analysis_outputs_dict["eps_fpr_lb"], list)
self.assertTrue(
all(
isinstance(x, (float, np.floating))
for x in analysis_outputs_dict["eps_fpr_lb"]
)
)
self.assertIsInstance(analysis_outputs_dict["eps_fpr_ub"], list)
self.assertTrue(
all(
isinstance(x, (float, np.floating))
for x in analysis_outputs_dict["eps_fpr_ub"]
)
)
self.assertIsInstance(analysis_outputs_dict["eps_tpr_lb"], list)
self.assertTrue(
all(
isinstance(x, (float, np.floating))
for x in analysis_outputs_dict["eps_tpr_lb"]
)
)
self.assertIsInstance(analysis_outputs_dict["eps_tpr_ub"], list)
self.assertTrue(
all(
isinstance(x, (float, np.floating))
for x in analysis_outputs_dict["eps_tpr_ub"]
)
)
self.assertIsInstance(analysis_outputs_dict["eps_cp"], (float, np.floating))

self.assertIsInstance(analysis_outputs_dict["accuracy"], (float, np.floating))
self.assertIsInstance(analysis_outputs_dict["accuracy_ci"], list)
self.assertEqual(len(analysis_outputs_dict["accuracy_ci"]), 2)
self.assertTrue(
all(
isinstance(x, (float, np.floating))
for x in analysis_outputs_dict["accuracy_ci"]
)
)

self.assertIsInstance(analysis_outputs_dict["auc"], (float, np.floating))
self.assertIsInstance(analysis_outputs_dict["auc_ci"], list)
self.assertEqual(len(analysis_outputs_dict["auc_ci"]), 2)
self.assertTrue(
all(
isinstance(x, (float, np.floating))
for x in analysis_outputs_dict["auc_ci"]
)
)

# List of floats fields
self.assertIsListOfFloats(analysis_outputs_dict["eps_fpr_lb"])
self.assertIsListOfFloats(analysis_outputs_dict["eps_fpr_ub"])
self.assertIsListOfFloats(analysis_outputs_dict["eps_tpr_lb"])
self.assertIsListOfFloats(analysis_outputs_dict["eps_tpr_ub"])

# Confidence intervals (list of 2 floats)
self.assertIsListOfFloatsWithLength(analysis_outputs_dict["accuracy_ci"], 2)
self.assertIsListOfFloatsWithLength(analysis_outputs_dict["auc_ci"], 2)

# Data size dictionary
self.assertIsInstance(analysis_outputs_dict["data_size"], dict)
self.assertTrue(
{"train_size", "test_size", "bootstrap_size"}.issubset(
analysis_outputs_dict["data_size"]
)
self.assertAllKeysPresent(
analysis_outputs_dict["data_size"],
["train_size", "test_size", "bootstrap_size"],
)
self.assertTrue(
all(isinstance(x, int) for x in analysis_outputs_dict["data_size"].values())
Expand Down
Loading