Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 62 additions & 6 deletions privacy_guard/analysis/mia/analysis_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass
from typing import List
from typing import List, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -53,6 +53,7 @@ class AnalysisNodeOutput(BaseAnalysisOutput):
auc (float): Mean area under the curve (AUC) of the attack.
auc_ci (List[float]): Confidence interval for the AUC, represented as [lower_bound, upper_bound].
data_size (dict[str, int]): Size of the training, test dataset and bootstrap sample size.
tpr_target (float): Target TPR used for computing epsilon.
"""

# Empirical epsilons
Expand All @@ -73,6 +74,9 @@ class AnalysisNodeOutput(BaseAnalysisOutput):
auc_ci: List[float]
# Dataset sizes
data_size: dict[str, int]
# TPR target and index (only set when custom tpr_target is provided)
tpr_target: Optional[float]
tpr_idx: Optional[int]


class AnalysisNode(BaseAnalysisNode):
Expand All @@ -93,6 +97,11 @@ class AnalysisNode(BaseAnalysisNode):
use_fnr_tnr: boolean for whether to use FNR and TNR in addition to FPR and TPR error thresholds in eps_max_array computation.
show_progress: boolean for whether to show tqdm progress bar
with_timer: boolean for whether to show timer for analysis node
tpr_target: Optional target TPR for computing epsilon. If None (default), uses legacy 1% intervals.
If specified, uses fine-grained intervals determined by tpr_threshold_width.
tpr_threshold_width: Width (step size) between TPR thresholds for fine-grained mode.
Only used when tpr_target is specified. Default 0.0025 (0.25%).
Start is always fixed at 0.01. num_thresholds = int((1.0 - 0.01) / width) + 1.
"""

def __init__(
Expand All @@ -106,6 +115,8 @@ def __init__(
use_fnr_tnr: bool = False,
show_progress: bool = False,
with_timer: bool = False,
tpr_target: Optional[float] = None,
tpr_threshold_width: float = 0.0025,
) -> None:
self._delta = delta
self._n_users_for_eval = n_users_for_eval
Expand All @@ -117,15 +128,59 @@ def __init__(

self._use_upper_bound = use_upper_bound

self._tpr_target = tpr_target
self._tpr_threshold_width = tpr_threshold_width
self._num_thresholds: int

self._timer_stats: TimerStats = {}

if self._n_users_for_eval < 0:
raise ValueError(
'Input to AnalysisNode "n_users_for_eval" must be nonnegative'
)

if self._tpr_target is not None:
assert isinstance(self._tpr_target, float)
if self._tpr_target < 0.01 or self._tpr_target > 1.0:
raise ValueError(
'Input to AnalysisNode "tpr_target" must be between 0.01 and 1.0'
)

if self._tpr_threshold_width <= 0:
raise ValueError(
'Input to AnalysisNode "tpr_threshold_width" must be positive'
)

if not np.isclose(
0.99 / self._tpr_threshold_width,
round(0.99 / self._tpr_threshold_width),
):
raise ValueError(
'Input to AnalysisNode "tpr_threshold_width" must evenly divide 0.99. '
"Valid examples: 0.001, 0.0025, 0.003, 0.005, 0.01"
)

# Determine num_thresholds based on tpr_target
if self._tpr_target is None:
# Legacy: 1% intervals (0.01 to 1.0, 100 thresholds)
self._num_thresholds = 100
self._tpr_threshold_width = 0.01
else:
self._tpr_threshold_width = tpr_threshold_width
self._num_thresholds = int((1.0 - 0.01) / tpr_threshold_width) + 1

self._error_thresholds: NDArray[np.floating] = np.linspace(
0.01, 1.0, self._num_thresholds
)

super().__init__(analysis_input=analysis_input)

def _get_tpr_index(self) -> int:
"""Convert TPR target to array index."""
if self._tpr_target is None:
return 0 # Legacy behavior: TPR = 1% is at index 0
return int(np.where(self._error_thresholds == self._tpr_target)[0][0])

def _calculate_one_off_eps(self) -> float:
df_train_user = self.analysis_input.df_train_user
df_test_user = self.analysis_input.df_test_user
Expand Down Expand Up @@ -253,9 +308,10 @@ def run_analysis(self) -> BaseAnalysisOutput:

eps_tpr_boundary = eps_tpr_ub if self._use_upper_bound else eps_tpr_lb

tpr_idx = self._get_tpr_index()
outputs = AnalysisNodeOutput(
eps=eps_tpr_boundary[0], # epsilon at TPR=1% UB threshold
eps_lb=eps_tpr_lb[0],
eps=eps_tpr_boundary[tpr_idx], # epsilon at specified TPR threshold
eps_lb=eps_tpr_lb[tpr_idx],
eps_fpr_max_ub=np.nanmax(eps_fpr_ub),
eps_fpr_lb=list(eps_fpr_lb),
eps_fpr_ub=list(eps_fpr_ub),
Expand All @@ -273,6 +329,8 @@ def run_analysis(self) -> BaseAnalysisOutput:
"test_size": test_size,
"bootstrap_size": bootstrap_sample_size,
},
tpr_target=self._tpr_target,
tpr_idx=tpr_idx,
)

if self._with_timer:
Expand Down Expand Up @@ -313,8 +371,6 @@ def _make_metrics_array(

bootstrap_sample_size = min(train_size, test_size)

error_thresholds = np.linspace(0.01, 1, 100)

metrics_array = [
MIAResults(
loss_train[
Expand All @@ -329,7 +385,7 @@ def _make_metrics_array(
],
).compute_metrics_at_error_threshold(
self._delta,
error_threshold=error_thresholds,
error_threshold=self._error_thresholds,
cap_eps=self._cap_eps,
use_fnr_tnr=self._use_fnr_tnr,
verbose=False,
Expand Down
17 changes: 13 additions & 4 deletions privacy_guard/analysis/mia/parallel_analysis_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import tempfile
from concurrent.futures import ProcessPoolExecutor
from typing import Optional

import numpy as np
import torch
Expand All @@ -41,6 +42,8 @@ class ParallelAnalysisNode(AnalysisNode):
use_upper_bound: boolean for whether to compute epsilon at the upper-bound of CI
use_fnr_tnr: boolean for whether to use FNR and TNR in addition to FPR and TPR error thresholds in eps_max_array computation.
with_timer: boolean for whether to show timer for analysis node
tpr_target: Optional target TPR for computing epsilon. If None (default), uses legacy 1% intervals.
tpr_threshold_width: Width between TPR thresholds for fine-grained mode. Default 0.0025.
"""

def __init__(
Expand All @@ -53,6 +56,8 @@ def __init__(
num_bootstrap_resampling_times: int = 1000,
use_fnr_tnr: bool = False,
with_timer: bool = False,
tpr_target: Optional[float] = None,
tpr_threshold_width: float = 0.0025,
) -> None:
super().__init__(
analysis_input=analysis_input,
Expand All @@ -62,6 +67,8 @@ def __init__(
num_bootstrap_resampling_times=num_bootstrap_resampling_times,
use_fnr_tnr=use_fnr_tnr,
with_timer=with_timer,
tpr_target=tpr_target,
tpr_threshold_width=tpr_threshold_width,
)
self._eps_computation_tasks_num = eps_computation_tasks_num

Expand All @@ -87,7 +94,6 @@ def _compute_metrics_array(
loss_test: torch.Tensor = torch.load(test_filename, weights_only=True)
train_size, test_size = loss_train.shape[0], loss_test.shape[0]
bootstrap_sample_size = min(train_size, test_size)
error_thresholds = np.linspace(0.01, 1, 100)
metrics_results = []

try:
Expand All @@ -107,7 +113,7 @@ def _compute_metrics_array(

metrics_result = mia_results.compute_metrics_at_error_threshold(
self._delta,
error_threshold=error_thresholds,
error_threshold=self._error_thresholds,
use_fnr_tnr=self._use_fnr_tnr,
verbose=False,
)
Expand Down Expand Up @@ -221,9 +227,10 @@ def run_analysis(self) -> AnalysisNodeOutput:

eps_tpr_boundary = eps_tpr_ub if self._use_upper_bound else eps_tpr_lb

tpr_idx = self._get_tpr_index()
outputs = AnalysisNodeOutput(
eps=eps_tpr_boundary[0], # epsilon at TPR=1% UB threshold
eps_lb=eps_tpr_lb[0],
eps=eps_tpr_boundary[tpr_idx], # epsilon at specified TPR threshold
eps_lb=eps_tpr_lb[tpr_idx],
eps_fpr_max_ub=np.nanmax(eps_fpr_ub),
eps_fpr_lb=list(eps_fpr_lb),
eps_fpr_ub=list(eps_fpr_ub),
Expand All @@ -243,6 +250,8 @@ def run_analysis(self) -> AnalysisNodeOutput:
"test_size": test_size,
"bootstrap_size": bootstrap_sample_size,
},
tpr_target=self._tpr_target,
tpr_idx=tpr_idx,
)

if self._with_timer:
Expand Down
114 changes: 114 additions & 0 deletions privacy_guard/analysis/tests/test_analysis_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,117 @@ def test_use_fnr_tnr_parameter_comparison(self) -> None:
outputs_false["accuracy"], outputs_true["accuracy"], places=10
)
self.assertAlmostEqual(outputs_false["auc"], outputs_true["auc"], places=10)

def test_get_tpr_index_none_target(self) -> None:
"""Test that _get_tpr_index returns 0 when tpr_target is None (legacy behavior)."""
analysis_node = AnalysisNode(
analysis_input=self.analysis_input,
delta=0.000001,
n_users_for_eval=100,
num_bootstrap_resampling_times=10,
tpr_target=None,
)
self.assertEqual(analysis_node._get_tpr_index(), 0)

def test_get_tpr_index_with_target(self) -> None:
"""Test that _get_tpr_index returns correct index that points to tpr_target."""
# Create error_thresholds array to get actual values
num_thresholds = int((1.0 - 0.01) / 0.0025) + 1
error_thresholds = np.linspace(0.01, 1.0, num_thresholds)

# Test with actual values from the array at various indices
test_indices = [0, 6, 36, 196, num_thresholds - 1]

for idx in test_indices:
tpr_target = error_thresholds[idx]
analysis_node = AnalysisNode(
analysis_input=self.analysis_input,
delta=0.000001,
n_users_for_eval=100,
num_bootstrap_resampling_times=10,
tpr_target=tpr_target,
tpr_threshold_width=0.0025,
)
tpr_idx = analysis_node._get_tpr_index()
self.assertEqual(
tpr_idx,
idx,
msg=f"tpr_target={tpr_target}: expected index {idx}, got {tpr_idx}",
)

def test_tpr_threshold_width_positive_validation(self) -> None:
"""Test that tpr_threshold_width must be positive."""
with self.assertRaisesRegex(ValueError, "must be positive"):
AnalysisNode(
analysis_input=self.analysis_input,
delta=0.000001,
n_users_for_eval=100,
num_bootstrap_resampling_times=10,
tpr_threshold_width=0,
)

with self.assertRaisesRegex(ValueError, "must be positive"):
AnalysisNode(
analysis_input=self.analysis_input,
delta=0.000001,
n_users_for_eval=100,
num_bootstrap_resampling_times=10,
tpr_threshold_width=-0.01,
)

def test_tpr_threshold_width_divisibility_validation(self) -> None:
"""Test that tpr_threshold_width must evenly divide 0.99."""
with self.assertRaisesRegex(ValueError, "must evenly divide 0.99"):
AnalysisNode(
analysis_input=self.analysis_input,
delta=0.000001,
n_users_for_eval=100,
num_bootstrap_resampling_times=10,
tpr_threshold_width=0.02,
)

def test_tpr_target_range_validation(self) -> None:
"""Test that tpr_target must be between 0.01 and 1.0."""
with self.assertRaisesRegex(ValueError, "must be between 0.01 and 1.0"):
AnalysisNode(
analysis_input=self.analysis_input,
delta=0.000001,
n_users_for_eval=100,
num_bootstrap_resampling_times=10,
tpr_target=0.005,
)

with self.assertRaisesRegex(ValueError, "must be between 0.01 and 1.0"):
AnalysisNode(
analysis_input=self.analysis_input,
delta=0.000001,
n_users_for_eval=100,
num_bootstrap_resampling_times=10,
tpr_target=1.5,
)

def test_error_thresholds_array_creation(self) -> None:
"""Test that _error_thresholds array is correctly created."""
# Legacy mode: 100 thresholds
analysis_node_legacy = AnalysisNode(
analysis_input=self.analysis_input,
delta=0.000001,
n_users_for_eval=100,
num_bootstrap_resampling_times=10,
tpr_target=None,
)
self.assertEqual(len(analysis_node_legacy._error_thresholds), 100)

# Fine-grained mode
analysis_node_fine = AnalysisNode(
analysis_input=self.analysis_input,
delta=0.000001,
n_users_for_eval=100,
num_bootstrap_resampling_times=10,
tpr_target=0.01,
tpr_threshold_width=0.0025,
)
expected_num_thresholds = int(0.99 / 0.0025) + 1
self.assertEqual(
len(analysis_node_fine._error_thresholds), expected_num_thresholds
)
19 changes: 19 additions & 0 deletions privacy_guard/analysis/tests/test_parallel_analysis_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,22 @@ def test_use_fnr_tnr_parameter(self) -> None:
self.assertGreater(
len(outputs_false["eps_tpr_ub"]), len(outputs_true["eps_tpr_ub"])
)

def test_tpr_target_parameter(self) -> None:
"""Test that tpr_target parameter works correctly in ParallelAnalysisNode."""
parallel_node = ParallelAnalysisNode(
analysis_input=self.analysis_input,
delta=0.000001,
n_users_for_eval=100,
num_bootstrap_resampling_times=10,
eps_computation_tasks_num=2,
tpr_target=0.025,
tpr_threshold_width=0.0025,
)
# Verify _get_tpr_index returns correct index
tpr_idx = parallel_node._get_tpr_index()
self.assertAlmostEqual(
parallel_node._error_thresholds[tpr_idx],
0.025,
places=10,
)
Loading