From 464c43af6a8a77f11208c6928a2ad480d10d2473 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 3 Jun 2026 19:10:10 +0200 Subject: [PATCH] Allow checkpoint-only version artifact directories --- .../pipelines/versioned_artifacts.py | 68 +++++++++ tests/pipelines/test_versioned_artifacts.py | 137 ++++++++++++++++++ 2 files changed, 205 insertions(+) diff --git a/src/microplex_us/pipelines/versioned_artifacts.py b/src/microplex_us/pipelines/versioned_artifacts.py index b1ea51b..6e7341d 100644 --- a/src/microplex_us/pipelines/versioned_artifacts.py +++ b/src/microplex_us/pipelines/versioned_artifacts.py @@ -88,6 +88,7 @@ def save_versioned_us_microplex_artifacts( child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None, allow_stage_input_overrides: bool = False, stage_input_overrides: tuple[USStageInputOverride, ...] = (), + stage_runtime_writer: USStageRuntimeWriter | None = None, ) -> USMicroplexArtifactPaths: """Persist a build under a stable versioned directory beneath one output root.""" output_root = Path(output_root) @@ -118,6 +119,7 @@ def save_versioned_us_microplex_artifacts( child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables, allow_stage_input_overrides=allow_stage_input_overrides, stage_input_overrides=stage_input_overrides, + stage_runtime_writer=stage_runtime_writer, ) return replace(paths, version_id=resolved_version_id) @@ -560,6 +562,8 @@ def _allocate_versioned_output_dir_for_config( if version_id is not None: output_dir = output_root / version_id if output_dir.exists(): + if _version_dir_contains_only_configured_checkpoints(output_dir, config): + return version_id, output_dir raise FileExistsError( f"Versioned artifact directory already exists: {output_dir}" ) @@ -578,6 +582,70 @@ def _allocate_versioned_output_dir_for_config( return candidate_version_id, output_dir +def _version_dir_contains_only_configured_checkpoints( + output_dir: Path, + config: Mapping[str, Any], +) -> bool: + if not output_dir.is_dir(): + return False + + resolved_output_dir = output_dir.expanduser().resolve(strict=False) + checkpoint_roots = _configured_checkpoint_roots_inside_version_dir( + resolved_output_dir, + config, + ) + if not checkpoint_roots: + return False + + return all( + _path_is_allowed_checkpoint_tree_member(path, checkpoint_roots) + for path in output_dir.rglob("*") + ) + + +def _configured_checkpoint_roots_inside_version_dir( + output_dir: Path, + config: Mapping[str, Any], +) -> tuple[Path, ...]: + roots: list[Path] = [] + for key in ( + "pipeline_checkpoint_save_post_imputation_path", + "pipeline_checkpoint_save_post_microsim_path", + ): + checkpoint_path = config.get(key) + if checkpoint_path is None: + continue + resolved_checkpoint_path = Path(checkpoint_path).expanduser().resolve( + strict=False + ) + if ( + resolved_checkpoint_path != output_dir + and _path_is_relative_to(resolved_checkpoint_path, output_dir) + ): + roots.append(resolved_checkpoint_path) + return tuple(roots) + + +def _path_is_allowed_checkpoint_tree_member( + path: Path, + checkpoint_roots: tuple[Path, ...], +) -> bool: + resolved_path = path.expanduser().resolve(strict=False) + return any( + _path_is_relative_to(resolved_path, checkpoint_root) + or _path_is_relative_to(checkpoint_root, resolved_path) + for checkpoint_root in checkpoint_roots + ) + + +def _path_is_relative_to(path: Path, other: Path) -> bool: + try: + path.relative_to(other) + except ValueError: + return False + return True + + def _short_config_hash(config: dict[str, Any]) -> str: import hashlib import json diff --git a/tests/pipelines/test_versioned_artifacts.py b/tests/pipelines/test_versioned_artifacts.py index c89cc9c..a62ec6e 100644 --- a/tests/pipelines/test_versioned_artifacts.py +++ b/tests/pipelines/test_versioned_artifacts.py @@ -2,10 +2,12 @@ import json import sqlite3 +from dataclasses import replace from pathlib import Path import duckdb import pandas as pd +import pytest from microplex.core import ( EntityObservation, EntityRelationship, @@ -18,6 +20,7 @@ TimeStructure, ) +import microplex_us.pipelines.versioned_artifacts as versioned_artifacts_module from microplex_us.pipelines import ( build_and_save_versioned_us_microplex, build_and_save_versioned_us_microplex_from_data_dir, @@ -405,6 +408,140 @@ def test_save_versioned_us_microplex_artifacts_uses_explicit_version(tmp_path): assert conn.execute("SELECT COUNT(*) FROM target_metrics").fetchone()[0] == 2 +def test_save_versioned_artifacts_allows_configured_checkpoint_contents(tmp_path): + result = _make_result( + targets_db=tmp_path / "targets.db", + baseline_dataset=tmp_path / "baseline.h5", + snap_values=(100.0, 50.0), + ) + root = tmp_path / "builds" + output_dir = root / "run-1" + post_imputation_checkpoint = output_dir / "checkpoints" / "post_imputation" + post_microsim_checkpoint = output_dir / "checkpoints" / "post_microsim" + post_imputation_checkpoint.mkdir(parents=True) + post_microsim_checkpoint.mkdir(parents=True) + (post_imputation_checkpoint / "metadata.json").write_text("{}") + (post_microsim_checkpoint / "metadata.json").write_text("{}") + result.config = replace( + result.config, + pipeline_checkpoint_save_post_imputation_path=str(post_imputation_checkpoint), + pipeline_checkpoint_save_post_microsim_path=str(post_microsim_checkpoint), + ) + + version_id, allocated_output_dir = ( + versioned_artifacts_module._allocate_versioned_output_dir( + root, + version_id="run-1", + result=result, + ) + ) + + assert version_id == "run-1" + assert allocated_output_dir == output_dir + + +def test_save_versioned_artifacts_allows_configured_checkpoint_parent(tmp_path): + result = _make_result( + targets_db=tmp_path / "targets.db", + baseline_dataset=tmp_path / "baseline.h5", + snap_values=(100.0, 50.0), + ) + root = tmp_path / "builds" + output_dir = root / "run-1" + checkpoint_parent = output_dir / "checkpoints" + checkpoint_parent.mkdir(parents=True) + result.config = replace( + result.config, + pipeline_checkpoint_save_post_imputation_path=str( + checkpoint_parent / "post_imputation" + ), + pipeline_checkpoint_save_post_microsim_path=str( + checkpoint_parent / "post_microsim" + ), + ) + + version_id, allocated_output_dir = ( + versioned_artifacts_module._allocate_versioned_output_dir( + root, + version_id="run-1", + result=result, + ) + ) + + assert version_id == "run-1" + assert allocated_output_dir == output_dir + + +def test_save_versioned_artifacts_rejects_unrelated_existing_version_dir(tmp_path): + result = _make_result( + targets_db=tmp_path / "targets.db", + baseline_dataset=tmp_path / "baseline.h5", + snap_values=(100.0, 50.0), + ) + root = tmp_path / "builds" + output_dir = root / "run-1" + post_imputation_checkpoint = output_dir / "checkpoints" / "post_imputation" + post_imputation_checkpoint.mkdir(parents=True) + (post_imputation_checkpoint / "metadata.json").write_text("{}") + (output_dir / "manifest.json").write_text("{}") + result.config = replace( + result.config, + pipeline_checkpoint_save_post_imputation_path=str(post_imputation_checkpoint), + ) + + with pytest.raises(FileExistsError, match="Versioned artifact directory"): + versioned_artifacts_module._allocate_versioned_output_dir( + root, + version_id="run-1", + result=result, + ) + + +def test_save_versioned_artifacts_rejects_unconfigured_checkpoint_contents(tmp_path): + result = _make_result( + targets_db=tmp_path / "targets.db", + baseline_dataset=tmp_path / "baseline.h5", + snap_values=(100.0, 50.0), + ) + root = tmp_path / "builds" + output_dir = root / "run-1" + checkpoint_dir = output_dir / "checkpoints" / "post_imputation" + checkpoint_dir.mkdir(parents=True) + (checkpoint_dir / "metadata.json").write_text("{}") + + with pytest.raises(FileExistsError, match="Versioned artifact directory"): + versioned_artifacts_module._allocate_versioned_output_dir( + root, + version_id="run-1", + result=result, + ) + + +def test_save_versioned_artifacts_rejects_existing_version_file(tmp_path): + result = _make_result( + targets_db=tmp_path / "targets.db", + baseline_dataset=tmp_path / "baseline.h5", + snap_values=(100.0, 50.0), + ) + root = tmp_path / "builds" + output_path = root / "run-1" + output_path.parent.mkdir(parents=True) + output_path.write_text("not a directory") + result.config = replace( + result.config, + pipeline_checkpoint_save_post_imputation_path=str( + output_path / "checkpoints" / "post_imputation" + ), + ) + + with pytest.raises(FileExistsError, match="Versioned artifact directory"): + versioned_artifacts_module._allocate_versioned_output_dir( + root, + version_id="run-1", + result=result, + ) + + def test_frontier_helpers_select_best_versioned_run(tmp_path): targets_db = tmp_path / "policyengine_targets.db" _create_policyengine_targets_db(targets_db)