Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions src/policyengine/core/release_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
}


class DataReleaseManifestUnavailableError(ValueError):
"""Raised when a data release manifest cannot be fetched or is absent."""


class PackageVersion(BaseModel):
name: str
version: str
Expand Down Expand Up @@ -161,10 +165,14 @@ def get_data_release_manifest(country_id: str) -> DataReleaseManifest:
timeout=HF_REQUEST_TIMEOUT_SECONDS,
)
if response.status_code in (401, 403):
raise ValueError(
raise DataReleaseManifestUnavailableError(
"Could not fetch the data release manifest from Hugging Face. "
"If this country uses a private data repo, set HUGGING_FACE_TOKEN."
)
if response.status_code == 404:
raise DataReleaseManifestUnavailableError(
"No data release manifest was published for this data package."
)
response.raise_for_status()
return DataReleaseManifest.model_validate_json(response.text)

Expand All @@ -183,13 +191,23 @@ def certify_data_release_compatibility(
country_manifest = get_release_manifest(country_id)
try:
data_release_manifest = get_data_release_manifest(country_id)
except Exception as exc:
except DataReleaseManifestUnavailableError as exc:
bundled_certification = country_manifest.certification
if (
bundled_certification is not None
and bundled_certification.certified_for_model_version
== runtime_model_version
):
if (
runtime_data_build_fingerprint is not None
and bundled_certification.data_build_fingerprint is not None
and runtime_data_build_fingerprint
!= bundled_certification.data_build_fingerprint
):
raise ValueError(
"Runtime data build fingerprint does not match the bundled "
"data certification."
)
return bundled_certification
raise exc
built_with_model = (
Expand Down Expand Up @@ -339,15 +357,20 @@ def resolve_managed_dataset_reference(
return resolve_dataset_reference(country_id, dataset)


def resolve_local_managed_dataset_source(country_id: str, dataset_uri: str) -> str:
def resolve_local_managed_dataset_source(
country_id: str,
dataset_uri: str,
*,
allow_local_mirror: bool = True,
) -> str:
"""Resolve a local mirror of a managed dataset when available.

This preserves the bundled dataset URI for provenance while allowing local
development environments with sibling data-repo checkouts to load the
exact certified artifact from disk rather than re-downloading it.
"""

if not dataset_uri.startswith("hf://"):
if not allow_local_mirror or not dataset_uri.startswith("hf://"):
return dataset_uri

local_hint = LOCAL_DATA_REPO_HINTS.get(country_id)
Expand Down
10 changes: 8 additions & 2 deletions src/policyengine/tax_benefit_models/uk/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def _managed_release_bundle(
bundle = dict(uk_latest.release_bundle)
bundle["runtime_dataset"] = dataset_logical_name(dataset_uri)
bundle["runtime_dataset_uri"] = dataset_uri
if dataset_source and dataset_source != dataset_uri:
if dataset_source:
bundle["runtime_dataset_source"] = dataset_source
bundle["managed_by"] = "policyengine.py"
return bundle
Expand Down Expand Up @@ -467,7 +467,13 @@ def managed_microsimulation(
dataset,
allow_unmanaged=allow_unmanaged,
)
dataset_source = resolve_local_managed_dataset_source("uk", dataset_uri)
dataset_source = resolve_local_managed_dataset_source(
"uk",
dataset_uri,
allow_local_mirror=not (
allow_unmanaged and dataset is not None and "://" in dataset
),
)
runtime_dataset = dataset_source
if isinstance(dataset_source, str) and "hf://" not in dataset_source:
from policyengine_uk.data.dataset_schema import (
Expand Down
10 changes: 8 additions & 2 deletions src/policyengine/tax_benefit_models/us/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def _managed_release_bundle(
bundle = dict(us_latest.release_bundle)
bundle["runtime_dataset"] = dataset_logical_name(dataset_uri)
bundle["runtime_dataset_uri"] = dataset_uri
if dataset_source and dataset_source != dataset_uri:
if dataset_source:
bundle["runtime_dataset_source"] = dataset_source
bundle["managed_by"] = "policyengine.py"
return bundle
Expand Down Expand Up @@ -632,7 +632,13 @@ def managed_microsimulation(
dataset,
allow_unmanaged=allow_unmanaged,
)
dataset_source = resolve_local_managed_dataset_source("us", dataset_uri)
dataset_source = resolve_local_managed_dataset_source(
"us",
dataset_uri,
allow_local_mirror=not (
allow_unmanaged and dataset is not None and "://" in dataset
),
)
microsim = Microsimulation(dataset=dataset_source, **kwargs)
microsim.policyengine_bundle = _managed_release_bundle(
dataset_uri,
Expand Down
142 changes: 132 additions & 10 deletions tests/test_release_manifests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import json
from unittest.mock import MagicMock, patch

from requests import Timeout

from policyengine.core.release_manifest import (
DataCertification,
DataReleaseManifestUnavailableError,
certify_data_release_compatibility,
dataset_logical_name,
get_data_release_manifest,
Expand Down Expand Up @@ -194,6 +198,22 @@ def test__given_country__then_can_fetch_data_release_manifest(self):
)
assert mock_get.call_count == 1

def test__given_missing_data_release_manifest__then_fetch_raises_unavailable(self):
get_data_release_manifest.cache_clear()
response = MagicMock()
response.status_code = 404

with patch(
"policyengine.core.release_manifest.requests.get",
return_value=response,
):
try:
get_data_release_manifest("us")
except DataReleaseManifestUnavailableError as error:
assert "No data release manifest" in str(error)
else:
raise AssertionError("Expected missing manifest to be reported")

def test__given_matching_fingerprint__then_certification_allows_reuse(self):
get_data_release_manifest.cache_clear()
payload = {
Expand Down Expand Up @@ -231,6 +251,73 @@ def test__given_matching_fingerprint__then_certification_allows_reuse(self):
assert certification.built_with_model_version == "1.601.0"
assert certification.certified_for_model_version == "1.602.0"

def test__given_private_manifest_unavailable__then_bundled_certification_is_used(
self,
):
get_data_release_manifest.cache_clear()

with patch(
"policyengine.core.release_manifest.get_data_release_manifest",
side_effect=DataReleaseManifestUnavailableError("private repo"),
):
certification = certify_data_release_compatibility(
"us",
runtime_model_version="1.602.0",
)

assert certification == get_release_manifest("us").certification

def test__given_private_manifest_unavailable_and_fingerprint_mismatch__then_fails(
self,
):
get_data_release_manifest.cache_clear()

with (
patch(
"policyengine.core.release_manifest.get_data_release_manifest",
side_effect=DataReleaseManifestUnavailableError("private repo"),
),
patch(
"policyengine.core.release_manifest.get_release_manifest",
return_value=MagicMock(
certification=DataCertification(
compatibility_basis="matching_data_build_fingerprint",
certified_for_model_version="1.602.0",
data_build_fingerprint="sha256:expected",
),
),
),
):
try:
certify_data_release_compatibility(
"us",
runtime_model_version="1.602.0",
runtime_data_build_fingerprint="sha256:not-a-match",
)
except ValueError as error:
assert "does not match the bundled data certification" in str(error)
else:
raise AssertionError("Expected fingerprint mismatch to fail")

def test__given_manifest_fetch_failure__then_certification_does_not_fallback(
self,
):
get_data_release_manifest.cache_clear()

with patch(
"policyengine.core.release_manifest.get_data_release_manifest",
side_effect=Timeout("network timeout"),
):
try:
certify_data_release_compatibility(
"us",
runtime_model_version="1.602.0",
)
except Timeout as error:
assert "network timeout" in str(error)
else:
raise AssertionError("Expected timeout to propagate")

def test__given_mismatched_version_and_fingerprint__then_certification_fails(self):
get_data_release_manifest.cache_clear()
payload = {
Expand Down Expand Up @@ -327,18 +414,30 @@ def test__given_us_managed_microsimulation__then_passes_certified_dataset_and_bu
microsim = managed_us_microsimulation()

dataset = mock_microsimulation.call_args.kwargs["dataset"]
assert str(dataset).endswith(
"policyengine_us_data/storage/enhanced_cps_2024.h5"
)
assert dataset == microsim.policyengine_bundle["runtime_dataset_source"]
assert microsim.policyengine_bundle["policyengine_version"] == "3.4.0"
assert microsim.policyengine_bundle["runtime_dataset"] == "enhanced_cps_2024"
assert (
microsim.policyengine_bundle["runtime_dataset_uri"]
== us_latest.default_dataset_uri
)
assert str(microsim.policyengine_bundle["runtime_dataset_source"]).endswith(
"policyengine_us_data/storage/enhanced_cps_2024.h5"
)
dataset_source = microsim.policyengine_bundle["runtime_dataset_source"]
assert dataset_source == us_latest.default_dataset_uri or str(
dataset_source
).endswith("policyengine_us_data/storage/enhanced_cps_2024.h5")

def test__given_us_unmanaged_dataset_uri__then_source_is_not_rewritten(self):
dataset = "hf://policyengine/policyengine-us-data/cps_2023.h5@1.73.0"

with patch("policyengine_us.Microsimulation") as mock_microsimulation:
microsim = managed_us_microsimulation(
dataset=dataset,
allow_unmanaged=True,
)

assert mock_microsimulation.call_args.kwargs["dataset"] == dataset
assert microsim.policyengine_bundle["runtime_dataset_uri"] == dataset
assert microsim.policyengine_bundle["runtime_dataset_source"] == dataset

def test__given_uk_managed_dataset_name__then_resolves_within_bundle(self):
with patch("policyengine_uk.Microsimulation") as mock_microsimulation:
Expand All @@ -347,13 +446,36 @@ def test__given_uk_managed_dataset_name__then_resolves_within_bundle(self):
dataset = mock_microsimulation.call_args.kwargs["dataset"]
from policyengine_uk.data.dataset_schema import UKSingleYearDataset

assert isinstance(dataset, UKSingleYearDataset)
assert getattr(dataset, "time_period", None) == "2023"
if isinstance(dataset, UKSingleYearDataset):
assert getattr(dataset, "time_period", None) == "2023"
else:
assert dataset == (
"hf://policyengine/policyengine-uk-data-private/"
"enhanced_frs_2023_24.h5@1.40.4"
)
assert microsim.policyengine_bundle["policyengine_version"] == "3.4.0"
assert microsim.policyengine_bundle["runtime_dataset"] == "enhanced_frs_2023_24"
assert microsim.policyengine_bundle["runtime_dataset_uri"] == (
"hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.4"
)
assert str(microsim.policyengine_bundle["runtime_dataset_source"]).endswith(
"policyengine_uk_data/storage/enhanced_frs_2023_24.h5"
dataset_source = microsim.policyengine_bundle["runtime_dataset_source"]
assert (
dataset_source
== "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.4"
or str(dataset_source).endswith(
"policyengine_uk_data/storage/enhanced_frs_2023_24.h5"
)
)

def test__given_uk_unmanaged_dataset_uri__then_source_is_not_rewritten(self):
dataset = "hf://policyengine/policyengine-uk-data-private/frs_2022_23.h5@1.40.4"

with patch("policyengine_uk.Microsimulation") as mock_microsimulation:
microsim = managed_uk_microsimulation(
dataset=dataset,
allow_unmanaged=True,
)

assert mock_microsimulation.call_args.kwargs["dataset"] == dataset
assert microsim.policyengine_bundle["runtime_dataset_uri"] == dataset
assert microsim.policyengine_bundle["runtime_dataset_source"] == dataset
Loading