diff --git a/src/microplex_us/pipelines/pe_us_recalibrate_from_checkpoint.py b/src/microplex_us/pipelines/pe_us_recalibrate_from_checkpoint.py index bf24997..ee53bc9 100644 --- a/src/microplex_us/pipelines/pe_us_recalibrate_from_checkpoint.py +++ b/src/microplex_us/pipelines/pe_us_recalibrate_from_checkpoint.py @@ -19,9 +19,10 @@ import argparse import json +import os import sys +from collections.abc import Sequence from pathlib import Path -from typing import Sequence from microplex_us.pipelines.us import ( USMicroplexBuildConfig, @@ -29,6 +30,16 @@ ) +def _prepare_output_root(output_root: Path) -> Path: + if not output_root.exists(): + raise FileNotFoundError(f"--output-root does not exist: {output_root}") + if not output_root.is_dir(): + raise NotADirectoryError(f"--output-root is not a directory: {output_root}") + if not os.access(output_root, os.W_OK | os.X_OK): + raise PermissionError(f"--output-root is not writable: {output_root}") + return output_root + + def main(argv: Sequence[str] | None = None) -> int: parser = argparse.ArgumentParser( description=( @@ -51,7 +62,7 @@ def main(argv: Sequence[str] | None = None) -> int: "--output-root", type=Path, required=True, - help="Output directory for the recalibrated bundle and summary.", + help="Existing output directory for the recalibrated bundle and summary.", ) parser.add_argument( "--targets-db", @@ -96,6 +107,7 @@ def main(argv: Sequence[str] | None = None) -> int: ), ) args = parser.parse_args(argv) + output_root = _prepare_output_root(args.output_root) config_kwargs: dict[str, object] = { "calibration_backend": args.calibration_backend, @@ -116,20 +128,19 @@ def main(argv: Sequence[str] | None = None) -> int: config = USMicroplexBuildConfig(**config_kwargs) result = recalibrate_policyengine_us_from_checkpoint(config, args.checkpoint_path) - args.output_root.mkdir(parents=True, exist_ok=True) - result.calibrated_data.to_parquet(args.output_root / "calibrated_data.parquet") + result.calibrated_data.to_parquet(output_root / "calibrated_data.parquet") result.policyengine_tables.households.to_parquet( - args.output_root / "households.parquet" + output_root / "households.parquet" ) if result.policyengine_tables.persons is not None: result.policyengine_tables.persons.to_parquet( - args.output_root / "persons.parquet" + output_root / "persons.parquet" ) - (args.output_root / "calibration_summary.json").write_text( + (output_root / "calibration_summary.json").write_text( json.dumps(result.calibration_summary, indent=2, default=str) ) print( - f"Recalibrated from {args.checkpoint_path} → {args.output_root} " + f"Recalibrated from {args.checkpoint_path} → {output_root} " f"(stage={result.loaded_stage}, " f"rows={len(result.calibrated_data)})" ) diff --git a/tests/pipelines/test_recalibrate_from_checkpoint.py b/tests/pipelines/test_recalibrate_from_checkpoint.py index f13b3c7..2f5c44d 100644 --- a/tests/pipelines/test_recalibrate_from_checkpoint.py +++ b/tests/pipelines/test_recalibrate_from_checkpoint.py @@ -11,17 +11,17 @@ 1. The helper loads a post-imputation checkpoint and dispatches the bundle to a fresh pipeline's calibrate method. -2. The helper rejects post-microsim checkpoints in v1 (resume from that - stage needs pickled constraints, which is a follow-up). +2. The helper also accepts post-microsim checkpoints, where materialized + target columns already exist on the bundle. 3. The helper raises a clear error if the checkpoint directory is missing. """ from __future__ import annotations +import os from pathlib import Path from typing import Any -from unittest.mock import MagicMock import numpy as np import pandas as pd @@ -71,7 +71,9 @@ def test_checkpoint_dispatches_to_calibrate( orchestrates the load and hand-off, so the parametrized test covers both paths. """ - from microplex_us.pipelines.us import recalibrate_policyengine_us_from_checkpoint + from microplex_us.pipelines.us import ( + recalibrate_policyengine_us_from_checkpoint, + ) bundle = _make_bundle(n=40) save_us_pipeline_checkpoint( @@ -114,7 +116,9 @@ def _fake_calibrate( def test_unsupported_stage_raises(self, tmp_path: Path) -> None: """A metadata.json with an unknown stage is rejected.""" - from microplex_us.pipelines.us import recalibrate_policyengine_us_from_checkpoint + from microplex_us.pipelines.us import ( + recalibrate_policyengine_us_from_checkpoint, + ) (tmp_path / "checkpoint").mkdir() import json @@ -127,8 +131,133 @@ def test_unsupported_stage_raises(self, tmp_path: Path) -> None: recalibrate_policyengine_us_from_checkpoint(cfg, tmp_path / "checkpoint") def test_missing_checkpoint_raises(self, tmp_path: Path) -> None: - from microplex_us.pipelines.us import recalibrate_policyengine_us_from_checkpoint + from microplex_us.pipelines.us import ( + recalibrate_policyengine_us_from_checkpoint, + ) cfg = USMicroplexBuildConfig(policyengine_targets_db=tmp_path / "targets.db") with pytest.raises(FileNotFoundError): recalibrate_policyengine_us_from_checkpoint(cfg, tmp_path / "nope") + + +class TestRecalibrateFromCheckpointCli: + def test_prepare_output_root_accepts_existing_empty_directory( + self, + tmp_path: Path, + ) -> None: + from microplex_us.pipelines.pe_us_recalibrate_from_checkpoint import ( + _prepare_output_root, + ) + + output_root = tmp_path / "output" + output_root.mkdir() + + assert _prepare_output_root(output_root) == output_root + assert output_root.is_dir() + assert list(output_root.iterdir()) == [] + + def test_prepare_output_root_rejects_missing_directory( + self, + tmp_path: Path, + ) -> None: + from microplex_us.pipelines.pe_us_recalibrate_from_checkpoint import ( + _prepare_output_root, + ) + + output_root = tmp_path / "output" + + with pytest.raises(FileNotFoundError, match="--output-root does not exist"): + _prepare_output_root(output_root) + assert not output_root.exists() + + def test_prepare_output_root_rejects_unwritable_directory( + self, + tmp_path: Path, + ) -> None: + from microplex_us.pipelines.pe_us_recalibrate_from_checkpoint import ( + _prepare_output_root, + ) + + output_root = tmp_path / "output" + output_root.mkdir() + original_mode = output_root.stat().st_mode + try: + output_root.chmod(0o500) + if os.access(output_root, os.W_OK | os.X_OK): + pytest.skip("current platform still reports chmod 0500 as writable") + with pytest.raises(PermissionError, match="--output-root is not writable"): + _prepare_output_root(output_root) + finally: + output_root.chmod(original_mode) + + def test_main_rejects_output_file_before_recalibration( + self, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + import microplex_us.pipelines.pe_us_recalibrate_from_checkpoint as cli + + called = False + + def _fail_if_called(*args: Any, **kwargs: Any) -> None: + nonlocal called + called = True + raise AssertionError("recalibration should not start") + + monkeypatch.setattr( + cli, + "recalibrate_policyengine_us_from_checkpoint", + _fail_if_called, + ) + output_root = tmp_path / "output" + output_root.write_text("not a directory") + + with pytest.raises(NotADirectoryError, match="--output-root is not a directory"): + cli.main( + [ + "--checkpoint-path", + str(tmp_path / "checkpoint"), + "--output-root", + str(output_root), + "--targets-db", + str(tmp_path / "targets.db"), + ] + ) + + assert called is False + + def test_main_rejects_missing_output_directory_before_recalibration( + self, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + import microplex_us.pipelines.pe_us_recalibrate_from_checkpoint as cli + + called = False + + def _fail_if_called(*args: Any, **kwargs: Any) -> None: + nonlocal called + called = True + raise AssertionError("recalibration should not start") + + monkeypatch.setattr( + cli, + "recalibrate_policyengine_us_from_checkpoint", + _fail_if_called, + ) + output_root = tmp_path / "output" + + with pytest.raises(FileNotFoundError, match="--output-root does not exist"): + cli.main( + [ + "--checkpoint-path", + str(tmp_path / "checkpoint"), + "--output-root", + str(output_root), + "--targets-db", + str(tmp_path / "targets.db"), + ] + ) + + assert called is False + assert not output_root.exists()