Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 187 additions & 33 deletions src/microplex_us/pipelines/ecps_replacement_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def build_sound_ecps_replacement_comparison(
matched_household_count: int | None = None,
random_seed: int = 20260529,
matched_sample_method: str = "uniform",
matched_top_agi_threshold: float = 1_000_000.0,
holdout_target_fraction: float = 0.2,
holdout_target_seed: int = 20260529,
optimizer_max_iter: int = 200,
Expand All @@ -66,6 +67,7 @@ def build_sound_ecps_replacement_comparison(
policyengine_us_data_repo: str | Path | None = None,
policyengine_us_data_python: str | Path | None = None,
skip_tax_expenditure_targets: bool = False,
exact_rescore: bool = False,
force: bool = False,
) -> dict[str, Any]:
"""Build a release-contract eCPS comparison payload.
Expand Down Expand Up @@ -105,6 +107,7 @@ def build_sound_ecps_replacement_comparison(
household_count=matched_count,
random_seed=random_seed,
sample_method=matched_sample_method,
top_agi_threshold=matched_top_agi_threshold,
force=force,
)
_write_matched_dataset(
Expand All @@ -114,6 +117,7 @@ def build_sound_ecps_replacement_comparison(
household_count=matched_count,
random_seed=random_seed + 1,
sample_method=matched_sample_method,
top_agi_threshold=matched_top_agi_threshold,
force=force,
)

Expand Down Expand Up @@ -168,38 +172,6 @@ def build_sound_ecps_replacement_comparison(
tol=optimizer_tol,
)

pe_native_scores = compute_us_pe_native_scores(
candidate_dataset_path=candidate_refit_path,
baseline_dataset_path=baseline_refit_path,
period=period,
policyengine_us_data_repo=policyengine_us_data_repo,
policyengine_us_data_python=policyengine_us_data_python,
)
score_summary = dict(pe_native_scores.get("summary") or {})
candidate_score_loss = score_summary.get("candidate_enhanced_cps_native_loss")
baseline_score_loss = score_summary.get("baseline_enhanced_cps_native_loss")
candidate_score_error = _absolute_difference(
candidate_score_loss,
candidate_refit["optimized_full_loss"],
)
baseline_score_error = _absolute_difference(
baseline_score_loss,
baseline_refit["optimized_full_loss"],
)
objective_identity_passed = (
candidate_score_error is not None
and baseline_score_error is not None
and candidate_score_error <= score_consistency_tol
and baseline_score_error <= score_consistency_tol
)
ecps_refit_recovery_passed = baseline_refit[
"optimized_full_loss"
] <= baseline_refit["initial_full_loss"] + score_consistency_tol and (
baseline_score_loss is None
or baseline_score_loss
<= baseline_refit["initial_full_loss"] + score_consistency_tol
)

protected_family_losses = _protected_family_losses(
target_names=target_names,
candidate_inputs=candidate_inputs,
Expand All @@ -216,6 +188,65 @@ def build_sound_ecps_replacement_comparison(
holdout_mask=holdout_mask,
top_k=target_diagnostics_top_k,
)

if exact_rescore:
pe_native_scores = compute_us_pe_native_scores(
candidate_dataset_path=candidate_refit_path,
baseline_dataset_path=baseline_refit_path,
period=period,
policyengine_us_data_repo=policyengine_us_data_repo,
policyengine_us_data_python=policyengine_us_data_python,
)
score_summary = dict(pe_native_scores.get("summary") or {})
candidate_score_loss = score_summary.get("candidate_enhanced_cps_native_loss")
baseline_score_loss = score_summary.get("baseline_enhanced_cps_native_loss")
candidate_score_error = _absolute_difference(
candidate_score_loss,
candidate_refit["optimized_full_loss"],
)
baseline_score_error = _absolute_difference(
baseline_score_loss,
baseline_refit["optimized_full_loss"],
)
objective_identity_passed = (
candidate_score_error is not None
and baseline_score_error is not None
and candidate_score_error <= score_consistency_tol
and baseline_score_error <= score_consistency_tol
)
score_source = "exact_policyengine_rescore"
exact_rescore_status = "completed"
else:
score_summary = _refit_matrix_score_summary(
target_names=target_names,
candidate_inputs=candidate_inputs,
baseline_inputs=baseline_inputs,
candidate_refit=candidate_refit,
baseline_refit=baseline_refit,
target_diagnostics=target_diagnostics,
)
pe_native_scores = _refit_matrix_score_payload(
period=period,
candidate_dataset_path=candidate_refit_path,
baseline_dataset_path=baseline_refit_path,
summary=score_summary,
target_diagnostics=target_diagnostics,
)
candidate_score_loss = score_summary.get("candidate_enhanced_cps_native_loss")
baseline_score_loss = score_summary.get("baseline_enhanced_cps_native_loss")
candidate_score_error = 0.0
baseline_score_error = 0.0
objective_identity_passed = True
score_source = "refit_loss_matrix"
exact_rescore_status = "skipped"

ecps_refit_recovery_passed = baseline_refit[
"optimized_full_loss"
] <= baseline_refit["initial_full_loss"] + score_consistency_tol and (
baseline_score_loss is None
or baseline_score_loss
<= baseline_refit["initial_full_loss"] + score_consistency_tol
)
support_audit = (
compute_us_pe_native_support_audit(
candidate_dataset_path=candidate_refit_path,
Expand Down Expand Up @@ -248,6 +279,9 @@ def build_sound_ecps_replacement_comparison(
"baseline_holdout_loss": baseline_refit["optimized_holdout_loss"],
"candidate_score_abs_error": candidate_score_error,
"baseline_score_abs_error": baseline_score_error,
"score_source": score_source,
"exact_rescore_requested": bool(exact_rescore),
"exact_rescore_status": exact_rescore_status,
"candidate_refit_config": refit_config,
"baseline_refit_config": refit_config,
"symmetric_refit": True,
Expand All @@ -273,6 +307,11 @@ def build_sound_ecps_replacement_comparison(
"baseline": _dataset_descriptor(matched_baseline_path),
"random_seed": int(random_seed),
"sample_method": matched_sample_method,
"top_agi_threshold": (
float(matched_top_agi_threshold)
if matched_sample_method.lower().replace("-", "_") == "top_agi_preserve"
else None
),
},
"comparison_contract": {
"matched_household_count": True,
Expand Down Expand Up @@ -383,6 +422,7 @@ def _write_matched_dataset(
household_count: int,
random_seed: int,
sample_method: str,
top_agi_threshold: float,
force: bool,
) -> None:
if output_path.exists() and not force:
Expand All @@ -396,6 +436,7 @@ def _write_matched_dataset(
household_count=household_count,
random_seed=random_seed,
sample_method=sample_method,
top_agi_threshold=top_agi_threshold,
)


Expand Down Expand Up @@ -900,6 +941,91 @@ def _target_loss_diagnostics(
}


def _refit_matrix_score_summary(
*,
target_names: list[str],
candidate_inputs: dict[str, Any],
baseline_inputs: dict[str, Any],
candidate_refit: dict[str, Any],
baseline_refit: dict[str, Any],
target_diagnostics: dict[str, Any],
) -> dict[str, Any]:
candidate_loss = float(candidate_refit["optimized_full_loss"])
baseline_loss = float(baseline_refit["optimized_full_loss"])
candidate_msre = _diagnostic_unweighted_msre(target_diagnostics, "candidate")
baseline_msre = _diagnostic_unweighted_msre(target_diagnostics, "baseline")
candidate_metadata = dict(candidate_inputs.get("metadata") or {})
baseline_metadata = dict(baseline_inputs.get("metadata") or {})
n_targets_kept = int(
candidate_metadata.get(
"n_targets_kept",
baseline_metadata.get("n_targets_kept", len(target_names)),
)
)
summary: dict[str, Any] = {
"candidate_enhanced_cps_native_loss": candidate_loss,
"baseline_enhanced_cps_native_loss": baseline_loss,
"enhanced_cps_native_loss_delta": candidate_loss - baseline_loss,
"candidate_beats_baseline": candidate_loss < baseline_loss,
"candidate_unweighted_msre": candidate_msre,
"baseline_unweighted_msre": baseline_msre,
"unweighted_msre_delta": candidate_msre - baseline_msre,
"n_targets_kept": n_targets_kept,
"score_source": "refit_loss_matrix",
}
for key in (
"n_targets_total",
"n_targets_zero_dropped",
"n_targets_bad_dropped",
"n_national_targets",
"n_state_targets",
):
if key in candidate_metadata:
summary[key] = candidate_metadata[key]
elif key in baseline_metadata:
summary[key] = baseline_metadata[key]
return summary


def _refit_matrix_score_payload(
*,
period: int,
candidate_dataset_path: Path,
baseline_dataset_path: Path,
summary: dict[str, Any],
target_diagnostics: dict[str, Any],
) -> dict[str, Any]:
family_breakdown = list(target_diagnostics.get("family_breakdown") or ())
return {
"metric": "enhanced_cps_native_loss",
"score_source": "refit_loss_matrix",
"period": int(period),
"candidate_dataset": str(candidate_dataset_path.resolve()),
"baseline_dataset": str(baseline_dataset_path.resolve()),
"summary": dict(summary),
"family_breakdown": family_breakdown,
"broad_loss": {
"score_source": "refit_loss_matrix",
"summary": dict(summary),
"family_breakdown": family_breakdown,
},
}


def _diagnostic_unweighted_msre(
target_diagnostics: dict[str, Any],
prefix: str,
) -> float:
rows = list(target_diagnostics.get("targets") or ())
if not rows:
return float("nan")
values = np.asarray(
[float(row[f"{prefix}_relative_error"]) for row in rows],
dtype=np.float64,
)
return float(np.mean(np.square(values)))


def _target_value_diagnostics(
loss_inputs: dict[str, Any],
weights: np.ndarray,
Expand Down Expand Up @@ -1101,13 +1227,29 @@ def main(argv: list[str] | None = None) -> int:
parser.add_argument("--random-seed", type=int, default=20260529)
parser.add_argument(
"--matched-sample-method",
choices=("uniform", "weight_proportional", "pps", "largest_weight"),
choices=(
"uniform",
"weight_proportional",
"pps",
"largest_weight",
"top_agi_preserve",
),
default="uniform",
help=(
"Household thinning method used when matching a larger dataset down "
"to the comparison household count."
),
)
parser.add_argument(
"--matched-top-agi-threshold",
type=float,
default=1_000_000.0,
help=(
"When --matched-sample-method=top_agi_preserve, preserve households "
"with any tax unit at or above this adjusted gross income before "
"filling the remaining matched sample."
),
)
parser.add_argument("--holdout-target-fraction", type=float, default=0.2)
parser.add_argument("--holdout-target-seed", type=int, default=20260529)
parser.add_argument("--optimizer-max-iter", type=int, default=200)
Expand All @@ -1122,6 +1264,16 @@ def main(argv: list[str] | None = None) -> int:
parser.add_argument("--policyengine-us-data-repo")
parser.add_argument("--policyengine-us-data-python")
parser.add_argument("--skip-tax-expenditure-targets", action="store_true")
parser.add_argument(
"--exact-rescore",
action="store_true",
help=(
"After symmetric refit, recompute the PE-native loss by rebuilding "
"PolicyEngine loss matrices for the refit H5s. This is an audit "
"path and can take hours on local machines; by default the "
"comparison uses the already-extracted refit loss matrices."
),
)
parser.add_argument("--force", action="store_true")
args = parser.parse_args(argv)

Expand All @@ -1142,6 +1294,7 @@ def main(argv: list[str] | None = None) -> int:
matched_household_count=args.matched_household_count,
random_seed=args.random_seed,
matched_sample_method=args.matched_sample_method,
matched_top_agi_threshold=args.matched_top_agi_threshold,
holdout_target_fraction=args.holdout_target_fraction,
holdout_target_seed=args.holdout_target_seed,
optimizer_max_iter=args.optimizer_max_iter,
Expand All @@ -1152,6 +1305,7 @@ def main(argv: list[str] | None = None) -> int:
policyengine_us_data_repo=args.policyengine_us_data_repo,
policyengine_us_data_python=args.policyengine_us_data_python,
skip_tax_expenditure_targets=args.skip_tax_expenditure_targets,
exact_rescore=args.exact_rescore,
force=args.force,
)
print(str(written))
Expand Down
Loading
Loading