From 3c82c215b1c366842c8e05d5e7a9f6ab7bb8e5fa Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 3 Jun 2026 23:06:27 +0100 Subject: [PATCH] Allow Arch calibration targets in rebuild checkpoints --- .../pe_us_data_rebuild_checkpoint.py | 78 ++++++++++++++- .../test_pe_us_data_rebuild_checkpoint.py | 95 +++++++++++++++++++ 2 files changed, 168 insertions(+), 5 deletions(-) diff --git a/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py b/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py index d62b7db..4aed1cd 100644 --- a/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py +++ b/src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, replace from datetime import UTC, datetime from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import h5py import numpy as np @@ -80,6 +80,7 @@ DEFAULT_CHECKPOINT_IMPUTATION_ABLATION_EVAL_FRACTION = 0.25 MIN_CHECKPOINT_IMPUTATION_ABLATION_HOUSEHOLDS = 8 LOGGER = logging.getLogger(__name__) +DEFAULT_ARCH_CALIBRATION_TARGET_PROFILE = "pe_native_broad_source_backed" def _root_logger_has_handlers() -> bool: @@ -146,11 +147,23 @@ def _normalize_path_value(value: str | Path | None) -> str | None: return str(Path(value).expanduser()) +def _normalize_arch_targets_db_value( + value: str | Path | tuple[str | Path, ...] | None, +) -> str | tuple[str, ...] | None: + if value is None: + return None + if isinstance(value, (str, Path)): + return str(Path(value).expanduser()) + return tuple(str(Path(path).expanduser()) for path in value) + + def _validate_checkpoint_config_context( config: USMicroplexBuildConfig, *, policyengine_baseline_dataset: str | Path, policyengine_targets_db: str | Path, + arch_targets_db: str | Path | tuple[str | Path, ...] | None, + calibration_target_source: Literal["policyengine", "arch"], target_period: int, target_profile: str, calibration_target_profile: str | None, @@ -166,11 +179,18 @@ def _validate_checkpoint_config_context( policyengine_baseline_dataset ), "policyengine_targets_db": _normalize_path_value(policyengine_targets_db), + "arch_targets_db": _normalize_arch_targets_db_value(arch_targets_db), + "calibration_target_source": calibration_target_source, "policyengine_dataset_year": int(target_period), "policyengine_target_period": int(target_period), "policyengine_target_profile": target_profile, "policyengine_calibration_target_profile": ( - calibration_target_profile or target_profile + calibration_target_profile + or ( + DEFAULT_ARCH_CALIBRATION_TARGET_PROFILE + if calibration_target_source == "arch" + else target_profile + ) ), "policyengine_target_variables": tuple(target_variables), "policyengine_target_domains": tuple(target_domains), @@ -1831,6 +1851,8 @@ def default_policyengine_us_data_rebuild_checkpoint_config( *, policyengine_baseline_dataset: str | Path, policyengine_targets_db: str | Path, + arch_targets_db: str | Path | tuple[str | Path, ...] | None = None, + calibration_target_source: Literal["policyengine", "arch"] = "policyengine", target_period: int = 2024, target_profile: str = "pe_native_broad", calibration_target_profile: str | None = None, @@ -1845,6 +1867,21 @@ def default_policyengine_us_data_rebuild_checkpoint_config( """Return the canonical rebuild config with required PE comparison context.""" resolved_target_period = int(target_period) + if calibration_target_source not in {"policyengine", "arch"}: + raise ValueError( + "calibration_target_source must be 'policyengine' or 'arch', " + f"got {calibration_target_source!r}" + ) + resolved_arch_targets_db = _normalize_arch_targets_db_value(arch_targets_db) + if calibration_target_source == "arch" and resolved_arch_targets_db is None: + raise ValueError( + "arch_targets_db is required when calibration_target_source='arch'" + ) + resolved_calibration_target_profile = calibration_target_profile or ( + DEFAULT_ARCH_CALIBRATION_TARGET_PROFILE + if calibration_target_source == "arch" + else target_profile + ) resolved_baseline_weight_sum = _infer_policyengine_baseline_household_weight_sum( policyengine_baseline_dataset, target_period=resolved_target_period, @@ -1874,12 +1911,12 @@ def default_policyengine_us_data_rebuild_checkpoint_config( return default_policyengine_us_data_rebuild_config( policyengine_baseline_dataset=str(policyengine_baseline_dataset), policyengine_targets_db=str(policyengine_targets_db), + arch_targets_db=resolved_arch_targets_db, + calibration_target_source=calibration_target_source, policyengine_dataset_year=resolved_target_period, policyengine_target_period=resolved_target_period, policyengine_target_profile=target_profile, - policyengine_calibration_target_profile=( - calibration_target_profile or target_profile - ), + policyengine_calibration_target_profile=resolved_calibration_target_profile, policyengine_target_variables=tuple(target_variables), policyengine_target_domains=tuple(target_domains), policyengine_target_geo_levels=tuple(target_geo_levels), @@ -1956,6 +1993,8 @@ def run_policyengine_us_data_rebuild_checkpoint( *, policyengine_baseline_dataset: str | Path, policyengine_targets_db: str | Path, + arch_targets_db: str | Path | tuple[str | Path, ...] | None = None, + calibration_target_source: Literal["policyengine", "arch"] = "policyengine", target_period: int = 2024, target_profile: str = "pe_native_broad", calibration_target_profile: str | None = None, @@ -2022,6 +2061,8 @@ def run_policyengine_us_data_rebuild_checkpoint( resolved_config = config or default_policyengine_us_data_rebuild_checkpoint_config( policyengine_baseline_dataset=policyengine_baseline_dataset, policyengine_targets_db=policyengine_targets_db, + arch_targets_db=arch_targets_db, + calibration_target_source=calibration_target_source, target_period=target_period, target_profile=target_profile, calibration_target_profile=calibration_target_profile, @@ -2038,6 +2079,8 @@ def run_policyengine_us_data_rebuild_checkpoint( resolved_config, policyengine_baseline_dataset=policyengine_baseline_dataset, policyengine_targets_db=policyengine_targets_db, + arch_targets_db=arch_targets_db, + calibration_target_source=calibration_target_source, target_period=target_period, target_profile=target_profile, calibration_target_profile=calibration_target_profile, @@ -2123,6 +2166,10 @@ def run_policyengine_us_data_rebuild_checkpoint( output_root=Path(output_root).expanduser(), version_id=version_id or "auto", target_profile=resolved_config.policyengine_target_profile, + calibration_target_profile=( + resolved_config.policyengine_calibration_target_profile + ), + calibration_target_source=resolved_config.calibration_target_source, donor_condition_selection=resolved_config.donor_imputer_condition_selection, providers=",".join(provider_names), ) @@ -2247,6 +2294,25 @@ def main(argv: list[str] | None = None) -> None: parser.add_argument("--target-period", type=int, default=2024) parser.add_argument("--target-profile", default="pe_native_broad") parser.add_argument("--calibration-target-profile") + parser.add_argument( + "--calibration-target-source", + choices=["policyengine", "arch"], + default="policyengine", + help=( + "Target provider used for calibration. Use 'arch' with " + "--arch-targets-db for MP production calibration while keeping " + "--target-profile on the PE/eCPS comparison surface." + ), + ) + parser.add_argument( + "--arch-targets-db", + action="append", + default=[], + help=( + "Arch targets SQLite DB or consumer_facts.jsonl path for " + "calibration. May be supplied more than once." + ), + ) parser.add_argument("--n-synthetic", type=int, default=100_000) parser.add_argument("--random-seed", type=int, default=42) parser.add_argument("--donor-imputer-condition-selection") @@ -2457,6 +2523,8 @@ def main(argv: list[str] | None = None) -> None: output_root=args.output_root, policyengine_baseline_dataset=args.baseline_dataset, policyengine_targets_db=args.targets_db, + arch_targets_db=(tuple(args.arch_targets_db) if args.arch_targets_db else None), + calibration_target_source=args.calibration_target_source, target_period=args.target_period, target_profile=args.target_profile, calibration_target_profile=args.calibration_target_profile, diff --git a/tests/pipelines/test_pe_us_data_rebuild_checkpoint.py b/tests/pipelines/test_pe_us_data_rebuild_checkpoint.py index 5b7b2f6..aab5aa1 100644 --- a/tests/pipelines/test_pe_us_data_rebuild_checkpoint.py +++ b/tests/pipelines/test_pe_us_data_rebuild_checkpoint.py @@ -88,6 +88,46 @@ def test_default_policyengine_us_data_rebuild_checkpoint_config_preserves_explic assert config.policyengine_calibration_target_variables == ("snap",) +def test_default_policyengine_us_data_rebuild_checkpoint_config_uses_arch_source_backed_calibration_scope() -> ( + None +): + config = default_policyengine_us_data_rebuild_checkpoint_config( + policyengine_baseline_dataset="/tmp/enhanced_cps_2024.h5", + policyengine_targets_db="/tmp/policy_data.db", + arch_targets_db=( + "/tmp/arch/fixtures/consumer_facts.jsonl", + "/tmp/arch/macro/targets.db", + ), + calibration_target_source="arch", + ) + + assert config.policyengine_target_profile == "pe_native_broad" + assert ( + config.policyengine_calibration_target_profile + == "pe_native_broad_source_backed" + ) + assert config.calibration_target_source == "arch" + assert config.arch_targets_db == ( + "/tmp/arch/fixtures/consumer_facts.jsonl", + "/tmp/arch/macro/targets.db", + ) + + +def test_default_policyengine_us_data_rebuild_checkpoint_config_requires_arch_targets_for_arch_calibration() -> ( + None +): + try: + default_policyengine_us_data_rebuild_checkpoint_config( + policyengine_baseline_dataset="/tmp/enhanced_cps_2024.h5", + policyengine_targets_db="/tmp/policy_data.db", + calibration_target_source="arch", + ) + except ValueError as exc: + assert "arch_targets_db is required" in str(exc) + else: + raise AssertionError("Expected arch calibration without targets DB to fail") + + def test_default_policyengine_us_data_rebuild_checkpoint_config_infers_total_weight_targets( monkeypatch, ) -> None: @@ -725,6 +765,61 @@ def fake_run_policyengine_us_data_rebuild_checkpoint(**kwargs): assert "hasRealPolicyEngineComparison" in stdout +def test_main_passes_arch_calibration_target_source(monkeypatch, capsys) -> None: + captured: dict[str, Any] = {} + artifact_dir = Path("/tmp/artifacts/run-1") + parity_path = artifact_dir / "pe_us_data_rebuild_parity.json" + + def fake_run_policyengine_us_data_rebuild_checkpoint(**kwargs): + captured.update(kwargs) + return SimpleNamespace( + artifacts=SimpleNamespace( + artifact_paths=SimpleNamespace(output_dir=artifact_dir) + ), + parity_path=parity_path, + parity_payload={ + "verdict": {"hasRealPolicyEngineComparison": True}, + }, + ) + + monkeypatch.setattr( + checkpoint_module, + "run_policyengine_us_data_rebuild_checkpoint", + fake_run_policyengine_us_data_rebuild_checkpoint, + ) + + checkpoint_module.main( + [ + "--output-root", + "/tmp/artifacts", + "--baseline-dataset", + "/tmp/enhanced_cps_2024.h5", + "--targets-db", + "/tmp/policy_data.db", + "--version-id", + "run-1", + "--calibration-target-source", + "arch", + "--arch-targets-db", + "/tmp/arch/fixtures/consumer_facts.jsonl", + "--arch-targets-db", + "/tmp/arch/macro/targets.db", + "--defer-native-audit", + "--defer-imputation-ablation", + ] + ) + + assert captured["target_profile"] == "pe_native_broad" + assert captured["calibration_target_profile"] is None + assert captured["calibration_target_source"] == "arch" + assert captured["arch_targets_db"] == ( + "/tmp/arch/fixtures/consumer_facts.jsonl", + "/tmp/arch/macro/targets.db", + ) + stdout = capsys.readouterr().out + assert "/tmp/artifacts/run-1" in stdout + + def test_run_policyengine_us_data_rebuild_checkpoint_rejects_empty_provider_sequence( tmp_path, ) -> None: