From abf67fca5d86a50a4bcc6e999fd8048c2755eb65 Mon Sep 17 00:00:00 2001 From: Luca Lanzilao Date: Fri, 12 Jun 2026 12:00:03 +0200 Subject: [PATCH 1/6] support forecaster checkpoint from model registry --- workflow/rules/common.smk | 4 ++++ workflow/scripts/inference_get_checkpoint_mlflow.py | 8 +++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index 020ad956..238a56f7 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -143,6 +143,10 @@ def model_id(checkpoint_uri: str) -> str: """Generate a model ID based on the checkpoint URI.""" ckpt_type = _checkpoint_uri_type(checkpoint_uri) if ckpt_type == "mlflow": + fragment = checkpoint_uri.split("#")[-1] + if "/models/" in fragment: + parts = fragment.strip("/").split("/") + return f"{parts[1]}-v{parts[3]}"[:HASH_LENGTH] return checkpoint_uri.split("/")[-1][:HASH_LENGTH] elif ckpt_type == "huggingface": return checkpoint_uri.split("/")[-1].split(".")[0] diff --git a/workflow/scripts/inference_get_checkpoint_mlflow.py b/workflow/scripts/inference_get_checkpoint_mlflow.py index 219f1480..c0f5984b 100644 --- a/workflow/scripts/inference_get_checkpoint_mlflow.py +++ b/workflow/scripts/inference_get_checkpoint_mlflow.py @@ -17,12 +17,18 @@ def main(args): parsed_url = urlparse(run_uri) if parsed_url.netloc in KNOWN_MLFLOW_TRACKING_URI: uri, fragment = run_uri.split("#") - run_id = fragment.split("/")[-1] if parsed_url.netloc == "mlflow.ecmwf.int": TokenAuth(uri).login() client = AnemoiMlflowClient(uri, authentication=True) else: client = MlflowClient(tracking_uri=uri) + if "/models/" in fragment: + parts = fragment.strip("/").split("/") + model_name, version = parts[1], parts[3] + model_version = client.get_model_version(model_name, version) + run_id = model_version.run_id + else: + run_id = fragment.split("/")[-1] run = client.get_run(run_id) path = run.data.params.get("config.hardware.paths.checkpoints") path = path or run.data.params.get("config.system.output.checkpoints.root") From b6c71b460f334862d78d1ed0afee43d752add642 Mon Sep 17 00:00:00 2001 From: Luca Lanzilao Date: Wed, 17 Jun 2026 09:11:02 +0200 Subject: [PATCH 2/6] download artifacts locally when checkpoint is in model registry --- workflow/rules/common.smk | 4 +- workflow/rules/inference.smk | 12 +++- .../inference_get_checkpoint_mlflow.py | 57 +++++++++++++++++-- 3 files changed, 64 insertions(+), 9 deletions(-) diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index 238a56f7..c75da889 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -146,7 +146,9 @@ def model_id(checkpoint_uri: str) -> str: fragment = checkpoint_uri.split("#")[-1] if "/models/" in fragment: parts = fragment.strip("/").split("/") - return f"{parts[1]}-v{parts[3]}"[:HASH_LENGTH] + if len(parts) >= 4 and parts[2] == "versions": + return f"{parts[1]}-v{parts[3]}"[:HASH_LENGTH] + return f"{parts[1]}-latest"[:HASH_LENGTH] return checkpoint_uri.split("/")[-1][:HASH_LENGTH] elif ckpt_type == "huggingface": return checkpoint_uri.split("/")[-1].split(".")[0] diff --git a/workflow/rules/inference.smk b/workflow/rules/inference.smk index 2d770556..c7fb0c6a 100644 --- a/workflow/rules/inference.smk +++ b/workflow/rules/inference.smk @@ -19,14 +19,20 @@ rule inference_get_checkpoint: checkpoint_type=lambda wc: _checkpoint_uri_type( ENV_CONFIGS[wc.env_id]["checkpoint"] ), + checkpoint_is_registry=lambda wc: "/models/" in ENV_CONFIGS[wc.env_id]["checkpoint"], shell: r""" ( mkdir -p $(dirname {output.checkpoint}) if [ "{params.checkpoint_type}" = "mlflow" ]; then - ln -s $(python workflow/scripts/inference_get_checkpoint_mlflow.py {params.checkpoint}) {output.checkpoint} - echo "Located checkpoint from MLFlow log." - echo "Created symlink: {output.checkpoint} -> $(readlink {output.checkpoint})" + if [ "{params.checkpoint_is_registry}" = "True" ]; then + python workflow/scripts/inference_get_checkpoint_mlflow.py {params.checkpoint} --output {output.checkpoint} + echo "Downloaded checkpoint from MLFlow model registry: {output.checkpoint}" + else + ln -s $(python workflow/scripts/inference_get_checkpoint_mlflow.py {params.checkpoint}) {output.checkpoint} + echo "Located checkpoint from MLFlow log." + echo "Created symlink: {output.checkpoint} -> $(readlink {output.checkpoint})" + fi elif [ "{params.checkpoint_type}" = "huggingface" ]; then repo_id=$(python -c "import re; print(re.search(r'huggingface\.co/([^/]+/[^/]+)', '{params.checkpoint}').group(1))") file_path=$(python -c "import re; print(re.search(r'huggingface\.co/[^/]+/[^/]+/blob/[^/]+/(.*)', '{params.checkpoint}').group(1))") diff --git a/workflow/scripts/inference_get_checkpoint_mlflow.py b/workflow/scripts/inference_get_checkpoint_mlflow.py index c0f5984b..d4b3fb82 100644 --- a/workflow/scripts/inference_get_checkpoint_mlflow.py +++ b/workflow/scripts/inference_get_checkpoint_mlflow.py @@ -1,16 +1,35 @@ import argparse +import logging +import shutil from pathlib import Path from urllib.parse import urlparse + from anemoi.utils.mlflow.auth import TokenAuth from anemoi.utils.mlflow.client import AnemoiMlflowClient from mlflow.tracking import MlflowClient +LOG = logging.getLogger(__name__) + KNOWN_MLFLOW_TRACKING_URI = [ "mlflow.ecmwf.int", "service.meteoswiss.ch", "servicedepl.meteoswiss.ch", ] +CHECKPOINT_FILENAME = "inference-last.ckpt" + + +def _find_artifact_path(client, run_id, filename, path=""): + """Recursively search a run's artifacts for a file by name, returning its artifact path.""" + for artifact in client.list_artifacts(run_id, path): + if artifact.is_dir: + result = _find_artifact_path(client, run_id, filename, artifact.path) + if result is not None: + return result + elif Path(artifact.path).name == filename: + return artifact.path + return None + def main(args): run_uri = args.run_uri @@ -24,15 +43,33 @@ def main(args): client = MlflowClient(tracking_uri=uri) if "/models/" in fragment: parts = fragment.strip("/").split("/") - model_name, version = parts[1], parts[3] - model_version = client.get_model_version(model_name, version) - run_id = model_version.run_id + model_name = parts[1] + if len(parts) >= 4 and parts[2] == "versions": + model_version = client.get_model_version(model_name, parts[3]) + else: + versions = client.search_model_versions(f"name='{model_name}'") + if not versions: + raise ValueError(f"No versions found for model '{model_name}' in the registry") + model_version = max(versions, key=lambda v: int(v.version)) + LOG.info("Found model version: %s (run ID: %s)", model_version.version, model_version.run_id) + output_path = Path(args.output) + artifact_path = _find_artifact_path(client, model_version.run_id, CHECKPOINT_FILENAME) + if artifact_path is None: + raise FileNotFoundError( + f"Could not find '{CHECKPOINT_FILENAME}' in MLflow artifacts for run {model_version.run_id}" + ) + local_path = Path( + client.download_artifacts(model_version.run_id, artifact_path, str(output_path.parent)) + ) + if local_path != output_path: + shutil.move(str(local_path), str(output_path)) + return else: run_id = fragment.split("/")[-1] run = client.get_run(run_id) path = run.data.params.get("config.hardware.paths.checkpoints") path = path or run.data.params.get("config.system.output.checkpoints.root") - path = Path(path) / "inference-last.ckpt" + path = Path(path) / CHECKPOINT_FILENAME if not path.exists(): raise FileNotFoundError(f"Checkpoint path does not exist: {path}") print(str(path)) @@ -43,6 +80,16 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Get local checkpoint location.") - parser.add_argument("run_uri", type=str, help="MLFlow run URI") + parser.add_argument("run_uri", + type=str, + help="MLFlow run URI" + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Destination path for the downloaded checkpoint (required for model registry URLs).", + ) args = parser.parse_args() + logging.basicConfig(level=logging.INFO) main(args) From c0f67e75b10419a5ca2d0a24bbe8e0b4c790b1e6 Mon Sep 17 00:00:00 2001 From: Daniele Nerini Date: Thu, 18 Jun 2026 14:03:02 +0200 Subject: [PATCH 3/6] Linting --- workflow/rules/inference.smk | 3 ++- .../inference_get_checkpoint_mlflow.py | 23 ++++++++++++------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/workflow/rules/inference.smk b/workflow/rules/inference.smk index c7fb0c6a..fd059312 100644 --- a/workflow/rules/inference.smk +++ b/workflow/rules/inference.smk @@ -19,7 +19,8 @@ rule inference_get_checkpoint: checkpoint_type=lambda wc: _checkpoint_uri_type( ENV_CONFIGS[wc.env_id]["checkpoint"] ), - checkpoint_is_registry=lambda wc: "/models/" in ENV_CONFIGS[wc.env_id]["checkpoint"], + checkpoint_is_registry=lambda wc: "/models/" + in ENV_CONFIGS[wc.env_id]["checkpoint"], shell: r""" ( diff --git a/workflow/scripts/inference_get_checkpoint_mlflow.py b/workflow/scripts/inference_get_checkpoint_mlflow.py index d4b3fb82..26e1a7a6 100644 --- a/workflow/scripts/inference_get_checkpoint_mlflow.py +++ b/workflow/scripts/inference_get_checkpoint_mlflow.py @@ -49,17 +49,27 @@ def main(args): else: versions = client.search_model_versions(f"name='{model_name}'") if not versions: - raise ValueError(f"No versions found for model '{model_name}' in the registry") + raise ValueError( + f"No versions found for model '{model_name}' in the registry" + ) model_version = max(versions, key=lambda v: int(v.version)) - LOG.info("Found model version: %s (run ID: %s)", model_version.version, model_version.run_id) + LOG.info( + "Found model version: %s (run ID: %s)", + model_version.version, + model_version.run_id, + ) output_path = Path(args.output) - artifact_path = _find_artifact_path(client, model_version.run_id, CHECKPOINT_FILENAME) + artifact_path = _find_artifact_path( + client, model_version.run_id, CHECKPOINT_FILENAME + ) if artifact_path is None: raise FileNotFoundError( f"Could not find '{CHECKPOINT_FILENAME}' in MLflow artifacts for run {model_version.run_id}" ) local_path = Path( - client.download_artifacts(model_version.run_id, artifact_path, str(output_path.parent)) + client.download_artifacts( + model_version.run_id, artifact_path, str(output_path.parent) + ) ) if local_path != output_path: shutil.move(str(local_path), str(output_path)) @@ -80,10 +90,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Get local checkpoint location.") - parser.add_argument("run_uri", - type=str, - help="MLFlow run URI" - ) + parser.add_argument("run_uri", type=str, help="MLFlow run URI") parser.add_argument( "--output", type=str, From 84f84feeb56a7db50b9b46b6fe655a2d7352f3c1 Mon Sep 17 00:00:00 2001 From: Daniele Nerini Date: Thu, 18 Jun 2026 14:14:25 +0200 Subject: [PATCH 4/6] Update config using model registry --- config/temporal-downscalers-ich1.yaml | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/config/temporal-downscalers-ich1.yaml b/config/temporal-downscalers-ich1.yaml index 18d83704..b9515033 100644 --- a/config/temporal-downscalers-ich1.yaml +++ b/config/temporal-downscalers-ich1.yaml @@ -1,7 +1,8 @@ # yaml-language-server: $schema=../workflow/tools/config.schema.json description: | - Evaluate skill of temporal downscaler driven by ICON-CH1 Stage E forecaster with - subgrid orography. + Evaluate skill of Varda-single-1.0 against ground observations. + +config_label: varda-single-1.0 dates: start: 2025-03-01T00:00 @@ -10,15 +11,15 @@ dates: runs: - temporal_downscaler: - checkpoint: /scratch/mch/miccatta/ICON_interpolator_checkpoints/checkpoint_stage-C-interpolator-n320-6hto1h-reduced-variables/f9279244ed6f4c458597bdcf335ab36f/inference-last.ckpt - label: Varda-Single + checkpoint: https://service.meteoswiss.ch/mlstore#/models/sruc-m-2-interpolator/versions/3 + label: Varda-single-1.0 steps: 0/120/1 config: resources/inference/configs/sgm-temporal-downscaler-global_trimedge_multi.yaml extra_requirements: - anemoi-datasets==0.5.35 # - anemoi-inference==0.11.0 forecaster: - checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/c30490b6ba064e4db03b430f3a2595ad + checkpoint: https://service.meteoswiss.ch/mlstore#/models/sruc-m-1-forecaster/versions/4 config: resources/inference/configs/sgm-multidataset-forecaster-global-ich1-oper.yaml steps: 0/120/6 - baseline: From 4c8efcd60a927bc1eb29aba592677f142be7d0be Mon Sep 17 00:00:00 2001 From: Daniele Nerini Date: Thu, 18 Jun 2026 14:17:47 +0200 Subject: [PATCH 5/6] Rename config --- config/{temporal-downscalers-ich1.yaml => varda-single-1.0.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename config/{temporal-downscalers-ich1.yaml => varda-single-1.0.yaml} (100%) diff --git a/config/temporal-downscalers-ich1.yaml b/config/varda-single-1.0.yaml similarity index 100% rename from config/temporal-downscalers-ich1.yaml rename to config/varda-single-1.0.yaml From 51a7224e1a8bb826c49a51c2e5de2a039cf22f7b Mon Sep 17 00:00:00 2001 From: Daniele Nerini Date: Thu, 18 Jun 2026 14:24:16 +0200 Subject: [PATCH 6/6] Fix tests --- tests/conftest.py | 12 ++---------- tests/unit/test_config.py | 20 ++++---------------- 2 files changed, 6 insertions(+), 26 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 2ef90ddc..a12df00b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,16 +7,8 @@ @pytest.fixture -def example_forecasters_config(): - configfile = PROJECT_ROOT / "config/forecasters-ich1.yaml" - with open(configfile, "r") as f: - config = yaml.safe_load(f) - return config - - -@pytest.fixture -def example_temporal_downscalers_config(): - configfile = PROJECT_ROOT / "config/temporal-downscalers-ich1.yaml" +def example_config(): + configfile = PROJECT_ROOT / "config/varda-single-1.0.yaml" with open(configfile, "r") as f: config = yaml.safe_load(f) return config diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index ed6b99c4..20c4e7a6 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -3,25 +3,13 @@ from evalml.config import ConfigModel -def test_example_forecasters_config(example_forecasters_config): +def test_example_config(example_config): """Test that the example config loads correctly.""" # this shoudd not raise an error - _ = ConfigModel.model_validate(example_forecasters_config) + _ = ConfigModel.model_validate(example_config) # this should raise an error - del example_forecasters_config["runs"] + del example_config["runs"] with pytest.raises(ValueError, match="Field required"): - _ = ConfigModel.model_validate(example_forecasters_config) - - -def test_example_temporal_downscalers_config(example_temporal_downscalers_config): - """Test that the example config loads correctly.""" - - # this shoudd not raise an error - _ = ConfigModel.model_validate(example_temporal_downscalers_config) - - # this should raise an error - del example_temporal_downscalers_config["runs"] - with pytest.raises(ValueError, match="Field required"): - _ = ConfigModel.model_validate(example_temporal_downscalers_config) + _ = ConfigModel.model_validate(example_config)