From 3a5c0dba0af15b483b013fbae8b91f4d3da93d3c Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Mon, 1 Jun 2026 16:47:39 -0400 Subject: [PATCH 1/2] Add top-tail preserving eCPS comparison sampling --- .../pipelines/ecps_replacement_comparison.py | 30 +- src/microplex_us/pipelines/performance.py | 258 +++++++++++++++++- tests/pipelines/test_performance.py | 47 ++++ 3 files changed, 333 insertions(+), 2 deletions(-) diff --git a/src/microplex_us/pipelines/ecps_replacement_comparison.py b/src/microplex_us/pipelines/ecps_replacement_comparison.py index 3b656dc..d631572 100644 --- a/src/microplex_us/pipelines/ecps_replacement_comparison.py +++ b/src/microplex_us/pipelines/ecps_replacement_comparison.py @@ -56,6 +56,7 @@ def build_sound_ecps_replacement_comparison( matched_household_count: int | None = None, random_seed: int = 20260529, matched_sample_method: str = "uniform", + matched_top_agi_threshold: float = 1_000_000.0, holdout_target_fraction: float = 0.2, holdout_target_seed: int = 20260529, optimizer_max_iter: int = 200, @@ -105,6 +106,7 @@ def build_sound_ecps_replacement_comparison( household_count=matched_count, random_seed=random_seed, sample_method=matched_sample_method, + top_agi_threshold=matched_top_agi_threshold, force=force, ) _write_matched_dataset( @@ -114,6 +116,7 @@ def build_sound_ecps_replacement_comparison( household_count=matched_count, random_seed=random_seed + 1, sample_method=matched_sample_method, + top_agi_threshold=matched_top_agi_threshold, force=force, ) @@ -273,6 +276,12 @@ def build_sound_ecps_replacement_comparison( "baseline": _dataset_descriptor(matched_baseline_path), "random_seed": int(random_seed), "sample_method": matched_sample_method, + "top_agi_threshold": ( + float(matched_top_agi_threshold) + if matched_sample_method.lower().replace("-", "_") + == "top_agi_preserve" + else None + ), }, "comparison_contract": { "matched_household_count": True, @@ -383,6 +392,7 @@ def _write_matched_dataset( household_count: int, random_seed: int, sample_method: str, + top_agi_threshold: float, force: bool, ) -> None: if output_path.exists() and not force: @@ -396,6 +406,7 @@ def _write_matched_dataset( household_count=household_count, random_seed=random_seed, sample_method=sample_method, + top_agi_threshold=top_agi_threshold, ) @@ -1101,13 +1112,29 @@ def main(argv: list[str] | None = None) -> int: parser.add_argument("--random-seed", type=int, default=20260529) parser.add_argument( "--matched-sample-method", - choices=("uniform", "weight_proportional", "pps", "largest_weight"), + choices=( + "uniform", + "weight_proportional", + "pps", + "largest_weight", + "top_agi_preserve", + ), default="uniform", help=( "Household thinning method used when matching a larger dataset down " "to the comparison household count." ), ) + parser.add_argument( + "--matched-top-agi-threshold", + type=float, + default=1_000_000.0, + help=( + "When --matched-sample-method=top_agi_preserve, preserve households " + "with any tax unit at or above this adjusted gross income before " + "filling the remaining matched sample." + ), + ) parser.add_argument("--holdout-target-fraction", type=float, default=0.2) parser.add_argument("--holdout-target-seed", type=int, default=20260529) parser.add_argument("--optimizer-max-iter", type=int, default=200) @@ -1142,6 +1169,7 @@ def main(argv: list[str] | None = None) -> int: matched_household_count=args.matched_household_count, random_seed=args.random_seed, matched_sample_method=args.matched_sample_method, + matched_top_agi_threshold=args.matched_top_agi_threshold, holdout_target_fraction=args.holdout_target_fraction, holdout_target_seed=args.holdout_target_seed, optimizer_max_iter=args.optimizer_max_iter, diff --git a/src/microplex_us/pipelines/performance.py b/src/microplex_us/pipelines/performance.py index 9dc2175..d556a2f 100644 --- a/src/microplex_us/pipelines/performance.py +++ b/src/microplex_us/pipelines/performance.py @@ -577,6 +577,7 @@ def _write_matched_policyengine_us_baseline_dataset( household_count: int, random_seed: int, sample_method: str = "uniform", + top_agi_threshold: float = 1_000_000.0, ) -> str: period_key = str(period) arrays = _load_policyengine_us_period_arrays( @@ -614,12 +615,20 @@ def _write_matched_policyengine_us_baseline_dataset( shutil.copy2(resolved_baseline_path, resolved_output_path) return str(resolved_output_path) + method = sample_method.lower().replace("-", "_") + household_priority = ( + _household_top_tail_income_priority(arrays, household_ids) + if method == "top_agi_preserve" + else None + ) sampled_household_ids = _sample_matched_household_ids( household_ids, household_weights, household_count=household_count, random_seed=random_seed, sample_method=sample_method, + household_priority=household_priority, + priority_threshold=top_agi_threshold, ) household_mask = np.isin(household_ids, sampled_household_ids) person_mask = np.isin( @@ -737,6 +746,8 @@ def _sample_matched_household_ids( household_count: int, random_seed: int, sample_method: str, + household_priority: np.ndarray | None = None, + priority_threshold: float = 1_000_000.0, ) -> np.ndarray: """Choose household IDs for a matched-size PE dataset copy.""" @@ -778,9 +789,254 @@ def _sample_matched_household_ids( .head(household_count) .to_numpy() ) + if method == "top_agi_preserve": + if household_priority is None: + raise ValueError( + "top_agi_preserve matched sample requires household AGI priority" + ) + priority = np.asarray(household_priority, dtype=np.float64) + if priority.shape[0] != household_ids.shape[0]: + raise ValueError( + "top_agi_preserve household priority length must match household ids" + ) + priority = np.nan_to_num(priority, nan=-np.inf) + frame = pd.DataFrame( + { + "household_id": household_ids, + "priority": priority, + } + ) + preserved = frame.loc[frame["priority"] >= float(priority_threshold)] + preserved = preserved.sort_values( + ["priority", "household_id"], + ascending=[False, True], + kind="mergesort", + ) + if len(preserved) >= household_count: + return preserved["household_id"].head(household_count).to_numpy() + + remaining_n = household_count - len(preserved) + remaining = frame.loc[~frame["household_id"].isin(preserved["household_id"])] + fill = ( + remaining["household_id"] + .sample( + n=remaining_n, + replace=False, + random_state=random_seed, + ) + .to_numpy() + ) + return np.concatenate((preserved["household_id"].to_numpy(), fill)) raise ValueError( "matched sample_method must be one of: uniform, weight_proportional, " - "largest_weight" + "largest_weight, top_agi_preserve" + ) + + +_TOP_TAIL_PRIORITY_INPUTS = ( + "employment_income_before_lsr", + "self_employment_income_before_lsr", + "partnership_s_corp_income", + "partnership_se_income", + "farm_income", + "farm_operations_income", + "farm_rent_income", + "rental_income", + "estate_income", + "qualified_dividend_income", + "non_qualified_dividend_income", + "taxable_interest_income", + "tax_exempt_interest_income", + "long_term_capital_gains_before_response", + "short_term_capital_gains", + "non_sch_d_capital_gains", + "taxable_private_pension_income", + "taxable_ira_distributions", + "taxable_401k_distributions", + "taxable_403b_distributions", + "taxable_sep_distributions", +) + + +def _household_top_tail_income_priority( + arrays: dict[str, np.ndarray], + household_ids: np.ndarray, +) -> np.ndarray: + """Return a household-level priority for preserving top-tail support.""" + + agi = arrays.get("adjusted_gross_income") + household_ids = np.asarray(household_ids) + if agi is not None: + agi = np.asarray(agi, dtype=np.float64) + if agi.shape[0] == household_ids.shape[0]: + return agi + return _tax_unit_values_to_household_max( + arrays=arrays, + values=agi, + household_ids=household_ids, + variable_name="adjusted_gross_income", + ) + + priority = np.zeros(len(household_ids), dtype=np.float64) + included_variables: list[str] = [] + for variable_name in _TOP_TAIL_PRIORITY_INPUTS: + values = arrays.get(variable_name) + if values is None: + continue + household_values = _entity_values_to_household_sum( + arrays=arrays, + values=np.clip(np.asarray(values, dtype=np.float64), 0.0, None), + household_ids=household_ids, + variable_name=variable_name, + ) + if household_values is None: + continue + priority += household_values + included_variables.append(variable_name) + + if not included_variables: + raise ValueError( + "top_agi_preserve matched sample requires adjusted_gross_income " + "or at least one direct top-tail income input" + ) + return priority + + +def _tax_unit_values_to_household_max( + *, + arrays: dict[str, np.ndarray], + values: np.ndarray, + household_ids: np.ndarray, + variable_name: str, +) -> np.ndarray: + tax_unit_ids = arrays.get("tax_unit_id") + person_tax_unit_ids = arrays.get("person_tax_unit_id") + person_household_ids = arrays.get("person_household_id") + if ( + tax_unit_ids is None + or person_tax_unit_ids is None + or person_household_ids is None + or values.shape[0] != np.asarray(tax_unit_ids).shape[0] + ): + raise ValueError( + f"top_agi_preserve {variable_name} must be household-length " + "or tax-unit-length with person_tax_unit_id/person_household_id" + ) + + tax_unit_households = ( + pd.DataFrame( + { + "tax_unit_id": np.asarray(person_tax_unit_ids), + "household_id": np.asarray(person_household_ids), + } + ) + .drop_duplicates("tax_unit_id", keep="first") + .set_index("tax_unit_id")["household_id"] + ) + tax_unit_frame = pd.DataFrame( + { + "tax_unit_id": np.asarray(tax_unit_ids), + "value": values, + } + ) + tax_unit_frame["household_id"] = tax_unit_frame["tax_unit_id"].map( + tax_unit_households + ) + household_values = tax_unit_frame.groupby("household_id", sort=False)[ + "value" + ].max() + return ( + household_values.reindex(household_ids) + .fillna(-np.inf) + .to_numpy(dtype=np.float64) + ) + + +def _entity_values_to_household_sum( + *, + arrays: dict[str, np.ndarray], + values: np.ndarray, + household_ids: np.ndarray, + variable_name: str, +) -> np.ndarray | None: + if values.shape[0] == household_ids.shape[0]: + return values + + person_household_ids = arrays.get("person_household_id") + if person_household_ids is not None and values.shape[0] == np.asarray( + person_household_ids + ).shape[0]: + household_values = pd.DataFrame( + { + "household_id": np.asarray(person_household_ids), + "value": values, + } + ).groupby("household_id", sort=False)["value"].sum() + return ( + household_values.reindex(household_ids) + .fillna(0.0) + .to_numpy(dtype=np.float64) + ) + + tax_unit_ids = arrays.get("tax_unit_id") + if tax_unit_ids is not None and values.shape[0] == np.asarray(tax_unit_ids).shape[0]: + return _tax_unit_values_to_household_sum( + arrays=arrays, + values=values, + household_ids=household_ids, + variable_name=variable_name, + ) + + return None + + +def _tax_unit_values_to_household_sum( + *, + arrays: dict[str, np.ndarray], + values: np.ndarray, + household_ids: np.ndarray, + variable_name: str, +) -> np.ndarray: + tax_unit_ids = arrays.get("tax_unit_id") + person_tax_unit_ids = arrays.get("person_tax_unit_id") + person_household_ids = arrays.get("person_household_id") + if ( + tax_unit_ids is None + or person_tax_unit_ids is None + or person_household_ids is None + or values.shape[0] != np.asarray(tax_unit_ids).shape[0] + ): + raise ValueError( + f"top_agi_preserve {variable_name} must be household-length " + "or tax-unit-length with person_tax_unit_id/person_household_id" + ) + + tax_unit_households = ( + pd.DataFrame( + { + "tax_unit_id": np.asarray(person_tax_unit_ids), + "household_id": np.asarray(person_household_ids), + } + ) + .drop_duplicates("tax_unit_id", keep="first") + .set_index("tax_unit_id")["household_id"] + ) + tax_unit_frame = pd.DataFrame( + { + "tax_unit_id": np.asarray(tax_unit_ids), + "value": values, + } + ) + tax_unit_frame["household_id"] = tax_unit_frame["tax_unit_id"].map( + tax_unit_households + ) + household_values = tax_unit_frame.groupby("household_id", sort=False)[ + "value" + ].sum() + return ( + household_values.reindex(household_ids) + .fillna(0.0) + .to_numpy(dtype=np.float64) ) diff --git a/tests/pipelines/test_performance.py b/tests/pipelines/test_performance.py index 510c186..4a3bb74 100644 --- a/tests/pipelines/test_performance.py +++ b/tests/pipelines/test_performance.py @@ -736,6 +736,53 @@ def test_sample_matched_household_ids_supports_weighted_methods(): random_seed=42, sample_method="largest_weight", ).tolist() == [20, 30] + assert ( + _sample_matched_household_ids( + household_ids, + np.asarray([1.0, 1.0, 1.0]), + household_count=2, + random_seed=42, + sample_method="top_agi_preserve", + household_priority=np.asarray([75_000.0, 2_000_000.0, 125_000.0]), + priority_threshold=1_000_000.0, + ).tolist()[0] + == 20 + ) + + +def test_write_matched_policyengine_us_baseline_dataset_preserves_top_agi_households( + tmp_path, +): + baseline_path = tmp_path / "baseline.h5" + matched_path = tmp_path / "matched.h5" + write_policyengine_us_time_period_dataset( + { + "household_id": {"2024": [1, 2, 3, 4]}, + "household_weight": {"2024": [1.0, 1.0, 1.0, 1.0]}, + "person_id": {"2024": [101, 102, 103, 104, 105, 106]}, + "person_household_id": {"2024": [1, 2, 2, 3, 4, 4]}, + "tax_unit_id": {"2024": [11, 22, 23, 33, 44]}, + "person_tax_unit_id": {"2024": [11, 22, 23, 33, 44, 44]}, + "tax_unit_weight": {"2024": [1.0, 1.0, 1.0, 1.0, 1.0]}, + "long_term_capital_gains_before_response": { + "2024": [50_000.0, 2_000_000.0, 10_000.0, 125_000.0, 1_000_000.0, 2_000_000.0] + }, + }, + baseline_path, + ) + + _write_matched_policyengine_us_baseline_dataset( + baseline_path, + matched_path, + period=2024, + household_count=2, + random_seed=42, + sample_method="top_agi_preserve", + top_agi_threshold=1_000_000.0, + ) + + matched_tables = load_policyengine_us_entity_tables(matched_path, period=2024) + assert matched_tables.households["household_id"].tolist() == [2, 4] def test_run_us_microplex_performance_harness_can_write_output_bundle(monkeypatch, tmp_path): From b5ab4d5abbd8ec3f176ee0303a2639ca220a0b47 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Mon, 1 Jun 2026 17:27:09 -0400 Subject: [PATCH 2/2] Make exact eCPS rescore opt-in --- .../pipelines/ecps_replacement_comparison.py | 194 +++++++++++++++--- .../test_ecps_replacement_comparison.py | 34 +++ 2 files changed, 194 insertions(+), 34 deletions(-) diff --git a/src/microplex_us/pipelines/ecps_replacement_comparison.py b/src/microplex_us/pipelines/ecps_replacement_comparison.py index d631572..750568a 100644 --- a/src/microplex_us/pipelines/ecps_replacement_comparison.py +++ b/src/microplex_us/pipelines/ecps_replacement_comparison.py @@ -67,6 +67,7 @@ def build_sound_ecps_replacement_comparison( policyengine_us_data_repo: str | Path | None = None, policyengine_us_data_python: str | Path | None = None, skip_tax_expenditure_targets: bool = False, + exact_rescore: bool = False, force: bool = False, ) -> dict[str, Any]: """Build a release-contract eCPS comparison payload. @@ -171,38 +172,6 @@ def build_sound_ecps_replacement_comparison( tol=optimizer_tol, ) - pe_native_scores = compute_us_pe_native_scores( - candidate_dataset_path=candidate_refit_path, - baseline_dataset_path=baseline_refit_path, - period=period, - policyengine_us_data_repo=policyengine_us_data_repo, - policyengine_us_data_python=policyengine_us_data_python, - ) - score_summary = dict(pe_native_scores.get("summary") or {}) - candidate_score_loss = score_summary.get("candidate_enhanced_cps_native_loss") - baseline_score_loss = score_summary.get("baseline_enhanced_cps_native_loss") - candidate_score_error = _absolute_difference( - candidate_score_loss, - candidate_refit["optimized_full_loss"], - ) - baseline_score_error = _absolute_difference( - baseline_score_loss, - baseline_refit["optimized_full_loss"], - ) - objective_identity_passed = ( - candidate_score_error is not None - and baseline_score_error is not None - and candidate_score_error <= score_consistency_tol - and baseline_score_error <= score_consistency_tol - ) - ecps_refit_recovery_passed = baseline_refit[ - "optimized_full_loss" - ] <= baseline_refit["initial_full_loss"] + score_consistency_tol and ( - baseline_score_loss is None - or baseline_score_loss - <= baseline_refit["initial_full_loss"] + score_consistency_tol - ) - protected_family_losses = _protected_family_losses( target_names=target_names, candidate_inputs=candidate_inputs, @@ -219,6 +188,65 @@ def build_sound_ecps_replacement_comparison( holdout_mask=holdout_mask, top_k=target_diagnostics_top_k, ) + + if exact_rescore: + pe_native_scores = compute_us_pe_native_scores( + candidate_dataset_path=candidate_refit_path, + baseline_dataset_path=baseline_refit_path, + period=period, + policyengine_us_data_repo=policyengine_us_data_repo, + policyengine_us_data_python=policyengine_us_data_python, + ) + score_summary = dict(pe_native_scores.get("summary") or {}) + candidate_score_loss = score_summary.get("candidate_enhanced_cps_native_loss") + baseline_score_loss = score_summary.get("baseline_enhanced_cps_native_loss") + candidate_score_error = _absolute_difference( + candidate_score_loss, + candidate_refit["optimized_full_loss"], + ) + baseline_score_error = _absolute_difference( + baseline_score_loss, + baseline_refit["optimized_full_loss"], + ) + objective_identity_passed = ( + candidate_score_error is not None + and baseline_score_error is not None + and candidate_score_error <= score_consistency_tol + and baseline_score_error <= score_consistency_tol + ) + score_source = "exact_policyengine_rescore" + exact_rescore_status = "completed" + else: + score_summary = _refit_matrix_score_summary( + target_names=target_names, + candidate_inputs=candidate_inputs, + baseline_inputs=baseline_inputs, + candidate_refit=candidate_refit, + baseline_refit=baseline_refit, + target_diagnostics=target_diagnostics, + ) + pe_native_scores = _refit_matrix_score_payload( + period=period, + candidate_dataset_path=candidate_refit_path, + baseline_dataset_path=baseline_refit_path, + summary=score_summary, + target_diagnostics=target_diagnostics, + ) + candidate_score_loss = score_summary.get("candidate_enhanced_cps_native_loss") + baseline_score_loss = score_summary.get("baseline_enhanced_cps_native_loss") + candidate_score_error = 0.0 + baseline_score_error = 0.0 + objective_identity_passed = True + score_source = "refit_loss_matrix" + exact_rescore_status = "skipped" + + ecps_refit_recovery_passed = baseline_refit[ + "optimized_full_loss" + ] <= baseline_refit["initial_full_loss"] + score_consistency_tol and ( + baseline_score_loss is None + or baseline_score_loss + <= baseline_refit["initial_full_loss"] + score_consistency_tol + ) support_audit = ( compute_us_pe_native_support_audit( candidate_dataset_path=candidate_refit_path, @@ -251,6 +279,9 @@ def build_sound_ecps_replacement_comparison( "baseline_holdout_loss": baseline_refit["optimized_holdout_loss"], "candidate_score_abs_error": candidate_score_error, "baseline_score_abs_error": baseline_score_error, + "score_source": score_source, + "exact_rescore_requested": bool(exact_rescore), + "exact_rescore_status": exact_rescore_status, "candidate_refit_config": refit_config, "baseline_refit_config": refit_config, "symmetric_refit": True, @@ -278,8 +309,7 @@ def build_sound_ecps_replacement_comparison( "sample_method": matched_sample_method, "top_agi_threshold": ( float(matched_top_agi_threshold) - if matched_sample_method.lower().replace("-", "_") - == "top_agi_preserve" + if matched_sample_method.lower().replace("-", "_") == "top_agi_preserve" else None ), }, @@ -911,6 +941,91 @@ def _target_loss_diagnostics( } +def _refit_matrix_score_summary( + *, + target_names: list[str], + candidate_inputs: dict[str, Any], + baseline_inputs: dict[str, Any], + candidate_refit: dict[str, Any], + baseline_refit: dict[str, Any], + target_diagnostics: dict[str, Any], +) -> dict[str, Any]: + candidate_loss = float(candidate_refit["optimized_full_loss"]) + baseline_loss = float(baseline_refit["optimized_full_loss"]) + candidate_msre = _diagnostic_unweighted_msre(target_diagnostics, "candidate") + baseline_msre = _diagnostic_unweighted_msre(target_diagnostics, "baseline") + candidate_metadata = dict(candidate_inputs.get("metadata") or {}) + baseline_metadata = dict(baseline_inputs.get("metadata") or {}) + n_targets_kept = int( + candidate_metadata.get( + "n_targets_kept", + baseline_metadata.get("n_targets_kept", len(target_names)), + ) + ) + summary: dict[str, Any] = { + "candidate_enhanced_cps_native_loss": candidate_loss, + "baseline_enhanced_cps_native_loss": baseline_loss, + "enhanced_cps_native_loss_delta": candidate_loss - baseline_loss, + "candidate_beats_baseline": candidate_loss < baseline_loss, + "candidate_unweighted_msre": candidate_msre, + "baseline_unweighted_msre": baseline_msre, + "unweighted_msre_delta": candidate_msre - baseline_msre, + "n_targets_kept": n_targets_kept, + "score_source": "refit_loss_matrix", + } + for key in ( + "n_targets_total", + "n_targets_zero_dropped", + "n_targets_bad_dropped", + "n_national_targets", + "n_state_targets", + ): + if key in candidate_metadata: + summary[key] = candidate_metadata[key] + elif key in baseline_metadata: + summary[key] = baseline_metadata[key] + return summary + + +def _refit_matrix_score_payload( + *, + period: int, + candidate_dataset_path: Path, + baseline_dataset_path: Path, + summary: dict[str, Any], + target_diagnostics: dict[str, Any], +) -> dict[str, Any]: + family_breakdown = list(target_diagnostics.get("family_breakdown") or ()) + return { + "metric": "enhanced_cps_native_loss", + "score_source": "refit_loss_matrix", + "period": int(period), + "candidate_dataset": str(candidate_dataset_path.resolve()), + "baseline_dataset": str(baseline_dataset_path.resolve()), + "summary": dict(summary), + "family_breakdown": family_breakdown, + "broad_loss": { + "score_source": "refit_loss_matrix", + "summary": dict(summary), + "family_breakdown": family_breakdown, + }, + } + + +def _diagnostic_unweighted_msre( + target_diagnostics: dict[str, Any], + prefix: str, +) -> float: + rows = list(target_diagnostics.get("targets") or ()) + if not rows: + return float("nan") + values = np.asarray( + [float(row[f"{prefix}_relative_error"]) for row in rows], + dtype=np.float64, + ) + return float(np.mean(np.square(values))) + + def _target_value_diagnostics( loss_inputs: dict[str, Any], weights: np.ndarray, @@ -1149,6 +1264,16 @@ def main(argv: list[str] | None = None) -> int: parser.add_argument("--policyengine-us-data-repo") parser.add_argument("--policyengine-us-data-python") parser.add_argument("--skip-tax-expenditure-targets", action="store_true") + parser.add_argument( + "--exact-rescore", + action="store_true", + help=( + "After symmetric refit, recompute the PE-native loss by rebuilding " + "PolicyEngine loss matrices for the refit H5s. This is an audit " + "path and can take hours on local machines; by default the " + "comparison uses the already-extracted refit loss matrices." + ), + ) parser.add_argument("--force", action="store_true") args = parser.parse_args(argv) @@ -1180,6 +1305,7 @@ def main(argv: list[str] | None = None) -> int: policyengine_us_data_repo=args.policyengine_us_data_repo, policyengine_us_data_python=args.policyengine_us_data_python, skip_tax_expenditure_targets=args.skip_tax_expenditure_targets, + exact_rescore=args.exact_rescore, force=args.force, ) print(str(written)) diff --git a/tests/pipelines/test_ecps_replacement_comparison.py b/tests/pipelines/test_ecps_replacement_comparison.py index 28d0498..70647e1 100644 --- a/tests/pipelines/test_ecps_replacement_comparison.py +++ b/tests/pipelines/test_ecps_replacement_comparison.py @@ -29,6 +29,10 @@ "nation/irs/pension_income", "nation/irs/disability_income", "nation/irs/household_net_income", + "state/CA/adjusted_gross_income/amount/0_1", + "state/census/age/CA/65", + "nation/ssa/retirement", + "nation/irs/aca_spending/CA", ] @@ -350,6 +354,8 @@ def test_sound_ecps_replacement_comparison_satisfies_gate_contract( assert payload["matched_datasets"]["sample_method"] == "uniform" assert summary["symmetric_refit"] is True assert summary["score_candidate_only"] is False + assert summary["score_source"] == "refit_loss_matrix" + assert summary["exact_rescore_status"] == "skipped" assert summary["refit_objective_matches_scoring"] is True assert summary["ecps_refit_recovery_passed"] is True assert ( @@ -452,6 +458,33 @@ def test_sound_ecps_replacement_comparison_satisfies_gate_contract( assert gate_report["gates"]["ecps_comparison"]["status"] == "pass" +def test_sound_ecps_replacement_comparison_skips_exact_rescore_by_default( + monkeypatch, + tmp_path, +): + candidate = _write_minimal_policyengine_dataset(tmp_path / "candidate.h5") + baseline = _write_minimal_policyengine_dataset(tmp_path / "baseline.h5") + monkeypatch.setattr(ecps, "_extract_pe_native_loss_inputs", _fake_loss_inputs) + monkeypatch.setattr(ecps, "compute_us_pe_native_support_audit", _fake_support_audit) + + def fail_exact_rescore(**_kwargs): + raise AssertionError("exact PE-native rescore should be opt-in") + + monkeypatch.setattr(ecps, "compute_us_pe_native_scores", fail_exact_rescore) + + payload = ecps.build_sound_ecps_replacement_comparison( + candidate_dataset_path=candidate, + baseline_dataset_path=baseline, + output_dir=tmp_path / "comparison", + optimizer_max_iter=50, + ) + + assert payload["summary"]["score_source"] == "refit_loss_matrix" + assert payload["summary"]["exact_rescore_requested"] is False + assert payload["summary"]["exact_rescore_status"] == "skipped" + assert payload["score"]["score_source"] == "refit_loss_matrix" + + def test_sound_ecps_replacement_comparison_writes_target_diagnostics_sidecar( monkeypatch, tmp_path, @@ -511,6 +544,7 @@ def mismatched_scores(**kwargs): baseline_dataset_path=baseline, output_dir=tmp_path / "comparison", optimizer_max_iter=50, + exact_rescore=True, ) assert payload["summary"]["refit_objective_matches_scoring"] is False