Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions src/microplex_us/pipelines/versioned_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def save_versioned_us_microplex_artifacts(
child_tax_unit_agi_drift_variables: tuple[str, ...] | None = None,
allow_stage_input_overrides: bool = False,
stage_input_overrides: tuple[USStageInputOverride, ...] = (),
stage_runtime_writer: USStageRuntimeWriter | None = None,
) -> USMicroplexArtifactPaths:
"""Persist a build under a stable versioned directory beneath one output root."""
output_root = Path(output_root)
Expand Down Expand Up @@ -118,6 +119,7 @@ def save_versioned_us_microplex_artifacts(
child_tax_unit_agi_drift_variables=child_tax_unit_agi_drift_variables,
allow_stage_input_overrides=allow_stage_input_overrides,
stage_input_overrides=stage_input_overrides,
stage_runtime_writer=stage_runtime_writer,
)
return replace(paths, version_id=resolved_version_id)

Expand Down Expand Up @@ -560,6 +562,8 @@ def _allocate_versioned_output_dir_for_config(
if version_id is not None:
output_dir = output_root / version_id
if output_dir.exists():
if _version_dir_contains_only_configured_checkpoints(output_dir, config):
return version_id, output_dir
raise FileExistsError(
f"Versioned artifact directory already exists: {output_dir}"
)
Expand All @@ -578,6 +582,70 @@ def _allocate_versioned_output_dir_for_config(
return candidate_version_id, output_dir


def _version_dir_contains_only_configured_checkpoints(
output_dir: Path,
config: Mapping[str, Any],
) -> bool:
if not output_dir.is_dir():
return False

resolved_output_dir = output_dir.expanduser().resolve(strict=False)
checkpoint_roots = _configured_checkpoint_roots_inside_version_dir(
resolved_output_dir,
config,
)
if not checkpoint_roots:
return False

return all(
_path_is_allowed_checkpoint_tree_member(path, checkpoint_roots)
for path in output_dir.rglob("*")
)


def _configured_checkpoint_roots_inside_version_dir(
output_dir: Path,
config: Mapping[str, Any],
) -> tuple[Path, ...]:
roots: list[Path] = []
for key in (
"pipeline_checkpoint_save_post_imputation_path",
"pipeline_checkpoint_save_post_microsim_path",
):
checkpoint_path = config.get(key)
if checkpoint_path is None:
continue
resolved_checkpoint_path = Path(checkpoint_path).expanduser().resolve(
strict=False
)
if (
resolved_checkpoint_path != output_dir
and _path_is_relative_to(resolved_checkpoint_path, output_dir)
):
roots.append(resolved_checkpoint_path)
return tuple(roots)


def _path_is_allowed_checkpoint_tree_member(
path: Path,
checkpoint_roots: tuple[Path, ...],
) -> bool:
resolved_path = path.expanduser().resolve(strict=False)
return any(
_path_is_relative_to(resolved_path, checkpoint_root)
or _path_is_relative_to(checkpoint_root, resolved_path)
for checkpoint_root in checkpoint_roots
)


def _path_is_relative_to(path: Path, other: Path) -> bool:
try:
path.relative_to(other)
except ValueError:
return False
return True


def _short_config_hash(config: dict[str, Any]) -> str:
import hashlib
import json
Expand Down
137 changes: 137 additions & 0 deletions tests/pipelines/test_versioned_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import json
import sqlite3
from dataclasses import replace
from pathlib import Path

import duckdb
import pandas as pd
import pytest
from microplex.core import (
EntityObservation,
EntityRelationship,
Expand All @@ -18,6 +20,7 @@
TimeStructure,
)

import microplex_us.pipelines.versioned_artifacts as versioned_artifacts_module
from microplex_us.pipelines import (
build_and_save_versioned_us_microplex,
build_and_save_versioned_us_microplex_from_data_dir,
Expand Down Expand Up @@ -405,6 +408,140 @@ def test_save_versioned_us_microplex_artifacts_uses_explicit_version(tmp_path):
assert conn.execute("SELECT COUNT(*) FROM target_metrics").fetchone()[0] == 2


def test_save_versioned_artifacts_allows_configured_checkpoint_contents(tmp_path):
result = _make_result(
targets_db=tmp_path / "targets.db",
baseline_dataset=tmp_path / "baseline.h5",
snap_values=(100.0, 50.0),
)
root = tmp_path / "builds"
output_dir = root / "run-1"
post_imputation_checkpoint = output_dir / "checkpoints" / "post_imputation"
post_microsim_checkpoint = output_dir / "checkpoints" / "post_microsim"
post_imputation_checkpoint.mkdir(parents=True)
post_microsim_checkpoint.mkdir(parents=True)
(post_imputation_checkpoint / "metadata.json").write_text("{}")
(post_microsim_checkpoint / "metadata.json").write_text("{}")
result.config = replace(
result.config,
pipeline_checkpoint_save_post_imputation_path=str(post_imputation_checkpoint),
pipeline_checkpoint_save_post_microsim_path=str(post_microsim_checkpoint),
)

version_id, allocated_output_dir = (
versioned_artifacts_module._allocate_versioned_output_dir(
root,
version_id="run-1",
result=result,
)
)

assert version_id == "run-1"
assert allocated_output_dir == output_dir


def test_save_versioned_artifacts_allows_configured_checkpoint_parent(tmp_path):
result = _make_result(
targets_db=tmp_path / "targets.db",
baseline_dataset=tmp_path / "baseline.h5",
snap_values=(100.0, 50.0),
)
root = tmp_path / "builds"
output_dir = root / "run-1"
checkpoint_parent = output_dir / "checkpoints"
checkpoint_parent.mkdir(parents=True)
result.config = replace(
result.config,
pipeline_checkpoint_save_post_imputation_path=str(
checkpoint_parent / "post_imputation"
),
pipeline_checkpoint_save_post_microsim_path=str(
checkpoint_parent / "post_microsim"
),
)

version_id, allocated_output_dir = (
versioned_artifacts_module._allocate_versioned_output_dir(
root,
version_id="run-1",
result=result,
)
)

assert version_id == "run-1"
assert allocated_output_dir == output_dir


def test_save_versioned_artifacts_rejects_unrelated_existing_version_dir(tmp_path):
result = _make_result(
targets_db=tmp_path / "targets.db",
baseline_dataset=tmp_path / "baseline.h5",
snap_values=(100.0, 50.0),
)
root = tmp_path / "builds"
output_dir = root / "run-1"
post_imputation_checkpoint = output_dir / "checkpoints" / "post_imputation"
post_imputation_checkpoint.mkdir(parents=True)
(post_imputation_checkpoint / "metadata.json").write_text("{}")
(output_dir / "manifest.json").write_text("{}")
result.config = replace(
result.config,
pipeline_checkpoint_save_post_imputation_path=str(post_imputation_checkpoint),
)

with pytest.raises(FileExistsError, match="Versioned artifact directory"):
versioned_artifacts_module._allocate_versioned_output_dir(
root,
version_id="run-1",
result=result,
)


def test_save_versioned_artifacts_rejects_unconfigured_checkpoint_contents(tmp_path):
result = _make_result(
targets_db=tmp_path / "targets.db",
baseline_dataset=tmp_path / "baseline.h5",
snap_values=(100.0, 50.0),
)
root = tmp_path / "builds"
output_dir = root / "run-1"
checkpoint_dir = output_dir / "checkpoints" / "post_imputation"
checkpoint_dir.mkdir(parents=True)
(checkpoint_dir / "metadata.json").write_text("{}")

with pytest.raises(FileExistsError, match="Versioned artifact directory"):
versioned_artifacts_module._allocate_versioned_output_dir(
root,
version_id="run-1",
result=result,
)


def test_save_versioned_artifacts_rejects_existing_version_file(tmp_path):
result = _make_result(
targets_db=tmp_path / "targets.db",
baseline_dataset=tmp_path / "baseline.h5",
snap_values=(100.0, 50.0),
)
root = tmp_path / "builds"
output_path = root / "run-1"
output_path.parent.mkdir(parents=True)
output_path.write_text("not a directory")
result.config = replace(
result.config,
pipeline_checkpoint_save_post_imputation_path=str(
output_path / "checkpoints" / "post_imputation"
),
)

with pytest.raises(FileExistsError, match="Versioned artifact directory"):
versioned_artifacts_module._allocate_versioned_output_dir(
root,
version_id="run-1",
result=result,
)


def test_frontier_helpers_select_best_versioned_run(tmp_path):
targets_db = tmp_path / "policyengine_targets.db"
_create_policyengine_targets_db(targets_db)
Expand Down
Loading