Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# yaml-language-server: $schema=../workflow/tools/config.schema.json
description: |
Evaluate skill of temporal downscaler driven by ICON-CH1 Stage E forecaster with
subgrid orography.
Evaluate skill of Varda-single-1.0 against ground observations.

config_label: varda-single-1.0

dates:
start: 2025-03-01T00:00
Expand All @@ -10,15 +11,15 @@ dates:

runs:
- temporal_downscaler:
checkpoint: /scratch/mch/miccatta/ICON_interpolator_checkpoints/checkpoint_stage-C-interpolator-n320-6hto1h-reduced-variables/f9279244ed6f4c458597bdcf335ab36f/inference-last.ckpt
label: Varda-Single
checkpoint: https://service.meteoswiss.ch/mlstore#/models/sruc-m-2-interpolator/versions/3
label: Varda-single-1.0
steps: 0/120/1
config: resources/inference/configs/sgm-temporal-downscaler-global_trimedge_multi.yaml
extra_requirements:
- anemoi-datasets==0.5.35
# - anemoi-inference==0.11.0
forecaster:
checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/c30490b6ba064e4db03b430f3a2595ad
checkpoint: https://service.meteoswiss.ch/mlstore#/models/sruc-m-1-forecaster/versions/4
config: resources/inference/configs/sgm-multidataset-forecaster-global-ich1-oper.yaml
steps: 0/120/6
- baseline:
Expand Down
12 changes: 2 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,8 @@


@pytest.fixture
def example_forecasters_config():
configfile = PROJECT_ROOT / "config/forecasters-ich1.yaml"
with open(configfile, "r") as f:
config = yaml.safe_load(f)
return config


@pytest.fixture
def example_temporal_downscalers_config():
configfile = PROJECT_ROOT / "config/temporal-downscalers-ich1.yaml"
def example_config():
configfile = PROJECT_ROOT / "config/varda-single-1.0.yaml"
with open(configfile, "r") as f:
config = yaml.safe_load(f)
return config
20 changes: 4 additions & 16 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,13 @@
from evalml.config import ConfigModel


def test_example_forecasters_config(example_forecasters_config):
def test_example_config(example_config):
"""Test that the example config loads correctly."""

# this shoudd not raise an error
_ = ConfigModel.model_validate(example_forecasters_config)
_ = ConfigModel.model_validate(example_config)

# this should raise an error
del example_forecasters_config["runs"]
del example_config["runs"]
with pytest.raises(ValueError, match="Field required"):
_ = ConfigModel.model_validate(example_forecasters_config)


def test_example_temporal_downscalers_config(example_temporal_downscalers_config):
"""Test that the example config loads correctly."""

# this shoudd not raise an error
_ = ConfigModel.model_validate(example_temporal_downscalers_config)

# this should raise an error
del example_temporal_downscalers_config["runs"]
with pytest.raises(ValueError, match="Field required"):
_ = ConfigModel.model_validate(example_temporal_downscalers_config)
_ = ConfigModel.model_validate(example_config)
6 changes: 6 additions & 0 deletions workflow/rules/common.smk
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ def model_id(checkpoint_uri: str) -> str:
"""Generate a model ID based on the checkpoint URI."""
ckpt_type = _checkpoint_uri_type(checkpoint_uri)
if ckpt_type == "mlflow":
fragment = checkpoint_uri.split("#")[-1]
if "/models/" in fragment:
parts = fragment.strip("/").split("/")
if len(parts) >= 4 and parts[2] == "versions":
return f"{parts[1]}-v{parts[3]}"[:HASH_LENGTH]
return f"{parts[1]}-latest"[:HASH_LENGTH]
return checkpoint_uri.split("/")[-1][:HASH_LENGTH]
elif ckpt_type == "huggingface":
return checkpoint_uri.split("/")[-1].split(".")[0]
Expand Down
13 changes: 10 additions & 3 deletions workflow/rules/inference.smk
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,21 @@ rule inference_get_checkpoint:
checkpoint_type=lambda wc: _checkpoint_uri_type(
ENV_CONFIGS[wc.env_id]["checkpoint"]
),
checkpoint_is_registry=lambda wc: "/models/"
in ENV_CONFIGS[wc.env_id]["checkpoint"],
shell:
r"""
(
mkdir -p $(dirname {output.checkpoint})
if [ "{params.checkpoint_type}" = "mlflow" ]; then
ln -s $(python workflow/scripts/inference_get_checkpoint_mlflow.py {params.checkpoint}) {output.checkpoint}
echo "Located checkpoint from MLFlow log."
echo "Created symlink: {output.checkpoint} -> $(readlink {output.checkpoint})"
if [ "{params.checkpoint_is_registry}" = "True" ]; then
python workflow/scripts/inference_get_checkpoint_mlflow.py {params.checkpoint} --output {output.checkpoint}
echo "Downloaded checkpoint from MLFlow model registry: {output.checkpoint}"
else
ln -s $(python workflow/scripts/inference_get_checkpoint_mlflow.py {params.checkpoint}) {output.checkpoint}
echo "Located checkpoint from MLFlow log."
echo "Created symlink: {output.checkpoint} -> $(readlink {output.checkpoint})"
fi
elif [ "{params.checkpoint_type}" = "huggingface" ]; then
repo_id=$(python -c "import re; print(re.search(r'huggingface\.co/([^/]+/[^/]+)', '{params.checkpoint}').group(1))")
file_path=$(python -c "import re; print(re.search(r'huggingface\.co/[^/]+/[^/]+/blob/[^/]+/(.*)', '{params.checkpoint}').group(1))")
Expand Down
64 changes: 62 additions & 2 deletions workflow/scripts/inference_get_checkpoint_mlflow.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,85 @@
import argparse
import logging
import shutil
from pathlib import Path
from urllib.parse import urlparse

from anemoi.utils.mlflow.auth import TokenAuth
from anemoi.utils.mlflow.client import AnemoiMlflowClient
from mlflow.tracking import MlflowClient

LOG = logging.getLogger(__name__)

KNOWN_MLFLOW_TRACKING_URI = [
"mlflow.ecmwf.int",
"service.meteoswiss.ch",
"servicedepl.meteoswiss.ch",
]

CHECKPOINT_FILENAME = "inference-last.ckpt"


def _find_artifact_path(client, run_id, filename, path=""):
"""Recursively search a run's artifacts for a file by name, returning its artifact path."""
for artifact in client.list_artifacts(run_id, path):
if artifact.is_dir:
result = _find_artifact_path(client, run_id, filename, artifact.path)
if result is not None:
return result
elif Path(artifact.path).name == filename:
return artifact.path
return None


def main(args):
run_uri = args.run_uri
parsed_url = urlparse(run_uri)
if parsed_url.netloc in KNOWN_MLFLOW_TRACKING_URI:
uri, fragment = run_uri.split("#")
run_id = fragment.split("/")[-1]
if parsed_url.netloc == "mlflow.ecmwf.int":
TokenAuth(uri).login()
client = AnemoiMlflowClient(uri, authentication=True)
else:
client = MlflowClient(tracking_uri=uri)
if "/models/" in fragment:
parts = fragment.strip("/").split("/")
model_name = parts[1]
if len(parts) >= 4 and parts[2] == "versions":
model_version = client.get_model_version(model_name, parts[3])
else:
versions = client.search_model_versions(f"name='{model_name}'")
if not versions:
raise ValueError(
f"No versions found for model '{model_name}' in the registry"
)
model_version = max(versions, key=lambda v: int(v.version))
LOG.info(
"Found model version: %s (run ID: %s)",
model_version.version,
model_version.run_id,
)
output_path = Path(args.output)
artifact_path = _find_artifact_path(
client, model_version.run_id, CHECKPOINT_FILENAME
)
if artifact_path is None:
raise FileNotFoundError(
f"Could not find '{CHECKPOINT_FILENAME}' in MLflow artifacts for run {model_version.run_id}"
)
local_path = Path(
client.download_artifacts(
model_version.run_id, artifact_path, str(output_path.parent)
)
)
if local_path != output_path:
shutil.move(str(local_path), str(output_path))
return
else:
run_id = fragment.split("/")[-1]
run = client.get_run(run_id)
path = run.data.params.get("config.hardware.paths.checkpoints")
path = path or run.data.params.get("config.system.output.checkpoints.root")
path = Path(path) / "inference-last.ckpt"
path = Path(path) / CHECKPOINT_FILENAME
if not path.exists():
raise FileNotFoundError(f"Checkpoint path does not exist: {path}")
print(str(path))
Expand All @@ -38,5 +91,12 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Get local checkpoint location.")
parser.add_argument("run_uri", type=str, help="MLFlow run URI")
parser.add_argument(
"--output",
type=str,
default=None,
help="Destination path for the downloaded checkpoint (required for model registry URLs).",
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
main(args)
Loading