From 356f789c83330481eeda0eb3ab04c989b6b73523 Mon Sep 17 00:00:00 2001 From: Duhyeong Kim Date: Mon, 9 Feb 2026 13:14:15 -0800 Subject: [PATCH] Update target TPR of prod AIA Summary: This diff adds support for configurable TPR (True Positive Rate) targets in MIA. **Key Changes:** - Added `tpr_target` parameter: Allows users to specify a custom TPR threshold (default behavior remains 1% if not specified) - Added `tpr_threshold_width` parameter: Controls the granularity of TPR thresholds (default: 0.25%) - Added input validation for `tpr_threshold_width` (must be positive and evenly divide 99%, e.g., 0.1%, 0.25%, 0.3%, 0.5%, etc.) - Added unit tests for the changes in this diff - `test_get_tpr_index_none_target` to verify legacy behavior (Returns index 0 when `tpr_target=None`) - `test_get_tpr_index_with_target` to verify `get_tpr_index` returns the correct index - `test_tpr_threshold_width_positive_validation` to verify ValueError is raised when `tpr_threshold_width <= 0` - `test_tpr_threshold_width_divisibility_validation` to verify ValueError is raised when `tpr_threshold_width` doesn't evenly divide 0.99 **Backward Compatibility:** - Fully backward compatible - existing configs without `tpr_target` continue to use the original 1% TPR with 100 thresholds Differential Revision: D92439701 --- privacy_guard/analysis/mia/analysis_node.py | 68 ++++++++++- .../analysis/mia/parallel_analysis_node.py | 17 ++- .../analysis/tests/test_analysis_node.py | 114 ++++++++++++++++++ .../tests/test_parallel_analysis_node.py | 19 +++ 4 files changed, 208 insertions(+), 10 deletions(-) diff --git a/privacy_guard/analysis/mia/analysis_node.py b/privacy_guard/analysis/mia/analysis_node.py index 7ff5910..abf9ef7 100644 --- a/privacy_guard/analysis/mia/analysis_node.py +++ b/privacy_guard/analysis/mia/analysis_node.py @@ -16,7 +16,7 @@ from collections.abc import Generator from contextlib import contextmanager from dataclasses import dataclass -from typing import List +from typing import List, Optional import numpy as np import torch @@ -53,6 +53,7 @@ class AnalysisNodeOutput(BaseAnalysisOutput): auc (float): Mean area under the curve (AUC) of the attack. auc_ci (List[float]): Confidence interval for the AUC, represented as [lower_bound, upper_bound]. data_size (dict[str, int]): Size of the training, test dataset and bootstrap sample size. + tpr_target (float): Target TPR used for computing epsilon. """ # Empirical epsilons @@ -73,6 +74,9 @@ class AnalysisNodeOutput(BaseAnalysisOutput): auc_ci: List[float] # Dataset sizes data_size: dict[str, int] + # TPR target and index (only set when custom tpr_target is provided) + tpr_target: Optional[float] + tpr_idx: Optional[int] class AnalysisNode(BaseAnalysisNode): @@ -93,6 +97,11 @@ class AnalysisNode(BaseAnalysisNode): use_fnr_tnr: boolean for whether to use FNR and TNR in addition to FPR and TPR error thresholds in eps_max_array computation. show_progress: boolean for whether to show tqdm progress bar with_timer: boolean for whether to show timer for analysis node + tpr_target: Optional target TPR for computing epsilon. If None (default), uses legacy 1% intervals. + If specified, uses fine-grained intervals determined by tpr_threshold_width. + tpr_threshold_width: Width (step size) between TPR thresholds for fine-grained mode. + Only used when tpr_target is specified. Default 0.0025 (0.25%). + Start is always fixed at 0.01. num_thresholds = int((1.0 - 0.01) / width) + 1. """ def __init__( @@ -106,6 +115,8 @@ def __init__( use_fnr_tnr: bool = False, show_progress: bool = False, with_timer: bool = False, + tpr_target: Optional[float] = None, + tpr_threshold_width: float = 0.0025, ) -> None: self._delta = delta self._n_users_for_eval = n_users_for_eval @@ -117,6 +128,10 @@ def __init__( self._use_upper_bound = use_upper_bound + self._tpr_target = tpr_target + self._tpr_threshold_width = tpr_threshold_width + self._num_thresholds: int + self._timer_stats: TimerStats = {} if self._n_users_for_eval < 0: @@ -124,8 +139,48 @@ def __init__( 'Input to AnalysisNode "n_users_for_eval" must be nonnegative' ) + if self._tpr_target is not None: + assert isinstance(self._tpr_target, float) + if self._tpr_target < 0.01 or self._tpr_target > 1.0: + raise ValueError( + 'Input to AnalysisNode "tpr_target" must be between 0.01 and 1.0' + ) + + if self._tpr_threshold_width <= 0: + raise ValueError( + 'Input to AnalysisNode "tpr_threshold_width" must be positive' + ) + + if not np.isclose( + 0.99 / self._tpr_threshold_width, + round(0.99 / self._tpr_threshold_width), + ): + raise ValueError( + 'Input to AnalysisNode "tpr_threshold_width" must evenly divide 0.99. ' + "Valid examples: 0.001, 0.0025, 0.003, 0.005, 0.01" + ) + + # Determine num_thresholds based on tpr_target + if self._tpr_target is None: + # Legacy: 1% intervals (0.01 to 1.0, 100 thresholds) + self._num_thresholds = 100 + self._tpr_threshold_width = 0.01 + else: + self._tpr_threshold_width = tpr_threshold_width + self._num_thresholds = int((1.0 - 0.01) / tpr_threshold_width) + 1 + + self._error_thresholds: NDArray[np.floating] = np.linspace( + 0.01, 1.0, self._num_thresholds + ) + super().__init__(analysis_input=analysis_input) + def _get_tpr_index(self) -> int: + """Convert TPR target to array index.""" + if self._tpr_target is None: + return 0 # Legacy behavior: TPR = 1% is at index 0 + return int(np.where(self._error_thresholds == self._tpr_target)[0][0]) + def _calculate_one_off_eps(self) -> float: df_train_user = self.analysis_input.df_train_user df_test_user = self.analysis_input.df_test_user @@ -253,9 +308,10 @@ def run_analysis(self) -> BaseAnalysisOutput: eps_tpr_boundary = eps_tpr_ub if self._use_upper_bound else eps_tpr_lb + tpr_idx = self._get_tpr_index() outputs = AnalysisNodeOutput( - eps=eps_tpr_boundary[0], # epsilon at TPR=1% UB threshold - eps_lb=eps_tpr_lb[0], + eps=eps_tpr_boundary[tpr_idx], # epsilon at specified TPR threshold + eps_lb=eps_tpr_lb[tpr_idx], eps_fpr_max_ub=np.nanmax(eps_fpr_ub), eps_fpr_lb=list(eps_fpr_lb), eps_fpr_ub=list(eps_fpr_ub), @@ -273,6 +329,8 @@ def run_analysis(self) -> BaseAnalysisOutput: "test_size": test_size, "bootstrap_size": bootstrap_sample_size, }, + tpr_target=self._tpr_target, + tpr_idx=tpr_idx, ) if self._with_timer: @@ -313,8 +371,6 @@ def _make_metrics_array( bootstrap_sample_size = min(train_size, test_size) - error_thresholds = np.linspace(0.01, 1, 100) - metrics_array = [ MIAResults( loss_train[ @@ -329,7 +385,7 @@ def _make_metrics_array( ], ).compute_metrics_at_error_threshold( self._delta, - error_threshold=error_thresholds, + error_threshold=self._error_thresholds, cap_eps=self._cap_eps, use_fnr_tnr=self._use_fnr_tnr, verbose=False, diff --git a/privacy_guard/analysis/mia/parallel_analysis_node.py b/privacy_guard/analysis/mia/parallel_analysis_node.py index f63c1fe..8e2db7f 100644 --- a/privacy_guard/analysis/mia/parallel_analysis_node.py +++ b/privacy_guard/analysis/mia/parallel_analysis_node.py @@ -15,6 +15,7 @@ import os import tempfile from concurrent.futures import ProcessPoolExecutor +from typing import Optional import numpy as np import torch @@ -41,6 +42,8 @@ class ParallelAnalysisNode(AnalysisNode): use_upper_bound: boolean for whether to compute epsilon at the upper-bound of CI use_fnr_tnr: boolean for whether to use FNR and TNR in addition to FPR and TPR error thresholds in eps_max_array computation. with_timer: boolean for whether to show timer for analysis node + tpr_target: Optional target TPR for computing epsilon. If None (default), uses legacy 1% intervals. + tpr_threshold_width: Width between TPR thresholds for fine-grained mode. Default 0.0025. """ def __init__( @@ -53,6 +56,8 @@ def __init__( num_bootstrap_resampling_times: int = 1000, use_fnr_tnr: bool = False, with_timer: bool = False, + tpr_target: Optional[float] = None, + tpr_threshold_width: float = 0.0025, ) -> None: super().__init__( analysis_input=analysis_input, @@ -62,6 +67,8 @@ def __init__( num_bootstrap_resampling_times=num_bootstrap_resampling_times, use_fnr_tnr=use_fnr_tnr, with_timer=with_timer, + tpr_target=tpr_target, + tpr_threshold_width=tpr_threshold_width, ) self._eps_computation_tasks_num = eps_computation_tasks_num @@ -87,7 +94,6 @@ def _compute_metrics_array( loss_test: torch.Tensor = torch.load(test_filename, weights_only=True) train_size, test_size = loss_train.shape[0], loss_test.shape[0] bootstrap_sample_size = min(train_size, test_size) - error_thresholds = np.linspace(0.01, 1, 100) metrics_results = [] try: @@ -107,7 +113,7 @@ def _compute_metrics_array( metrics_result = mia_results.compute_metrics_at_error_threshold( self._delta, - error_threshold=error_thresholds, + error_threshold=self._error_thresholds, use_fnr_tnr=self._use_fnr_tnr, verbose=False, ) @@ -221,9 +227,10 @@ def run_analysis(self) -> AnalysisNodeOutput: eps_tpr_boundary = eps_tpr_ub if self._use_upper_bound else eps_tpr_lb + tpr_idx = self._get_tpr_index() outputs = AnalysisNodeOutput( - eps=eps_tpr_boundary[0], # epsilon at TPR=1% UB threshold - eps_lb=eps_tpr_lb[0], + eps=eps_tpr_boundary[tpr_idx], # epsilon at specified TPR threshold + eps_lb=eps_tpr_lb[tpr_idx], eps_fpr_max_ub=np.nanmax(eps_fpr_ub), eps_fpr_lb=list(eps_fpr_lb), eps_fpr_ub=list(eps_fpr_ub), @@ -243,6 +250,8 @@ def run_analysis(self) -> AnalysisNodeOutput: "test_size": test_size, "bootstrap_size": bootstrap_sample_size, }, + tpr_target=self._tpr_target, + tpr_idx=tpr_idx, ) if self._with_timer: diff --git a/privacy_guard/analysis/tests/test_analysis_node.py b/privacy_guard/analysis/tests/test_analysis_node.py index 2c5e1ac..c99fbf2 100644 --- a/privacy_guard/analysis/tests/test_analysis_node.py +++ b/privacy_guard/analysis/tests/test_analysis_node.py @@ -496,3 +496,117 @@ def test_use_fnr_tnr_parameter_comparison(self) -> None: outputs_false["accuracy"], outputs_true["accuracy"], places=10 ) self.assertAlmostEqual(outputs_false["auc"], outputs_true["auc"], places=10) + + def test_get_tpr_index_none_target(self) -> None: + """Test that _get_tpr_index returns 0 when tpr_target is None (legacy behavior).""" + analysis_node = AnalysisNode( + analysis_input=self.analysis_input, + delta=0.000001, + n_users_for_eval=100, + num_bootstrap_resampling_times=10, + tpr_target=None, + ) + self.assertEqual(analysis_node._get_tpr_index(), 0) + + def test_get_tpr_index_with_target(self) -> None: + """Test that _get_tpr_index returns correct index that points to tpr_target.""" + # Create error_thresholds array to get actual values + num_thresholds = int((1.0 - 0.01) / 0.0025) + 1 + error_thresholds = np.linspace(0.01, 1.0, num_thresholds) + + # Test with actual values from the array at various indices + test_indices = [0, 6, 36, 196, num_thresholds - 1] + + for idx in test_indices: + tpr_target = error_thresholds[idx] + analysis_node = AnalysisNode( + analysis_input=self.analysis_input, + delta=0.000001, + n_users_for_eval=100, + num_bootstrap_resampling_times=10, + tpr_target=tpr_target, + tpr_threshold_width=0.0025, + ) + tpr_idx = analysis_node._get_tpr_index() + self.assertEqual( + tpr_idx, + idx, + msg=f"tpr_target={tpr_target}: expected index {idx}, got {tpr_idx}", + ) + + def test_tpr_threshold_width_positive_validation(self) -> None: + """Test that tpr_threshold_width must be positive.""" + with self.assertRaisesRegex(ValueError, "must be positive"): + AnalysisNode( + analysis_input=self.analysis_input, + delta=0.000001, + n_users_for_eval=100, + num_bootstrap_resampling_times=10, + tpr_threshold_width=0, + ) + + with self.assertRaisesRegex(ValueError, "must be positive"): + AnalysisNode( + analysis_input=self.analysis_input, + delta=0.000001, + n_users_for_eval=100, + num_bootstrap_resampling_times=10, + tpr_threshold_width=-0.01, + ) + + def test_tpr_threshold_width_divisibility_validation(self) -> None: + """Test that tpr_threshold_width must evenly divide 0.99.""" + with self.assertRaisesRegex(ValueError, "must evenly divide 0.99"): + AnalysisNode( + analysis_input=self.analysis_input, + delta=0.000001, + n_users_for_eval=100, + num_bootstrap_resampling_times=10, + tpr_threshold_width=0.02, + ) + + def test_tpr_target_range_validation(self) -> None: + """Test that tpr_target must be between 0.01 and 1.0.""" + with self.assertRaisesRegex(ValueError, "must be between 0.01 and 1.0"): + AnalysisNode( + analysis_input=self.analysis_input, + delta=0.000001, + n_users_for_eval=100, + num_bootstrap_resampling_times=10, + tpr_target=0.005, + ) + + with self.assertRaisesRegex(ValueError, "must be between 0.01 and 1.0"): + AnalysisNode( + analysis_input=self.analysis_input, + delta=0.000001, + n_users_for_eval=100, + num_bootstrap_resampling_times=10, + tpr_target=1.5, + ) + + def test_error_thresholds_array_creation(self) -> None: + """Test that _error_thresholds array is correctly created.""" + # Legacy mode: 100 thresholds + analysis_node_legacy = AnalysisNode( + analysis_input=self.analysis_input, + delta=0.000001, + n_users_for_eval=100, + num_bootstrap_resampling_times=10, + tpr_target=None, + ) + self.assertEqual(len(analysis_node_legacy._error_thresholds), 100) + + # Fine-grained mode + analysis_node_fine = AnalysisNode( + analysis_input=self.analysis_input, + delta=0.000001, + n_users_for_eval=100, + num_bootstrap_resampling_times=10, + tpr_target=0.01, + tpr_threshold_width=0.0025, + ) + expected_num_thresholds = int(0.99 / 0.0025) + 1 + self.assertEqual( + len(analysis_node_fine._error_thresholds), expected_num_thresholds + ) diff --git a/privacy_guard/analysis/tests/test_parallel_analysis_node.py b/privacy_guard/analysis/tests/test_parallel_analysis_node.py index 0ce8392..68aff6d 100644 --- a/privacy_guard/analysis/tests/test_parallel_analysis_node.py +++ b/privacy_guard/analysis/tests/test_parallel_analysis_node.py @@ -308,3 +308,22 @@ def test_use_fnr_tnr_parameter(self) -> None: self.assertGreater( len(outputs_false["eps_tpr_ub"]), len(outputs_true["eps_tpr_ub"]) ) + + def test_tpr_target_parameter(self) -> None: + """Test that tpr_target parameter works correctly in ParallelAnalysisNode.""" + parallel_node = ParallelAnalysisNode( + analysis_input=self.analysis_input, + delta=0.000001, + n_users_for_eval=100, + num_bootstrap_resampling_times=10, + eps_computation_tasks_num=2, + tpr_target=0.025, + tpr_threshold_width=0.0025, + ) + # Verify _get_tpr_index returns correct index + tpr_idx = parallel_node._get_tpr_index() + self.assertAlmostEqual( + parallel_node._error_thresholds[tpr_idx], + 0.025, + places=10, + )