diff --git a/config/temporal-downscalers-ich1.yaml b/config/varda-single-1.0.yaml similarity index 87% rename from config/temporal-downscalers-ich1.yaml rename to config/varda-single-1.0.yaml index 18d83704..b9515033 100644 --- a/config/temporal-downscalers-ich1.yaml +++ b/config/varda-single-1.0.yaml @@ -1,7 +1,8 @@ # yaml-language-server: $schema=../workflow/tools/config.schema.json description: | - Evaluate skill of temporal downscaler driven by ICON-CH1 Stage E forecaster with - subgrid orography. + Evaluate skill of Varda-single-1.0 against ground observations. + +config_label: varda-single-1.0 dates: start: 2025-03-01T00:00 @@ -10,15 +11,15 @@ dates: runs: - temporal_downscaler: - checkpoint: /scratch/mch/miccatta/ICON_interpolator_checkpoints/checkpoint_stage-C-interpolator-n320-6hto1h-reduced-variables/f9279244ed6f4c458597bdcf335ab36f/inference-last.ckpt - label: Varda-Single + checkpoint: https://service.meteoswiss.ch/mlstore#/models/sruc-m-2-interpolator/versions/3 + label: Varda-single-1.0 steps: 0/120/1 config: resources/inference/configs/sgm-temporal-downscaler-global_trimedge_multi.yaml extra_requirements: - anemoi-datasets==0.5.35 # - anemoi-inference==0.11.0 forecaster: - checkpoint: https://service.meteoswiss.ch/mlstore#/experiments/602/runs/c30490b6ba064e4db03b430f3a2595ad + checkpoint: https://service.meteoswiss.ch/mlstore#/models/sruc-m-1-forecaster/versions/4 config: resources/inference/configs/sgm-multidataset-forecaster-global-ich1-oper.yaml steps: 0/120/6 - baseline: diff --git a/tests/conftest.py b/tests/conftest.py index 2ef90ddc..a12df00b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,16 +7,8 @@ @pytest.fixture -def example_forecasters_config(): - configfile = PROJECT_ROOT / "config/forecasters-ich1.yaml" - with open(configfile, "r") as f: - config = yaml.safe_load(f) - return config - - -@pytest.fixture -def example_temporal_downscalers_config(): - configfile = PROJECT_ROOT / "config/temporal-downscalers-ich1.yaml" +def example_config(): + configfile = PROJECT_ROOT / "config/varda-single-1.0.yaml" with open(configfile, "r") as f: config = yaml.safe_load(f) return config diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index ed6b99c4..20c4e7a6 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -3,25 +3,13 @@ from evalml.config import ConfigModel -def test_example_forecasters_config(example_forecasters_config): +def test_example_config(example_config): """Test that the example config loads correctly.""" # this shoudd not raise an error - _ = ConfigModel.model_validate(example_forecasters_config) + _ = ConfigModel.model_validate(example_config) # this should raise an error - del example_forecasters_config["runs"] + del example_config["runs"] with pytest.raises(ValueError, match="Field required"): - _ = ConfigModel.model_validate(example_forecasters_config) - - -def test_example_temporal_downscalers_config(example_temporal_downscalers_config): - """Test that the example config loads correctly.""" - - # this shoudd not raise an error - _ = ConfigModel.model_validate(example_temporal_downscalers_config) - - # this should raise an error - del example_temporal_downscalers_config["runs"] - with pytest.raises(ValueError, match="Field required"): - _ = ConfigModel.model_validate(example_temporal_downscalers_config) + _ = ConfigModel.model_validate(example_config) diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index 020ad956..c75da889 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -143,6 +143,12 @@ def model_id(checkpoint_uri: str) -> str: """Generate a model ID based on the checkpoint URI.""" ckpt_type = _checkpoint_uri_type(checkpoint_uri) if ckpt_type == "mlflow": + fragment = checkpoint_uri.split("#")[-1] + if "/models/" in fragment: + parts = fragment.strip("/").split("/") + if len(parts) >= 4 and parts[2] == "versions": + return f"{parts[1]}-v{parts[3]}"[:HASH_LENGTH] + return f"{parts[1]}-latest"[:HASH_LENGTH] return checkpoint_uri.split("/")[-1][:HASH_LENGTH] elif ckpt_type == "huggingface": return checkpoint_uri.split("/")[-1].split(".")[0] diff --git a/workflow/rules/inference.smk b/workflow/rules/inference.smk index f805ecd3..2d784774 100644 --- a/workflow/rules/inference.smk +++ b/workflow/rules/inference.smk @@ -19,14 +19,21 @@ rule inference_get_checkpoint: checkpoint_type=lambda wc: _checkpoint_uri_type( ENV_CONFIGS[wc.env_id]["checkpoint"] ), + checkpoint_is_registry=lambda wc: "/models/" + in ENV_CONFIGS[wc.env_id]["checkpoint"], shell: r""" ( mkdir -p $(dirname {output.checkpoint}) if [ "{params.checkpoint_type}" = "mlflow" ]; then - ln -s $(python workflow/scripts/inference_get_checkpoint_mlflow.py {params.checkpoint}) {output.checkpoint} - echo "Located checkpoint from MLFlow log." - echo "Created symlink: {output.checkpoint} -> $(readlink {output.checkpoint})" + if [ "{params.checkpoint_is_registry}" = "True" ]; then + python workflow/scripts/inference_get_checkpoint_mlflow.py {params.checkpoint} --output {output.checkpoint} + echo "Downloaded checkpoint from MLFlow model registry: {output.checkpoint}" + else + ln -s $(python workflow/scripts/inference_get_checkpoint_mlflow.py {params.checkpoint}) {output.checkpoint} + echo "Located checkpoint from MLFlow log." + echo "Created symlink: {output.checkpoint} -> $(readlink {output.checkpoint})" + fi elif [ "{params.checkpoint_type}" = "huggingface" ]; then repo_id=$(python -c "import re; print(re.search(r'huggingface\.co/([^/]+/[^/]+)', '{params.checkpoint}').group(1))") file_path=$(python -c "import re; print(re.search(r'huggingface\.co/[^/]+/[^/]+/blob/[^/]+/(.*)', '{params.checkpoint}').group(1))") diff --git a/workflow/scripts/inference_get_checkpoint_mlflow.py b/workflow/scripts/inference_get_checkpoint_mlflow.py index 219f1480..26e1a7a6 100644 --- a/workflow/scripts/inference_get_checkpoint_mlflow.py +++ b/workflow/scripts/inference_get_checkpoint_mlflow.py @@ -1,32 +1,85 @@ import argparse +import logging +import shutil from pathlib import Path from urllib.parse import urlparse + from anemoi.utils.mlflow.auth import TokenAuth from anemoi.utils.mlflow.client import AnemoiMlflowClient from mlflow.tracking import MlflowClient +LOG = logging.getLogger(__name__) + KNOWN_MLFLOW_TRACKING_URI = [ "mlflow.ecmwf.int", "service.meteoswiss.ch", "servicedepl.meteoswiss.ch", ] +CHECKPOINT_FILENAME = "inference-last.ckpt" + + +def _find_artifact_path(client, run_id, filename, path=""): + """Recursively search a run's artifacts for a file by name, returning its artifact path.""" + for artifact in client.list_artifacts(run_id, path): + if artifact.is_dir: + result = _find_artifact_path(client, run_id, filename, artifact.path) + if result is not None: + return result + elif Path(artifact.path).name == filename: + return artifact.path + return None + def main(args): run_uri = args.run_uri parsed_url = urlparse(run_uri) if parsed_url.netloc in KNOWN_MLFLOW_TRACKING_URI: uri, fragment = run_uri.split("#") - run_id = fragment.split("/")[-1] if parsed_url.netloc == "mlflow.ecmwf.int": TokenAuth(uri).login() client = AnemoiMlflowClient(uri, authentication=True) else: client = MlflowClient(tracking_uri=uri) + if "/models/" in fragment: + parts = fragment.strip("/").split("/") + model_name = parts[1] + if len(parts) >= 4 and parts[2] == "versions": + model_version = client.get_model_version(model_name, parts[3]) + else: + versions = client.search_model_versions(f"name='{model_name}'") + if not versions: + raise ValueError( + f"No versions found for model '{model_name}' in the registry" + ) + model_version = max(versions, key=lambda v: int(v.version)) + LOG.info( + "Found model version: %s (run ID: %s)", + model_version.version, + model_version.run_id, + ) + output_path = Path(args.output) + artifact_path = _find_artifact_path( + client, model_version.run_id, CHECKPOINT_FILENAME + ) + if artifact_path is None: + raise FileNotFoundError( + f"Could not find '{CHECKPOINT_FILENAME}' in MLflow artifacts for run {model_version.run_id}" + ) + local_path = Path( + client.download_artifacts( + model_version.run_id, artifact_path, str(output_path.parent) + ) + ) + if local_path != output_path: + shutil.move(str(local_path), str(output_path)) + return + else: + run_id = fragment.split("/")[-1] run = client.get_run(run_id) path = run.data.params.get("config.hardware.paths.checkpoints") path = path or run.data.params.get("config.system.output.checkpoints.root") - path = Path(path) / "inference-last.ckpt" + path = Path(path) / CHECKPOINT_FILENAME if not path.exists(): raise FileNotFoundError(f"Checkpoint path does not exist: {path}") print(str(path)) @@ -38,5 +91,12 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Get local checkpoint location.") parser.add_argument("run_uri", type=str, help="MLFlow run URI") + parser.add_argument( + "--output", + type=str, + default=None, + help="Destination path for the downloaded checkpoint (required for model registry URLs).", + ) args = parser.parse_args() + logging.basicConfig(level=logging.INFO) main(args)