Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 73 additions & 5 deletions src/microplex_us/pipelines/pe_us_data_rebuild_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import dataclass, replace
from datetime import UTC, datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

import h5py
import numpy as np
Expand Down Expand Up @@ -80,6 +80,7 @@
DEFAULT_CHECKPOINT_IMPUTATION_ABLATION_EVAL_FRACTION = 0.25
MIN_CHECKPOINT_IMPUTATION_ABLATION_HOUSEHOLDS = 8
LOGGER = logging.getLogger(__name__)
DEFAULT_ARCH_CALIBRATION_TARGET_PROFILE = "pe_native_broad_source_backed"


def _root_logger_has_handlers() -> bool:
Expand Down Expand Up @@ -146,11 +147,23 @@ def _normalize_path_value(value: str | Path | None) -> str | None:
return str(Path(value).expanduser())


def _normalize_arch_targets_db_value(
value: str | Path | tuple[str | Path, ...] | None,
) -> str | tuple[str, ...] | None:
if value is None:
return None
if isinstance(value, (str, Path)):
return str(Path(value).expanduser())
return tuple(str(Path(path).expanduser()) for path in value)


def _validate_checkpoint_config_context(
config: USMicroplexBuildConfig,
*,
policyengine_baseline_dataset: str | Path,
policyengine_targets_db: str | Path,
arch_targets_db: str | Path | tuple[str | Path, ...] | None,
calibration_target_source: Literal["policyengine", "arch"],
target_period: int,
target_profile: str,
calibration_target_profile: str | None,
Expand All @@ -166,11 +179,18 @@ def _validate_checkpoint_config_context(
policyengine_baseline_dataset
),
"policyengine_targets_db": _normalize_path_value(policyengine_targets_db),
"arch_targets_db": _normalize_arch_targets_db_value(arch_targets_db),
"calibration_target_source": calibration_target_source,
"policyengine_dataset_year": int(target_period),
"policyengine_target_period": int(target_period),
"policyengine_target_profile": target_profile,
"policyengine_calibration_target_profile": (
calibration_target_profile or target_profile
calibration_target_profile
or (
DEFAULT_ARCH_CALIBRATION_TARGET_PROFILE
if calibration_target_source == "arch"
else target_profile
)
),
"policyengine_target_variables": tuple(target_variables),
"policyengine_target_domains": tuple(target_domains),
Expand Down Expand Up @@ -1831,6 +1851,8 @@ def default_policyengine_us_data_rebuild_checkpoint_config(
*,
policyengine_baseline_dataset: str | Path,
policyengine_targets_db: str | Path,
arch_targets_db: str | Path | tuple[str | Path, ...] | None = None,
calibration_target_source: Literal["policyengine", "arch"] = "policyengine",
target_period: int = 2024,
target_profile: str = "pe_native_broad",
calibration_target_profile: str | None = None,
Expand All @@ -1845,6 +1867,21 @@ def default_policyengine_us_data_rebuild_checkpoint_config(
"""Return the canonical rebuild config with required PE comparison context."""

resolved_target_period = int(target_period)
if calibration_target_source not in {"policyengine", "arch"}:
raise ValueError(
"calibration_target_source must be 'policyengine' or 'arch', "
f"got {calibration_target_source!r}"
)
resolved_arch_targets_db = _normalize_arch_targets_db_value(arch_targets_db)
if calibration_target_source == "arch" and resolved_arch_targets_db is None:
raise ValueError(
"arch_targets_db is required when calibration_target_source='arch'"
)
resolved_calibration_target_profile = calibration_target_profile or (
DEFAULT_ARCH_CALIBRATION_TARGET_PROFILE
if calibration_target_source == "arch"
else target_profile
)
resolved_baseline_weight_sum = _infer_policyengine_baseline_household_weight_sum(
policyengine_baseline_dataset,
target_period=resolved_target_period,
Expand Down Expand Up @@ -1874,12 +1911,12 @@ def default_policyengine_us_data_rebuild_checkpoint_config(
return default_policyengine_us_data_rebuild_config(
policyengine_baseline_dataset=str(policyengine_baseline_dataset),
policyengine_targets_db=str(policyengine_targets_db),
arch_targets_db=resolved_arch_targets_db,
calibration_target_source=calibration_target_source,
policyengine_dataset_year=resolved_target_period,
policyengine_target_period=resolved_target_period,
policyengine_target_profile=target_profile,
policyengine_calibration_target_profile=(
calibration_target_profile or target_profile
),
policyengine_calibration_target_profile=resolved_calibration_target_profile,
policyengine_target_variables=tuple(target_variables),
policyengine_target_domains=tuple(target_domains),
policyengine_target_geo_levels=tuple(target_geo_levels),
Expand Down Expand Up @@ -1956,6 +1993,8 @@ def run_policyengine_us_data_rebuild_checkpoint(
*,
policyengine_baseline_dataset: str | Path,
policyengine_targets_db: str | Path,
arch_targets_db: str | Path | tuple[str | Path, ...] | None = None,
calibration_target_source: Literal["policyengine", "arch"] = "policyengine",
target_period: int = 2024,
target_profile: str = "pe_native_broad",
calibration_target_profile: str | None = None,
Expand Down Expand Up @@ -2022,6 +2061,8 @@ def run_policyengine_us_data_rebuild_checkpoint(
resolved_config = config or default_policyengine_us_data_rebuild_checkpoint_config(
policyengine_baseline_dataset=policyengine_baseline_dataset,
policyengine_targets_db=policyengine_targets_db,
arch_targets_db=arch_targets_db,
calibration_target_source=calibration_target_source,
target_period=target_period,
target_profile=target_profile,
calibration_target_profile=calibration_target_profile,
Expand All @@ -2038,6 +2079,8 @@ def run_policyengine_us_data_rebuild_checkpoint(
resolved_config,
policyengine_baseline_dataset=policyengine_baseline_dataset,
policyengine_targets_db=policyengine_targets_db,
arch_targets_db=arch_targets_db,
calibration_target_source=calibration_target_source,
target_period=target_period,
target_profile=target_profile,
calibration_target_profile=calibration_target_profile,
Expand Down Expand Up @@ -2123,6 +2166,10 @@ def run_policyengine_us_data_rebuild_checkpoint(
output_root=Path(output_root).expanduser(),
version_id=version_id or "auto",
target_profile=resolved_config.policyengine_target_profile,
calibration_target_profile=(
resolved_config.policyengine_calibration_target_profile
),
calibration_target_source=resolved_config.calibration_target_source,
donor_condition_selection=resolved_config.donor_imputer_condition_selection,
providers=",".join(provider_names),
)
Expand Down Expand Up @@ -2247,6 +2294,25 @@ def main(argv: list[str] | None = None) -> None:
parser.add_argument("--target-period", type=int, default=2024)
parser.add_argument("--target-profile", default="pe_native_broad")
parser.add_argument("--calibration-target-profile")
parser.add_argument(
"--calibration-target-source",
choices=["policyengine", "arch"],
default="policyengine",
help=(
"Target provider used for calibration. Use 'arch' with "
"--arch-targets-db for MP production calibration while keeping "
"--target-profile on the PE/eCPS comparison surface."
),
)
parser.add_argument(
"--arch-targets-db",
action="append",
default=[],
help=(
"Arch targets SQLite DB or consumer_facts.jsonl path for "
"calibration. May be supplied more than once."
),
)
parser.add_argument("--n-synthetic", type=int, default=100_000)
parser.add_argument("--random-seed", type=int, default=42)
parser.add_argument("--donor-imputer-condition-selection")
Expand Down Expand Up @@ -2457,6 +2523,8 @@ def main(argv: list[str] | None = None) -> None:
output_root=args.output_root,
policyengine_baseline_dataset=args.baseline_dataset,
policyengine_targets_db=args.targets_db,
arch_targets_db=(tuple(args.arch_targets_db) if args.arch_targets_db else None),
calibration_target_source=args.calibration_target_source,
target_period=args.target_period,
target_profile=args.target_profile,
calibration_target_profile=args.calibration_target_profile,
Expand Down
95 changes: 95 additions & 0 deletions tests/pipelines/test_pe_us_data_rebuild_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,46 @@ def test_default_policyengine_us_data_rebuild_checkpoint_config_preserves_explic
assert config.policyengine_calibration_target_variables == ("snap",)


def test_default_policyengine_us_data_rebuild_checkpoint_config_uses_arch_source_backed_calibration_scope() -> (
None
):
config = default_policyengine_us_data_rebuild_checkpoint_config(
policyengine_baseline_dataset="/tmp/enhanced_cps_2024.h5",
policyengine_targets_db="/tmp/policy_data.db",
arch_targets_db=(
"/tmp/arch/fixtures/consumer_facts.jsonl",
"/tmp/arch/macro/targets.db",
),
calibration_target_source="arch",
)

assert config.policyengine_target_profile == "pe_native_broad"
assert (
config.policyengine_calibration_target_profile
== "pe_native_broad_source_backed"
)
assert config.calibration_target_source == "arch"
assert config.arch_targets_db == (
"/tmp/arch/fixtures/consumer_facts.jsonl",
"/tmp/arch/macro/targets.db",
)


def test_default_policyengine_us_data_rebuild_checkpoint_config_requires_arch_targets_for_arch_calibration() -> (
None
):
try:
default_policyengine_us_data_rebuild_checkpoint_config(
policyengine_baseline_dataset="/tmp/enhanced_cps_2024.h5",
policyengine_targets_db="/tmp/policy_data.db",
calibration_target_source="arch",
)
except ValueError as exc:
assert "arch_targets_db is required" in str(exc)
else:
raise AssertionError("Expected arch calibration without targets DB to fail")


def test_default_policyengine_us_data_rebuild_checkpoint_config_infers_total_weight_targets(
monkeypatch,
) -> None:
Expand Down Expand Up @@ -725,6 +765,61 @@ def fake_run_policyengine_us_data_rebuild_checkpoint(**kwargs):
assert "hasRealPolicyEngineComparison" in stdout


def test_main_passes_arch_calibration_target_source(monkeypatch, capsys) -> None:
captured: dict[str, Any] = {}
artifact_dir = Path("/tmp/artifacts/run-1")
parity_path = artifact_dir / "pe_us_data_rebuild_parity.json"

def fake_run_policyengine_us_data_rebuild_checkpoint(**kwargs):
captured.update(kwargs)
return SimpleNamespace(
artifacts=SimpleNamespace(
artifact_paths=SimpleNamespace(output_dir=artifact_dir)
),
parity_path=parity_path,
parity_payload={
"verdict": {"hasRealPolicyEngineComparison": True},
},
)

monkeypatch.setattr(
checkpoint_module,
"run_policyengine_us_data_rebuild_checkpoint",
fake_run_policyengine_us_data_rebuild_checkpoint,
)

checkpoint_module.main(
[
"--output-root",
"/tmp/artifacts",
"--baseline-dataset",
"/tmp/enhanced_cps_2024.h5",
"--targets-db",
"/tmp/policy_data.db",
"--version-id",
"run-1",
"--calibration-target-source",
"arch",
"--arch-targets-db",
"/tmp/arch/fixtures/consumer_facts.jsonl",
"--arch-targets-db",
"/tmp/arch/macro/targets.db",
"--defer-native-audit",
"--defer-imputation-ablation",
]
)

assert captured["target_profile"] == "pe_native_broad"
assert captured["calibration_target_profile"] is None
assert captured["calibration_target_source"] == "arch"
assert captured["arch_targets_db"] == (
"/tmp/arch/fixtures/consumer_facts.jsonl",
"/tmp/arch/macro/targets.db",
)
stdout = capsys.readouterr().out
assert "/tmp/artifacts/run-1" in stdout


def test_run_policyengine_us_data_rebuild_checkpoint_rejects_empty_provider_sequence(
tmp_path,
) -> None:
Expand Down
Loading