diff --git a/.gitignore b/.gitignore
index 8732a470..a352b891 100644
--- a/.gitignore
+++ b/.gitignore
@@ -175,6 +175,8 @@ pyrightconfig.json
*.torch
plots/*
*.npz
+outputs/*
+logs/*
# conda
.conda/*
@@ -183,7 +185,6 @@ plots/*
temp.*
# local script
-interpolate.sh
cosmo-grid.npz
*.out
.conda/*
@@ -194,7 +195,9 @@ temp.zarr.sync*
src/hirad/eval/__pycache__/*
interpolate_basic.log
interpolated.torch
+mlruns/
+.secrets.env
out
core
-*.png
-*.nc
\ No newline at end of file
+*.nc
+*.err
diff --git a/README.md b/README.md
index b0dbd2eb..72e4b9d1 100644
--- a/README.md
+++ b/README.md
@@ -2,121 +2,211 @@
HiRAD-Gen is short for high-resolution atmospheric downscaling using generative models. This repository contains the code and configuration required to train and use the model.
-## Installation (Alps)
+[Showcase](#Showcase)
+[Setup - clariden/santis](#setup-claridensantis)
+[Inference - clariden/santis](#running-inference-on-alps)
+[Regression training - clariden/santis](#run-regression-model-training-alps)
+[Diffusion training - clariden/santis](#run-diffusion-model-training-alps)
+
+## Showcase
+
+
+
+ |
+ Input ERA5 |
+ Prediction |
+ Target REAL-CH1 |
+
+
+ | 2t |
+  |
+  |
+  |
+
+
+ | 10u |
+  |
+  |
+  |
+
+
+ | 10v |
+  |
+  |
+  |
+
+
+ | tp |
+  |
+  |
+  |
+
+
+
+### Ensemble Total Preceipitatin 1h
+
+
+
+## Setup clariden/santis container environment
+Container environment setup needed to run training and inference experiments on clariden/santis is contained in this repository under `ci/edf/modulus_env.toml`. Image squash is on clariden/alps under `/capstor/scratch/cscs/pstamenk/corr_diff.sqsh`. All the jobs can be run using this environment without additional installations and setup.
-To set up the environment for **HiRAD-Gen** on Alps supercomputer, follow these steps:
-
-1. **Start the PyTorch user environment**:
- ```bash
- uenv start pytorch/v2.6.0:v1 --view=default
- ```
-
-2. **Create a Python virtual environment** (replace `{env_name}` with your desired environment name):
- ```bash
- python -m venv ./{env_name}
- ```
-
-3. **Activate the virtual environment**:
- ```bash
- source ./{env_name}/bin/activate
- ```
-
-4. **Install project dependencies**:
- ```bash
- pip install -e .
- ```
-
-This will set up the necessary environment to run HiRAD-Gen within the Alps infrastructure.
-
-## Training
+## Inference
-### Run regression model training (Alps)
+### Running inference on Alps
-1. Script for running the training of regression model is in `src/hirad/train_regression.sh`.
+1. Script for running the inference is in `src/hirad/generate.sh`.
Inside this script set the following:
```bash
### OUTPUT ###
-#SBATCH --output=your_path_to_output_log
-#SBATCH --error=your_path_to_output_error
+#SBATCH --output=path_to_output_log
+#SBATCH --error=path_to_output_error
```
```bash
-#SBATCH -A your_compute_group
+#SBATCH -A compute_group
```
```bash
-srun bash -c "
- . ./{your_env_name}/bin/activate
- python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml
+srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -c "
+ pip install -e .
+ python src/hirad/inference/generate.py --config-name=main-config-file-in-src/hirad/conf.yaml
"
```
2. Set up the following config files in `src/hirad/conf`:
-- In `training_era_cosmo_regression.yaml` set:
+- In main config file (by default `generate_era_real.yaml`) set:
```
hydra:
run:
- dir: your_path_to_save_training_output
+ dir: your_path_to_save_inference_output
```
-- In `training/era_cosmo_regression.yaml` set:
+- In generation config file (by default `generation/era_real.yaml`):
+Choose the inference mode:
```
-hp:
- training_duration: number of samples to train for (set to 4 for debugging, 512 fits into 30 minutes on 1 gpu with total_batch_size: 4)
+inference_mode: all/regression/diffusion
+```
+by default `all` does both regression and diffusion. Depending on mode, regression and/or diffusion model pretrained weights should be provided:
```
-- In `dataset/era_cosmo.yaml` set the `dataset_path` if different from default.
+io:
+ res_ckpt_path: path_to_directory_containing_diffusion_training_model_checkpoints
+ reg_ckpt_path: path_to_directory_containing_regression_training_model_checkpoints
+```
+Finally, from the dataset, subset of time steps can be chosen to do inference for.
+
+One way is to list steps under `times:` in format `%Y%m%d-%H%M` for era5_cosmo dataset.
+
+The other way is to specify `times_range:` with three items: first time step (`%Y%m%d-%H%M`), last time step (`%Y%m%d-%H%M`), hour shift (int). Hour shift specifies distance in hours between closest time steps for specific dataset.
3. Submit the job with:
```bash
-sbatch src/hirad/train_regression.sh
+sbatch src/hirad/generate.sh
```
-### Run diffusion model training (Alps)
-Before training diffusion model, checkpoint for regression model has to exist.
+### Visualizing results
-1. Script for running the training of diffusion model is in `src/hirad/train_diffusion.sh`.
+After generation is finished, visualization of results can be done using `src/hirad/snapshots.sh`. Set:
+```bash
+### OUTPUT ###
+#SBATCH --output=path_to_output_log
+```
+```bash
+### ENVIRONMENT ####
+#SBATCH -A compute_group
+```
+```bash
+srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -c "
+ pip install -e .
+ python src/hirad/eval/snapshots.py --config-name=src/hirad/conf/config-file-in-src/hirad/conf.yaml
+"
+```
+In config file (by default `eval_real.yaml`) set:
+```bash
+# Path to the inference output directory
+inference_output_dir: '/path/to/generated/results/directory'
+results_dir_name: 'name_of_directory_to_save_output_plots'
+```
+If you want to generate plots for subset of times from inference set (follow same convection as in generate config):
+```
+times: list of times to visualize
+times_range: [start time, end time, time step] to visualize
+```
+
+Other setting can be changed according to output grid.
+
+Submit the job with:
+```bash
+sbatch src/hirad/snapshots.sh
+```
+
+### Evaluation of generated data
+
+Evaluation of generated samples can be done using `src/hirad/eval_precip.sh` and `src/hirad/eval_wind.sh`. Set:
+```bash
+### OUTPUT ###
+#SBATCH --output=path_to_output_log
+
+### ENVIRONMENT ####
+#SBATCH -A compute_group
+
+### CONFIG ###
+CONFIG_NAME="src/hirad/conf/config_file.yaml"
+```
+Default config file is the same as for visualization `eval_real.yaml`, and requires to set the same fileds. In both `eval_precip.sh` and `eval_wind.sh` there are several python scripts called. They are all commented out by default. Comment out the ones you want to run.
+
+Submit jobs with:
+```bash
+sbatch src/hirad/eval_precip.sh
+sbatch src/hirad/eval_wind.sh
+```
+
+## Training
+
+### Run regression model training (Alps)
+
+1. Script for running the training of regression model is in `src/hirad/train_regression.sh`. Here, you can change the sbatch settings.
Inside this script set the following:
```bash
### OUTPUT ###
-#SBATCH --output=your_path_to_output_log
-#SBATCH --error=your_path_to_output_error
+#SBATCH --output=path_to_output_log
+#SBATCH --error=path_to_output_error
```
```bash
-#SBATCH -A your_compute_group
+#SBATCH -A compute_group
```
```bash
-srun bash -c "
- . ./{your_env_name}/bin/activate
- python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion.yaml
+srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -c "
+ pip install -e .
+ python src/hirad/training/train.py --config-name=main-config-file-in-src/hirad/conf.yaml
"
```
2. Set up the following config files in `src/hirad/conf`:
-- In `training_era_cosmo_diffusion.yaml` set:
+- In main config file (by default `training_era_real_regression.yaml`) set:
```
hydra:
run:
- dir: your_path_to_save_training_output
-```
-- In `training/era_cosmo_regression.yaml` set:
-```
-hp:
- training_duration: number of samples to train for (set to 4 for debugging, 512 fits into 30 minutes on 1 gpu with total_batch_size: 4)
-io:
- regression_checkpoint_path: path_to_directory_containing_regression_training_model_checkpoints
+ dir: your_path_to_save_training_outputs
```
-- In `dataset/era_cosmo.yaml` set the `dataset_path` if different from default.
+- All other parameters for training regression can be changed in the main config file and config files the main config is referencing (default values are working for debugging purposes).
3. Submit the job with:
```bash
-sbatch src/hirad/train_diffusion.sh
+sbatch src/hirad/train_regression.sh
```
-## Inference
-
-### Running inference on Alps
+### Run diffusion model training (Alps)
+Before training diffusion model, checkpoint for regression model has to exist.
-1. Script for running the inference is in `src/hirad/generate.sh`.
-Inside this script set the following:
+1. Script for running the training of diffusion model is in `src/hirad/train_diffusion.sh`. Here, you can change the sbatch settings. Inside this script set the following:
```bash
### OUTPUT ###
#SBATCH --output=your_path_to_output_log
@@ -126,42 +216,33 @@ Inside this script set the following:
#SBATCH -A your_compute_group
```
```bash
-srun bash -c "
- . ./{your_env_name}/bin/activate
- python src/hirad/inference/generate.py --config-name=generate_era_cosmo.yaml
+srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -c "
+ pip install -e .
+ python src/hirad/training/train.py --config-name=main-config-file-in-src/hirad/conf.yaml
"
```
2. Set up the following config files in `src/hirad/conf`:
-- In `generate_era_cosmo.yaml` set:
+- In main config file (by default `training_era_real_diffusion_patched.yaml`) set:
```
hydra:
run:
- dir: your_path_to_save_inference_output
-```
-- In `generation/era_cosmo.yaml`:
-Choose the inference mode:
-```
-inference_mode: all/regression/diffusion
+ dir: your_path_to_save_training_output
```
-by default `all` does both regression and diffusion. Depending on mode, regression and/or diffusion model pretrained weights should be provided:
+- In training config file (by default `training/era_real_diffusion_patched.yaml`) set:
```
io:
- res_ckpt_path: path_to_directory_containing_diffusion_training_model_checkpoints
- reg_ckpt_path: path_to_directory_containing_regression_training_model_checkpoints
+ regression_checkpoint_path: path_to_directory_containing_regression_training_model_checkpoints
```
-Finally, from the dataset, subset of time steps can be chosen to do inference for.
-
-One way is to list steps under `times:` in format `%Y%m%d-%H%M` for era5_cosmo dataset.
-
-The other way is to specify `times_range:` with three items: first time step (`%Y%m%d-%H%M`), last time step (`%Y%m%d-%H%M`), hour shift (int). Hour shift specifies distance in hours between closest time steps for specific dataset (6 for era_cosmo).
-
-By default, inference is done for one time step `20160101-0000`
-
-- In `dataset/era_cosmo.yaml` set the `dataset_path` if different from default.
+- All other parameters for training regression can be changed in the main config file and config files the main config is referencing (default values are working for debugging purposes).
3. Submit the job with:
```bash
-sbatch src/hirad/generate.sh
-```
\ No newline at end of file
+sbatch src/hirad/train_diffusion.sh
+```
+
+## MLflow logging
+
+During training MLflow can be used to log metrics.
+Logging config files for regression and diffusion are located in `src/hirad/conf/logging/`. Set `method` to `mlflow` and specify `uri` if you want to log on remote server, otherwise run will be logged locally in output directory. Other options can also be modified here.
diff --git a/ci/cscs.yml b/ci/cscs.yml
index fc926459..732a4497 100644
--- a/ci/cscs.yml
+++ b/ci/cscs.yml
@@ -12,14 +12,20 @@ build_job:
stage: build
extends: .container-builder-cscs-gh200
variables:
- DOCKERFILE: ci/docker/Dockerfile
+ DOCKERFILE: ci/docker/Dockerfile.ci
+ KUBERNETES_MEMORY_REQUEST: '64Gi'
+ KUBERNETES_MEMORY_LIMIT: '64Gi'
-#test_job:
-# stage: test
-# extends: .container-runner-clariden-gh200
-# image: $PERSIST_IMAGE_NAME
-# script:
-# - /opt/helloworld/bin/hello
-# variables:
-# SLURM_JOB_NUM_NODES: 2
-# SLURM_NTASKS: 2
+test_job:
+ stage: test
+ extends: .container-runner-santis-gh200
+ image: $PERSIST_IMAGE_NAME
+ script:
+ - pytest /opt/hirad-gen/tests -v
+ variables:
+ USE_MPI: NO
+ SLURM_MPI_TYPE: pmix
+ SLURM_NETWORK: disable_rdzv_get
+ PMIX_MCA_psec: native
+ SLURM_JOB_NUM_NODES: 1
+ SLURM_NTASKS: 1
diff --git a/ci/docker/Dockerfile b/ci/docker/Dockerfile
index 4772d76d..696fe0b7 100644
--- a/ci/docker/Dockerfile
+++ b/ci/docker/Dockerfile
@@ -1,41 +1,13 @@
-# Following some suggestions in https://meteoswiss.atlassian.net/wiki/spaces/APN/pages/719684202/Clariden+Alps+environment+setup
-
-#FROM ubuntu:22.04 as builder
-FROM nvcr.io/nvidia/pytorch:25.01-py3
-
-COPY . /src
+#FROM physicsnemo-cscs as builder
+FROM jfrog.svc.cscs.ch/docker-group-csstaff/alps-images/ngc-physicsnemo:25.11-alps4-dev
# setup
-RUN apt-get update && apt-get install python3-pip python3-venv -y
RUN pip install --upgrade \
pip
- #ninja
- #wheel
- #packaging
- #setuptools
-
-# update flash-attn
-RUN MAX_JOBS=16 pip install --upgrade --no-build-isolation \
- flash-attn==2.7.4.post1 -v
-# install the rest of dependencies
-# TODO: Factor pydeps into a separate file(s)
-# TODO: Add versions for things
+# install dependencies
RUN pip install \
anemoi-datasets \
cartopy \
- matplotlib \
- numpy \
- pandas \
- scipy \
- torch
-
-
-# replace pynvml with nvidia-ml-py
-RUN pip uninstall -y pynvml && pip install nvidia-ml-py
-
-#CMD ["python3.11" "src/input_data/interpolate_basic_test.py"]
-
-
-
-
+ mlflow \
+ xskillscore
diff --git a/ci/docker/Dockerfile.ci b/ci/docker/Dockerfile.ci
new file mode 100644
index 00000000..f496fd44
--- /dev/null
+++ b/ci/docker/Dockerfile.ci
@@ -0,0 +1,14 @@
+FROM jfrog.svc.cscs.ch/docker-group-csstaff/alps-images/ngc-physicsnemo:25.11-alps4-dev as builder
+
+# setup
+RUN pip install --upgrade pip
+
+# install dependencies
+RUN pip install mlflow \
+ anemoi-datasets
+
+COPY . /opt/hirad-gen
+
+WORKDIR /opt/hirad-gen
+
+RUN pip install /opt/hirad-gen --no-dependencies
\ No newline at end of file
diff --git a/ci/docker/Dockerfile.corrdiff b/ci/docker/Dockerfile.corrdiff
new file mode 100644
index 00000000..f4be0712
--- /dev/null
+++ b/ci/docker/Dockerfile.corrdiff
@@ -0,0 +1,14 @@
+FROM nvcr.io/nvidia/physicsnemo/physicsnemo:25.06
+
+# setup
+RUN apt-get update && apt-get install python3-pip python3-venv -y
+RUN pip install --upgrade pip
+
+# Install the rest of dependencies.
+RUN pip install \
+ anemoi.datasets \
+ Cartopy==0.22.0 \
+ xskillscore \
+ scoringrules \
+ mlflow \
+ meteodata-lab
diff --git a/ci/docker/Dockerfile.python b/ci/docker/Dockerfile.python
new file mode 100644
index 00000000..62284af5
--- /dev/null
+++ b/ci/docker/Dockerfile.python
@@ -0,0 +1,63 @@
+# Following some suggestions in https://meteoswiss.atlassian.net/wiki/spaces/APN/pages/719684202/Clariden+Alps+environment+setup
+
+FROM ubuntu:24.04 as builder
+# 24.04 needed for GDAL >=3.5
+#FROM nvcr.io/nvidia/pytorch:25.01-py3
+
+
+# setup
+RUN apt-get update && apt-get install python3-pip python3-venv -y
+RUN apt-get install -y \
+ libgdal-dev \
+ libnetcdff-dev \
+ ecbuild \
+ curl \
+ wget \
+ build-essential \
+ cmake \
+ gfortran\
+ vim
+
+# Install eccodes from source, because we need version 2.38
+RUN wget https://github.com/ecmwf/eccodes/archive/refs/tags/2.38.0.tar.gz &&\
+ tar -xzf 2.38.0.tar.gz &&\
+ mkdir eccodes-2.38.0/build
+WORKDIR eccodes-2.38.0/build
+RUN cmake ../../eccodes-2.38.0 &&\
+ make &&\
+ ctest &&\
+ make install
+WORKDIR ../../
+
+#RUN pip install --upgrade \
+# pip
+ #ninja
+ #wheel
+ #packaging
+ #setuptools
+
+# install the rest of dependencies
+# TODO: Factor pydeps into a separate file(s)
+# TODO: Add versions for things
+#RUN pip install \
+# anemoi-datasets \
+# cartopy \
+# matplotlib \
+# numpy \
+# pandas \
+# scipy \
+# torch
+
+#RUN apt-get install libgdal-dev -y # Old version 3.4.1, needs 3.5
+#RUN apt-get install gdal-bin -y # Insufficient
+#RUN apt-get install curl -y
+#RUN apt-get install sudo -y
+#RUN curl -sL "https://url.geocarpentry.org/gdal-ubuntu" | bash
+
+
+#RUN pip install \
+# rasterio==1.3.9 \
+ # meteodata-lab[regrid]
+
+
+
diff --git a/ci/edf/modulus_env.toml b/ci/edf/modulus_env.toml
new file mode 100644
index 00000000..6227a244
--- /dev/null
+++ b/ci/edf/modulus_env.toml
@@ -0,0 +1,11 @@
+image = "/capstor/scratch/cscs/pstamenk/corr_diff.sqsh"
+
+mounts = ["/capstor", "/iopsstor", "/users"]
+
+# The initial directory in the container.
+workdir = "${PWD}"
+
+[env]
+PMIX_MCA_psec = "native"
+[annotations]
+com.hooks.cxi.enabled = "false"
diff --git a/ci/edf/python_env.toml b/ci/edf/python_env.toml
new file mode 100644
index 00000000..c8dded80
--- /dev/null
+++ b/ci/edf/python_env.toml
@@ -0,0 +1,7 @@
+image = "/capstor/scratch/cscs/mmcgloho/python-ubuntu.sqsh"
+
+mounts = ["/capstor", "/iopsstor", "/users"]
+
+# The initial directory in the container.
+workdir = "${PWD}"
+
diff --git a/docs/images/showcase/10-input.png b/docs/images/showcase/10-input.png
new file mode 100644
index 00000000..022e2718
Binary files /dev/null and b/docs/images/showcase/10-input.png differ
diff --git a/docs/images/showcase/10-target.png b/docs/images/showcase/10-target.png
new file mode 100644
index 00000000..c940b28e
Binary files /dev/null and b/docs/images/showcase/10-target.png differ
diff --git a/docs/images/showcase/10u-pred.png b/docs/images/showcase/10u-pred.png
new file mode 100644
index 00000000..c1e7c735
Binary files /dev/null and b/docs/images/showcase/10u-pred.png differ
diff --git a/docs/images/showcase/10v-input.png b/docs/images/showcase/10v-input.png
new file mode 100644
index 00000000..edc327dc
Binary files /dev/null and b/docs/images/showcase/10v-input.png differ
diff --git a/docs/images/showcase/10v-pred.png b/docs/images/showcase/10v-pred.png
new file mode 100644
index 00000000..5cddcdfc
Binary files /dev/null and b/docs/images/showcase/10v-pred.png differ
diff --git a/docs/images/showcase/10v-target.png b/docs/images/showcase/10v-target.png
new file mode 100644
index 00000000..53ad6e88
Binary files /dev/null and b/docs/images/showcase/10v-target.png differ
diff --git a/docs/images/showcase/2t-input.png b/docs/images/showcase/2t-input.png
new file mode 100644
index 00000000..e74ae328
Binary files /dev/null and b/docs/images/showcase/2t-input.png differ
diff --git a/docs/images/showcase/2t-pred.png b/docs/images/showcase/2t-pred.png
new file mode 100644
index 00000000..6a351040
Binary files /dev/null and b/docs/images/showcase/2t-pred.png differ
diff --git a/docs/images/showcase/2t-target.png b/docs/images/showcase/2t-target.png
new file mode 100644
index 00000000..6425e346
Binary files /dev/null and b/docs/images/showcase/2t-target.png differ
diff --git a/docs/images/showcase/tp-input.png b/docs/images/showcase/tp-input.png
new file mode 100644
index 00000000..9d77b4de
Binary files /dev/null and b/docs/images/showcase/tp-input.png differ
diff --git a/docs/images/showcase/tp-pred1.png b/docs/images/showcase/tp-pred1.png
new file mode 100644
index 00000000..6e54b5f0
Binary files /dev/null and b/docs/images/showcase/tp-pred1.png differ
diff --git a/docs/images/showcase/tp-pred2.png b/docs/images/showcase/tp-pred2.png
new file mode 100644
index 00000000..5c2a335d
Binary files /dev/null and b/docs/images/showcase/tp-pred2.png differ
diff --git a/docs/images/showcase/tp-pred3.png b/docs/images/showcase/tp-pred3.png
new file mode 100644
index 00000000..87598875
Binary files /dev/null and b/docs/images/showcase/tp-pred3.png differ
diff --git a/docs/images/showcase/tp-pred4.png b/docs/images/showcase/tp-pred4.png
new file mode 100644
index 00000000..c6b7645b
Binary files /dev/null and b/docs/images/showcase/tp-pred4.png differ
diff --git a/docs/images/showcase/tp-target.png b/docs/images/showcase/tp-target.png
new file mode 100644
index 00000000..1ed3e9c2
Binary files /dev/null and b/docs/images/showcase/tp-target.png differ
diff --git a/interpolate.sh b/interpolate.sh
deleted file mode 100755
index 607ad66f..00000000
--- a/interpolate.sh
+++ /dev/null
@@ -1,6 +0,0 @@
-#!/bin/bash
-
-#SBATCH --partition=postproc
-#SBATCH --time=23:59:00
-
-python src/input_data/interpolate_basic.py src/input_data/era-all.yaml src/input_data/cosmo-all.yaml /store_new/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-all-channels/
diff --git a/pyproject.toml b/pyproject.toml
index 1477899a..7ffe3f37 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -10,19 +10,10 @@ authors = [
{ name="Petar Stamenkovic", email="petar.stamenkovic@meteoswiss.ch" }
]
readme = "README.md"
-requires-python = ">=3.12"
+#requires-python = ">=3.12"
license = {file = "LICENSE"}
dependencies = [
- "cartopy>=0.24.1",
- "cftime>=1.6.4",
- "hydra-core>=1.3.2",
- "matplotlib>=3.10.1",
- "omegaconf>=2.3.0",
- "tensorboard>=2.19.0",
- "termcolor>=3.1.0",
- "torchinfo>=1.8.0",
- "treelib>=1.7.1"
]
[tool.setuptools]
diff --git a/src/hirad/calculate_stats.sh b/src/hirad/calculate_stats.sh
new file mode 100644
index 00000000..3e8eaf14
--- /dev/null
+++ b/src/hirad/calculate_stats.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+
+#SBATCH --job-name="corrdiff-first-stage"
+
+### HARDWARE ###
+#SBATCH --partition=normal
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=1
+#SBATCH --gpus-per-node=1
+#SBATCH --cpus-per-task=72
+#SBATCH --time=12:00:00
+#SBATCH --no-requeue
+#SBATCH --exclusive
+
+
+### OUTPUT ###
+#SBATCH --output=/capstor/scratch/cscs/pstamenk/logs/calculate_stats.log
+#SBATCH --error=/capstor/scratch/cscs/pstamenk/logs/calculate_stats.err
+
+### ENVIRONMENT ####
+#SBATCH -A a161
+
+# Get master node.
+MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)"
+# Get IP for hostname.
+MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')"
+export MASTER_ADDR
+export MASTER_PORT=29500
+
+export OMP_NUM_THREADS=1
+
+srun --environment=./ci/edf/modulus_env.toml bash -c "
+ source ../hirad_env/hirad/bin/activate
+ python src/hirad/input_data/calculate_transformed_stats.py
+"
\ No newline at end of file
diff --git a/src/hirad/conf/compute_eval.yaml b/src/hirad/conf/compute_eval.yaml
new file mode 100644
index 00000000..069ad1e0
--- /dev/null
+++ b/src/hirad/conf/compute_eval.yaml
@@ -0,0 +1,12 @@
+hydra:
+ job:
+ chdir: true
+ name: diffusion_era5_cosmo_7500000_test
+ run:
+ dir: ./outputs/generation/${hydra:job.name}
+
+# Get defaults
+defaults:
+ - _self_
+ # Dataset
+ - dataset/era_cosmo_inference
diff --git a/src/hirad/conf/dataset/anemoi_era_cosmo.yaml b/src/hirad/conf/dataset/anemoi_era_cosmo.yaml
new file mode 100644
index 00000000..e3af9871
--- /dev/null
+++ b/src/hirad/conf/dataset/anemoi_era_cosmo.yaml
@@ -0,0 +1,19 @@
+type: anemoi_era5_cosmo
+input_anemoi_dataset_path: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+corrected_tp_path: '/capstor/scratch/cscs/mmcgloho/basic-numpy/copernicus-era-1h/copernicus-interpolated/'
+target_anemoi_dataset_path: '/capstor/store/cscs/pasc/c38/anemoi-datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr'
+input_channel_names: [2t, 10u, 10v, tcw, t_850, z_850, u_850, v_850, t_500, z_500, u_500, v_500, tp]
+output_channel_names: [2t, 10u, 10v, tp]
+static_channel_names: ['hsurf']
+transform_channels: ['tp-box_cox_025']
+transform_input_means: {'tp-box_cox_025': -3.82097061637088}
+transform_input_stdevs: {'tp-box_cox_025': 0.2275895397699994}
+transform_output_means: {'tp-box_cox_025': -3.92281}
+transform_output_stdevs: {'tp-box_cox_025': 0.19150567}
+n_month_hour_channels: 4
+start_date: '2015-11-29'
+end_date: '2019-10-31'
+trim_edge: 19
+validation: True
+validation_start_date: '2019-11-01'
+validation_end_date: '2020-10-28'
\ No newline at end of file
diff --git a/src/hirad/conf/dataset/anemoi_era_cosmo_inference.yaml b/src/hirad/conf/dataset/anemoi_era_cosmo_inference.yaml
new file mode 100644
index 00000000..1651877a
--- /dev/null
+++ b/src/hirad/conf/dataset/anemoi_era_cosmo_inference.yaml
@@ -0,0 +1,16 @@
+type: anemoi_era5_cosmo
+input_anemoi_dataset_path: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+corrected_tp_path: '/capstor/scratch/cscs/mmcgloho/basic-numpy/copernicus-era-1h/copernicus-interpolated/'
+target_anemoi_dataset_path: '/capstor/store/cscs/pasc/c38/anemoi-datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr'
+input_channel_names: [2t, 10u, 10v, tcw, t_850, z_850, u_850, v_850, t_500, z_500, u_500, v_500, tp]
+output_channel_names: [2t, 10u, 10v, tp]
+static_channel_names: ['hsurf']
+transform_channels: ['tp-box_cox_025']
+transform_input_means: {'tp-box_cox_025': -3.82097061637088}
+transform_input_stdevs: {'tp-box_cox_025': 0.2275895397699994}
+transform_output_means: {'tp-box_cox_025': -3.92281}
+transform_output_stdevs: {'tp-box_cox_025': 0.19150567}
+n_month_hour_channels: 4
+start_date: '2015-11-29'
+end_date: '2020-10-28'
+trim_edge: 19
\ No newline at end of file
diff --git a/src/hirad/conf/dataset/anemoi_era_real.yaml b/src/hirad/conf/dataset/anemoi_era_real.yaml
new file mode 100644
index 00000000..2010f35e
--- /dev/null
+++ b/src/hirad/conf/dataset/anemoi_era_real.yaml
@@ -0,0 +1,20 @@
+type: anemoi_era5_real
+# input_anemoi_dataset_path: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+input_anemoi_dataset_path: '/capstor/store/mch/msopr/ml/datasets/aifs-ea-an-oper-0001-mars-n320-1979-2024-1h-v2-with-era51.zarr'
+# corrected_tp_path: '/capstor/scratch/cscs/mmcgloho/basic-numpy/copernicus-era-1h/copernicus-interpolated/'
+target_anemoi_dataset_path: '/capstor/store/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-v1.0.zarr'
+input_channel_names: [2t, 10u, 10v, tcw, t_850, z_850, u_850, v_850, t_500, z_500, u_500, v_500, tp]
+output_channel_names: [2t, 10u, 10v, tp]
+static_channel_names: ['FIS']
+transform_channels: ['tp-box_cox_025']
+transform_input_means: {'tp-box_cox_025': -3.815209745618941}
+transform_input_stdevs: {'tp-box_cox_025': 0.22851179478814418}
+transform_output_means: {'tp-box_cox_025': -3.8121187083242556}
+transform_output_stdevs: {'tp-box_cox_025': 0.3345851858215482}
+n_month_hour_channels: 4
+start_date: '2005-01-02'
+end_date: '2020-12-31'
+trim_edge: 41
+validation: True
+validation_start_date: '2021-01-01'
+validation_end_date: '2021-12-31'
diff --git a/src/hirad/conf/dataset/anemoi_era_real_inference.yaml b/src/hirad/conf/dataset/anemoi_era_real_inference.yaml
new file mode 100644
index 00000000..43bb66b6
--- /dev/null
+++ b/src/hirad/conf/dataset/anemoi_era_real_inference.yaml
@@ -0,0 +1,17 @@
+type: anemoi_era5_real
+# input_anemoi_dataset_path: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+input_anemoi_dataset_path: '/capstor/store/mch/msopr/ml/datasets/aifs-ea-an-oper-0001-mars-n320-1979-2024-1h-v2-with-era51.zarr'
+# corrected_tp_path: '/capstor/scratch/cscs/mmcgloho/basic-numpy/copernicus-era-1h/copernicus-interpolated/'
+target_anemoi_dataset_path: '/capstor/store/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-v1.0.zarr'
+input_channel_names: [2t, 10u, 10v, tcw, t_850, z_850, u_850, v_850, t_500, z_500, u_500, v_500, tp]
+output_channel_names: [2t, 10u, 10v, tp]
+static_channel_names: ['FIS']
+transform_channels: ['tp-box_cox_025']
+transform_input_means: {'tp-box_cox_025': -3.815209745618941}
+transform_input_stdevs: {'tp-box_cox_025': 0.22851179478814418}
+transform_output_means: {'tp-box_cox_025': -3.8121187083242556}
+transform_output_stdevs: {'tp-box_cox_025': 0.3345851858215482}
+n_month_hour_channels: 4
+start_date: '2021-01-01'
+end_date: '2024-12-31'
+trim_edge: 41
diff --git a/src/hirad/conf/dataset/era_cosmo.yaml b/src/hirad/conf/dataset/era_cosmo.yaml
index b1e21e6f..70ebb579 100644
--- a/src/hirad/conf/dataset/era_cosmo.yaml
+++ b/src/hirad/conf/dataset/era_cosmo.yaml
@@ -1,3 +1,32 @@
type: era5_cosmo
-dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/trim_19_full
-validation_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/trim_19_full/validation
\ No newline at end of file
+# dataset_path: /iopsstor/scratch/cscs/mmcgloho/run-1_2/train/
+# validation_path: /iopsstor/scratch/cscs/mmcgloho/run-1_2/validation/
+dataset_path: /iopsstor/scratch/cscs/mmcgloho/basic-numpy/era5-cosmo-1h-all-channels/train/
+validation_path: /iopsstor/scratch/cscs/mmcgloho/basic-numpy/era5-cosmo-1h-all-channels/validation/
+# input_channel_names: [2t, 10u, 10v, tcw, t_850, z_850, u_850, v_850, t_500, z_500, u_500, v_500, tp]
+# output_channel_names: [2t, 10u, 10v, tp]
+input_dir_name: era-copernicus-interpolated
+output_dir_name: cosmo
+input_channel_names: ['10u', '10v', '2d', '2t',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+output_channel_names: ['10u', '10v', '2d', '2t',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'skt', 'sp', 'tcc', 'tp', 'tqv',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+static_channel_names: ['hsurf']
+transform_channels: ['tp-box_cox_025']
+n_month_hour_channels: 4
\ No newline at end of file
diff --git a/src/hirad/conf/dataset/era_cosmo_inference.yaml b/src/hirad/conf/dataset/era_cosmo_inference.yaml
new file mode 100644
index 00000000..9fabeffd
--- /dev/null
+++ b/src/hirad/conf/dataset/era_cosmo_inference.yaml
@@ -0,0 +1,28 @@
+type: era5_cosmo
+# dataset_path: /iopsstor/scratch/cscs/mmcgloho/basic-numpy/era5-cosmo-1h-all-channels/validation/
+dataset_path: /iopsstor/scratch/cscs/mmcgloho/run-1_2/validation/
+input_dir_name: era-interpolated
+output_dir_name: cosmo-orig
+# input_channel_names: ['10u', '10v', '2d', '2t',
+# 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+# 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+# 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+# 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+# 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+# 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+# 'z',
+# 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+# ]
+# output_channel_names: ['10u', '10v', '2d', '2t',
+# 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+# 'skt', 'sp', 'tcc', 'tp', 'tqv',
+# 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+# 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+# 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+# 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+# 'z',
+# 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+# ]
+static_channel_names: ['hsurf']
+transform_channels: ['tp-box_cox_025']
+n_month_hour_channels: 4
\ No newline at end of file
diff --git a/src/hirad/conf/dataset/era_real.yaml b/src/hirad/conf/dataset/era_real.yaml
new file mode 100644
index 00000000..ec1b9508
--- /dev/null
+++ b/src/hirad/conf/dataset/era_real.yaml
@@ -0,0 +1,8 @@
+type: era5_real
+dataset_path: /iopsstor/scratch/cscs/mmcgloho/basic-numpy/era5-realch1/v1.0-channel-subset/train
+validation_path: /iopsstor/scratch/cscs/mmcgloho/basic-numpy/era5-realch1/v1.0-channel-subset/validation
+input_channel_names: [2t, 10u, 10v, tcw, t_850, z_850, u_850, v_850, t_500, z_500, u_500, v_500, tp]
+output_channel_names: [2t, 10u, 10v, tp]
+static_channel_names: ['z']
+transform_channels: ['tp-box_cox_025']
+n_month_hour_channels: 4
\ No newline at end of file
diff --git a/src/hirad/conf/dataset/era_real_inference.yaml b/src/hirad/conf/dataset/era_real_inference.yaml
new file mode 100644
index 00000000..847d6046
--- /dev/null
+++ b/src/hirad/conf/dataset/era_real_inference.yaml
@@ -0,0 +1,7 @@
+type: era5_real
+dataset_path: /iopsstor/scratch/cscs/mmcgloho/basic-numpy/era5-realch1/v1.0-channel-subset/validation
+input_channel_names: [2t, 10u, 10v, tcw, t_850, z_850, u_850, v_850, t_500, z_500, u_500, v_500, tp]
+output_channel_names: [2t, 10u, 10v, tp]
+static_channel_names: ['z']
+transform_channels: ['tp-box_cox_025']
+n_month_hour_channels: 4
\ No newline at end of file
diff --git a/src/hirad/conf/eval_cosmo.yaml b/src/hirad/conf/eval_cosmo.yaml
new file mode 100644
index 00000000..08c55f2d
--- /dev/null
+++ b/src/hirad/conf/eval_cosmo.yaml
@@ -0,0 +1,25 @@
+# Path to the inference output directory
+inference_output_dir: '/path/to/generation/results/directory'
+results_dir_name: 'evaluation_maps'
+
+# Constants for evaluation depend on dataset specifics and should be modified accordingly
+
+conv_factor_hourly: 1000 # Convert precip of ERA5 from meters to mm/h
+conv_factor: 24000 # Convert the precip of ERA5 from meters/h to mm/day
+wet_threshold: 0.1 # Threshold for wet-hour in mm/h
+log_interval: 24 # Log progress every N timesteps
+land_sea_mask_path: '/capstor/store/mch/msopr/hirad-gen/eval/lsm.npy' # Path to land-sea mask numpy file
+
+# Constants describing the grid for plotting
+lat_start: -4.42
+lat_end: 3.36
+lat_step: 0.02
+lon_start: -6.82
+lon_end: 4.80
+lon_step: 0.02
+relax_zone: 19
+height: 352
+width: 544
+
+# List of channels to evaluate/plot - comment out if all channels are to be used
+# plot_channels: ['2t', '10u', '10v', 'tp', 't_700', 'u_700', 'v_700', 'z_700', 'q_700', 'w_700']
\ No newline at end of file
diff --git a/src/hirad/conf/eval_real.yaml b/src/hirad/conf/eval_real.yaml
new file mode 100644
index 00000000..339d9f35
--- /dev/null
+++ b/src/hirad/conf/eval_real.yaml
@@ -0,0 +1,36 @@
+# Path to the inference output directory
+inference_output_dir: '/path/to/generation/results/directory'
+results_dir_name: 'evaluation_maps'
+
+# Constants for evaluation depend on dataset specifics and should be modified accordingly
+
+conv_factor_hourly: 1000 # Convert precip of ERA5 from meters to mm/h
+conv_factor: 24000 # Convert the precip of ERA5 from meters/h to mm/day
+wet_threshold: 0.1 # Threshold for wet-hour in mm/h
+log_interval: 24 # Log progress every N timesteps
+land_sea_mask_path: '/capstor/store/cscs/pasc/c38/real_grid_info/lsm_real.npy' # Path to land-sea mask numpy file
+
+# Constants describing the grid for plotting
+lat_start: -4.42
+lat_end: 3.36
+lat_step: 0.01
+lon_start: -6.82
+lon_end: 4.80
+lon_step: 0.01
+relax_zone: 41
+height: 704
+width: 1088
+
+# If data was generated in several steps, you HAVE TO SPECIFY TIME STEPS BELOW
+# If you don't want to evaluate all generated samples, you can specify a range of time steps to evaluate.
+# This is useful for debugging or if you only want to evaluate a subset of the generated data.
+# Make sure that generated samples are available for the specified time steps.
+# Use times_ranges to combine multiple seasons/years into a single evaluation run.
+times_ranges:
+ - ['20210601-0000', '20210831-2300', 1]
+ - ['20220601-0000', '20220831-2300', 1]
+ - ['20230601-0000', '20230831-2300', 1]
+ - ['20240601-0000', '20240831-2300', 1]
+
+# List of channels to evaluate/plot - comment out if all channels are to be used
+# plot_channels: ['2t', '10u', '10v', 'tp', 't_700', 'u_700', 'v_700', 'z_700', 'q_700', 'w_700']
\ No newline at end of file
diff --git a/src/hirad/conf/generate_era_cosmo.yaml b/src/hirad/conf/generate_era_cosmo.yaml
index 5d7649de..b3365aab 100644
--- a/src/hirad/conf/generate_era_cosmo.yaml
+++ b/src/hirad/conf/generate_era_cosmo.yaml
@@ -1,20 +1,19 @@
hydra:
job:
chdir: true
- name: generation_full
+ name: generation_era_cosmo_results
run:
- dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name}
+ dir: ./outputs/generation/${hydra:job.name}
# Get defaults
defaults:
- _self_
# Dataset
- - dataset/era_cosmo
+ - dataset/anemoi_era_cosmo_inference
# Sampler
- sampler/stochastic
#- sampler/deterministic
# Generation
- - generation/era_cosmo
- #- generation/patched_based
\ No newline at end of file
+ - generation/era_cosmo
\ No newline at end of file
diff --git a/src/hirad/conf/generate_era_cosmo_test.yaml b/src/hirad/conf/generate_era_cosmo_test.yaml
new file mode 100644
index 00000000..e0f53a60
--- /dev/null
+++ b/src/hirad/conf/generate_era_cosmo_test.yaml
@@ -0,0 +1,19 @@
+hydra:
+ job:
+ chdir: true
+ name: generation_era5_cosmo_test
+ run:
+ dir: ./outputs/generation/${hydra:job.name}
+
+# Get defaults
+defaults:
+ - _self_
+ # Dataset
+ - dataset/era_cosmo_inference
+
+ # Sampler
+ - sampler/stochastic
+ #- sampler/deterministic
+
+ # Generation
+ - generation/era_cosmo_test
\ No newline at end of file
diff --git a/src/hirad/conf/generate_era_real.yaml b/src/hirad/conf/generate_era_real.yaml
new file mode 100644
index 00000000..20f48b3b
--- /dev/null
+++ b/src/hirad/conf/generate_era_real.yaml
@@ -0,0 +1,19 @@
+hydra:
+ job:
+ chdir: true
+ name: generation_era_real_results
+ run:
+ dir: ./outputs/generation/${hydra:job.name}
+
+# Get defaults
+defaults:
+ - _self_
+ # Dataset
+ - dataset/anemoi_era_real_inference
+
+ # Sampler
+ - sampler/stochastic
+ #- sampler/deterministic
+
+ # Generation
+ - generation/era_real
\ No newline at end of file
diff --git a/src/hirad/conf/generation/era_cosmo.yaml b/src/hirad/conf/generation/era_cosmo.yaml
index be4219d2..3c19a0a9 100644
--- a/src/hirad/conf/generation/era_cosmo.yaml
+++ b/src/hirad/conf/generation/era_cosmo.yaml
@@ -1,4 +1,4 @@
-num_ensembles: 8
+num_ensembles: 16
# Number of ensembles to generate per input
seed_batch_size: 4
# Size of the batched inference
@@ -6,21 +6,31 @@ inference_mode: all
# Choose between "all" (regression + diffusion), "regression" or "diffusion"
# Patch size. Patch-based sampling will be utilized if these dimensions differ from
# img_shape_x and img_shape_y
-# overlap_pixels: 0
- # Number of overlapping pixels between adjacent patches
-# boundary_pixels: 0
- # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary
- # artifact.
+
+randomize: True
+ # Whether to randomize the random seeds for each generation. If false, fixed seeds
+ # from 0 to num_ensembles-1 will be used for each time step in times/times_range.
+random_seed: 2578458
+ # Base random seed. This is only used when randomize is True.
+ # random seed will be set for numpy random module to have reproducible randomized generative process.
+
+# Patching parameters
patching: False
+# patch_shape_x: 128
+# patch_shape_y: 128
+# overlap_pix: 4
+# # Number of overlapping pixels between adjacent patches
+# boundary_pix: 2
+# # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary
+# # artifact.
+
hr_mean_conditioning: True
# sample_res: full
# Sampling resolution
-times_range: null
-times:
- - 20160101-0000
- # - 20160101-0600
- # - 20160101-1200
-has_laed_time: False
+times_range: ['20200601-0000','20200831-2300',1]
+ # Start date, end date and time interval (in hours) for the generation
+times: null
+has_lead_time: False
perf:
force_fp16: False
@@ -30,15 +40,13 @@ perf:
# whether to use torch.compile on the diffusion model
# this will make the first time stamp generation very slow due to compilation overheads
# but will significantly speed up subsequent inference runs
- num_writer_workers: 1
+ num_writer_workers: 8
# number of workers to use for writing file
# To support multiple workers a threadsafe version of the netCDF library must be used
io:
- res_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/diffusion_refactoring/checkpoints_diffusion
- # res_ckpt_path: null
- # Checkpoint filename for the diffusion model
- reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_refactoring/checkpoints_regression
- # reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_test/checkpoints_regression
- # Checkpoint filename for the mean predictor model
- output_path: ./images
\ No newline at end of file
+ res_ckpt_path: /path/to/diffusion/model/weights
+ # Checkpoint filename for the residual predictor model
+ reg_ckpt_path: /oath/to/regression/model/weigths
+ # Checkpoint filename for the mean predictor model
+ output_path: .
\ No newline at end of file
diff --git a/src/hirad/conf/generation/era_cosmo_test.yaml b/src/hirad/conf/generation/era_cosmo_test.yaml
new file mode 100644
index 00000000..9e10eb4a
--- /dev/null
+++ b/src/hirad/conf/generation/era_cosmo_test.yaml
@@ -0,0 +1,42 @@
+# TODO: See if there's a way to inherit from era_cosmo.yaml
+num_ensembles: 8
+ # Number of ensembles to generate per input
+seed_batch_size: 4
+ # Size of the batched inference
+inference_mode: all
+ # Choose between "all" (regression + diffusion), "regression" or "diffusion"
+ # Patch size. Patch-based sampling will be utilized if these dimensions differ from
+ # img_shape_x and img_shape_y
+# overlap_pixels: 0
+ # Number of overlapping pixels between adjacent patches
+# boundary_pixels: 0
+ # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary
+ # artifact.
+patching: False
+hr_mean_conditioning: True
+# sample_res: full
+ # Sampling resolution
+times_range: ['20200131-0000','20200131-2300',1]
+times: null
+has_lead_time: False
+
+perf:
+ force_fp16: False
+ # Whether to force fp16 precision for the model. If false, it'll use the precision
+ # specified upon training.
+ use_torch_compile: False
+ # whether to use torch.compile on the diffusion model
+ # this will make the first time stamp generation very slow due to compilation overheads
+ # but will significantly speed up subsequent inference runs
+ num_writer_workers: 8
+ # number of workers to use for writing file
+ # To support multiple workers a threadsafe version of the netCDF library must be used
+
+io:
+ res_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/training/diffusion_test/checkpoints_diffusion
+ # Checkpoint filename for the diffusion model
+ reg_ckpt_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_test/checkpoints_regression
+ # Checkpoint filename for the mean predictor model
+ output_path: ./outputs/evaluation
+
+
diff --git a/src/hirad/conf/generation/era_cosmo_training.yaml b/src/hirad/conf/generation/era_cosmo_training.yaml
new file mode 100644
index 00000000..54dc8402
--- /dev/null
+++ b/src/hirad/conf/generation/era_cosmo_training.yaml
@@ -0,0 +1,18 @@
+defaults:
+ - ../sampler@sampler: stochastic
+ - ../dataset@dataset: era_cosmo_inference
+
+num_ensembles: 16
+ # Number of ensembles to generate per input
+# overlap_pixels: 0
+ # Number of overlapping pixels between adjacent patches
+# boundary_pixels: 0
+ # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary
+ # artifact.
+times_range: null
+times:
+ - 20200721-1900
+ - 20200722-1900
+
+perf:
+ num_writer_workers: 10
\ No newline at end of file
diff --git a/src/hirad/conf/generation/era_cosmo_training_patched.yaml b/src/hirad/conf/generation/era_cosmo_training_patched.yaml
new file mode 100644
index 00000000..e6e2d7c1
--- /dev/null
+++ b/src/hirad/conf/generation/era_cosmo_training_patched.yaml
@@ -0,0 +1,24 @@
+defaults:
+ - ../sampler@sampler: stochastic
+ - ../dataset@dataset: era_cosmo_inference
+
+num_ensembles: 16
+ # Number of ensembles to generate per input
+
+patching: True
+# Use patch-based sampling
+overlap_pix: 4
+# Number of overlapping pixels between adjacent patches
+boundary_pix: 2
+# Number of boundary pixels to be cropped out. 2 is recommended to address the boundary
+# artifact.
+patch_shape_x: 128
+patch_shape_y: 128
+
+times_range: null
+times:
+ - 20200926-1800
+ # - 20200927-0000
+
+perf:
+ num_writer_workers: 10
\ No newline at end of file
diff --git a/src/hirad/conf/generation/era_real.yaml b/src/hirad/conf/generation/era_real.yaml
new file mode 100644
index 00000000..12309ff9
--- /dev/null
+++ b/src/hirad/conf/generation/era_real.yaml
@@ -0,0 +1,56 @@
+num_ensembles: 8
+ # Number of ensembles to generate per input
+seed_batch_size: 4
+ # Size of the batched inference
+inference_mode: all
+ # Choose between "all" (regression + diffusion), "regression" or "diffusion"
+ # Patch size. Patch-based sampling will be utilized if these dimensions differ from
+ # img_shape_x and img_shape_y
+
+randomize: True
+ # Whether to randomize the random seeds for each generation. If false, fixed seeds
+ # from 0 to num_ensembles-1 will be used for each time step in times/times_range.
+random_seed: 2578458
+ # Base random seed. This is only used when randomize is True.
+ # random seed will be set for numpy random module to have reproducible randomized generative process.
+
+# Patching parameters
+patching: True
+patch_shape_x: 384
+patch_shape_y: 384
+overlap_pix: 4
+# # Number of overlapping pixels between adjacent patches
+boundary_pix: 2
+# # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary
+# # artifact.
+
+hr_mean_conditioning: True
+
+times_range: ['20211015-0000','20211015-2300',1]
+ # Start date, end date and time interval (in hours) for the generation
+
+times: null
+has_lead_time: False
+
+perf:
+ force_fp16: False
+ # Whether to force fp16 precision for the model. If false, it'll use the precision
+ # specified upon training.
+ use_torch_compile: True
+ # whether to use torch.compile on the diffusion model
+ # this will make the first time stamp generation very slow due to compilation overheads
+ # but will significantly speed up subsequent inference runs
+ num_writer_workers: 8
+ # number of workers to use for writing file
+ # To support multiple workers a threadsafe version of the netCDF library must be used
+ enable_timing: True
+ # Whether to collect and log detailed per-step timing breakdowns.
+ # Disabling removes all torch.cuda.synchronize() calls from the hot path,
+ # which improves throughput for production runs.
+
+io:
+ res_ckpt_path: /path/to/diffusion/checkpoint
+ # Checkpoint filename for the residual predictor model
+ reg_ckpt_path: /path/to/regression/checkpoint
+ # Checkpoint filename for the mean predictor model
+ output_path: .
\ No newline at end of file
diff --git a/src/hirad/conf/logging/era_cosmo_diffusion.yaml b/src/hirad/conf/logging/era_cosmo_diffusion.yaml
new file mode 100644
index 00000000..86ec7fe4
--- /dev/null
+++ b/src/hirad/conf/logging/era_cosmo_diffusion.yaml
@@ -0,0 +1,8 @@
+# set method to mlflow to log with mlflow
+method: mlflow
+experiment_name: hirad-corrdiff-diffusion
+run_name: era-cosmo-1h
+# change uri to remote mlflow server; if null, it is stored locally
+# if uri is remote make sure to have credentials set in ~/.mlflow/credentials
+uri: null
+log_images: false
\ No newline at end of file
diff --git a/src/hirad/conf/logging/era_cosmo_regression.yaml b/src/hirad/conf/logging/era_cosmo_regression.yaml
new file mode 100644
index 00000000..e7a62873
--- /dev/null
+++ b/src/hirad/conf/logging/era_cosmo_regression.yaml
@@ -0,0 +1,8 @@
+# set method to mlflow to log with mlflow
+method: mlflow
+experiment_name: hirad-corrdiff-regression
+run_name: era-cosmo-1h
+# change uri to remote mlflow server; if null, it is stored locally
+# if uri is remote make sure to have credentials set in ~/.mlflow/credentials
+uri: null
+log_images: false
\ No newline at end of file
diff --git a/src/hirad/conf/logging/era_real_diffusion.yaml b/src/hirad/conf/logging/era_real_diffusion.yaml
new file mode 100644
index 00000000..b45ef2cf
--- /dev/null
+++ b/src/hirad/conf/logging/era_real_diffusion.yaml
@@ -0,0 +1,8 @@
+# set method to mlflow to log with mlflow
+method: mlflow
+experiment_name: hirad-corrdiff-diffusion-training
+run_name: era-real
+# change uri to remote mlflow server; if null, it is stored locally
+# if uri is remote make sure to have credentials set in ~/.mlflow/credentials
+uri: null
+log_images: false
\ No newline at end of file
diff --git a/src/hirad/conf/logging/era_real_regression.yaml b/src/hirad/conf/logging/era_real_regression.yaml
new file mode 100644
index 00000000..1b731019
--- /dev/null
+++ b/src/hirad/conf/logging/era_real_regression.yaml
@@ -0,0 +1,8 @@
+# set method to mlflow to log with mlflow
+method: mlflow
+experiment_name: hirad-corrdiff-regression-training
+run_name: era-real
+# change uri to remote mlflow server; if null, it is stored locally
+# if uri is remote make sure to have credentials set in ~/.mlflow/credentials
+uri: null
+log_images: false
\ No newline at end of file
diff --git a/src/hirad/conf/model/era_cosmo_diffusion.yaml b/src/hirad/conf/model/era_cosmo_diffusion.yaml
index 441239e1..7a060e8b 100644
--- a/src/hirad/conf/model/era_cosmo_diffusion.yaml
+++ b/src/hirad/conf/model/era_cosmo_diffusion.yaml
@@ -10,6 +10,6 @@ model_args:
# Controls how positional information is encoded.
N_grid_channels: 4
# Number of channels for positional grid embeddings
- embedding_type: "zero"
+ embedding_type: "positional"
# Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++,
# 'zero' for none
\ No newline at end of file
diff --git a/src/hirad/conf/model/era_cosmo_diffusion_patched.yaml b/src/hirad/conf/model/era_cosmo_diffusion_patched.yaml
new file mode 100644
index 00000000..9362932b
--- /dev/null
+++ b/src/hirad/conf/model/era_cosmo_diffusion_patched.yaml
@@ -0,0 +1,13 @@
+name: patched_diffusion
+ # Name of the preconditioner
+hr_mean_conditioning: True
+ # High-res mean (regression's output) as additional condition
+
+# Standard model parameters.
+# Standard model parameters.
+model_args:
+ gridtype: "learnable"
+ # Type of positional grid to use: 'sinusoidal', 'learnable', 'linear'.
+ # Controls how positional information is encoded.
+ N_grid_channels: 100
+ # Number of channels for positional grid embeddings
\ No newline at end of file
diff --git a/src/hirad/conf/model/era_real_diffusion_patched.yaml b/src/hirad/conf/model/era_real_diffusion_patched.yaml
new file mode 100644
index 00000000..2af0e5ce
--- /dev/null
+++ b/src/hirad/conf/model/era_real_diffusion_patched.yaml
@@ -0,0 +1,16 @@
+name: patched_diffusion
+ # Name of the preconditioner
+hr_mean_conditioning: True
+ # High-res mean (regression's output) as additional condition
+
+# Standard model parameters.
+# Standard model parameters.
+model_args:
+ gridtype: "learnable"
+ # Type of positional grid to use: 'sinusoidal', 'learnable', 'linear'.
+ # Controls how positional information is encoded.
+ N_grid_channels: 100
+ # Number of channels for positional grid embeddings
+ embedding_type: "positional"
+ # Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++,
+ # 'zero' for none
\ No newline at end of file
diff --git a/src/hirad/conf/model/era_real_regression.yaml b/src/hirad/conf/model/era_real_regression.yaml
new file mode 100644
index 00000000..29b43e8f
--- /dev/null
+++ b/src/hirad/conf/model/era_real_regression.yaml
@@ -0,0 +1,10 @@
+name: regression
+hr_mean_conditioning: False
+
+# Default regression model parameters. Do not modify.
+model_args:
+ "N_grid_channels": 4
+ # Number of channels for positional grid embeddings
+ "embedding_type": "zero"
+ # Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++,
+ # 'zero' for none
\ No newline at end of file
diff --git a/src/hirad/conf/model_size/normal.yaml b/src/hirad/conf/model_size/normal.yaml
index 96c29fbc..e746a2c6 100644
--- a/src/hirad/conf/model_size/normal.yaml
+++ b/src/hirad/conf/model_size/normal.yaml
@@ -24,4 +24,4 @@ model_args:
# Per-resolution multipliers for the number of channels.
channel_mult: [1, 2, 2, 2, 2]
# Resolutions at which self-attention layers are applied.
- attn_resolutions: [28]
\ No newline at end of file
+ attn_resolutions: [22]
\ No newline at end of file
diff --git a/src/hirad/conf/model_size/normal_real.yaml b/src/hirad/conf/model_size/normal_real.yaml
new file mode 100644
index 00000000..0aab33a4
--- /dev/null
+++ b/src/hirad/conf/model_size/normal_real.yaml
@@ -0,0 +1,27 @@
+# @package _global_.model
+
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Normal model size (80 million parameters) should be used by default for full datasets and higher grid size.
+
+model_args:
+ # Base multiplier for the number of channels across the network.
+ model_channels: 128
+ # Per-resolution multipliers for the number of channels.
+ channel_mult: [1, 2, 2, 2, 2]
+ # Resolutions at which self-attention layers are applied.
+ attn_resolutions: [44]
\ No newline at end of file
diff --git a/src/hirad/conf/sampler/deterministic.yaml b/src/hirad/conf/sampler/deterministic.yaml
index 856906b6..f65e7384 100644
--- a/src/hirad/conf/sampler/deterministic.yaml
+++ b/src/hirad/conf/sampler/deterministic.yaml
@@ -2,7 +2,8 @@
# Deterministic sampler is not implemented correctly in this codebase and shouldn't be used.
type: deterministic
-num_steps: 9
+params:
+ num_steps: 9
# Number of denoising steps
-solver: euler
+ solver: euler
# ODE solver type: euler is the simplest solver
\ No newline at end of file
diff --git a/src/hirad/conf/sampler/stochastic.yaml b/src/hirad/conf/sampler/stochastic.yaml
index 808270c9..2f3cccbe 100644
--- a/src/hirad/conf/sampler/stochastic.yaml
+++ b/src/hirad/conf/sampler/stochastic.yaml
@@ -1,5 +1,3 @@
# Stochastic sampler is slower, but should give better results than deterministic sampler.
-type: stochastic
-# boundary_pix: 2 # set for patched diffusion
-# overlap_pix: 4 # set for patched diffusion
\ No newline at end of file
+type: stochastic
\ No newline at end of file
diff --git a/src/hirad/conf/training/era_cosmo_diffusion.yaml b/src/hirad/conf/training/era_cosmo_diffusion.yaml
index 07cbb03f..eed75c27 100644
--- a/src/hirad/conf/training/era_cosmo_diffusion.yaml
+++ b/src/hirad/conf/training/era_cosmo_diffusion.yaml
@@ -1,20 +1,20 @@
# Hyperparameters
hp:
- training_duration: 5000000
+ training_duration: 8000000
# Training duration based on the number of processed samples
- total_batch_size: 128
+ total_batch_size: "auto"
# Total batch size
- batch_size_per_gpu: "auto"
+ batch_size_per_gpu: 20
# Batch size per GPU
lr: 0.0002
# Learning rate
- grad_clip_threshold: null
+ grad_clip_threshold: 1e6
# no gradient clipping for defualt non-patch-based training
- lr_decay: 1
+ lr_decay: 0.7
# LR decay rate
- lr_rampup: 0
+ lr_rampup: 1e6
# Rampup for learning rate, in number of samples
- lr_decay_rate: 5e5
+ lr_decay_rate: 1e6
# Learning rate decay threshold in number of samples, applied every lr_decay_rate samples.
# Performance
@@ -22,22 +22,24 @@ perf:
fp_optimizations: amp-bf16
# Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"]
# "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16}
- dataloader_workers: 10
+ dataloader_workers: 30
# DataLoader worker processes
songunet_checkpoint_level: 0 # 0 means no checkpointing
# Gradient checkpointing level, value is number of layers to checkpoint
-
+ use_apex_gn: True
+ torch_compile: True
+ profile_mode: False
# I/O
io:
- regression_checkpoint_path: /capstor/scratch/cscs/boeschf/HiRAD-Gen/outputs_full/regression/checkpoints_regression/
+ # regression_checkpoint_path: /capstor/scratch/cscs/pstamenk/outputs/training/regression_era5_cosmo/checkpoints_regression
+ regression_checkpoint_path: /path/to/regression/checkpoint
# Where to load the regression checkpoint
- print_progress_freq: 5000
+ print_progress_freq: 2000
# How often to print progress
- save_checkpoint_freq: 250000
+ save_checkpoint_freq: 200000
# How often to save the checkpoints, measured in number of processed samples
- validation_freq: 25000
+ validation_freq: 50000
# how often to record the validation loss, measured in number of processed samples
- validation_steps: 4
- # how many loss evaluations are used to compute the validation loss per checkpoint
+ validation_steps: 90
# how many loss evaluations are used to compute the validation loss per checkpoint
checkpoint_dir: .
\ No newline at end of file
diff --git a/src/hirad/conf/training/era_cosmo_diffusion_patched.yaml b/src/hirad/conf/training/era_cosmo_diffusion_patched.yaml
new file mode 100644
index 00000000..6480ec0d
--- /dev/null
+++ b/src/hirad/conf/training/era_cosmo_diffusion_patched.yaml
@@ -0,0 +1,53 @@
+# Hyperparameters
+hp:
+ training_duration: 8000000
+ # Training duration based on the number of processed samples
+ total_batch_size: "auto"
+ # Total batch size
+ batch_size_per_gpu: 10
+ # Batch size per GPU
+ lr: 0.0002
+ # Learning rate
+ grad_clip_threshold: 1e6
+ # no gradient clipping for defualt non-patch-based training
+ lr_decay: 0.7
+ # LR decay rate
+ lr_rampup: 1000000
+ # Rampup for learning rate, in number of samples
+ lr_decay_rate: 5e5
+ # Learning rate decay threshold in number of samples, applied every lr_decay_rate samples.
+ patch_shape_x: 128
+ patch_shape_y: 128
+ # Patch size. Patch training is used if these dimensions differ from
+ # img_shape_x and img_shape_y.
+ patch_num: 15
+ # Number of patches from a single sample. Total number of patches is
+ # patch_num * batch_size_global.
+ max_patch_per_gpu: 300
+ # Maximum number of pataches a gpu can hold
+
+# Performance
+perf:
+ fp_optimizations: amp-bf16
+ # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"]
+ # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16}
+ dataloader_workers: 10
+ # DataLoader worker processes
+ songunet_checkpoint_level: 0 # 0 means no checkpointing
+ # Gradient checkpointing level, value is number of layers to checkpoint
+ use_apex_gn: True
+ torch_compile: True
+ profile_mode: False
+# I/O
+io:
+ regression_checkpoint_path: /path/to/regression/checkpoint
+ # Where to load the regression checkpoint
+ print_progress_freq: 2000
+ # How often to print progress
+ save_checkpoint_freq: 250000
+ # How often to save the checkpoints, measured in number of processed samples
+ validation_freq: 50000
+ # how often to record the validation loss, measured in number of processed samples
+ validation_steps: 50
+ # how many loss evaluations are used to compute the validation loss per checkpoint
+ checkpoint_dir: .
\ No newline at end of file
diff --git a/src/hirad/conf/training/era_cosmo_regression.yaml b/src/hirad/conf/training/era_cosmo_regression.yaml
index 98c6c249..ad291dcc 100644
--- a/src/hirad/conf/training/era_cosmo_regression.yaml
+++ b/src/hirad/conf/training/era_cosmo_regression.yaml
@@ -1,10 +1,10 @@
# Hyperparameters
hp:
- training_duration: 500000
+ training_duration: 1000000
# Training duration based on the number of processed samples
- total_batch_size: 64
- # Total batch size
- batch_size_per_gpu: "auto"
+ total_batch_size: "auto"
+ # Total batch size -- based 8 per GPU -- 2 nodes is 2x8x4 -- see sbatch vars for how many gpus. diffusion need to point to the rgression.
+ batch_size_per_gpu: 20
# Batch size per GPU
lr: 0.0002
# Learning rate
@@ -22,21 +22,22 @@ perf:
fp_optimizations: amp-bf16
# Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"]
# "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16}
- dataloader_workers: 10
+ dataloader_workers: 30
# DataLoader worker processes
songunet_checkpoint_level: 0 # 0 means no checkpointing
# Gradient checkpointing level, value is number of layers to checkpoint
- # torch_compile: True
- # use_apex_gn: True
+ use_apex_gn: True
+ torch_compile: True
+ profile_mode: False
# I/O
io:
- print_progress_freq: 1024
+ print_progress_freq: 1000
# How often to print progress
- save_checkpoint_freq: 25000
+ save_checkpoint_freq: 100000
# How often to save the checkpoints, measured in number of processed samples
- validation_freq: 5000
+ validation_freq: 50000
# how often to record the validation loss, measured in number of processed samples
- validation_steps: 10
+ validation_steps: 55
# how many loss evaluations are used to compute the validation loss per checkpoint
checkpoint_dir: .
\ No newline at end of file
diff --git a/src/hirad/conf/training/era_real_diffusion_patched.yaml b/src/hirad/conf/training/era_real_diffusion_patched.yaml
new file mode 100644
index 00000000..69f19ef3
--- /dev/null
+++ b/src/hirad/conf/training/era_real_diffusion_patched.yaml
@@ -0,0 +1,53 @@
+# Hyperparameters
+hp:
+ training_duration: 8000000
+ # Training duration based on the number of processed samples
+ total_batch_size: "auto"
+ # Total batch size
+ batch_size_per_gpu: 8
+ # Batch size per GPU
+ lr: 0.0002
+ # Learning rate
+ grad_clip_threshold: 1e6
+ # no gradient clipping for defualt non-patch-based training
+ lr_decay: 0.7
+ # LR decay rate
+ lr_rampup: 1000000
+ # Rampup for learning rate, in number of samples
+ lr_decay_rate: 1e6
+ # Learning rate decay threshold in number of samples, applied every lr_decay_rate samples.
+ patch_shape_x: 384
+ patch_shape_y: 384
+ # Patch size. Patch training is used if these dimensions differ from
+ # img_shape_x and img_shape_y.
+ patch_num: 2
+ # Number of patches from a single sample. Total number of patches is
+ # patch_num * batch_size_global.
+ max_patch_per_gpu: 300
+ # Maximum number of pataches a gpu can hold
+
+# Performance
+perf:
+ fp_optimizations: amp-bf16
+ # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"]
+ # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16}
+ dataloader_workers: 30
+ # DataLoader worker processes
+ songunet_checkpoint_level: 0 # 0 means no checkpointing
+ # Gradient checkpointing level, value is number of layers to checkpoint
+ use_apex_gn: True
+ torch_compile: True
+ profile_mode: False
+# I/O
+io:
+ regression_checkpoint_path: /path/to/regression/checkpoint
+ # Where to load the regression checkpoint
+ print_progress_freq: 2000
+ # How often to print progress
+ save_checkpoint_freq: 250000
+ # How often to save the checkpoints, measured in number of processed samples
+ validation_freq: 100000
+ # how often to record the validation loss, measured in number of processed samples
+ validation_steps: 45
+ # how many loss evaluations are used to compute the validation loss per checkpoint
+ checkpoint_dir: .
\ No newline at end of file
diff --git a/src/hirad/conf/training/era_real_regression.yaml b/src/hirad/conf/training/era_real_regression.yaml
new file mode 100644
index 00000000..3fee0c2e
--- /dev/null
+++ b/src/hirad/conf/training/era_real_regression.yaml
@@ -0,0 +1,43 @@
+# Hyperparameters
+hp:
+ training_duration: 3000000
+ # Training duration based on the number of processed samples
+ total_batch_size: "auto"
+ # Total batch size -- based 8 per GPU -- 2 nodes is 2x8x4 -- see sbatch vars for how many gpus. diffusion need to point to the rgression.
+ batch_size_per_gpu: 4
+ # Batch size per GPU
+ lr: 0.0002
+ # Learning rate
+ grad_clip_threshold: null
+ # no gradient clipping for defualt non-patch-based training
+ lr_decay: 1
+ # LR decay rate
+ lr_rampup: 0
+ # Rampup for learning rate, in number of samples
+ lr_decay_rate: 5e5
+ # Learning rate decay threshold in number of samples, applied every lr_decay_rate samples.
+
+# Performance
+perf:
+ fp_optimizations: amp-bf16
+ # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"]
+ # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16}
+ dataloader_workers: 10
+ # DataLoader worker processes
+ songunet_checkpoint_level: 0 # 0 means no checkpointing
+ # Gradient checkpointing level, value is number of layers to checkpoint
+ use_apex_gn: True
+ torch_compile: True
+ profile_mode: False
+
+# I/O
+io:
+ print_progress_freq: 1000
+ # How often to print progress
+ save_checkpoint_freq: 100000
+ # How often to save the checkpoints, measured in number of processed samples
+ validation_freq: 50000
+ # how often to record the validation loss, measured in number of processed samples
+ validation_steps: 55
+ # how many loss evaluations are used to compute the validation loss per checkpoint
+ checkpoint_dir: .
\ No newline at end of file
diff --git a/src/hirad/conf/training_era_cosmo_diffusion.yaml b/src/hirad/conf/training_era_cosmo_diffusion.yaml
index 0a069e9c..ef780a7e 100644
--- a/src/hirad/conf/training_era_cosmo_diffusion.yaml
+++ b/src/hirad/conf/training_era_cosmo_diffusion.yaml
@@ -1,16 +1,16 @@
hydra:
job:
chdir: true
- name: diffusion
+ name: diffusion_era_cosmo
run:
- dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name}
+ dir: ./outputs/training/${hydra:job.name}
# Get defaults
defaults:
- _self_
# Dataset
- - dataset/era_cosmo
+ - dataset/anemoi_era_cosmo
# Model
- model/era_cosmo_diffusion
@@ -18,4 +18,7 @@ defaults:
- model_size/normal
# Training
- - training/era_cosmo_diffusion
\ No newline at end of file
+ - training/era_cosmo_diffusion
+
+ # Logging
+ - logging/era_cosmo_diffusion
diff --git a/src/hirad/conf/training_era_cosmo_diffusion_patched.yaml b/src/hirad/conf/training_era_cosmo_diffusion_patched.yaml
new file mode 100644
index 00000000..b4ae6b29
--- /dev/null
+++ b/src/hirad/conf/training_era_cosmo_diffusion_patched.yaml
@@ -0,0 +1,24 @@
+hydra:
+ job:
+ chdir: true
+ name: diffusion_era5_cosmo_patched
+ run:
+ dir: ./outputs/training/${hydra:job.name}
+
+# Get defaults
+defaults:
+ - _self_
+
+ # Dataset
+ - dataset/anemoi_era_cosmo
+
+ # Model
+ - model/era_cosmo_diffusion_patched
+
+ - model_size/normal
+
+ # Training
+ - training/era_cosmo_diffusion_patched
+
+ # Logging
+ - logging/era_cosmo_diffusion
\ No newline at end of file
diff --git a/src/hirad/conf/training_era_cosmo_diffusion_test.yaml b/src/hirad/conf/training_era_cosmo_diffusion_test.yaml
new file mode 100644
index 00000000..27230507
--- /dev/null
+++ b/src/hirad/conf/training_era_cosmo_diffusion_test.yaml
@@ -0,0 +1,24 @@
+hydra:
+ job:
+ chdir: true
+ name: diffusion_era5_cosmo_test
+ run:
+ dir: ./outputs/training/${hydra:job.name}
+
+# Get defaults
+defaults:
+ - _self_
+
+ # Dataset
+ - dataset/era_cosmo
+
+ # Model
+ - model/era_cosmo_diffusion
+
+ - model_size/mini
+
+ # Training
+ - training/era_cosmo_diffusion
+
+ # Inference visualization
+ - generation/era_cosmo_training
\ No newline at end of file
diff --git a/src/hirad/conf/training_era_cosmo_regression.yaml b/src/hirad/conf/training_era_cosmo_regression.yaml
index 1de83d91..18ee91da 100644
--- a/src/hirad/conf/training_era_cosmo_regression.yaml
+++ b/src/hirad/conf/training_era_cosmo_regression.yaml
@@ -1,16 +1,17 @@
hydra:
job:
chdir: true
- name: regression
+ name: regression_era_cosmo
run:
- dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name}
+ dir: ./outputs/training/${hydra:job.name}
+
# Get defaults
defaults:
- _self_
# Dataset
- - dataset/era_cosmo
+ - dataset/anemoi_era_cosmo
# Model
- model/era_cosmo_regression
@@ -18,4 +19,7 @@ defaults:
- model_size/normal
# Training
- - training/era_cosmo_regression
\ No newline at end of file
+ - training/era_cosmo_regression
+
+ # Logging
+ - logging/era_cosmo_regression
diff --git a/src/hirad/conf/training_era_cosmo_regression_test.yaml b/src/hirad/conf/training_era_cosmo_regression_test.yaml
new file mode 100644
index 00000000..96e88fa7
--- /dev/null
+++ b/src/hirad/conf/training_era_cosmo_regression_test.yaml
@@ -0,0 +1,25 @@
+hydra:
+ job:
+ chdir: true
+ name: regression_era5_cosmo_test
+ run:
+ dir: ./outputs/training/${hydra:job.name}
+
+# Get defaults
+defaults:
+ - _self_
+
+ # Dataset
+ - dataset/era_cosmo
+
+ # Model
+ - model/era_cosmo_regression
+
+ - model_size/mini
+
+ # Training
+ # Leave same as prod
+ - training/era_cosmo_regression
+
+ # Inference visualization
+ - generation/era_cosmo_training
\ No newline at end of file
diff --git a/src/hirad/conf/training_era_real_diffusion_patched.yaml b/src/hirad/conf/training_era_real_diffusion_patched.yaml
new file mode 100644
index 00000000..b24f95c5
--- /dev/null
+++ b/src/hirad/conf/training_era_real_diffusion_patched.yaml
@@ -0,0 +1,24 @@
+hydra:
+ job:
+ chdir: true
+ name: diffusion_era_real
+ run:
+ dir: ./outputs/training/${hydra:job.name}
+
+# Get defaults
+defaults:
+ - _self_
+
+ # Dataset
+ - dataset/anemoi_era_real
+
+ # Model
+ - model/era_real_diffusion_patched
+
+ - model_size/normal_real
+
+ # Training
+ - training/era_real_diffusion_patched
+
+ # Logging
+ - logging/era_real_diffusion
\ No newline at end of file
diff --git a/src/hirad/conf/training_era_real_regression.yaml b/src/hirad/conf/training_era_real_regression.yaml
new file mode 100644
index 00000000..23eab0ca
--- /dev/null
+++ b/src/hirad/conf/training_era_real_regression.yaml
@@ -0,0 +1,25 @@
+hydra:
+ job:
+ chdir: true
+ name: regression_era_real
+ run:
+ dir: ./outputs/training/${hydra:job.name}
+
+
+# Get defaults
+defaults:
+ - _self_
+
+ # Dataset
+ - dataset/anemoi_era_real
+
+ # Model
+ - model/era_real_regression
+
+ - model_size/normal_real
+
+ # Training
+ - training/era_real_regression
+
+ # Logging
+ - logging/era_real_regression
diff --git a/src/hirad/datasets/__init__.py b/src/hirad/datasets/__init__.py
index 53e791e5..6f96a844 100644
--- a/src/hirad/datasets/__init__.py
+++ b/src/hirad/datasets/__init__.py
@@ -1,3 +1,7 @@
-from .dataset import init_train_valid_datasets_from_config, init_dataset_from_config
+from .dataset import init_train_valid_datasets_from_config, init_dataset_from_config, get_dataset_and_sampler_inference, known_datasets
from .era5_cosmo import ERA5_COSMO
-from .base import DownscalingDataset
+from .era5_real import ERA5_REAL
+from .base import DownscalingDataset, ChannelMetadata, get_channels_from_strings, get_strings_from_channels
+from .anemoi_dataset import AnemoiDataset, ANEMOI_ERA5_COSMO, ANEMOI_ERA5_REAL
+from .anemoi_dataset_copernicus_tp import AnemoiDatasetCopernicus, ANEMOI_ERA5COPERNICUSTP_REAL, ANEMOI_ERA5COPERNICUSTP_COSMO
+from .constants import REAL_TO_ERA_CHANNEL_MAP, ERA_TO_REAL_CHANNEL_MAP
\ No newline at end of file
diff --git a/src/hirad/datasets/anemoi_dataset.py b/src/hirad/datasets/anemoi_dataset.py
new file mode 100644
index 00000000..c83dc45d
--- /dev/null
+++ b/src/hirad/datasets/anemoi_dataset.py
@@ -0,0 +1,402 @@
+from .base import DownscalingDataset, ChannelMetadata
+
+from anemoi.datasets import open_dataset
+import datetime
+import os
+import numpy as np
+from pandas import to_datetime
+import torch
+from typing import List, Tuple
+import yaml
+import torch.nn.functional as F
+import time
+from pathlib import Path
+# import zarr
+
+from .constants import REAL_TO_ERA_CHANNEL_MAP, ERA_TO_REAL_CHANNEL_MAP
+from hirad.utils.console import PythonLogger
+from hirad.utils.dataset_utils import GridData, regrid_icon_to_rotlatlon
+
+
+logger = PythonLogger(__name__)
+
+# Margin to use for ERA dataset (to avoid nans from interpolation at boundary)
+INPUT_MARGIN_DEGREES = 0.5
+
+class AnemoiDataset(DownscalingDataset):
+ def __init__(self,
+ type: str,
+ input_anemoi_dataset_path: str,
+ target_anemoi_dataset_path: str,
+ start_date: datetime.datetime = None,
+ end_date: datetime.datetime = None,
+ input_channel_names: List[str] = [],
+ output_channel_names: List[str] = [],
+ static_channel_names: List[str] = [],
+ transform_channels: List[str] = [],
+ transform_input_means: dict = {},
+ transform_input_stdevs: dict = {},
+ transform_output_means: dict = {},
+ transform_output_stdevs: dict = {},
+ n_month_hour_channels: int = None,
+ trim_edge: int = 0,
+ ):
+ super().__init__()
+
+ input_dataset = type.split('_')[-2]
+ target_dataset = type.split('_')[-1]
+ self.real_target = target_dataset == 'real'
+ self.trim_edge = trim_edge
+
+ if input_dataset != 'era5':
+ raise ValueError(f"Input dataset {input_dataset} not supported for AnemoiDataset. Only 'era5' is supported.")
+ if target_dataset != 'cosmo' and target_dataset !='real':
+ raise ValueError(f"Target dataset {target_dataset} not supported for AnemoiDataset. Only 'cosmo' and 'real' are supported.")
+
+ if self.real_target:
+ # Map output channel names from real to era5
+ output_channel_names_real = [ERA_TO_REAL_CHANNEL_MAP[name] for name in output_channel_names]
+ self.lat_lon_real = torch.load("/capstor/store/cscs/pasc/c38/real_grid_info/realch1-lat-lon", weights_only=False)
+ self.regrid_indices_real = torch.from_numpy(np.load("/capstor/store/cscs/pasc/c38/real_grid_info/remap_indices.npy")).long()
+ self.regrid_weights_real = torch.from_numpy(np.load("/capstor/store/cscs/pasc/c38/real_grid_info/remap_weights.npy"))
+
+ #TODO switch hanbdling paths to Path rather than pure strings
+ self._n_month_hour_channels = n_month_hour_channels
+ target_open_dataset_kwargs = {}
+ if start_date is not None and end_date is not None:
+ assert start_date < end_date, "start_date must be before end_date"
+ target_open_dataset_kwargs['start'] = start_date
+ target_open_dataset_kwargs['end'] = end_date
+ if trim_edge > 0 and not self.real_target:
+ target_open_dataset_kwargs['trim_edge'] = trim_edge
+ self._output_dataset = open_dataset(target_anemoi_dataset_path, select=output_channel_names_real if self.real_target else output_channel_names, **target_open_dataset_kwargs)
+ assert self._output_dataset.shape[1] == len(output_channel_names)
+
+ # Load ERA dataset, trimming the area and limiting the dates to the target dataset
+ start_date = self._output_dataset.metadata()['start_date'] if start_date is None else start_date
+ end_date = self._output_dataset.metadata()['end_date'] if end_date is None else end_date
+ latitudes = self.latitude()
+ longitudes = self.longitude()
+ min_lat = min(latitudes) - INPUT_MARGIN_DEGREES
+ max_lat = max(latitudes) + INPUT_MARGIN_DEGREES
+ min_lon = max(0, min(longitudes) - INPUT_MARGIN_DEGREES)
+ max_lon = max(longitudes) + INPUT_MARGIN_DEGREES
+ area=(max_lat, min_lon, min_lat, max_lon)
+
+ self._input_dataset = open_dataset(input_anemoi_dataset_path, select=input_channel_names, start=start_date, end=end_date, area=area)
+ assert self._input_dataset.shape[1] == len(input_channel_names)
+
+ # Check that we have the same number of time points in each dataset
+ assert self._input_dataset.shape[0] == self._output_dataset.shape[0]
+
+ # Load static info and channel names
+ if static_channel_names:
+ self._static_channels = [ChannelMetadata(name) if len(name.split('_'))==1
+ else ChannelMetadata(name.split('_')[0],name.split('_')[1])
+ for name in static_channel_names]
+ static_open_dataset_kwargs = {}
+ if not self.real_target and trim_edge > 0:
+ static_open_dataset_kwargs['trim_edge'] = trim_edge
+ static_dataset = open_dataset(target_anemoi_dataset_path, select=static_channel_names, start=start_date, end=start_date, **static_open_dataset_kwargs)
+ assert static_dataset.shape[1] == len(static_channel_names)
+ # take first time point, and squeeze() to remove ensemble dimension
+ static_data = static_dataset[0,:,:,:].squeeze()
+ # Could also get these from stats, but one-time calculation is OK.
+ self.static_mean = static_data.mean(axis=-1, keepdims=True)
+ self.static_std = static_data.std(axis=-1, keepdims=True)
+ target_shape = self.image_shape()
+ self.static_data_normalized = (static_data - self.static_mean.reshape((self.static_mean.shape[0],1))) \
+ / self.static_std.reshape((self.static_std.shape[0],1))
+ self.static_data_normalized = torch.from_numpy(self.static_data_normalized)
+ self.static_data_normalized = regrid_icon_to_rotlatlon(self.static_data_normalized, self.regrid_indices_real, self.regrid_weights_real)
+ if trim_edge > 0 and self.real_target:
+ self.static_data_normalized = self.static_data_normalized[:, trim_edge:-trim_edge, trim_edge:-trim_edge]
+ # self.normalize_input(np.flip(static_data.squeeze().reshape(-1, *target_shape), 1))
+ else:
+ self.static_data_normalized = None
+ self._static_channels = []
+
+ # Load target channel names
+ self._output_channels = [ChannelMetadata(name) if len(name.split('_'))==1
+ else ChannelMetadata(name.split('_')[0],name.split('_')[1])
+ for name in output_channel_names]
+ # Load era5 channel names
+ self._input_channels = [ChannelMetadata(name) if len(name.split('_'))==1
+ else ChannelMetadata(name.split('_')[0],name.split('_')[1])
+ for name in input_channel_names]
+
+ # Load stats for normalizing channels of input and output
+ target_stats = self._output_dataset.statistics
+ self.output_mean = target_stats['mean'][:]
+ self.output_std = target_stats['stdev'][:]
+
+ input_stats = self._input_dataset.statistics
+ self.input_mean = input_stats['mean'][:]
+ self.input_std = input_stats['stdev'][:]
+
+ assert len(transform_channels) == len(transform_input_means) ==\
+ len(transform_input_stdevs) == len(transform_output_means) == \
+ len(transform_output_stdevs)
+
+ # FEATURE: load the mean and std values for transformed channels and update the normalization statistics
+ self.input_transforms = {}
+ self.input_inverse_transforms = {}
+ self.output_transforms = {}
+ self.output_inverse_transforms = {}
+ for transform_descriptor in transform_channels:
+ channel, transformation = transform_descriptor.split('-')
+ input_channel_idx = input_channel_names.index(channel) if channel in input_channel_names else None
+ output_channel_idx = output_channel_names.index(channel) if channel in output_channel_names else None
+ if transformation.startswith('box_cox'):
+ lmbda_str = transformation.split('_')[-1]
+ lmbda = float(transformation.split('_')[-1])/(10**(len(lmbda_str)-1))
+ print(f"Applying Box-Cox transformation with lambda={lmbda} to channel {channel} (input idx: {input_channel_idx} ({input_channel_names[input_channel_idx]}), output idx: {output_channel_idx} ({output_channel_names[output_channel_idx]}))")
+ if input_channel_idx is not None:
+ self.input_transforms[input_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_transform(x, lmbda)
+ self.input_inverse_transforms[input_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_inverse_transform(x, lmbda)
+ self.input_mean[input_channel_idx] = transform_input_means[transform_descriptor]
+ self.input_std[input_channel_idx] = transform_input_stdevs[transform_descriptor]
+ if output_channel_idx is not None:
+ self.output_transforms[output_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_transform(x, lmbda)
+ self.output_inverse_transforms[output_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_inverse_transform(x, lmbda)
+ self.output_mean[output_channel_idx] = transform_output_means[transform_descriptor]
+ self.output_std[output_channel_idx] = transform_output_stdevs[transform_descriptor]
+ else:
+ raise ValueError(f"Transformation: {transformation} for channel {channel} not implemented.")
+
+ # Initialize the interpolator
+ self.interpolator = GridData(
+ self._input_dataset.longitudes,
+ self._input_dataset.latitudes,
+ self.longitude(),
+ self.latitude())
+
+
+ # DO NOT SUBMIT: This is not implemented yet.
+ # Question: Is it OK to change the signature to return 3 items?
+ def __getitem__(self, idx):
+ """Get input and target data. Transform and normalize, but do not interpolate."""
+
+ # Pull input, replacing the corrected tp if applicable
+ date_str = to_datetime(self._input_dataset.dates[idx]).strftime('%Y%m%d-%H%M')
+
+ # Don't reshape, but do squeeze ensemble dimension.
+ input_data = self._input_dataset[idx].squeeze()
+
+ # Pull target data
+ # squeeze the ensemble dimesnsion
+ target_data = self._output_dataset[idx].squeeze()
+
+ # next two steps only if target is cosmo, real has to be regridded first (done in training loop on gpu-s for efficiency)
+ # reshape to image_shape
+ # flip so that it starts in top-left corner (by default it is bottom left)
+ # if not self.real_target:
+ # target_shape = self.image_shape()
+ # target_data = np.flip(target_data \
+ # .reshape(-1,*target_shape),
+ # 1)
+
+ return torch.from_numpy(target_data.copy()),\
+ torch.from_numpy(input_data),\
+ date_str
+
+ def get_static_data(self):
+ return self.static_data_normalized
+
+ def __len__(self):
+ return len(self._output_dataset.dates)
+
+ # Question: Do we need an input longitude as well?
+ def longitude(self) -> np.ndarray:
+ """Get longitude values from the target dataset."""
+ if self.real_target:
+ return self.lat_lon_real[:,1]
+ return self._output_dataset.longitudes
+
+ def latitude(self) -> np.ndarray:
+ """Get latitude values from the target dataset."""
+ if self.real_target:
+ return self.lat_lon_real[:,0]
+ return self._output_dataset.latitudes
+
+ def input_channels(self) -> List[ChannelMetadata]:
+ """Metadata for the input channels. A list of ChannelMetadata, one for each channel"""
+ return self._input_channels
+
+ def output_channels(self) -> List[ChannelMetadata]:
+ """Metadata for the output channels. A list of ChannelMetadata, one for each channel"""
+ return self._output_channels
+
+ def static_channels(self) -> List[ChannelMetadata]:
+ """Metadata for the static channels. A list of ChannelMetadata, one for each channel"""
+ return self._static_channels
+
+ def time(self) -> List:
+ """Get time values from the dataset."""
+ #TODO Choose the time format and convert to that, currently it's a string from a filename
+ return [to_datetime(dt64).strftime('%Y%m%d-%H%M') for dt64 in self._output_dataset.dates]
+
+ def image_shape(self) -> Tuple[int, int]:
+ """Get the (height, width) of the data."""
+ if self.real_target:
+ return 704,1088
+ return self._output_dataset.field_shape
+
+ def input_shape(self) -> Tuple[int, int]:
+ """Get the (height, width) of the input data."""
+ return self._input_dataset.field_shape
+
+ def normalization_stats(self):
+ """Get the mean and std stats for normalizing the input and output data."""
+ return {"input_mean": self.input_mean,
+ "input_std": self.input_std,
+ "output_mean": self.output_mean,
+ "output_std": self.output_std}
+
+ def stats_to_torch(self, device: torch.device, dtype: torch.dtype = torch.float32):
+ """Convert the mean and std stats to torch tensors on the specified device."""
+ self.input_mean = torch.from_numpy(self.input_mean).to(device=device, dtype=dtype)
+ self.input_std = torch.from_numpy(self.input_std).to(device=device, dtype=dtype)
+ self.output_mean = torch.from_numpy(self.output_mean).to(device=device, dtype=dtype)
+ self.output_std = torch.from_numpy(self.output_std).to(device=device, dtype=dtype)
+
+ def stats_to_numpy(self):
+ """Convert the mean and std stats to numpy arrays."""
+ self.input_mean = self.input_mean.cpu().numpy() if isinstance(self.input_mean, torch.Tensor) else self.input_mean
+ self.input_std = self.input_std.cpu().numpy() if isinstance(self.input_std, torch.Tensor) else self.input_std
+ self.output_mean = self.output_mean.cpu().numpy() if isinstance(self.output_mean, torch.Tensor) else self.output_mean
+ self.output_std = self.output_std.cpu().numpy() if isinstance(self.output_std, torch.Tensor) else self.output_std
+
+ def normalize_input(self, x: np.ndarray | torch.Tensor, mean: np.ndarray | torch.Tensor = None, std: np.ndarray | torch.Tensor = None) -> np.ndarray | torch.Tensor:
+ """Convert input from physical units to normalized data."""
+ if mean is None:
+ mean = self.input_mean
+ if std is None:
+ std = self.input_std
+ for channel_idx, transform in self.input_transforms.items():
+ x[:,channel_idx,::] = transform(x[:,channel_idx,::])
+ return (x - self.input_mean[(None,) + (...,) + (None,) * (x.ndim - 2)]) \
+ / self.input_std[(None,) + (...,) + (None,) * (x.ndim - 2)]
+
+
+ def denormalize_input(self, x: np.ndarray | torch.Tensor, mean: np.ndarray | torch.Tensor = None, std: np.ndarray | torch.Tensor = None) -> np.ndarray | torch.Tensor:
+ """Convert input from normalized data to physical units."""
+ if mean is None:
+ mean = self.input_mean
+ if std is None:
+ std = self.input_std
+ x = x * self.input_std[(None,) + (...,) + (None,) * (x.ndim - 2)] \
+ + self.input_mean[(None,) + (...,) + (None,) * (x.ndim - 2)]
+ for channel_idx, inverse_transform in self.input_inverse_transforms.items():
+ x[:,channel_idx,::] = inverse_transform(x[:,channel_idx,::])
+ return x
+
+
+ def normalize_output(self, x: np.ndarray | torch.Tensor, mean: np.ndarray | torch.Tensor = None, std: np.ndarray | torch.Tensor = None) -> np.ndarray | torch.Tensor:
+ """Convert output from physical units to normalized data."""
+ if mean is None:
+ mean = self.output_mean
+ if std is None:
+ std = self.output_std
+ for channel_idx, transform in self.output_transforms.items():
+ x[:,channel_idx,::] = transform(x[:,channel_idx,::])
+ return (x - self.output_mean[(None,) + (...,) + (None,) * (x.ndim - 2)]) \
+ / self.output_std[(None,) + (...,) + (None,) * (x.ndim - 2)]
+
+
+ def denormalize_output(self, x: np.ndarray | torch.Tensor, mean: np.ndarray | torch.Tensor = None, std: np.ndarray | torch.Tensor = None) -> np.ndarray | torch.Tensor:
+ """Convert output from normalized data to physical units."""
+ if mean is None:
+ mean = self.output_mean
+ if std is None:
+ std = self.output_std
+ x = x * self.output_std[(None,) + (...,) + (None,) * (x.ndim - 2)] \
+ + self.output_mean[(None,) + (...,) + (None,) * (x.ndim - 2)]
+ for channel_idx, inverse_transform in self.output_inverse_transforms.items():
+ x[:,channel_idx,::] = inverse_transform(x[:,channel_idx,::])
+ return x
+
+ def box_cox_transform(self, channel_array: np.ndarray | torch.Tensor, lmbda: float) -> np.ndarray | torch.Tensor:
+ """Apply Box-Cox transformation to the data."""
+ if isinstance(channel_array, torch.Tensor):
+ channel_array = torch.clamp(channel_array, min=0)
+ return (torch.pow(channel_array, lmbda) - 1) / lmbda
+ channel_array = np.clip(channel_array, 0, None)
+ return (np.power(channel_array, lmbda) - 1) / lmbda
+
+ def box_cox_inverse_transform(self, channel_array: np.ndarray | torch.Tensor, lmbda: float) -> np.ndarray | torch.Tensor:
+ """Apply inverse Box-Cox transformation to the data."""
+ if isinstance(channel_array, torch.Tensor):
+ channel_array = torch.clamp(channel_array, min=-1/lmbda)
+ return torch.pow((lmbda * channel_array) + 1, 1 / lmbda)
+ channel_array = np.clip(channel_array, -1/lmbda, None)
+ return np.power((lmbda * channel_array) + 1, 1 / lmbda)
+
+ def make_time_grids(self, dates: list[str], device: torch.device, dtype: torch.dtype = torch.float32) -> torch.Tensor:
+ """
+ Create multi-frequency cyclic sin/cos feature grids for hour and month (batched).
+
+ Parameters
+ ----------
+ dates : Sequence[str]
+ Date strings in the format 'YYYYMMDD-HHMM', length = B
+
+ Returns
+ -------
+ grid : torch.Tensor, shape (B, C)
+ Channels = [sin(k*hour), cos(k*hour), sin(k*month), cos(k*month) for each k]
+ """
+
+ B = len(dates)
+
+ # --- parse month and hour ---
+ months = torch.tensor(
+ [int(d.split("-")[0][4:6]) for d in dates],
+ dtype=torch.float32,
+ device=device,
+ )
+ hours = torch.tensor(
+ [int(d.split("-")[1][0:2]) for d in dates],
+ dtype=torch.float32,
+ device=device,
+ )
+
+ # normalize cyclic components
+ hours = (hours % 24) / 24.0 # (B,)
+ months = ((months - 1) % 12) / 12.0 # (B,)
+
+ # frequencies
+ n_freq = self._n_month_hour_channels // 2
+ freqs = torch.arange(
+ 1, n_freq + 1, dtype=torch.float32, device=device
+ ) # (K,)
+
+ # shape helpers
+ hours = hours[:, None] # (B, 1)
+ months = months[:, None] # (B, 1)
+
+ # --- hour encodings ---
+ hour_angles = 2 * torch.pi * hours * freqs # (B, K)
+ hour_feats = torch.stack(
+ [torch.sin(hour_angles), torch.cos(hour_angles)],
+ dim=2
+ ) # (B, K, 2)
+
+ # --- month encodings ---
+ month_angles = 2 * torch.pi * months * freqs
+ month_feats = torch.stack(
+ [torch.sin(month_angles), torch.cos(month_angles)],
+ dim=2
+ ) # (B, K, 2)
+
+ # concatenate and flatten channels
+ feats = torch.cat([hour_feats, month_feats], dim=1) # (B, 2K, 2)
+ feats = feats.reshape(B, -1) # (B, C)
+
+ # expand to spatial grid
+ # grid = feats[:, :, None, None].expand(B, feats.shape[1], H, W)
+
+ return feats
+
+ANEMOI_ERA5_REAL = AnemoiDataset
+ANEMOI_ERA5_COSMO = AnemoiDataset
diff --git a/src/hirad/datasets/anemoi_dataset_copernicus_tp.py b/src/hirad/datasets/anemoi_dataset_copernicus_tp.py
new file mode 100644
index 00000000..c94f0143
--- /dev/null
+++ b/src/hirad/datasets/anemoi_dataset_copernicus_tp.py
@@ -0,0 +1,409 @@
+from .base import DownscalingDataset, ChannelMetadata
+
+from anemoi.datasets import open_dataset
+import datetime
+import os
+import numpy as np
+from pandas import to_datetime
+import torch
+from typing import List, Tuple
+import yaml
+import torch.nn.functional as F
+import time
+from pathlib import Path
+# import zarr
+
+from .constants import REAL_TO_ERA_CHANNEL_MAP, ERA_TO_REAL_CHANNEL_MAP
+from hirad.utils.console import PythonLogger
+from hirad.utils.dataset_utils import GridData, regrid_icon_to_rotlatlon
+
+
+logger = PythonLogger(__name__)
+
+# Margin to use for ERA dataset (to avoid nans from interpolation at boundary)
+INPUT_MARGIN_DEGREES = 0.5
+
+class AnemoiDatasetCopernicus(DownscalingDataset):
+ def __init__(self,
+ type: str,
+ input_anemoi_dataset_path: str,
+ target_anemoi_dataset_path: str,
+ corrected_tp_path: str,
+ start_date: datetime.datetime = None,
+ end_date: datetime.datetime = None,
+ input_channel_names: List[str] = [],
+ output_channel_names: List[str] = [],
+ static_channel_names: List[str] = [],
+ transform_channels: List[str] = [],
+ transform_input_means: dict = {},
+ transform_input_stdevs: dict = {},
+ transform_output_means: dict = {},
+ transform_output_stdevs: dict = {},
+ n_month_hour_channels: int = None,
+ trim_edge: int = 0,
+ ):
+ super().__init__()
+
+ input_dataset = type.split('_')[1]
+ target_dataset = type.split('_')[-1]
+ self.real_target = target_dataset == 'real'
+ self.trim_edge = trim_edge
+
+ if input_dataset != 'era5':
+ raise ValueError(f"Input dataset {input_dataset} not supported for AnemoiDataset. Only 'era5' is supported.")
+ if target_dataset != 'cosmo' and target_dataset !='real':
+ raise ValueError(f"Target dataset {target_dataset} not supported for AnemoiDataset. Only 'cosmo' and 'real' are supported.")
+
+ if self.real_target:
+ # Map output channel names from real to era5
+ output_channel_names_real = [ERA_TO_REAL_CHANNEL_MAP[name] for name in output_channel_names]
+ self.lat_lon_real = torch.load("/capstor/store/cscs/pasc/c38/real_grid_info/realch1-lat-lon", weights_only=False)
+ self.regrid_indices_real = torch.from_numpy(np.load("/capstor/store/cscs/pasc/c38/real_grid_info/remap_indices.npy")).long()
+ self.regrid_weights_real = torch.from_numpy(np.load("/capstor/store/cscs/pasc/c38/real_grid_info/remap_weights.npy"))
+
+ self._corrected_tp_path = corrected_tp_path
+ #TODO switch hanbdling paths to Path rather than pure strings
+ self._n_month_hour_channels = n_month_hour_channels
+ target_open_dataset_kwargs = {}
+ if start_date is not None and end_date is not None:
+ assert start_date < end_date, "start_date must be before end_date"
+ target_open_dataset_kwargs['start'] = start_date
+ target_open_dataset_kwargs['end'] = end_date
+ if trim_edge > 0 and not self.real_target:
+ target_open_dataset_kwargs['trim_edge'] = trim_edge
+ self._output_dataset = open_dataset(target_anemoi_dataset_path, select=output_channel_names_real if self.real_target else output_channel_names, **target_open_dataset_kwargs)
+ assert self._output_dataset.shape[1] == len(output_channel_names)
+
+ # Load ERA dataset, trimming the area and limiting the dates to the target dataset
+ start_date = self._output_dataset.metadata()['start_date'] if start_date is None else start_date
+ end_date = self._output_dataset.metadata()['end_date'] if end_date is None else end_date
+ latitudes = self.latitude()
+ longitudes = self.longitude()
+ min_lat = min(latitudes) - INPUT_MARGIN_DEGREES
+ max_lat = max(latitudes) + INPUT_MARGIN_DEGREES
+ min_lon = max(0, min(longitudes) - INPUT_MARGIN_DEGREES)
+ max_lon = max(longitudes) + INPUT_MARGIN_DEGREES
+ area=(max_lat, min_lon, min_lat, max_lon)
+
+ self._input_dataset = open_dataset(input_anemoi_dataset_path, select=input_channel_names, start=start_date, end=end_date, area=area)
+ assert self._input_dataset.shape[1] == len(input_channel_names)
+
+ # Check that we have the same number of time points in each dataset
+ assert self._input_dataset.shape[0] == self._output_dataset.shape[0]
+
+ # Load static info and channel names
+ if static_channel_names:
+ self._static_channels = [ChannelMetadata(name) if len(name.split('_'))==1
+ else ChannelMetadata(name.split('_')[0],name.split('_')[1])
+ for name in static_channel_names]
+ static_open_dataset_kwargs = {}
+ if not self.real_target and trim_edge > 0:
+ static_open_dataset_kwargs['trim_edge'] = trim_edge
+ static_dataset = open_dataset(target_anemoi_dataset_path, select=static_channel_names, start=start_date, end=start_date, **static_open_dataset_kwargs)
+ assert static_dataset.shape[1] == len(static_channel_names)
+ # take first time point, and squeeze() to remove ensemble dimension
+ static_data = static_dataset[0,:,:,:].squeeze()
+ # Could also get these from stats, but one-time calculation is OK.
+ self.static_mean = static_data.mean(axis=-1, keepdims=True)
+ self.static_std = static_data.std(axis=-1, keepdims=True)
+ target_shape = self.image_shape()
+ self.static_data_normalized = (static_data - self.static_mean.reshape((self.static_mean.shape[0],1))) \
+ / self.static_std.reshape((self.static_std.shape[0],1))
+ self.static_data_normalized = torch.from_numpy(self.static_data_normalized)
+ self.static_data_normalized = regrid_icon_to_rotlatlon(self.static_data_normalized, self.regrid_indices_real, self.regrid_weights_real)
+ if trim_edge > 0 and self.real_target:
+ self.static_data_normalized = self.static_data_normalized[:, trim_edge:-trim_edge, trim_edge:-trim_edge]
+ # self.normalize_input(np.flip(static_data.squeeze().reshape(-1, *target_shape), 1))
+ else:
+ self.static_data_normalized = None
+ self._static_channels = []
+
+ # Load target channel names
+ self._output_channels = [ChannelMetadata(name) if len(name.split('_'))==1
+ else ChannelMetadata(name.split('_')[0],name.split('_')[1])
+ for name in output_channel_names]
+ # Load era5 channel names
+ self._input_channels = [ChannelMetadata(name) if len(name.split('_'))==1
+ else ChannelMetadata(name.split('_')[0],name.split('_')[1])
+ for name in input_channel_names]
+
+ # Load stats for normalizing channels of input and output
+ target_stats = self._output_dataset.statistics
+ self.output_mean = target_stats['mean'][:]
+ self.output_std = target_stats['stdev'][:]
+
+ input_stats = self._input_dataset.statistics
+ self.input_mean = input_stats['mean'][:]
+ self.input_std = input_stats['stdev'][:]
+
+ assert len(transform_channels) == len(transform_input_means) ==\
+ len(transform_input_stdevs) == len(transform_output_means) == \
+ len(transform_output_stdevs)
+
+ # FEATURE: load the mean and std values for transformed channels and update the normalization statistics
+ self.input_transforms = {}
+ self.input_inverse_transforms = {}
+ self.output_transforms = {}
+ self.output_inverse_transforms = {}
+ for transform_descriptor in transform_channels:
+ channel, transformation = transform_descriptor.split('-')
+ input_channel_idx = input_channel_names.index(channel) if channel in input_channel_names else None
+ output_channel_idx = output_channel_names.index(channel) if channel in output_channel_names else None
+ if transformation.startswith('box_cox'):
+ lmbda_str = transformation.split('_')[-1]
+ lmbda = float(transformation.split('_')[-1])/(10**(len(lmbda_str)-1))
+ print(f"Applying Box-Cox transformation with lambda={lmbda} to channel {channel} (input idx: {input_channel_idx} ({input_channel_names[input_channel_idx]}), output idx: {output_channel_idx} ({output_channel_names[output_channel_idx]}))")
+ if input_channel_idx is not None:
+ self.input_transforms[input_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_transform(x, lmbda)
+ self.input_inverse_transforms[input_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_inverse_transform(x, lmbda)
+ self.input_mean[input_channel_idx] = transform_input_means[transform_descriptor]
+ self.input_std[input_channel_idx] = transform_input_stdevs[transform_descriptor]
+ if output_channel_idx is not None:
+ self.output_transforms[output_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_transform(x, lmbda)
+ self.output_inverse_transforms[output_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_inverse_transform(x, lmbda)
+ self.output_mean[output_channel_idx] = transform_output_means[transform_descriptor]
+ self.output_std[output_channel_idx] = transform_output_stdevs[transform_descriptor]
+ else:
+ raise ValueError(f"Transformation: {transformation} for channel {channel} not implemented.")
+
+ # Initialize the interpolator
+ self.interpolator = GridData(
+ self._input_dataset.longitudes,
+ self._input_dataset.latitudes,
+ self.longitude(),
+ self.latitude())
+
+
+ # DO NOT SUBMIT: This is not implemented yet.
+ # Question: Is it OK to change the signature to return 3 items?
+ def __getitem__(self, idx):
+ """Get input and target data. Transform and normalize, but do not interpolate."""
+
+ # Pull input, replacing the corrected tp if applicable
+ date_str = to_datetime(self._input_dataset.dates[idx]).strftime('%Y%m%d-%H%M')
+ # Don't reshape, but do squeeze ensemble dimension.
+ input_data = self._input_dataset[idx].squeeze()
+ # TODO: Consider generalizing this to other channels, in case we have cp.
+ if ChannelMetadata('tp') in self._input_channels:
+ tp_idx = self._input_channels.index(ChannelMetadata('tp'))
+ corrected_tp_data = np.load(os.path.join(self._corrected_tp_path, f'{date_str}.npy'))
+ input_data[tp_idx,::] = corrected_tp_data
+ # input_data = self.normalize_input(input_data)
+
+ # Pull target data
+ # squeeze the ensemble dimesnsion
+ # next two steps only if target is cosmo, real has to be regridded first (done in training loop on gpu-s for efficiency)
+ # reshape to image_shape
+ # flip so that it starts in top-left corner (by default it is bottom left)
+ target_shape = self.image_shape()
+ target_data = self._output_dataset[idx].squeeze()
+ if not self.real_target:
+ target_data = np.flip(target_data \
+ .reshape(-1,*target_shape),
+ 1)
+ # target_data = self.normalize_output(target_data)
+
+ return torch.from_numpy(target_data.copy()),\
+ torch.from_numpy(input_data),\
+ date_str
+
+ def get_static_data(self):
+ return self.static_data_normalized
+
+ def __len__(self):
+ return len(self._output_dataset.dates)
+
+ # Question: Do we need an input longitude as well?
+ def longitude(self) -> np.ndarray:
+ """Get longitude values from the target dataset."""
+ if self.real_target:
+ return self.lat_lon_real[:,1]
+ return self._output_dataset.longitudes
+
+ def latitude(self) -> np.ndarray:
+ """Get latitude values from the target dataset."""
+ if self.real_target:
+ return self.lat_lon_real[:,0]
+ return self._output_dataset.latitudes
+
+ def input_channels(self) -> List[ChannelMetadata]:
+ """Metadata for the input channels. A list of ChannelMetadata, one for each channel"""
+ return self._input_channels
+
+ def output_channels(self) -> List[ChannelMetadata]:
+ """Metadata for the output channels. A list of ChannelMetadata, one for each channel"""
+ return self._output_channels
+
+ def static_channels(self) -> List[ChannelMetadata]:
+ """Metadata for the static channels. A list of ChannelMetadata, one for each channel"""
+ return self._static_channels
+
+ def time(self) -> List:
+ """Get time values from the dataset."""
+ #TODO Choose the time format and convert to that, currently it's a string from a filename
+ return [to_datetime(dt64).strftime('%Y%m%d-%H%M') for dt64 in self._output_dataset.dates]
+
+ def image_shape(self) -> Tuple[int, int]:
+ """Get the (height, width) of the data."""
+ if self.real_target:
+ return 704,1088
+ return self._output_dataset.field_shape
+
+ def input_shape(self) -> Tuple[int, int]:
+ """Get the (height, width) of the input data."""
+ return self._input_dataset.field_shape
+
+ def normalization_stats(self):
+ """Get the mean and std stats for normalizing the input and output data."""
+ return {"input_mean": self.input_mean,
+ "input_std": self.input_std,
+ "output_mean": self.output_mean,
+ "output_std": self.output_std}
+
+ def stats_to_torch(self, device: torch.device, dtype: torch.dtype = torch.float32):
+ """Convert the mean and std stats to torch tensors on the specified device."""
+ self.input_mean = torch.from_numpy(self.input_mean).to(device=device, dtype=dtype)
+ self.input_std = torch.from_numpy(self.input_std).to(device=device, dtype=dtype)
+ self.output_mean = torch.from_numpy(self.output_mean).to(device=device, dtype=dtype)
+ self.output_std = torch.from_numpy(self.output_std).to(device=device, dtype=dtype)
+
+ def stats_to_numpy(self):
+ """Convert the mean and std stats to numpy arrays."""
+ self.input_mean = self.input_mean.cpu().numpy() if isinstance(self.input_mean, torch.Tensor) else self.input_mean
+ self.input_std = self.input_std.cpu().numpy() if isinstance(self.input_std, torch.Tensor) else self.input_std
+ self.output_mean = self.output_mean.cpu().numpy() if isinstance(self.output_mean, torch.Tensor) else self.output_mean
+ self.output_std = self.output_std.cpu().numpy() if isinstance(self.output_std, torch.Tensor) else self.output_std
+
+ def normalize_input(self, x: np.ndarray | torch.Tensor, mean: np.ndarray | torch.Tensor = None, std: np.ndarray | torch.Tensor = None) -> np.ndarray | torch.Tensor:
+ """Convert input from physical units to normalized data."""
+ if mean is None:
+ mean = self.input_mean
+ if std is None:
+ std = self.input_std
+ for channel_idx, transform in self.input_transforms.items():
+ x[:,channel_idx,::] = transform(x[:,channel_idx,::])
+ return (x - self.input_mean[(None,) + (...,) + (None,) * (x.ndim - 2)]) \
+ / self.input_std[(None,) + (...,) + (None,) * (x.ndim - 2)]
+
+
+ def denormalize_input(self, x: np.ndarray | torch.Tensor, mean: np.ndarray | torch.Tensor = None, std: np.ndarray | torch.Tensor = None) -> np.ndarray | torch.Tensor:
+ """Convert input from normalized data to physical units."""
+ if mean is None:
+ mean = self.input_mean
+ if std is None:
+ std = self.input_std
+ x = x * self.input_std[(None,) + (...,) + (None,) * (x.ndim - 2)] \
+ + self.input_mean[(None,) + (...,) + (None,) * (x.ndim - 2)]
+ for channel_idx, inverse_transform in self.input_inverse_transforms.items():
+ x[:,channel_idx,::] = inverse_transform(x[:,channel_idx,::])
+ return x
+
+
+ def normalize_output(self, x: np.ndarray | torch.Tensor, mean: np.ndarray | torch.Tensor = None, std: np.ndarray | torch.Tensor = None) -> np.ndarray | torch.Tensor:
+ """Convert output from physical units to normalized data."""
+ if mean is None:
+ mean = self.output_mean
+ if std is None:
+ std = self.output_std
+ for channel_idx, transform in self.output_transforms.items():
+ x[:,channel_idx,::] = transform(x[:,channel_idx,::])
+ return (x - self.output_mean[(None,) + (...,) + (None,) * (x.ndim - 2)]) \
+ / self.output_std[(None,) + (...,) + (None,) * (x.ndim - 2)]
+
+
+ def denormalize_output(self, x: np.ndarray | torch.Tensor, mean: np.ndarray | torch.Tensor = None, std: np.ndarray | torch.Tensor = None) -> np.ndarray | torch.Tensor:
+ """Convert output from normalized data to physical units."""
+ if mean is None:
+ mean = self.output_mean
+ if std is None:
+ std = self.output_std
+ x = x * self.output_std[(None,) + (...,) + (None,) * (x.ndim - 2)] \
+ + self.output_mean[(None,) + (...,) + (None,) * (x.ndim - 2)]
+ for channel_idx, inverse_transform in self.output_inverse_transforms.items():
+ x[:,channel_idx,::] = inverse_transform(x[:,channel_idx,::])
+ return x
+
+ def box_cox_transform(self, channel_array: np.ndarray | torch.Tensor, lmbda: float) -> np.ndarray | torch.Tensor:
+ """Apply Box-Cox transformation to the data."""
+ if isinstance(channel_array, torch.Tensor):
+ channel_array = torch.clamp(channel_array, min=0)
+ return (torch.pow(channel_array, lmbda) - 1) / lmbda
+ channel_array = np.clip(channel_array, 0, None)
+ return (np.power(channel_array, lmbda) - 1) / lmbda
+
+ def box_cox_inverse_transform(self, channel_array: np.ndarray | torch.Tensor, lmbda: float) -> np.ndarray | torch.Tensor:
+ """Apply inverse Box-Cox transformation to the data."""
+ if isinstance(channel_array, torch.Tensor):
+ channel_array = torch.clamp(channel_array, min=-1/lmbda)
+ return torch.pow((lmbda * channel_array) + 1, 1 / lmbda)
+ channel_array = np.clip(channel_array, -1/lmbda, None)
+ return np.power((lmbda * channel_array) + 1, 1 / lmbda)
+
+ def make_time_grids(self, dates: list[str], device: torch.device, dtype: torch.dtype = torch.float32) -> torch.Tensor:
+ """
+ Create multi-frequency cyclic sin/cos feature grids for hour and month (batched).
+
+ Parameters
+ ----------
+ dates : Sequence[str]
+ Date strings in the format 'YYYYMMDD-HHMM', length = B
+
+ Returns
+ -------
+ grid : torch.Tensor, shape (B, C, H, W)
+ Channels = [sin(k*hour), cos(k*hour), sin(k*month), cos(k*month) for each k]
+ """
+
+ B = len(dates)
+
+ # --- parse month and hour ---
+ months = torch.tensor(
+ [int(d.split("-")[0][4:6]) for d in dates],
+ dtype=torch.float32,
+ device=device,
+ )
+ hours = torch.tensor(
+ [int(d.split("-")[1][0:2]) for d in dates],
+ dtype=torch.float32,
+ device=device,
+ )
+
+ # normalize cyclic components
+ hours = (hours % 24) / 24.0 # (B,)
+ months = ((months - 1) % 12) / 12.0 # (B,)
+
+ # frequencies
+ n_freq = self._n_month_hour_channels // 2
+ freqs = torch.arange(
+ 1, n_freq + 1, dtype=torch.float32, device=device
+ ) # (K,)
+
+ # shape helpers
+ hours = hours[:, None] # (B, 1)
+ months = months[:, None] # (B, 1)
+
+ # --- hour encodings ---
+ hour_angles = 2 * torch.pi * hours * freqs # (B, K)
+ hour_feats = torch.stack(
+ [torch.sin(hour_angles), torch.cos(hour_angles)],
+ dim=2
+ ) # (B, K, 2)
+
+ # --- month encodings ---
+ month_angles = 2 * torch.pi * months * freqs
+ month_feats = torch.stack(
+ [torch.sin(month_angles), torch.cos(month_angles)],
+ dim=2
+ ) # (B, K, 2)
+
+ # concatenate and flatten channels
+ feats = torch.cat([hour_feats, month_feats], dim=1) # (B, 2K, 2)
+ feats = feats.reshape(B, -1) # (B, C)
+
+ # expand to spatial grid
+ # grid = feats[:, :, None, None].expand(B, feats.shape[1], H, W)
+
+ return feats
+
+ANEMOI_ERA5COPERNICUSTP_REAL = AnemoiDatasetCopernicus
+ANEMOI_ERA5COPERNICUSTP_COSMO = AnemoiDatasetCopernicus
diff --git a/src/hirad/datasets/base.py b/src/hirad/datasets/base.py
index 22b00d25..1fa57c1e 100644
--- a/src/hirad/datasets/base.py
+++ b/src/hirad/datasets/base.py
@@ -31,6 +31,20 @@ class ChannelMetadata:
auxiliary: bool = False
+def get_channels_from_strings(channel_strings: List[str] | str) -> List[ChannelMetadata] | ChannelMetadata:
+ """Convert list of channel strings to ChannelMetadata objects."""
+ if isinstance(channel_strings, str):
+ return ChannelMetadata(channel_strings) if len(channel_strings.split('_'))==1 else ChannelMetadata(channel_strings.split('_')[0],channel_strings.split('_')[1])
+ else:
+ return [ChannelMetadata(name) if len(name.split('_'))==1 else ChannelMetadata(name.split('_')[0],name.split('_')[1]) for name in channel_strings]
+
+def get_strings_from_channels(channels: List[ChannelMetadata] | ChannelMetadata) -> List[str] | str:
+ """Convert list of ChannelMetadata objects to channel strings."""
+ if isinstance(channels, ChannelMetadata):
+ return channels.name if not channels.level else f"{channels.name}_{channels.level}"
+ else:
+ return [ch.name if not ch.level else f"{ch.name}_{ch.level}" for ch in channels]
+
class DownscalingDataset(torch.utils.data.Dataset, ABC):
"""An abstract class that defines the interface for downscaling datasets."""
diff --git a/src/hirad/datasets/constants.py b/src/hirad/datasets/constants.py
new file mode 100644
index 00000000..2ebe4d68
--- /dev/null
+++ b/src/hirad/datasets/constants.py
@@ -0,0 +1,9 @@
+REAL_TO_ERA_CHANNEL_MAP = {
+ 'T_2M': '2t',
+ 'U_10M': '10u',
+ 'V_10M': '10v',
+ 'TOT_PREC_1H': 'tp'
+}
+
+ERA_TO_REAL_CHANNEL_MAP = {v: k for k, v in REAL_TO_ERA_CHANNEL_MAP.items()}
+
diff --git a/src/hirad/datasets/dataset.py b/src/hirad/datasets/dataset.py
index c09a4f23..cb2124e7 100644
--- a/src/hirad/datasets/dataset.py
+++ b/src/hirad/datasets/dataset.py
@@ -22,12 +22,20 @@
from hirad.distributed import DistributedManager
from .era5_cosmo import ERA5_COSMO
+from .era5_real import ERA5_REAL
+from .anemoi_dataset import ANEMOI_ERA5_COSMO, ANEMOI_ERA5_REAL
+from .anemoi_dataset_copernicus_tp import ANEMOI_ERA5COPERNICUSTP_COSMO, ANEMOI_ERA5COPERNICUSTP_REAL
from .base import DownscalingDataset
# this maps all known dataset types to the corresponding init function
known_datasets = {
"era5_cosmo": ERA5_COSMO,
+ "era5_real": ERA5_REAL,
+ "anemoi_era5_cosmo": ANEMOI_ERA5_COSMO,
+ "anemoi_era5_real": ANEMOI_ERA5_REAL,
+ "anemoi_era5_copernicus_tp_real": ANEMOI_ERA5COPERNICUSTP_REAL,
+ "anemoi_era5_copernicus_tp_cosmo": ANEMOI_ERA5COPERNICUSTP_COSMO,
}
@@ -37,6 +45,7 @@ def init_train_valid_datasets_from_config(
batch_size: int = 1,
seed: int = 0,
train_test_split: bool = True,
+ sampler_start_idx: int = 0,
) -> Tuple[
DownscalingDataset,
Iterable,
@@ -52,21 +61,29 @@ def init_train_valid_datasets_from_config(
- batch_size (int): The number of samples in each batch of data. Defaults to 1.
- seed (int): The random seed for dataset shuffling. Defaults to 0.
- train_test_split (bool): A flag to determine whether to create a validation dataset. Defaults to True.
+ - sampler_start_idx (int): The initial index of the sampler to use for resuming training. Defaults to 0.
Returns:
- Tuple[base.DownscalingDataset, Iterable, Optional[base.DownscalingDataset], Optional[Iterable]]: A tuple containing the training dataset and iterator, and optionally the validation dataset and iterator if train_test_split is True.
"""
config = copy.deepcopy(dataset_cfg)
- if 'validation_path' in config:
- del config['validation_path']
+ config.pop("validation", None)
+ config.pop("validation_start_date", None)
+ config.pop("validation_end_date", None)
(dataset, dataset_iter) = init_dataset_from_config(
- config, dataloader_cfg, batch_size=batch_size, seed=seed
+ config, dataloader_cfg, batch_size=batch_size, seed=seed, sampler_start_idx=sampler_start_idx,
)
if train_test_split:
valid_dataset_cfg = copy.deepcopy(dataset_cfg)
- valid_dataset_cfg["dataset_path"] = valid_dataset_cfg["validation_path"]
- del valid_dataset_cfg['validation_path']
+ del valid_dataset_cfg['validation']
+ if "validation_start_date" not in valid_dataset_cfg or "validation_end_date" not in valid_dataset_cfg:
+ raise ValueError("validation_start_date and validation_en_date must be specified in anemoi dataset_cfg when validation is set to True")
+ valid_dataset_cfg["start_date"] = valid_dataset_cfg["validation_start_date"]
+ valid_dataset_cfg["end_date"] = valid_dataset_cfg["validation_end_date"]
+ del valid_dataset_cfg['validation_start_date']
+ del valid_dataset_cfg['validation_end_date']
+
(valid_dataset, valid_dataset_iter) = init_dataset_from_config(
valid_dataset_cfg, dataloader_cfg, batch_size=batch_size, seed=seed
)
@@ -81,14 +98,12 @@ def init_dataset_from_config(
dataloader_cfg: Union[dict, None] = None,
batch_size: int = 1,
seed: int = 0,
+ sampler_start_idx: int = 0,
+ pop_type: bool = True,
) -> Tuple[DownscalingDataset, Iterable]:
+
dataset_cfg = copy.deepcopy(dataset_cfg)
- dataset_type = dataset_cfg.pop("type", "era5_cosmo")
- if "validation_path" in dataset_cfg:
- del dataset_cfg['validation_path']
- if "train_test_split" in dataset_cfg:
- # handled by init_train_valid_datasets_from_config
- del dataset_cfg["train_test_split"]
+ dataset_type = dataset_cfg.get("type", "era5_cosmo")
dataset_init_func = known_datasets[dataset_type]
dataset_obj = dataset_init_func(**dataset_cfg)
@@ -97,7 +112,7 @@ def init_dataset_from_config(
dist = DistributedManager()
dataset_sampler = InfiniteSampler(
- dataset=dataset_obj, rank=dist.rank, num_replicas=dist.world_size, seed=seed
+ dataset=dataset_obj, rank=dist.rank, num_replicas=dist.world_size, seed=seed, start_idx=sampler_start_idx,
)
dataset_iterator = iter(
@@ -111,3 +126,22 @@ def init_dataset_from_config(
)
return (dataset_obj, dataset_iterator)
+
+
+def get_dataset_and_sampler_inference(dataset_cfg, times, has_lead_time=False):
+ """
+ Get a dataset and sampler for generation.
+ """
+ (dataset, _) = init_dataset_from_config(dataset_cfg, batch_size=1)
+ # if has_lead_time:
+ # plot_times = times
+ # else:
+ # plot_times = [
+ # datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%S")
+ # for time in times
+ # ]
+ all_times = dataset.time()
+ time_indices = [all_times.index(t) for t in times]
+ sampler = time_indices
+
+ return dataset, sampler
diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py
index f97dbc64..ff58b0ac 100644
--- a/src/hirad/datasets/era5_cosmo.py
+++ b/src/hirad/datasets/era5_cosmo.py
@@ -5,68 +5,188 @@
from typing import List, Tuple
import yaml
import torch.nn.functional as F
+import time
+# import zarr
+
+from hirad.utils.console import PythonLogger
+
+logger = PythonLogger(__name__)
+
+# DATASET_ORIG_PATH = '/capstor/store/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-linear-interpolation-full'
+DATASET_ORIG_PATH = "/capstor/store/cscs/pasc/c38/basic-numpy/basic-numpy/era5-cosmo-1h-all-channels/train"
class ERA5_COSMO(DownscalingDataset):
- def __init__(self, dataset_path: str):
+ def __init__(self,
+ dataset_path: str,
+ input_channel_names: List[str] = [],
+ output_channel_names: List[str] = [],
+ static_channel_names: List[str] = [],
+ transform_channels: List[str] = [],
+ n_month_hour_channels: int = None,
+ input_dir_name: str = 'era-interpolated',
+ output_dir_name: str = 'cosmo',
+ ):
super().__init__()
#TODO switch hanbdling paths to Path rather than pure strings
+ self._n_month_hour_channels = n_month_hour_channels
self._dataset_path = dataset_path
- self._era5_path = os.path.join(dataset_path, 'era-interpolated')
- self._cosmo_path = os.path.join(dataset_path, 'cosmo')
- self._info_path = os.path.join(dataset_path, 'info')
+ print(f"Loading ERA5-COSMO dataset from path: {dataset_path}")
+ print(f"Input dir name: {input_dir_name}, output dir name: {output_dir_name}")
+ self._era5_path = os.path.join(dataset_path, input_dir_name)
+ self._cosmo_path = os.path.join(dataset_path, output_dir_name)
+ self._info_path = os.path.join(DATASET_ORIG_PATH, 'info')
+ # self._static_path = '/capstor/store/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-linear-interpolation-full/static'# os.path.join(dataset_path, 'static')
+ self._static_path = os.path.join(DATASET_ORIG_PATH, 'static')# os.path.join(dataset_path, 'static')
+ # self._zarr_path = os.path.join(dataset_path, 'dataset.zarr')
# load file list (each file is one date-time state)
- self._file_list = os.listdir(self._cosmo_path)
+ self._file_list = sorted(os.listdir(self._cosmo_path))
+
+ # open zarr store
+ # self._zarr_store = zarr.open(self._zarr_path, mode='r')
+ # self.era5 = self._zarr_store['era5']
+ # self.cosmo = self._zarr_store['cosmo']
+
+ # Load static info and channel names
+ if static_channel_names:
+ with open(os.path.join(self._static_path, 'cosmo-static.yaml'), 'r') as file:
+ self._static_info = yaml.safe_load(file)
+ self._static_indeces = [self._static_info['select'].index(name) for name in static_channel_names]
+ self._static_channels = [ChannelMetadata(name) if len(name.split('_'))==1
+ else ChannelMetadata(name.split('_')[0],name.split('_')[1])
+ for name in self._static_info['select'] if name in static_channel_names]
+ static_data = torch.load(os.path.join(self._static_path,'cosmo-static'), weights_only=False)[self._static_indeces]
+ orig_shape = self.image_shape()
+ self.static_data = np.flip(static_data \
+ .squeeze() \
+ .reshape(-1,*orig_shape),
+ 1)
+ self.static_mean = self.static_data.mean(axis=(1,2))
+ self.static_std = self.static_data.std(axis=(1,2))
+ else:
+ self.static_data = None
# Load cosmo info and channel names
with open(os.path.join(self._info_path,'cosmo.yaml'), 'r') as file:
self._cosmo_info = yaml.safe_load(file)
- self._cosmo_channels = [ChannelMetadata(name) for name in self._cosmo_info['select']]
+ if output_channel_names:
+ self._cosmo_indeces = [self._cosmo_info['select'].index(name) for name in output_channel_names]
+ else:
+ self._cosmo_indeces = list(range(len(self._cosmo_info['select'])))
+ output_channel_names = self._cosmo_info['select']
+ self._cosmo_channels = [ChannelMetadata(name) if len(name.split('_'))==1
+ else ChannelMetadata(name.split('_')[0],name.split('_')[1])
+ for name in self._cosmo_info['select'] if name in output_channel_names]
# Load era5 info and channel names
with open(os.path.join(self._info_path,'era.yaml'), 'r') as file:
self._era_info = yaml.safe_load(file)
+ if input_channel_names:
+ self._era_indeces = [self._era_info['select'].index(name) for name in input_channel_names]
+ else:
+ self._era_indeces = list(range(len(self._era_info['select'])))
+ input_channel_names = self._era_info['select']
self._era_channels = [ChannelMetadata(name) if len(name.split('_'))==1
- else ChannelMetadata(name.split('_')[0],name.split('_')[1])
- for name in self._era_info['select']]
+ else ChannelMetadata(name.split('_')[0],name.split('_')[1])
+ for name in self._era_info['select'] if name in input_channel_names]
# Load stats for normalizing channels of input and output
cosmo_stats = torch.load(os.path.join(self._info_path,'cosmo-stats'), weights_only=False)
- self.output_mean = cosmo_stats['mean']
- self.output_std = cosmo_stats['stdev']
+ self.output_mean = cosmo_stats['mean'][self._cosmo_indeces]
+ self.output_std = cosmo_stats['stdev'][self._cosmo_indeces]
era_stats = torch.load(os.path.join(self._info_path,'era-stats'), weights_only=False)
- self.input_mean = era_stats['mean']
- self.input_std = era_stats['stdev']
+ self.input_mean = era_stats['mean'][self._era_indeces]
+ self.input_std = era_stats['stdev'][self._era_indeces]
+ if self.static_data is not None:
+ self.input_mean = np.concatenate((self.input_mean, self.static_mean), axis=0)
+ self.input_std = np.concatenate((self.input_std, self.static_std), axis=0)
+ # FEATURE: load the mean and std values for transformed channels and update the normalization statistics
+ self.input_transforms = {}
+ self.input_inverse_transforms = {}
+ self.output_transforms = {}
+ self.output_inverse_transforms = {}
+ for transform_descriptor in transform_channels:
+ channel, transformation = transform_descriptor.split('-')
+ input_channel_idx = input_channel_names.index(channel) if channel in input_channel_names else None
+ output_channel_idx = output_channel_names.index(channel) if channel in output_channel_names else None
+ if transformation.startswith('box_cox'):
+ lmbda_str = transformation.split('_')[-1]
+ lmbda = float(transformation.split('_')[-1])/(10**(len(lmbda_str)-1))
+ print(f"Applying Box-Cox transformation with lambda={lmbda} to channel {channel} (input idx: {input_channel_idx} ({input_channel_names[input_channel_idx]}), output idx: {output_channel_idx} ({output_channel_names[output_channel_idx]}))")
+ if input_channel_idx is not None:
+ self.input_transforms[input_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_transform(x, lmbda)
+ self.input_inverse_transforms[input_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_inverse_transform(x, lmbda)
+ self.input_mean[input_channel_idx] = torch.load(os.path.join(self._info_path,f"era5-{transform_descriptor}-mean"), weights_only=False)
+ self.input_std[input_channel_idx] = torch.load(os.path.join(self._info_path,f"era5-{transform_descriptor}-std"), weights_only=False)
+ if output_channel_idx is not None:
+ self.output_transforms[output_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_transform(x, lmbda)
+ self.output_inverse_transforms[output_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_inverse_transform(x, lmbda)
+ self.output_mean[output_channel_idx] = torch.load(os.path.join(self._info_path,f"cosmo-{transform_descriptor}-mean"), weights_only=False)
+ self.output_std[output_channel_idx] = torch.load(os.path.join(self._info_path,f"cosmo-{transform_descriptor}-std"), weights_only=False)
+ else:
+ raise ValueError(f"Transformation: {transformation} for channel {channel} not implemented.")
+
def __getitem__(self, idx):
"""Get cosmo and era5 interpolated to cosmo grid"""
- # get era5 data point
+ # get data point
# squeeze the ensemble dimesnsion
# reshape to image_shape
# flip so that it starts in top-left corner (by default it is bottom left)
# orig_shape = [350,542] #TODO currently padding to be divisible by 16
orig_shape = self.image_shape()
- era5_data = np.flip(torch.load(os.path.join(self._era5_path,self._file_list[idx]), weights_only=False)\
+ try:
+ # start = time.perf_counter()
+ # era5_data = torch.load(os.path.join(self._era5_path,self._file_list[idx].split('.')[0]), weights_only=False)[self._era_indeces]
+ era5_data = np.load(os.path.join(self._era5_path,self._file_list[idx]), mmap_mode='r')[self._era_indeces]
+ # end = time.perf_counter()
+ # logger.info(f"Reading time era: {end - start:.6f} seconds")
+ except:
+ logger.error(f"Error loading file {os.path.join(self._era5_path,self._file_list[idx])}")
+ raise
+ # start = time.perf_counter()
+ era5_data = np.flip(era5_data \
.squeeze() \
.reshape(-1,*orig_shape),
1)
+ era5_data = np.concatenate((era5_data, self.static_data), axis=0) if self.static_data is not None else era5_data
era5_data = self.normalize_input(era5_data)
- # get cosmo data point
- cosmo_data = np.flip(torch.load(os.path.join(self._cosmo_path,self._file_list[idx]), weights_only=False)\
+ # end = time.perf_counter()
+ # logger.info(f"Preprocess time era: {end - start:.6f} seconds")
+ try:
+ # start = time.perf_counter()
+ # cosmo_data = torch.load(os.path.join(self._cosmo_path,self._file_list[idx]), weights_only=False)[self._cosmo_indeces]
+ cosmo_data = np.load(os.path.join(self._cosmo_path,self._file_list[idx]), mmap_mode='r')[self._cosmo_indeces]
+ # end = time.perf_counter()
+ # logger.info(f"Reading time cosmo: {end - start:.6f} seconds")
+ except:
+ logger.error(f"Error loading file {os.path.join(self._cosmo_path,self._file_list[idx])}")
+ raise
+ # start = time.perf_counter()
+ cosmo_data = np.flip(cosmo_data\
.squeeze() \
.reshape(-1,*orig_shape),
1)
cosmo_data = self.normalize_output(cosmo_data)
- # return samples
- return torch.tensor(cosmo_data),\
- torch.tensor(era5_data),
- # return F.pad(torch.tensor(cosmo_data), pad=(1,1,1,1), mode='constant', value=0), \
- # F.pad(torch.tensor(era5_data), pad=(1,1,1,1), mode='constant', value=0), \
- # 0
+ # end = time.perf_counter()
+ # logger.info(f"Preprocess time cosmo: {end - start:.6f} seconds")
+
+ if self._n_month_hour_channels is not None and self._n_month_hour_channels>0:
+ # extract month and hour from filename
+ filename = self._file_list[idx]
+ date_str, hour_str = filename.split('-')
+ month = int(date_str[4:6])
+ hour = int(hour_str[0:2])
+
+ time_grid = self.make_time_grids(hour, month)
+ era5_data = np.concatenate((era5_data, time_grid), axis=0)
+
+ return torch.from_numpy(cosmo_data),\
+ torch.from_numpy(era5_data)
def __len__(self):
return len(self._file_list)
@@ -86,8 +206,13 @@ def latitude(self) -> np.ndarray:
def input_channels(self) -> List[ChannelMetadata]:
"""Metadata for the input channels. A list of ChannelMetadata, one for each channel"""
- return self._era_channels
-
+ channels = self._era_channels + self._static_channels if self.static_data is not None else self._era_channels
+ if self._n_month_hour_channels is not None and self._n_month_hour_channels>0:
+ for i in range(self._n_month_hour_channels):
+ channels.append(ChannelMetadata("hour-enc",f"{i}"))
+ for i in range(self._n_month_hour_channels):
+ channels.append(ChannelMetadata("month-enc",f"{i}"))
+ return channels
def output_channels(self) -> List[ChannelMetadata]:
"""Metadata for the output channels. A list of ChannelMetadata, one for each channel"""
@@ -108,23 +233,84 @@ def image_shape(self) -> Tuple[int, int]:
def normalize_input(self, x: np.ndarray) -> np.ndarray:
"""Convert input from physical units to normalized data."""
+ for channel_idx, transform in self.input_transforms.items():
+ x[channel_idx,::] = transform(x[channel_idx,::])
return (x - self.input_mean.reshape((self.input_mean.shape[0],1,1))) \
/ self.input_std.reshape((self.input_std.shape[0],1,1))
def denormalize_input(self, x: np.ndarray) -> np.ndarray:
"""Convert input from normalized data to physical units."""
- return x * self.input_std.reshape((self.input_std.shape[0],1,1)) \
+ if self._n_month_hour_channels is not None and self._n_month_hour_channels>0:
+ x = x[:,:-2*self._n_month_hour_channels,:,:]
+ x = x * self.input_std.reshape((self.input_std.shape[0],1,1)) \
+ self.input_mean.reshape((self.input_mean.shape[0],1,1))
+ for channel_idx, inverse_transform in self.input_inverse_transforms.items():
+ x[:,channel_idx,::] = inverse_transform(x[:,channel_idx,::])
+ return x
def normalize_output(self, x: np.ndarray) -> np.ndarray:
"""Convert output from physical units to normalized data."""
+ for channel_idx, transform in self.output_transforms.items():
+ x[channel_idx,::] = transform(x[channel_idx,::])
return (x - self.output_mean.reshape((self.output_mean.shape[0],1,1))) \
/ self.output_std.reshape((self.output_std.shape[0],1,1))
def denormalize_output(self, x: np.ndarray) -> np.ndarray:
"""Convert output from normalized data to physical units."""
- return x * self.output_std.reshape((self.output_std.shape[0],1,1)) \
- + self.output_mean.reshape((self.output_mean.shape[0],1,1))
\ No newline at end of file
+ x = x * self.output_std.reshape((self.output_std.shape[0],1,1)) \
+ + self.output_mean.reshape((self.output_mean.shape[0],1,1))
+ for channel_idx, inverse_transform in self.output_inverse_transforms.items():
+ x[:,channel_idx,::] = inverse_transform(x[:,channel_idx,::])
+ return x
+
+ def box_cox_transform(self, channel_array: np.ndarray, lmbda: float) -> np.ndarray:
+ """Apply Box-Cox transformation to the data."""
+ channel_array = np.clip(channel_array, 0, None)
+ return (np.power(channel_array, lmbda) - 1) / lmbda
+
+ def box_cox_inverse_transform(self, channel_array: np.ndarray, lmbda: float) -> np.ndarray:
+ """Apply inverse Box-Cox transformation to the data."""
+ channel_array = np.clip(channel_array, -1/lmbda, None)
+ return np.power((lmbda * channel_array) + 1, 1 / lmbda)
+
+ def make_time_grids(self, hour, month):
+ """
+ Create multi-frequency cyclic sin/cos feature grids for hour and month.
+
+ Parameters
+ ----------
+ hour : int
+ Hour of day, 0-23
+ month : int
+ Month of year, 1-12
+
+ Returns
+ -------
+ grid : np.ndarray, shape (C, H, W)
+ Channels = [sin(k*hour), cos(k*hour), sin(k*month), cos(k*month) for each k frequency]
+ """
+ H, W = self.image_shape()
+ hour_freqs = np.arange(1, self._n_month_hour_channels//2 + 1)
+ month_freqs = np.arange(1, self._n_month_hour_channels//2 + 1)
+
+ channels = []
+
+ # --- hour encodings ---
+ for k in hour_freqs:
+ angle = 2 * np.pi * k * (hour % 24) / 24.0
+ channels.append(np.sin(angle))
+ channels.append(np.cos(angle))
+
+ # --- month encodings ---
+ for k in month_freqs:
+ angle = 2 * np.pi * k * ((month - 1) % 12) / 12.0
+ channels.append(np.sin(angle))
+ channels.append(np.cos(angle))
+
+ channels = np.array(channels, dtype=np.float32)
+ grid = np.tile(channels[:, None, None], (1, H, W)) # (C, H, W)
+
+ return grid
\ No newline at end of file
diff --git a/src/hirad/datasets/era5_real.py b/src/hirad/datasets/era5_real.py
new file mode 100644
index 00000000..bd8265ee
--- /dev/null
+++ b/src/hirad/datasets/era5_real.py
@@ -0,0 +1,309 @@
+from .base import DownscalingDataset, ChannelMetadata
+import os
+import numpy as np
+import torch
+from typing import List, Tuple
+import yaml
+import torch.nn.functional as F
+import time
+# import zarr
+
+from hirad.utils.console import PythonLogger
+
+logger = PythonLogger(__name__)
+
+DATASET_ORIG_PATH = '/iopsstor/scratch/cscs/mmcgloho/basic-numpy/era5-realch1/v1.0-channel-subset'
+
+ERA5_TO_REAL_CHANNEL_MAP = {
+ '2t': 'TD_2M',
+ '10u': 'U_10M',
+ '10v': 'V_10M',
+ 'tp': 'TOT_PREC_1H'
+}
+
+class ERA5_REAL(DownscalingDataset):
+ def __init__(self,
+ dataset_path: str,
+ input_channel_names: List[str] = [],
+ output_channel_names: List[str] = [],
+ static_channel_names: List[str] = [],
+ transform_channels: List[str] = [],
+ n_month_hour_channels: int = 0,
+ input_dir_name: str = 'era-copernicus-interpolated',
+ output_dir_name: str = 'realch1',
+ ):
+ super().__init__()
+
+ #TODO switch hanbdling paths to Path rather than pure strings
+ self._n_month_hour_channels = n_month_hour_channels
+ self._dataset_path = dataset_path
+ self._era5_path = os.path.join(dataset_path, input_dir_name)
+ self._real_path = os.path.join(dataset_path, output_dir_name)
+ self._info_path = os.path.join(DATASET_ORIG_PATH, 'info')
+ # self._static_path = '/capstor/store/mch/msopr/hirad-gen/basic-torch/era5-real-1h-linear-interpolation-full/static'# os.path.join(dataset_path, 'static')
+ self._static_path = os.path.join(DATASET_ORIG_PATH, 'static')# os.path.join(dataset_path, 'static')
+ # self._zarr_path = os.path.join(dataset_path, 'dataset.zarr')
+
+ # load file list (each file is one date-time state)
+ self._file_list = sorted(os.listdir(self._real_path))
+
+ # open zarr store
+ # self._zarr_store = zarr.open(self._zarr_path, mode='r')
+ # self.era5 = self._zarr_store['era5']
+ # self.real = self._zarr_store['real']
+
+ # Load static info and channel names
+ if static_channel_names:
+ with open(os.path.join(self._static_path, output_dir_name+'-static.yaml'), 'r') as file:
+ self._static_info = yaml.safe_load(file)
+ self._static_indeces = [self._static_info['select'].index(name) for name in static_channel_names]
+ self._static_channels = [ChannelMetadata(name) if len(name.split('_'))==1
+ else ChannelMetadata(name.split('_')[0],name.split('_')[1])
+ for name in self._static_info['select'] if name in static_channel_names]
+ static_data = np.load(os.path.join(self._static_path, output_dir_name+'-static.npy'))[self._static_indeces]
+ orig_shape = self.image_shape()
+ self.static_data = np.flip(static_data \
+ .squeeze() \
+ .reshape(-1,*orig_shape),
+ 1)
+ self.static_mean = self.static_data.mean(axis=(1,2))
+ self.static_std = self.static_data.std(axis=(1,2))
+ else:
+ self.static_data = None
+
+ # Load real info and channel names
+ with open(os.path.join(self._info_path, output_dir_name+'.yaml'), 'r') as file:
+ self._real_info = yaml.safe_load(file)
+ if output_channel_names:
+ self._real_indeces = [self._real_info['select'].index(name) for name in output_channel_names]
+ else:
+ self._real_indeces = list(range(len(self._real_info['select'])))
+ output_channel_names = self._real_info['select']
+ self._real_channels = [ChannelMetadata(name) if len(name.split('_'))==1
+ else ChannelMetadata(name.split('_')[0],name.split('_')[1])
+ for name in self._real_info['select'] if name in output_channel_names]
+
+ # Load era5 info and channel names
+ with open(os.path.join(self._info_path,'era.yaml'), 'r') as file:
+ self._era_info = yaml.safe_load(file)
+ if input_channel_names:
+ self._era_indeces = [self._era_info['select'].index(name) for name in input_channel_names]
+ else:
+ self._era_indeces = list(range(len(self._era_info['select'])))
+ input_channel_names = self._era_info['select']
+ self._era_channels = [ChannelMetadata(name) if len(name.split('_'))==1
+ else ChannelMetadata(name.split('_')[0],name.split('_')[1])
+ for name in self._era_info['select'] if name in input_channel_names]
+
+ # Load stats for normalizing channels of input and output
+
+ era_stats = torch.load(os.path.join(self._info_path,'era-stats'), weights_only=False)
+ self.input_mean = era_stats['mean'][self._era_indeces]
+ self.input_std = era_stats['stdev'][self._era_indeces]
+
+ real_stats = torch.load(os.path.join(self._info_path,'realch1-stats'), weights_only=False)
+ self.output_mean = real_stats['mean'][self._real_indeces]
+ self.output_std = real_stats['stdev'][self._real_indeces]
+
+ if self.static_data is not None:
+ self.input_mean = np.concatenate((self.input_mean, self.static_mean), axis=0)
+ self.input_std = np.concatenate((self.input_std, self.static_std), axis=0)
+
+ # FEATURE: load the mean and std values for transformed channels and update the normalization statistics
+
+ self.input_transforms = {}
+ self.input_inverse_transforms = {}
+ self.output_transforms = {}
+ self.output_inverse_transforms = {}
+ for transform_descriptor in transform_channels:
+ channel, transformation = transform_descriptor.split('-')
+ input_channel_idx = self._era_info['select'].index(channel) if channel in self._era_info['select'] else None
+ # output_channel_idx = self._real_info['select'].index(ERA5_TO_REAL_CHANNEL_MAP[channel]) if ERA5_TO_REAL_CHANNEL_MAP[channel] in self._real_info['select'] else None
+ output_channel_idx = self._real_info['select'].index(channel) if channel in self._real_info['select'] else None
+ if transformation.startswith('box_cox'):
+ lmbda_str = transformation.split('_')[-1]
+ lmbda = float(transformation.split('_')[-1])/(10**(len(lmbda_str)-1))
+ print(f"Applying Box-Cox transformation with lambda={lmbda} to channel {channel} (input idx: {input_channel_idx}, output idx: {output_channel_idx})")
+ if input_channel_idx is not None:
+ self.input_transforms[input_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_transform(x, lmbda)
+ self.input_inverse_transforms[input_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_inverse_transform(x, lmbda)
+ self.input_mean[input_channel_idx] = torch.load(os.path.join(self._info_path,f"era5-{transform_descriptor}-mean"), weights_only=False)
+ self.input_std[input_channel_idx] = torch.load(os.path.join(self._info_path,f"era5-{transform_descriptor}-std"), weights_only=False)
+ if output_channel_idx is not None:
+ self.output_transforms[output_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_transform(x, lmbda)
+ self.output_inverse_transforms[output_channel_idx] = lambda x, lmbda=lmbda: self.box_cox_inverse_transform(x, lmbda)
+ self.output_mean[output_channel_idx] = torch.load(os.path.join(self._info_path,f"realch1-{transform_descriptor}-mean"), weights_only=False)
+ self.output_std[output_channel_idx] = torch.load(os.path.join(self._info_path,f"realch1-{transform_descriptor}-std"), weights_only=False)
+ else:
+ raise ValueError(f"Transformation: {transformation} for channel {channel} not implemented.")
+
+ def __getitem__(self, idx):
+ """Get real and era5 interpolated to real grid"""
+ # get data point
+ # squeeze the ensemble dimesnsion
+ # reshape to image_shape
+ # flip so that it starts in top-left corner (by default it is bottom left)
+ # orig_shape = [350,542] #TODO currently padding to be divisible by 16
+ orig_shape = self.image_shape()
+ try:
+ era5_data = np.load(os.path.join(self._era5_path,self._file_list[idx]), mmap_mode='r')[self._era_indeces]
+ except:
+ logger.error(f"Error loading file {os.path.join(self._era5_path,self._file_list[idx])}")
+ raise
+ era5_data = np.flip(era5_data \
+ .squeeze() \
+ .reshape(-1,*orig_shape),
+ 1)
+ era5_data = np.concatenate((era5_data, self.static_data), axis=0) if self.static_data is not None else era5_data
+ era5_data = self.normalize_input(era5_data)
+
+ try:
+ real_data = np.load(os.path.join(self._real_path,self._file_list[idx]), mmap_mode='r')[self._real_indeces]
+ except:
+ logger.error(f"Error loading file {os.path.join(self._real_path,self._file_list[idx])}")
+ raise
+ real_data = np.flip(real_data\
+ .squeeze() \
+ .reshape(-1,*orig_shape),
+ 1)
+ real_data = self.normalize_output(real_data)
+
+ if self._n_month_hour_channels is not None and self._n_month_hour_channels>0:
+ # extract month and hour from filename
+ filename = self._file_list[idx]
+ date_str, hour_str = filename.split('-')
+ month = int(date_str[4:6])
+ hour = int(hour_str[0:2])
+
+ time_grid = self.make_time_grids(hour, month)
+ era5_data = np.concatenate((era5_data, time_grid), axis=0)
+
+ return torch.tensor(real_data),\
+ torch.tensor(era5_data)
+
+ def __len__(self):
+ return len(self._file_list)
+
+
+ def longitude(self) -> np.ndarray:
+ """Get longitude values from the dataset."""
+ lat_lon = torch.load(os.path.join(self._info_path,output_dir_name+'-lat-lon'), weights_only=False)
+ return lat_lon[:,1]
+
+
+ def latitude(self) -> np.ndarray:
+ """Get latitude values from the dataset."""
+ lat_lon = torch.load(os.path.join(self._info_path,output_dir_name+'-lat-lon'), weights_only=False)
+ return lat_lon[:,0]
+
+
+ def input_channels(self) -> List[ChannelMetadata]:
+ """Metadata for the input channels. A list of ChannelMetadata, one for each channel"""
+ channels = self._era_channels + self._static_channels if self.static_data is not None else self._era_channels
+ if self._n_month_hour_channels is not None and self._n_month_hour_channels>0:
+ for i in range(self._n_month_hour_channels):
+ channels.append(ChannelMetadata("hour-enc",f"{i}"))
+ for i in range(self._n_month_hour_channels):
+ channels.append(ChannelMetadata("month-enc",f"{i}"))
+ return channels
+
+ def output_channels(self) -> List[ChannelMetadata]:
+ """Metadata for the output channels. A list of ChannelMetadata, one for each channel"""
+ return self._real_channels
+
+
+ def time(self) -> List:
+ """Get time values from the dataset."""
+ #TODO Choose the time format and convert to that, currently it's a string from a filename
+ return [file.split('.')[0] for file in self._file_list]
+
+
+ def image_shape(self) -> Tuple[int, int]:
+ """Get the (height, width) of the data (same for input and output)."""
+ #TODO load from info, I hardcode it for now (real from anemoi-datasets minus trim-edge=20)
+ return 704,1088
+
+
+ def normalize_input(self, x: np.ndarray) -> np.ndarray:
+ """Convert input from physical units to normalized data."""
+ for channel_idx, transform in self.input_transforms.items():
+ x[channel_idx,::] = transform(x[channel_idx,::])
+ return (x - self.input_mean.reshape((self.input_mean.shape[0],1,1))) \
+ / self.input_std.reshape((self.input_std.shape[0],1,1))
+
+
+ def denormalize_input(self, x: np.ndarray) -> np.ndarray:
+ """Convert input from normalized data to physical units."""
+ if self._n_month_hour_channels is not None and self._n_month_hour_channels>0:
+ x = x[:,:-2*self._n_month_hour_channels,:,:]
+ x = x * self.input_std.reshape((self.input_std.shape[0],1,1)) \
+ + self.input_mean.reshape((self.input_mean.shape[0],1,1))
+ for channel_idx, inverse_transform in self.input_inverse_transforms.items():
+ x[:,channel_idx,::] = inverse_transform(x[:,channel_idx,::])
+ return x
+
+
+ def normalize_output(self, x: np.ndarray) -> np.ndarray:
+ """Convert output from physical units to normalized data."""
+ for channel_idx, transform in self.output_transforms.items():
+ x[channel_idx,::] = transform(x[channel_idx,::])
+ return (x - self.output_mean.reshape((self.output_mean.shape[0],1,1))) \
+ / self.output_std.reshape((self.output_std.shape[0],1,1))
+
+
+ def denormalize_output(self, x: np.ndarray) -> np.ndarray:
+ """Convert output from normalized data to physical units."""
+ x = x * self.output_std.reshape((self.output_std.shape[0],1,1)) \
+ + self.output_mean.reshape((self.output_mean.shape[0],1,1))
+ for channel_idx, inverse_transform in self.output_inverse_transforms.items():
+ x[:,channel_idx,::] = inverse_transform(x[:,channel_idx,::])
+ return x
+
+ def box_cox_transform(self, channel_array: np.ndarray, lmbda: float) -> np.ndarray:
+ """Apply Box-Cox transformation to the data."""
+ channel_array = np.clip(channel_array, 0, None)
+ return (np.power(channel_array, lmbda) - 1) / lmbda
+
+ def box_cox_inverse_transform(self, channel_array: np.ndarray, lmbda: float) -> np.ndarray:
+ """Apply inverse Box-Cox transformation to the data."""
+ channel_array = np.clip(channel_array, -1/lmbda, None)
+ return np.power((lmbda * channel_array) + 1, 1 / lmbda)
+
+ def make_time_grids(self, hour, month):
+ """
+ Create multi-frequency cyclic sin/cos feature grids for hour and month.
+
+ Parameters
+ ----------
+ hour : int
+ Hour of day, 0-23
+ month : int
+ Month of year, 1-12
+
+ Returns
+ -------
+ grid : np.ndarray, shape (C, H, W)
+ Channels = [sin(k*hour), cos(k*hour), sin(k*month), cos(k*month) for each k frequency]
+ """
+ H, W = self.image_shape()
+ hour_freqs = np.arange(1, self._n_month_hour_channels//2 + 1)
+ month_freqs = np.arange(1, self._n_month_hour_channels//2 + 1)
+
+ channels = []
+
+ # --- hour encodings ---
+ for k in hour_freqs:
+ angle = 2 * np.pi * k * (hour % 24) / 24.0
+ channels.append(np.sin(angle))
+ channels.append(np.cos(angle))
+
+ # --- month encodings ---
+ for k in month_freqs:
+ angle = 2 * np.pi * k * ((month - 1) % 12) / 12.0
+ channels.append(np.sin(angle))
+ channels.append(np.cos(angle))
+
+ channels = np.array(channels, dtype=np.float32)
+ grid = np.tile(channels[:, None, None], (1, H, W)) # (C, H, W)
+
+ return grid
\ No newline at end of file
diff --git a/src/hirad/distributed/manager.py b/src/hirad/distributed/manager.py
index eca46c63..75c0fe7d 100644
--- a/src/hirad/distributed/manager.py
+++ b/src/hirad/distributed/manager.py
@@ -529,7 +529,7 @@ def setup(
DistributedManager._shared_state["_is_initialized"] = True
manager = DistributedManager()
- manager._distributed = torch.distributed.is_available()
+ manager._distributed = torch.distributed.is_available() and world_size > 1
if manager._distributed:
# Update rank and world_size if using distributed
manager._rank = rank
@@ -546,22 +546,23 @@ def setup(
#TODO device_id makes the init hang, couldn't figure out why
if manager._distributed:
# Setup distributed process group
- # try:
- dist.init_process_group(
- backend,
- rank=manager.rank,
- world_size=manager.world_size,
- )
+ try:
+ dist.init_process_group(
+ backend,
+ rank=manager.rank,
+ world_size=manager.world_size,
+ device_id=manager.device,
+ )
# rank=manager.rank,
# world_size=manager.world_size,
# device_id=manager.device,
- # except TypeError:
- # # device_id only introduced in PyTorch 2.3
- # dist.init_process_group(
- # backend,
- # rank=manager.rank,
- # world_size=manager.world_size,
- # )
+ except TypeError:
+ # device_id only introduced in PyTorch 2.3
+ dist.init_process_group(
+ backend,
+ rank=manager.rank,
+ world_size=manager.world_size,
+ )
if torch.cuda.is_available():
# Set device for this process and empty cache to optimize memory usage
diff --git a/src/hirad/eval.sh b/src/hirad/eval.sh
new file mode 100644
index 00000000..39890790
--- /dev/null
+++ b/src/hirad/eval.sh
@@ -0,0 +1,39 @@
+#!/bin/bash
+
+#SBATCH --job-name="testrun"
+
+### HARDWARE ###
+#SBATCH --partition=normal
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=2
+#SBATCH --gpus-per-node=2
+#SBATCH --cpus-per-task=72
+##SBATCH --time=00:30:00
+#SBATCH --no-requeue
+#SBATCH --exclusive
+
+### OUTPUT ###
+#SBATCH --output=./logs/eval_compute.log
+
+### ENVIRONMENT ####
+#SBATCH -A a161
+
+# Choose method to initialize dist in pythorch
+export DISTRIBUTED_INITIALIZATION_METHOD=SLURM
+
+MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)"
+echo "Master node : $MASTER_ADDR"
+# Get IP for hostname.
+MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')"
+echo "Master address : $MASTER_ADDR"
+export MASTER_ADDR
+export MASTER_PORT=29500
+echo "Master port: $MASTER_PORT"
+
+export OMP_NUM_THREADS=1
+
+# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml
+srun --environment=./ci/edf/modulus_env.toml bash -c "
+ pip install -e . --no-dependencies
+ python src/hirad/eval/compute_eval.py --config-name=generate_era_cosmo.yaml
+"
\ No newline at end of file
diff --git a/src/hirad/eval/__init__.py b/src/hirad/eval/__init__.py
index 13d9eb37..3401c73d 100644
--- a/src/hirad/eval/__init__.py
+++ b/src/hirad/eval/__init__.py
@@ -1,2 +1,3 @@
-from .metrics import compute_mae, average_power_spectrum
-from .plotting import plot_error_projection, plot_power_spectra
+from .metrics import absolute_error, compute_mae, average_power_spectrum, crps
+from .plotting import plot_map, plot_error_projection, plot_power_spectra, plot_scores_vs_t
+from .eval_utils import GridConfig
diff --git a/src/hirad/eval/compute_eval.py b/src/hirad/eval/compute_eval.py
new file mode 100644
index 00000000..00e524e3
--- /dev/null
+++ b/src/hirad/eval/compute_eval.py
@@ -0,0 +1,229 @@
+import hydra
+import logging
+import os
+import json
+from omegaconf import OmegaConf, DictConfig
+import torch
+import numpy as np
+import contextlib
+import datetime
+from pandas import to_datetime
+
+from hirad.distributed import DistributedManager
+from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper
+from concurrent.futures import ThreadPoolExecutor
+
+from hirad.eval import absolute_error, crps, plot_scores_vs_t, plot_error_projection
+from hirad.models import EDMPrecondSuperResolution, UNet
+from hirad.inference import Generator
+from hirad.utils.inference_utils import save_images, save_results_as_torch
+from hirad.utils.function_utils import get_time_from_range
+from hirad.utils.checkpoint import load_checkpoint
+
+from hirad.datasets import get_dataset_and_sampler_inference
+
+from hirad.utils.train_helpers import set_patch_shape
+
+@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate")
+def main(cfg: DictConfig) -> None:
+
+
+ # Initialize distributed manager
+ DistributedManager.initialize()
+ dist = DistributedManager()
+ device = dist.device
+
+ # Initialize logger
+ logger = PythonLogger("generate") # General python logger
+
+ if cfg.generation.times_range:
+ times = get_time_from_range(cfg.generation.times_range, time_format="%Y%m%d-%H%M")
+
+ dataset_cfg = OmegaConf.to_container(cfg.dataset)
+ if "has_lead_time" in cfg.generation:
+ has_lead_time = cfg.generation["has_lead_time"]
+ else:
+ has_lead_time = False
+ dataset, sampler = get_dataset_and_sampler_inference(
+ dataset_cfg=dataset_cfg, times=times, has_lead_time=has_lead_time
+ )
+ pred_path = getattr(cfg.generation.io, "output_path", "./outputs")
+ output_path = './plots/analysis202511'
+
+ compute_crps_per_time(times, dataset, pred_path, output_path)
+ compute_crps_over_time_and_area(times, output_path)
+ plot_crps_over_time_and_area(times, dataset, output_path)
+
+def _get_data_path(output_path, time=None, filename=None):
+ if time:
+ return os.path.join(output_path, time, filename)
+ else:
+ return os.path.join(output_path, filename)
+
+def load_data(output_path, time=None, filename=None):
+ return torch.load(_get_data_path(output_path, time, filename), weights_only=False)
+
+def save_data(data, output_path, time=None, filename=None):
+ path = _get_data_path(output_path, time, filename)
+ torch.save(data, path)
+
+def compute_crps_per_time(times, dataset, pred_path, output_path):
+ logging.info('Computing CRPS for each time point')
+ input_channels = dataset.input_channels()
+ output_channels = dataset.output_channels()
+ start_time=times[0]
+
+ # Load one prediction ensemble to get the shape
+ prediction_ensemble = torch.load(os.path.join(pred_path, start_time, f'{start_time}-predictions'), weights_only=False)
+
+ # Get a map of output to input channel, for building baseline errors
+ output_to_input_channel_map = {}
+ for j in range(len(output_channels)):
+ index = -1
+ for k in range(len(input_channels)):
+ if input_channels[k].name == output_channels[j].name:
+ index = k
+ output_to_input_channel_map[j] = index
+
+
+ for i in range(len(times)):
+ curr_time = times[i]
+ if i % (24*5) == 0:
+ logging.info(f'on time {curr_time}')
+ prediction_ensemble = load_data(pred_path, time=curr_time, filename=f'{curr_time}-predictions')
+ baseline = load_data(pred_path, time=curr_time, filename=f'{curr_time}-baseline')
+ target = load_data(pred_path, time=curr_time, filename=f'{curr_time}-target')
+
+ # Calculate ensemble mean error
+ ensemble_mean = np.mean(prediction_ensemble, 0)
+ ensemble_mean_error = absolute_error(ensemble_mean, target)
+
+ # Calculate interpolation error (baseline #1)
+ interpolation_error = np.zeros(target.shape)
+ for j in range(len(output_channels)):
+ k = output_to_input_channel_map[j]
+ if k > -1:
+ interpolation_error[j,::] = absolute_error(baseline[k,::], target[j,::])
+
+ # Calculate persistence error (baseline #2)
+ persistence_error = np.zeros(target.shape)
+ if i > 0:
+ prev = load_data(pred_path, time=times[i-1], filename=f'{times[i-1]}-target')
+ persistence_error = absolute_error(prev, target)
+ else:
+ # for the first time point, persist the next-time-point target.
+ # This is fiction but it keeps the plots from looking weird.
+ prev = load_data(pred_path, time=times[i+1], filename=f'{times[i+1]}-target')
+ persistence_error = absolute_error(prev, target)
+
+
+ # Calculate CRPS
+ crps_diffusion_area = crps(prediction_ensemble, target, average_over_area=False, average_over_channels=False)
+
+ save_data(crps_diffusion_area, output_path, time=curr_time, filename=f'{curr_time}-crps-ensemble')
+ save_data(ensemble_mean_error, output_path, time=curr_time, filename=f'{curr_time}-ensemble-mean-error')
+ save_data(interpolation_error, output_path, time=curr_time, filename=f'{curr_time}-interpolation-error')
+ save_data(persistence_error, output_path, time=curr_time, filename=f'{curr_time}-persistence-error')
+
+def compute_crps_over_time_and_area(times, output_path):
+ logging.info('computing crps and errors')
+ start_time=times[0]
+ end_time=times[-1]
+
+ # shape = (channels, x, y)
+ crps_area = load_data(output_path, time=start_time, filename=f'{start_time}-crps-ensemble')
+ num_channels = crps_area.shape[0]
+
+ # make area and time plot
+ total_crps_area = np.zeros_like(crps_area)
+ total_ensemble_mean_area = np.zeros_like(crps_area)
+ total_interpolation_area = np.zeros_like(crps_area)
+ total_persistence_area = np.zeros_like(crps_area)
+
+ crps_over_time = np.zeros((num_channels, len(times)))
+ ensemble_mean_over_time = np.zeros_like(crps_over_time)
+ interpolation_over_time = np.zeros_like(crps_over_time)
+ persistence_over_time = np.zeros_like(crps_over_time)
+ for i in range(len(times)):
+ curr_time = times[i]
+ if i % (24*5) == 0:
+ logging.info(f'on time {times[i]}')
+ crps_area = load_data(output_path, time=curr_time, filename=f'{curr_time}-crps-ensemble')
+ total_crps_area = total_crps_area + crps_area
+
+ ensemble_mean_area = load_data(output_path, time=curr_time, filename=f'{curr_time}-ensemble-mean-error')
+ total_ensemble_mean_area = total_ensemble_mean_area + ensemble_mean_area
+ interpolation_area = load_data(output_path, time=curr_time, filename=f'{curr_time}-interpolation-error')
+ total_interpolation_area = total_interpolation_area + interpolation_area
+ persistence_area = load_data(output_path, time=curr_time, filename=f'{curr_time}-persistence-error')
+ if i>0:
+ total_persistence_area = total_persistence_area + persistence_area
+
+ for j in range(num_channels):
+ crps_over_time[j,i] = np.mean(crps_area[j,::])
+ ensemble_mean_over_time[j,i] = np.mean(ensemble_mean_area[j,::])
+ interpolation_over_time[j,i] = np.mean(interpolation_area[j,::])
+ persistence_over_time[j,i] = np.mean(persistence_area[j,::])
+ mean_crps_area = total_crps_area / len(times)
+ mean_ensemble_mean_area = total_ensemble_mean_area / len(times)
+ mean_interpolation_area = total_interpolation_area / len(times)
+ mean_persistence_area = total_persistence_area / (len(times)-1)
+ save_data(mean_crps_area, output_path, filename=f'crps-ensemble-area-{start_time}-{end_time}')
+ save_data(mean_ensemble_mean_area, output_path, filename=f'mae-ensemble-mean-area-{start_time}-{end_time}')
+ save_data(mean_interpolation_area, output_path, filename=f'mae-interpolation-area-{start_time}-{end_time}')
+ save_data(mean_persistence_area, output_path, filename=f'mae-persistence-area-{start_time}-{end_time}')
+
+ # Little hack to make the plots look nicer, without having to change dimensions.
+ persistence_over_time[:,0] = persistence_over_time[:,1]
+
+ save_data(crps_over_time, output_path, filename=f'crps-ensemble-time-{start_time}-{end_time}')
+ save_data(ensemble_mean_over_time, output_path, filename=f'mae-ensemble-mean-time-{start_time}-{end_time}')
+ save_data(interpolation_over_time, output_path, filename=f'mae-interpolation-time-{start_time}-{end_time}')
+ save_data(persistence_over_time, output_path, filename=f'mae-persistence-time-{start_time}-{end_time}')
+
+def plot_crps_over_time_and_area(times, dataset, output_path):
+ logging.info('plotting crps and errors')
+ longitudes = dataset.longitude()
+ latitudes = dataset.latitude()
+ output_channels = dataset.output_channels()
+ start_time=times[0]
+ end_time=times[-1]
+
+ crps_area = load_data(output_path, filename=f'crps-ensemble-area-{start_time}-{end_time}')
+ ensemble_mean_area = load_data(output_path, filename=f'mae-ensemble-mean-area-{start_time}-{end_time}')
+ interpolation_area = load_data(output_path, filename=f'mae-interpolation-area-{start_time}-{end_time}')
+ persistence_area = load_data(output_path, filename=f'mae-persistence-area-{start_time}-{end_time}')
+
+ crps_ensemble_time = load_data(output_path, filename=f'crps-ensemble-time-{start_time}-{end_time}')
+ ensemble_mean_time = load_data(output_path, filename=f'mae-ensemble-mean-time-{start_time}-{end_time}')
+ interpolation_time = load_data(output_path, filename=f'mae-interpolation-time-{start_time}-{end_time}')
+ persistence_time = load_data(output_path, filename=f'mae-persistence-time-{start_time}-{end_time}')
+
+ for j in range(len(output_channels)):
+ plot_error_projection(crps_area[j,::], latitudes, longitudes,
+ _get_data_path(output_path, filename=f'NEW-crps-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'),
+ label=output_channels[j].name, title=f'Mean absolute error: CRPS: {output_channels[j].name}')
+ plot_error_projection(ensemble_mean_area[j,::], latitudes, longitudes,
+ _get_data_path(output_path, filename=f'NEW-mae-ensemble-mean-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'),
+ label=output_channels[j].name, title=f'Mean absolute error: Ensemble mean: {output_channels[j].name}')
+ plot_error_projection(interpolation_area[j,::], latitudes, longitudes,
+ _get_data_path(output_path, filename=f'NEW-mae-interpolation-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'),
+ label=output_channels[j].name, title=f'Mean absolute error: Interpolation: {output_channels[j].name}')
+ plot_error_projection(persistence_area[j,::], latitudes, longitudes,
+ _get_data_path(output_path, filename=f'NEW-mae-persistence-area-{start_time}-{end_time}-{output_channels[j].name}.jpg'),
+ label=output_channels[j].name, title=f'Mean absolute error: Persistence: {output_channels[j].name}')
+
+ maes = {}
+ maes['interpolation'] = interpolation_time[j,::]
+ maes['ensemble mean'] = ensemble_mean_time[j,::]
+ maes['crps'] = crps_ensemble_time[j,:]
+ maes['persistence'] = persistence_time[j,::]
+ # TODO: consider casting times to datetime objects to avoid warnings.
+ # However, this seems to be working OK, and a direct cast causes plotting errors
+ plot_scores_vs_t(maes, times,
+ _get_data_path(output_path, filename=f'NEW-error-plot-time-{start_time}-{end_time}-{output_channels[j].name}.jpg'),
+ title=f'Mean absolute error: {output_channels[j].name}', xlabel='time', ylabel='MAE')
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py b/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py
new file mode 100644
index 00000000..1bc62092
--- /dev/null
+++ b/src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py
@@ -0,0 +1,146 @@
+import logging
+from datetime import datetime
+from pathlib import Path
+
+import hydra
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import xarray as xr
+
+from hirad.eval.eval_utils import concat_and_group_diurnal, get_channel_indices, load_generation_setup, load_land_sea_mask, parse_eval_cli, resolve_ts_dir
+
+def save_plot(hour, means, stds, labels, ylabel, title, out_path):
+ hrs = np.concatenate([hour.values, [24]])
+ plt.figure(figsize=(8,4))
+ for mean, std, label in zip(means, stds, labels):
+ vals = np.append(mean.values, mean.values[0])
+ line, = plt.plot(hrs, vals, label=label)
+ if std is not None:
+ stdv = np.append(std.values, std.values[0])
+ plt.fill_between(hrs, np.maximum(vals - stdv, 0), vals + stdv, color=line.get_color(), alpha=0.3)
+ plt.xlabel('Hour (UTC)')
+ plt.xticks(range(0,25,3))
+ plt.xlim(0,24)
+ plt.ylabel(ylabel)
+ plt.title(title)
+ plt.grid(True)
+ plt.legend()
+ plt.tight_layout()
+ Path(out_path).parent.mkdir(parents=True, exist_ok=True)
+ plt.savefig(out_path)
+ plt.close()
+
+def main(cfg: dict):
+ # Setup logging
+ logging.basicConfig(level=logging.INFO)
+ logger = logging.getLogger(__name__)
+
+ logger.info("Starting computations for diurnal cycle of precipitation amount and wet-hours")
+ try:
+ generation_dir, gen_cfg, times = load_generation_setup(cfg)
+ except ValueError as exc:
+ logger.error(str(exc))
+ return
+ datetimes = [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]
+ logger.info(f"Loaded {len(times)} timesteps to process")
+
+ indices = get_channel_indices(gen_cfg)
+
+ # Location of the output from inference
+ out_root = Path(generation_dir)
+
+ # Find channel indices
+ indices = get_channel_indices(gen_cfg)
+ tp_out = indices['output']['tp']
+ tp_in = indices['input'].get('tp', tp_out)
+ logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}")
+
+ # Land-sea mask
+ land_mask = load_land_sea_mask(cfg.get("land_sea_mask_path"), cfg.get("height"), cfg.get("width"))
+
+ # Prepare lists to collect DataArrays
+ target_precip, baseline_precip, pred_precip, mean_pred_precip = [], [], [], []
+ target_wet, baseline_wet, pred_wet, mean_pred_wet = [], [], [], []
+
+ # Collect data
+ for idx, ts in enumerate(times, 1):
+ dt = datetimes[idx-1]
+ target = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-target", weights_only=False)[tp_out] * cfg.get("conv_factor")
+ baseline = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-baseline", weights_only=False)[tp_in] * cfg.get("conv_factor")
+ preds = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-predictions", weights_only=False)[:, tp_out, :, :] * cfg.get("conv_factor")
+ try:
+ mean_pred = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-regression-prediction", weights_only=False)[tp_out] * cfg.get("conv_factor")
+ except:
+ mean_pred = None
+
+ # DataArrays for spatial means at each timestep
+ da_target = xr.DataArray(target, dims=("lat","lon"), coords=land_mask.coords)
+ da_baseline = xr.DataArray(baseline, dims=("lat","lon"), coords=land_mask.coords)
+ da_preds = xr.DataArray(preds, dims=("member","lat","lon"), coords={"member": np.arange(preds.shape[0]), **land_mask.coords})
+ if mean_pred is not None:
+ da_mean_pred = xr.DataArray(mean_pred, dims=("lat","lon"), coords=land_mask.coords)
+
+ # Apply land mask after conversion to xarray
+ da_target = da_target * land_mask
+ da_baseline = da_baseline * land_mask
+ da_preds = da_preds * land_mask
+ if mean_pred is not None:
+ da_mean_pred = da_mean_pred * land_mask
+
+ # Spatial mean
+ target_precip.append(da_target.mean(dim=("lat","lon")).assign_coords(time=dt))
+ baseline_precip.append(da_baseline.mean(dim=("lat","lon")).assign_coords(time=dt))
+ pred_precip.append(da_preds.mean(dim=("lat","lon")).assign_coords(time=dt))
+ if mean_pred is not None:
+ mean_pred_precip.append(da_mean_pred.mean(dim=("lat","lon")).assign_coords(time=dt))
+
+ # Wet-hour fraction, i.e., freq(precip) > wet_threshold
+ target_wet.append(((da_target / 24 > cfg.get("wet_threshold")).mean().assign_coords(time=dt)))
+ baseline_wet.append(((da_baseline / 24 > cfg.get("wet_threshold")).mean().assign_coords(time=dt)))
+ pred_wet.append(((da_preds / 24> cfg.get("wet_threshold")).mean(dim=("lat","lon")).assign_coords(time=dt)))
+ if mean_pred is not None:
+ mean_pred_wet.append(((da_mean_pred / 24 > cfg.get("wet_threshold")).mean().assign_coords(time=dt)))
+
+ if idx % cfg.get("log_interval") == 0 or idx == len(times):
+ logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})")
+
+ # Compute diurnal means and stds
+ amount_target_mean, _ = concat_and_group_diurnal(target_precip)
+ amount_baseline_mean, _ = concat_and_group_diurnal(baseline_precip)
+ amount_pred_mean, amount_pred_std = concat_and_group_diurnal(pred_precip, is_member=True)
+ if mean_pred_precip:
+ amount_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_precip)
+
+ wet_target_mean, _ = concat_and_group_diurnal(target_wet, scale=100.0) # scale to obtain percentages
+ wet_baseline_mean, _ = concat_and_group_diurnal(baseline_wet, scale=100.0)
+ wet_pred_mean, wet_pred_std = concat_and_group_diurnal(pred_wet, is_member=True, scale=100.0)
+ if mean_pred_wet:
+ wet_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_wet, scale=100.0)
+
+ # Generate plots
+ output_path = out_root / cfg.get("results_dir_name", "evaluation_maps")
+ output_path.mkdir(parents=True, exist_ok=True)
+ save_plot(
+ amount_target_mean.hour,
+ [amount_target_mean, amount_baseline_mean, amount_pred_mean, amount_mean_pred_mean] if mean_pred_precip else [amount_target_mean, amount_baseline_mean, amount_pred_mean],
+ [None, None, amount_pred_std, None] if mean_pred_precip else [None, None, amount_pred_std],
+ ['Target','Input','CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_precip else ['Target','Input','CorrDiff ± Std(Members)'],
+ 'Precipitation (mm/day)',
+ 'Diurnal Cycle of Precip Amount',
+ output_path / 'diurnal_cycle_precip_amount.png'
+ )
+ save_plot(
+ wet_target_mean.hour,
+ [wet_target_mean, wet_baseline_mean, wet_pred_mean, wet_mean_pred_mean] if mean_pred_wet else [wet_target_mean, wet_baseline_mean, wet_pred_mean],
+ [None, None, wet_pred_std, None] if mean_pred_wet else [None, None, wet_pred_std],
+ ['Target','Input','CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_wet else ['Target','Input','CorrDiff ± Std(Members)'],
+ 'Wet-Hour Fraction [%]',
+ 'Diurnal Cycle of Wet-Hours (>0.1 mm/h)',
+ output_path / 'diurnal_cycle_precip_wethours.png'
+ )
+
+ logger.info("Plots saved.")
+
+if __name__ == '__main__':
+ main(parse_eval_cli())
\ No newline at end of file
diff --git a/src/hirad/eval/diurnal_cycle_precip_p99.py b/src/hirad/eval/diurnal_cycle_precip_p99.py
new file mode 100644
index 00000000..ecd257cb
--- /dev/null
+++ b/src/hirad/eval/diurnal_cycle_precip_p99.py
@@ -0,0 +1,165 @@
+"""
+Plots the diurnal cycle of the all-hour 99th percentile of
+precipitation, a somewhat reliable measure of the precipitation intensity.
+
+Each hour, member and type is treaded separately, to conserve memory... but if the
+period is long, this can still be a lot of data and thus an OOM error can occur.
+"""
+import logging
+from datetime import datetime
+from pathlib import Path
+
+import hydra
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import xarray as xr
+
+from hirad.eval.eval_utils import get_channel_indices, load_generation_setup, load_land_sea_mask, parse_eval_cli, resolve_ts_dir
+
+
+def save_plot(hours, lines, labels, ylabel, title, out_path):
+ Path(out_path).parent.mkdir(parents=True, exist_ok=True)
+ plt.figure(figsize=(8,4))
+ for data, label in zip(lines, labels):
+ if isinstance(data, tuple): # (mean, std)
+ mean, std = data
+ lower = np.maximum(np.array(mean) - std, 0)
+ upper = np.array(mean) + std
+ line, = plt.plot(hours, mean, label=label)
+ plt.fill_between(hours, lower, upper, alpha=0.3, color=line.get_color())
+ else:
+ plt.plot(hours, data, label=label)
+ plt.xlabel('Hour (UTC)')
+ plt.xticks(range(0,25,3))
+ plt.xlim(0,24)
+ plt.ylabel(ylabel)
+ plt.title(title)
+ plt.grid(True)
+ plt.legend()
+ plt.tight_layout()
+ plt.savefig(out_path)
+ plt.close()
+
+
+def main(cfg: dict):
+ # Setup logging
+ logging.basicConfig(level=logging.INFO)
+ logger = logging.getLogger(__name__)
+
+ logger.info("Starting computation for diurnal cycle of 99th-percentile of precipitation")
+ try:
+ generation_dir, gen_cfg, times = load_generation_setup(cfg)
+ except ValueError as exc:
+ logger.error(str(exc))
+ return
+ logger.info(f"Loaded {len(times)} timesteps to process")
+
+ # Output root
+ out_root = Path(generation_dir)
+
+ # Find channel indices
+ indices = get_channel_indices(gen_cfg)
+ tp_out = indices['output']['tp']
+ tp_in = indices['input'].get('tp', tp_out)
+ logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}")
+
+ # Land-sea mask
+ land_mask = load_land_sea_mask(cfg.get("land_sea_mask_path"), cfg.get("height"), cfg.get("width"))
+
+ # Storage for diurnal cycles
+ pct99_mean = {}
+ pct99_std = {}
+
+ # -- Process target and baseline --
+ for mode in ['target', 'baseline', 'regression-prediction']:
+ logger.info(f"Processing mode: {mode}")
+
+ data_list = []
+ try:
+ for ts in times:
+ data = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode in ['target','regression-prediction'] else tp_in] * cfg.get("conv_factor")
+ data_list.append(data)
+ except:
+ logger.error(f"Error loading data for mode {mode}. Skipping.")
+ continue
+
+ da = xr.DataArray(
+ np.stack(data_list, axis=0),
+ dims=['time', 'lat', 'lon'],
+ coords={'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times],
+ 'lat': land_mask.coords['lat'], 'lon': land_mask.coords['lon']}
+ )
+
+ # Select only land pixels to avoid all-NaN slices in quantile
+ land_bool = land_mask.notnull().stack(space=('lat', 'lon'))
+ da_land = da.stack(space=('lat', 'lon')).isel(space=land_bool.values)
+
+ # Group by hour and compute 99th percentile over time, then spatial mean
+ hourly_p99 = da_land.groupby('time.hour').quantile(0.99, dim='time')
+ pct99_mean[mode] = hourly_p99.mean(dim='space')
+
+ # -- Predictions: compute per hour per member, then mean+std across members --
+ logger.info("Processing predictions")
+
+ # Load all prediction data at once into xarray
+ pred_data_list = []
+ for ts in times:
+ preds = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-predictions", weights_only=False) * cfg.get("conv_factor") # [n_members, n_channels, lat, lon]
+ tp_data = preds[:, tp_out] # [n_members, lat, lon]
+ tp_da = xr.DataArray(tp_data, dims=['member', 'lat', 'lon'],
+ coords={'lat': land_mask.coords['lat'], 'lon': land_mask.coords['lon']})
+ pred_data_list.append(tp_da)
+
+ pred_da = xr.concat(pred_data_list, dim='time') # [member, time, lat, lon]
+ pred_da = pred_da.assign_coords({
+ 'time': [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]
+ })
+ pred_da = pred_da.transpose('member', 'time', 'lat', 'lon')
+
+ # Select only land pixels to avoid all-NaN slices in quantile
+ land_bool = land_mask.notnull().stack(space=('lat', 'lon'))
+ pred_da_land = pred_da.stack(space=('lat', 'lon')).isel(space=land_bool.values)
+
+ logger.info('Calculating 99th percentile for predictions')
+ # Group by hour, compute 99th percentile across time, then spatial mean over land
+ hourly_p99_by_member = pred_da_land.groupby('time.hour').quantile(0.99, dim='time').mean(dim='space')
+
+ # Store ensemble statistics as xarray DataArrays
+ pct99_mean['prediction'] = hourly_p99_by_member.mean(dim='member')
+ pct99_std['prediction'] = hourly_p99_by_member.std(dim='member')
+
+ # Prepare cyclic lists for plotting
+ def cycle_fn(x):
+ return x.values.tolist() + [x.values.tolist()[0]]
+
+ logger.info("Preparing data for plotting")
+ hrs_c = list(range(24)) + [0 + 24]
+ pct99_lines = [
+ cycle_fn(pct99_mean['target']),
+ cycle_fn(pct99_mean['baseline']),
+ (
+ cycle_fn(pct99_mean['prediction']),
+ cycle_fn(pct99_std['prediction'])
+ )
+ ]
+ if 'regression-prediction' in pct99_mean:
+ pct99_lines.append(cycle_fn(pct99_mean['regression-prediction']))
+
+ # Plot combined diurnal 99th-percentile cycle
+ labels = ['Target', 'Input', 'CorrDiff 99th Pct ± Std', 'Regression Prediction'] if 'regression-prediction' in pct99_mean else ['Target', 'Input', 'CorrDiff 99th Pct ± Std']
+ output_path = out_root / cfg.get("results_dir_name", "evaluation_maps")
+ output_path.mkdir(parents=True, exist_ok=True)
+ fn = output_path / 'diurnal_cycle_precip_99th_percentile.png'
+ save_plot(
+ hrs_c,
+ pct99_lines,
+ labels,
+ 'Precipitation (mm/day)',
+ 'Diurnal Cycle of 99th-Percentile Precipitation',
+ fn
+ )
+ logger.info(f"Combined plot saved: {fn}")
+
+if __name__ == '__main__':
+ main(parse_eval_cli())
\ No newline at end of file
diff --git a/src/hirad/eval/diurnal_cycle_temp_wind.py b/src/hirad/eval/diurnal_cycle_temp_wind.py
new file mode 100644
index 00000000..c9b63f8d
--- /dev/null
+++ b/src/hirad/eval/diurnal_cycle_temp_wind.py
@@ -0,0 +1,167 @@
+import logging
+from datetime import datetime
+from pathlib import Path
+
+import hydra
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import xarray as xr
+
+from hirad.eval.eval_utils import concat_and_group_diurnal, get_channel_indices, load_generation_setup, load_land_sea_mask, parse_eval_cli, resolve_ts_dir
+
+def main(cfg: dict):
+ # Initialize
+ logging.basicConfig(level=logging.INFO)
+ logger = logging.getLogger(__name__)
+
+ logger.info("Starting computation for diurnal cycles of 2m temperature and windspeed")
+ try:
+ generation_dir, gen_cfg, times = load_generation_setup(cfg)
+ except ValueError as exc:
+ logger.error(str(exc))
+ return
+ datetimes = [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]
+ logger.info(f"Loaded {len(times)} timesteps to process")
+
+ # Indices for channels
+ indices = get_channel_indices(gen_cfg)
+ out_ch = indices['output']
+ in_ch = indices['input']
+
+ # Temperature channel (try '2t' first, fallback to 't2m')
+ t2m_out = out_ch.get('2t', out_ch.get('t2m'))
+ t2m_in = in_ch.get('2t', in_ch.get('t2m', t2m_out))
+
+ # Wind channels
+ u_out = out_ch['10u']
+ u_in = in_ch.get('10u', u_out)
+ v_out = out_ch['10v']
+ v_in = in_ch.get('10v', v_out)
+
+ # Output path
+ out_root = Path(generation_dir)
+ def load(ts, fn):
+ return torch.load(resolve_ts_dir(out_root, ts) / ts / fn, weights_only=False)
+
+ # Land-sea mask
+ land_mask = load_land_sea_mask(cfg.get("land_sea_mask_path"), cfg.get("height"), cfg.get("width"))
+
+ # Prepare lists to collect DataArrays
+ target_temp, baseline_temp, pred_temp, mean_pred_temp = [], [], [], []
+ target_wind, baseline_wind, pred_wind, mean_pred_wind = [], [], [], []
+
+ def mean_over_land(data, dims, coords, time_coord):
+ da = xr.DataArray(data, dims=dims, coords=coords) * land_mask
+ return da.mean(dim=("lat","lon")).assign_coords(time=time_coord)
+
+ # Loop over timestamps
+ for idx, ts in enumerate(times, 1):
+ dt = datetimes[idx-1]
+
+ # Load data
+ target = load(ts, f"{ts}-target")
+ baseline = load(ts, f"{ts}-baseline")
+ predictions = load(ts, f"{ts}-predictions")
+ try:
+ regression_pred = load(ts, f"{ts}-regression-prediction")
+ except:
+ regression_pred = None
+
+ # Process temperature (convert to Celsius)
+ target_temp.append(mean_over_land(
+ target[t2m_out] - 273.15, ("lat","lon"), land_mask.coords, dt))
+ baseline_temp.append(mean_over_land(
+ baseline[t2m_in] - 273.15, ("lat","lon"), land_mask.coords, dt))
+ pred_temp.append(mean_over_land(
+ predictions[:, t2m_out, :, :] - 273.15, ("member","lat","lon"),
+ {"member": np.arange(predictions.shape[0]), **land_mask.coords}, dt))
+ if regression_pred is not None:
+ mean_pred_temp.append(mean_over_land(
+ regression_pred[t2m_out] - 273.15, ("lat","lon"), land_mask.coords, dt))
+
+
+ # Process wind speed
+ target_wind.append(mean_over_land(
+ np.hypot(target[u_out], target[v_out]), ("lat","lon"), land_mask.coords, dt))
+ baseline_wind.append(mean_over_land(
+ np.hypot(baseline[u_in], baseline[v_in]), ("lat","lon"), land_mask.coords, dt))
+ pred_wind.append(mean_over_land(
+ np.hypot(predictions[:, u_out, :, :], predictions[:, v_out, :, :]),
+ ("member","lat","lon"), {"member": np.arange(predictions.shape[0]), **land_mask.coords}, dt))
+ if regression_pred is not None:
+ mean_pred_wind.append(mean_over_land(
+ np.hypot(regression_pred[u_out], regression_pred[v_out]), ("lat","lon"), land_mask.coords, dt))
+
+ if idx % cfg.get("log_interval") == 0 or idx == len(times):
+ logger.info(f"Processed {idx}/{len(times)} timesteps ({ts})")
+
+ # Compute diurnal means and stds
+ temp_target_mean, _ = concat_and_group_diurnal(target_temp)
+ temp_baseline_mean, _ = concat_and_group_diurnal(baseline_temp)
+ temp_pred_mean, temp_pred_std = concat_and_group_diurnal(pred_temp, is_member=True)
+ if mean_pred_temp:
+ temp_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_temp)
+
+ wind_target_mean, _ = concat_and_group_diurnal(target_wind)
+ wind_baseline_mean, _ = concat_and_group_diurnal(baseline_wind)
+ wind_pred_mean, wind_pred_std = concat_and_group_diurnal(pred_wind, is_member=True)
+ if mean_pred_wind:
+ wind_mean_pred_mean, _ = concat_and_group_diurnal(mean_pred_wind)
+
+ def save_plot(hour, means, stds, labels, ylabel, title, out_path):
+ hrs = np.concatenate([hour.values, [24]])
+ plt.figure(figsize=(8,4))
+ for mean, std, label in zip(means, stds, labels):
+ vals = np.append(mean.values, mean.values[0])
+ line, = plt.plot(hrs, vals, label=label)
+ if std is not None:
+ stdv = np.append(std.values, std.values[0])
+ plt.fill_between(hrs, np.maximum(vals - stdv, 0), vals + stdv, color=line.get_color(), alpha=0.3)
+ plt.xlabel('Hour (UTC)')
+ plt.xticks(range(0,25,3))
+ plt.xlim(0,24)
+ plt.ylabel(ylabel)
+ plt.title(title)
+ plt.grid(True)
+ plt.legend()
+ plt.tight_layout()
+ Path(out_path).parent.mkdir(parents=True, exist_ok=True)
+ plt.savefig(out_path)
+ plt.close()
+
+ data = [temp_target_mean, temp_baseline_mean, temp_pred_mean, temp_mean_pred_mean] if mean_pred_temp else [temp_target_mean, temp_baseline_mean, temp_pred_mean]
+ labels = ['Target', 'Input', 'CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_temp else ['Target', 'Input', 'CorrDiff ± Std(Members)']
+ stds = [None, None, temp_pred_std, None] if mean_pred_temp else [None, None, temp_pred_std]
+
+ output_path = out_root / cfg.get("results_dir_name", "evaluation_maps")
+ output_path.mkdir(parents=True, exist_ok=True)
+ # Generate plots
+ save_plot(
+ temp_target_mean.hour,
+ data,
+ stds,
+ labels,
+ '2m Temperature [°C]',
+ 'Diurnal Cycle of 2m Temperature',
+ output_path / 'diurnal_cycle_2t.png'
+ )
+
+ data = [wind_target_mean, wind_baseline_mean, wind_pred_mean, wind_mean_pred_mean] if mean_pred_wind else [wind_target_mean, wind_baseline_mean, wind_pred_mean]
+ labels = ['Target', 'Input', 'CorrDiff ± Std(Members)', 'Regression Prediction'] if mean_pred_wind else ['Target', 'Input', 'CorrDiff ± Std(Members)']
+ stds = [None, None, wind_pred_std, None] if mean_pred_wind else [None, None, wind_pred_std]
+
+ save_plot(
+ wind_target_mean.hour,
+ data,
+ stds,
+ labels,
+ 'Windspeed [m/s]',
+ 'Diurnal Cycle of Windspeed',
+ output_path / 'diurnal_cycle_windspeed.png'
+ )
+
+ logger.info("Plots saved.")
+
+if __name__ == '__main__':
+ main(parse_eval_cli())
\ No newline at end of file
diff --git a/src/hirad/eval/eval_utils.py b/src/hirad/eval/eval_utils.py
new file mode 100644
index 00000000..680f5d71
--- /dev/null
+++ b/src/hirad/eval/eval_utils.py
@@ -0,0 +1,211 @@
+import argparse
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional, Tuple
+
+import numpy as np
+import xarray as xr
+import yaml
+
+from hirad.datasets import get_channels_from_strings, get_strings_from_channels, known_datasets
+from hirad.utils.function_utils import get_time_from_range
+
+
+@dataclass
+class GridConfig:
+ lat: np.ndarray
+ lon: np.ndarray
+ height: int
+ width: int
+ relax_zone: int
+
+
+DEFAULT_GRID_CONFIG = GridConfig(
+ lat=np.arange(-4.42, 3.36 + 0.02, 0.02),
+ lon=np.arange(-6.82, 4.80 + 0.02, 0.02),
+ height=352,
+ width=544,
+ relax_zone=19,
+)
+
+
+def grid_cfg_from_cfg(cfg) -> GridConfig:
+ """Build a :class:`GridConfig` from ``lat_*``/``lon_*``/``height``/``width``/``relax_zone`` fields of *cfg*."""
+ return GridConfig(
+ lat=np.arange(cfg.get("lat_start"), cfg.get("lat_end") + cfg.get("lat_step"), cfg.get("lat_step")),
+ lon=np.arange(cfg.get("lon_start"), cfg.get("lon_end") + cfg.get("lon_step"), cfg.get("lon_step")),
+ height=cfg.get("height"),
+ width=cfg.get("width"),
+ relax_zone=cfg.get("relax_zone"),
+ )
+
+
+def load_land_sea_mask(path, height=352, width=544):
+ """Load and return a land-sea mask as xarray DataArray."""
+ lsm_data = np.load(path).reshape(height, width)
+ return xr.DataArray(
+ np.where(lsm_data >= 0.5, 1.0, np.nan),
+ dims=['lat', 'lon'],
+ coords={"lat": np.arange(height), "lon": np.arange(width)},
+ )
+
+
+def concat_and_group_diurnal(list_of_da, is_member=False, scale=1.0):
+ """Concatenate DataArrays along ``time`` and compute diurnal mean (and member std)."""
+ da = xr.concat(list_of_da, dim="time")
+ if is_member:
+ mean = da.groupby("time.hour").mean(dim="time").mean(dim="member") * scale
+ std = da.std(dim="member").groupby("time.hour").mean(dim="time") * scale
+ else:
+ mean = da.groupby("time.hour").mean(dim="time") * scale
+ std = None
+ return mean, std
+
+
+def percentiles_from_histogram(hist_counts, bin_edges, percentiles_dict):
+ """Estimate percentiles from a histogram via linear interpolation on the CDF.
+
+ Parameters
+ ----------
+ hist_counts : np.ndarray
+ Raw (unnormalized) histogram counts per bin.
+ bin_edges : np.ndarray
+ Bin edges (length = ``len(hist_counts) + 1``).
+ percentiles_dict : dict
+ Mapping ``label -> fractional percentile`` (e.g. ``{99: 0.99}``).
+
+ Returns
+ -------
+ dict
+ Mapping ``label -> estimated value``.
+ """
+ cumulative = np.cumsum(hist_counts)
+ total = cumulative[-1]
+ if total == 0:
+ return {key: np.nan for key in percentiles_dict}
+
+ cdf = cumulative / total # CDF at upper bin edges
+ results = {}
+ for key, p in percentiles_dict.items():
+ idx = np.searchsorted(cdf, p)
+ if idx >= len(cdf):
+ results[key] = bin_edges[-1]
+ elif idx == 0:
+ frac = p / cdf[0] if cdf[0] > 0 else 0.0
+ results[key] = bin_edges[0] + frac * (bin_edges[1] - bin_edges[0])
+ else:
+ cdf_low, cdf_high = cdf[idx - 1], cdf[idx]
+ frac = (p - cdf_low) / (cdf_high - cdf_low) if (cdf_high - cdf_low) > 0 else 0.0
+ results[key] = bin_edges[idx] + frac * (bin_edges[idx + 1] - bin_edges[idx])
+ return results
+
+
+def load_generation_setup(cfg: dict) -> Tuple[Path, dict, list]:
+ """Validate ``cfg['inference_output_dir']``, load its generation config, and resolve times.
+
+ Returns ``(generation_dir, gen_cfg, times)``. Raises :class:`ValueError` on failure.
+ """
+ generation_dir = cfg.get("inference_output_dir")
+ if generation_dir is None:
+ raise ValueError("No inference_output_dir specified in config.")
+
+ generation_dir = Path(generation_dir)
+ if not generation_dir.is_dir():
+ raise ValueError(f"Inference output directory {generation_dir} does not exist or is not a directory.")
+
+ generation_config_path = min(generation_dir.glob("**/.hydra/config.yaml"), default=None)
+ if generation_config_path is None:
+ raise ValueError(f"No generation config file found in {generation_dir}.")
+
+ with open(generation_config_path, "r") as f:
+ gen_cfg = yaml.safe_load(f)
+
+ times = _resolve_times(cfg, gen_cfg)
+ if times is None:
+ raise ValueError("No times, times_range, or times_ranges specified in config or generation config.")
+
+ return generation_dir, gen_cfg, times
+
+
+def _resolve_times(cfg: dict, gen_cfg: dict, time_format: str = "%Y%m%d-%H%M") -> Optional[list]:
+ """Resolve timestep strings from eval cfg, falling back to ``gen_cfg['generation']``.
+
+ Priority (in each source): ``times_ranges`` > ``times_range`` > ``times``.
+ """
+ def _from(source: dict) -> Optional[list]:
+ if source.get("times_ranges"):
+ return [t for tr in source["times_ranges"] for t in get_time_from_range(tr, time_format=time_format)]
+ if source.get("times_range"):
+ return get_time_from_range(source["times_range"], time_format=time_format)
+ return source.get("times")
+
+ return _from(cfg) or _from(gen_cfg.get("generation", {}))
+
+
+def resolve_io_channels(gen_cfg: dict) -> Tuple[list, list]:
+ """Resolve ``(input_channels, output_channels)`` from a generation config.
+
+ Uses ``input_channel_names`` / ``output_channels_names`` from ``gen_cfg['dataset']``
+ when available; otherwise instantiates the dataset and queries it.
+ """
+ dataset_cfg = gen_cfg.get("dataset", {})
+ input_channels = get_channels_from_strings(dataset_cfg.get("input_channel_names", []))
+ output_channels = get_channels_from_strings(dataset_cfg.get("output_channels_names", []))
+ if not input_channels or not output_channels:
+ dataset = known_datasets[dataset_cfg.get("type")](**dataset_cfg)
+ input_channels = dataset.input_channels()
+ output_channels = dataset.output_channels()
+ return input_channels, output_channels
+
+
+def get_channel_indices(gen_cfg: dict, channels=None) -> dict:
+ """Return ``{'input': {name: idx}, 'output': {name: idx}}`` from a generation config.
+
+ When *channels* is given, the mappings are filtered to only those names.
+ """
+ input_channels, output_channels = resolve_io_channels(gen_cfg)
+ in_ch = {get_strings_from_channels(c): i for i, c in enumerate(input_channels)}
+ out_ch = {get_strings_from_channels(c): i for i, c in enumerate(output_channels)}
+ if channels is None:
+ return {'input': in_ch, 'output': out_ch}
+ return {
+ 'input': {ch: in_ch[ch] for ch in channels if ch in in_ch},
+ 'output': {ch: out_ch[ch] for ch in channels if ch in out_ch},
+ }
+
+
+def resolve_ts_dir(out_root: Path, ts: str) -> Path:
+ """Return the directory under *out_root* that contains the timestamp folder *ts*."""
+ if (out_root / ts).is_dir():
+ return out_root
+ matches = [p.parent for p in out_root.glob(f"*/{ts}") if p.is_dir()]
+ if matches:
+ return matches[0]
+ raise FileNotFoundError(f"Timestamp directory {ts} not found under {out_root}")
+
+
+
+def parse_eval_cli(allow_times: bool = False) -> dict:
+ """Parse standard eval CLI args (``--config-name``) and return the loaded YAML config.
+
+ When ``allow_times=True``, also accepts ``--times YYYYMMDD-HHMM ...`` to override
+ ``times`` / ``times_range`` / ``times_ranges`` in the config.
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config-name", help="Path to YAML config file for evaluation.")
+ if allow_times:
+ parser.add_argument(
+ "--times", nargs="+",
+ help="One or more timesteps (YYYYMMDD-HHMM) overriding times/times_range/times_ranges.",
+ )
+ args = parser.parse_args()
+
+ with open(args.config_name, "r") as f:
+ cfg = yaml.safe_load(f)
+
+ if allow_times and getattr(args, "times", None):
+ for k in ("times", "times_range", "times_ranges"):
+ cfg.pop(k, None)
+ cfg["times"] = args.times
+
+ return cfg
diff --git a/src/hirad/eval/hist.py b/src/hirad/eval/hist.py
new file mode 100644
index 00000000..7e335ee8
--- /dev/null
+++ b/src/hirad/eval/hist.py
@@ -0,0 +1,244 @@
+"""
+Plots the domain-mean precipitation distribution over land.
+
+This script computes and visualizes the distribution of precipitation values
+over land.
+"""
+import logging
+from pathlib import Path
+
+import hydra
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+from hirad.eval.eval_utils import get_channel_indices, load_generation_setup, load_land_sea_mask, parse_eval_cli, resolve_ts_dir
+from hirad.eval.eval_utils import percentiles_from_histogram
+
+
+def save_distribution_plot(hist_data_dict, bin_edges, labels, colors, title, ylabel, out_path, percentiles_data=None):
+ """Save distribution plot with pre-computed histograms."""
+ Path(out_path).parent.mkdir(parents=True, exist_ok=True)
+
+ plt.figure(figsize=(10, 6))
+
+ # Plot histograms from pre-computed bin counts
+ for (key, hist_data), label, color in zip(hist_data_dict.items(), labels, colors):
+ if isinstance(hist_data, tuple): # Handle ensemble data
+ # Plot individual members with transparency
+ for i, member_hist in enumerate(hist_data):
+ alpha = 0.5 if i > 0 else 0.7
+ label_member = label if i == 0 else None
+ # Plot histogram from bin counts
+ bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
+ plt.plot(bin_centers, member_hist, alpha=alpha, color=color,
+ label=label_member, drawstyle='steps-mid')
+ else:
+ # Plot histogram from bin counts
+ bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
+ plt.plot(bin_centers, hist_data, alpha=0.7, color=color,
+ label=label, linewidth=2, drawstyle='steps-mid')
+
+ plt.xscale('log')
+ plt.yscale('log')
+ plt.xlabel(ylabel)
+ plt.ylabel('Probability Density')
+ plt.ylim(1e-8, 1)
+ plt.xlim(bin_edges[1], bin_edges[-1])
+ plt.title(title)
+ plt.grid(True, alpha=0.3)
+
+ # Add percentile lines if provided
+ if percentiles_data:
+ # Calculate y-range for percentile lines (lowest 10% of log scale)
+ y_bottom, y_top = plt.ylim()
+ log_bottom, log_top = np.log10(y_bottom), np.log10(y_top)
+ vline_ymax = 10**(log_bottom + 0.1 * (log_top - log_bottom))
+ vline_ymin = y_bottom
+
+ # Define line styles for percentiles
+ percentile_styles = {99: '--', 99.9: ':', 99.99: '-.'}
+ percentile_labels = {99: '99th all-hour percentiles', 99.9: '99.9th all-hour percentiles', 99.99: '99.99th all-hour percentiles'}
+ colors = {'target': 'blue', 'baseline': 'orange', 'predictions': 'green', 'regression-prediction': 'red'}
+ legend_added = set()
+
+ # Plot all percentile lines
+ for dataset_name, data in percentiles_data.items():
+ color = colors[dataset_name]
+
+ if dataset_name in ['target', 'baseline', 'regression-prediction']:
+ # Single dataset
+ for percentile, value in data.items():
+ linestyle = percentile_styles[percentile]
+ legend_added.add(percentile) # Track percentiles for black legend entries
+
+ plt.vlines(x=value, colors=color, ymin=vline_ymin, ymax=vline_ymax,
+ linestyles=linestyle, alpha=0.8) # No label here
+ else:
+ # Ensemble members
+ for member_data in data.values():
+ for percentile, value in member_data.items():
+ linestyle = percentile_styles[percentile]
+ legend_added.add(percentile) # Track percentiles for black legend entries
+
+ plt.vlines(x=value, colors=color, ymin=vline_ymin, ymax=vline_ymax,
+ linestyles=linestyle, alpha=0.6) # No label here
+
+ # Add black legend entries for percentiles (override the colored ones)
+ for percentile in [99, 99.9, 99.99]:
+ if percentile in legend_added:
+ plt.plot([], [], color='black', linestyle=percentile_styles[percentile],
+ label=percentile_labels[percentile])
+
+ plt.legend()
+ plt.tight_layout()
+ plt.savefig(out_path, dpi=300, bbox_inches='tight')
+ plt.close()
+
+
+def main(cfg: dict):
+ # Setup logging
+ logging.basicConfig(level=logging.INFO)
+ logger = logging.getLogger(__name__)
+
+ logger.info("Starting computation for domain-mean precipitation distribution over land")
+ try:
+ generation_dir, gen_cfg, times = load_generation_setup(cfg)
+ except ValueError as exc:
+ logger.error(str(exc))
+ return
+ logger.info(f"Loaded {len(times)} timesteps to process")
+
+
+ # Output root
+ out_root = Path(generation_dir)
+
+ # Find channel indices
+ indices = get_channel_indices(gen_cfg)
+ tp_out = indices['output']['tp']
+ tp_in = indices['input'].get('tp', tp_out)
+ logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}")
+
+ # Land-sea mask
+ land_mask = load_land_sea_mask(cfg.get("land_sea_mask_path"), cfg.get("height"), cfg.get("width"))
+
+ # Define histogram bins
+ # bins = np.logspace(-1, 3.3, 200) # Log-spaced bins for precipitation
+ log_bins = np.logspace(-1, 3.3, 200) # Log-spaced bins for precipitation
+ bins = np.concatenate([[0], log_bins]) # Prepend 0 to capture all sub-0.1 values
+
+ # Storage for histogram data and land values
+ hist_data = {}
+ raw_hist_counts = {} # Store raw counts for percentile estimation
+
+ # -- Process target and baseline --
+ for mode in ['target', 'baseline', 'regression-prediction']:
+ logger.info(f"Processing mode: {mode}")
+
+ hist_counts = np.zeros(len(bins) - 1)
+ total_samples = 0
+
+ try:
+ for i, ts in enumerate(times):
+ if i % cfg.get("log_interval") == 0:
+ logger.info(f"Processing timestep {i+1}/{len(times)}")
+
+ data = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode in ['target', 'regression-prediction'] else tp_in] * cfg.get("conv_factor_hourly") * land_mask
+
+ land_values = data.values[~np.isnan(data.values)]
+
+ counts, _ = np.histogram(land_values, bins=bins)
+ hist_counts += counts
+ total_samples += len(land_values)
+ except:
+ logger.warning(f"{mode} not available, skipping")
+ continue
+ # Store raw counts for percentile estimation
+ raw_hist_counts[mode] = hist_counts.copy()
+ # Normalize to probability density
+ bin_widths = np.diff(bins)
+ hist_data[mode] = hist_counts[1:] / (total_samples * bin_widths[1:])
+ logger.info(f"Processed {total_samples} land values for {mode}")
+
+ # -- Process predictions: compute histogram for each ensemble member --
+ logger.info("Processing predictions")
+
+ n_members = None
+ member_hist_data = []
+
+ for i, ts in enumerate(times):
+ if i % cfg.get("log_interval") == 0:
+ logger.info(f"Processing timestep {i+1}/{len(times)}")
+
+ preds = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-predictions", weights_only=False) * cfg.get("conv_factor_hourly") # [n_members, n_channels, lat, lon]
+
+ if n_members is None:
+ n_members = preds.shape[0]
+ member_hist_data = [np.zeros(len(bins) - 1) for _ in range(n_members)]
+ member_sample_counts = [0 for _ in range(n_members)]
+
+ for member_idx in range(n_members):
+ data = preds[member_idx, tp_out] * land_mask
+ land_values = data.values[~np.isnan(data.values)]
+
+ counts, _ = np.histogram(land_values, bins=bins)
+ member_hist_data[member_idx] += counts
+ member_sample_counts[member_idx] += len(land_values)
+
+ # Normalize member histograms to probability density
+ bin_widths = np.diff(bins)
+ normalized_member_hists = []
+ for member_idx in range(n_members):
+ normalized_hist = member_hist_data[member_idx][1:] / (member_sample_counts[member_idx] * bin_widths[1:])
+ normalized_member_hists.append(normalized_hist)
+
+ hist_data['predictions'] = tuple(normalized_member_hists)
+
+ logger.info(f"Collected {n_members} ensemble members for predictions")
+
+ # Compute percentiles for all datasets
+ percentiles_data = {}
+ percentiles = {99: 0.99, 99.9: 0.999, 99.99: 0.9999}
+
+ # Target and baseline percentiles
+ for mode in ['target', 'baseline', 'regression-prediction']:
+ if mode in raw_hist_counts:
+ cumulative = np.cumsum(raw_hist_counts[mode])
+ total = cumulative[-1]
+ cdf = cumulative / total # CDF at upper bin edges
+ percentiles_data[mode] = percentiles_from_histogram(
+ raw_hist_counts[mode], bins, percentiles
+ )
+
+ # Ensemble member percentiles
+ percentiles_data['predictions'] = {}
+ for member_idx in range(n_members):
+ cumulative = np.cumsum(member_hist_data[member_idx])
+ total = cumulative[-1]
+ cdf = cumulative / total # CDF at upper bin edges
+ percentiles_data['predictions'][f'member_{member_idx}'] = percentiles_from_histogram(
+ member_hist_data[member_idx], bins, percentiles
+ )
+
+ # Create distribution plots
+ labels = ['Target', 'Input', 'Regression Prediction', 'CorrDiff Ensemble'] if 'regression-prediction' in hist_data else ['Target', 'Input', 'CorrDiff Ensemble']
+ colors = ['blue', 'orange', 'red', 'green'] if 'regression-prediction' in hist_data else ['blue', 'orange', 'green']
+
+ output_path = out_root / cfg.get("results_dir_name", "evaluation_maps")
+ output_path.mkdir(parents=True, exist_ok=True)
+ fn = output_path / 'precipitation_distribution_over_land.png'
+ save_distribution_plot(
+ hist_data, # Skip the first bin (0 to 0.1) for plotting
+ bins[1:],
+ labels,
+ colors,
+ 'Domain-Mean Precip. Over Land (Pooled Data)',
+ 'Precipitation (mm/h)',
+ fn,
+ percentiles_data
+ )
+ logger.info(f"Distribution plot saved: {fn}")
+
+
+if __name__ == '__main__':
+ main(parse_eval_cli())
\ No newline at end of file
diff --git a/src/hirad/eval/map_precip_stats.py b/src/hirad/eval/map_precip_stats.py
new file mode 100644
index 00000000..30e2f601
--- /dev/null
+++ b/src/hirad/eval/map_precip_stats.py
@@ -0,0 +1,251 @@
+import logging
+from datetime import datetime
+from pathlib import Path
+
+import numpy as np
+import torch
+import xarray as xr
+import numba
+
+from hirad.eval.eval_utils import get_channel_indices, grid_cfg_from_cfg, load_generation_setup, parse_eval_cli, resolve_ts_dir
+from hirad.eval.plotting import (
+ plot_map_precipitation, plot_map
+)
+
+
+@numba.njit
+def _longest_spell(x):
+ """Longest consecutive run of True values in a 1-D boolean array."""
+ best = 0
+ cur = 0
+ for i in range(x.shape[0]):
+ if x[i]:
+ cur += 1
+ if cur > best:
+ best = cur
+ else:
+ cur = 0
+ return best
+
+
+@numba.njit(parallel=True)
+def _consecutive_spell_2d(condition_3d):
+ """
+ condition_3d: bool array of shape (T, H, W).
+ Returns int array of shape (H, W) with longest spell per grid point.
+ """
+ T, H, W = condition_3d.shape
+ out = np.empty((H, W), dtype=np.int64)
+ for i in numba.prange(H):
+ for j in range(W):
+ out[i, j] = _longest_spell(condition_3d[:, i, j])
+ return out
+
+
+def consecutive_spell(data_np, condition_fn):
+ """
+ data_np: numpy array (T, H, W)
+ condition_fn: callable that takes the array and returns bool array of same shape
+ """
+ cond = condition_fn(data_np)
+ return _consecutive_spell_2d(cond)
+
+
+def apply_statistic(data_np, times_dt, stat_type, stat_param, wet_threshold=0.1):
+ """
+ Apply statistic on array containing time sequence of total precipitation map.
+ data_np: (T, H, W) float array
+ times_dt: list of datetime objects (length T)
+ Returns: (H, W) numpy array
+ """
+ if stat_type == 'mean':
+ return np.mean(data_np, axis=0)
+
+ if stat_type == 'quantile':
+ return np.quantile(data_np, stat_param, axis=0)
+
+ if stat_type == 'Rx1hr':
+ return np.max(data_np, axis=0)
+
+ # For daily aggregations, build daily sums using xarray (fast groupby)
+ if stat_type in ('Rx1day', 'Rx5day', 'cdd', 'cwd'):
+ da = xr.DataArray(
+ data_np, dims=['time', 'lat', 'lon'],
+ coords={'time': times_dt}
+ )
+ daily = da.resample(time="1D").sum("time").values # (D, H, W)
+
+ if stat_type == 'Rx1day':
+ return np.max(daily, axis=0)
+
+ if stat_type == 'Rx5day':
+ # Rolling sum along time axis using a cumsum trick
+ D, H, W = daily.shape
+ if D < 5:
+ return np.sum(daily, axis=0)
+ rolling5 = np.empty((D - 4, H, W), dtype=daily.dtype)
+ for t in range(D - 4):
+ rolling5[t] = daily[t:t+5].sum(axis=0)
+ return np.max(rolling5, axis=0)
+
+ if stat_type == 'cdd':
+ return consecutive_spell(daily, lambda x: x < 1.0)
+
+ if stat_type == 'cwd':
+ return consecutive_spell(daily, lambda x: x >= 1.0)
+
+ if stat_type == 'weth_freq':
+ return np.mean(data_np / 24.0 > wet_threshold, axis=0) * 100.0
+
+ raise ValueError(f"Unsupported statistic type: {stat_type}")
+
+
+def plot_stat_map(data, filename, stat_config, label, grid_cfg):
+ """Plot a single statistic map with appropriate styling."""
+ if stat_config['type'] == 'weth_freq':
+ plot_map(
+ data, filename,
+ title=f'{label}: {stat_config["title_stat"]} (%)',
+ label='Wet-Hour Frequency [%]', vmin=0, vmax=30, cmap='PuBu', extend='max', grid_cfg=grid_cfg
+ )
+ elif stat_config['type'] == 'cdd':
+ plot_map(
+ data, filename,
+ title=f'{label}: {stat_config["title_stat"]}',
+ label='Days', vmin=0, vmax=60, cmap='viridis', extend='max', grid_cfg=grid_cfg
+ )
+ elif stat_config['type'] == 'cwd':
+ plot_map(
+ data, filename,
+ title=f'{label}: {stat_config["title_stat"]}',
+ label='Days', vmin=0, vmax=20, cmap='viridis', extend='max', grid_cfg=grid_cfg
+ )
+ else:
+ plot_map_precipitation(
+ data, filename,
+ title=f'{label}: {stat_config["title_stat"]} Precipitation',
+ threshold=stat_config['threshold'], rfac=1.0, grid_cfg=grid_cfg
+ )
+
+
+def _load_predictions_all_members(filepath, conv_factor):
+ """Load prediction file once and return (n_members, C, H, W) tensor."""
+ return torch.load(filepath, weights_only=False) * conv_factor
+
+
+def main(cfg: dict):
+ logging.basicConfig(level=logging.INFO)
+ logger = logging.getLogger(__name__)
+
+ grid_cfg = grid_cfg_from_cfg(cfg)
+
+ logger.info("Starting precipitation statistics generation")
+ try:
+ generation_dir, gen_cfg, times = load_generation_setup(cfg)
+ except ValueError as exc:
+ logger.error(str(exc))
+ return
+ logger.info(f"Processing {len(times)} timesteps")
+
+ times_dt = [datetime.strptime(ts, "%Y%m%d-%H%M") for ts in times]
+
+ out_root = Path(generation_dir)
+ output_path = out_root / cfg.get("results_dir_name", "evaluation_maps")
+ output_path.mkdir(parents=True, exist_ok=True)
+ indices = get_channel_indices(gen_cfg)
+ tp_out = indices['output']['tp']
+ tp_in = indices['input'].get('tp', tp_out)
+ conv_factor = cfg.get("conv_factor")
+ log_interval = cfg.get("log_interval", 100)
+ wet_threshold = cfg.get("wet_threshold", 0.1)
+
+ STATISTICS_CONFIG = {
+ 'mean': {'type': 'mean', 'threshold': 0.01, 'title': 'Mean'},
+ 'p99': {'type': 'quantile', 'param': 0.99, 'threshold': 0.1, 'title': '99th Percentile'},
+ 'p99.9': {'type': 'quantile', 'param': 0.999, 'threshold': 0.1, 'title': '99.9th Percentile'},
+ 'p99.99': {'type': 'quantile', 'param': 0.9999, 'threshold': 0.1, 'title': '99.99th Percentile'},
+ 'Rx1hr': {'type': 'Rx1hr', 'threshold': 0.1, 'title': 'Maximum (Rx1hr)'},
+ 'Rx1day': {'type': 'Rx1day', 'threshold': 0.1, 'title': 'Maximum 1-day Amount (Rx1day)'},
+ 'Rx5day': {'type': 'Rx5day', 'threshold': 0.1, 'title': 'Maximum 5-day Total (Rx5day)'},
+ 'cdd': {'type': 'cdd', 'threshold': 0.1, 'title': 'Consecutive Dry Days (CDD)'},
+ 'cwd': {'type': 'cwd', 'threshold': 0.1, 'title': 'Consecutive Wet Days (CWD)'},
+ 'weth_freq': {'type': 'weth_freq', 'threshold': 0.01, 'title': 'Wet-Hour Frequency'}
+ }
+ stat_configs = [
+ {'stat_name': name, 'title_stat': config['title'], 'param': config.get('param'), **config}
+ for name, config in STATISTICS_CONFIG.items()
+ ]
+
+ # --- Basic modes: target, baseline, regression-prediction ---
+ basic_modes = {
+ 'target': (tp_out, 'Target'),
+ 'baseline': (tp_in, 'Input'),
+ 'regression-prediction': (tp_out, 'Regression Prediction')
+ }
+
+ for mode, (tp_channel, label) in basic_modes.items():
+ logger.info(f"Processing mode: {mode}")
+ data_list = []
+ try:
+ for i, ts in enumerate(times):
+ if i % log_interval == 0:
+ logger.info(f"Loading {mode} timestep {i+1}/{len(times)}: {ts}")
+ data = torch.load(resolve_ts_dir(out_root, ts) / ts / f"{ts}-{mode}", weights_only=False) * conv_factor
+ data_list.append(data[tp_channel].numpy() if isinstance(data, torch.Tensor) else data[tp_channel])
+ except Exception:
+ logger.warning(f"{mode} not available, skipping")
+ continue
+
+ # Stack into (T, H, W) numpy array
+ mode_data = np.stack(data_list, axis=0).astype(np.float64)
+ del data_list
+
+ for stat_config in stat_configs:
+ logger.info(f"Computing {stat_config['title_stat']} for {mode}...")
+ result = apply_statistic(mode_data, times_dt, stat_config['type'], stat_config['param'], wet_threshold)
+ map_output_dir = output_path / f"maps_{stat_config['stat_name']}"
+ map_output_dir.mkdir(parents=True, exist_ok=True)
+ plot_stat_map(result, str(map_output_dir / f'{mode}_{stat_config["stat_name"]}'), stat_config, label, grid_cfg)
+
+ # --- Predictions: load each file ONCE, distribute to all members ---
+ logger.info("Processing predictions mode...")
+ sample_data = torch.load(resolve_ts_dir(out_root, times[0]) / times[0] / f"{times[0]}-predictions", weights_only=False)
+ n_members = sample_data.shape[0]
+ del sample_data
+ logger.info(f"Found {n_members} ensemble members")
+
+ # Pre-allocate arrays for ALL members at once: (n_members, T, H, W)
+ # If memory is tight, we can do this in chunks. For 16 members × 2200 × 704 × 1088 × 4 bytes ≈ 107 GB
+ # Instead we can do cummulative statistics on the fly without storing all members in memory (like in map_wind_stats), but this works for now.
+ H, W = cfg.get("height"), cfg.get("width")
+ member_arrays = [np.empty((len(times), H, W), dtype=np.float32) for _ in range(n_members)]
+
+ logger.info("Loading all prediction timesteps (single pass over files)...")
+ for i, ts in enumerate(times):
+ if i % log_interval == 0:
+ logger.info(f"Loading predictions timestep {i+1}/{len(times)}: {ts}")
+ pred_data = torch.load(out_root / ts / f"{ts}-predictions", weights_only=False) * conv_factor
+ for m in range(n_members):
+ member_arrays[m][i] = (pred_data[m, tp_out].numpy() if isinstance(pred_data, torch.Tensor)
+ else pred_data[m, tp_out])
+ del pred_data
+
+ for member_idx in range(n_members):
+ logger.info(f"Computing statistics for prediction member {member_idx+1}/{n_members}")
+ member_data = member_arrays[member_idx].astype(np.float64)
+
+ for stat_config in stat_configs:
+ logger.info(f"Computing {stat_config['title_stat']} for member {member_idx+1}...")
+ member_result = apply_statistic(member_data, times_dt, stat_config['type'], stat_config['param'], wet_threshold)
+ map_output_dir = output_path / f"maps_{stat_config['stat_name']}"
+ map_output_dir.mkdir(parents=True, exist_ok=True)
+ member_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_{stat_config["stat_name"]}')
+ member_label = f'CorrDiff Member {member_idx+1}'
+ plot_stat_map(member_result, member_filename, stat_config, member_label, grid_cfg)
+
+ del member_arrays
+ logger.info("All precipitation statistics maps generated successfully")
+
+
+if __name__ == '__main__':
+ main(parse_eval_cli())
\ No newline at end of file
diff --git a/src/hirad/eval/map_wind_stats.py b/src/hirad/eval/map_wind_stats.py
new file mode 100644
index 00000000..a6367b10
--- /dev/null
+++ b/src/hirad/eval/map_wind_stats.py
@@ -0,0 +1,484 @@
+import logging
+from pathlib import Path
+
+import hydra
+import numpy as np
+import torch
+
+from hirad.datasets import get_channels_from_strings, get_strings_from_channels
+from hirad.utils.function_utils import get_time_from_range
+from hirad.eval.eval_utils import get_channel_indices, grid_cfg_from_cfg, load_generation_setup, parse_eval_cli, resolve_ts_dir
+from hirad.eval.plotting import plot_map
+
+
+def compute_wind_speed(u, v):
+ """Compute wind speed from U and V."""
+ return np.hypot(u, v)
+
+
+def compute_wind_direction(u, v, calm_threshold=0.0):
+ """Compute wind direction in degrees from N."""
+ dir_deg = (np.degrees(np.arctan2(-u, -v)) % 360)
+ if calm_threshold > 0:
+ speed = np.hypot(u, v)
+ dir_deg = np.where(speed <= calm_threshold, np.nan, dir_deg)
+ return dir_deg
+
+
+def apply_all_wind_statistics_streaming(times, out_root, mode, u_channel, v_channel, stat_configs, logger=None, log_interval=100):
+ """Compute ALL wind statistics in a single pass through timesteps."""
+ accumulators = {}
+ counts = {}
+ sin_accs = {}
+ cos_accs = {}
+ speed_accs = {}
+
+ for sc in stat_configs:
+ key = sc['stat_name']
+ accumulators[key] = None
+ counts[key] = 0
+ sin_accs[key] = None
+ cos_accs[key] = None
+ speed_accs[key] = None
+
+ for i, ts in enumerate(times):
+ if logger and i % log_interval == 0:
+ logger.info(f" Streaming {mode} timestep {i+1}/{len(times)}: {ts}")
+
+ data = torch.load(resolve_ts_dir(out_root, ts) / ts / f"{ts}-{mode}", weights_only=False)
+ u = data[u_channel].cpu().numpy() if torch.is_tensor(data[u_channel]) else data[u_channel]
+ v = data[v_channel].cpu().numpy() if torch.is_tensor(data[v_channel]) else data[v_channel]
+ del data
+
+ # Pre-compute shared quantities once per timestep
+ speed = compute_wind_speed(u, v)
+ direction_calm = None # lazy
+ direction_raw = None # lazy
+
+ for sc in stat_configs:
+ key = sc['stat_name']
+ stype = sc['type']
+
+ if stype == 'mean_speed':
+ if accumulators[key] is None:
+ accumulators[key] = np.zeros_like(speed)
+ accumulators[key] += speed
+ elif stype == 'max_speed':
+ if accumulators[key] is None:
+ accumulators[key] = np.full_like(speed, -np.inf)
+ np.maximum(accumulators[key], speed, out=accumulators[key])
+ elif stype == 'wind_power':
+ if accumulators[key] is None:
+ accumulators[key] = np.zeros_like(speed)
+ accumulators[key] += speed ** 3
+ elif stype == 'mean_u':
+ if accumulators[key] is None:
+ accumulators[key] = np.zeros_like(u)
+ accumulators[key] += u
+ elif stype == 'mean_v':
+ if accumulators[key] is None:
+ accumulators[key] = np.zeros_like(v)
+ accumulators[key] += v
+ elif stype in ['calm_freq', 'light_breeze_freq', 'moderate_breeze_freq',
+ 'strong_breeze_freq', 'gale_freq']:
+ thresholds = {
+ 'calm_freq': 2.0, 'light_breeze_freq': 1.6,
+ 'moderate_breeze_freq': 5.5, 'strong_breeze_freq': 10.8,
+ 'gale_freq': 17.2
+ }
+ threshold = thresholds[stype]
+ if accumulators[key] is None:
+ accumulators[key] = np.zeros_like(speed)
+ if stype == 'calm_freq':
+ accumulators[key] += (speed < threshold).astype(float)
+ else:
+ accumulators[key] += (speed > threshold).astype(float)
+ elif stype == 'prevailing_direction':
+ if direction_calm is None:
+ direction_calm = compute_wind_direction(u, v, calm_threshold=1.0)
+ rad = np.deg2rad(direction_calm)
+ weighted_sin = np.sin(rad) * speed
+ weighted_cos = np.cos(rad) * speed
+ if sin_accs[key] is None:
+ sin_accs[key] = np.zeros_like(weighted_sin)
+ cos_accs[key] = np.zeros_like(weighted_cos)
+ speed_accs[key] = np.zeros_like(speed)
+ sin_accs[key] += np.nan_to_num(weighted_sin, 0)
+ cos_accs[key] += np.nan_to_num(weighted_cos, 0)
+ speed_accs[key] += speed
+ elif stype == 'direction_variability':
+ if direction_raw is None:
+ direction_raw = compute_wind_direction(u, v)
+ rad = np.deg2rad(direction_raw)
+ if sin_accs[key] is None:
+ sin_accs[key] = np.zeros_like(speed)
+ cos_accs[key] = np.zeros_like(speed)
+ sin_accs[key] += np.sin(rad)
+ cos_accs[key] += np.cos(rad)
+
+ counts[key] += 1
+
+ del u, v, speed, direction_calm, direction_raw
+
+ # Finalize all statistics
+ results = {}
+ for sc in stat_configs:
+ key = sc['stat_name']
+ stype = sc['type']
+ count = counts[key]
+
+ if stype == 'prevailing_direction':
+ mean_dir = np.arctan2(
+ sin_accs[key] / (speed_accs[key] + 1e-10),
+ cos_accs[key] / (speed_accs[key] + 1e-10)
+ )
+ results[key] = np.mod(np.rad2deg(mean_dir), 360)
+ elif stype == 'direction_variability':
+ R = np.clip(np.hypot(sin_accs[key] / count, cos_accs[key] / count), 1e-10, 1.0)
+ results[key] = np.rad2deg(np.sqrt(-2 * np.log(R)))
+ elif stype == 'max_speed':
+ results[key] = accumulators[key]
+ elif stype in ['calm_freq', 'light_breeze_freq', 'moderate_breeze_freq',
+ 'strong_breeze_freq', 'gale_freq']:
+ results[key] = (accumulators[key] / count) * 100
+ else:
+ results[key] = accumulators[key] / count
+
+ return results
+
+
+def plot_wind_stat_map(data, filename, stat_config, label, grid_cfg):
+ """Plot wind statistic map."""
+ if stat_config['type'] == 'mean_speed':
+ plot_map(
+ data, filename,
+ title=f'{label}: {stat_config["title_stat"]}',
+ label='Wind Speed [m/s]', vmin=0, vmax=10, cmap='inferno', extend='max', grid_cfg=grid_cfg
+ )
+ elif stat_config['type'] == 'max_speed':
+ plot_map(
+ data, filename,
+ title=f'{label}: {stat_config["title_stat"]}',
+ label='Wind Speed [m/s]', vmin=0, vmax=30, cmap='inferno', extend='max', grid_cfg=grid_cfg
+ )
+ elif stat_config['type'] == 'wind_power':
+ plot_map(
+ data, filename,
+ title=f'{label}: {stat_config["title_stat"]}',
+ label='Wind Power Density [m³/s³]', vmin=0, vmax=1000, cmap='plasma', extend='max', grid_cfg=grid_cfg
+ )
+ elif stat_config['type'] in ['calm_freq', 'light_breeze_freq', 'moderate_breeze_freq', 'strong_breeze_freq', 'gale_freq']:
+ plot_map(
+ data, filename,
+ title=f'{label}: {stat_config["title_stat"]}',
+ label='Frequency [%]', vmin=0, vmax=80, cmap='GnBu', extend='max', grid_cfg=grid_cfg
+ )
+ elif stat_config['type'] == 'prevailing_direction':
+ plot_map(
+ data, filename,
+ title=f'{label}: {stat_config["title_stat"]}',
+ label='Direction [degrees from N]', vmin=0, vmax=360, cmap='twilight', extend='neither', grid_cfg=grid_cfg
+ )
+ elif stat_config['type'] == 'direction_variability':
+ plot_map(
+ data, filename,
+ title=f'{label}: {stat_config["title_stat"]}',
+ label='Circular Std Dev [degrees]', vmin=20, vmax=140, cmap='viridis', extend='max', grid_cfg=grid_cfg
+ )
+ elif stat_config['type'] in ['mean_u', 'mean_v']:
+ plot_map(
+ data, filename,
+ title=f'{label}: {stat_config["title_stat"]}',
+ label='Wind Component [m/s]', vmin=-5, vmax=5, cmap='RdBu_r', extend='both', grid_cfg=grid_cfg
+ )
+ else:
+ plot_map(
+ data, filename,
+ title=f'{label}: {stat_config["title_stat"]}',
+ label='Value', vmin=None, vmax=None, cmap='viridis', extend='neither', grid_cfg=grid_cfg
+ )
+
+
+def main(cfg: dict):
+ logging.basicConfig(level=logging.INFO)
+ logger = logging.getLogger(__name__)
+
+ grid_cfg = grid_cfg_from_cfg(cfg)
+
+ logger.info("Starting wind statistics generation")
+ try:
+ generation_dir, gen_cfg, times = load_generation_setup(cfg)
+ except ValueError as exc:
+ logger.error(str(exc))
+ return
+ logger.info(f"Processing {len(times)} timesteps")
+
+ out_root = Path(generation_dir)
+ output_path = out_root / cfg.get("results_dir_name", "evaluation_maps")
+ output_path.mkdir(parents=True, exist_ok=True)
+ indices = get_channel_indices(gen_cfg)
+
+ u10_out = indices['output'].get('10u')
+ v10_out = indices['output'].get('10v')
+ u10_in = indices['input'].get('10u', u10_out)
+ v10_in = indices['input'].get('10v', v10_out)
+
+
+ WIND_STATISTICS_CONFIG = {
+ 'mean_speed': {
+ 'type': 'mean_speed',
+ 'title': 'Mean Wind Speed'
+ },
+ 'max_speed': {
+ 'type': 'max_speed',
+ 'title': 'Maximum Wind Speed'
+ },
+ 'wind_power': {
+ 'type': 'wind_power',
+ 'title': 'Mean Wind Power Density'
+ },
+ 'calm_freq': {
+ 'type': 'calm_freq',
+ 'title': 'Calm Frequency (<2 m/s, Beaufort 0-1)'
+ },
+ 'light_breeze_freq': {
+ 'type': 'light_breeze_freq',
+ 'title': 'Light Breeze Frequency (>1.6 m/s, Beaufort 2+)'
+ },
+ 'moderate_breeze_freq': {
+ 'type': 'moderate_breeze_freq',
+ 'title': 'Moderate Breeze Frequency (>5.5 m/s, Beaufort 4+)'
+ },
+ 'strong_breeze_freq': {
+ 'type': 'strong_breeze_freq',
+ 'title': 'Strong Breeze Frequency (>10.8 m/s, Beaufort 6+)'
+ },
+ 'gale_freq': {
+ 'type': 'gale_freq',
+ 'title': 'Gale Frequency (>17.2 m/s, Beaufort 8+)'
+ },
+ 'prevailing_dir': {
+ 'type': 'prevailing_direction',
+ 'title': 'Prevailing Wind Direction'
+ },
+ 'dir_variability': {
+ 'type': 'direction_variability',
+ 'title': 'Wind Direction Variability'
+ },
+ 'mean_u': {
+ 'type': 'mean_u',
+ 'title': 'Mean U-Component'
+ },
+ 'mean_v': {
+ 'type': 'mean_v',
+ 'title': 'Mean V-Component'
+ }
+ }
+
+ stat_configs = [
+ {
+ 'stat_name': name,
+ 'title_stat': config['title'],
+ 'param': config.get('param'),
+ **config
+ }
+ for name, config in WIND_STATISTICS_CONFIG.items()
+ ]
+
+ basic_modes = {
+ 'target': ((u10_out, v10_out), 'Target'),
+ 'baseline': ((u10_in, v10_in), 'Input'),
+ 'regression-prediction': ((u10_out, v10_out), 'Regression Prediction')
+ }
+
+ logger.info(f"Generating {len(stat_configs)} statistics for {len(basic_modes)} modes + predictions")
+
+ log_interval = cfg.get("log_interval", 100)
+
+ for mode, (wind_channels, label) in basic_modes.items():
+ logger.info(f"Processing mode: {mode}")
+ u_channel, v_channel = wind_channels
+
+ try:
+ test_data = torch.load(resolve_ts_dir(out_root, times[0])/times[0]/f"{times[0]}-{mode}", weights_only=False)
+ del test_data
+ except Exception as e:
+ logger.warning(f"{mode} not available: {e}")
+ continue
+
+ try:
+ results = apply_all_wind_statistics_streaming(
+ times, out_root, mode, u_channel, v_channel,
+ stat_configs, logger=logger, log_interval=log_interval
+ )
+
+ for stat_config in stat_configs:
+ key = stat_config['stat_name']
+ map_output_dir = output_path / f"maps_wind_{key}"
+ map_output_dir.mkdir(parents=True, exist_ok=True)
+ plot_wind_stat_map(
+ results[key],
+ str(map_output_dir / f'{mode}_{key}'),
+ stat_config,
+ label,
+ grid_cfg
+ )
+ del results
+ except Exception as e:
+ logger.error(f"Failed computing statistics for {mode}: {e}")
+ continue
+
+
+ logger.info("Processing predictions mode...")
+ try:
+ data = torch.load(resolve_ts_dir(out_root, times[0])/times[0]/f"{times[0]}-predictions", weights_only=False)
+ n_members = data.shape[0]
+ del data
+ logger.info(f"Found {n_members} ensemble members")
+
+ # Initialize accumulators for all members × all statistics at once
+ # Each accumulator is keyed by (member_idx, stat_name)
+ accumulators = {}
+ counts = {}
+ sin_accs = {}
+ cos_accs = {}
+ speed_accs = {}
+
+ for member_idx in range(n_members):
+ for stat_config in stat_configs:
+ key = (member_idx, stat_config['stat_name'])
+ accumulators[key] = None
+ counts[key] = 0
+ sin_accs[key] = None
+ cos_accs[key] = None
+ speed_accs[key] = None
+
+ # Single pass over timesteps — load each file once for all members
+ for i, ts in enumerate(times):
+ if i % log_interval == 0:
+ logger.info(f"Loading predictions timestep {i+1}/{len(times)}: {ts}")
+
+ pred_data = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-predictions", weights_only=False)
+
+ for member_idx in range(n_members):
+ u_data = pred_data[member_idx, u10_out]
+ v_data = pred_data[member_idx, v10_out]
+ u = u_data.cpu().numpy() if torch.is_tensor(u_data) else u_data
+ v = v_data.cpu().numpy() if torch.is_tensor(v_data) else v_data
+
+ # Pre-compute shared quantities once per member per timestep
+ speed = compute_wind_speed(u, v)
+ direction_calm = None
+ direction_raw = None
+
+ for stat_config in stat_configs:
+ key = (member_idx, stat_config['stat_name'])
+ stype = stat_config['type']
+
+ if stype == 'mean_speed':
+ if accumulators[key] is None:
+ accumulators[key] = np.zeros_like(speed)
+ accumulators[key] += speed
+ elif stype == 'max_speed':
+ if accumulators[key] is None:
+ accumulators[key] = np.full_like(speed, -np.inf)
+ np.maximum(accumulators[key], speed, out=accumulators[key])
+ elif stype == 'wind_power':
+ if accumulators[key] is None:
+ accumulators[key] = np.zeros_like(speed)
+ accumulators[key] += speed ** 3
+ elif stype == 'mean_u':
+ if accumulators[key] is None:
+ accumulators[key] = np.zeros_like(u)
+ accumulators[key] += u
+ elif stype == 'mean_v':
+ if accumulators[key] is None:
+ accumulators[key] = np.zeros_like(v)
+ accumulators[key] += v
+ elif stype in ['calm_freq', 'light_breeze_freq', 'moderate_breeze_freq',
+ 'strong_breeze_freq', 'gale_freq']:
+ thresholds = {
+ 'calm_freq': 2.0, 'light_breeze_freq': 1.6,
+ 'moderate_breeze_freq': 5.5, 'strong_breeze_freq': 10.8,
+ 'gale_freq': 17.2
+ }
+ threshold = thresholds[stype]
+ if accumulators[key] is None:
+ accumulators[key] = np.zeros_like(speed)
+ if stype == 'calm_freq':
+ accumulators[key] += (speed < threshold).astype(float)
+ else:
+ accumulators[key] += (speed > threshold).astype(float)
+ elif stype == 'prevailing_direction':
+ if direction_calm is None:
+ direction_calm = compute_wind_direction(u, v, calm_threshold=1.0)
+ rad = np.deg2rad(direction_calm)
+ weighted_sin = np.sin(rad) * speed
+ weighted_cos = np.cos(rad) * speed
+ if sin_accs[key] is None:
+ sin_accs[key] = np.zeros_like(weighted_sin)
+ cos_accs[key] = np.zeros_like(weighted_cos)
+ speed_accs[key] = np.zeros_like(speed)
+ sin_accs[key] += np.nan_to_num(weighted_sin, 0)
+ cos_accs[key] += np.nan_to_num(weighted_cos, 0)
+ speed_accs[key] += speed
+ elif stype == 'direction_variability':
+ if direction_raw is None:
+ direction_raw = compute_wind_direction(u, v)
+ rad = np.deg2rad(direction_raw)
+ if sin_accs[key] is None:
+ sin_accs[key] = np.zeros_like(speed)
+ cos_accs[key] = np.zeros_like(speed)
+ sin_accs[key] += np.sin(rad)
+ cos_accs[key] += np.cos(rad)
+
+ counts[key] += 1
+
+ del u, v, speed, direction_calm, direction_raw
+
+ del pred_data
+
+ # Finalize and plot all statistics for all members
+ for member_idx in range(n_members):
+ logger.info(f"Finalizing and plotting member {member_idx+1}/{n_members}")
+ for stat_config in stat_configs:
+ stat_key = stat_config['stat_name']
+ key = (member_idx, stat_key)
+ stype = stat_config['type']
+ count = counts[key]
+
+ try:
+ if stype == 'prevailing_direction':
+ mean_dir = np.arctan2(
+ sin_accs[key] / (speed_accs[key] + 1e-10),
+ cos_accs[key] / (speed_accs[key] + 1e-10)
+ )
+ member_result = np.mod(np.rad2deg(mean_dir), 360)
+ elif stype == 'direction_variability':
+ R = np.clip(np.hypot(sin_accs[key] / count, cos_accs[key] / count), 1e-10, 1.0)
+ member_result = np.rad2deg(np.sqrt(-2 * np.log(R)))
+ elif stype == 'max_speed':
+ member_result = accumulators[key]
+ elif stype in ['calm_freq', 'light_breeze_freq', 'moderate_breeze_freq',
+ 'strong_breeze_freq', 'gale_freq']:
+ member_result = (accumulators[key] / count) * 100
+ else:
+ member_result = accumulators[key] / count
+
+ map_output_dir = output_path / f"maps_wind_{stat_key}"
+ map_output_dir.mkdir(parents=True, exist_ok=True)
+ member_filename = str(map_output_dir / f'prediction_member_{member_idx:02d}_{stat_key}')
+ plot_wind_stat_map(member_result, member_filename, stat_config, f'CorrDiff Member {member_idx+1}', grid_cfg)
+ del member_result
+ except Exception as e:
+ logger.error(f"Failed {stat_config['title_stat']} for member {member_idx+1}: {e}")
+ continue
+
+ except Exception as e:
+ logger.warning(f"Predictions not available: {e}")
+
+ logger.info("Wind statistics generation complete")
+
+
+if __name__ == '__main__':
+ main(parse_eval_cli())
\ No newline at end of file
diff --git a/src/hirad/eval/metrics.py b/src/hirad/eval/metrics.py
index b73caf0e..af9d5a9a 100644
--- a/src/hirad/eval/metrics.py
+++ b/src/hirad/eval/metrics.py
@@ -3,9 +3,10 @@
import numpy as np
import torch
import xskillscore
-import scoringrules as sr
from scipy.signal import periodogram
+import xskillscore
+import xarray as xr
# set up MAE calculation to be run for each channel for a given date/time (for target COSMO, prediction, and ERA interpolated)
@@ -58,6 +59,41 @@ def average_power_spectrum(data: np.ndarray, d=2.0): # d=2km by default
return freqs, power_spectra
-def crps():
- # Time, variable, ensemble, x, y
- xskillscore.crps_ensemble()
\ No newline at end of file
+def crps(prediction_ensemble, target, average_over_area=True, average_over_channels=True, average_over_time=True):
+ # Assumes that prediction_ensemble is in form:
+ # (member, channel, x, y) or
+ # (time, member, channel, x, y)
+ # Returns: a k-dimensional array of continuous ranked probability scores,
+ # where k is the number of dimensions that were not averaged over.
+ # For example, if average_over_area is False (and all others true), will
+ # return an ndarray of shape (X,Y)
+ target_coords = [('channel', np.arange(target.shape[-3])),
+ ('x', np.arange(target.shape[-2])),
+ ('y', np.arange(target.shape[-1]))]
+
+
+ forecasts_coords = [('member', np.arange(prediction_ensemble.shape[-4])),
+ ('channel', np.arange(prediction_ensemble.shape[-3])),
+ ('x', np.arange(prediction_ensemble.shape[-2])),
+ ('y', np.arange(prediction_ensemble.shape[-1]))]
+
+ if prediction_ensemble.ndim > 4 and target.ndim > 3:
+ forecasts_coords.insert(0, ('time', np.arange(prediction_ensemble.shape[-5])))
+ target_coords.insert(0, ('time', np.arange(target.shape[-4])))
+
+
+
+ forecasts = xr.DataArray(prediction_ensemble, coords = forecasts_coords)
+ observations = xr.DataArray(target, coords = target_coords)
+
+ dim = []
+ if prediction_ensemble.ndim > 4 and average_over_time:
+ dim.append('time')
+ if average_over_area:
+ dim.append('x')
+ dim.append('y')
+ if average_over_channels:
+ dim.append('channel')
+ crps = xskillscore.crps_ensemble(observations=observations, forecasts=forecasts, dim=dim)
+ crps = crps.to_numpy()
+ return crps
diff --git a/src/hirad/eval/plot_maps.py b/src/hirad/eval/plot_maps.py
new file mode 100644
index 00000000..a2230a68
--- /dev/null
+++ b/src/hirad/eval/plot_maps.py
@@ -0,0 +1,163 @@
+import torch
+import yaml
+import numpy as np
+import os
+from pathlib import Path
+import argparse
+from hirad.eval import plotting
+from hirad.utils.inference_utils import calculate_bounds, transform_channel
+import hydra
+from omegaconf import DictConfig, OmegaConf
+
+channel_plot_args = {
+ "2t": {"label": "°C"},
+ "10u": {"label": "m/s"},
+ "10v": {"label": "m/s"},
+ "tp": {"label": "boxcox(mm/h)"},
+}
+
+
+def get_available_time_steps(results_dir):
+ """Find all available time step directories"""
+ results_path = Path(results_dir)
+ time_step_dirs = [d.name for d in results_path.iterdir() if d.is_dir() and not d.name.startswith('plots') and not d.name.startswith('maps')]
+ return sorted(time_step_dirs)
+
+def plot_time_step(results_dir, output_dir, time_step, input_channels, output_channels, cfg):
+ """Plot all channels for a single time step"""
+ print(f"Processing time step: {time_step}")
+
+ # Set up paths for this time step
+ ts_results_dir = Path(results_dir) / time_step
+
+ # Load tensors
+ try:
+ target = torch.load(ts_results_dir / f"{time_step}-target", weights_only=False)
+ baseline = torch.load(ts_results_dir / f"{time_step}-baseline", weights_only=False)
+ predictions = torch.load(ts_results_dir / f"{time_step}-predictions", weights_only=False)
+
+ try:
+ mean_pred = torch.load(ts_results_dir / f"{time_step}-regression-prediction", weights_only=False)
+ except FileNotFoundError:
+ mean_pred = None
+ print(f" Warning: No mean prediction found for {time_step}")
+
+ except FileNotFoundError as e:
+ print(f" Error: Missing required file for {time_step}: {e}")
+ return
+
+ # Create output directory for this time step
+ ts_output_dir = Path(output_dir) / time_step
+ ts_output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Plot each channel
+ for idx, channel in enumerate(output_channels):
+ print(f" Plotting channel: {channel}")
+
+ # Get input channel index (handle case where input/output channels differ)
+ try:
+ input_idx = input_channels.index(channel)
+ except ValueError:
+ print(f" Warning: Channel {channel} not in input channels, skipping baseline")
+ continue
+
+ # Transform data
+ if channel != "tp" or not cfg.get("plot_box_precipitation", False):
+ tgt = transform_channel(target[idx], channel)
+ base = transform_channel(baseline[input_idx], channel)
+ preds = transform_channel(predictions[:, idx], channel)
+ mean = transform_channel(mean_pred[idx], channel) if mean_pred is not None else None
+ if channel == "tp":
+ threshold = transform_channel(np.array([cfg.get("tp_threshold", 0.002)]), "tp")[0] # Transform threshold too
+ tgt = np.ma.masked_where(tgt <= threshold, tgt)
+ base = np.ma.masked_where(base <= threshold, base)
+ preds = np.ma.masked_where(preds <= threshold, preds)
+ if mean is not None:
+ mean = np.ma.masked_where(mean <= threshold, mean)
+ else:
+ # For precipitation, use raw values if plotting box precipitation
+ tgt = target[idx]
+ base = baseline[input_idx]
+ preds = predictions[:, idx]
+ mean = mean_pred[idx] if mean_pred is not None else None
+
+ # Calculate consistent bounds (skip for precipitation)
+ if channel != "tp" or not cfg.get("plot_box_precipitation", False):
+ arrays = [tgt, base] + [preds[i] for i in range(preds.shape[0])]
+ if mean is not None:
+ arrays.append(mean)
+ vmin, vmax = calculate_bounds(*arrays)
+ else:
+ vmin, vmax = None, None
+
+ base_channel_dir = ts_output_dir / channel
+ base_channel_dir.mkdir(parents=True, exist_ok=True)
+
+ # Plot target
+ fname = ts_output_dir / channel / f"target"
+ if channel == "tp" and cfg.get("plot_box_precipitation", False):
+ plotting.plot_map_precipitation(tgt, str(fname), title=f"Target - {channel}")
+ else:
+ plotting.plot_map(tgt, str(fname), vmin=vmin, vmax=vmax,
+ title=f"Target - {channel}", **channel_plot_args.get(channel, {}))
+
+ # Plot baseline
+ fname = ts_output_dir / channel / "baseline"
+ if channel == "tp" and cfg.get("plot_box_precipitation", False):
+ plotting.plot_map_precipitation(base, str(fname), title=f"Baseline - {channel}")
+ else:
+ plotting.plot_map(base, str(fname), vmin=vmin, vmax=vmax,
+ title=f"Baseline - {channel}", **channel_plot_args.get(channel, {}))
+
+ # Plot mean prediction if available
+ if mean is not None:
+ fname = ts_output_dir / channel / "mean-prediction"
+ if channel == "tp" and cfg.get("plot_box_precipitation", False):
+ plotting.plot_map_precipitation(mean, str(fname), title=f"Mean Prediction - {channel}")
+ else:
+ plotting.plot_map(mean, str(fname), vmin=vmin, vmax=vmax,
+ title=f"Mean Prediction - {channel}", **channel_plot_args.get(channel, {}))
+
+ # Plot ensemble members
+ for member_idx in range(preds.shape[0]):
+ fname = ts_output_dir / channel / f"prediction_{member_idx:02d}"
+ if channel == "tp" and cfg.get("plot_box_precipitation", False):
+ plotting.plot_map_precipitation(
+ preds[member_idx], str(fname),
+ title=f"Prediction {member_idx} - {channel}"
+ )
+ else:
+ plotting.plot_map(
+ preds[member_idx], str(fname),
+ vmin=vmin, vmax=vmax,
+ title=f"Prediction {member_idx} - {channel}",
+ **channel_plot_args.get(channel, {})
+ )
+
+@hydra.main(version_base=None, config_path="../conf", config_name="plotting")
+def main(cfg: DictConfig) -> None:
+ OmegaConf.resolve(cfg)
+
+ input_channels = cfg.dataset.input_channel_names
+ output_channels = cfg.dataset.output_channel_names
+
+ # Set up directories
+ results_dir = Path(cfg.results_dir)
+ output_dir = Path(cfg.output_dir) if "output_dir" in cfg and cfg.output_dir else results_dir / "plots"
+ output_dir.mkdir(exist_ok=True)
+
+ # Determine time steps to process
+ if cfg.time_steps:
+ time_steps = cfg.time_steps
+ else:
+ time_steps = get_available_time_steps(results_dir)
+ print(f"Found {len(time_steps)} time steps: {time_steps}")
+
+ # Process each time step
+ for time_step in time_steps:
+ plot_time_step(results_dir, output_dir, time_step, input_channels, output_channels, cfg)
+
+ print(f"Plotting complete. Results saved to: {output_dir}")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py
index 1ca11c2e..61e0ff0b 100644
--- a/src/hirad/eval/plotting.py
+++ b/src/hirad/eval/plotting.py
@@ -1,22 +1,228 @@
import logging
import cartopy.crs as ccrs
+import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
+from matplotlib.colors import BoundaryNorm, ListedColormap
-def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np.array, filename: str):
+from hirad.eval.eval_utils import GridConfig, DEFAULT_GRID_CONFIG
+
+
+def plot_map(values: np.array,
+ filename: str,
+ label='',
+ title='',
+ vmin=None,
+ vmax=None,
+ cmap=None,
+ extend='neither',
+ norm=None,
+ ticks=None,
+ grid_cfg = DEFAULT_GRID_CONFIG,
+ patch_idx=0,
+ patch_size=None
+ ):
+ """Plot observed or interpolated data in a scatter plot."""
+ logging.info(f'Creating map: {filename}')
+
+ if patch_size is None:
+ patch_size = (grid_cfg.height, grid_cfg.width)
+ # TODO: implement properly plotting of pathces for patched diffusion inference inspection
+ # n_col_stacked = math.ceil(704/patch_size[0])
+ # last_start = (n_col_stacked-1) * patch_size[0]
+ # # print(n_col_stacked)
+ # latitudes_start = last_start-(patch_idx%n_col_stacked)*patch_size[0]
+ # # print(latitudes_start)
+ # longitudes_start = (patch_idx//n_col_stacked)*patch_size[1]
+ # # print(longitudes_start)
+ # lat = LAT if lat is None else lat
+ # lon = LON if lon is None else lon
+ # latitudes = lat[RELAX_ZONE : RELAX_ZONE + patch_size[0]] #LAT[RELAX_ZONE+latitudes_start:RELAX_ZONE+latitudes_start+patch_size[0]] #LAT[RELAX_ZONE : RELAX_ZONE + 352]
+ # longitudes = lon[RELAX_ZONE : RELAX_ZONE + patch_size[1]] #LON[RELAX_ZONE+longitudes_start:RELAX_ZONE+longitudes_start+patch_size[1]] #LON[RELAX_ZONE : RELAX_ZONE + 544]
+ latitudes = grid_cfg.lat[grid_cfg.relax_zone : grid_cfg.relax_zone + grid_cfg.height]
+ longitudes = grid_cfg.lon[grid_cfg.relax_zone : grid_cfg.relax_zone + grid_cfg.width]
+ lon2d, lat2d = np.meshgrid(longitudes, latitudes)
+
+ fig, ax = plt.subplots(
+ figsize=(8, 6),
+ subplot_kw={"projection": ccrs.RotatedPole(pole_longitude=-170.0,
+ pole_latitude= 43.0)}
+ )
+ contour = ax.pcolormesh(
+ lon2d, lat2d, values,
+ cmap=cmap, shading="auto",
+ norm=norm if norm else None,
+ vmin=None if norm else vmin,
+ vmax=None if norm else vmax,
+ )
+ ax.coastlines()
+ ax.add_feature(cfeature.BORDERS, linewidth=1)
+ ax.gridlines(visible=False)
+ ax.set_xticks([])
+ ax.set_yticks([])
+
+ plt.title(title)
+ cbar = plt.colorbar(
+ contour,
+ label=label,
+ orientation="horizontal",
+ extend=extend,
+ shrink=0.75,
+ pad=0.02
+ )
+ if ticks is not None:
+ cbar.set_ticks(ticks)
+ cbar.set_ticklabels([f'{tick:g}' for tick in ticks])
+
+ plt.tight_layout()
+ fig.savefig(f"{filename}.png", dpi=300, bbox_inches="tight")
+ plt.close(fig)
+
+def plot_map_precipitation(values, filename, title='', threshold=0.01, rfac=1000.0, grid_cfg=DEFAULT_GRID_CONFIG):
+ """Plot precipitation data with specific colormap and thresholds."""
+ # Scale and mask values below threshold
+ values = rfac * values # m/h --> mm/h
+ values = np.ma.masked_where(values <= threshold, values)
+
+ # Predefined colors and bounds specific for precipitation
+ colors = ['none', 'powderblue', 'dodgerblue', 'mediumblue',
+ 'forestgreen', 'limegreen', 'lawngreen',
+ 'yellow', 'gold', 'darkorange', 'red',
+ 'darkviolet', 'violet', 'thistle']
+ bounds = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000]
+
+ cmap = ListedColormap(colors)
+ norm = BoundaryNorm(bounds, ncolors=len(colors), clip=False)
+
+ plot_map(
+ values, filename,
+ cmap=cmap,
+ norm=norm,
+ ticks=bounds,
+ title=title,
+ label='mm/h',
+ extend='max',
+ grid_cfg=grid_cfg,
+ )
+
+def plot_map_wind_precip(
+ u: np.ndarray,
+ v: np.ndarray,
+ tp: np.ndarray,
+ filename: str,
+ title: str = '',
+ tp_threshold: float = 0.1,
+ tp_rfac: float = 1000.0,
+ wind_vmax: float = 15.0,
+ grid_cfg: GridConfig = DEFAULT_GRID_CONFIG,
+):
+ """Plot surface windspeed as filled background with precipitation overlaid.
+
+ Parameters
+ ----------
+ u, v : wind component arrays (H, W), in m/s
+ tp : total precipitation array (H, W), in m/h (ERA5 units)
+ filename : output path without extension
+ tp_threshold : minimum precipitation to show in mm/h (after rfac scaling)
+ tp_rfac : conversion factor applied to tp before plotting (default 1000 → m/h → mm/h)
+ wind_vmax : upper end of the wind-speed colorbar [m/s]
+ """
+ logging.info(f'Creating wind+precip map: {filename}')
+
+ wind_speed = np.hypot(u, v)
+
+ precip = tp_rfac * tp
+ precip_masked = np.ma.masked_where(precip <= tp_threshold, precip)
+
+ precip_colors = ['powderblue', 'dodgerblue', 'mediumblue',
+ 'forestgreen', 'limegreen', 'lawngreen',
+ 'yellow', 'gold', 'darkorange', 'red']
+ precip_bounds = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200]
+ precip_cmap = ListedColormap(precip_colors)
+ precip_norm = BoundaryNorm(precip_bounds, ncolors=len(precip_colors), clip=False)
+
+ latitudes = grid_cfg.lat[grid_cfg.relax_zone : grid_cfg.relax_zone + grid_cfg.height]
+ longitudes = grid_cfg.lon[grid_cfg.relax_zone : grid_cfg.relax_zone + grid_cfg.width]
+ lon2d, lat2d = np.meshgrid(longitudes, latitudes)
+
+ fig, ax = plt.subplots(
+ figsize=(10, 6),
+ subplot_kw={"projection": ccrs.RotatedPole(pole_longitude=-170.0, pole_latitude=43.0)},
+ )
+
+ # Background: wind speed
+ wind_mesh = ax.pcolormesh(
+ lon2d, lat2d, wind_speed,
+ cmap='inferno', shading='auto', vmin=0, vmax=wind_vmax,
+ )
+
+ # Overlay: precipitation (semi-transparent so wind field remains visible)
+ precip_mesh = ax.pcolormesh(
+ lon2d, lat2d, precip_masked,
+ cmap=precip_cmap, norm=precip_norm, shading='auto', alpha=0.75,
+ )
+
+ ax.coastlines()
+ ax.add_feature(cfeature.BORDERS, linewidth=1)
+ ax.gridlines(visible=False)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ ax.set_title(title)
+
+ _ = fig.colorbar(
+ wind_mesh, ax=ax, label='Wind Speed [m/s]',
+ orientation='horizontal', shrink=0.7, pad=0.04, extend='max',
+ )
+
+ cbar_precip = fig.colorbar(
+ precip_mesh, ax=ax, label='Precipitation [mm/h]',
+ orientation='vertical', shrink=0.6, pad=0.02, extend='max',
+ )
+ cbar_precip.set_ticks(precip_bounds)
+ cbar_precip.set_ticklabels([f'{b:g}' for b in precip_bounds])
+
+ fig.savefig(f'{filename}.png', dpi=300, bbox_inches='tight')
+ plt.close(fig)
+
+
+@DeprecationWarning
+def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np.array, filename: str, label='', title='', vmin=None, vmax=None):
+ """Plot observed or interpolated data in a scatter plot."""
fig = plt.figure()
fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()})
logging.info(f'plotting values to {filename}')
- p = ax.scatter(x=longitudes, y=latitudes, c=values)
+ p = ax.scatter(x=longitudes, y=latitudes, c=values, vmin=vmin, vmax=vmax)
ax.coastlines()
ax.gridlines(draw_labels=True)
- plt.colorbar(p, label="absolute error", orientation="horizontal")
+ plt.colorbar(p, label=label, orientation="horizontal")
+ plt.savefig(filename)
+ plt.close('all')
+
+def plot_scores_vs_t(scores: dict[str,np.ndarray], times: np.array, filename: str, xlabel='', ylabel='', title=''):
+
+ ax = plt.subplot()
+ colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w'] # TODO, add more
+ i=0
+ for k in scores.keys():
+ style = colors[i]
+ # If more than 50 points, don't connect lines
+ if len(times) > 50:
+ style = style + '.'
+ else:
+ style = style + '-'
+ p, = ax.plot(times, scores[k], style)
+ i=i+1
+ p.set_label(k)
+ ax.legend()
+ ax.set_xticks([times[0],times[-1]])
+ plt.xlabel(xlabel)
+ plt.ylabel(ylabel)
+ plt.title(title)
plt.savefig(filename)
plt.close('all')
def plot_power_spectra(freqs: dict, spec: dict, channel_name, filename):
- fig = plt.figure()
for k in freqs.keys():
plt.loglog(freqs[k], spec[k], label=k)
plt.title(channel_name)
diff --git a/src/hirad/eval/probability_of_exceedance.py b/src/hirad/eval/probability_of_exceedance.py
new file mode 100644
index 00000000..d6bda7d9
--- /dev/null
+++ b/src/hirad/eval/probability_of_exceedance.py
@@ -0,0 +1,260 @@
+"""
+Plots the probability of exceedance for precipitation over land.
+
+This script computes and visualizes the complementary cumulative distribution
+(probability of exceeding x mm/h) over land).
+"""
+import logging
+from pathlib import Path
+
+import hydra
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import xarray as xr
+
+from hirad.datasets import get_channels_from_strings, get_strings_from_channels
+from hirad.utils.function_utils import get_time_from_range
+from hirad.eval.eval_utils import get_channel_indices, load_generation_setup, load_land_sea_mask, parse_eval_cli, resolve_ts_dir
+from hirad.eval.eval_utils import percentiles_from_histogram
+
+
+def save_exceedance_plot(exceedance_data_dict, thresholds, labels, colors, title, ylabel, out_path, percentiles_data=None):
+ """Save probability of exceedance plot with pre-computed data."""
+ Path(out_path).parent.mkdir(parents=True, exist_ok=True)
+
+ plt.figure(figsize=(10, 6))
+
+ # Plot exceedance curves
+ for (key, exceedance_data), label, color in zip(exceedance_data_dict.items(), labels, colors):
+ if isinstance(exceedance_data, tuple): # Handle ensemble data
+ # Plot individual members with transparency
+ for i, member_exceedance in enumerate(exceedance_data):
+ alpha = 0.5 if i > 0 else 0.7
+ label_member = label if i == 0 else None
+ plt.plot(thresholds, member_exceedance, alpha=alpha, color=color,
+ label=label_member, linewidth=1)
+ else:
+ # Plot single dataset
+ plt.plot(thresholds, exceedance_data, alpha=0.7, color=color,
+ label=label, linewidth=2)
+
+ plt.xscale('log')
+ plt.yscale('log')
+ plt.xlabel(ylabel)
+ plt.ylabel('Probability of Exceedance')
+ plt.ylim(1e-8, 1)
+ plt.xlim(thresholds[1], thresholds[-1])
+ plt.title(title)
+ plt.grid(True, alpha=0.3)
+
+ # Add percentile lines if provided
+ if percentiles_data:
+ # Calculate y-range for percentile lines (lowest 10% of log scale)
+ y_bottom, y_top = plt.ylim()
+ log_bottom, log_top = np.log10(y_bottom), np.log10(y_top)
+ vline_ymax = 10**(log_bottom + 0.1 * (log_top - log_bottom))
+ vline_ymin = y_bottom
+
+ # Define line styles for percentiles
+ percentile_styles = {99: '--', 99.9: ':', 99.99: '-.'}
+ percentile_labels = {99: '99th all-hour percentiles', 99.9: '99.9th all-hour percentiles', 99.99: '99.99th all-hour percentiles'}
+ colors_perc = {'target': 'blue', 'baseline': 'orange', 'predictions': 'green', 'regression-prediction': 'red'}
+ legend_added = set()
+
+ # Plot all percentile lines
+ for dataset_name, data in percentiles_data.items():
+ color = colors_perc[dataset_name]
+
+ if dataset_name in ['target', 'baseline', 'regression-prediction']:
+ # Single dataset
+ for percentile, value in data.items():
+ linestyle = percentile_styles[percentile]
+ legend_added.add(percentile) # Track percentiles for black legend entries
+
+ plt.vlines(x=value, colors=color, ymin=vline_ymin, ymax=vline_ymax,
+ linestyles=linestyle, alpha=0.8) # No label here
+ else:
+ # Ensemble members
+ for member_data in data.values():
+ for percentile, value in member_data.items():
+ linestyle = percentile_styles[percentile]
+ legend_added.add(percentile) # Track percentiles for black legend entries
+
+ plt.vlines(x=value, colors=color, ymin=vline_ymin, ymax=vline_ymax,
+ linestyles=linestyle, alpha=0.6) # No label here
+
+ # Add black legend entries for percentiles (override the colored ones)
+ for percentile in [99, 99.9, 99.99]:
+ if percentile in legend_added:
+ plt.plot([], [], color='black', linestyle=percentile_styles[percentile],
+ label=percentile_labels[percentile])
+
+ plt.legend()
+ plt.tight_layout()
+ plt.savefig(out_path, dpi=300, bbox_inches='tight')
+ plt.close()
+
+
+def main(cfg: dict):
+ # Setup logging
+ logging.basicConfig(level=logging.INFO)
+ logger = logging.getLogger(__name__)
+
+ logger.info("Starting computation for probability of exceedance over land")
+ try:
+ generation_dir, gen_cfg, times = load_generation_setup(cfg)
+ except ValueError as exc:
+ logger.error(str(exc))
+ return
+ logger.info(f"Loaded {len(times)} timesteps to process")
+
+ # Output root
+ out_root = Path(generation_dir)
+
+ # Find channel indices
+ indices = get_channel_indices(gen_cfg)
+ tp_out = indices['output']['tp']
+ tp_in = indices['input'].get('tp', tp_out)
+ logger.info(f"TP channel indices - output: {tp_out}, input: {tp_in}")
+
+ # Land-sea mask
+ land_mask = load_land_sea_mask(cfg.get("land_sea_mask_path"), cfg.get("height"), cfg.get("width"))
+
+ # Define thresholds for exceedance calculation
+ thresholds = np.logspace(-2, 3.0, 200) # From 0.01 to 1000 mm/h
+ n_thresholds = len(thresholds)
+
+ # Histogram bins for percentile estimation (fine-grained log-spaced)
+ hist_bins = np.concatenate([
+ np.array([0.0]),
+ np.logspace(-2, 3.2, 5000) # From 0.01 to ~1585 mm/h
+ ])
+ n_hist_bins = len(hist_bins) - 1
+
+ # Storage for exceedance data and land values
+ exceedance_counts = {}
+ totals = {}
+ hist_counts = {} # For percentile estimation
+
+ # -- Process target and baseline --
+ for mode in ['target', 'baseline', 'regression-prediction']:
+ logger.info(f"Processing mode: {mode}")
+
+ mode_exc_counts = np.zeros(n_thresholds, dtype=np.int64)
+ mode_total = 0
+ mode_hist = np.zeros(n_hist_bins, dtype=np.int64)
+
+ try:
+ for i, ts in enumerate(times):
+ if i % cfg.get("log_interval") == 0:
+ logger.info(f"Processing timestep {i+1}/{len(times)}")
+
+ data = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-{mode}", weights_only=False)[tp_out if mode in ['target','regression-prediction'] else tp_in] * cfg.get("conv_factor_hourly") * land_mask
+
+ land_values = data.values[~np.isnan(data.values)]
+ n_vals = len(land_values)
+ mode_total += n_vals
+ # Update counts for exceedance calculation
+ mode_exc_counts += np.sum(
+ land_values[:, None] > thresholds[None, :], axis=0
+ )
+ # Update histogram counts for percentile estimation
+ mode_hist += np.histogram(land_values, bins=hist_bins)[0]
+ except:
+ logger.warning(f"{mode} data not found, skipping")
+ continue
+
+ # Compute exceedance probabilities
+ exceedance_counts[mode] = mode_exc_counts
+ totals[mode] = mode_total
+ hist_counts[mode] = mode_hist
+ logger.info(f"Processed {mode_total} land values for {mode}")
+
+ # -- Process predictions: compute exceedance for each ensemble member --
+ logger.info("Processing predictions")
+
+ n_members = None
+ member_exc_counts = None
+ member_totals = None
+ member_hist = None
+
+ for i, ts in enumerate(times):
+ if i % cfg.get("log_interval") == 0:
+ logger.info(f"Processing timestep {i+1}/{len(times)}")
+
+ preds = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-predictions", weights_only=False) * cfg.get("conv_factor_hourly") # [n_members, n_channels, lat, lon]
+
+ if n_members is None:
+ n_members = preds.shape[0]
+ member_exc_counts = [np.zeros(n_thresholds, dtype=np.int64) for _ in range(n_members)]
+ member_totals = [0] * n_members
+ member_hist = [np.zeros(n_hist_bins, dtype=np.int64) for _ in range(n_members)]
+
+ for member_idx in range(n_members):
+ data = preds[member_idx, tp_out] * land_mask
+ land_values = data.values[~np.isnan(data.values)]
+ n_vals = len(land_values)
+ member_totals[member_idx] += n_vals
+ member_exc_counts[member_idx] += np.sum(
+ land_values[:, None] > thresholds[None, :], axis=0
+ )
+ member_hist[member_idx] += np.histogram(land_values, bins=hist_bins)[0]
+
+ # Compute exceedance probabilities
+ exceedance_data = {}
+ for mode in ['target', 'baseline', 'regression-prediction']:
+ if mode in exceedance_counts and totals[mode] > 0:
+ exceedance_data[mode] = exceedance_counts[mode] / totals[mode]
+
+ member_exceedance_data = []
+ for member_idx in range(n_members):
+ if member_totals[member_idx] > 0:
+ member_exceedance_data.append(
+ member_exc_counts[member_idx] / member_totals[member_idx]
+ )
+
+ exceedance_data['predictions'] = tuple(member_exceedance_data)
+
+ logger.info(f"Collected {n_members} ensemble members for predictions")
+
+ # Compute percentiles for all datasets
+ percentiles_data = {}
+ percentiles = {99: 0.99, 99.9: 0.999, 99.99: 0.9999}
+
+ # Estimating percentiles from fine-grained histograms
+ for mode in ['target', 'baseline', 'regression-prediction']:
+ if mode in hist_counts and totals[mode] > 0:
+ percentiles_data[mode] = percentiles_from_histogram(
+ hist_counts[mode], hist_bins, percentiles
+ )
+
+ percentiles_data['predictions'] = {}
+ for member_idx in range(n_members):
+ if member_totals[member_idx] > 0:
+ percentiles_data['predictions'][f'member_{member_idx}'] = percentiles_from_histogram(
+ member_hist[member_idx], hist_bins, percentiles
+ )
+
+ # Create exceedance plots
+ labels = ['Target', 'Input', 'Regression Prediction', 'CorrDiff Ensemble'] if 'regression-prediction' in exceedance_data else ['Target', 'Input', 'CorrDiff Ensemble']
+ colors = ['blue', 'orange', 'red', 'green'] if 'regression-prediction' in exceedance_data else ['blue', 'orange', 'green']
+
+ output_path = out_root / cfg.get("results_dir_name", "evaluation_maps")
+ output_path.mkdir(parents=True, exist_ok=True)
+ fn = output_path / 'precipitation_exceedance_over_land.png'
+ save_exceedance_plot(
+ exceedance_data,
+ thresholds,
+ labels,
+ colors,
+ 'Probability of Exceedance',
+ 'All-hour Precipitation Over Land [mm/h] (Pooled Data)',
+ fn,
+ percentiles_data
+ )
+ logger.info(f"Exceedance plot saved: {fn}")
+
+
+if __name__ == '__main__':
+ main(parse_eval_cli())
diff --git a/src/hirad/eval/probability_of_exceedance_wind.py b/src/hirad/eval/probability_of_exceedance_wind.py
new file mode 100644
index 00000000..f280e2db
--- /dev/null
+++ b/src/hirad/eval/probability_of_exceedance_wind.py
@@ -0,0 +1,343 @@
+"""Probability of exceedance for wind speed and components."""
+import logging
+from pathlib import Path
+
+import hydra
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import xarray as xr
+
+from hirad.datasets import get_channels_from_strings, get_strings_from_channels
+from hirad.utils.function_utils import get_time_from_range
+from hirad.eval.eval_utils import get_channel_indices, load_generation_setup, parse_eval_cli, resolve_ts_dir
+from hirad.eval.eval_utils import percentiles_from_histogram
+
+
+def compute_wind_speed(u, v):
+ """Compute wind speed from U and V components."""
+ return np.hypot(u, v)
+
+
+def compute_exceedance_probs(values, thresholds, use_abs=False):
+ """Compute exceedance probabilities."""
+ if use_abs:
+ return np.array([np.mean(np.abs(values) > t) for t in thresholds])
+ else:
+ return np.array([np.mean(values > t) for t in thresholds])
+
+
+def update_exceedance_counts(counts, total, values, thresholds, use_abs=False):
+ """Update exceedance counts incrementally."""
+ data = np.abs(values) if use_abs else values
+ counts += (data[:, None] > thresholds[None, :]).sum(axis=0)
+ total += len(values)
+ return counts, total
+
+
+def compute_percentiles(values, percentile_dict, use_abs=False):
+ """Compute percentiles."""
+ data = np.abs(values) if use_abs else values
+ data_array = xr.DataArray(data)
+ return {key: data_array.quantile(p).item() for key, p in percentile_dict.items()}
+
+
+def save_exceedance_plot(exceedance_data_dict, thresholds, labels, colors, title, ylabel, out_path, percentiles_data=None):
+ """Save probability of exceedance plot."""
+ Path(out_path).parent.mkdir(parents=True, exist_ok=True)
+
+ plt.figure(figsize=(10, 6))
+
+ # Plot exceedance curves
+ for (key, exceedance_data), label, color in zip(exceedance_data_dict.items(), labels, colors):
+ if isinstance(exceedance_data, tuple): # Handle ensemble data
+ # Plot individual members with transparency
+ for i, member_exceedance in enumerate(exceedance_data):
+ alpha = 0.5 if i > 0 else 0.7
+ label_member = label if i == 0 else None
+ plt.plot(thresholds, member_exceedance, alpha=alpha, color=color,
+ label=label_member, linewidth=1)
+ else:
+ # Plot single dataset
+ plt.plot(thresholds, exceedance_data, alpha=0.7, color=color,
+ label=label, linewidth=2)
+
+ plt.xscale('log')
+ plt.xlim(thresholds[1], thresholds[-1])
+ plt.yscale('log')
+ plt.xlabel(ylabel)
+ plt.ylabel('Probability of Exceedance')
+ plt.ylim(1e-8, 1)
+ plt.title(title)
+ plt.grid(True, alpha=0.3)
+
+ # Add percentile lines if provided
+ if percentiles_data:
+ # Calculate y-range for percentile lines (lowest 10% of log scale)
+ y_bottom, y_top = plt.ylim()
+ log_bottom, log_top = np.log10(y_bottom), np.log10(y_top)
+ vline_ymax = 10**(log_bottom + 0.1 * (log_top - log_bottom))
+ vline_ymin = y_bottom
+
+ # Define line styles for percentiles
+ percentile_styles = {99: '--', 99.9: ':', 99.99: '-.'}
+ percentile_labels = {99: '99th all-hour percentiles', 99.9: '99.9th all-hour percentiles', 99.99: '99.99th all-hour percentiles'}
+ colors_perc = {'target': 'blue', 'baseline': 'orange', 'predictions': 'green', 'regression-prediction': 'red'}
+ legend_added = set()
+
+ # Plot all percentile lines
+ for dataset_name, data in percentiles_data.items():
+ color = colors_perc[dataset_name]
+
+ if dataset_name in ['target', 'baseline', 'regression-prediction']:
+ # Single dataset
+ for percentile, value in data.items():
+ linestyle = percentile_styles[percentile]
+ legend_added.add(percentile) # Track percentiles for black legend entries
+
+ plt.vlines(x=value, colors=color, ymin=vline_ymin, ymax=vline_ymax,
+ linestyles=linestyle, alpha=0.8) # No label here
+ else:
+ # Ensemble members
+ for member_data in data.values():
+ for percentile, value in member_data.items():
+ linestyle = percentile_styles[percentile]
+ legend_added.add(percentile) # Track percentiles for black legend entries
+
+ plt.vlines(x=value, colors=color, ymin=vline_ymin, ymax=vline_ymax,
+ linestyles=linestyle, alpha=0.6) # No label here
+
+ # Add black legend entries for percentiles (override the colored ones)
+ for percentile in [99, 99.9, 99.99]:
+ if percentile in legend_added:
+ plt.plot([], [], color='black', linestyle=percentile_styles[percentile],
+ label=percentile_labels[percentile])
+
+ plt.legend()
+ plt.tight_layout()
+ plt.savefig(out_path, dpi=300, bbox_inches='tight')
+ plt.close()
+
+
+def main(cfg: dict):
+ # Setup logging
+ logging.basicConfig(level=logging.INFO)
+ logger = logging.getLogger(__name__)
+
+ logger.info("Starting computation for probability of exceedance for wind speed")
+ try:
+ generation_dir, gen_cfg, times = load_generation_setup(cfg)
+ except ValueError as exc:
+ logger.error(str(exc))
+ return
+ logger.info(f"Loaded {len(times)} timesteps to process")
+
+ # Output root
+ out_root = Path(generation_dir)
+
+ # Find channel indices for wind components
+ indices = get_channel_indices(gen_cfg)
+ u10_out = indices['output'].get('10u')
+ v10_out = indices['output'].get('10v')
+ u10_in = indices['input'].get('10u', u10_out)
+ v10_in = indices['input'].get('10v', v10_out)
+
+ if u10_out is None or v10_out is None:
+ logger.error("Wind components (10u, 10v) not found in dataset!")
+ return
+
+ logger.info(f"Wind component channel indices - output: 10u={u10_out}, 10v={v10_out}, input: 10u={u10_in}, 10v={v10_in}")
+
+ # Define thresholds for exceedance calculation (same for all variables)
+ thresholds = np.logspace(-1, 2, 200) # From 0.1 to ~100 m/s
+ n_thresholds = len(thresholds)
+
+ # Histogram bins for percentile estimation (fine-grained log-spaced)
+ hist_bins = np.concatenate([
+ np.array([0.0]),
+ np.logspace(-1, 2.5, 5000) # From 0.1 to ~316 m/s
+ ])
+ n_hist_bins = len(hist_bins) - 1
+
+ # Storage for exceedance counts (incremental computation)
+ exceedance_counts = {
+ 'speed': {}, 'u': {}, 'v': {}
+ }
+ totals = {'speed': {}, 'u': {}, 'v': {}}
+ hist_counts = {
+ 'speed': {}, 'u': {}, 'v': {}
+ }
+
+ # -- Process target and baseline --
+ for mode in ['target', 'baseline', 'regression-prediction']:
+ logger.info(f"Processing mode: {mode}")
+
+ # Initialize counts
+ for var in ['speed', 'u', 'v']:
+ exceedance_counts[var][mode] = np.zeros(n_thresholds, dtype=np.int64)
+ totals[var][mode] = 0
+ hist_counts[var][mode] = np.zeros(n_hist_bins, dtype=np.int64)
+
+ try:
+ for i, ts in enumerate(times):
+ if i % cfg.get("log_interval") == 0:
+ logger.info(f"Processing timestep {i+1}/{len(times)}")
+
+ data = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-{mode}", weights_only=False)
+
+ # Extract wind components
+ if mode in ['target', 'regression-prediction']:
+ u = data[u10_out]
+ v = data[v10_out]
+ else: # baseline
+ u = data[u10_in]
+ v = data[v10_in]
+
+ wind_speed = compute_wind_speed(u, v)
+
+ # Get valid values
+ valid_mask = ~np.isnan(wind_speed)
+ speed_vals = wind_speed[valid_mask].flatten()
+ u_vals = u[valid_mask].flatten()
+ v_vals = v[valid_mask].flatten()
+
+ # Update exceedance counts incrementally
+ exceedance_counts['speed'][mode], totals['speed'][mode] = update_exceedance_counts(
+ exceedance_counts['speed'][mode], totals['speed'][mode], speed_vals, thresholds, use_abs=False
+ )
+ exceedance_counts['u'][mode], totals['u'][mode] = update_exceedance_counts(
+ exceedance_counts['u'][mode], totals['u'][mode], u_vals, thresholds, use_abs=True
+ )
+ exceedance_counts['v'][mode], totals['v'][mode] = update_exceedance_counts(
+ exceedance_counts['v'][mode], totals['v'][mode], v_vals, thresholds, use_abs=True
+ )
+
+ # Collect samples for percentiles (subsample to save memory)
+ hist_counts['speed'][mode] += np.histogram(speed_vals, bins=hist_bins)[0]
+ hist_counts['u'][mode] += np.histogram(np.abs(u_vals), bins=hist_bins)[0]
+ hist_counts['v'][mode] += np.histogram(np.abs(v_vals), bins=hist_bins)[0]
+
+ except Exception as e:
+ logger.warning(f"{mode} data not found or error occurred, skipping: {e}")
+ continue
+
+ logger.info(f"Processed {totals['speed'][mode]} values for {mode}")
+
+ # -- Process predictions: compute exceedance for each ensemble member --
+ logger.info("Processing predictions")
+
+ n_members = None
+ member_counts = {'speed': [], 'u': [], 'v': []}
+ member_totals = {'speed': [], 'u': [], 'v': []}
+ member_hist_counts = {'speed': [], 'u': [], 'v': []}
+
+ for i, ts in enumerate(times):
+ if i % cfg.get("log_interval") == 0:
+ logger.info(f"Processing timestep {i+1}/{len(times)}")
+
+ preds = torch.load(resolve_ts_dir(out_root, ts)/ts/f"{ts}-predictions", weights_only=False) # [n_members, n_channels, lat, lon]
+
+ if n_members is None:
+ n_members = preds.shape[0]
+ for var in ['speed', 'u', 'v']:
+ member_counts[var] = [np.zeros(n_thresholds, dtype=np.int64) for _ in range(n_members)]
+ member_totals[var] = [0 for _ in range(n_members)]
+ member_hist_counts[var] = [np.zeros(n_hist_bins, dtype=np.int64) for _ in range(n_members)]
+
+ for member_idx in range(n_members):
+ u = preds[member_idx, u10_out]
+ v = preds[member_idx, v10_out]
+ wind_speed = compute_wind_speed(u, v)
+
+ valid_mask = ~np.isnan(wind_speed)
+ speed_vals = wind_speed[valid_mask].flatten()
+ u_vals = u[valid_mask].flatten()
+ v_vals = v[valid_mask].flatten()
+
+ # Update counts
+ member_counts['speed'][member_idx], member_totals['speed'][member_idx] = update_exceedance_counts(
+ member_counts['speed'][member_idx], member_totals['speed'][member_idx], speed_vals, thresholds, use_abs=False
+ )
+ member_counts['u'][member_idx], member_totals['u'][member_idx] = update_exceedance_counts(
+ member_counts['u'][member_idx], member_totals['u'][member_idx], u_vals, thresholds, use_abs=True
+ )
+ member_counts['v'][member_idx], member_totals['v'][member_idx] = update_exceedance_counts(
+ member_counts['v'][member_idx], member_totals['v'][member_idx], v_vals, thresholds, use_abs=True
+ )
+
+ # Collect samples for percentiles
+ member_hist_counts['speed'][member_idx] += np.histogram(speed_vals, bins=hist_bins)[0]
+ member_hist_counts['u'][member_idx] += np.histogram(np.abs(u_vals), bins=hist_bins)[0]
+ member_hist_counts['v'][member_idx] += np.histogram(np.abs(v_vals), bins=hist_bins)[0]
+
+ logger.info(f"Collected {n_members} ensemble members for predictions")
+
+ # Convert counts to probabilities
+ exceedance_data = {'speed': {}, 'u': {}, 'v': {}}
+
+ for var in ['speed', 'u', 'v']:
+ # Single datasets
+ for mode in ['target', 'baseline', 'regression-prediction']:
+ if mode in exceedance_counts[var] and totals[var][mode] > 0:
+ exceedance_data[var][mode] = exceedance_counts[var][mode] / totals[var][mode]
+
+ # Ensemble members
+ member_probs = []
+ for member_idx in range(n_members):
+ if member_totals[var][member_idx] > 0:
+ member_probs.append(member_counts[var][member_idx] / member_totals[var][member_idx])
+ exceedance_data[var]['predictions'] = tuple(member_probs)
+
+ # Compute percentiles for all datasets and variables
+ percentiles = {99: 0.99, 99.9: 0.999, 99.99: 0.9999}
+ percentiles_data = {'speed': {}, 'u': {}, 'v': {}}
+
+ # Single datasets (target, baseline, regression-prediction)
+ for var in ['speed', 'u', 'v']:
+ for mode in ['target', 'baseline', 'regression-prediction']:
+ if mode in hist_counts[var] and totals[var][mode] > 0:
+ percentiles_data[var][mode] = percentiles_from_histogram(
+ hist_counts[var][mode], hist_bins, percentiles
+ )
+
+ # Ensemble members
+ percentiles_data[var]['predictions'] = {}
+ for member_idx in range(n_members):
+ if member_totals[var][member_idx] > 0:
+ percentiles_data[var]['predictions'][f'member_{member_idx}'] = percentiles_from_histogram(
+ member_hist_counts[var][member_idx], hist_bins, percentiles
+ )
+
+ # Create exceedance plots
+ labels = ['Target', 'Input', 'Regression Prediction', 'CorrDiff Ensemble'] if 'regression-prediction' in exceedance_data['speed'] else ['Target', 'Input', 'CorrDiff Ensemble']
+ colors = ['blue', 'orange', 'red', 'green'] if 'regression-prediction' in exceedance_data['speed'] else ['blue', 'orange', 'green']
+
+ # Define plot configurations
+ plot_configs = [
+ ('windspeed_exceedance.png', 'speed', 'Probability of Exceedance for Wind Speed',
+ 'All-hour Wind Speed [m/s] (Pooled Data)'),
+ ('wind_u_exceedance.png', 'u', 'Probability of Exceedance for abs(10u)',
+ 'All-hour 10u Component [m/s] (Pooled Data)'),
+ ('wind_v_exceedance.png', 'v', 'Probability of Exceedance for abs(10v)',
+ 'All-hour 10v Component [m/s] (Pooled Data)'),
+ ]
+
+ output_path = out_root / cfg.get("results_dir_name", "evaluation_maps")
+ output_path.mkdir(parents=True, exist_ok=True)
+ for filename, var, title, ylabel in plot_configs:
+ fn = output_path / filename
+ save_exceedance_plot(
+ exceedance_data[var],
+ thresholds,
+ labels,
+ colors,
+ title,
+ ylabel,
+ fn,
+ percentiles_data[var]
+ )
+ logger.info(f"{var.capitalize()} exceedance plot saved: {fn}")
+
+
+if __name__ == '__main__':
+ main(parse_eval_cli())
diff --git a/src/hirad/eval/snapshots.py b/src/hirad/eval/snapshots.py
new file mode 100644
index 00000000..c76f4b05
--- /dev/null
+++ b/src/hirad/eval/snapshots.py
@@ -0,0 +1,295 @@
+"""Generates maps of precipitation, temperature, and wind components/speed/direction."""
+import logging
+from dataclasses import dataclass, field, replace
+from datetime import datetime
+from pathlib import Path
+
+import hydra
+import numpy as np
+import torch
+
+from hirad.datasets import get_channels_from_strings, get_strings_from_channels
+from hirad.eval import compute_mae, plot_map
+from hirad.eval.eval_utils import (
+ DEFAULT_GRID_CONFIG,
+ grid_cfg_from_cfg,
+ load_generation_setup,
+ parse_eval_cli,
+ resolve_io_channels,
+ resolve_ts_dir,
+)
+from hirad.eval.plotting import plot_map_precipitation, plot_map_wind_precip
+from hirad.utils.inference_utils import calculate_bounds
+
+
+def wind_direction(u, v):
+ """Compute wind direction from u and v components."""
+ return (np.arctan2(-u, -v) * 180 / np.pi) % 360
+
+@dataclass
+class ChannelMeta:
+ """Metadata for a channel."""
+ name: str
+ cmap: str = "viridis"
+ me_cmap: str | None = None
+ unit: str = ""
+ norm: any = None
+ err_vmin: float = None
+ err_vmax: float = None
+ vmin: float = None
+ vmax: float = None
+ extend: str = "both"
+ precip_kwargs: dict = field(default_factory=lambda: {"threshold": 0.01, "rfac": 1000.0})
+
+ @classmethod
+ def get(cls, ch_or_name: "ChannelMeta | str | None", *, vmin=None, vmax=None) -> "ChannelMeta":
+ name = getattr(ch_or_name, "name", ch_or_name or "")
+ base = CHANNELS.get(name) or cls(name=name)
+ if vmin is not None or vmax is not None:
+ return replace(base, vmin=vmin, vmax=vmax)
+ return base
+
+CHANNELS = {
+ "tp": ChannelMeta(name="tp", cmap=None, unit="mm/h", extend="max", precip_kwargs={"threshold": 0.01, "rfac": 1000.0}),
+ "2t": ChannelMeta(name="2t", cmap="RdYlBu_r", me_cmap="RdBu", unit="K", err_vmin=-4.5, err_vmax=4.5),
+ "10u": ChannelMeta(name="10u", cmap="BrBG", me_cmap="BrBG", unit="m/s", err_vmin=-10, err_vmax=10, vmin=-10, vmax=10),
+ "10v": ChannelMeta(name="10v", cmap="BrBG", me_cmap="BrBG", unit="m/s", err_vmin=-10, err_vmax=10, vmin=-10, vmax=10),
+}
+
+def format_time_str(dt_str, input_fmt="%Y%m%d-%H%M", output_fmt="%d-%m-%Y %H:%M"):
+ """Convert time string from input_fmt to output_fmt."""
+ dt = datetime.strptime(dt_str, input_fmt)
+ return dt.strftime(output_fmt)
+
+class FileRepository:
+ def __init__(self, root_path):
+ self.root = Path(root_path)
+
+ def load(self, time, filename):
+ return torch.load(resolve_ts_dir(self.root, time) / time / filename, weights_only=False)
+
+ def _ensure_dir(self, *subdirs):
+ """Make (and return) root_path/subdir1/subdir2/…."""
+ d = self.root.joinpath(*subdirs)
+ d.mkdir(parents=True, exist_ok=True)
+ return d
+
+ def _make_fname(self, curr_time, prefix, suffix, member_idx):
+ """Build a filename like '20250724-1230-prefix-suffix[_member]'."""
+ base = f"{curr_time}-{prefix}-{suffix}"
+ if member_idx is not None:
+ base += f"_{member_idx}"
+ return base
+
+ def output_file(self, channel, curr_time, suffix, member_idx=None):
+ # decide on the folder name: e.g. 'tp_100m' or just 'tp'
+ folder = f"{channel.name}_{channel.level}" if getattr(channel, "level", None) else channel.name
+ fname = self._make_fname(curr_time, channel.name, suffix, member_idx)
+ return self._ensure_dir(folder) / fname
+
+ def wind_file(self, wind_type, curr_time, suffix, member_idx=None):
+ # e.g. wind_type = "FF10m" or "DD10m"
+ fname = self._make_fname(curr_time, wind_type, suffix, member_idx)
+ return self._ensure_dir(wind_type) / fname
+
+
+def _pred_members(arr, *channel_idxs):
+ """Yield ``(member_idx_or_None, *2d_slices)`` for 3D or 4D ensemble arrays."""
+ if arr.ndim > 3:
+ for m in range(arr.shape[0]):
+ yield (m, *(arr[m, i, :, :] for i in channel_idxs))
+ else:
+ yield (None, *(arr[i, :, :] for i in channel_idxs))
+
+
+def save_field(name, data, meta, output_files, channel, t, member=None, kind=None, cmap=None, vmin=None, vmax=None, custom_path=None, plot_func=None, title=None, grid_cfg=DEFAULT_GRID_CONFIG, **plot_kwargs):
+ """Save a field by plotting it with the appropriate function and parameters. Handles precipitation specially and supports custom paths."""
+ # Determine output path
+ suffix = f"{name}-{kind}" if kind else f"{name}"
+ out_path = custom_path or output_files.output_file(channel, t, suffix, member)
+
+ # Choose plotting function
+ plot = plot_func or plot_map
+
+ # Case dependent plot parameters
+ title = title or meta.name
+ extend = 'max' if kind == 'mae' else meta.extend
+ plot_map_args = {'title': title, 'norm': meta.norm, 'extend': extend, **plot_kwargs}
+ common_args = {'grid_cfg': grid_cfg}
+
+ # Precipitation case
+ if plot.__name__ == 'plot_map_precipitation':
+ precip_args = {**meta.precip_kwargs, **{k: plot_kwargs[k] for k in meta.precip_kwargs if k in plot_kwargs}}
+ plot(data, out_path, title=title, **precip_args, **common_args)
+ else:
+ plot(data, out_path, vmin=vmin if vmin is not None else meta.vmin, vmax=vmax if vmax is not None else meta.vmax, cmap=cmap or meta.cmap, label=meta.unit, **plot_map_args, **common_args)
+
+
+def main(cfg: dict) -> None:
+ # Initialize logger
+ logging.basicConfig(level=logging.INFO)
+ logger = logging.getLogger("plot_maps")
+
+ grid_cfg = grid_cfg_from_cfg(cfg)
+
+ try:
+ generation_dir, gen_cfg, times = load_generation_setup(cfg)
+ except ValueError as exc:
+ logger.error(str(exc))
+ return
+
+ input_channels, output_channels = resolve_io_channels(gen_cfg)
+
+ plot_channels = cfg.get("plot_channels", None)
+ if plot_channels is not None:
+ plot_channels = get_channels_from_strings(plot_channels)
+ else:
+ plot_channels = output_channels
+
+ logger.info(f"Processing {len(times)} timestep(s): {times}")
+ logger.info(f"Plot channels : {get_strings_from_channels(plot_channels)}")
+ logger.info(f"Input channels : {get_strings_from_channels(input_channels)}")
+ logger.info(f"Output channels: {get_strings_from_channels(output_channels)}")
+
+ input_channel_indices = []
+ output_channel_indices = []
+ for channel in plot_channels or []:
+ input_channel_indices.append(input_channels.index(channel) if channel in input_channels else -1)
+ output_channel_indices.append(output_channels.index(channel) if channel in output_channels else -1)
+
+ output_path = Path(generation_dir) / cfg.get("results_dir_name", "evaluation_maps") / "snapshots"
+ output_path.mkdir(parents=True, exist_ok=True)
+ logger.info(f"Output directory: {output_path}")
+ input_files = FileRepository(generation_dir)
+ output_files = FileRepository(output_path)
+
+ for curr_time in times:
+ logger.info(f"Plotting timestep: {curr_time}")
+ prediction = input_files.load(curr_time, f'{curr_time}-predictions')
+ baseline = input_files.load(curr_time, f'{curr_time}-baseline')
+ target = input_files.load(curr_time, f'{curr_time}-target')
+ try:
+ mean_pred = input_files.load(curr_time, f'{curr_time}-regression-prediction')
+ except FileNotFoundError:
+ mean_pred = None
+
+ for idx, channel in enumerate(plot_channels):
+ in_idx = input_channel_indices[idx]
+ out_idx = output_channel_indices[idx]
+
+ # TODO Implement that it plots just output or just input channel if the other is missing
+ if in_idx == -1 or out_idx == -1:
+ logger.warning(f"Channel {channel.name} not found in input or output channels. Skipping.")
+ continue
+
+ plot_title = f"{format_time_str(curr_time)}: {getattr(channel, 'title', channel.name if channel.level == '' else f'{channel.name}_{channel.level}')}"
+ target_2d = target[out_idx, :, :]
+ baseline_2d = baseline[in_idx, :, :]
+ vmin, vmax = calculate_bounds(
+ target_2d,
+ prediction[:, out_idx, :, :] if prediction.ndim > 3 else prediction[idx, :, :],
+ None if channel.name == "tp" else baseline_2d,
+ mean_pred[out_idx, :, :] if mean_pred is not None else None,
+ )
+ meta = ChannelMeta.get(channel, vmin=vmin, vmax=vmax)
+
+ # Build sources to plot: (label, member_idx, 2D field).
+ sources: list = [("target", None, target_2d), ("baseline", None, baseline_2d)]
+ if mean_pred is not None:
+ sources.append(("mean-prediction", None, mean_pred[out_idx, :, :]))
+ for m, p in _pred_members(prediction, out_idx):
+ sources.append(("prediction", m, p))
+
+ if channel.name == "tp":
+ for label, m, data in sources:
+ save_field(label, data, meta, output_files, channel, curr_time,
+ member=m, plot_func=plot_map_precipitation,
+ title=plot_title, grid_cfg=grid_cfg)
+ continue
+
+ err_cmap = meta.cmap if channel.name not in ("10u", "10v", "2t") else 'viridis'
+ for label, m, data in sources:
+ save_field(label, data, meta, output_files, channel, curr_time,
+ member=m, title=plot_title, grid_cfg=grid_cfg)
+ if label == "target":
+ continue
+ _, mae = compute_mae(data, target_2d)
+ me = data - target_2d
+ save_field(label, mae.reshape(data.shape), meta, output_files, channel, curr_time,
+ member=m, kind="mae", cmap=err_cmap, vmin=0, vmax=meta.err_vmax,
+ title=plot_title, grid_cfg=grid_cfg)
+ save_field(label, me, meta, output_files, channel, curr_time,
+ member=m, kind="me", cmap=meta.me_cmap,
+ vmin=meta.err_vmin, vmax=meta.err_vmax,
+ title=plot_title, grid_cfg=grid_cfg)
+
+ # Wind speed / direction (and combined wind+precip) plots
+ wind_out = {ch.name: i for i, ch in enumerate(output_channels) if ch.name in ("10u", "10v")}
+ wind_in = {ch.name: i for i, ch in enumerate(input_channels) if ch.name in ("10u", "10v")}
+ if "10u" not in wind_out or "10v" not in wind_out:
+ continue
+
+ o_u, o_v = wind_out["10u"], wind_out["10v"]
+ i_u, i_v = wind_in["10u"], wind_in["10v"]
+
+ # Build wind sources: (label, member_idx, u_2d, v_2d).
+ wind_sources: list = [
+ ("target", None, target[o_u, :, :], target[o_v, :, :]),
+ ("baseline", None, baseline[i_u, :, :], baseline[i_v, :, :]),
+ ]
+ if mean_pred is not None:
+ wind_sources.append(("mean-prediction", None, mean_pred[o_u, :, :], mean_pred[o_v, :, :]))
+ for m, pu, pv in _pred_members(prediction, o_u, o_v):
+ wind_sources.append(("prediction", m, pu, pv))
+
+ title_speed = f"{format_time_str(curr_time)}: FF10m"
+ title_dir = f"{format_time_str(curr_time)}: DD10m"
+ speed_meta = ChannelMeta.get("10u", vmin=0, vmax=10)
+ dir_meta = ChannelMeta.get("10u", vmin=0, vmax=360)
+
+ wind_kinds = (
+ ("FF10m", speed_meta, "viridis", 0, 10, "max", title_speed),
+ ("DD10m", dir_meta, "twilight", 0, 360, "neither", title_dir),
+ )
+ for label, m, u, v in wind_sources:
+ speed = np.hypot(u, v)
+ direction = wind_direction(u, v)
+ for kind, w_meta, w_cmap, w_vmin, w_vmax, w_extend, w_title in wind_kinds:
+ data = speed if kind == "FF10m" else direction
+ save_field(
+ f"{kind}-{label}", data, w_meta, output_files, None, curr_time,
+ member=m, cmap=w_cmap, vmin=w_vmin, vmax=w_vmax, extend=w_extend,
+ custom_path=output_files.wind_file(kind, curr_time, f"{kind}-{label}", m),
+ plot_func=plot_map, title=w_title, grid_cfg=grid_cfg,
+ )
+
+ # Combined wind-speed + precipitation maps
+ tp_out_idx = next((i for i, ch in enumerate(output_channels) if ch.name == "tp"), None)
+ if tp_out_idx is None:
+ continue
+ tp_in_idx = next((i for i, ch in enumerate(input_channels) if ch.name == "tp"), tp_out_idx)
+ title_wp = f"{format_time_str(curr_time)}: FF10m + Precipitation"
+ wp_dir = output_files._ensure_dir("wind_precip")
+
+ wp_sources: list = [
+ ("target", None, target[o_u, :, :], target[o_v, :, :], target[tp_out_idx, :, :]),
+ ("baseline", None, baseline[i_u, :, :], baseline[i_v, :, :], baseline[tp_in_idx, :, :]),
+ ]
+ if mean_pred is not None:
+ wp_sources.append(("mean-prediction", None,
+ mean_pred[o_u, :, :], mean_pred[o_v, :, :], mean_pred[tp_out_idx, :, :]))
+ for m, pu, pv, ptp in _pred_members(prediction, o_u, o_v, tp_out_idx):
+ wp_sources.append(("prediction", m, pu, pv, ptp))
+
+ for label, m, u, v, tp in wp_sources:
+ suffix = f"{label}_{m:02d}" if m is not None else label
+ plot_map_wind_precip(
+ u, v, tp,
+ str(wp_dir / f"{curr_time}-wind_precip-{suffix}"),
+ title=title_wp, grid_cfg=grid_cfg,
+ )
+
+ logger.info(f"Snapshots saved to: {output_path}")
+
+if __name__ == "__main__":
+ main(parse_eval_cli(allow_times=True))
diff --git a/src/hirad/eval/video_from_snapshots.py b/src/hirad/eval/video_from_snapshots.py
new file mode 100644
index 00000000..c1562653
--- /dev/null
+++ b/src/hirad/eval/video_from_snapshots.py
@@ -0,0 +1,248 @@
+import cv2
+import numpy as np
+from PIL import Image
+import os
+import re
+from datetime import datetime
+import glob
+from pathlib import Path
+from tqdm import tqdm
+import subprocess
+import shutil
+
+def parse_filename(filename):
+ """Parse filename to extract date, hour, type, and member"""
+ # Pattern for prediction files: YYYYMMDD-HH-tp-prediction_MM.png
+ pred_pattern = r'(\d{8})-(\d+)-tp-prediction_(\d+)\.png'
+ # Pattern for target/mean files: YYYYMMDD-HH-tp-(target|mean-prediction).png
+ other_pattern = r'(\d{8})-(\d+)-tp-(target|mean-prediction)\.png'
+
+ pred_match = re.match(pred_pattern, filename)
+ if pred_match:
+ date, hour, member = pred_match.groups()
+ return date, hour, 'prediction', int(member)
+
+ other_match = re.match(other_pattern, filename)
+ if other_match:
+ date, hour, img_type = other_match.groups()
+ return date, hour, img_type, None
+
+ return None
+
+def get_sorted_timestamps(image_folder):
+ """Get all unique timestamps sorted chronologically"""
+ timestamps = set()
+ for filename in os.listdir(image_folder):
+ parsed = parse_filename(filename)
+ if parsed:
+ date, hour, _, _ = parsed
+ timestamps.add((date, hour))
+
+ return sorted(list(timestamps))
+
+def load_and_resize_image(filepath, target_size):
+ """Load image and resize to target size, converting RGBA to RGB"""
+ if not os.path.exists(filepath):
+ # Create blank image if file doesn't exist
+ return np.zeros((*target_size[::-1], 3), dtype=np.uint8)
+
+ img = Image.open(filepath)
+
+ # Convert RGBA to RGB if needed
+ if img.mode == 'RGBA':
+ # Create white background
+ background = Image.new('RGB', img.size, (255, 255, 255))
+ background.paste(img, mask=img.split()[-1]) # Use alpha channel as mask
+ img = background
+ elif img.mode != 'RGB':
+ img = img.convert('RGB')
+
+ img = img.resize(target_size, Image.Resampling.LANCZOS)
+ return np.array(img)
+
+def create_grid_layout(images, grid_shape):
+ """Create a grid layout from list of images"""
+ rows, cols = grid_shape
+ if len(images) != rows * cols:
+ # Pad with blank images if needed
+ while len(images) < rows * cols:
+ images.append(np.zeros_like(images[0]))
+
+ # Arrange images in grid
+ image_rows = []
+ for i in range(rows):
+ row_images = images[i*cols:(i+1)*cols]
+ image_rows.append(np.hstack(row_images))
+
+ return np.vstack(image_rows)
+
+
+def create_video_from_existing_images(image_folder, output_path, layout_type="all_members", fps=12):
+ """Create video directly from existing images without creating intermediate frames"""
+ timestamps = get_sorted_timestamps(image_folder)
+
+ if not timestamps:
+ print("No valid images found!")
+ return
+
+ # Get image size
+ sample_file = next(Path(image_folder).glob("*.png"))
+ sample_img = Image.open(sample_file)
+ img_width, img_height = sample_img.size
+ if layout_type == "all_members":
+ img_width //= 8
+ img_height //= 8
+
+ print(f"Sample image size: {img_width} x {img_height}")
+
+ # Create one test frame to get EXACT dimensions
+ if layout_type == "all_members":
+ # Load real images for the first timestamp to get exact dimensions
+ first_date, first_hour = timestamps[0]
+
+ # Load 16 prediction images
+ pred_images = []
+ for member in range(16):
+ filename = f"{first_date}-{first_hour}-tp-prediction_{member}.png"
+ filepath = Path(image_folder) / filename
+ img = load_and_resize_image(str(filepath), (img_width, img_height))
+ pred_images.append(img)
+
+ pred_grid = create_grid_layout(pred_images, (4, 4))
+
+ # Load target and mean and baseline
+ target_file = f"{first_date}-{first_hour}-tp-target.png"
+ target_img = load_and_resize_image(str(Path(image_folder) / target_file), (img_width, img_height))
+
+ mean_file = f"{first_date}-{first_hour}-tp-mean-prediction.png"
+ mean_img = load_and_resize_image(str(Path(image_folder) / mean_file), (img_width, img_height))
+
+ baseline_file = f"{first_date}-{first_hour}-tp-baseline.png"
+ baseline_img = load_and_resize_image(str(Path(image_folder) / baseline_file), (img_width, img_height))
+
+ bottom_row = np.hstack([baseline_img, mean_img, target_img])
+
+ # Pad if needed
+ pad_width = pred_grid.shape[1] - bottom_row.shape[1]
+ if pad_width > 0:
+ padding = np.zeros((bottom_row.shape[0], pad_width, 3), dtype=np.uint8)
+ bottom_row = np.hstack([bottom_row, padding])
+
+ test_frame = np.vstack([pred_grid, bottom_row])
+ frame_height, frame_width = test_frame.shape[:2] # numpy: (height, width)
+
+ else: # single_member
+ frame_width = 4 * img_width
+ frame_height = img_height
+
+ # print(f"Frame dimensions (H x W): {frame_height} x {frame_width}")
+
+ # Initialize VideoWriter with CORRECT dimension order (width, height)
+ fourcc = cv2.VideoWriter_fourcc(*'MJPG')
+ out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height)) # OpenCV: (width, height)
+
+ if not out.isOpened():
+ print("Failed to initialize VideoWriter!")
+ return False
+
+ # print(f"VideoWriter initialized with (W x H): {frame_width} x {frame_height}")
+
+ # Create video
+ for date, hour in tqdm(timestamps, desc=f"Creating {layout_type} video"):
+
+ if layout_type == "all_members":
+ # Load 16 prediction images
+ pred_images = []
+ for member in range(16):
+ filename = f"{date}-{hour}-tp-prediction_{member}.png"
+ filepath = Path(image_folder) / filename
+ img = load_and_resize_image(str(filepath), (img_width, img_height))
+ pred_images.append(img)
+
+ pred_grid = create_grid_layout(pred_images, (4, 4))
+
+ # Load target and mean
+ target_file = f"{date}-{hour}-tp-target.png"
+ target_img = load_and_resize_image(str(Path(image_folder) / target_file), (img_width, img_height))
+
+ mean_file = f"{date}-{hour}-tp-mean-prediction.png"
+ mean_img = load_and_resize_image(str(Path(image_folder) / mean_file), (img_width, img_height))
+
+ baseline_file = f"{date}-{hour}-tp-baseline.png"
+ baseline_img = load_and_resize_image(str(Path(image_folder) / baseline_file), (img_width, img_height))
+
+ bottom_row = np.hstack([baseline_img, mean_img, target_img])
+
+ # Pad if needed (using same logic as test frame)
+ pad_width = pred_grid.shape[1] - bottom_row.shape[1]
+ if pad_width > 0:
+ padding = np.zeros((bottom_row.shape[0], pad_width, 3), dtype=np.uint8)
+ bottom_row = np.hstack([bottom_row, padding])
+
+ frame = np.vstack([pred_grid, bottom_row])
+
+ else: # single_member
+ pred_file = f"{date}-{hour}-tp-prediction_15.png"
+ pred_img = load_and_resize_image(str(Path(image_folder) / pred_file), (img_width, img_height))
+
+ target_file = f"{date}-{hour}-tp-target.png"
+ target_img = load_and_resize_image(str(Path(image_folder) / target_file), (img_width, img_height))
+
+ mean_file = f"{date}-{hour}-tp-mean-prediction.png"
+ mean_img = load_and_resize_image(str(Path(image_folder) / mean_file), (img_width, img_height))
+
+ baseline_file = f"{date}-{hour}-tp-baseline.png"
+ baseline_img = load_and_resize_image(str(Path(image_folder) / baseline_file), (img_width, img_height))
+
+ frame = np.hstack([baseline_img, mean_img, pred_img, target_img])
+
+ # Verify frame dimensions EXACTLY match VideoWriter
+ actual_height, actual_width = frame.shape[:2]
+ if actual_width != frame_width or actual_height != frame_height:
+ print(f"ERROR: Frame size mismatch!")
+ print(f"Expected: {frame_width} x {frame_height}")
+ print(f"Got: {actual_width} x {actual_height}")
+ print(f"Resizing frame to match...")
+ # Force resize to exact dimensions
+ frame = cv2.resize(frame, (frame_width, frame_height))
+
+ # Ensure 3 channels
+ if frame.shape[2] != 3:
+ print(f"ERROR: Frame has {frame.shape[2]} channels, expected 3")
+ if frame.shape[2] == 4:
+ frame = frame[:, :, :3] # Drop alpha channel
+
+ # print(f"Frame shape: {frame.shape}")
+
+ # Convert RGB to BGR for OpenCV
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
+
+ # Write frame
+ success = out.write(frame_bgr)
+ # if not success:
+ # print(f"Failed to write frame for {date}-{hour}")
+ # break
+
+ out.release()
+ print(f"Video saved: {output_path}")
+ return True
+
+
+def main():
+ image_folder = Path("/capstor/scratch/cscs/pstamenk/outputs/generation/generate_8_attention/tp")
+
+ # Create both videos directly from existing images
+ create_video_from_existing_images(
+ image_folder,
+ image_folder / "precipitation_all_members.avi",
+ layout_type="all_members"
+ )
+
+ create_video_from_existing_images(
+ image_folder,
+ image_folder / "precipitation_single_member.avi",
+ layout_type="single_member"
+ )
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/src/hirad/eval_precip.sh b/src/hirad/eval_precip.sh
new file mode 100644
index 00000000..4630d490
--- /dev/null
+++ b/src/hirad/eval_precip.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+
+#SBATCH --job-name="eval_precip"
+
+### HARDWARE ###
+#SBATCH --partition=normal
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=1
+#SBATCH --gpus-per-node=1
+#SBATCH --cpus-per-task=72
+#SBATCH --time=12:00:00
+#SBATCH --no-requeue
+#SBATCH --exclusive
+
+### OUTPUT ###
+#SBATCH --output=./logs/plots_precip.log
+
+### ENVIRONMENT ####
+#SBATCH -A a161
+
+### CONFIG ###
+CONFIG_NAME="src/hirad/conf/eval_real.yaml"
+
+srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -c "
+ pip install -e .
+
+ # Diurnal cycle
+ # python src/hirad/eval/diurnal_cycle_precip_mean_wet-hour.py --config-name=${CONFIG_NAME}
+ # python src/hirad/eval/diurnal_cycle_precip_p99.py --config-name=${CONFIG_NAME}
+
+ # Histograms
+ # python src/hirad/eval/hist.py --config-name=${CONFIG_NAME}
+ # python src/hirad/eval/probability_of_exceedance.py --config-name=${CONFIG_NAME}
+
+ # Maps
+ # python src/hirad/eval/map_precip_stats.py --config-name=${CONFIG_NAME}
+"
\ No newline at end of file
diff --git a/src/hirad/eval_wind.sh b/src/hirad/eval_wind.sh
new file mode 100644
index 00000000..88f3f221
--- /dev/null
+++ b/src/hirad/eval_wind.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+
+#SBATCH --job-name="eval_wind"
+
+### HARDWARE ###
+#SBATCH --partition=normal
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=1
+#SBATCH --gpus-per-node=1
+#SBATCH --cpus-per-task=72
+#SBATCH --time=12:00:00
+#SBATCH --no-requeue
+#SBATCH --exclusive
+
+### OUTPUT ###
+#SBATCH --output=./logs/plots_wind.log
+
+### ENVIRONMENT ####
+#SBATCH -A a161
+
+### CONFIG ###
+CONFIG_NAME="src/hirad/conf/eval_real.yaml"
+
+srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -c "
+ pip install -e .
+
+ # Diurnal cycle
+ # python src/hirad/eval/diurnal_cycle_temp_wind.py --config-name=${CONFIG_NAME}
+
+ # Probability of exceedance
+ # python src/hirad/eval/probability_of_exceedance_wind.py --config-name=${CONFIG_NAME}
+
+ # Maps
+ # python src/hirad/eval/map_wind_stats.py --config-name=${CONFIG_NAME}
+"
\ No newline at end of file
diff --git a/src/hirad/generate.sh b/src/hirad/generate.sh
old mode 100644
new mode 100755
index 87c8979c..40e3c8e4
--- a/src/hirad/generate.sh
+++ b/src/hirad/generate.sh
@@ -1,25 +1,43 @@
#!/bin/bash
+# Generate predictions.
+#
+# Default mode (no SLURM array): single run using config defaults.
+# Monthly array mode (SLURM_ARRAY_TASK_ID and START_MONTH set):
+# each array task generates one month, starting from START_MONTH
+# (array index 0 -> START_MONTH, index 1 -> next month, ...).
+# Submit via ./src/hirad/submit_monthly.sh START_MONTH END_MONTH.
-#SBATCH --job-name="testrun"
+#SBATCH --job-name="generate"
### HARDWARE ###
-#SBATCH --partition=debug
+#SBATCH --partition=normal
#SBATCH --nodes=1
-#SBATCH --ntasks-per-node=1
-#SBATCH --gpus-per-node=1
+#SBATCH --ntasks-per-node=2
+#SBATCH --gpus-per-node=2
#SBATCH --cpus-per-task=72
#SBATCH --time=00:30:00
#SBATCH --no-requeue
#SBATCH --exclusive
### OUTPUT ###
-#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/full_generation.log
-#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/full_generation.err
+#SBATCH --output=./logs/regression_generation_%A_%a.log
### ENVIRONMENT ####
-#SBATCH --uenv=pytorch/v2.6.0:/user-environment
-#SBATCH --view=default
-#SBATCH -A a-a122
+#SBATCH -A a161
+
+set -euo pipefail
+
+# Optional Hydra overrides for monthly array mode.
+EXTRA_ARGS=()
+if [[ -n "${SLURM_ARRAY_TASK_ID:-}" && -n "${START_MONTH:-}" ]]; then
+ START=$(date -u -d "${START_MONTH}-01 +${SLURM_ARRAY_TASK_ID} months" +%Y%m%d-%H%M)
+ END=$(date -u -d "${START_MONTH}-01 +$((SLURM_ARRAY_TASK_ID + 1)) months" +%Y%m%d-%H%M)
+ echo "Generating ${START:0:4}_${START:4:2}: ${START} -> ${END}"
+ EXTRA_ARGS+=(
+ "generation.times_range=[${START},${END},1]"
+ "hydra.run.dir=./outputs/generation/era_real_${START:0:4}"
+ )
+fi
# Choose method to initialize dist in pythorch
export DISTRIBUTED_INITIALIZATION_METHOD=SLURM
@@ -33,19 +51,10 @@ export MASTER_ADDR
export MASTER_PORT=29500
echo "Master port: $MASTER_PORT"
-# Get number of physical cores using Python
-PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))")
-# Use SLURM_NTASKS (number of processes to be launched by torchrun)
-LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1}
-# Compute threads per process
-OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS ))
-export OMP_NUM_THREADS=$OMP_THREADS
-echo "Physical cores: $PHYSICAL_CORES"
-echo "Local processes: $LOCAL_PROCS"
-echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS"
-
-# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml
-srun bash -c "
- . ./train_env/bin/activate
- python src/hirad/inference/generate.py --config-name=generate_era_cosmo.yaml
-"
\ No newline at end of file
+export OMP_NUM_THREADS=1
+
+EXTRA_ARGS_STR="${EXTRA_ARGS[*]@Q}"
+srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -c "
+ pip install -e .
+ python src/hirad/inference/generate.py --config-name=generate_era_real.yaml ${EXTRA_ARGS_STR}
+"
diff --git a/src/hirad/generate_test.sh b/src/hirad/generate_test.sh
new file mode 100644
index 00000000..0f12f9ea
--- /dev/null
+++ b/src/hirad/generate_test.sh
@@ -0,0 +1,48 @@
+#!/bin/bash
+
+#SBATCH --job-name="corrdiff-test-genreate"
+
+### HARDWARE ###
+#SBATCH --partition=debug
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=2
+#SBATCH --gpus-per-node=2
+#SBATCH --cpus-per-task=72
+#SBATCH --time=00:30:00
+#SBATCH --no-requeue
+#SBATCH --exclusive
+
+### OUTPUT ###
+#SBATCH --output=./logs/generation_test.log
+
+### ENVIRONMENT ####
+#SBATCH -A a161
+
+# Choose method to initialize dist in pythorch
+export DISTRIBUTED_INITIALIZATION_METHOD=SLURM
+
+MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)"
+echo "Master node : $MASTER_ADDR"
+# Get IP for hostname.
+MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')"
+echo "Master address : $MASTER_ADDR"
+export MASTER_ADDR
+export MASTER_PORT=29500
+echo "Master port: $MASTER_PORT"
+
+# Get number of physical cores using Python
+# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))")
+# # Use SLURM_NTASKS (number of processes to be launched by torchrun)
+# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1}
+# # Compute threads per process
+# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS ))
+# export OMP_NUM_THREADS=$OMP_THREADS
+export OMP_NUM_THREADS=72
+# echo "Physical cores: $PHYSICAL_CORES"
+# echo "Local processes: $LOCAL_PROCS"
+# echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS"
+
+srun --environment=./ci/edf/modulus_env.toml bash -c "
+ pip install -e . --no-dependencies
+ python src/hirad/inference/generate.py --config-name=generate_era_cosmo_test.yaml
+"
\ No newline at end of file
diff --git a/src/hirad/inference/README.md b/src/hirad/inference/README.md
new file mode 100644
index 00000000..e69de29b
diff --git a/src/hirad/inference/__init__.py b/src/hirad/inference/__init__.py
new file mode 100644
index 00000000..1593b3a8
--- /dev/null
+++ b/src/hirad/inference/__init__.py
@@ -0,0 +1,3 @@
+from .deterministic_sampler import deterministic_sampler
+from .stochastic_sampler import stochastic_sampler
+from .generator import Generator
\ No newline at end of file
diff --git a/src/hirad/utils/deterministic_sampler.py b/src/hirad/inference/deterministic_sampler.py
similarity index 99%
rename from src/hirad/utils/deterministic_sampler.py
rename to src/hirad/inference/deterministic_sampler.py
index e502875e..af02c3bd 100644
--- a/src/hirad/utils/deterministic_sampler.py
+++ b/src/hirad/inference/deterministic_sampler.py
@@ -20,7 +20,7 @@
import nvtx
import torch
-from hirad.models import EDMPrecond
+from hirad.models import EDMPrecondSuperResolution as EDMPrecond
# ruff: noqa: E731
diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py
index 35f856f3..2c5d361c 100644
--- a/src/hirad/inference/generate.py
+++ b/src/hirad/inference/generate.py
@@ -1,51 +1,42 @@
import hydra
import os
import json
+import time
+from collections import defaultdict
from omegaconf import OmegaConf, DictConfig
import torch
import torch._dynamo
-import nvtx
import numpy as np
import contextlib
from hirad.distributed import DistributedManager
from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper
from concurrent.futures import ThreadPoolExecutor
-from functools import partial
-import cartopy.crs as ccrs
-from matplotlib import pyplot as plt
-from einops import rearrange
-from torch.distributed import gather
-
-
-from hydra.utils import to_absolute_path
from hirad.models import EDMPrecondSuperResolution, UNet
-from hirad.utils.patching import GridPatching2D
-from hirad.utils.stochastic_sampler import stochastic_sampler
-from hirad.utils.deterministic_sampler import deterministic_sampler
-from hirad.utils.inference_utils import (
- get_time_from_range,
- regression_step,
- diffusion_step,
-)
+from hirad.inference import Generator
+from hirad.utils.inference_utils import save_results_as_torch
+from hirad.utils.function_utils import get_time_from_range
from hirad.utils.checkpoint import load_checkpoint
+from hirad.utils.dataset_utils import regrid_icon_to_rotlatlon
-
-from hirad.utils.generate_utils import (
- get_dataset_and_sampler
-)
+from hirad.datasets import get_dataset_and_sampler_inference
from hirad.utils.train_helpers import set_patch_shape
-from hirad.eval import compute_mae, average_power_spectrum, plot_error_projection, plot_power_spectra
+def _sync_t() -> float:
+ """Return wall-clock time after synchronizing all pending CUDA ops."""
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ return time.perf_counter()
+
@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate")
def main(cfg: DictConfig) -> None:
"""Generate random dowscaled atmospheric states using the techniques described in the paper
"Elucidating the Design Space of Diffusion-Based Generative Models".
"""
- torch.backends.cudnn.enabled = False
+ # torch.backends.cudnn.enabled = False
# Initialize distributed manager
DistributedManager.initialize()
dist = DistributedManager()
@@ -55,25 +46,22 @@ def main(cfg: DictConfig) -> None:
logger = PythonLogger("generate") # General python logger
logger0 = RankZeroLoggingWrapper(logger, dist)
- # Handle the batch size
- seeds = list(np.arange(cfg.generation.num_ensembles))
- num_batches = (
- (len(seeds) - 1) // (cfg.generation.seed_batch_size * dist.world_size) + 1
- ) * dist.world_size
- all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
- rank_batches = all_batches[dist.rank :: dist.world_size]
-
# Synchronize
if dist.world_size > 1:
torch.distributed.barrier()
+ # Set precision for inference
+ input_dtype = torch.float16 if cfg.generation.perf.get("force_fp16", False) else torch.float32
+
# Parse the inference input times
- if cfg.generation.times_range and cfg.generation.times:
+ if cfg.generation.get("times_range", None) and cfg.generation.get("times", None):
raise ValueError("Either times_range or times must be provided, but not both")
- if cfg.generation.times_range:
+ if cfg.generation.get("times_range", None):
times = get_time_from_range(cfg.generation.times_range, time_format="%Y%m%d-%H%M") #TODO check what time formats we are using and adapt
- else:
+ elif cfg.generation.get("times", None):
times = cfg.generation.times
+ else:
+ raise ValueError("Either times_range or times must be provided")
# Create dataset object
dataset_cfg = OmegaConf.to_container(cfg.dataset)
@@ -81,32 +69,20 @@ def main(cfg: DictConfig) -> None:
has_lead_time = cfg.generation["has_lead_time"]
else:
has_lead_time = False
- dataset, sampler = get_dataset_and_sampler(
+ dataset, sampler = get_dataset_and_sampler_inference(
dataset_cfg=dataset_cfg, times=times, has_lead_time=has_lead_time
)
+ dataset.stats_to_torch(device=dist.device, dtype=input_dtype)
+ dataset.interpolator.to(device=dist.device)
+ is_real_target = dataset_cfg.get("type").split("_")[-1] == "real"
+ if is_real_target:
+ dataset.regrid_indices_real = dataset.regrid_indices_real.to(dist.device)
+ dataset.regrid_weights_real = dataset.regrid_weights_real.to(dist.device, dtype=input_dtype)
img_shape = dataset.image_shape()
img_out_channels = len(dataset.output_channels())
- # Parse the patch shape
- if cfg.generation.patching:
- patch_shape_x = cfg.generation.patch_shape_x
- patch_shape_y = cfg.generation.patch_shape_y
- else:
- patch_shape_x, patch_shape_y = None, None
- patch_shape = (patch_shape_y, patch_shape_x)
- use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape)
- if use_patching:
- patching = GridPatching2D(
- img_shape=img_shape,
- patch_shape=patch_shape,
- boundary_pix=cfg.generation.boundary_pix,
- overlap_pix=cfg.generation.overlap_pix,
- )
- logger0.info("Patch-based training enabled")
- else:
- patching = None
- logger0.info("Patch-based training disabled")
+ #TODO: Isolate loading into the method of generator
# Parse the inference mode
if cfg.generation.inference_mode == "regression":
load_net_reg, load_net_res = True, False
@@ -127,6 +103,10 @@ def main(cfg: DictConfig) -> None:
raise FileNotFoundError(f"Missing config file at '{diffusion_model_args_path}'.")
with open(diffusion_model_args_path, 'r') as f:
diffusion_model_args = json.load(f)
+ # Disable AMP for inference (even if model is trained with AMP)
+ if "amp_mode" in diffusion_model_args:
+ diffusion_model_args["amp_mode"] = False
+ use_apex_gn = diffusion_model_args.get("use_apex_gn", False)
net_res = EDMPrecondSuperResolution(**diffusion_model_args)
@@ -136,14 +116,11 @@ def main(cfg: DictConfig) -> None:
device=dist.device
)
- #TODO fix to use channels_last which is optimal for H100
- net_res = net_res.eval().to(device).to(memory_format=torch.channels_last)
+ net_res = net_res.eval().to(device)
+ if use_apex_gn:
+ net_res = net_res.to(memory_format=torch.channels_last)
if cfg.generation.perf.force_fp16:
net_res.use_fp16 = True
-
- # Disable AMP for inference (even if model is trained with AMP)
- if hasattr(net_res, "amp_mode"):
- net_res.amp_mode = False
else:
net_res = None
@@ -158,6 +135,10 @@ def main(cfg: DictConfig) -> None:
raise FileNotFoundError(f"Missing config file at '{regression_model_args_path}'.")
with open(regression_model_args_path, 'r') as f:
regression_model_args = json.load(f)
+ # Disable AMP for inference (even if model is trained with AMP)
+ if "amp_mode" in regression_model_args:
+ regression_model_args["amp_mode"] = False
+ use_apex_gn_reg = regression_model_args.get("use_apex_gn", False)
net_reg = UNet(**regression_model_args)
@@ -167,320 +148,275 @@ def main(cfg: DictConfig) -> None:
device=dist.device
)
- net_reg = net_reg.eval().to(device).to(memory_format=torch.channels_last)
+ net_reg = net_reg.eval().to(device)
+ if use_apex_gn_reg:
+ net_reg = net_reg.to(memory_format=torch.channels_last)
if cfg.generation.perf.force_fp16:
net_reg.use_fp16 = True
-
- # Disable AMP for inference (even if model is trained with AMP)
- if hasattr(net_reg, "amp_mode"):
- net_reg.amp_mode = False
else:
net_reg = None
- # Reset since we are using a different mode.
+ # Reset since we are using a different mode.
if cfg.generation.perf.use_torch_compile:
+ torch._dynamo.config.cache_size_limit = 264
torch._dynamo.reset()
- # Only compile residual network
- # Overhead of compiling regression network outweights any benefits
if net_res:
- net_res = torch.compile(net_res, mode="reduce-overhead")
-
- # Partially instantiate the sampler based on the configs
- if cfg.sampler.type == "deterministic":
- if cfg.generation.hr_mean_conditioning:
- raise NotImplementedError(
- "High-res mean conditioning is not yet implemented for the deterministic sampler"
- )
- sampler_fn = partial(
- deterministic_sampler,
- num_steps=cfg.sampler.num_steps,
- # num_ensembles=cfg.generation.num_ensembles,
- solver=cfg.sampler.solver,
- )
- elif cfg.sampler.type == "stochastic":
- sampler_fn = partial(stochastic_sampler, patching=patching)
- else:
- raise ValueError(f"Unknown sampling method {cfg.sampling.type}")
-
+ net_res = torch.compile(net_res)
+ if net_reg:
+ net_reg = torch.compile(net_reg)
- # Main generation definition
- def generate_fn(image_lr, lead_time_label):
- with nvtx.annotate("generate_fn", color="green"):
- # (1, C, H, W)
- image_lr = image_lr.to(memory_format=torch.channels_last)
-
- if net_reg:
- with nvtx.annotate("regression_model", color="yellow"):
- image_reg = regression_step(
- net=net_reg,
- img_lr=image_lr,
- latents_shape=(
- cfg.generation.seed_batch_size,
- img_out_channels,
- img_shape[0],
- img_shape[1],
- ), # (batch_size, C, H, W)
- lead_time_label=lead_time_label,
- )
- if net_res:
- if cfg.generation.hr_mean_conditioning:
- mean_hr = image_reg[0:1]
- else:
- mean_hr = None
- with nvtx.annotate("diffusion model", color="purple"):
- image_res = diffusion_step(
- net=net_res,
- sampler_fn=sampler_fn,
- img_shape=img_shape,
- img_out_channels=img_out_channels,
- rank_batches=rank_batches,
- img_lr=image_lr.expand(
- cfg.generation.seed_batch_size, -1, -1, -1
- ), #.to(memory_format=torch.channels_last),
- rank=dist.rank,
- device=device,
- mean_hr=mean_hr,
- lead_time_label=lead_time_label,
- )
- if cfg.generation.inference_mode == "regression":
- image_out = image_reg
- elif cfg.generation.inference_mode == "diffusion":
- image_out = image_res
- else:
- image_out = image_reg[0:1,::] + image_res
-
- # Gather tensors on rank 0
- if dist.world_size > 1:
- if dist.rank == 0:
- gathered_tensors = [
- torch.zeros_like(
- image_out, dtype=image_out.dtype, device=image_out.device
- )
- for _ in range(dist.world_size)
- ]
- else:
- gathered_tensors = None
- torch.distributed.barrier()
- gather(
- image_out,
- gather_list=gathered_tensors if dist.rank == 0 else None,
- dst=0,
- )
- if dist.rank == 0:
- if cfg.generation.inference_mode != "regression":
- return torch.cat(gathered_tensors), image_reg[0:1,::]
- return torch.cat(gathered_tensors), None
- else:
- return None, None
- else:
- #TODO do this for multi-gpu setting above too
- if cfg.generation.inference_mode != "regression":
- return image_out, image_reg
- return image_out, None
+ generator = Generator(
+ net_reg=net_reg,
+ net_res=net_res,
+ batch_size=cfg.generation.seed_batch_size,
+ ensemble_size=cfg.generation.num_ensembles,
+ hr_mean_conditioning=cfg.generation.hr_mean_conditioning,
+ n_out_channels=img_out_channels,
+ inference_mode=cfg.generation.inference_mode,
+ dist=dist,
+ )
+
+ # Parse the patch shape
+ if cfg.generation.patching:
+ patch_shape_x = cfg.generation.patch_shape_x
+ patch_shape_y = cfg.generation.patch_shape_y
+ else:
+ patch_shape_x, patch_shape_y = None, None
+ patch_shape = (patch_shape_y, patch_shape_x)
+ use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape)
+ if use_patching:
+ generator.initialize_patching(img_shape=img_shape,
+ patch_shape=patch_shape,
+ boundary_pix=cfg.generation.boundary_pix,
+ overlap_pix=cfg.generation.overlap_pix,
+ )
+ sampler_params = dict(OmegaConf.to_container(cfg.sampler.params, resolve=True)) if "params" in cfg.sampler else {}
+ sampler_params["use_apex_gn"] = use_apex_gn
+ generator.initialize_sampler(cfg.sampler.type, **sampler_params)
# generate images
output_path = getattr(cfg.generation.io, "output_path", "./outputs")
logger0.info(f"Generating images, saving results to {output_path}...")
batch_size = 1
warmup_steps = min(len(times) - 1, 2)
- # Generates model predictions from the input data using the specified
- # `generate_fn`, and save the predictions to the provided NetCDF file. It iterates
- # through the dataset using a data loader, computes predictions, and saves them along
- # with associated metadata.
+ enable_timing = cfg.generation.perf.get("enable_timing", True)
+ _t = _sync_t if enable_timing else (lambda: 0.0)
torch_cuda_profiler = (
torch.cuda.profiler.profile()
- if torch.cuda.is_available()
+ if torch.cuda.is_available() and cfg.generation.perf.get("profile", False)
else contextlib.nullcontext()
)
torch_nvtx_profiler = (
torch.autograd.profiler.emit_nvtx()
- if torch.cuda.is_available()
+ if torch.cuda.is_available() and cfg.generation.perf.get("profile", False)
else contextlib.nullcontext()
)
with torch_cuda_profiler:
with torch_nvtx_profiler:
+ with torch.inference_mode():
- data_loader = torch.utils.data.DataLoader(
- dataset=dataset, sampler=sampler, batch_size=1, pin_memory=True
- )
- time_index = -1
- if dist.rank == 0:
- writer_executor = ThreadPoolExecutor(
- max_workers=cfg.generation.perf.num_writer_workers
+ data_loader = torch.utils.data.DataLoader(
+ dataset=dataset, sampler=sampler, batch_size=1, pin_memory=True,
+ num_workers=4, persistent_workers=True,
)
- writer_threads = []
-
- # Create timer objects only if CUDA is available
- use_cuda_timing = torch.cuda.is_available()
- if use_cuda_timing:
- start = torch.cuda.Event(enable_timing=True)
- end = torch.cuda.Event(enable_timing=True)
- else:
- # Dummy no-op functions for CPU case
- class DummyEvent:
- def record(self):
- pass
-
- def synchronize(self):
- pass
-
- def elapsed_time(self, _):
- return 0
-
- start = end = DummyEvent()
-
- times = dataset.time()
- for index, (image_tar, image_lr, *lead_time_label) in enumerate(
- iter(data_loader)
- ):
- time_index += 1
+ time_index = -1
if dist.rank == 0:
- logger0.info(f"starting index: {time_index}")
-
- if time_index == warmup_steps:
- start.record()
+ writer_executor = ThreadPoolExecutor(
+ max_workers=cfg.generation.perf.num_writer_workers
+ )
+ writer_threads = []
- # continue
- if lead_time_label:
- lead_time_label = lead_time_label[0].to(dist.device).contiguous()
+ # Create timer objects only if CUDA is available
+ use_cuda_timing = torch.cuda.is_available()
+ if use_cuda_timing:
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
else:
- lead_time_label = None
- image_lr = (
- image_lr.to(device=device)
- .to(torch.float32)
- .to(memory_format=torch.channels_last)
- )
- image_tar = image_tar.to(device=device).to(torch.float32)
- image_out, image_reg = generate_fn(image_lr,lead_time_label)
- if dist.rank == 0:
- batch_size = image_out.shape[0]
- # write out data in a seperate thread so we don't hold up inferencing
- writer_threads.append(
- writer_executor.submit(
- save_images,
- output_path,
- times[sampler[time_index]],
- dataset,
- image_out.cpu().numpy(),
- image_tar.cpu().numpy(),
- image_lr.cpu().numpy(),
- image_reg.cpu().numpy() if image_reg is not None else None,
+ # Dummy no-op functions for CPU case
+ class DummyEvent:
+ def record(self):
+ pass
+
+ def synchronize(self):
+ pass
+
+ def elapsed_time(self, _):
+ return 0
+
+ start = end = DummyEvent()
+
+ # Per-section timing accumulators (wall-clock, GPU-synchronized)
+ step_timings = defaultdict(float)
+ timed_step_count = 0
+
+ #TODO: Isolate static channel loading into the method of generator or reuse training manager static channel loading
+ static_channels = dataset.get_static_data()
+ if static_channels is not None:
+ static_channels = static_channels[None, ::].flip(-2)
+ if use_apex_gn:
+ static_channels = static_channels.to(
+ dist.device,
+ dtype=input_dtype,
+ non_blocking=True,
+ ).to(memory_format=torch.channels_last)
+ else:
+ static_channels = (
+ static_channels.to(dist.device)
+ .to(input_dtype)
+ .contiguous()
+ )
+ lead_time_label = None
+
+ times = dataset.time()
+ # t_iter_end is updated at the end of each loop body; the gap between
+ # t_iter_end[i] and the start of body[i+1] equals DataLoader fetch time.
+ t_iter_end = _t()
+ for index, (image_tar, image_lr, *date_str) in enumerate(
+ iter(data_loader)
+ ):
+ t_iter_start = _t()
+ t_data_load = t_iter_start - t_iter_end
+
+ t_preproc_start = _t()
+
+ time_index += 1
+ if dist.rank == 0:
+ logger0.info(f"starting index: {time_index} time: {times[sampler[time_index]]}")
+
+ if time_index == warmup_steps:
+ start.record()
+
+ savedir = os.path.join(output_path,f"{times[sampler[time_index]]}")
+ os.makedirs(savedir,exist_ok=True)
+
+ #TODO: Move all the data processing inside the generator and just pass raw data to it. This includes regridding, normalization, date embedding creation, etc.
+ # Same as with static channel loading, we can reuse some of the code from training manager for this. This will also make it easier to maintain and update the data processing steps in one place.
+ if is_real_target:
+ image_tar = regrid_icon_to_rotlatlon(
+ image_tar.to(dist.device, dtype=input_dtype),
+ dataset.regrid_indices_real,
+ dataset.regrid_weights_real,
)
+ if dataset.trim_edge > 0:
+ image_tar = image_tar[:, :, dataset.trim_edge:-dataset.trim_edge, dataset.trim_edge:-dataset.trim_edge]
+ else:
+ image_tar = image_tar.reshape(*image_tar.shape[:-1], *dataset.image_shape())
+ if lead_time_label:
+ lead_time_label = lead_time_label[0].to(dist.device).contiguous()
+ else:
+ lead_time_label = None
+ image_lr = dataset.interpolator(image_lr.to(dist.device, dtype=input_dtype)).reshape(*image_lr.shape[:-1], *dataset.image_shape()).flip(-2)
+ image_lr = dataset.normalize_input(image_lr)
+ if use_apex_gn:
+ image_lr = image_lr.to(memory_format=torch.channels_last)
+ date_embedding = None
+ if dataset._n_month_hour_channels:
+ date_embedding = dataset.make_time_grids(*date_str, dist.device, dtype=input_dtype)
+
+ random_seed = cfg.generation.get("random_seed", None)+index if cfg.generation.get("randomize", False) and cfg.generation.get("random_seed", None) is not None else None
+ t_preproc_end = _t()
+
+ t_gen_start = _t()
+ image_out, image_reg = generator.generate(
+ image_lr,
+ static_channels=static_channels,
+ date_embedding=date_embedding,
+ lead_time_label=lead_time_label,
+ randomize=cfg.generation.get("randomize", False),
+ random_seed=random_seed,
+ skip_timing=(not enable_timing or time_index < warmup_steps),
+ )
+ t_gen_end = _t()
+
+ t_postproc_start = _t()
+ if dist.rank == 0:
+ batch_size = image_out.shape[0]
+ # write out data in a seperate thread so we don't hold up inferencing
+ image_tar = image_tar[0].squeeze().cpu().numpy()
+ prediction_ensemble = dataset.denormalize_output(image_out).squeeze().flip(-2).cpu().numpy()
+ baseline = dataset.denormalize_input(image_lr)[0].squeeze().flip(-2).cpu().numpy()
+ if image_reg is not None:
+ mean_pred = dataset.denormalize_output(image_reg)[0].squeeze().flip(-2).cpu().numpy()
+ t_postproc_end = _t()
+
+ t_write_start = _t()
+ if dist.rank == 0:
+ writer_threads.append(
+ writer_executor.submit(
+ save_results_as_torch,
+ savedir,
+ times[sampler[time_index]],
+ prediction_ensemble,
+ image_tar,
+ baseline,
+ mean_pred if image_reg is not None else None,
+ )
+ )
+ t_write_end = _t()
+
+ if enable_timing and time_index >= warmup_steps:
+ timed_step_count += 1
+ step_timings["data_loading"] += t_data_load
+ step_timings["preprocessing"] += t_preproc_end - t_preproc_start
+ step_timings["generation"] += t_gen_end - t_gen_start
+ step_timings["postprocessing"] += t_postproc_end - t_postproc_start
+ step_timings["io_submit"] += t_write_end - t_write_start
+
+ t_iter_end = _t()
+
+ end.record()
+ end.synchronize()
+ elapsed_time = (
+ start.elapsed_time(end) / 1000.0 if use_cuda_timing else 0
+ ) # Convert ms to s
+ timed_steps = time_index + 1 - warmup_steps
+ if dist.rank == 0 and use_cuda_timing:
+ average_time_per_batch_element = elapsed_time / timed_steps / batch_size
+ logger.info(
+ f"Total time to run {timed_steps} steps and {batch_size} members = {elapsed_time} s"
+ )
+ logger.info(
+ f"Average time per batch element = {average_time_per_batch_element} s"
)
- end.record()
- end.synchronize()
- elapsed_time = (
- start.elapsed_time(end) / 1000.0 if use_cuda_timing else 0
- ) # Convert ms to s
- timed_steps = time_index + 1 - warmup_steps
- if dist.rank == 0 and use_cuda_timing:
- average_time_per_batch_element = elapsed_time / timed_steps / batch_size
- logger.info(
- f"Total time to run {timed_steps} steps and {batch_size} members = {elapsed_time} s"
- )
- logger.info(
- f"Average time per batch element = {average_time_per_batch_element} s"
- )
- # make sure all the workers are done writing
- if dist.rank == 0:
- for thread in list(writer_threads):
- thread.result()
- writer_threads.remove(thread)
- writer_executor.shutdown()
+ # Log per-section timing breakdown
+ if dist.rank == 0 and timed_step_count > 0:
+ logger0.info("--- Inference timing breakdown (avg over timed steps, wall-clock GPU-synced) ---")
+ for key in ["data_loading", "preprocessing", "generation", "postprocessing", "io_submit"]:
+ avg = step_timings[key] / timed_step_count
+ logger0.info(f" {key:20s}: {avg:.3f} s/step")
+ # Log generator's internal breakdown (regression / diffusion / gather)
+ gen_timings = generator.get_timings()
+ if gen_timings:
+ logger0.info("--- Generator internal timing breakdown (avg over timed steps, warmup excluded) ---")
+ for key, (total, count) in gen_timings.items():
+ avg = total / count if count > 0 else 0.0
+ if key.endswith("_calls"):
+ logger0.info(f" {key:20s}: {avg:.1f} calls/step (n={count})")
+ else:
+ logger0.info(f" {key:20s}: {avg:.3f} s/step (n={count})")
+ # Derived: average time per individual net() forward call
+ if "diff_net_forward" in gen_timings and "diff_net_forward_calls" in gen_timings:
+ fwd_total, fwd_count = gen_timings["diff_net_forward"]
+ calls_total, calls_count = gen_timings["diff_net_forward_calls"]
+ if calls_total > 0:
+ per_call_ms = (fwd_total / calls_total) * 1000
+ logger0.info(f" {'diff_net_fwd/call':20s}: {per_call_ms:.1f} ms/call")
+
+ # make sure all the workers are done writing
+ if dist.rank == 0:
+ for thread in list(writer_threads):
+ thread.result()
+ writer_threads.remove(thread)
+ writer_executor.shutdown()
if dist.rank == 0:
f.close()
logger0.info("Generation Completed.")
-def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred):
-
- os.makedirs(output_path, exist_ok=True)
-
- longitudes = dataset.longitude()
- latitudes = dataset.latitude()
- input_channels = dataset.input_channels()
- output_channels = dataset.output_channels()
-
- target = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze()),1) #.reshape(len(output_channels),-1)
- prediction = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1) #.reshape(len(output_channels),-1)
- baseline = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze()),1)# .reshape(len(input_channels),-1)
- if mean_pred is not None:
- mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze()),1) #.reshape(len(output_channels),-1)
-
-
- freqs = {}
- power = {}
- for idx, channel in enumerate(output_channels):
- input_channel_idx = input_channels.index(channel)
-
- if channel.name=="tp":
- target[idx,::] = prepare_precipitaiton(target[idx,:,:])
- prediction[idx,::] = prepare_precipitaiton(prediction[idx,:,:])
- baseline[input_channel_idx,:,:] = prepare_precipitaiton(baseline[input_channel_idx])
- if mean_pred is not None:
- mean_pred[idx,::] = prepare_precipitaiton(mean_pred[idx,::])
-
- _plot_projection(longitudes, latitudes, target[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-target.jpg'))
- _plot_projection(longitudes, latitudes, prediction[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-prediction.jpg'))
- _plot_projection(longitudes, latitudes, baseline[input_channel_idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-input.jpg'))
- if mean_pred is not None:
- _plot_projection(longitudes, latitudes, mean_pred[idx,:,:], os.path.join(output_path, f'{time_step}-{channel.name}-mean_prediction.jpg'))
-
- _, baseline_errors = compute_mae(baseline[input_channel_idx,:,:], target[idx,:,:])
- _, prediction_errors = compute_mae(prediction[idx,:,:], target[idx,:,:])
- if mean_pred is not None:
- _, mean_prediction_errors = compute_mae(mean_pred[idx,:,:], target[idx,:,:])
-
-
- plot_error_projection(baseline_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-baseline-error.jpg'))
- plot_error_projection(prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-prediction-error.jpg'))
- if mean_pred is not None:
- plot_error_projection(mean_prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-mean-prediction-error.jpg'))
-
- b_freq, b_power = average_power_spectrum(baseline[input_channel_idx,:,:].squeeze(), 2.0)
- freqs['baseline'] = b_freq
- power['baseline'] = b_power
- #plotting.plot_power_spectrum(b_freq, b_power, target_channels[t_c], os.path.join('plots/spectra/baseline2dt', target_channels[t_c] + '-all_dates'))
- t_freq, t_power = average_power_spectrum(target[idx,:,:].squeeze(), 2.0)
- freqs['target'] = t_freq
- power['target'] = t_power
- p_freq, p_power = average_power_spectrum(prediction[idx,:,:].squeeze(), 2.0)
- freqs['prediction'] = p_freq
- power['prediction'] = p_power
- if mean_pred is not None:
- mp_freq, mp_power = average_power_spectrum(mean_pred[idx,:,:].squeeze(), 2.0)
- freqs['mean_prediction'] = mp_freq
- power['mean_prediction'] = mp_power
- plot_power_spectra(freqs, power, channel.name, os.path.join(output_path, f'{time_step}-{channel.name}-spectra.jpg'))
-
-
-def prepare_precipitaiton(precip_array):
- precip_array = np.clip(precip_array, 0, None)
- epsilon = 1e-2
- precip_array = precip_array + epsilon
- precip_array = np.log(precip_array)
- # log_min, log_max = precip_array.min(), precip_array.max()
- # precip_array = (precip_array-log_min)/(log_max-log_min)
- return precip_array
-
-
-def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None):
-
- """Plot observed or interpolated data in a scatter plot."""
- # TODO: Refactor this somehow, it's not really generalizing well across variables.
- fig = plt.figure()
- fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()})
- p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax)
- ax.coastlines()
- ax.gridlines(draw_labels=True)
- plt.colorbar(p, label="K", orientation="horizontal")
- plt.savefig(filename)
- plt.close('all')
-
if __name__ == "__main__":
main()
\ No newline at end of file
diff --git a/src/hirad/inference/generator.py b/src/hirad/inference/generator.py
new file mode 100644
index 00000000..ff990c1b
--- /dev/null
+++ b/src/hirad/inference/generator.py
@@ -0,0 +1,190 @@
+from typing import Callable
+from functools import partial
+import time
+from collections import defaultdict
+import nvtx
+import numpy as np
+import random
+import torch
+from torch.distributed import gather
+from hirad.utils.inference_utils import regression_step, diffusion_step
+from hirad.distributed import DistributedManager
+from hirad.utils.patching import GridPatching2D
+from hirad.inference import stochastic_sampler, deterministic_sampler
+
+
+def _sync_t() -> float:
+ """Return wall-clock time after synchronizing all pending CUDA ops."""
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ return time.perf_counter()
+
+class Generator():
+ def __init__(self,
+ net_reg: torch.nn.Module,
+ net_res: torch.nn.Module,
+ batch_size: int,
+ ensemble_size: int,
+ hr_mean_conditioning: bool,
+ n_out_channels: int,
+ inference_mode: str,
+ dist: DistributedManager,
+ ):
+
+ self.net_reg = net_reg
+ self.net_res = net_res
+ self.batch_size = batch_size
+ self.hr_mean_conditioning = hr_mean_conditioning
+ self.n_out_channels = n_out_channels
+ self.inference_mode = inference_mode
+ self.ensemble_size = ensemble_size
+ self.dist = dist
+ self.get_rank_batches()
+ self.patching = None
+ self._timings: dict[str, float] = defaultdict(float)
+ self._timing_counts: dict[str, int] = defaultdict(int)
+
+ def get_rank_batches(self, seeds=None):
+ if seeds is None:
+ seeds = list(np.arange(self.ensemble_size))
+ num_batches = (
+ (len(seeds) - 1) // (self.batch_size * self.dist.world_size) + 1
+ ) * self.dist.world_size
+ all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
+ self.rank_batches = all_batches[self.dist.rank :: self.dist.world_size]
+
+ def initialize_sampler(self, sampler_type, **sampler_args):
+ if sampler_type == "deterministic":
+ if self.hr_mean_conditioning:
+ raise NotImplementedError(
+ "High-res mean conditioning is not yet implemented for the deterministic sampler"
+ )
+ self.sampler = partial(
+ deterministic_sampler,
+ **sampler_args
+ )
+ elif sampler_type == "stochastic":
+ self.sampler = partial(stochastic_sampler, patching=self.patching, **sampler_args)
+ else:
+ raise ValueError(f"Unknown sampling method {sampler_type}")
+
+ def get_timings(self) -> dict[str, tuple[float, int]]:
+ """Return accumulated timing stats: {key: (total_seconds, call_count)}."""
+ return {k: (self._timings[k], self._timing_counts[k]) for k in self._timings}
+
+ def initialize_patching(self, img_shape, patch_shape, boundary_pix, overlap_pix):
+ self.patching = GridPatching2D(
+ img_shape=img_shape,
+ patch_shape=patch_shape,
+ boundary_pix=boundary_pix,
+ overlap_pix=overlap_pix,
+ )
+
+ def generate(self, image_lr, static_channels=None, date_embedding=None, lead_time_label=None, randomize=False, random_seed=None, use_apex_gn=False, skip_timing=False):
+ with nvtx.annotate("generate_fn", color="green"):
+ # (1, C, H, W)
+ img_shape = image_lr.shape[-2:]
+
+ _step_timings: dict = {} if not skip_timing else None
+ _t = _sync_t if not skip_timing else (lambda: 0.0)
+
+ if self.net_reg:
+ with nvtx.annotate("regression_model", color="yellow"):
+ _t0 = _t()
+ image_reg = regression_step(
+ net=self.net_reg,
+ img_lr=image_lr,
+ latents_shape=(
+ self.batch_size,
+ self.n_out_channels,
+ img_shape[0],
+ img_shape[1],
+ ), # (batch_size, C, H, W)
+ lead_time_label=lead_time_label,
+ static_channels=static_channels,
+ date_embedding=date_embedding,
+ use_apex_gn=use_apex_gn,
+ _timings=_step_timings,
+ )
+ if not skip_timing:
+ self._timings["regression"] += _t() - _t0
+ self._timing_counts["regression"] += 1
+ if self.net_res:
+ if self.hr_mean_conditioning:
+ mean_hr = image_reg[0:1]
+ else:
+ mean_hr = None
+ if randomize:
+ # Set random seed for numpy
+ if random_seed is not None:
+ np.random.seed((random_seed) % (1 << 31))
+ seeds = np.random.randint(0, 1<<31, size=self.ensemble_size)
+ self.get_rank_batches(seeds=seeds)
+ with nvtx.annotate("diffusion model", color="purple"):
+ _t0 = _t()
+ image_res = diffusion_step(
+ net=self.net_res,
+ sampler_fn=self.sampler,
+ img_shape=img_shape,
+ img_out_channels=self.n_out_channels,
+ rank_batches=self.rank_batches,
+ img_lr=image_lr.expand(
+ self.batch_size, -1, -1, -1
+ ).to(memory_format=torch.channels_last),
+ rank=self.dist.rank,
+ device=image_lr.device,
+ mean_hr=mean_hr,
+ lead_time_label=lead_time_label,
+ static_channels=static_channels,
+ date_embedding=date_embedding,
+ use_apex_gn=use_apex_gn,
+ _timings=_step_timings,
+ )
+ if not skip_timing:
+ self._timings["diffusion"] += _t() - _t0
+ self._timing_counts["diffusion"] += 1
+
+ if not skip_timing and _step_timings:
+ for k, v in _step_timings.items():
+ self._timings[k] += v
+ self._timing_counts[k] += 1
+ if self.inference_mode == "regression":
+ image_out = image_reg[0:1,::]
+ elif self.inference_mode == "diffusion":
+ image_out = image_res
+ else:
+ image_out = image_reg[0:1,::] + image_res
+
+ # Gather tensors on rank 0
+ if self.dist.world_size > 1:
+ if self.dist.rank == 0:
+ gathered_tensors = [
+ torch.zeros_like(
+ image_out, dtype=image_out.dtype, device=image_out.device
+ )
+ for _ in range(self.dist.world_size)
+ ]
+ else:
+ gathered_tensors = None
+
+ _t0 = _t()
+ torch.distributed.barrier()
+ gather(
+ image_out,
+ gather_list=gathered_tensors if self.dist.rank == 0 else None,
+ dst=0,
+ )
+ if not skip_timing:
+ self._timings["gather"] += _t() - _t0
+ self._timing_counts["gather"] += 1
+
+ if self.dist.rank == 0:
+ if self.inference_mode != "regression":
+ return torch.cat(gathered_tensors), image_reg[0:1,::]
+ return torch.cat(gathered_tensors)[0:1,::], None
+ else:
+ return None, None
+ else:
+ if self.inference_mode != "regression":
+ return image_out, image_reg[0:1,::]
+ return image_out, None
diff --git a/src/hirad/utils/stochastic_sampler.py b/src/hirad/inference/stochastic_sampler.py
similarity index 72%
rename from src/hirad/utils/stochastic_sampler.py
rename to src/hirad/inference/stochastic_sampler.py
index 198fde43..05120048 100644
--- a/src/hirad/utils/stochastic_sampler.py
+++ b/src/hirad/inference/stochastic_sampler.py
@@ -16,6 +16,7 @@
from typing import Callable, Optional
+import time
import torch
from torch import Tensor
@@ -23,6 +24,12 @@
from hirad.utils.patching import GridPatching2D
+def _sync_t() -> float:
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ return time.perf_counter()
+
+
def stochastic_sampler(
net: torch.nn.Module,
latents: torch.Tensor,
@@ -32,6 +39,8 @@ def stochastic_sampler(
patching: Optional[GridPatching2D] = None,
mean_hr: Optional[torch.Tensor] = None,
lead_time_label: Optional[torch.Tensor] = None,
+ static_channels: Optional[torch.Tensor] = None,
+ date_embedding: Optional[torch.Tensor] = None,
num_steps: int = 18,
sigma_min: float = 0.002,
sigma_max: float = 800,
@@ -40,6 +49,8 @@ def stochastic_sampler(
S_min: float = 0,
S_max: float = float("inf"),
S_noise: float = 1,
+ use_apex_gn: bool = False,
+ _timings: Optional[dict] = None,
) -> torch.Tensor:
"""
Proposed EDM sampler (Algorithm 2) with minor changes to enable
@@ -97,6 +108,10 @@ def stochastic_sampler(
of `img_lr`. By default None.
lead_time_label : Optional[Tensor], optional
Optional lead time labels. By default None.
+ static_channels : Optional[Tensor], optional
+ Optional static channels input of shape (1, C_static, H, W). By default None.
+ date_embedding : Optional[Tensor], optional
+ Optional date embedding input of shape (B, C_date). By default None.
num_steps : int
Number of time steps for the sampler. By default 18.
sigma_min : float
@@ -114,6 +129,8 @@ def stochastic_sampler(
Maximum time step for applying churn. By default float("inf").
S_noise : float
Noise scaling factor applied during the churn step. By default 1.
+ use_apex_gn : bool
+ Whether Apex's fused group normalization is used.
Returns
-------
@@ -130,8 +147,8 @@ def stochastic_sampler(
# Adjust noise levels based on what's supported by the network.
# Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution.
- sigma_min = max(sigma_min, net.sigma_min)
- sigma_max = min(sigma_max, net.sigma_max)
+ sigma_min = max(sigma_min, net.module.sigma_min if hasattr(net, "module") else net.sigma_min)
+ sigma_max = min(sigma_max, net.module.sigma_max if hasattr(net, "module") else net.sigma_max)
if patching is not None and not isinstance(patching, GridPatching2D):
raise ValueError("patching must be an instance of GridPatching2D.")
@@ -153,6 +170,9 @@ def stochastic_sampler(
f"{img_lr.shape[0]} vs {latents.shape[0]}."
)
+ _t = _sync_t if _timings is not None else (lambda: 0.0)
+ _t_preproc_start = _t()
+
# Time step discretization.
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
t_steps = (
@@ -162,7 +182,8 @@ def stochastic_sampler(
* (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
) ** rho
t_steps = torch.cat(
- [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]
+ [net.module.round_sigma(t_steps) if hasattr(net, "module") else net.round_sigma(t_steps),
+ torch.zeros_like(t_steps[:1])]
) # t_N = 0
batch_size = img_lr.shape[0]
@@ -177,12 +198,36 @@ def stochastic_sampler(
)
x_lr = torch.cat((mean_hr.expand(x_lr.shape[0], -1, -1, -1), x_lr), dim=1)
+ if static_channels is not None:
+ # Expand static channels to batch size
+ if static_channels.shape[-2:] != img_lr.shape[-2:]:
+ raise ValueError(
+ f"mean_hr and img_lr must have the same height and width, "
+ f"but found {mean_hr.shape[-2:]} vs {img_lr.shape[-2:]}."
+ )
+ static_expanded = static_channels.expand(batch_size, -1, -1, -1)
+ x_lr = torch.cat((x_lr, static_expanded), dim=1)
+
# input and position padding + patching
if patching:
+ # print(f"Input for generator beofre patching {x_lr.shape}")
# Patched conditioning [x_lr, mean_hr]
+ if static_channels is not None:
+ img_lr = torch.cat(
+ (img_lr, static_channels.expand(img_lr.shape[0], *static_channels.shape[1:])),
+ dim=1,
+ )
+ # print(f"Shape of img_lr after static channels diffusion patching: img_lr {img_lr.shape}")
+ if date_embedding is not None:
+ date_embedding = date_embedding[:, :, None, None].expand(img_lr.shape[0], date_embedding.shape[1], *img_lr.shape[2:])
+ if use_apex_gn:
+ date_embedding = date_embedding.to(img_lr.dtype, non_blocking=True).to(memory_format=torch.channels_last)
+ else:
+ date_embedding = date_embedding.to(img_lr.dtype, non_blocking=True).contiguous()
+ img_lr = torch.cat((img_lr, date_embedding), dim=1)
# (batch_size * patch_num, C_in + C_out, patch_shape_y, patch_shape_x)
x_lr = patching.apply(input=x_lr, additional_input=img_lr)
-
+ # print(f"Input for generator after patching {x_lr.shape}")
# Function to select the correct positional embedding for each patch
def patch_embedding_selector(emb):
# emb: (N_pe, image_shape_y, image_shape_x)
@@ -190,17 +235,35 @@ def patch_embedding_selector(emb):
return patching.apply(emb[None].expand(batch_size, -1, -1, -1))
else:
+ if date_embedding is not None:
+ date_embedding = date_embedding[:, :, None, None].expand(x_lr.shape[0], date_embedding.shape[1], *x_lr.shape[2:])
+ if use_apex_gn:
+ date_embedding = date_embedding.to(x_lr.dtype, non_blocking=True).to(memory_format=torch.channels_last)
+ else:
+ date_embedding = date_embedding.to(x_lr.dtype, non_blocking=True).contiguous()
+ x_lr = torch.cat((x_lr, date_embedding), dim=1)
+
patch_embedding_selector = None
+ _t_preproc_end = _t()
+
# Main sampling loop.
+ x_lr = x_lr.to(latents.device) # ensure correct device once, before the loop
+ _t_net_forward = 0.0
+ _n_net_forward = 0
+ _t_loop_start = _t()
x_next = latents.to(torch.float64) * t_steps[0]
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
x_cur = x_next
# Increase noise temporarily.
gamma = S_churn / num_steps if S_min <= t_cur <= S_max else 0
- t_hat = net.round_sigma(t_cur + gamma * t_cur)
+ t_hat = net.module.round_sigma(t_cur + gamma * t_cur) if hasattr(net, "module") else net.round_sigma(t_cur + gamma * t_cur)
- x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur)
+ # Only generate noise when it will actually be used (gamma > 0).
+ if gamma > 0:
+ x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur)
+ else:
+ x_hat = x_cur
# Euler step. Perform patching operation on score tensor if patch-based
# generation is used denoised = net(x_hat, t_hat,
@@ -209,8 +272,8 @@ def patch_embedding_selector(emb):
x_hat_batch = (patching.apply(input=x_hat) if patching else x_hat).to(
latents.device
)
- x_lr = x_lr.to(latents.device)
+ _tn0 = _t()
if lead_time_label is not None:
denoised = net(
x_hat_batch,
@@ -234,6 +297,9 @@ def patch_embedding_selector(emb):
class_labels,
embedding_selector=patch_embedding_selector,
).to(torch.float64)
+ _t_net_forward += _t() - _tn0
+ _n_net_forward += 1
+
if patching:
# Un-patch the denoised image
# (batch_size, C_out, img_shape_y, img_shape_x)
@@ -250,6 +316,7 @@ def patch_embedding_selector(emb):
latents.device
)
+ _tn0 = _t()
if lead_time_label is not None:
denoised = net(
x_next_batch,
@@ -267,6 +334,9 @@ def patch_embedding_selector(emb):
class_labels,
embedding_selector=patch_embedding_selector,
).to(torch.float64)
+ _t_net_forward += _t() - _tn0
+ _n_net_forward += 1
+
if patching:
# Un-patch the denoised image
# (batch_size, C_out, img_shape_y, img_shape_x)
@@ -274,4 +344,13 @@ def patch_embedding_selector(emb):
d_prime = (x_next - denoised) / t_next
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
+
+ _t_loop_total = _t() - _t_loop_start
+
+ if _timings is not None:
+ _timings["diff_sampler_preproc"] = _timings.get("diff_sampler_preproc", 0.0) + (_t_preproc_end - _t_preproc_start)
+ _timings["diff_net_forward"] = _timings.get("diff_net_forward", 0.0) + _t_net_forward
+ _timings["diff_net_forward_calls"] = _timings.get("diff_net_forward_calls", 0) + _n_net_forward
+ _timings["diff_loop_overhead"] = _timings.get("diff_loop_overhead", 0.0) + (_t_loop_total - _t_net_forward)
+
return x_next
diff --git a/src/hirad/input_data/__init__.py b/src/hirad/input_data/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/hirad/input_data/calculate_transformed_stats.py b/src/hirad/input_data/calculate_transformed_stats.py
new file mode 100644
index 00000000..0e3d701e
--- /dev/null
+++ b/src/hirad/input_data/calculate_transformed_stats.py
@@ -0,0 +1,73 @@
+import numpy as np
+import yaml
+import os
+from tqdm import tqdm
+import torch
+
+
+def transform_channel(channel_array, channel_name="tp"):
+ if channel_name == "tp":
+ channel_array = np.clip(channel_array, 0, None)
+ channel_array = (np.power(channel_array,0.25)-1)/0.25
+ return channel_array
+
+def main():
+ base_path = '/iopsstor/scratch/cscs/mmcgloho/basic-numpy/era5-realch1/v1.0-channel-subset/train'
+ all_files_input = os.listdir(os.path.join(base_path, 'era-copernicus-interpolated/'))
+ all_files_output = os.listdir(os.path.join(base_path, 'realch1/'))
+ input_info = yaml.safe_load(open(os.path.join(base_path, 'info', 'era.yaml')))
+ output_info = yaml.safe_load(open(os.path.join(base_path, 'info', 'realch1.yaml')))
+
+ tp_input_index = input_info['select'].index('tp')
+ tp_output_index = output_info['select'].index('tp')
+
+ print('Total precipitation input index:', tp_input_index)
+ print('Total precipitation output index:', tp_output_index)
+
+ print(f"Found {len(all_files_input)} input files and {len(all_files_output)} output files.")
+ print("Calculating transformed mean and std...")
+
+ # calculate mean
+ input_values = []
+ output_values = []
+ print("Calculating input mean")
+ for f in tqdm(all_files_input):
+ data = np.load(os.path.join(base_path, 'era-copernicus-interpolated', f))
+ data = transform_channel(data[tp_input_index,:])
+ input_values.append(np.mean(data))
+ input_mean = np.mean(input_values)
+ print("Calculating output mean")
+ for f in tqdm(all_files_output):
+ data = np.load(os.path.join(base_path, 'realch1', f))
+ data = transform_channel(data[tp_output_index,:])
+ output_values.append(np.mean(data))
+ output_mean = np.mean(output_values)
+
+ torch.save(input_mean, os.path.join('./info', 'era5-tp-box_cox_025-mean'))
+ torch.save(output_mean, os.path.join('./info','realch1-tp-box_cox_025-mean'))
+
+ # calculate std
+ input_values = []
+ output_values = []
+ print("Calculating input std")
+ for f in tqdm(all_files_input):
+ data = np.load(os.path.join(base_path, 'era-copernicus-interpolated', f))
+ data = transform_channel(data[tp_input_index,:])
+ input_values.append(np.mean((data - input_mean)**2))
+ input_std = np.sqrt(np.mean(input_values))
+ print("Calculating output std")
+ for f in tqdm(all_files_output):
+ data = np.load(os.path.join(base_path, 'realch1', f))
+ data = transform_channel(data[tp_output_index,:])
+ output_values.append(np.mean((data - output_mean)**2))
+ output_std = np.sqrt(np.mean(output_values))
+
+ torch.save(input_std, os.path.join('./info','era5-tp-box_cox_025-std'))
+ torch.save(output_std, os.path.join('./info','realch1-tp-box_cox_025-std'))
+
+ print(f"Input Mean: {input_mean}, Input Std: {input_std}")
+ print(f"Output Mean: {output_mean}, Output Std: {output_std}")
+ print("Done.")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/src/hirad/input_data/check-fileload.py b/src/hirad/input_data/check-fileload.py
new file mode 100644
index 00000000..ff78b70d
--- /dev/null
+++ b/src/hirad/input_data/check-fileload.py
@@ -0,0 +1,19 @@
+import os
+import sys
+import numpy as np
+
+dir = sys.argv[1]
+files = os.listdir(dir)
+data = np.load(dir + files[0])
+shape = data.shape
+for i in range(len(files)):
+ if i % 10000 == 0:
+ print(i)
+ try:
+ data = np.load(dir + files[i])
+ except:
+ print(f'{files[i]} does not load')
+ if data.shape != shape:
+ print(f'{files[i]} has shape {data.shape}')
+
+
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/copernicus.yaml b/src/hirad/input_data/configs/copernicus.yaml
new file mode 100644
index 00000000..2a9ccbf9
--- /dev/null
+++ b/src/hirad/input_data/configs/copernicus.yaml
@@ -0,0 +1,6 @@
+path: '/capstor/store/mch/msopr/hirad-gen/copernicus-datasets/'
+channels: ['tp']
+start: 2015-10-01
+#start: 2020-01-01
+end: 2020-11-30
+frequency: 1
\ No newline at end of file
diff --git a/src/hirad/input_data/cosmo-all.yaml b/src/hirad/input_data/configs/cosmo-all.yaml
similarity index 79%
rename from src/hirad/input_data/cosmo-all.yaml
rename to src/hirad/input_data/configs/cosmo-all.yaml
index 034210e8..d807691f 100644
--- a/src/hirad/input_data/cosmo-all.yaml
+++ b/src/hirad/input_data/configs/cosmo-all.yaml
@@ -1,4 +1,5 @@
-dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr'
+#dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr'
+dataset: '/capstor/store/mch/msopr/hirad-gen/anemoi-datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr'
select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'hsurf', 'insolation',
'lsm', 'msl',
'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
@@ -13,7 +14,7 @@ select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_
# ALL COSMO CHANNELS.
# Intersection between ERA and COSMO excludes hsurf, tcc, tqv
trim_edge: 19 # Removes boundary
-start: 2016-01-01
-# start: 2015-11-29
-end: 2016-02-29
-# end: 2020-12-31
\ No newline at end of file
+#start: 2016-01-01
+start: 2015-10-01
+#end: 2016-01-01
+end: 2020-12-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/cosmo-static.yaml b/src/hirad/input_data/configs/cosmo-static.yaml
new file mode 100644
index 00000000..9bece97b
--- /dev/null
+++ b/src/hirad/input_data/configs/cosmo-static.yaml
@@ -0,0 +1,8 @@
+dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr'
+select: ['hsurf', 'lsm',]
+# Static cosmo channels
+trim_edge: 19 # Removes boundary
+start: 2016-01-01
+# start: 2015-11-29
+end: 2016-01-01
+# end: 2020-12-31
\ No newline at end of file
diff --git a/src/hirad/input_data/cosmo.yaml b/src/hirad/input_data/configs/cosmo.yaml
similarity index 100%
rename from src/hirad/input_data/cosmo.yaml
rename to src/hirad/input_data/configs/cosmo.yaml
diff --git a/src/hirad/input_data/configs/era-all-2015q4.yaml b/src/hirad/input_data/configs/era-all-2015q4.yaml
new file mode 100644
index 00000000..58242e7f
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-2015q4.yaml
@@ -0,0 +1,19 @@
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2015-10-01
+end: 2015-12-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-2016q1.yaml b/src/hirad/input_data/configs/era-all-2016q1.yaml
new file mode 100644
index 00000000..794cec23
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-2016q1.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2016-03-01
+end: 2016-03-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-2016q2.yaml b/src/hirad/input_data/configs/era-all-2016q2.yaml
new file mode 100644
index 00000000..96aeefd7
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-2016q2.yaml
@@ -0,0 +1,19 @@
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2016-04-01
+end: 2016-06-30
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-2016q3.yaml b/src/hirad/input_data/configs/era-all-2016q3.yaml
new file mode 100644
index 00000000..97ded368
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-2016q3.yaml
@@ -0,0 +1,19 @@
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2016-07-01
+end: 2016-09-30
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-2016q4.yaml b/src/hirad/input_data/configs/era-all-2016q4.yaml
new file mode 100644
index 00000000..a3d035fd
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-2016q4.yaml
@@ -0,0 +1,19 @@
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2016-10-01
+end: 2016-12-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-2017q1.yaml b/src/hirad/input_data/configs/era-all-2017q1.yaml
new file mode 100644
index 00000000..e39399f1
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-2017q1.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2017-01-01
+end: 2017-03-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-2017q2.yaml b/src/hirad/input_data/configs/era-all-2017q2.yaml
new file mode 100644
index 00000000..8504768e
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-2017q2.yaml
@@ -0,0 +1,19 @@
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2017-04-01
+end: 2017-06-30
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-2017q3.yaml b/src/hirad/input_data/configs/era-all-2017q3.yaml
new file mode 100644
index 00000000..97caadd5
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-2017q3.yaml
@@ -0,0 +1,19 @@
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2017-07-01
+end: 2017-09-30
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-2017q4.yaml b/src/hirad/input_data/configs/era-all-2017q4.yaml
new file mode 100644
index 00000000..505c490f
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-2017q4.yaml
@@ -0,0 +1,19 @@
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2017-10-01
+end: 2017-12-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-2018q1.yaml b/src/hirad/input_data/configs/era-all-2018q1.yaml
new file mode 100644
index 00000000..9e3ac5f8
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-2018q1.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2018-01-01
+end: 2018-03-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-2018q2.yaml b/src/hirad/input_data/configs/era-all-2018q2.yaml
new file mode 100644
index 00000000..71182629
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-2018q2.yaml
@@ -0,0 +1,19 @@
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2018-04-01
+end: 2018-06-30
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-2018q3.yaml b/src/hirad/input_data/configs/era-all-2018q3.yaml
new file mode 100644
index 00000000..a951d6ba
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-2018q3.yaml
@@ -0,0 +1,19 @@
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2018-07-01
+end: 2018-09-30
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-2018q4.yaml b/src/hirad/input_data/configs/era-all-2018q4.yaml
new file mode 100644
index 00000000..0cf76140
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-2018q4.yaml
@@ -0,0 +1,19 @@
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2018-10-01
+end: 2018-12-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-201901.yaml b/src/hirad/input_data/configs/era-all-201901.yaml
new file mode 100755
index 00000000..4a2f1bc9
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-201901.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2019-01-01
+end: 2019-01-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-201902.yaml b/src/hirad/input_data/configs/era-all-201902.yaml
new file mode 100755
index 00000000..86c12bfd
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-201902.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2019-02-01
+end: 2019-02-28
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-201903.yaml b/src/hirad/input_data/configs/era-all-201903.yaml
new file mode 100755
index 00000000..351a2ffc
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-201903.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2019-03-01
+end: 2019-03-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-201904.yaml b/src/hirad/input_data/configs/era-all-201904.yaml
new file mode 100755
index 00000000..2eac8680
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-201904.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2019-04-01
+end: 2019-04-30
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-201905.yaml b/src/hirad/input_data/configs/era-all-201905.yaml
new file mode 100755
index 00000000..f0fd37a8
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-201905.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2019-05-01
+end: 2019-05-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-201906.yaml b/src/hirad/input_data/configs/era-all-201906.yaml
new file mode 100755
index 00000000..698e6898
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-201906.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2019-06-01
+end: 2019-06-30
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-201907.yaml b/src/hirad/input_data/configs/era-all-201907.yaml
new file mode 100755
index 00000000..8a3430fd
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-201907.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2019-07-01
+end: 2019-07-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-201908.yaml b/src/hirad/input_data/configs/era-all-201908.yaml
new file mode 100755
index 00000000..54d5b4e6
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-201908.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2019-08-01
+end: 2019-08-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-201909.yaml b/src/hirad/input_data/configs/era-all-201909.yaml
new file mode 100755
index 00000000..48e004fa
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-201909.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2019-09-01
+end: 2019-09-30
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-201910.yaml b/src/hirad/input_data/configs/era-all-201910.yaml
new file mode 100755
index 00000000..f4c7d7f4
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-201910.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2019-10-01
+end: 2019-10-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-201911.yaml b/src/hirad/input_data/configs/era-all-201911.yaml
new file mode 100755
index 00000000..4e6dab71
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-201911.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2019-11-01
+end: 2019-11-30
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-201912.yaml b/src/hirad/input_data/configs/era-all-201912.yaml
new file mode 100755
index 00000000..e8072aff
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-201912.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2019-12-01
+end: 2019-12-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-202004.yaml b/src/hirad/input_data/configs/era-all-202004.yaml
new file mode 100755
index 00000000..4a76c512
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-202004.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr/'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2020-04-01
+end: 2020-04-30
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-202005.yaml b/src/hirad/input_data/configs/era-all-202005.yaml
new file mode 100755
index 00000000..a8356dbd
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-202005.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr/'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2020-05-01
+end: 2020-05-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-202006.yaml b/src/hirad/input_data/configs/era-all-202006.yaml
new file mode 100755
index 00000000..e98a075c
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-202006.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr/'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2020-06-01
+end: 2020-06-30
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-202007.yaml b/src/hirad/input_data/configs/era-all-202007.yaml
new file mode 100755
index 00000000..5d009358
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-202007.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr/'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2020-07-01
+end: 2020-07-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-202008.yaml b/src/hirad/input_data/configs/era-all-202008.yaml
new file mode 100755
index 00000000..d9fd8d65
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-202008.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr/'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2020-08-01
+end: 2020-08-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-202009.yaml b/src/hirad/input_data/configs/era-all-202009.yaml
new file mode 100755
index 00000000..e63731df
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-202009.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr/'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2020-09-01
+end: 2020-09-30
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-202010.yaml b/src/hirad/input_data/configs/era-all-202010.yaml
new file mode 100755
index 00000000..e566ace6
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-202010.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr/'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2020-10-01
+end: 2020-10-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-202011.yaml b/src/hirad/input_data/configs/era-all-202011.yaml
new file mode 100755
index 00000000..a681b3c1
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-202011.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr/'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2020-11-01
+end: 2020-11-30
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era-all-202012.yaml b/src/hirad/input_data/configs/era-all-202012.yaml
new file mode 100755
index 00000000..00f76db3
--- /dev/null
+++ b/src/hirad/input_data/configs/era-all-202012.yaml
@@ -0,0 +1,20 @@
+
+#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr/'
+select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
+ 'lsm', 'msl',
+ 'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
+ 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude', 'sdor', 'slor', 'skt', 'sp', 'tcw', 'tp',
+ 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925',
+ 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925',
+ 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925',
+ 'w_100', 'w_1000', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_50', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925',
+ 'z',
+ 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925',
+]
+# ALL ERA CHANNELS
+# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
+# Note: Bounding dates/area will be done in .py code.
+start: 2020-12-01
+end: 2020-12-31
\ No newline at end of file
diff --git a/src/hirad/input_data/era-all.yaml b/src/hirad/input_data/configs/era-all.yaml
similarity index 79%
rename from src/hirad/input_data/era-all.yaml
rename to src/hirad/input_data/configs/era-all.yaml
index bca216c3..3bb3f8df 100644
--- a/src/hirad/input_data/era-all.yaml
+++ b/src/hirad/input_data/configs/era-all.yaml
@@ -1,6 +1,7 @@
#dataset: '/store_new/mch/msopr/hirad-gen/era5-1h-new/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr'
-dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr/'
select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp', 'insolation',
'lsm', 'msl',
'q_100', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_50', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925',
@@ -14,4 +15,6 @@ select: ['10u', '10v', '2d', '2t', 'cos_julian_day', 'cos_latitude', 'cos_local_
]
# ALL ERA CHANNELS
# Intersection between ERA and COSMO excludes cp, sdor, slor, tcw
-# Note: Bounding dates/area will be done in .py code.
\ No newline at end of file
+# Note: Bounding dates/area will be done in .py code.
+start: 2020-01-01
+end: 2020-01-31
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/era.yaml b/src/hirad/input_data/configs/era.yaml
new file mode 100644
index 00000000..4a161130
--- /dev/null
+++ b/src/hirad/input_data/configs/era.yaml
@@ -0,0 +1,8 @@
+#dataset: '/scratch/mch/apennino/data/aifs-ea-an-oper-0001-mars-n320-1979-2022-6h-v6.zarr'
+#dataset: '/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr'
+dataset: '/capstor/store/cscs/swissai/weatherbench/aifs-ea-an-oper-0001-mars-n320-1979-2023-1h-v1-with-ERA51.zarr/'
+select: ['2t', '10u', '10v', 'tcw', 't_850', 'z_850', 'u_850', 'v_850', 't_500', 'z_500', 'u_500', 'v_500', 'tp']
+ # See table S2 from corrdiff paper for the inputs.
+# Note: Bounding dates/area will be done in .py code.
+start: 2019-01-01
+end: 2019-04-30
\ No newline at end of file
diff --git a/src/hirad/input_data/configs/realch1-2005-2009.yaml b/src/hirad/input_data/configs/realch1-2005-2009.yaml
new file mode 100644
index 00000000..fc112042
--- /dev/null
+++ b/src/hirad/input_data/configs/realch1-2005-2009.yaml
@@ -0,0 +1,8 @@
+#dataset: '/capstor/store/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-ifsnames-v1.0.zarr'
+#dataset: '/store_new/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-v1.0.zarr'
+dataset: '/store_new/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-ifsnames-v1.0.zarr'
+#dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.2.zarr'
+select: ['2t', '10u', '10v', 'tp']
+#select: ['TD_2M', 'U_10M', 'V_10M', 'TOT_PREC_1H'] # TOT_PREC in v0.1
+start: 2005-01-01
+end: 2009-12-31
diff --git a/src/hirad/input_data/configs/realch1-2009-2011.yaml b/src/hirad/input_data/configs/realch1-2009-2011.yaml
new file mode 100644
index 00000000..4e59a48f
--- /dev/null
+++ b/src/hirad/input_data/configs/realch1-2009-2011.yaml
@@ -0,0 +1,8 @@
+#dataset: '/capstor/store/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-ifsnames-v1.0.zarr'
+#dataset: '/store_new/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-v1.0.zarr'
+dataset: '/store_new/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-ifsnames-v1.0.zarr'
+#dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.2.zarr'
+select: ['2t', '10u', '10v', 'tp']
+#select: ['TD_2M', 'U_10M', 'V_10M', 'TOT_PREC_1H'] # TOT_PREC in v0.1
+start: 2009-01-01
+end: 2011-12-31
diff --git a/src/hirad/input_data/configs/realch1-2012-2014.yaml b/src/hirad/input_data/configs/realch1-2012-2014.yaml
new file mode 100644
index 00000000..f7da639d
--- /dev/null
+++ b/src/hirad/input_data/configs/realch1-2012-2014.yaml
@@ -0,0 +1,8 @@
+#dataset: '/capstor/store/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-ifsnames-v1.0.zarr'
+#dataset: '/store_new/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-v1.0.zarr'
+dataset: '/store_new/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-ifsnames-v1.0.zarr'
+#dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.2.zarr'
+select: ['2t', '10u', '10v', 'tp']
+#select: ['TD_2M', 'U_10M', 'V_10M', 'TOT_PREC_1H'] # TOT_PREC in v0.1
+start: 2012-01-01
+end: 2014-12-31
diff --git a/src/hirad/input_data/configs/realch1-2015-2017.yaml b/src/hirad/input_data/configs/realch1-2015-2017.yaml
new file mode 100644
index 00000000..cc7885cf
--- /dev/null
+++ b/src/hirad/input_data/configs/realch1-2015-2017.yaml
@@ -0,0 +1,8 @@
+#dataset: '/capstor/store/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-ifsnames-v1.0.zarr'
+#dataset: '/store_new/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-v1.0.zarr'
+dataset: '/store_new/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-ifsnames-v1.0.zarr'
+#dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.2.zarr'
+select: ['2t', '10u', '10v', 'tp']
+#select: ['TD_2M', 'U_10M', 'V_10M', 'TOT_PREC_1H'] # TOT_PREC in v0.1
+start: 2015-01-01
+end: 2017-12-31
diff --git a/src/hirad/input_data/configs/realch1-2018-2020.yaml b/src/hirad/input_data/configs/realch1-2018-2020.yaml
new file mode 100644
index 00000000..13c967d8
--- /dev/null
+++ b/src/hirad/input_data/configs/realch1-2018-2020.yaml
@@ -0,0 +1,8 @@
+#dataset: '/capstor/store/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-ifsnames-v1.0.zarr'
+#dataset: '/store_new/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-v1.0.zarr'
+dataset: '/store_new/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-ifsnames-v1.0.zarr'
+#dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.2.zarr'
+select: ['2t', '10u', '10v', 'tp']
+#select: ['TD_2M', 'U_10M', 'V_10M', 'TOT_PREC_1H'] # TOT_PREC in v0.1
+start: 2018-01-01
+end: 2020-12-31
diff --git a/src/hirad/input_data/configs/realch1-2021-2024.yaml b/src/hirad/input_data/configs/realch1-2021-2024.yaml
new file mode 100644
index 00000000..4788a998
--- /dev/null
+++ b/src/hirad/input_data/configs/realch1-2021-2024.yaml
@@ -0,0 +1,8 @@
+#dataset: '/capstor/store/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-ifsnames-v1.0.zarr'
+#dataset: '/store_new/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-v1.0.zarr'
+dataset: '/store_new/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-ifsnames-v1.0.zarr'
+#dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.2.zarr'
+select: ['2t', '10u', '10v', 'tp']
+#select: ['TD_2M', 'U_10M', 'V_10M', 'TOT_PREC_1H'] # TOT_PREC in v0.1
+start: 2021-01-01
+end: 2024-12-31
diff --git a/src/hirad/input_data/configs/realch1-all.yaml b/src/hirad/input_data/configs/realch1-all.yaml
new file mode 100644
index 00000000..6b0f7765
--- /dev/null
+++ b/src/hirad/input_data/configs/realch1-all.yaml
@@ -0,0 +1,26 @@
+dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.1.zarr'
+select: ['CLCH', 'CLCL', 'CLCM', 'CLCT',
+ 'FI_100', 'FI_1000', 'FI_150', 'FI_200', 'FI_250', 'FI_300', 'FI_400',
+ 'FI_50', 'FI_500', 'FI_600', 'FI_700', 'FI_850', 'FI_925',
+ 'FR_LAND', 'HSURF',
+ 'OMEGA_100', 'OMEGA_1000', 'OMEGA_150', 'OMEGA_200', 'OMEGA_250',
+ 'OMEGA_300', 'OMEGA_400', 'OMEGA_50', 'OMEGA_500', 'OMEGA_600',
+ 'OMEGA_700', 'OMEGA_850', 'OMEGA_925',
+ 'PLCOV', 'PMSL', 'PS',
+ 'QV_100', 'QV_1000', 'QV_150', 'QV_200', 'QV_250', 'QV_300', 'QV_400',
+ 'QV_50', 'QV_500', 'QV_600', 'QV_700', 'QV_850', 'QV_925',
+ 'SKC', 'SKT', 'SOILTYP', 'SSO_GAMMA', 'SSO_SIGMA', 'SSO_STDH',
+ 'SSO_THETA', 'TD_2M', 'TOT_PREC', 'TOT_PREC_6H',
+ 'T_100', 'T_1000', 'T_150', 'T_200', 'T_250', 'T_2M', 'T_300', 'T_400',
+ 'T_50', 'T_500', 'T_600', 'T_700', 'T_850', 'T_925',
+ 'U_100', 'U_1000', 'U_10M', 'U_150', 'U_200', 'U_250', 'U_300', 'U_400',
+ 'U_50', 'U_500', 'U_600', 'U_700', 'U_850', 'U_925',
+ 'V_100', 'V_1000', 'V_10M', 'V_150', 'V_200', 'V_250', 'V_300', 'V_400',
+ 'V_50', 'V_500', 'V_600', 'V_700', 'V_850', 'V_925',
+ 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude',
+ 'insolation', 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude']
+# ALL REALCH1 CHANNELS.
+start: 2020-01-01
+# start: 2015-11-29
+end: 2020-01-01
+# end: 2020-12-31
diff --git a/src/hirad/input_data/configs/realch1-static.yaml b/src/hirad/input_data/configs/realch1-static.yaml
new file mode 100644
index 00000000..55f42368
--- /dev/null
+++ b/src/hirad/input_data/configs/realch1-static.yaml
@@ -0,0 +1,8 @@
+
+dataset: '/store_new/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-ifsnames-v1.0.zarr'
+#dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.2.zarr'
+select: ['z']
+#select: ['TD_2M', 'U_10M', 'V_10M', 'TOT_PREC_1H'] # TOT_PREC in v0.1
+start: 2021-01-01
+end: 2021-01-01
+# date is irrelevant as we only need one time point
diff --git a/src/hirad/input_data/configs/realch1.yaml b/src/hirad/input_data/configs/realch1.yaml
new file mode 100644
index 00000000..c9c02a02
--- /dev/null
+++ b/src/hirad/input_data/configs/realch1.yaml
@@ -0,0 +1,8 @@
+#dataset: '/capstor/store/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-ifsnames-v1.0.zarr'
+#dataset: '/store_new/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-v1.0.zarr'
+dataset: '/store_new/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-ifsnames-v1.0.zarr'
+#dataset: '/scratch/mch/fzanetta/data/anemoi/datasets/mch-realch1-fdb-1km-2020-2020-1h-pl13-v0.2.zarr'
+select: ['2t', '10u', '10v', 'tp']
+#select: ['TD_2M', 'U_10M', 'V_10M', 'TOT_PREC_1H'] # TOT_PREC in v0.1
+start: 2020-01-01
+end: 2020-01-10
diff --git a/src/hirad/input_data/download_copernicus_tp.py b/src/hirad/input_data/download_copernicus_tp.py
index 55031eee..6eee50cf 100644
--- a/src/hirad/input_data/download_copernicus_tp.py
+++ b/src/hirad/input_data/download_copernicus_tp.py
@@ -5,10 +5,14 @@
"product_type": ["reanalysis"],
"variable": ["total_precipitation"],
"year": [
- "2016"
+ "2015", "2016", "2017",
+ "2018", "2019", "2020",
],
"month": [
- "01", "02"
+ "01", "02", "03",
+ "04", "05", "06",
+ "07", "08", "09",
+ "10", "11", "12",
],
"day": [
"01", "02", "03",
@@ -34,7 +38,9 @@
"21:00", "22:00", "23:00"
],
"data_format": "netcdf",
- "download_format": "unarchived"
+ "download_format": "unarchived",
+ "grid": "N320",
+ "area": [60, 0, 40, 20]
}
client = cdsapi.Client()
diff --git a/src/hirad/input_data/era.yaml b/src/hirad/input_data/era.yaml
deleted file mode 100644
index 3234321a..00000000
--- a/src/hirad/input_data/era.yaml
+++ /dev/null
@@ -1,4 +0,0 @@
-dataset: '/scratch/mch/apennino/data/aifs-ea-an-oper-0001-mars-n320-1979-2022-6h-v6.zarr'
-select: ['2t', '10u', '10v', 'tcw', 't_850', 'z_850', 'u_850', 'v_850', 't_500', 'z_500', 'u_500', 'v_500', 'tp']
- # See table S2 from corrdiff paper for the inputs.
-# Note: Bounding dates/area will be done in .py code.
\ No newline at end of file
diff --git a/src/hirad/input_data/generate_lmdb.py b/src/hirad/input_data/generate_lmdb.py
new file mode 100644
index 00000000..04134fc2
--- /dev/null
+++ b/src/hirad/input_data/generate_lmdb.py
@@ -0,0 +1,72 @@
+import lmdb
+import os
+import torch
+import numpy as np
+
+IN_DIR = '/store_new/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-all-channels/era-interpolated'
+DB_FILENAME='/store_new/mch/msopr/hirad-gen/all-channels-lmdb.db'
+
+class HiradLmdb:
+ def __init__(self, db_path: str, to_write = False):
+ self.env = lmdb.open(db_path)
+ self.txn = self.env.begin(write=to_write)
+ self.cursor = self.txn.cursor()
+ return
+
+ def get_next(self) -> (str, np.ndarray):
+ return key, nparray
+
+
+def fill_database():
+ print('starting')
+
+ lmdb_env = lmdb.open(DB_FILENAME, map_size=int(1e11))
+ lmdb_txn = lmdb_env.begin(write=True)
+ files = sorted(os.listdir(IN_DIR))
+ #for i in range(len(files)):
+ for i in range(10):
+ f = files[i]
+ print(f)
+ torchdata = torch.load(os.path.join(IN_DIR, f), weights_only=False)
+ lmdb_txn.put(f.encode(), torchdata)
+ print(torchdata.shape)
+ print(torchdata.dtype)
+ lmdb_txn.commit()
+ lmdb_env.close()
+
+def get_data_for_time(datefmt: str):
+ lmdb_env = lmdb.open(DB_FILENAME)
+ lmdb_txn = lmdb_env.begin(write=False)
+ data = lmdb_txn.get(datefmt.encode())
+ if data == None:
+ raise KeyError(f'date {datefmt} not found in database')
+ data = np.frombuffer(data, dtype=np.float64).reshape([101,1,191488])
+ lmdb_txn.commit()
+ lmdb_env.close()
+ return data
+
+
+
+def read_database():
+ lmdb_env = lmdb.open(DB_FILENAME)
+ lmdb_txn = lmdb_env.begin(write=False)
+ lmdb_cursor = lmdb_txn.cursor()
+ nsamples = 0
+ for key, value in lmdb_cursor:
+ print (type(key))
+ print(key)
+ print (type(value))
+ nsamples = nsamples + 1
+ print(nsamples)
+
+
+
+def main():
+ fill_database()
+ read_database()
+ nparray = get_data_for_time('20160101-0900')
+ print(nparray.dtype)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/src/hirad/input_data/hdf5.py b/src/hirad/input_data/hdf5.py
new file mode 100644
index 00000000..b45c29da
--- /dev/null
+++ b/src/hirad/input_data/hdf5.py
@@ -0,0 +1,71 @@
+from h5py import File
+import os
+import torch
+import logging
+import numpy as np
+
+IN_DIR = '/store_new/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-all-channels/era-interpolated'
+DB_FILENAME='/store_new/mch/msopr/hirad-gen/input-data-hdf5.db'
+
+class HiradHdf5:
+ def __init__(self, in_dir: str, db_path: str, data_shape: tuple, to_write = False):
+ self.in_dir = in_dir
+ self.files = sorted(os.listdir(in_dir))
+ self.db_path = db_path
+ shape = list(data_shape)
+ shape.insert(0,len(self.files))
+ self.db_shape = tuple(shape)
+ self.to_write = to_write
+ if to_write:
+ self.dbf = File(DB_FILENAME, "w")
+ self.dset = self.dbf.create_dataset("all-channels", self.db_shape, dtype='float64')
+ else:
+ self.dbf = File(DB_FILENAME, "r")
+ self.dset = self.dbf.require_dataset("all-channels", self.db_shape, dtype='float64', exact=True)
+ return
+
+ def fill_database(self):
+ if not self.to_write:
+ raise PermissionError('database not opened with write')
+ logging.info('filling database')
+ for i in range(len(self.files)):
+ #for i in range(10):
+ f = self.files[i]
+ logging.info(f'saving {f}')
+ torchdata = torch.load(os.path.join(IN_DIR, f), weights_only=False)
+ self.dset[i,:] = torchdata
+
+ def read_index(self, i: int):
+ return self.dset[i,:]
+
+ def read_datetime(self, datefmt: str):
+ # Raises ValueError if not found
+ i = self.files.index(datefmt)
+ return self.read_index(i)
+
+ def test_read_database(self):
+ for i in range(10):
+ nparray = self.dset[i,:]
+ f = self.files[i]
+ torchdata = torch.load(os.path.join(IN_DIR, f), weights_only=False)
+ logging.info(np.array_equal(torchdata, nparray))
+
+def main():
+ logging.basicConfig(
+ format='%(asctime)s %(levelname)-8s %(message)s',
+ level=logging.INFO,
+ datefmt='%Y-%m-%d %H:%M:%S')
+ # write DB
+ #hdf5db = HiradHdf5(IN_DIR, DB_FILENAME, (101, 1, 191488), True)
+ #hdf5db.fill_database()
+ # read DB
+ hdf5db = HiradHdf5(IN_DIR, DB_FILENAME, (101, 1, 191488), False)
+ hdf5db.test_read_database()
+ jan10000 = hdf5db.read_index(0)
+ jan10900 = hdf5db.read_datetime('20160101-0900')
+ logging.info(jan10000.shape)
+ logging.info(jan10900.shape)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/src/hirad/input_data/interpolate_basic.py b/src/hirad/input_data/interpolate_basic.py
index 7ec58017..7ba29cbf 100644
--- a/src/hirad/input_data/interpolate_basic.py
+++ b/src/hirad/input_data/interpolate_basic.py
@@ -1,6 +1,6 @@
-import datetime
import logging
import os
+import re
import shutil
import sys
import yaml
@@ -8,247 +8,462 @@
from anemoi.datasets import open_dataset
from anemoi.datasets.data.dataset import Dataset
import cartopy.crs as ccrs
-import matplotlib.pyplot as plt
+import matplotlib.pyplot as plt
+import netCDF4
import numpy as np
from pandas import to_datetime
from scipy.interpolate import griddata
import torch
import multiprocessing
+from hirad.utils.dataset_utils import GridData
+
# Margin to use for ERA dataset (to avoid nans from interpolation at boundary)
ERA_MARGIN_DEGREES = 1.0
-def _read_input(era_config_file: str, cosmo_config_file: str, bound_to_cosmo_area=True) -> tuple[Dataset, Dataset]:
- """
- Read both ERA and COSMO data, optionally bounding to the COSMO data area, and return the 2m
- temperature values for the time range under COSMO.
- """
- # trim edge removes boundary
- with open(cosmo_config_file) as cosmo_file:
- cosmo_config = yaml.safe_load(cosmo_file)
- cosmo = open_dataset(cosmo_config)
- with open(era_config_file) as era_file:
- era_config = yaml.safe_load(era_file)
- era = open_dataset(era_config)
- # Subset the ERA dataset to have COSMO area/dates.
- start_date = cosmo.metadata()['start_date']
- end_date = cosmo.metadata()['end_date']
- # load era5 2m-temperature in the time-range of cosmo
- # area = N, W, S, E
- if bound_to_cosmo_area:
- min_lat = min(cosmo.latitudes) - ERA_MARGIN_DEGREES
- max_lat = max(cosmo.latitudes) + ERA_MARGIN_DEGREES
- min_lon = min(cosmo.longitudes) - ERA_MARGIN_DEGREES
- max_lon = max(cosmo.longitudes) + ERA_MARGIN_DEGREES
- era = open_dataset(era, start=start_date, end=end_date,
- area=(max_lat, min_lon, min_lat, max_lon))
- else:
- era = open_dataset(era, start=start_date, end=end_date)
-
- return (era, cosmo)
+def read_anemoi_ds(config_file: str, start_date = None, end_date = None, area = None) -> Dataset:
+ """Read an Anemoi dataset from config file, given (optional) date/area parameters.
+ Start/end and area from config file will also be subsetted, if present,
+ so start_date and end_date and area parameters will be additional subsetting,
+ not an override.
+ Parameters:
+ config_file: str
+ YAML file with anemoi recipe.
+ start_date, end_date: str (optional)
+ e.g. '2020-01-01', see anemoi open_dataset documentation.
+ area: tuple
+ (N, W, S, E) lat/lon lines to bound the area, see anemoi open_dataset.
-def regrid(era_for_time: np.ndarray, input_grid: np.ndarray, output_grid: np.ndarray):
- # shape (channel, ensemble, grid)
- interpolated_data = np.empty([era_for_time.shape[0], 1, output_grid.shape[0]])
- for j in range(era_for_time.shape[0]):
- values = np.array(era_for_time[j,0,:]) # get era grid values on the given date-time and channel
- regrid = griddata(input_grid, values, output_grid, method='linear') # interpolate era5 to cosmo grid using scipy griddata linear
- interpolated_data[j,0,:] = regrid
- return interpolated_data
+ Returns:
+ Dataset
+ anemoi.Dataset of the dataset in question
+ """
+ with open(config_file) as cfg_file:
+ config = yaml.safe_load(cfg_file)
+ if area:
+ return open_dataset(config, start=start_date, end=end_date, area=area)
+ return open_dataset(config, start=start_date, end=end_date)
-def _interpolate_task(i: int, era: Dataset, cosmo: Dataset, input_grid: np.ndarray, output_grid: np.ndarray, intermediate_files_path: str, outfile_plots_path: str = None, plot_indices=[0]):
- logging.info('interpolating time point ' + _format_date(cosmo.dates[i]))
- interpolated_data = np.empty([era.shape[1], 1, cosmo.shape[3]])
- for j in range(era.shape[1]):
- values = np.array(era[i,j,0,:]) # get era grid values on the given date-time and channel
- regrid = griddata(input_grid, values, output_grid, method='linear') # interpolate era5 to cosmo grid using scipy griddata linear
- interpolated_data[j,0,:] = regrid
- logging.info(f'writing time point { _format_date(cosmo.dates[i])} to files in path {intermediate_files_path}')
- if (intermediate_files_path):
- _save_datetime_file(interpolated_data, era.variables, era.dates[i], os.path.join(intermediate_files_path, "era-interpolated/"))
- _save_datetime_file(era[i,:,:,:], era.variables, era.dates[i], os.path.join(intermediate_files_path, "era/"))
- _save_datetime_file(cosmo[i,:,:,:], cosmo.variables, cosmo.dates[i], os.path.join(intermediate_files_path, "cosmo/"))
- logging.info(f'finished writing time point { _format_date(cosmo.dates[i])}')
+def save_anemoi_latlon_grid(dataset: Dataset, filename: str):
+ """Save lat/lon grid of an Anemoi dataset into a Torch file. (Note that
+ array will have column 0 with latitudes, and column 1 with longitutdes)
+
+ Parameters:
+ dataset: anemoi.Dataset
+ Dataset to extract lat/lon from.
+ filename: str
+ Full file path to output to.
- if outfile_plots_path and i in plot_indices:
- datestr = _format_date(era.dates[i])
- logging.info(f'plotting {datestr} to {outfile_plots_path}')
- for j,var in enumerate(era.variables):
- # plot era original
- _plot_and_save_projection(era.longitudes, era.latitudes, era[i, j, 0, :], f'{outfile_plots_path}{era.variables[j]}-{datestr}-era.jpg')
+ Returns: None
+ """
+ grid = np.column_stack((dataset.latitudes, dataset.longitudes))
+ torch.save(grid, filename)
- _plot_and_save_projection(cosmo.longitudes, cosmo.latitudes, interpolated_data[j, 0, :], f'{outfile_plots_path}{era.variables[j]}-{datestr}-era-interpolated.jpg')
- for j,var in enumerate(cosmo.variables):
- _plot_and_save_projection(cosmo.longitudes, cosmo.latitudes, cosmo[i, j, 0, :], f'{outfile_plots_path}{cosmo.variables[j]}-{datestr}-cosmo.jpg')
+def save_anemoi_stats(dataset: Dataset, filename: str):
+ """Save stats of an Anemoi dataset into a Torch file. (The torch file
+ will be a dictionary of stat to value)
+ Parameters:
+ dataset: anemoi.Dataset
+ Dataset to extract stats from.
+ filename: str
+ Full file path to output to.
+ Returns: None
+ """
+ torch.save(dataset.statistics, filename)
-def _interpolate_basic(era: Dataset, cosmo: Dataset, intermediate_files_path: str, threaded = True, outfile_plots_path: str =None, plot_indices=[0]):
- """Perform simple interpolation from ERA5 to COSMO grid for all data points in the COSMO date range.
+def regrid(input_values_for_time: np.ndarray, input_grid: np.ndarray, output_grid: np.ndarray):
+ """Regrid an array of values for a given time point from an input to output grid.
Parameters:
- era: Dataset
- Pre-loaded anemoi dataset for ERA
- cosmo: Dataset
- Pre-loaded anemoi dataset for COSMO
- intermediate_files_path
- If set, will save each date point to a new file.
+ input_values_for_time: np.ndarray
+ An array of dimension (channels, N) (where N = X x Y)
+ filename: str
+ Full file path to output to.
- Returns:
- np.ndarray
- 4-D array of interpolated values. (date, variable, ensemble, grid-point)
+ Returns: None
"""
- # Check that our date ranges do in fact line up.
- assert (era.start_date == cosmo.start_date and
- era.end_date == cosmo.end_date and
- era.frequency == cosmo.frequency and
- era.shape[0] == cosmo.shape[0]), "ERA and COSMO date ranges or frequencies do not align."
- input_grid = np.column_stack((era.longitudes, era.latitudes)) # stack lon-lat columns of era5 points
- output_grid = np.column_stack((cosmo.longitudes, cosmo.latitudes)) # stack lon-lat column of cosmo points
-
- dates = range(cosmo.shape[0])
-
- if (threaded):
- pool = multiprocessing.Pool()
- for i in dates:
- pool.apply_async(_interpolate_task, (i, era, cosmo, input_grid, output_grid, intermediate_files_path, outfile_plots_path, plot_indices))
-
- pool.close()
- pool.join()
- else:
- for i in dates:
- _interpolate_task(i, era, cosmo, input_grid, output_grid, intermediate_files_path, outfile_plots_path, plot_indices)
+ # shape (channel, grid)
+ assert(len(input_values_for_time.shape) == 2)
+ interpolated_data = np.empty([input_values_for_time.shape[0], output_grid.shape[0]])
+ for j in range(input_values_for_time.shape[0]):
+ values = np.array(input_values_for_time[j,:]) # get era grid values on the given date-time and channel
+ regrid = griddata(input_grid, values, output_grid, method='linear') # interpolate era5 to cosmo grid using scipy griddata linear
+ interpolated_data[j,:] = regrid
+ return interpolated_data
- return
+def regrid_with_interpolator(input_values_for_time: np.ndarray, interpolator: GridData):
+ assert(len(input_values_for_time.shape) == 2)
+ interpolated_data = np.empty([input_values_for_time.shape[0], interpolator.longitudes_target.shape[0]])
+ for j in range(input_values_for_time.shape[0]):
+ values = np.array(input_values_for_time[j,:]) # get input grid values on the given date-time and channel
+ values = values.reshape(1, values.shape[0])
+ regrid = interpolator.interpolate(values) # interpolate input to output grid using GridData method
+ interpolated_data[j,:] = regrid
+ return interpolated_data
-def _format_date(dt64: np.datetime64) -> str:
+def format_date(dt64: np.datetime64) -> str:
"""Makes date string from date time point, for saving files."""
return to_datetime(dt64).strftime('%Y%m%d-%H%M')
-def _save_datetime_file(values: np.ndarray[np.intp], variables: np.ndarray, date: np.datetime64, filepath: str):
- filename = filepath + _format_date(date)
- torch.save(values, filename)
-
-def _save_latlon_grid(dataset: Dataset, filename: str):
- grid = np.column_stack((dataset.latitudes, dataset.longitudes))
- torch.save(grid, filename)
-
-def _save_stats(dataset: Dataset, filename: str):
- torch.save(dataset.statistics, filename)
-
-def _save_interpolation(values: np.ndarray[np.intp], filename: str):
- """Output interpolated data to a given filename, in PyTorch tensor format."""
- torch_data = torch.from_numpy(values)
- torch.save(torch_data, filename)
+def save_datetime_file(values: np.ndarray[np.intp], date: np.datetime64, filepath: str, format='torch'):
+ """saves array of values for a given date into a torch file"""
+ filename = os.path.join(filepath, format_date(date))
+ logging.info(f'writing data to {filename}')
+ if format == 'torch':
+ torch.save(values, filename)
+ elif format == 'numpy':
+ np.save(filename, values)
+ else:
+ raise NotImplementedError(f'invalid format {format}; currently only ' \
+ 'output to torch or numpy')
-def _get_plot_indices(era: Dataset, cosmo: Dataset) -> np.ndarray[np.intp]:
- """
- Get indices of ERA5 data that is in the bounding rectangle of COSMO data.
- This is useful for plotting in the case where read_input(..., bound_to_cosmo_area=False) was used.
- In this case, one would then feed e.g. era.latitudes[indices] into _plot_projection.
- """
- min_lat_cosmo = min(cosmo.latitudes)
- max_lat_cosmo = max(cosmo.latitudes)
- min_lon_cosmo = min(cosmo.longitudes)
- max_lon_cosmo = max(cosmo.longitudes)
- box_lat = np.logical_and(era.latitudes>=min_lat_cosmo,era.latitudes<=max_lat_cosmo)
- box_lon = np.logical_and(era.longitudes>=min_lon_cosmo,era.longitudes<=max_lon_cosmo)
- indices = np.where(box_lon*box_lat)
- return indices
-
-def plot_projection(ax, longitudes: np.array, latitudes: np.array, values: np.array, cmap=None, vmin = None, vmax = None):
- p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax)
+def plot_projection(ax, longitudes: np.array, latitudes: np.array, values: np.array, cmap=None, vmin = None, vmax = None, s = None):
+ """Plot observed or interpolated data in a scatter plot"""
+ p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax, s=s)
ax.coastlines()
- ax.gridlines(draw_labels=False)
+ ax.gridlines(draw_labels=True)
plt.colorbar(p, orientation="horizontal")
-def _plot_and_save_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None):
- """Plot observed or interpolated data in a scatter plot."""
+def plot_and_save_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, projection=ccrs.PlateCarree(), cmap=None, vmin = None, vmax = None, s = None):
+ """Plot observed or interpolated data in a scatter plot and save to file."""
# TODO: Refactor this somehow, it's not really generalizing well across variables.
fig = plt.figure()
- fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()})
+ fig, ax = plt.subplots(subplot_kw={"projection": projection})
logging.info(f'plotting values to {filename}')
- plot_projection(ax, longitudes, latitudes, values, cmap, vmin, vmax)
- #p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax)
- #ax.coastlines()
- #ax.gridlines(draw_labels=True)
- #plt.colorbar(p, orientation="horizontal")
+ plot_projection(ax, longitudes, latitudes, values, cmap, vmin, vmax, s)
plt.savefig(filename)
plt.close('all')
-def interpolate_and_save(infile_era: str, infile_cosmo: str, outfile_data_path: str, threaded=True, outfile_plots_path: str = None, plot_indices=[0]):
- """Read both ERA and COSMO data and perform basic interpolation. Save output into Pytorch format, and (optionally) plot
- ERA, COSMO, and interpolated data.
+def interpolate_anemoi_time_point_to_grid(i: int, ds: Dataset, ds_name: str, input_grid: np.ndarray, output_grid: np.ndarray, output_data_path: str, format='torch', output_plots_path: str = None, plot_indices=[0]):
+ """Interpolate a given time index in a dataset from its grid (input_grid)
+ to an output_grid, and save the interpolated values. In certain cases,
+ additionally save a plot of the input and interpolated data.
+
+
+ i: int
+ Index of time to interpolate (0 is the first time point). This will
+ correspond to ds.dates[i]
+ ds: anemoi.Dataset
+ anemoi.Dataset to interpolate
+ ds_name: str
+ Name for the dataset (e.g. 'era'). This name will be used for the output
+ directory ('era-interpolated') and in the plot filenames.
+ input_grid: np.ndarray
+ An ndarray of shape (N,2), where N is the number of datapoints (N = X x Y).
+ ATTN: Longitudes are column 0 and Latitudes are column 1.
+ This should be equal to np.column_stack((ds.longitudes, ds.latitudes))
+ output_grid: np.ndarray
+ Target grid to interpolate to.
+ An ndarray of shape (N,2), where N is the number of datapoints (N = X x Y).
+ ATTN: Longitudes are column 0 and Latitudes are column 1.
+ output_data_path: str
+ Path of directory for output data.
+ format: str (Optional)
+ Format of output (torch or numpy)
+ output_plots_path: str (Optional)
+ Path of directory to output plots. If None, no plots will be created.
+ plot_indices: Array (Optional)
+ Indices of time points for which to plot data
+ """
+ logging.info('interpolating time point ' + format_date(ds.dates[i]))
+ # remove ensemble (3rd) dimension
+ interpolated_data = regrid(ds[i,:,0,:], input_grid=input_grid, output_grid=output_grid)
+ logging.info(f'writing time point { format_date(ds.dates[i])} to files in path {output_data_path}')
+ save_datetime_file(interpolated_data, ds.dates[i], os.path.join(output_data_path), format=format)
+ if output_plots_path and i in plot_indices:
+ datestr = format_date(ds.dates[i])
+ logging.info(f'plotting {datestr} to {output_plots_path}')
+ #for j in range(10):
+ for j in range(len(ds.variables)):
+ var = ds.variables[j]
+ # plot era original
+ plot_and_save_projection(input_grid[:,0], input_grid[:,1], ds[i, j, 0, :], f'{output_plots_path}/{ds.variables[j]}-{datestr}-{ds_name}.jpg')
+ # plot interpolated
+ plot_and_save_projection(output_grid[:,0], output_grid[:,1], interpolated_data[j, :], f'{output_plots_path}/{ds.variables[j]}-{datestr}-{ds_name}-interpolated.jpg')
+
+def save_anemoi_time_point(i: int, ds: Dataset, ds_name: str, data_output_path: str, plots_output_path: str = None, plot_indices=[0], format='torch'):
+ """Save a time point of anemoi data (either input or target) directly into a given format.
+ If the time point is in the """
+ logging.info(f'saving {ds_name} {ds.dates[i]}')
+ save_datetime_file(ds[i,:,0,:], ds.dates[i], data_output_path, format)
+ datestr = format_date(ds.dates[i])
+ if plots_output_path and i in plot_indices:
+ for j,var in enumerate(ds.variables):
+ plot_and_save_projection(ds.longitudes, ds.latitudes, ds[i, j, 0, :], f'{plots_output_path}/{var}-{datestr}-{ds_name}.jpg')
+
+def interpolate_anemoi_to_grid(infile_anemoi: str, ds_name: str, output_grid: np.ndarray, output_path: str, format='torch', plot_indices=[0]):
+ """Perform basic interpolation on an input dataset in anemoi format from
+ its native grid to a given output grid.
+ Save output as intermediate datetime files in a given format (torch/numpy)
+ Optionally plot interpolated data.
Parameters:
- infile_era: str
- Local file path to ERA5 data
- infile_cosmo: str
- Local file path to COSMO2 data
- outfile_data_path: str
- Local file path to intended output file
- outfile_plots_path: str (Optional)
- Local file path to plots. If specified, plots will be saved as "{plotfilepath_prefix}-(era|cosmo|interpolated).jpg"
-
- Returns:
- tuple[Dataset, Dataset]
- A tuple of ERA and COSMO 2m temperature data, in anemoi Dataset format, restricted to COSMO's date ranges
- (optionally the COSMO area as well).
+ infile_anemoi: str
+ Path to an anemoi recipe in YAML format.
+ ds_name: str
+ Name for the dataset (e.g. 'era'). This name will be used for the output
+ directory ('era-interpolated') and in the plot filenames.
+ output_grid: np.ndarray
+ An ndarray of shape (N,2), where N is the number of datapoints (N = X x Y).
+ ATTN: Longitudes are column 0 and Latitudes are column 1.
+ output_path: str
+ Path of parent directory for output. (sub-directories for plots, info,
+ and interpolated data will be created if they do not already exist)
+ format: str (Optional)
+ Format of output (torch or numpy)
+ plot_indices: Array (Optional)
+ Indices of time points for which to plot data
"""
- os.makedirs(outfile_data_path, exist_ok=True)
- os.makedirs(os.path.join(outfile_data_path, "info"), exist_ok=True)
- os.makedirs(os.path.join(outfile_data_path, "era"), exist_ok=True)
- os.makedirs(os.path.join(outfile_data_path, "cosmo"), exist_ok=True)
- os.makedirs(os.path.join(outfile_data_path, "era-interpolated"), exist_ok=True)
- if outfile_plots_path:
- os.makedirs(outfile_plots_path, exist_ok=True)
-
- logging.info(f'reading input according to configs {infile_era} and {infile_cosmo}')
- era, cosmo = _read_input(infile_era, infile_cosmo, bound_to_cosmo_area=True)
+ os.makedirs(os.path.join(output_path, 'info'), exist_ok=True)
+ os.makedirs(os.path.join(output_path, 'plots'), exist_ok=True)
+ os.makedirs(os.path.join(output_path, f'{ds_name}-interpolated'), exist_ok=True)
+
+ # read data
+ lats = output_grid[:,1]
+ lons = output_grid[:,0]
+ min_lat = min(lats) - ERA_MARGIN_DEGREES
+ max_lat = max(lats) + ERA_MARGIN_DEGREES
+ min_lon = min(lons) - ERA_MARGIN_DEGREES
+ max_lon = max(lons) + ERA_MARGIN_DEGREES
+ area=(max_lat, min_lon, min_lat, max_lon)
+ logging.info(f'projecting onto era area {area}')
+ ds = read_anemoi_ds(infile_anemoi, area = area)
logging.info('Successfully read input')
-
+
# Output stats and grid
- _save_stats(era, os.path.join(outfile_data_path, "info/era-stats"))
- _save_stats(cosmo, os.path.join(outfile_data_path, "info/cosmo-stats"))
- _save_latlon_grid(cosmo, os.path.join(outfile_data_path, "info/cosmo-lat-lon"))
- _save_latlon_grid(era, os.path.join(outfile_data_path, "info/era-lat-lon"))
+ save_anemoi_stats(ds, os.path.join(output_path, f'info/{ds_name}-stats'))
+ save_anemoi_latlon_grid(ds, os.path.join(output_path, f'info/{ds_name}-lat-lon'))
# Copy the .yaml files over for recording purposes
- shutil.copy(infile_cosmo, os.path.join(outfile_data_path, "info/cosmo.yaml"))
- shutil.copy(infile_era, os.path.join(outfile_data_path, "info/era.yaml"))
+ shutil.copy(infile_anemoi, os.path.join(output_path, f'info/{ds_name}.yaml'))
+
+ input_grid = np.column_stack((ds.longitudes, ds.latitudes))
+
+ for i in range(len(ds.dates)):
+ interpolate_anemoi_time_point_to_grid(i, ds, ds_name, input_grid, output_grid,
+ os.path.join(output_path, f'{ds_name}-interpolated'),
+ format=format,
+ output_plots_path=os.path.join(output_path, 'plots'),
+ plot_indices=plot_indices)
+
+
+def save_anemoi_as_format(infile_anemoi: str, ds_name: str, output_path: str, plot_indices=[0], format='torch',
+ start_date = None, end_date = None, area = None):
+ """ Output anemoi data into the same file structure/format as the regridded data, e.g.
+ to use as target data.
+ No regridding is performed."""
+ os.makedirs(os.path.join(output_path, 'info'), exist_ok=True)
+ plots_path = os.path.join(output_path, 'plots')
+ os.makedirs(plots_path, exist_ok=True)
+ ds_output_path = os.path.join(output_path, ds_name)
+ os.makedirs(ds_output_path, exist_ok=True)
+
+ ds = read_anemoi_ds(infile_anemoi, start_date=start_date, end_date=end_date, area=area)
+ # Copy the .yaml files over for recording purposes
+ shutil.copy(infile_anemoi, os.path.join(output_path, f'info/{ds_name}.yaml'))
+ save_anemoi_stats(ds, os.path.join(output_path, f'info/{ds_name}-stats'))
+ save_anemoi_latlon_grid(ds, os.path.join(output_path, f'info/{ds_name}-lat-lon'))
+ os.makedirs(ds_output_path, exist_ok=True)
+ for i in range(len(ds.dates)):
+ save_anemoi_time_point(i, ds, ds_name, data_output_path=ds_output_path, plots_output_path=plots_path, plot_indices=[0], format=format)
+
+def load_netcdf_file(path: str, variable: str, index_date: np.datetime64):
+ """
+ Get the corresponding netCDF file for a given variable, that includes
+ a given date.
- # generate interpolated data
- _interpolate_basic(era, cosmo, outfile_data_path, threaded=threaded, outfile_plots_path=outfile_plots_path, plot_indices=plot_indices)
+ Returns: tuple of (netCDF.Dataset, int) where
+ int is the index of where the date is within that dataset.
+ """
+ files = os.listdir(path)
+ pattern = '(.*)-([0-9]+)-([0-9]+).nc'
+ for filename in files:
+ matches = re.match(pattern, filename)
+ if matches:
+ f_var = matches[1]
+ f_start_year = int(matches[2])
+ f_end_year = int(matches[3])
+ if f_var == variable and to_datetime(index_date).year >= f_start_year and to_datetime(index_date).year <= f_end_year:
+ ds = netCDF4.Dataset(os.path.join(path, filename))
+ # Raises ValueError if number of instances not exactly 1, which indicates
+ # an implementation error somewhere.
+ index = np.where(ds['valid_time'][:] == index_date.astype('datetime64[s]').astype('int'))[0].item()
+ netcdf_date = ds['valid_time'][index]
+ logging.info(f'index {index} has datetime {format_date(np.datetime64(int(netcdf_date), 's'))}')
+ return ds, index
+ raise FileNotFoundError(f'Could not find .nc file for variable {variable} and date {index_date} in {path}')
+
+def load_netcdf_files_as_dict(input_path: str, variables: list, index_date: np.datetime64, expected_frequency: np.timedelta64,
+ reference_nc_dataset=None):
+ """
+ Get the corresponding netCDF Datasets for a given list of variables, that includes
+ a given date. Also checks to make sure each dataset lines up in terms of
+ dates and grids, with each other. Additionally, checks that grids match
+ up against another reference dataset.
+
+ Returns: tuple of (dict[netCDF.Dataset], int) where
+ int is the index of where the date is within that dataset.
+ It should be the same for all datasets
+ """
+ # Set up a dict that corresponds to [year][variable] with references to the
+ # corresponding NC dataset.
+ # (Note: I believe that because the NetCDF datasets are read-only, they
+ # will be cached so it doesn't matter for memory purposes if they are stored
+ # as objects or references)
+ curr_nc = {}
+ indices = {}
+ for var in variables:
+ ds, index = load_netcdf_file(input_path, var, index_date)
+ curr_nc[var] = ds
+ indices[var] = index
+
+ # Check that all the NC datasets match up in terms of time and grid
+ nc_times = curr_nc[variables[0]]['valid_time']
+ nc_date_index = indices[variables[0]]
+ grid_size = curr_nc[variables[0]][variables[0]][:].shape[1:]
+ nc_latitudes = curr_nc[variables[0]]['latitude'][:]
+ nc_longitudes = curr_nc[variables[0]]['longitude'][:]
+ if reference_nc_dataset:
+ reference_latitudes = curr_nc[variables[0]]['latitude'][:]
+ reference_longitudes = curr_nc[variables[0]]['longitude'][:]
+ assert np.array_equal(reference_longitudes, nc_longitudes), 'New NC datasets longitudes do not match reference dataset'
+ assert np.array_equal(reference_latitudes, nc_latitudes), 'New NC datasets latitudes do not match reference dataset'
+ for v in range(1, len(variables)):
+ # Check the times line up
+ more_times = curr_nc[variables[v]]['valid_time'][:]
+ assert np.array_equal(nc_times, more_times), 'Times between variable datasets do not line up; this is not yet supported'
+ assert len(more_times) == len(nc_times), 'Variable datasets appear to have different frequencies; this is not yet supported'
+ # Just for ease of use, we won't allow different indices.
+ assert nc_date_index == indices[variables[v]], 'Variable datasets do not line up with same start date'
+
+ # Check frequecy matches the config
+ nc_delta = np.datetime64(int(more_times[1]), 's') - np.datetime64(int(more_times[0]), 's')
+ assert nc_delta == expected_frequency, 'Frequency of NetCDF dataset for variable {variables[v]} is not the same as requested frequency.'
+
+ # Check the grid size and lat/lon is consistent
+ curr_grid_size = curr_nc[variables[v]][variables[v]][:].shape[1:]
+ assert np.array_equal(grid_size, curr_grid_size), 'NC datasets appear to have different grid sizes'
+ lon = curr_nc[variables[v]]['longitude'][:]
+ lat = curr_nc[variables[v]]['latitude'][:]
+ logging.info(lon)
+ logging.info(lon.shape)
+ logging.info(nc_longitudes)
+ logging.info(nc_longitudes.shape)
+ assert np.array_equal(lon, nc_longitudes), 'NC datasets appear to have different longitudes'
+ assert np.array_equal(lat, nc_latitudes), 'NC datasets appear to have different longitudes'
+ return curr_nc, nc_date_index
+
+def extract_netcdf_input_grid_025(nc: netCDF4.Dataset):
+ """
+ Gives reshaped lon-lat coordinates for NetCDF dataset.
+ For X x Y grid, shape will be (X*Y, 2)
+ Outputs in longitude as first column, latitude as second column,
+ for feeding into regridding.
+ """
+ logging.info('extracting lat/lon')
+ lon = nc['longitude'][:]
+ lat = nc['latitude'][:]
+ output_lon = np.zeros(len(lat) * len(lon))
+ output_lat = np.zeros(len(lat)* len(lon))
+ for i in range(len(lat)):
+ for j in range(len(lon)):
+ grid_index = i * len(lon) + j
+ output_lon[grid_index] = lon[j]
+ output_lat[grid_index] = lat[i]
+ return np.column_stack((output_lon, output_lat))
+
+def interpolate_netcdf_to_grid(infile_nc: str, ds_name: str, output_grid: np.ndarray, output_path: str, format='torch', plot_indices=[0]):
+ """
+ Corresponds to interpolate_anemoi_to_grid, for netCDF files.
+ More assumptions are made here about e.g. the filenames of the NetCDF files.
+
+ infile_nc: str: Filepath to a .yaml file with config info
+ ds_name: Dataset name (e.g. 'copernicus'), for output filenames
+ output_grid: np.ndarray of lon/lat coordinates to regrid to
+ output_path: parent directory for output. Subdirectories for data and plots will be created.
+ format: format to output data to ('torch' or 'numpy')
+ plot_indices: indices of time points to plot.
+ """
+ # set up output dirs
+ logging.info(f'setting up subdirs in {output_path}')
+ os.makedirs(os.path.join(output_path, 'info'), exist_ok=True)
+ os.makedirs(os.path.join(output_path, 'plots'), exist_ok=True)
+ os.makedirs(os.path.join(output_path, f'{ds_name}'), exist_ok=True)
+ os.makedirs(os.path.join(output_path, f'{ds_name}-interpolated'), exist_ok=True)
+
+ # extract from yaml config
+ with open(infile_nc) as cfg_file:
+ config = yaml.safe_load(cfg_file)
+ input_path = config['path'] # string
+ variables = config['channels'] # array
-def plot_tp(path_6h: str, path_1h: str):
- fig, axs = plt.subplots(2, 3, subplot_kw={"projection": ccrs.PlateCarree()})
+ start_date = np.datetime64(config['start']) # np.datetime64 type
+ end_date = np.datetime64(config['end']) # datetime type
+ frequency = np.timedelta64(int(config['frequency']), 'h')
- logging.info(f'plotting values to {filename}')
- p = ax.scatter(x=longitudes, y=latitudes, c=values)
- ax.coastlines()
- ax.gridlines(draw_labels=True)
- plt.colorbar(p, label="absolute error", orientation="horizontal")
- plt.savefig(filename)
- plt.close('all')
-
-def main():
- # TODO: Do better arg parsing so it's not as easy to reverse era and cosmo configs.
- if len(sys.argv) < 4:
- raise ValueError('Expected call interpolate_basic.py [era.yaml] [cosmo.yaml] [output directory]')
- infile_era = sys.argv[1]
- infile_cosmo = sys.argv[2]
- output_directory = sys.argv[3]
+ # Set up a dict that corresponds to [year][variable] with references to the
+ # corresponding NC dataset.
+ curr_nc, nc_date_index = load_netcdf_files_as_dict(input_path, variables, start_date, frequency)
+ grid_size = curr_nc[variables[0]][variables[0]][:].shape[1:]
+ input_grid = extract_netcdf_input_grid_025(curr_nc[variables[0]])
+
+ # Output grids as torch files
+ torch.save(np.column_stack((input_grid[:,1], input_grid[:,0])), os.path.join(output_path, 'info', f'{ds_name}-lat-lon'))
+
+ torch.save(np.column_stack((output_grid[:,1], output_grid[:,0])), os.path.join(output_path, 'info', f'target-lat-lon'))
+ # TODO consider outputting stats, but this would require additional calculations
+
+ # Copy the .yaml file over for recording purposes
+ shutil.copy(infile_nc, os.path.join(output_path, f'info/{ds_name}.yaml'))
- logging.basicConfig(
- filename=os.path.join(output_directory, 'interpolate_basic.log'),
- format='%(asctime)s %(levelname)-8s %(message)s',
- level=logging.INFO,
- datefmt='%Y-%m-%d %H:%M:%S')
- interpolate_and_save(infile_era, infile_cosmo, output_directory, threaded=False, outfile_plots_path=os.path.join(output_directory, "plots/"))
+
-if __name__ == "__main__":
- main()
+ # Set up the interpolator
+ interpolator = GridData(input_grid[:,0], input_grid[:,1], output_grid[:,0], output_grid[:,1])
+
+ # Iterate through each date
+ t = start_date
+ t_i = 0
+ while to_datetime(t).date() <= to_datetime(end_date).date():
+ logging.info(f'processing {t} nc date index {nc_date_index}')
+ # Set up an array to hold input data
+ input_values = np.ndarray((len(variables), grid_size[0]*grid_size[1]))
+ if nc_date_index >= curr_nc[variables[0]][variables[0]][:].shape[0]:
+ logging.info(f'{t} not found in current files, loading new netcdf files')
+ curr_nc, nc_date_index = load_netcdf_files_as_dict(input_path, variables, t, frequency,
+ reference_nc_dataset=curr_nc[variables[0]])
+ #timestamp = curr_nc[variables[0]]['valid_time'][nc_date_index]
+ #logging.info(f'time {t} has timestamp {timestamp}')
+ for v in range(len(variables)):
+ values = curr_nc[variables[v]][variables[v]][:][nc_date_index,:]
+ input_values[v,:] = values.flatten()
+ # Save the input data
+ save_datetime_file(input_values, t, os.path.join(output_path, ds_name), format=format)
+
+ # Regrid
+ interpolated_data = regrid_with_interpolator(input_values, interpolator)
+ save_datetime_file(interpolated_data, t, os.path.join(output_path, f'{ds_name}-interpolated'), format=format)
+
+ # Plot
+ if t_i in plot_indices:
+ output_plots_path = os.path.join(output_path, 'plots')
+ datestr = format_date(t)
+ logging.info(f'plotting {datestr} to {output_plots_path}')
+ #for j in range(10):
+ for j in range(len(variables)):
+ var = variables[j]
+ # plot original
+ plot_and_save_projection(input_grid[:,0], input_grid[:,1], input_values[j,:], f'{output_plots_path}/{variables[j]}-{datestr}-{ds_name}.jpg')
+ # plot interpolated
+ plot_and_save_projection(output_grid[:,0], output_grid[:,1], interpolated_data[j, :], f'{output_plots_path}/{variables[j]}-{datestr}-{ds_name}-interpolated.jpg')
+
+ nc_date_index = nc_date_index + 1
+ t = t + frequency
+ t_i = t_i + 1
+
\ No newline at end of file
diff --git a/src/hirad/input_data/interpolate_realch1.py b/src/hirad/input_data/interpolate_realch1.py
new file mode 100644
index 00000000..46c08674
--- /dev/null
+++ b/src/hirad/input_data/interpolate_realch1.py
@@ -0,0 +1,152 @@
+import hirad.input_data.interpolate_basic as interpolate_basic
+import hirad.input_data.regrid_copernicus_tp as regrid_copernicus_tp
+
+import datetime
+import logging
+import os
+import shutil
+import sys
+import yaml
+import array
+
+from anemoi.datasets import open_dataset
+from anemoi.datasets.data.dataset import Dataset
+import netCDF4
+import numpy as np
+from pandas import to_datetime
+from scipy.interpolate import griddata
+import torch
+import multiprocessing
+import xarray
+
+# Margin to use for ERA dataset (to avoid nans from interpolation at boundary)
+ERA_MARGIN_DEGREES = 1.0
+COPERNICUS_FILES = ['/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-2015-2016.nc',
+ '/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-2017-2018.nc',
+ '/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-2019-2020.nc']
+
+def _read_input(era_config_file: str, realch1_latlon_file: str) -> tuple[Dataset, Dataset, array.array, np.ndarray]:
+ """
+ Read ERA data, and return the values for the area under REA-L-CH1 (plus a margin).
+ """
+ # read the lat/lon data for REA-L-CH1
+ realch1_latlon = torch.load(realch1_latlon_file)
+ # we expect the start and end dates to be specified in config.
+ with open(era_config_file) as era_file:
+ era_config = yaml.safe_load(era_file)
+ era = open_dataset(era_config)
+ # Subset the ERA dataset to have REAL-CH-1 area.
+ # area = N, W, S, E
+ min_lat = min(realch1_latlon[:,0]) - ERA_MARGIN_DEGREES
+ max_lat = max(realch1_latlon[:,0]) + ERA_MARGIN_DEGREES
+ min_lon = min(realch1_latlon[:,1]) - ERA_MARGIN_DEGREES
+ max_lon = max(realch1_latlon[:,1]) + ERA_MARGIN_DEGREES
+ era = open_dataset(era,
+ area=(max_lat, min_lon, min_lat, max_lon))
+ copernicus_netcdf = []
+ for f in COPERNICUS_FILES:
+ netcdf_data = netCDF4.Dataset(f)
+ copernicus_netcdf.append(netcdf_data)
+ return (era, copernicus_netcdf, realch1_latlon)
+
+
+def main():
+ # read REA-L-CH1 latlon grid
+ era_config_file = sys.argv[1]
+ realch1_latlon_file = sys.argv[2]
+ netcdf_file = sys.argv[3]
+ output_directory = sys.argv[4]
+
+ logging.basicConfig(
+ filename=os.path.join(output_directory, 'interpolate_realch1.log'),
+ format='%(asctime)s %(levelname)-8s %(message)s',
+ level=logging.INFO,
+ datefmt='%Y-%m-%d %H:%M:%S')
+
+ # Copy ERA yml file
+ shutil.copy(era_config_file, os.path.join(output_directory, 'info', 'era.yaml'))
+
+ logging.info('reading realch1 lat/lon')
+ realch1_latlon = torch.load(realch1_latlon_file, weights_only=False)
+ realch1_lat = realch1_latlon[:,0]
+ realch1_lon = realch1_latlon[:,1]
+ # read ERA input
+ min_lat = min(realch1_lat) - interpolate_basic.ERA_MARGIN_DEGREES
+ max_lat = max(realch1_lat) + interpolate_basic.ERA_MARGIN_DEGREES
+ min_lon = min(realch1_lon) - interpolate_basic.ERA_MARGIN_DEGREES
+ max_lon = max(realch1_lon) + interpolate_basic.ERA_MARGIN_DEGREES
+ logging.info('reading era')
+
+ era = interpolate_basic.read_anemoi_ds(era_config_file,
+ area=(max_lat, min_lon, min_lat, max_lon))
+ era_grid = np.column_stack((era.longitudes, era.latitudes))
+ realch1_grid = np.column_stack((realch1_lon, realch1_lat))
+ logging.info(f'lat lon area is {min_lat}-{max_lat} {min_lon}-{max_lon}')
+
+ # save era stats and lat lon
+ interpolate_basic.save_anemoi_latlon_grid(era, os.path.join(output_directory, 'info', 'era-lat-lon'))
+ interpolate_basic.save_anemoi_stats(era, os.path.join(output_directory, 'info', 'era-stats'))
+
+ # read copernicus input for tp variable
+ logging.info('reading copernicus')
+ netcdf_data = netCDF4.Dataset(netcdf_file)
+ logging.info('processing netcdf data')
+ netcdf_latitudes, netcdf_longitudes = regrid_copernicus_tp.extract_lat_lon_025(netcdf_data)
+ netcdf_grid=np.column_stack((netcdf_longitudes, netcdf_latitudes))
+
+ # TODO: start and end date functionality
+ netcdf_tp_values = regrid_copernicus_tp.extract_values(netcdf_data, 'tp', start_date=era.start_date, end_date=era.end_date)
+ assert(netcdf_tp_values.shape[0] == era.shape[0])
+ # todo: incorporate this somehow
+ netcdf_tp_values = netcdf_tp_values.reshape((netcdf_tp_values.shape[0], 1,1, netcdf_tp_values.shape[1]))
+
+ # save copernicus stats and lat lon
+ torch.save(np.column_stack((netcdf_grid[:,1], netcdf_grid[:,0])),
+ os.path.join(output_directory, 'info', 'copernicus-lat-lon'))
+ regrid_copernicus_tp.make_stats(os.path.join(output_directory, 'info'),
+ os.path.join(output_directory, 'info'),
+ netcdf_tp_values)
+
+ # Iterate over ERA time range, which should be subsetted in configuration.
+ tp_index = era.variables.index('tp')
+ logging.info(f'tp index {tp_index}')
+
+ plot_indices = {12}
+
+ logging.info('interpolating')
+ #for i in plot_indices:
+ for i in range(era.shape[0]):
+ # T
+ t = era.dates[i]
+ # Get everything but the tp variable
+ era_for_time = era[i,:,:,:]
+ era_regridded = interpolate_basic.regrid(era_for_time, era_grid, realch1_grid)
+ # Regrid TP from copernicus
+ copernicus_regridded = interpolate_basic.regrid(netcdf_tp_values[i,:], netcdf_grid, realch1_grid)
+ # Concatenate and save
+ era_regridded[tp_index,:] = copernicus_regridded
+ #output=np.concatenate((era_regridded, copernicus_regridded), axis=0)
+ datefmt = interpolate_basic._format_date(t)
+ filename = os.path.join(output_directory, 'era-copernicus-interpolated',
+ datefmt)
+ torch.save(era_regridded, filename)
+
+ if i in plot_indices:
+ realch1var = ['t2m', '10u', '10v', 'tp']
+ realch1_data = torch.load(os.path.join(output_directory, 'realch1', datefmt), weights_only=False)
+ for j in range(realch1_data.shape[0]):
+ interpolate_basic.plot_and_save_projection(realch1_lon, realch1_lat, realch1_data[j,:],
+ os.path.join(output_directory, 'plots',
+ f'{datefmt}-{realch1var[j]}-realch1'))
+ for j in range(era_regridded.shape[0]):
+ interpolate_basic.plot_and_save_projection(era.longitudes, era.latitudes, era_for_time[j,:],
+ os.path.join(output_directory, 'plots',
+ f'{datefmt}-{era.variables[j]}-era'))
+ interpolate_basic.plot_and_save_projection(realch1_lon, realch1_lat,
+ era_regridded[j,0,:],
+ os.path.join(output_directory, 'plots',
+ f'{datefmt}-{era.variables[j]}-interpolated'))
+ return 0
+
+if __name__ == "__main__":
+ main()
diff --git a/src/hirad/input_data/process_copernicus_cosmo.py b/src/hirad/input_data/process_copernicus_cosmo.py
new file mode 100644
index 00000000..b9f210d3
--- /dev/null
+++ b/src/hirad/input_data/process_copernicus_cosmo.py
@@ -0,0 +1,49 @@
+import logging
+import os
+import sys
+
+import numpy as np
+import torch
+
+import interpolate_basic
+
+def main():
+ # TODO: Do better arg parsing so it's not as easy to reverse era and cosmo configs.
+ if len(sys.argv) < 4:
+ raise ValueError('Expected call process_copernicus_cosmo.py [copernicus.yaml] [cosmo.yaml] [output directory]')
+ infile_copernicus = sys.argv[1]
+ infile_cosmo = sys.argv[2]
+ output_path = sys.argv[3]
+
+ os.makedirs(output_path, exist_ok=True)
+
+ logging.basicConfig(
+ filename=os.path.join(output_path, f'process-copernicus-cosmo.log'),
+ format='%(asctime)s %(levelname)-8s %(message)s',
+ level=logging.INFO,
+ datefmt='%Y-%m-%d %H:%M:%S')
+
+ logging.info(f'running {sys.argv}')
+
+ output_grid = None
+
+ if infile_cosmo.endswith('yaml'):
+ cosmo = interpolate_basic.read_anemoi_ds(infile_cosmo)
+ output_grid = np.column_stack((cosmo.longitudes, cosmo.latitudes))
+ else:
+ # This must be a lat-lon torch file.
+ cosmo_latlon = torch.load(infile_cosmo, weights_only=False)
+ lats = cosmo_latlon[:,0]
+ lons = cosmo_latlon[:,1]
+ output_grid = np.column_stack((lons, lats))
+
+ # interpolate copernicus
+ format = 'numpy'
+ plot_indices=[0]
+ interpolate_basic.interpolate_netcdf_to_grid(infile_copernicus, 'copernicus',
+ output_grid, output_path=output_path,
+ format=format, plot_indices=plot_indices)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/hirad/input_data/process_era5_cosmo.py b/src/hirad/input_data/process_era5_cosmo.py
new file mode 100644
index 00000000..be120198
--- /dev/null
+++ b/src/hirad/input_data/process_era5_cosmo.py
@@ -0,0 +1,68 @@
+import logging
+import os
+import sys
+
+import numpy as np
+import torch
+
+import interpolate_basic
+
+def main():
+ # TODO: Do better arg parsing so it's not as easy to reverse era and cosmo configs.
+ if len(sys.argv) < 4:
+ raise ValueError('Expected call process_era5_cosmo.py [era.yaml] [cosmo.yaml] [output directory]')
+ infile_era = sys.argv[1]
+ infile_cosmo = sys.argv[2]
+ output_path = sys.argv[3]
+
+ os.makedirs(output_path, exist_ok=True)
+
+ erashortname = infile_era.split('/')[-1].split('.')[0]
+
+ logging.basicConfig(
+ filename=os.path.join(output_path, f'process-era5-cosmo-{erashortname}.log'),
+ format='%(asctime)s %(levelname)-8s %(message)s',
+ level=logging.INFO,
+ datefmt='%Y-%m-%d %H:%M:%S')
+
+ logging.info(f'running {sys.argv}')
+ #output_plots_path = None
+
+ output_grid = None
+
+ if infile_cosmo.endswith('yaml'):
+ cosmo = interpolate_basic.read_anemoi_ds(infile_cosmo)
+ output_grid = np.column_stack((cosmo.longitudes, cosmo.latitudes))
+ else:
+ # This must be a lat-lon torch file.
+ cosmo_latlon = torch.load(infile_cosmo, weights_only=False)
+ lats = cosmo_latlon[:,0]
+ lons = cosmo_latlon[:,1]
+ output_grid = np.column_stack((lons, lats))
+
+ # interpolate era
+ format = 'numpy'
+ plot_indices=[0]
+ interpolate_basic.interpolate_anemoi_to_grid(infile_era, 'era', output_grid, output_path=output_path, format=format, plot_indices=plot_indices)
+ # save era and cosmo input/output data into same format
+ # Save cosmo data
+ #if infile_cosmo.endswith('yaml'):
+ # interpolate_basic.save_anemoi_as_format(infile_cosmo, 'cosmo', output_path, plot_indices=plot_indices, format=format)
+
+ # Save ERA data (subsetted)
+ lats = output_grid[:,1]
+ lons = output_grid[:,0]
+ min_lat = min(lats) - interpolate_basic.ERA_MARGIN_DEGREES
+ max_lat = max(lats) + interpolate_basic.ERA_MARGIN_DEGREES
+ min_lon = min(lons) - interpolate_basic.ERA_MARGIN_DEGREES
+ max_lon = max(lons) + interpolate_basic.ERA_MARGIN_DEGREES
+ area=(max_lat, min_lon, min_lat, max_lon)
+ logging.info(f'projecting onto era area {area}')
+ # skip plotting as we did it already in interpolation step
+ #interpolate_basic.save_anemoi_as_format(infile_era, 'era', output_path, plot_indices=plot_indices, format=format, area=area,
+ # start_date=cosmo.start_date, end_date=cosmo.end_date)
+ #interpolate_basic.save_anemoi_as_format(infile_cosmo, 'cosmo', output_path, plot_indices=plot_indices, format=format)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/hirad/input_data/process_era5_realch1.py b/src/hirad/input_data/process_era5_realch1.py
new file mode 100644
index 00000000..dcc14134
--- /dev/null
+++ b/src/hirad/input_data/process_era5_realch1.py
@@ -0,0 +1,37 @@
+import logging
+import os
+import sys
+
+import numpy as np
+import torch
+
+import interpolate_basic
+
+def main():
+ # TODO: Do better arg parsing so it's not as easy to reverse era and cosmo configs.
+ if len(sys.argv) < 4:
+ raise ValueError('Expected call process_era5_realch1.py [era.yaml] [cosmo.yaml] [output directory]')
+ infile_era = sys.argv[1]
+ infile_realch1 = sys.argv[2]
+ output_path = sys.argv[3]
+
+ os.makedirs(output_path, exist_ok=True)
+
+ erashortname = infile_era.split('/')[-1].split('.')[0]
+
+ logging.basicConfig(
+ filename=os.path.join(output_path, f'process-era5-realch1-{erashortname}.log'),
+ format='%(asctime)s %(levelname)-8s %(message)s',
+ level=logging.INFO,
+ datefmt='%Y-%m-%d %H:%M:%S')
+
+ logging.info(f'running {sys.argv}')
+
+ format='numpy'
+ plot_indices=[0]
+
+ interpolate_basic.save_anemoi_as_format(infile_realch1, 'realch1', output_path, plot_indices=plot_indices, format=format)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/hirad/input_data/process_era5_with_copernicus.py b/src/hirad/input_data/process_era5_with_copernicus.py
new file mode 100644
index 00000000..4bb0aec5
--- /dev/null
+++ b/src/hirad/input_data/process_era5_with_copernicus.py
@@ -0,0 +1,102 @@
+import logging
+import os
+import sys
+
+import numpy as np
+import torch
+import yaml
+
+import interpolate_basic
+
+def load(filename: str):
+ if filename.endswith('.npy'):
+ return np.load(filename)
+ return torch.load(filename, weights_only=False)
+
+def save(values: np.ndarray, filename: str):
+ if filename.endswith('.npy'):
+ np.save(filename, values)
+ else:
+ torch.save(values, filename)
+ return
+
+def main():
+ # TODO: Do better arg parsing so it's not as easy to reverse era and cosmo configs.
+ if len(sys.argv) < 4:
+ raise ValueError('Expected call process_copernicus_era.py [era.yaml] [copernicus.yaml] [path]')
+
+ infile_era = sys.argv[1]
+ infile_copernicus = sys.argv[2]
+ era_dir = sys.argv[3]
+ copernicus_dir = sys.argv[4]
+ output_path = sys.argv[5]
+
+ os.makedirs(os.path.join(output_path, 'era-copernicus-interpolated'), exist_ok=True)
+
+ logging.basicConfig(
+ filename=os.path.join(output_path, f'process-era-copernicus.log'),
+ format='%(asctime)s %(levelname)-8s %(message)s',
+ level=logging.INFO,
+ datefmt='%Y-%m-%d %H:%M:%S')
+
+ logging.info(f'running {sys.argv}')
+
+ # Extract ERA channels from input yaml
+ with open(infile_era) as era_cfg_file:
+ era_config = yaml.safe_load(era_cfg_file)
+ with open(infile_copernicus) as cop_cfg_file:
+ copernicus_config = yaml.safe_load(cop_cfg_file)
+
+ era_channels = era_config['select']
+ copernicus_channels = copernicus_config['channels']
+ mapping = {}
+ for c in range(len(copernicus_channels)):
+ if era_channels.count(copernicus_channels[c]) == 1:
+ e = era_channels.index(copernicus_channels[c])
+ mapping[e] = c
+ logging.info('replacing {len(mapping)} channels in ERA data: {mapping}')
+
+ #era_dir = os.path.join(output_path, 'era-interpolated')
+ #copernicus_dir = os.path.join(path, 'copernicus-interpolated')
+ output_dir = os.path.join(output_path, 'era-copernicus-interpolated')
+ era_files = os.listdir(era_dir)
+ copernicus_files = os.listdir(copernicus_dir)
+ era_files.sort()
+ copernicus_files.sort()
+ #intersect_files = list(set(era_files).intersection(set(copernicus_files)))
+ era_format = 'torch'
+ copernicus_format = 'torch'
+ if era_files[0].endswith('npy'):
+ era_format = 'numpy'
+ if copernicus_files[0].endswith('npy'):
+ copernicus_format = 'numpy'
+
+ #for f in era_files:
+ #for i in range(15000):
+ for i in range(10000,12000):
+ #for i in range(12000,15000):
+ #for i in range(15000,20000):
+ #for i in range(20000,25000):
+ #for i in range(25000,30000):
+ #for i in range(30000,35000):
+ #for i in range(35000,40000):
+ #for i in range(40000,len(era_files)):
+ f = era_files[i]
+ logging.info(f'{i} {f}')
+ era = load(os.path.join(era_dir, f))
+ era = era.squeeze() # get rid of extra dimension, if present
+ base_filename = f
+ if era_format == 'numpy':
+ base_filename = f[:-4]
+ c_f = base_filename
+ if copernicus_format == 'numpy':
+ c_f = base_filename + '.npy'
+ if c_f in copernicus_files:
+ copernicus = load(os.path.join(copernicus_dir, c_f))
+ assert np.array_equal(era.shape[1:], copernicus.shape[1:]), f'Era and Copernicus files appear to have different grid shapes: {era.shape} vs {copernicus.shape}'
+ for k,v in mapping.items():
+ era[k,:] = copernicus[v,:]
+ save(era, os.path.join(output_dir, base_filename + '.npy'))
+
+if __name__ == "__main__":
+ main()
diff --git a/src/hirad/input_data/process_torch_to_numpy.py b/src/hirad/input_data/process_torch_to_numpy.py
new file mode 100644
index 00000000..3a83a64a
--- /dev/null
+++ b/src/hirad/input_data/process_torch_to_numpy.py
@@ -0,0 +1,27 @@
+
+import os
+import sys
+
+import numpy as np
+import torch
+
+
+def main():
+ input_dir = sys.argv[1]
+ output_dir = sys.argv[2]
+ in_files = os.listdir(input_dir)
+ out_files = os.listdir(output_dir)
+ in_files.sort()
+ for i in range(len(in_files)):
+ if i % 1000 == 0:
+ print(i)
+ f = in_files[i]
+ #for f in in_files
+ if not (f + '.npy') in out_files:
+ print(f)
+ data = torch.load(os.path.join(input_dir, f), weights_only=False)
+ np.save(os.path.join(output_dir, f), data)
+
+if __name__ == "__main__":
+ main()
+
diff --git a/src/hirad/input_data/read_tp.py b/src/hirad/input_data/read_tp.py
deleted file mode 100644
index a13fed31..00000000
--- a/src/hirad/input_data/read_tp.py
+++ /dev/null
@@ -1,183 +0,0 @@
-import logging
-import netCDF4
-from anemoi.datasets import open_dataset
-import numpy as np
-import yaml
-
-import matplotlib.pyplot as plt
-import cartopy.crs as ccrs
-import cartopy.feature as cfeature
-from matplotlib.colors import BoundaryNorm, ListedColormap
-
-import interpolate_basic
-
-import sys
-from pathlib import Path
-
-import os
-print (os.getcwd())
-
-sys.path.insert(0, Path(__file__).parent.as_posix())
-
-ANEMOI_1H_FILENAME = "/scratch/mch/omiralle/anemoi/aifs-ea-an-oper-0001-mars-n320-2015-2020-1h-v1-with-ERA51.zarr"
-ANEMOI_6H_FILENAME = "/scratch/mch/apennino/data/aifs-ea-an-oper-0001-mars-n320-1979-2022-6h-v6.zarr"
-COSMO_6H_FILENAME = "/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr"
-COSMO_1H_FILENAME = "/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr"
-COSMO_CONFIG_FILE="src/input_data/cosmo.yaml"
-CDF_FILENAME = "8e49f064d738154bed136666ff72ae1c.nc"
-
-
-LAT = np.arange(-4.42, 3.36 + 0.02, 0.02)
-LON = np.arange(-6.82, 4.80 + 0.02, 0.02)
-RELAX_ZONE = 19 # Number of points dropped on each side (relaxation zone)
-
-
-def extract_values(netcdf_data):
- netcdf_lat = netcdf_data['latitude'][:]
- netcdf_lon = netcdf_data['longitude'][:]
- netcdf_tp = netcdf_data['tp'][:,:]
- values = np.zeros((netcdf_tp.shape[0], netcdf_tp.shape[1]*netcdf_tp.shape[2]))
- latitudes = np.zeros(values.shape[1])
- longitudes = np.zeros(values.shape[1])
- # You could probably get this by reshaping, but I can't be bothered.
- for i in range(len(netcdf_lat)):
- if i % 10 == 0:
- print(i)
- for j in range(len(netcdf_lon)):
- grid_index = i * len(netcdf_lon) + j
- values[:,grid_index] = netcdf_tp[:,i,j]
- latitudes[grid_index] = netcdf_lat[i]
- longitudes[grid_index] = netcdf_lon[j]
- return values, latitudes, longitudes
-
-def plot_map(values: np.array,
- filename: str,
- label='',
- title='',
- vmin=None,
- vmax=None,
- cmap=None,
- extend='neither',
- norm=None,
- ticks=None):
- """Plot observed or interpolated data in a scatter plot."""
- logging.info(f'Creating map: {filename}')
-
- latitudes = LAT[RELAX_ZONE : RELAX_ZONE + 352]
- longitudes = LON[RELAX_ZONE : RELAX_ZONE + 544]
- lon2d, lat2d = np.meshgrid(longitudes, latitudes)
-
- fig, ax = plt.subplots(
- figsize=(8, 6),
- subplot_kw={"projection": ccrs.RotatedPole(pole_longitude=-170.0,
- pole_latitude= 43.0)}
- )
- values = values.reshape((len(latitudes), len(longitudes)))
- contour = ax.pcolormesh(
- lon2d, lat2d, values,
- cmap=cmap, shading="auto",
- norm=norm if norm else None,
- vmin=None if norm else vmin,
- vmax=None if norm else vmax,
- )
- ax.coastlines()
- ax.add_feature(cfeature.BORDERS, linewidth=1)
- ax.gridlines(visible=False)
- ax.set_xticks([])
- ax.set_yticks([])
-
- plt.title(title)
- cbar = plt.colorbar(
- contour,
- label=label,
- orientation="horizontal",
- extend=extend,
- shrink=0.75,
- pad=0.02
- )
- if ticks is not None:
- cbar.set_ticks(ticks)
- cbar.set_ticklabels([f'{tick:g}' for tick in ticks])
-
- plt.tight_layout()
- fig.savefig(f"{filename}.png", dpi=300, bbox_inches="tight")
- plt.close(fig)
-
-def plot_map_precipitation(values, filename, title='', threshold=0.1, rfac=1000.0):
- """Plot precipitation data with specific colormap and thresholds."""
- # Scale and mask values below threshold
- values = rfac * values # m/h --> mm/h
- values = np.ma.masked_where(values <= threshold, values)
-
- # Predefined colors and bounds specific for precipitation
- colors = ['none', 'powderblue', 'dodgerblue', 'mediumblue',
- 'forestgreen', 'limegreen', 'lawngreen',
- 'yellow', 'gold', 'darkorange', 'red',
- 'darkviolet', 'violet', 'thistle']
- bounds = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 30, 50, 70, 100, 150, 200]
- bounds = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000]
-
- cmap = ListedColormap(colors)
- norm = BoundaryNorm(bounds, ncolors=len(colors), clip=False)
-
- plot_map(
- values, filename,
- cmap=cmap,
- norm=norm,
- ticks=bounds,
- title=title,
- label='mm/h',
- extend='max'
- )
-
-
-print(interpolate_basic.regrid)
-
-file_id = netCDF4.Dataset(CDF_FILENAME)
-#anemoi1 = open_dataset(ANEMOI_1H_FILENAME)
-#anemoi6 = open_dataset(ANEMOI_6H_FILENAME)
-#with open(COSMO_CONFIG_FILE) as cosmo_file:
-# cosmo_config = yaml.safe_load(cosmo_file)
-#cosmo = open_dataset(cosmo_config)
-cosmo1 = open_dataset(COSMO_1H_FILENAME, trim_edge=19, select=['tp'],start='2016-01-01',end='2016-02-29')
-cosmo6 = open_dataset(COSMO_6H_FILENAME, trim_edge=19, select=['tp'],start='2016-01-01',end='2016-02-29')
-
-
-output_grid= np.column_stack((cosmo1.longitudes, cosmo1.latitudes))
-print(output_grid.shape)
-print(cosmo1[0,0,0,:].shape)
-
-plot_map_precipitation(values=cosmo1[0,:], filename="cosmo1.png")
-plot_map_precipitation(values=cosmo6[0,:], filename="cosmo6.png")
-
-#fig = plt.figure()
-#fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()})
-#interpolate_basic.plot_projection(ax, longitudes=cosmo1.longitudes, latitudes=cosmo1.latitudes, values=cosmo1[0,:])
-#fig.savefig('cosmo1.png')
-
-#fig = plt.figure()
-#fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()})
-#interpolate_basic.plot_projection(ax, longitudes=cosmo1.longitudes, latitudes=cosmo1.latitudes, values=cosmo6[0,:])
-#fig.savefig('cosmo6.png')
-
-
-values, latitudes, longitudes = extract_values(netcdf_data=file_id)
-input_grid=np.column_stack((longitudes, latitudes))
-vals = values[0,:].reshape((1,1,values.shape[1]))
-regrid=interpolate_basic.regrid(vals, input_grid, output_grid)
-plot_map_precipitation(regrid, 'netcdf.png')
-
-
-era1 = open_dataset(ANEMOI_1H_FILENAME, select=['tp'],start='2016-01-01',end='2016-02-29')
-era6 = open_dataset(ANEMOI_6H_FILENAME, select=['tp'],start='2016-01-01',end='2016-02-29')
-era_grid = np.column_stack((era1.longitudes, era1.latitudes))
-era1_regrid = interpolate_basic.regrid(era1[0,:], era_grid, output_grid)
-plot_map_precipitation(era1_regrid, "era1.png")
-
-
-
-era6_regrid = interpolate_basic.regrid(era6[0,:], era_grid, output_grid)
-plot_map_precipitation(era1_regrid, "era6.png")
-
-
-
diff --git a/src/hirad/input_data/regrid_copernicus_tp.py b/src/hirad/input_data/regrid_copernicus_tp.py
new file mode 100644
index 00000000..adcea16c
--- /dev/null
+++ b/src/hirad/input_data/regrid_copernicus_tp.py
@@ -0,0 +1,274 @@
+import logging
+import netCDF4
+import xarray
+import numpy as np
+import torch
+import datetime
+from scipy.interpolate import griddata
+
+from hirad.eval.plotting import plot_map_precipitation, plot_scores_vs_t
+from hirad.eval.metrics import absolute_error
+
+import interpolate_basic
+
+import sys
+from pathlib import Path
+
+import os
+print (os.getcwd())
+
+sys.path.insert(0, Path(__file__).parent.as_posix())
+
+
+CDF_FILENAME_BALFRIN = "/store_new/mch/msopr/hirad-gen/copernicus-datasets/tp-janfeb2020.nc"
+#CDF_FILENAME_CLARIDEN_TP = "/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2019-2020.nc"
+#CDF_FILENAME_CLARIDEN_TP = "/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2017-2018-n320.nc"
+CDF_FILENAME_CLARIDEN_TP = "/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2015-2016.nc"
+
+
+BASE_FILEPATH = "/capstor/store/"
+INPUT_DATA_FILEPATH = "mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-linear-interpolation-full/"
+OUTPUT_DATA_FILEPATH_ERA_INTERPOLATED = "/capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/train/era-interpolated-with-copernicus-tp"
+OUTPUT_DATA_FILEPATH_ERA = "/capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/train/era-with-copernicus-tp"
+TP_INDEX = 12
+
+LAT = np.arange(-4.42, 3.36 + 0.02, 0.02)
+LON = np.arange(-6.82, 4.80 + 0.02, 0.02)
+RELAX_ZONE = 19 # Number of points dropped on each side (relaxation zone)
+
+def extract_grib_values(grib_data):
+ grib_lat = grib_data['latitude'][:]
+ grib_lon = grib_data['longitude'][:]
+ grib_t2m = grib_data['t2m'][:]
+
+def extract_lat_lon_025(data):
+ logging.info('extracting lat/lon')
+ lat = data['latitude'][:]
+ lon = data['longitude'][:]
+ output_lat = np.zeros(len(lat)* len(lon))
+ output_lon = np.zeros(len(lat) * len(lon))
+ for i in range(len(lat)):
+ if i % 10 == 0:
+ print(i)
+ for j in range(len(lon)):
+ grid_index = i * len(lon) + j
+ output_lat[grid_index] = lat[i]
+ output_lon[grid_index] = lon[j]
+ return output_lat, output_lon
+
+def extract_lat_lon_n320(data):
+ lat = data['latitudes'][:]
+ lon = data['longitudes'][:]
+ logging.info('extracting lat/lon')
+ logging.info(f'lat lon shapes {lat.shape} {lon.shape}')
+
+# Get values for a given date range (inclusive)
+def extract_values(data: netCDF4.Dataset, variable, start_date=None, end_date=None, area=None):
+ values = data[variable][:]
+ #if area:
+ # Not sure this is working.
+ # lat = data['latitude'][:]
+ # lon = data['longitude'][:]
+ # https://stackoverflow.com/questions/29135885/netcdf4-extract-for-subset-of-lat-lon
+ # latli = np.argmin( np.abs(lat - area[2]))
+ # latui = np.argmin( np.abs(lat - area[0]))
+ # lonli = np.argmin( np.abs(lon - area[1]))
+ # lonui = np.argmin( np.abs(lon - area[3]))
+ # lat = data['latitude'][latli:latui]
+ # lon = data['longitude'][lonli:lonui]
+ # values = data[variable][latli:latui,lonli:lonui]
+ date_indices = range(values.shape[0])
+ if start_date and end_date:
+ date_indices = np.intersect1d(
+ np.where(data['valid_time'][:] >= start_date.astype(np.int64)),
+ np.where(data['valid_time'][:] <= end_date.astype(np.int64)))
+ values = values[date_indices,:]
+ if len(date_indices) == 0:
+ raise KeyError(f'{start_date} and {end_date} not valid range')
+
+ return np.reshape(values, (values.shape[0], values.shape[1]*values.shape[2]))
+
+def reshape_to_cosmo(vals):
+ return vals.reshape((len(LAT)-RELAX_ZONE*2, len(LON)-RELAX_ZONE*2))
+
+def calc_errors(cosmo1, era1):
+ make_plots = True
+
+ prev_netcdf_regrid = []
+
+ netcdf_error = np.zeros(cosmo1.dates.shape)
+ era_norm_error = np.zeros(cosmo1.dates.shape)
+ netcdf_early_error = np.zeros(cosmo1.dates.shape)
+ netcdf_late_error = np.zeros(cosmo1.dates.shape)
+
+ output_grid= np.column_stack((cosmo1.longitudes, cosmo1.latitudes))
+
+ for t in range(4):
+ #for t in range(len(cosmo1.dates)):
+ date = cosmo1.dates[t]
+ era_date = era1.dates[t]
+ if date != era_date:
+ logging.error('dates do not match: cosmo date: {date}, era date: {era_date}')
+ if date != netcdf_data['valid_time'][t]:
+ logging.error(f'dates do not match: cosmo date: {date}, netcdf: {netcdf_data["valid_time"][t]}')
+
+
+ # plot cosmo
+ if make_plots:
+ plot_map_precipitation(values=reshape_to_cosmo(cosmo1[t,:]), filename=f'plots/tp/{date}-cosmo1')
+
+ # plot netcdf
+ netcdf_vals = netcdf_values[t,:].reshape((1,1,netcdf_values.shape[1]))
+ netcdf_regrid=interpolate_basic.regrid(netcdf_vals, netcdf_grid, output_grid)
+ if make_plots:
+ plot_map_precipitation(reshape_to_cosmo(netcdf_regrid), f'plots/tp/{date}-netcdf-refactor')
+
+ # plot era
+ era_grid = np.column_stack((era1.longitudes, era1.latitudes))
+ era1_regrid = interpolate_basic.regrid(era1[t,:], era_grid, output_grid)
+ if make_plots:
+ plot_map_precipitation(reshape_to_cosmo(era1_regrid/6), f'plots/tp/{date}-era1-norm')
+
+ #if t % 6 == 0:
+ # if era6.dates[t//6] != date:
+ # logging.error(f'dates do not match: era1: {date}, era6: {era6.dates[t//6]}')
+ # era6_regrid = interpolate_basic.regrid(era6[t//6,:], era_grid, output_grid)
+ # plot_map_precipitation(reshape_to_cosmo(era6_regrid), f'plots/tp/{date}-era6')
+
+ era_norm_error[t] = np.mean(absolute_error(era1_regrid/6, cosmo1[t,:]))
+ netcdf_error[t] = np.mean(absolute_error(netcdf_regrid, cosmo1[t,:]))
+ logging.info(f'era norm error: {era_norm_error[t]} netcdf err: {netcdf_error[t]}')
+ if t>0:
+ netcdf_early_error[t] = np.mean(absolute_error(prev_netcdf_regrid, cosmo1[t,:]))
+ netcdf_late_error[t-1] = np.mean(absolute_error(netcdf_regrid, cosmo1[t-1,:]))
+ logging.info(f'netcdf early err: {netcdf_early_error[t]}, netcdf late err: {netcdf_late_error[t-1]}')
+ prev_netcdf_regrid = netcdf_regrid
+
+ maes = {}
+ maes['era normalized'] = era_norm_error
+ maes['copernicus'] = netcdf_error
+ maes['copernicus-early'] = netcdf_early_error
+ maes['copernicus-late'] = netcdf_late_error
+ plot_scores_vs_t(maes, times=cosmo1.dates, filename='plots/errors.png')
+
+def process_era_interpolated(netcdf_data, netcdf_tp_values, input_data_filepath, output_interpolated_filepath, netcdf_grid, cosmo_grid):
+ make_plots = False
+ #for t in range(100):
+ for t in range(netcdf_tp_values.shape[0]):
+ netcdf_date = netcdf_data['valid_time'][t]
+ date_filename = datetime.datetime.fromtimestamp(netcdf_date, datetime.UTC).strftime('%Y%m%d-%H%M')
+ t1 = datetime.datetime.now()
+ era_filename = os.path.join(input_data_filepath, "era-interpolated", date_filename)
+ output_filename = os.path.join(output_interpolated_filepath, date_filename)
+ if os.path.exists(era_filename):
+ requires_processing = False
+ if date_filename == '20200615-2100':
+ requires_processing = True
+ #if os.path.exists(output_filename):
+ # test the output to make sure it is not corrupted.
+ #requires_processing = False
+ #try:
+ # torch.load(output_filename, weights_only=False)
+ #except:
+ # requires_processing = True
+ if requires_processing:
+ era_data = torch.load(era_filename, weights_only=False)
+ t2 = datetime.datetime.now()
+ #if t % 100 == 0:
+ logging.info(f'regridding {date_filename} (netcdf date: {netcdf_date})')
+ interpolated_tp = griddata(netcdf_grid, netcdf_tp_values[t,:], cosmo_grid, method='linear')
+ t3 = datetime.datetime.now()
+ if make_plots and max > 0.0002:
+ nans = np.count_nonzero(np.isnan(interpolated_tp))
+ nonzeros = np.count_nonzero(interpolated_tp)
+ max = np.max(interpolated_tp)
+ logging.info(f'nonzeros: {nonzeros} nans: {nans} max: {max}')
+ if max > 0.0002:
+ cosmo_filename = os.path.join(input_data_filepath, "cosmo", date_filename)
+ cosmo_data = torch.load(cosmo_filename, weights_only=False)
+ plot_map_precipitation(reshape_to_cosmo(interpolated_tp), f'plots/tp/{date_filename}-netcdf-regrid')
+ # 3 is tp index in cosmo data
+ plot_map_precipitation(reshape_to_cosmo(cosmo_data[3,:]), f'plots/tp/{date_filename}-cosmo')
+ plot_map_precipitation(reshape_to_cosmo(era_data[TP_INDEX,0,:]/6), f'plots/tp/{date_filename}-era-norm')
+ era_data[TP_INDEX,0,:] = interpolated_tp
+ torch.save(era_data, output_filename)
+ t4 = datetime.datetime.now()
+
+def process_era(netcdf_data, netcdf_tp_values):
+ for t in range(netcdf_tp_values.shape[0]):
+ netcdf_date = netcdf_data['valid_time'][t]
+ date_filename = datetime.datetime.fromtimestamp(netcdf_date, datetime.UTC).strftime('%Y%m%d-%H%M')
+ era_filename = os.path.join(BASE_FILEPATH, INPUT_DATA_FILEPATH, "era", date_filename)
+ if os.path.exists(era_filename):
+ era_data = torch.load(era_filename, weights_only=False)
+ t2 = datetime.datetime.now()
+ logging.info(f'regridding {date_filename} (netcdf date: {netcdf_date})')
+ interpolated_tp = griddata(netcdf_grid, netcdf_tp_values[t,:], era_grid, method='linear')
+ t3 = datetime.datetime.now()
+ era_data[TP_INDEX,0,:] = interpolated_tp
+ torch.save(era_data, os.path.join(OUTPUT_DATA_FILEPATH_ERA, date_filename))
+ t4 = datetime.datetime.now()
+
+def extract_all_values():
+ set1 = netCDF4.Dataset("/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2015-2016.nc")
+ set2 = netCDF4.Dataset("/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2017-2018.nc")
+ set3 = netCDF4.Dataset("/capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2019-2020.nc")
+ set1_tp = extract_values(set1, 'tp')
+ set2_tp = extract_values(set2, 'tp')
+ set3_tp = extract_values(set3, 'tp')
+ all_tp = np.row_stack((set1_tp, set2_tp, set3_tp))
+ print(all_tp.shape)
+ return all_tp
+
+# Get stats from ERA and replace the TP variable with stats from Copernicus
+def make_stats(input_stats_directory: str, output_stats_directory: str, extracted_tp_values: np.ndarray):
+ stats = torch.load(os.path.join(input_stats_directory, 'era-stats'), weights_only=False)
+ print(stats)
+ #extracted_tp_values = extracted_tp_values.reshape(extracted_tp_values.shape[0] * extracted_tp_values.shape[1], 1)
+ flat_values = extracted_tp_values.flatten()
+ mean = np.mean(flat_values)
+ max = np.max(flat_values)
+ min = np.min(flat_values)
+ stdev = np.std(flat_values)
+ stats['mean'][TP_INDEX] = mean
+ stats['maximum'][TP_INDEX] = max
+ stats['minimum'][TP_INDEX] = min
+ stats['stdev'][TP_INDEX] = stdev
+ print(stats)
+ torch.save(stats, os.path.join(output_stats_directory, 'era-copernicus-stats'))
+
+
+#process_era(netcdf_data, netcdf_tp_values)
+
+
+def main():
+ root = logging.getLogger()
+ root.setLevel(logging.INFO)
+
+ logging.info('loading data')
+ netcdf_file = sys.argv[1]
+ input_data_filepath = sys.argv[2]
+ output_interpolated_filepath = sys.argv[3]
+
+ netcdf_data = netCDF4.Dataset(netcdf_file)
+ logging.info(netcdf_data)
+
+ logging.info('processing netcdf data')
+ netcdf_latitudes, netcdf_longitudes = extract_lat_lon_025(netcdf_data)
+ netcdf_grid=np.column_stack((netcdf_longitudes, netcdf_latitudes))
+ cosmo_grid = torch.load(os.path.join(input_data_filepath, 'info/cosmo-lat-lon'), weights_only=False)
+ cosmo_grid = np.column_stack((cosmo_grid[:,1], cosmo_grid[:,0]))
+ logging.info(f'netcdf grid shape {netcdf_grid.shape}')
+ logging.info(f'{netcdf_grid[1:10,:]}')
+ logging.info(f'cosmo grid shape {cosmo_grid.shape}')
+ logging.info(f'{cosmo_grid[1:10,:]}')
+
+ netcdf_tp_values = extract_values(netcdf_data, 'tp')
+
+
+ process_era_interpolated(netcdf_data, netcdf_tp_values, input_data_filepath, output_interpolated_filepath, netcdf_grid, cosmo_grid)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/src/hirad/input_data/regrid_realch1.py b/src/hirad/input_data/regrid_realch1.py
new file mode 100644
index 00000000..b1d873de
--- /dev/null
+++ b/src/hirad/input_data/regrid_realch1.py
@@ -0,0 +1,201 @@
+
+import datetime
+import logging
+import os
+import shutil
+import sys
+
+from anemoi.datasets import open_dataset
+from anemoi.datasets.data.dataset import Dataset
+import numpy as np
+from meteodatalab.operators import regrid
+import xarray as xr
+from meteodatalab import ogd_api
+from hirad.input_data import interpolate_basic
+import yaml
+import torch
+from pandas import to_datetime
+
+import matplotlib.pyplot as plt
+import cartopy.crs as ccrs
+
+TRIM_EDGE = 41
+XARRAY_BATCH = 4
+
+# Take anemoi dataset and provide xarray dataarrays for a set of variables.
+# returns: list of xarray dataarrays
+def anemoi_to_xarray(anemoi_data: Dataset, start_date_index=-1, end_date_index=-1):
+ if start_date_index == -1:
+ start_date_index = 0
+ if end_date_index == -1:
+ end_date_index = len(anemoi_data.dates)
+ lon = anemoi_data.longitudes
+ lat = anemoi_data.latitudes
+ eps = [0] # deterministic
+ time = anemoi_data.dates[start_date_index:end_date_index]
+ metadata = getMetadataFromOGD()
+ dataarrays = []
+ variables = anemoi_data.variables
+ for var_index in range(anemoi_data.shape[1]):
+ logging.info(f'building xarray for {variables[var_index]}')
+ ds = xr.Dataset(
+ data_vars=dict(
+ variable=(["time", "eps", "cell"],
+ np.array(anemoi_data[start_date_index:end_date_index,var_index,:,:])),
+ ),
+ coords=dict(
+ eps=eps,
+ time=time,
+ lon=("cell", lon),
+ lat=("cell", lat),
+ ),
+ attrs=dict(description=f'xarray from anemoi dataset for {variables[var_index]}',
+ metadata=metadata),
+ )
+ dataarrays.append(ds.to_dataarray())
+ return dataarrays
+
+# Run a request to get the metadata, so that we can fake out an xarray.
+def getMetadataFromOGD():
+ lead_times = ["P0DT0H"]
+ req = ogd_api.Request(
+ collection="ogd-forecasting-icon-ch1",
+ variable="TOT_PREC", #assuming this won't cause problems; we're only using grid info
+ ref_time="latest",
+ perturbed=False,
+ lead_time=lead_times,
+ )
+ tot_prec = ogd_api.get_from_ogd(req)
+ return tot_prec.metadata
+
+# get the geo coordinates for the rotated lat/lon dataset.
+# returns np.array of lats and array of lons
+def get_geo_coords(regridded_data: xr.Dataset, trim_edge=0):
+ xmin = regridded_data.metadata.get("longitudeOfFirstGridPointInDegrees")
+ xmax = regridded_data.metadata.get("longitudeOfLastGridPointInDegrees")
+ dx = regridded_data.metadata.get("iDirectionIncrementInDegrees")
+ ymin = regridded_data.metadata.get("latitudeOfFirstGridPointInDegrees")
+ ymax = regridded_data.metadata.get("latitudeOfLastGridPointInDegrees")
+ dy = regridded_data.metadata.get("jDirectionIncrementInDegrees")
+ y = np.arange(ymin,ymax+dy,dy)
+ x = np.arange(xmin,xmax+dx,dx)
+ # trim x and y according to trim_edge.
+ # (Have manually verified that when doing this, the outputs are the same as
+ # trimming post-projection)
+ y = y[trim_edge:len(y)-trim_edge]
+ x = x[trim_edge:len(x)-trim_edge]
+ sp_lat = regridded_data.metadata.get("latitudeOfSouthernPoleInDegrees") # -43.0. north_pole_lat = 43.0
+ sp_lon = regridded_data.metadata.get("longitudeOfSouthernPoleInDegrees") # 10.0. north_pole_lon = 190.0
+ xcoords = np.meshgrid(x,y)[0].flatten()
+ ycoords = np.meshgrid(x,y)[1].flatten()
+ # Expect south pole rotation of lon=10, latitude=-43
+ logging.info(f'sp_lat = {sp_lat}, sp_lon = {sp_lon}')
+ rotated_crs = ccrs.RotatedPole(
+ pole_longitude=(sp_lon + 180) % 360, pole_latitude=sp_lat * -1 # 190, 43
+ )
+ # Project onto PlateCarree. Geodetic produces similar coordinates (within 10 nanometers)
+ dst_grid = ccrs.PlateCarree()
+ geo_coords = dst_grid.transform_points(rotated_crs, xcoords, ycoords)
+ lats = geo_coords[:,1]
+ lons = geo_coords[:,0]
+ return lats, lons
+
+def regridded_to_numpy(regridded: xr.DataArray, trim_edge=0):
+ # regridded is in shape (eps, time, variable, x, y)
+ # want this in shape (time, channel, ensemble, grid)
+ # First, trim the edge
+ data = regridded.data[:,:,:,
+ trim_edge:regridded.data.shape[3]-trim_edge,
+ trim_edge:regridded.data.shape[4]-trim_edge]
+ # reshape to (time,channel,ensemble,grid)
+ data = data.reshape(data.shape[1], data.shape[0], data.shape[3]*data.shape[4])
+ return data
+
+def interpolate_anemoi_range_to_rotlatlon(i_start: int, i_end: int, ds: Dataset, ds_name: str, input_grid: np.ndarray, output_grid: np.ndarray, output_data_path: str, format='torch', output_plots_path: str = None, plot_indices=[0]):
+ torch_data = np.zeros([i_end-i_start, len(ds.variables), 1, output_grid.shape[0]])
+
+ xarrays = anemoi_to_xarray(ds, i_start, i_end)
+ for j in range(len(xarrays)):
+ logging.info(f'regridding {ds.variables[j]} for time {ds.dates[i_start]} to {ds.dates[i_end-1]}')
+ xarray = xarrays[j]
+ start = datetime.datetime.now()
+ regridded=regrid.icon2rotlatlon(xarray)
+ end = datetime.datetime.now()
+ logging.info(f' regridding took {end-start} seconds')
+ torch_data[0:i_end-i_start,j,:,:] = regridded_to_numpy(regridded, trim_edge=TRIM_EDGE)
+
+ logging.info('saving torch data')
+ for k in range(torch_data.shape[0]):
+ interpolate_basic.save_datetime_file(torch_data[k,:], ds.dates[i_start + k], output_data_path, format=format)
+ if (i_start + k) in plot_indices:
+ logging.info(f'plotting {i_start+k}')
+ datestr = interpolate_basic.format_date(ds.dates[i_start+k])
+ for v in range(torch_data.shape[1]):
+ interpolate_basic.plot_and_save_projection(input_grid[:,0], input_grid[:,1],
+ ds[i_start+k,v,0,:],
+ os.path.join(output_plots_path, f'{datestr}-{ds.variables[v]}-iconnative.png'),
+ s=0.005)
+ interpolate_basic.plot_and_save_projection(output_grid[:,0], output_grid[:,1],
+ torch_data[k,v,0,:],
+ os.path.join(output_plots_path, f'{datestr}-{ds.variables[v]}-rotlatlon.png'),
+ s=0.005)
+
+def interpolate_anemoi_to_rotlatlon(infile_anemoi: str, ds_name: str, output_grid: np.ndarray, output_path: str, format='torch', plot_indices=[0]):
+
+ # Copy the realch1.yml file to the info directory
+ shutil.copy(infile_anemoi, os.path.join(output_path, 'info'))
+
+ with open(infile_anemoi) as realch1_file:
+ realch1_config = yaml.safe_load(realch1_file)
+ realch1 = open_dataset(realch1_config)
+ variables = realch1.variables
+ input_grid = np.column_stack((realch1.longitudes, realch1.latitudes))
+
+ # Get the lat/lon info by regridding one variable
+ xarrays = anemoi_to_xarray(realch1, 0, 1)
+ regridded=regrid.icon2rotlatlon(xarrays[0])
+ logging.info('getting geo coords')
+ lats, lons = get_geo_coords(regridded, trim_edge=TRIM_EDGE)
+ output_grid=np.column_stack((lons, lats))
+
+ # Save grid to file
+ grid = np.column_stack((lats, lons))
+ torch.save(grid, os.path.join(output_path, 'info', 'realch1-lat-lon'))
+
+ # Save stats
+ interpolate_basic.save_anemoi_stats(realch1, os.path.join(output_path, f'info/{ds_name}-stats'))
+
+ output_data_path = os.path.join(output_path, ds_name)
+ output_plots_path = os.path.join(output_path, 'plots')
+
+ # Split regridding into batches; too many time points seems to not scale well.
+ for i in range(0, len(realch1.dates), XARRAY_BATCH):
+ start_index = i
+ end_index = min(i+XARRAY_BATCH, len(realch1.dates))
+ logging.info(f'start={start_index} end={end_index}')
+ interpolate_anemoi_range_to_rotlatlon(start_index, end_index, realch1, ds_name,
+ input_grid, output_grid, output_data_path, format, output_plots_path, plot_indices)
+
+
+def main():
+ # yml format
+ realch1_config_file = sys.argv[1]
+ output_directory = sys.argv[2]
+ if not os.path.exists(output_directory):
+ os.mkdir(output_directory)
+ for subdir in ['info', 'plots', 'realch1']:
+ if not os.path.exists(os.path.join(output_directory, subdir)):
+ os.mkdir(os.path.join(output_directory, subdir))
+
+ logging.basicConfig(
+ filename=os.path.join(output_directory, 'regrid_realch1.log'),
+ format='%(asctime)s %(levelname)-8s %(message)s',
+ level=logging.INFO,
+ datefmt='%Y-%m-%d %H:%M:%S')
+
+ interpolate_anemoi_to_rotlatlon(realch1_config_file, 'realch1', None, output_directory, 'numpy', [0])
+
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/src/hirad/input_data/reprocess_change_tp_accum.py b/src/hirad/input_data/reprocess_change_tp_accum.py
new file mode 100644
index 00000000..56423000
--- /dev/null
+++ b/src/hirad/input_data/reprocess_change_tp_accum.py
@@ -0,0 +1,58 @@
+import logging
+import os
+import sys
+
+import torch
+import numpy as np
+
+# Reprocess ERA-interpolated data to exclude the tp variable.
+
+# 6H data is all channels, but with 6h accumulation
+DATA_SOURCE_6H = "/capstor/scratch/cscs/mmcgloho/datasets/processed/era5-cosmo-1h-all-channels/era-interpolated"
+STATS_FILEPATH_6H = "/capstor/scratch/cscs/mmcgloho/datasets/processed/era5-cosmo-1h-all-channels/info"
+# 1h data is the updated
+DATA_SOURCE_1H = "/capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/validation/era-interpolated-with-copernicus-tp/"
+STATS_FILEPATH_1H = "/capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/validation/info"
+OUTPUT_DIR = "/iopsstor/scratch/cscs/mmcgloho/run-1_4/validation/era-interpolated"
+OUTPUT_STATS_FILEPATH = "/iopsstor/scratch/cscs/mmcgloho/run-1_4/validation/info/"
+TP_INDEX_6H = 34 # in era-all.yaml
+TP_INDEX_1H = 12 # in era.yaml
+
+def process(input_directory_6h: str, input_directory_1h: str, output_directory: str):
+ input_1h_filepath = os.path.join(input_directory_1h)
+ files = os.listdir(input_1h_filepath)
+ files.sort()
+ for f in range(len(files)):
+ if f % 100 == 0:
+ logging.info(f)
+ input_1h_file = os.path.join(input_directory_1h, files[f])
+ input_6h_file = os.path.join(input_directory_6h, files[f])
+ outfile = os.path.join(output_directory, files[f])
+ in_data_6h = torch.load(input_6h_file, weights_only=False)
+ in_data_1h = torch.load(input_1h_file, weights_only=False)
+ in_data_6h[TP_INDEX_6H,:] = in_data_1h[TP_INDEX_1H,:]
+ torch.save(in_data_6h, outfile)
+
+def edit_info(info_6h_filepath: str, info_1h_filepath: str, output_filepath: str):
+ stats_6h = torch.load(os.path.join(info_6h_filepath, 'era-stats'), weights_only=False)
+ stats_1h = torch.load(os.path.join(info_1h_filepath, 'era-stats'), weights_only=False)
+ logging.info(f'6h stats: {stats_6h}')
+ logging.info(f'1h stats: {stats_1h}')
+ for k in stats_6h.keys():
+ logging.info(k, stats_6h[k])
+ tmp = stats_6h[k]
+ tmp[TP_INDEX_6H] = stats_1h[k][TP_INDEX_1H]
+ stats_6h[k] = tmp
+ logging.info(stats_6h)
+ torch.save(stats_6h, os.path.join(output_filepath, 'era-stats'))
+
+def main():
+ logging.basicConfig(
+ format='%(asctime)s %(levelname)-8s %(message)s',
+ level=logging.INFO,
+ datefmt='%Y-%m-%d %H:%M:%S')
+ process(DATA_SOURCE_6H, DATA_SOURCE_1H, OUTPUT_DIR)
+ edit_info(STATS_FILEPATH_6H, STATS_FILEPATH_1H, OUTPUT_STATS_FILEPATH)
+
+if __name__ == "__main__":
+ main()
diff --git a/src/hirad/input_data/reprocess_exclude_tp.py b/src/hirad/input_data/reprocess_exclude_tp.py
new file mode 100644
index 00000000..190a4742
--- /dev/null
+++ b/src/hirad/input_data/reprocess_exclude_tp.py
@@ -0,0 +1,43 @@
+import logging
+import os
+import sys
+
+import torch
+import numpy as np
+
+# Reprocess ERA-interpolated data to exclude the tp variable.
+
+TP_INDEX = 12
+
+def process(input_directory: str, output_directory: str):
+ input_filepath = os.path.join(input_directory, 'era-interpolated')
+ files = os.listdir(input_filepath)
+ files.sort()
+ for f in range(len(files)):
+ if f % 100 == 0:
+ logging.info(f)
+ outfile = os.path.join(output_directory, 'era-interpolated', files[f])
+ if (not os.path.exists(outfile)) or (os.path.getsize(outfile) < 26000000):
+ in_data = torch.load(os.path.join(input_filepath, files[f]), weights_only=False)
+ out_data = in_data[0:TP_INDEX,:]
+ torch.save(out_data, outfile)
+
+def edit_info(input_filepath: str, output_filepath: str):
+ stats = torch.load(os.path.join(input_filepath, '/info', 'era-stats'), weights_only=False)
+ logging.info(stats)
+ for k in stats.keys():
+ logging.info(k, stats[k])
+ stats[k] = stats[k][0:TP_INDEX]
+ logging.info(stats)
+ torch.save(stats, os.path.join(output_filepath, "/info", "era-stats"))
+
+def main():
+ root = logging.getLogger()
+ root.setLevel(logging.INFO)
+ input_directory = sys.argv[1]
+ output_directory = sys.argv[2]
+ process(input_directory, output_directory)
+ #edit_info(input_directory, output_directory)
+
+if __name__ == "__main__":
+ main()
diff --git a/src/hirad/input_data/scripts/copernicus-tp.sh b/src/hirad/input_data/scripts/copernicus-tp.sh
new file mode 100644
index 00000000..943812f6
--- /dev/null
+++ b/src/hirad/input_data/scripts/copernicus-tp.sh
@@ -0,0 +1,5 @@
+#!/bin/sh
+
+pip install -e .
+pip install anemoi.datasets
+python src/hirad/input_data/read_tp.py
diff --git a/src/hirad/input_data/scripts/copyanemoi.sh b/src/hirad/input_data/scripts/copyanemoi.sh
new file mode 100644
index 00000000..361a9fbb
--- /dev/null
+++ b/src/hirad/input_data/scripts/copyanemoi.sh
@@ -0,0 +1,11 @@
+#!/bin/bash -l
+#
+#SBATCH --time=23:59:00
+#SBATCH --ntasks=1
+#SBATCH --partition=xfer
+
+echo -e "$SLURM_JOB_NAME started on $(date):\n $command $1 $2"
+cp -rvn $1 $2
+
+echo -e "$SLURM_JOB_NAME finished on $(date)\n"
+
diff --git a/src/hirad/input_data/scripts/fix-files.sh b/src/hirad/input_data/scripts/fix-files.sh
new file mode 100644
index 00000000..35a160a6
--- /dev/null
+++ b/src/hirad/input_data/scripts/fix-files.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+
+#clariden
+regex_pattern="^era-interpolated(.*)"
+for f in $(ls /capstor/scratch/cscs/mmcgloho/basic-torch/era5-cosmo-1h-all-channels/era-interpolated/);
+do
+ if [[ "$f" =~ $regex_pattern ]]; then
+ newf=${BASH_REMATCH[1]}
+ echo "newf: $newf"
+ mv /capstor/scratch/cscs/mmcgloho/basic-torch/era5-cosmo-1h-all-channels/era-interpolated/$f /capstor/scratch/cscs/mmcgloho/basic-torch/era5-cosmo-1h-all-channels/era-interpolated/$newf
+ fi
+done
\ No newline at end of file
diff --git a/src/hirad/input_data/scripts/interpolate-batches.sh b/src/hirad/input_data/scripts/interpolate-batches.sh
new file mode 100644
index 00000000..e95b82df
--- /dev/null
+++ b/src/hirad/input_data/scripts/interpolate-batches.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+#clariden
+for cfgfile in $(ls src/hirad/input_data/configs/era-all-2020*.yaml);
+do
+ cmd="sbatch -A a161 -t 12:00:00 -n 1 -c 1 --begin=now+18hour --environment=modulus_env src/hirad/interpolate.sh ${cfgfile}"
+ echo $cmd
+ $cmd
+done
+
+# balfrin
+#for cfgfile in $(ls src/hirad/input_data/configs/era-all-2016*.yaml);
+#do
+# cmd="sbatch -p postproc -t 12:00:00 -n 1 -c 1 src/hirad/interpolate.sh ${cfgfile}"
+# echo $cmd
+# $cmd
+#done
diff --git a/src/hirad/input_data/scripts/interpolate-reprocess.sh b/src/hirad/input_data/scripts/interpolate-reprocess.sh
new file mode 100644
index 00000000..054e7397
--- /dev/null
+++ b/src/hirad/input_data/scripts/interpolate-reprocess.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+
+#SBATCH --time=12:00:00
+
+echo 'activating env'
+source /users/mmcgloho/interpolate-env-ssp/bin/activate
+echo 'running'
+#python src/hirad/input_data/process_torch_to_numpy.py /store_new/mch/msopr/hirad-gen/basic-torch/era5-realch1/v1.0/era-copernicus-interpolated/ /iopsstor/scratch/cscs/mmcgloho/basic-numpy/era5-realch1/v1.0-channel-subset/era-copernicus-interpolated/
+#python src/hirad/input_data/process_torch_to_numpy.py /capstor/scratch/cscs/mmcgloho/basic-torch/era5-realch1/v1.0/era-copernicus-interpolated/ /iopsstor/scratch/cscs/mmcgloho/basic-numpy/era5-realch1/v1.0-channel-subset/era-copernicus-interpolated/
+python src/hirad/input_data/check-fileload.py /iopsstor/scratch/cscs/mmcgloho/basic-numpy/era5-realch1/v1.0-channel-subset/realch1/
\ No newline at end of file
diff --git a/src/hirad/input_data/scripts/interpolate.sh b/src/hirad/input_data/scripts/interpolate.sh
new file mode 100755
index 00000000..9210c5d2
--- /dev/null
+++ b/src/hirad/input_data/scripts/interpolate.sh
@@ -0,0 +1,62 @@
+#!/bin/bash
+
+#SBATCH --time=12:00:00
+
+source /users/mmcgloho/interpolate-env-ssp/bin/activate
+python src/hirad/input_data/regrid_realch1.py src/hirad/input_data/configs/realch1-static.yaml /iopsstor/scratch/cscs/mmcgloho/basic-numpy/era5-realch1/v1.0-channel-subset/static
+
+
+#srun -A a161 -t 12:00:00 --environment=modulus_env bash -c "
+# pip install -e . --no-dependencies
+# pip install anemoi.datasets
+# python src/hirad/input_data/interpolate_basic.py src/hirad/input_data/era-all.yaml src/hirad/input_data/cosmo-all.yaml /capstor/scratch/cscs/mmcgloho/datasets/processed/era5-cosmo-1h-all-channels/
+#"
+#pip install -e . --no-dependencies
+#pip install anemoi.datasets
+#pip install meteodata-lab
+#python src/hirad/input_data/interpolate_realch1.py \
+# src/hirad/input_data/era.yaml \
+# /capstor/store/mch/msopr/hirad-gen/basic-torch/era5-realch1/v0.2/info/realch1-lat-lon \
+# /capstor/store/mch/msopr/hirad-gen/copernicus-datasets/tp-2023-2024.nc \
+# /capstor/scratch/cscs/mmcgloho/basic-torch/era5-realch1/v1.0/
+#python src/hirad/input_data/interpolate_basic.py \
+# $1 \
+# /capstor/scratch/cscs/mmcgloho/basic-torch/era5-cosmo-1h-all-channels/info/cosmo-lat-lon \
+# /capstor/scratch/cscs/mmcgloho/basic-torch/era5-cosmo-1h-all-channels/
+
+#python src/hirad/input_data/process_copernicus_cosmo.py \
+# src/hirad/input_data/copernicus.yaml \
+# src/hirad/input_data/cosmo-all.yaml \
+# /capstor/scratch/cscs/mmcgloho/basic-numpy/copernicus-cosmo-1h
+
+#python src/hirad/input_data/process_era5_cosmo.py \
+# src/hirad/input_data/era-all.yaml \
+# src/hirad/input_data/cosmo-all.yaml \
+# /iopsstor/scratch/cscs/mmcgloho/basic-numpy/era5-cosmo-1h-all-channels/
+
+#python src/hirad/input_data/process_era5_cosmo.py \
+# src/hirad/input_data/configs/era-all-202006.yaml \
+# src/hirad/input_data/cosmo-all.yaml \
+# /iopsstor/scratch/cscs/mmcgloho/basic-numpy/era5-cosmo-1h-all-channels/
+
+#python src/hirad/input_data/process_era5_with_copernicus.py \
+ # src/hirad/input_data/era-all.yaml \
+ # src/hirad/input_data/copernicus.yaml \
+ # /capstor/scratch/cscs/mmcgloho/basic-torch/era5-cosmo-1h-all-channels/era-interpolated/ \
+ # /capstor/scratch/cscs/mmcgloho/basic-numpy/copernicus-cosmo-1h/copernicus-interpolated/ \
+ # /iopsstor/scratch/cscs/mmcgloho/basic-numpy/era5-cosmo-1h-all-channels/
+
+#python src/hirad/input_data/process_era5_with_copernicus.py \
+# src/hirad/input_data/era-all.yaml \
+# src/hirad/input_data/copernicus.yaml \
+# /capstor/scratch/cscs/mmcgloho/basic-torch/era5-cosmo-1h-all-channels/eram-interpolated/ \
+# /capstor/scratch/cscs/mmcgloho/basic-numpy/copernicus-cosmo-1h/copernicus-interpolated/ \
+# /iopsstor/scratch/cscs/mmcgloho/basic-numpy/era5-cosmo-1h-all-channels/
+
+
+# balfrin
+#python src/hirad/input_data/interpolate_basic.py \
+# $1 \
+# /store_new/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-all-channels/info/cosmo-lat-lon \
+# /store_new/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-all-channels/
+
diff --git a/src/hirad/input_data/scripts/regrid_copernicus_tp.sh b/src/hirad/input_data/scripts/regrid_copernicus_tp.sh
new file mode 100755
index 00000000..f7fc5501
--- /dev/null
+++ b/src/hirad/input_data/scripts/regrid_copernicus_tp.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+
+pip install -e .
+pip install anemoi.datasets
+
+python src/hirad/input_data/regrid_copernicus_tp.py \
+ /capstor/store/cscs/swissai/a161/datasets/copernicus/tp-2019-2020.nc \
+ /capstor/store/mch/msopr/hirad-gen/basic-torch/era5-cosmo-1h-linear-interpolation-full/ \
+ /capstor/store/cscs/swissai/a161/era5-cosmo-1h-linear-interpolation/train/era-interpolated-with-copernicus-tp/
\ No newline at end of file
diff --git a/src/hirad/input_data/scripts/rsync.sh b/src/hirad/input_data/scripts/rsync.sh
new file mode 100644
index 00000000..a5828794
--- /dev/null
+++ b/src/hirad/input_data/scripts/rsync.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+
+#SBATCH --time=12:00:00
+#SBATCH --partition=xfer
+
+rsync -av /iopsstor/scratch/cscs/mmcgloho/basic-numpy /capstor/store/cscs/pasc/c38/basic-numpy
\ No newline at end of file
diff --git a/src/hirad/input_data/test_input_data.py b/src/hirad/input_data/test_input_data.py
new file mode 100644
index 00000000..d8d0f6ee
--- /dev/null
+++ b/src/hirad/input_data/test_input_data.py
@@ -0,0 +1,92 @@
+import logging
+import os
+import sys
+
+import datetime
+import torch
+import numpy as np
+
+
+from hirad.eval.plotting import plot_map_precipitation, plot_scores_vs_t
+
+def load_all_data(filepath: str):
+ files = os.listdir(filepath)
+ example = torch.load(os.path.join(filepath, files[0]), weights_only=False)
+ dims = (len(files),) + example.shape
+ data = np.zeros(dims)
+ for f in range(100):
+ #for f in range(len(files)):
+ if f % 100 == 0:
+ logging.info(f)
+ curr = torch.load(os.path.join(filepath, files[f]), weights_only=False)
+ data[f,:] = curr
+ return data
+
+def count_nans(data: np.array):
+ nans = np.count_nonzero(np.isnan(data))
+ return nans
+
+def make_stats(filepath: str):
+ data = load_all_data(filepath)
+ stats = {}
+ num_channels = data.shape[1]
+ stats['mean'] = np.zeros(num_channels)
+ stats['stdev'] = np.zeros(num_channels)
+ stats['minimum'] = np.zeros(num_channels)
+ stats['maximum'] = np.zeros(num_channels)
+ for k in range(num_channels):
+ logging.info(f'channel {k}')
+ stats['mean'][k] = np.mean(data[:,k,:,:])
+ stats['minimum'][k] = np.min(data[:,k,:,:])
+ stats['maximum'][k] = np.max(data[:,k,:,:])
+ stats['stdev'][k] = np.std(data[:,k,:,:])
+ return stats
+
+def main():
+ root = logging.getLogger()
+ root.setLevel(logging.INFO)
+ input_directory = sys.argv[1]
+
+ logging.info(f'checking input directory {input_directory}')
+
+ missing_data = []
+ corrupt_data = []
+ nan_data = []
+ check_for_nans = False
+ check_for_corrupt = False
+
+
+
+ files = os.listdir(input_directory)
+ files.sort()
+ start_date = datetime.datetime.strptime(files[0],'%Y%m%d-%H%M')
+ next_date = datetime.datetime.strptime(files[1],'%Y%m%d-%H%M')
+ delta = next_date - start_date
+ prev_date = start_date - delta
+
+ for f in files:
+ curr_date = datetime.datetime.strptime(f,'%Y%m%d-%H%M')
+ if curr_date - prev_date != delta:
+ logging.info(f'missing data: {prev_date} and {curr_date} not {delta} apart')
+ expected_date = prev_date + delta
+ while (expected_date < curr_date):
+ missing_data.append(datetime.datetime.strftime(expected_date, '%Y%m%d-%H%M'))
+ expected_date = expected_date + delta
+ if check_for_corrupt:
+ try:
+ data = torch.load(os.path.join(input_directory, f), weights_only=False)
+ except:
+ logging.info(f'corrupt data: {curr_date}')
+ corrupt_data.append(curr_date)
+ if check_for_nans or curr_date == start_date:
+ if count_nans(data):
+ logging.info(f'data nans: {curr_date}')
+ nan_data.append(curr_date)
+ prev_date = curr_date
+ logging.info(f'missing data size {len(missing_data)}: {missing_data}')
+ logging.info(f'corrupt data size {len(corrupt_data)}: {corrupt_data}')
+ if check_for_nans:
+ logging.info(f'nan data: {nan_data}')
+
+if __name__ == "__main__":
+ main()
diff --git a/src/hirad/losses/__init__.py b/src/hirad/losses/__init__.py
index 1494a54c..ab2b2a93 100644
--- a/src/hirad/losses/__init__.py
+++ b/src/hirad/losses/__init__.py
@@ -1 +1 @@
-from .loss import ResidualLoss, RegressionLoss, RegressionLossCE
+from .loss import ResidualLoss, RegressionLoss
diff --git a/src/hirad/losses/loss.py b/src/hirad/losses/loss.py
index fb659607..4030bfc2 100644
--- a/src/hirad/losses/loss.py
+++ b/src/hirad/losses/loss.py
@@ -25,327 +25,6 @@
from hirad.utils.patching import RandomPatching2D
-class VPLoss:
- """
- Loss function corresponding to the variance preserving (VP) formulation.
-
- Parameters
- ----------
- beta_d: float, optional
- Coefficient for the diffusion process, by default 19.9.
- beta_min: float, optional
- Minimum bound, by defaults 0.1.
- epsilon_t: float, optional
- Small positive value, by default 1e-5.
-
- Note:
- -----
- Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and
- Poole, B., 2020. Score-based generative modeling through stochastic differential
- equations. arXiv preprint arXiv:2011.13456.
-
- """
-
- def __init__(
- self, beta_d: float = 19.9, beta_min: float = 0.1, epsilon_t: float = 1e-5
- ):
- self.beta_d = beta_d
- self.beta_min = beta_min
- self.epsilon_t = epsilon_t
-
- def __call__(
- self,
- net: torch.nn.Module,
- images: torch.Tensor,
- labels: torch.Tensor,
- augment_pipe: Optional[Callable] = None,
- ):
- """
- Calculate and return the loss corresponding to the variance preserving (VP)
- formulation.
-
- The method adds random noise to the input images and calculates the loss as the
- square difference between the network's predictions and the input images.
- The noise level is determined by 'sigma', which is computed as a function of
- 'epsilon_t' and random values. The calculated loss is weighted based on the
- inverse of 'sigma^2'.
-
- Parameters:
- ----------
- net: torch.nn.Module
- The neural network model that will make predictions.
-
- images: torch.Tensor
- Input images to the neural network.
-
- labels: torch.Tensor
- Ground truth labels for the input images.
-
- augment_pipe: callable, optional
- An optional data augmentation function that takes images as input and
- returns augmented images. If not provided, no data augmentation is applied.
-
- Returns:
- -------
- torch.Tensor
- A tensor representing the loss calculated based on the network's
- predictions.
- """
- rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
- sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1))
- weight = 1 / sigma**2
- y, augment_labels = (
- augment_pipe(images) if augment_pipe is not None else (images, None)
- )
- n = torch.randn_like(y) * sigma
- D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
- loss = weight * ((D_yn - y) ** 2)
- return loss
-
- def sigma(
- self, t: Union[float, torch.Tensor]
- ): # NOTE: also exists in preconditioning
- """
- Compute the sigma(t) value for a given t based on the VP formulation.
-
- The function calculates the noise level schedule for the diffusion process based
- on the given parameters `beta_d` and `beta_min`.
-
- Parameters
- ----------
- t : Union[float, torch.Tensor]
- The timestep or set of timesteps for which to compute sigma(t).
-
- Returns
- -------
- torch.Tensor
- The computed sigma(t) value(s).
- """
- t = torch.as_tensor(t)
- return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt()
-
-
-class VELoss:
- """
- Loss function corresponding to the variance exploding (VE) formulation.
-
- Parameters
- ----------
- sigma_min : float
- Minimum supported noise level, by default 0.02.
- sigma_max : float
- Maximum supported noise level, by default 100.0.
-
- Note:
- -----
- Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and
- Poole, B., 2020. Score-based generative modeling through stochastic differential
- equations. arXiv preprint arXiv:2011.13456.
- """
-
- def __init__(self, sigma_min: float = 0.02, sigma_max: float = 100.0):
- self.sigma_min = sigma_min
- self.sigma_max = sigma_max
-
- def __call__(self, net, images, labels, augment_pipe=None):
- """
- Calculate and return the loss corresponding to the variance exploding (VE)
- formulation.
-
- The method adds random noise to the input images and calculates the loss as the
- square difference between the network's predictions and the input images.
- The noise level is determined by 'sigma', which is computed as a function of
- 'sigma_min' and 'sigma_max' and random values. The calculated loss is weighted
- based on the inverse of 'sigma^2'.
-
- Parameters:
- ----------
- net: torch.nn.Module
- The neural network model that will make predictions.
-
- images: torch.Tensor
- Input images to the neural network.
-
- labels: torch.Tensor
- Ground truth labels for the input images.
-
- augment_pipe: callable, optional
- An optional data augmentation function that takes images as input and
- returns augmented images. If not provided, no data augmentation is applied.
-
- Returns:
- -------
- torch.Tensor
- A tensor representing the loss calculated based on the network's
- predictions.
- """
- rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
- sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform)
- weight = 1 / sigma**2
- y, augment_labels = (
- augment_pipe(images) if augment_pipe is not None else (images, None)
- )
- n = torch.randn_like(y) * sigma
- D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
- loss = weight * ((D_yn - y) ** 2)
- return loss
-
-
-class EDMLoss:
- """
- Loss function proposed in the EDM paper.
-
- Parameters
- ----------
- P_mean: float, optional
- Mean value for `sigma` computation, by default -1.2.
- P_std: float, optional:
- Standard deviation for `sigma` computation, by default 1.2.
- sigma_data: float, optional
- Standard deviation for data, by default 0.5.
-
- Note
- ----
- Reference: Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the
- design space of diffusion-based generative models. Advances in Neural Information
- Processing Systems, 35, pp.26565-26577.
- """
-
- def __init__(
- self, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5
- ):
- self.P_mean = P_mean
- self.P_std = P_std
- self.sigma_data = sigma_data
-
- def __call__(self, net, images, condition=None, labels=None, augment_pipe=None):
- """
- Calculate and return the loss corresponding to the EDM formulation.
-
- The method adds random noise to the input images and calculates the loss as the
- square difference between the network's predictions and the input images.
- The noise level is determined by 'sigma', which is computed as a function of
- 'P_mean' and 'P_std' random values. The calculated loss is weighted as a
- function of 'sigma' and 'sigma_data'.
-
- Parameters:
- ----------
- net: torch.nn.Module
- The neural network model that will make predictions.
-
- images: torch.Tensor
- Input images to the neural network.
-
- labels: torch.Tensor
- Ground truth labels for the input images.
-
- augment_pipe: callable, optional
- An optional data augmentation function that takes images as input and
- returns augmented images. If not provided, no data augmentation is applied.
-
- Returns:
- -------
- torch.Tensor
- A tensor representing the loss calculated based on the network's
- predictions.
- """
- rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
- sigma = (rnd_normal * self.P_std + self.P_mean).exp()
- weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
- y, augment_labels = (
- augment_pipe(images) if augment_pipe is not None else (images, None)
- )
- n = torch.randn_like(y) * sigma
- if condition is not None:
- D_yn = net(
- y + n,
- sigma,
- condition=condition,
- class_labels=labels,
- augment_labels=augment_labels,
- )
- else:
- D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
- loss = weight * ((D_yn - y) ** 2)
- return loss
-
-
-class EDMLossSR:
- """
- Variation of the loss function proposed in the EDM paper for Super-Resolution.
-
- Parameters
- ----------
- P_mean: float, optional
- Mean value for `sigma` computation, by default -1.2.
- P_std: float, optional:
- Standard deviation for `sigma` computation, by default 1.2.
- sigma_data: float, optional
- Standard deviation for data, by default 0.5.
-
- Note
- ----
- Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y.,
- Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023.
- Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling.
- arXiv preprint arXiv:2309.15214.
- """
-
- def __init__(
- self, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5
- ):
- self.P_mean = P_mean
- self.P_std = P_std
- self.sigma_data = sigma_data
-
- def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None):
- """
- Calculate and return the loss corresponding to the EDM formulation.
-
- The method adds random noise to the input images and calculates the loss as the
- square difference between the network's predictions and the input images.
- The noise level is determined by 'sigma', which is computed as a function of
- 'P_mean' and 'P_std' random values. The calculated loss is weighted as a
- function of 'sigma' and 'sigma_data'.
-
- Parameters:
- ----------
- net: torch.nn.Module
- The neural network model that will make predictions.
-
- images: torch.Tensor
- Input images to the neural network.
-
- labels: torch.Tensor
- Ground truth labels for the input images.
-
- augment_pipe: callable, optional
- An optional data augmentation function that takes images as input and
- returns augmented images. If not provided, no data augmentation is applied.
-
- Returns:
- -------
- torch.Tensor
- A tensor representing the loss calculated based on the network's
- predictions.
- """
- rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device)
- sigma = (rnd_normal * self.P_std + self.P_mean).exp()
- weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
-
- # augment for conditional generation
- img_tot = torch.cat((img_clean, img_lr), dim=1)
- y_tot, augment_labels = (
- augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None)
- )
- y = y_tot[:, : img_clean.shape[1], :, :]
- y_lr = y_tot[:, img_clean.shape[1] :, :, :]
-
- n = torch.randn_like(y) * sigma
- D_yn = net(y + n, y_lr, sigma, labels, augment_labels=augment_labels)
- loss = weight * ((D_yn - y) ** 2)
- return loss
-
class RegressionLoss:
"""
@@ -377,9 +56,13 @@ def __call__(
net: torch.nn.Module,
img_clean: torch.Tensor,
img_lr: torch.Tensor,
+ static_channels: Optional[torch.Tensor] = None,
+ date_embedding: Optional[torch.Tensor] = None,
augment_pipe: Optional[
Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]]
] = None,
+ lead_time_label: Optional[torch.Tensor] = None,
+ use_apex_gn: bool = False,
) -> torch.Tensor:
"""
Calculate and return the regression loss for
@@ -408,6 +91,12 @@ def __call__(
Low-resolution input images of shape (B, C_lr, H, W).
Used as input to the neural network.
+ static_channels : torch.Tensor, optional
+ Static channels input of shape (C_static, H, W).
+
+ date_embedding : torch.Tensor, optional
+ Date embedding input of shape (B, C_date).
+
augment_pipe : callable, optional
An optional data augmentation function.
Expected signature:
@@ -439,7 +128,37 @@ def __call__(
y_lr = y_tot[:, img_clean.shape[1] :, :, :]
zero_input = torch.zeros_like(y, device=img_clean.device)
- D_yn = net(zero_input, y_lr, force_fp32=False, augment_labels=augment_labels)
+
+ if static_channels is not None:
+ y_lr = torch.cat(
+ (y_lr, static_channels.expand(y_lr.shape[0], *static_channels.shape[1:])),
+ dim=1,
+ )
+
+ if date_embedding is not None:
+ date_embedding = date_embedding[:, :, None, None].expand(*date_embedding.shape[:2], *y_lr.shape[2:])
+ if use_apex_gn:
+ date_embedding = date_embedding.to(y_lr.dtype, non_blocking=True).to(memory_format=torch.channels_last)
+ else:
+ date_embedding = date_embedding.to(y_lr.dtype, non_blocking=True).contiguous()
+ y_lr = torch.cat((y_lr, date_embedding), dim=1)
+
+ if lead_time_label is not None:
+ D_yn = net(
+ zero_input,
+ y_lr,
+ force_fp32=False,
+ lead_time_label=lead_time_label,
+ augment_labels=augment_labels,
+ )
+ else:
+ D_yn = net(
+ zero_input,
+ y_lr,
+ force_fp32=False,
+ augment_labels=augment_labels,
+ )
+
loss = weight * ((D_yn - y) ** 2)
return loss
@@ -518,17 +237,46 @@ def __init__(
self.hr_mean_conditioning = hr_mean_conditioning
self.y_mean = None
+ def get_noise_params(self, y: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the noise parameters to apply denoising score matching.
+
+ Parameters
+ ----------
+ y : torch.Tensor
+ Latent state of shape :math:`(B, *)`. Only used to determine the shape of
+ the noise and create tensors on the same device.
+
+ Returns
+ -------
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ - Noise ``n`` of shape :math:`(B, *)` to be added to the latent state.
+ - Noise level ``sigma`` of shape :math:`(B, 1, 1, 1)`.
+ - Weight ``weight`` of shape :math:`(B, 1, 1, 1)` to multiply the loss.
+ """
+ # Sample noise level
+ rnd_normal = torch.randn([y.shape[0], 1, 1, 1], device=y.device)
+ sigma = (rnd_normal * self.P_std + self.P_mean).exp()
+ # Loss weight
+ weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
+ # Sample noise
+ n = torch.randn_like(y) * sigma
+ return n, sigma, weight
+
def __call__(
self,
net: torch.nn.Module,
img_clean: torch.Tensor,
img_lr: torch.Tensor,
+ static_channels: Optional[torch.Tensor] = None,
+ date_embedding: Optional[torch.Tensor] = None,
patching: Optional[RandomPatching2D] = None,
lead_time_label: Optional[torch.Tensor] = None,
augment_pipe: Optional[
Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]]
] = None,
use_patch_grad_acc: bool = False,
+ use_apex_gn: bool = False,
) -> torch.Tensor:
"""
Calculate and return the loss for denoising score matching.
@@ -590,6 +338,12 @@ def __call__(
Used as input to the regression network and conditioning for the
diffusion process.
+ static_channels : Optional[torch.Tensor], optional
+ Static channels input of shape (1, C_static, H, W), by default None.
+
+ date_embedding : Optional[torch.Tensor], optional
+ Date embedding input of shape (B, C_date), by default None
+
patching : Optional[RandomPatching2D], optional
Patching strategy for processing large images, by default None. See
:class:`physicsnemo.utils.patching.RandomPatching2D` for details.
@@ -614,6 +368,8 @@ def __call__(
use_patch_grad_acc: bool, optional
A boolean flag indicating whether to enable multi-iterations of patching accumulations
for amortizing regression cost. Default False.
+ use_apex_gn: bool, optional
+ A boolean flag indicating whether apex group norm is used in the model.
Returns
-------
@@ -655,28 +411,28 @@ def __call__(
y_lr_res = y_lr
batch_size = y.shape[0]
+ # print(f"Shape of y: {y.shape}, y_lr: {y_lr.shape}")
+
# if using multi-iterations of patching, switch to optimized version
- if use_patch_grad_acc:
+ if not use_patch_grad_acc or self.y_mean is None:
# form residual
- if self.y_mean is None:
- if lead_time_label is not None:
- y_mean = self.regression_net(
- torch.zeros_like(y, device=img_clean.device),
- y_lr_res,
- lead_time_label=lead_time_label,
- augment_labels=augment_labels,
- )
+ if static_channels is not None:
+ y_lr_res = torch.cat(
+ (y_lr_res, static_channels.expand(y_lr_res.shape[0], *static_channels.shape[1:])),
+ dim=1,
+ )
+ # print(f"Shape of y_lr after static channels regression: y_lr_res {y_lr_res.shape} y_lr {y_lr.shape}")
+
+ if date_embedding is not None:
+ date_embedding_reg = date_embedding[:, :, None, None].expand(*date_embedding.shape[:2], *y_lr_res.shape[2:])
+ if use_apex_gn:
+ date_embedding_reg = date_embedding_reg.to(y_lr_res.dtype, non_blocking=True).to(memory_format=torch.channels_last)
else:
- y_mean = self.regression_net(
- torch.zeros_like(y, device=img_clean.device),
- y_lr_res,
- augment_labels=augment_labels,
- )
- self.y_mean = y_mean
-
- # if on full domain, or if using patching without multi-iterations
- else:
- # form residual
+ date_embedding_reg = date_embedding_reg.to(y_lr_res.dtype, non_blocking=True).contiguous()
+ y_lr_res = torch.cat((y_lr_res, date_embedding_reg), dim=1)
+
+ # print(f"Shape of y_lr after date embedding regression: y_lr_res {y_lr_res.shape} y_lr {y_lr.shape}")
+
if lead_time_label is not None:
y_mean = self.regression_net(
torch.zeros_like(y, device=img_clean.device),
@@ -695,9 +451,21 @@ def __call__(
y = y - self.y_mean
+ # print(f"Shape of y after residual: y {y.shape} y_lr {y_lr.shape}")
+
if self.hr_mean_conditioning:
y_lr = torch.cat((self.y_mean, y_lr), dim=1)
+ # print(f"Shape of y_lr after hr mean conditioning: y_lr {y_lr.shape}")
+
+ if static_channels is not None:
+ y_lr = torch.cat(
+ (y_lr, static_channels.expand(y_lr.shape[0], *static_channels.shape[1:])),
+ dim=1,
+ )
+
+ # print(f"Shape of y_lr after static channels diffusion: y_lr {y_lr.shape}")
+
# patchified training
# conditioning: cat(y_mean, y_lr, input_interp, pos_embd), 4+12+100+4
# removed patch_embedding_selector due to compilation issue with dynamo.
@@ -707,326 +475,65 @@ def __call__(
y_patched = patching.apply(input=y)
# Patched conditioning on y_lr and interp(img_lr)
# (batch_size * patch_num, 2*c_in, patch_shape_y, patch_shape_x)
+ if static_channels is not None:
+ img_lr = torch.cat(
+ (img_lr, static_channels.expand(img_lr.shape[0], *static_channels.shape[1:])),
+ dim=1,
+ )
+ # print(f"Shape of img_lr after static channels diffusion patching: img_lr {img_lr.shape}")
+ if date_embedding is not None:
+ date_embedding = date_embedding[:, :, None, None].expand(*date_embedding.shape[:2], *img_lr.shape[2:])
+ if use_apex_gn:
+ date_embedding = date_embedding.to(img_lr.dtype, non_blocking=True).to(memory_format=torch.channels_last)
+ else:
+ date_embedding = date_embedding.to(img_lr.dtype, non_blocking=True).contiguous()
+ img_lr = torch.cat((img_lr, date_embedding), dim=1)
+ # print(f"Shape of img_lr after date embedding diffusion patching: img_lr {img_lr.shape}")
y_lr_patched = patching.apply(input=y_lr, additional_input=img_lr)
y = y_patched
y_lr = y_lr_patched
- # Noise
- rnd_normal = torch.randn([y.shape[0], 1, 1, 1], device=img_clean.device)
- sigma = (rnd_normal * self.P_std + self.P_mean).exp()
- weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
+ elif date_embedding is not None:
+ date_embedding = date_embedding[:, :, None, None].expand(*date_embedding.shape[:2], *y_lr.shape[2:])
+ if use_apex_gn:
+ date_embedding = date_embedding.to(y_lr.dtype, non_blocking=True).to(memory_format=torch.channels_last)
+ else:
+ date_embedding = date_embedding.to(y_lr.dtype, non_blocking=True).contiguous()
+ y_lr = torch.cat((y_lr, date_embedding), dim=1)
+
+ # print(f"Final shapes before noise addition: y {y.shape} y_lr {y_lr.shape}")
- # Input + noise
- latent = y + torch.randn_like(y) * sigma
+ # Add noise to the latent state
+ n, sigma, weight = self.get_noise_params(y)
if lead_time_label is not None:
D_yn = net(
- latent,
+ y + n,
y_lr,
sigma,
embedding_selector=None,
- global_index=patching.global_index(batch_size, img_clean.device)
- if patching is not None
- else None,
+ global_index=(
+ patching.global_index(batch_size, img_clean.device)
+ if patching is not None
+ else None
+ ),
lead_time_label=lead_time_label,
augment_labels=augment_labels,
)
else:
D_yn = net(
- latent,
+ y + n,
y_lr,
sigma,
embedding_selector=None,
- global_index=patching.global_index(batch_size, img_clean.device)
- if patching is not None
- else None,
+ global_index=(
+ patching.global_index(batch_size, img_clean.device)
+ if patching is not None
+ else None
+ ),
augment_labels=augment_labels,
)
loss = weight * ((D_yn - y) ** 2)
return loss
-
-
-
-class VELoss_dfsr:
- """
- Loss function for dfsr model, modified from class VELoss.
-
- Parameters
- ----------
- beta_start : float
- Noise level at the initial step of the forward diffusion process, by default 0.0001.
- beta_end : float
- Noise level at the Final step of the forward diffusion process, by default 0.02.
- num_diffusion_timesteps : int
- Total number of forward/backward diffusion steps, by default 1000.
-
-
- Note:
- -----
- Reference: Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models.
- Advances in neural information processing systems. 2020;33:6840-51.
- """
-
- def __init__(
- self,
- beta_start: float = 0.0001,
- beta_end: float = 0.02,
- num_diffusion_timesteps: int = 1000,
- ):
- # scheduler for diffusion:
- self.beta_schedule = "linear"
- self.beta_start = beta_start
- self.beta_end = beta_end
- self.num_diffusion_timesteps = num_diffusion_timesteps
- betas = self.get_beta_schedule(
- beta_schedule=self.beta_schedule,
- beta_start=self.beta_start,
- beta_end=self.beta_end,
- num_diffusion_timesteps=self.num_diffusion_timesteps,
- )
- self.betas = torch.from_numpy(betas).float()
- self.num_timesteps = betas.shape[0]
-
- def get_beta_schedule(
- self, beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps
- ):
- """
- Compute the variance scheduling parameters {beta(0), ..., beta(t), ..., beta(T)}
- based on the VP formulation.
-
- beta_schedule: str
- Method to construct the sequence of beta(t)'s.
- beta_start: float
- Noise level at the initial step of the forward diffusion process, e.g., beta(0)
- beta_end: float
- Noise level at the final step of the forward diffusion process, e.g., beta(T)
- num_diffusion_timesteps: int
- Total number of forward/backward diffusion steps
- """
-
- def sigmoid(x):
- return 1 / (np.exp(-x) + 1)
-
- if beta_schedule == "quad":
- betas = (
- np.linspace(
- beta_start**0.5,
- beta_end**0.5,
- num_diffusion_timesteps,
- dtype=np.float64,
- )
- ** 2
- )
- elif beta_schedule == "linear":
- betas = np.linspace(
- beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
- )
- elif beta_schedule == "const":
- betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
- elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
- betas = 1.0 / np.linspace(
- num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
- )
- elif beta_schedule == "sigmoid":
- betas = np.linspace(-6, 6, num_diffusion_timesteps)
- betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
- else:
- raise NotImplementedError(beta_schedule)
- if betas.shape != (num_diffusion_timesteps,):
- raise ValueError(
- f"Expected betas to have shape ({num_diffusion_timesteps},), "
- f"but got {betas.shape}"
- )
- return betas
-
- def __call__(self, net, images, labels, augment_pipe=None):
- """
- Calculate and return the loss corresponding to the variance preserving
- formulation.
-
- The method adds random noise to the input images and calculates the loss as the
- square difference between the network's predictions and the noise samples added
- to the t-th step of the diffusion process.
- The noise level is determined by 'beta_t' based on the given parameters 'beta_start',
- 'beta_end' and the current diffusion timestep t.
-
- Parameters:
- ----------
- net: torch.nn.Module
- The neural network model that will make predictions.
-
- images: torch.Tensor
- Input fluid flow data samples to the neural network.
-
- labels: torch.Tensor
- Ground truth labels for the input fluid flow data samples. Not required for dfsr.
-
- augment_pipe: callable, optional
- An optional data augmentation function that takes images as input and
- returns augmented images. If not provided, no data augmentation is applied.
-
- Returns:
- -------
- torch.Tensor
- A tensor representing the loss calculated based on the network's
- predictions.
- """
- t = torch.randint(
- low=0, high=self.num_timesteps, size=(images.size(0) // 2 + 1,)
- ).to(images.device)
- t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[: images.size(0)]
- e = torch.randn_like(images)
- b = self.betas.to(images.device)
- a = (1 - b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
- x = images * a.sqrt() + e * (1.0 - a).sqrt()
-
- output = net(x, t, labels)
- loss = (e - output).square()
-
- return loss
-
-
-class RegressionLossCE:
- """
- A regression loss function for deterministic predictions with probability
- channels and lead time labels. Adapted from
- :class:`physicsnemo.metrics.diffusion.loss.RegressionLoss`. In this version,
- probability channels are evaluated using CrossEntropyLoss instead of
- squared error.
- Note: this loss does not apply any reduction.
-
- Attributes
- ----------
- entropy : torch.nn.CrossEntropyLoss
- Cross entropy loss function used for probability channels.
- prob_channels : list[int]
- List of channel indices to be treated as probability channels.
-
- Note
- ----
- Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y.,
- Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023.
- Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling.
- arXiv preprint arXiv:2309.15214.
- """
-
- def __init__(
- self,
- prob_channels: list[int] = [4, 5, 6, 7, 8],
- ):
- """
- Arguments
- ----------
- prob_channels: list[int], optional
- List of channel indices from the target tensor to be treated as
- probability channels. Cross entropy loss is computed over these
- channels, while the remaining channels are treated as scalar
- channels and the squared error loss is computed over them. By
- default, [4, 5, 6, 7, 8].
- """
- self.entropy = torch.nn.CrossEntropyLoss(reduction="none")
- self.prob_channels = prob_channels
-
- def __call__(
- self,
- net: torch.nn.Module,
- img_clean: torch.Tensor,
- img_lr: torch.Tensor,
- lead_time_label: Optional[torch.Tensor] = None,
- augment_pipe: Optional[
- Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]]
- ] = None,
- ) -> torch.Tensor:
- """
- Calculate and return the loss for deterministic
- predictions, treating specific channels as probability distributions.
-
- Parameters
- ----------
- net : torch.nn.Module
- The neural network model that will make predictions.
- Expected signature: `net(input, img_lr, lead_time_label=lead_time_label, augment_labels=augment_labels)`,
- where:
- input (torch.Tensor): Tensor of shape (B, C_hr, H, W). Zero-filled.
- y_lr (torch.Tensor): Low-resolution input of shape (B, C_lr, H, W)
- lead_time_label (torch.Tensor, optional): Optional lead time
- labels. If provided, should be of shape (B,).
- augment_labels (torch.Tensor, optional): Optional augmentation
- labels, returned by `augment_pipe`.
- Returns:
- torch.Tensor: Predictions of shape (B, C_hr, H, W)
-
- img_clean : torch.Tensor
- High-resolution input images of shape (B, C_hr, H, W).
- Used as ground truth and for data augmentation if `augment_pipe` is provided.
-
- img_lr : torch.Tensor
- Low-resolution input images of shape (B, C_lr, H, W).
- Used as input to the neural network.
-
- lead_time_label : Optional[torch.Tensor], optional
- Lead time labels for temporal predictions, by default None.
- Shape can vary based on model requirements, typically (B,) or scalar.
-
- augment_pipe : Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]]]
- Data augmentation function.
- Expected signature:
- img_tot (torch.Tensor): Concatenated high and low resolution
- images of shape (B, C_hr+C_lr, H, W).
- Returns:
- Tuple[torch.Tensor, Optional[torch.Tensor]]:
- - Augmented images of shape (B, C_hr+C_lr, H, W)
- - Optional augmentation labels
-
- Returns
- -------
- torch.Tensor
- A tensor of shape (B, C_loss, H, W) representing the pixel-wise
- loss., where `C_loss = C_hr - len(prob_channels) + 1`. More
- specifically, the last channel of the output tensor corresponds to
- the cross-entropy loss computed over the channels specified in
- `prob_channels`, while the first `C_hr - len(prob_channels)`
- channels of the output tensor correspond to the squared error loss.
- """
- all_channels = list(range(img_clean.shape[1])) # [0, 1, 2, ..., 10]
- scalar_channels = [
- item for item in all_channels if item not in self.prob_channels
- ]
- weight = (
- 1.0 # (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
- )
-
- img_tot = torch.cat((img_clean, img_lr), dim=1)
- y_tot, augment_labels = (
- augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None)
- )
- y = y_tot[:, : img_clean.shape[1], :, :]
- y_lr = y_tot[:, img_clean.shape[1] :, :, :]
-
- input = torch.zeros_like(y, device=img_clean.device)
-
- if lead_time_label is not None:
- D_yn = net(
- input,
- y_lr,
- lead_time_label=lead_time_label,
- augment_labels=augment_labels,
- )
- else:
- D_yn = net(
- input,
- y_lr,
- lead_time_label=lead_time_label,
- augment_labels=augment_labels,
- )
- loss1 = weight * (D_yn[:, scalar_channels] - y[:, scalar_channels]) ** 2
- loss2 = (
- weight
- * self.entropy(D_yn[:, self.prob_channels], y[:, self.prob_channels])[
- :, None
- ]
- )
- loss = torch.cat((loss1, loss2), dim=1)
- return loss
\ No newline at end of file
diff --git a/src/hirad/models/__init__.py b/src/hirad/models/__init__.py
index b00a4777..2315a38f 100644
--- a/src/hirad/models/__init__.py
+++ b/src/hirad/models/__init__.py
@@ -7,8 +7,6 @@
PositionalEmbedding,
FourierEmbedding
)
-from .meta import ModelMetaData
-from .song_unet import SongUNet, SongUNetPosEmbd, SongUNetPosLtEmbd
-from .dhariwal_unet import DhariwalUNet
+from .song_unet import SongUNet, SongUNetPosEmbd
from .unet import UNet
-from .preconditioning import EDMPrecondSuperResolution, EDMPrecondSR, EDMPrecond
+from .preconditioning import EDMPrecondSuperResolution
diff --git a/src/hirad/models/dhariwal_unet.py b/src/hirad/models/dhariwal_unet.py
deleted file mode 100644
index 3880cd0a..00000000
--- a/src/hirad/models/dhariwal_unet.py
+++ /dev/null
@@ -1,259 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
-# SPDX-FileCopyrightText: All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Model architectures used in the paper "Elucidating the Design Space of
-Diffusion-Based Generative Models".
-"""
-
-from dataclasses import dataclass
-from typing import List
-
-import numpy as np
-import torch
-from torch.nn.functional import silu
-import torch.nn as nn
-
-from .layers import (
- Conv2d,
- GroupNorm,
- Linear,
- PositionalEmbedding,
- UNetBlock,
-)
-from .meta import ModelMetaData
-
-
-@dataclass
-class MetaData(ModelMetaData):
- name: str = "DhariwalUNet"
- # Optimization
- jit: bool = False
- cuda_graphs: bool = False
- amp_cpu: bool = False
- amp_gpu: bool = True
- torch_fx: bool = False
- # Data type
- bf16: bool = True
- # Inference
- onnx: bool = False
- # Physics informed
- func_torch: bool = False
- auto_grad: bool = False
-
-
-class DhariwalUNet(nn.Module):
- """
- Reimplementation of the ADM architecture, a U-Net variant, with optional
- self-attention.
-
- This model supports conditional and unconditional setups, as well as several
- options for various internal architectural choices such as encoder and decoder
- type, embedding type, etc., making it flexible and adaptable to different tasks
- and configurations.
-
- Parameters
- -----------
- img_resolution : int
- The resolution of the input/output image.
- in_channels : int
- Number of channels in the input image.
- out_channels : int
- Number of channels in the output image.
- label_dim : int, optional
- Number of class labels; 0 indicates an unconditional model. By default 0.
- augment_dim : int, optional
- Dimensionality of augmentation labels; 0 means no augmentation. By default 0.
- model_channels : int, optional
- Base multiplier for the number of channels across the network, by default 192.
- channel_mult : List[int], optional
- Per-resolution multipliers for the number of channels. By default [1,2,3,4].
- channel_mult_emb : int, optional
- Multiplier for the dimensionality of the embedding vector. By default 4.
- num_blocks : int, optional
- Number of residual blocks per resolution. By default 3.
- attn_resolutions : List[int], optional
- Resolutions at which self-attention layers are applied. By default [32, 16, 8].
- dropout : float, optional
- Dropout probability applied to intermediate activations. By default 0.10.
- label_dropout : float, optional
- Dropout probability of class labels for classifier-free guidance. By default 0.0.
-
- Reference
- ----------
- Reference: Dhariwal, P. and Nichol, A., 2021. Diffusion models beat gans on image
- synthesis. Advances in neural information processing systems, 34, pp.8780-8794.
-
- Note
- -----
- Equivalent to the original implementation by Dhariwal and Nichol, available at
- https://github.com/openai/guided-diffusion
-
- Example
- --------
- >>> model = DhariwalUNet(img_resolution=16, in_channels=2, out_channels=2)
- >>> noise_labels = torch.randn([1])
- >>> class_labels = torch.randint(0, 1, (1, 1))
- >>> input_image = torch.ones([1, 2, 16, 16])
- >>> output_image = model(input_image, noise_labels, class_labels)
- """
-
- def __init__(
- self,
- img_resolution: int,
- in_channels: int,
- out_channels: int,
- label_dim: int = 0,
- augment_dim: int = 0,
- model_channels: int = 192,
- channel_mult: List[int] = [1, 2, 3, 4],
- channel_mult_emb: int = 4,
- num_blocks: int = 3,
- attn_resolutions: List[int] = [32, 16, 8],
- dropout: float = 0.10,
- label_dropout: float = 0.0,
- ):
- super().__init__(meta=MetaData())
- self.label_dropout = label_dropout
- emb_channels = model_channels * channel_mult_emb
- init = dict(
- init_mode="kaiming_uniform",
- init_weight=np.sqrt(1 / 3),
- init_bias=np.sqrt(1 / 3),
- )
- init_zero = dict(init_mode="kaiming_uniform", init_weight=0, init_bias=0)
- block_kwargs = dict(
- emb_channels=emb_channels,
- channels_per_head=64,
- dropout=dropout,
- init=init,
- init_zero=init_zero,
- )
-
- # Mapping.
- self.map_noise = PositionalEmbedding(num_channels=model_channels)
- self.map_augment = (
- Linear(
- in_features=augment_dim,
- out_features=model_channels,
- bias=False,
- **init_zero,
- )
- if augment_dim
- else None
- )
- self.map_layer0 = Linear(
- in_features=model_channels, out_features=emb_channels, **init
- )
- self.map_layer1 = Linear(
- in_features=emb_channels, out_features=emb_channels, **init
- )
- self.map_label = (
- Linear(
- in_features=label_dim,
- out_features=emb_channels,
- bias=False,
- init_mode="kaiming_normal",
- init_weight=np.sqrt(label_dim),
- )
- if label_dim
- else None
- )
-
- # Encoder.
- self.enc = torch.nn.ModuleDict()
- cout = in_channels
- for level, mult in enumerate(channel_mult):
- res = img_resolution >> level
- if level == 0:
- cin = cout
- cout = model_channels * mult
- self.enc[f"{res}x{res}_conv"] = Conv2d(
- in_channels=cin, out_channels=cout, kernel=3, **init
- )
- else:
- self.enc[f"{res}x{res}_down"] = UNetBlock(
- in_channels=cout, out_channels=cout, down=True, **block_kwargs
- )
- for idx in range(num_blocks):
- cin = cout
- cout = model_channels * mult
- self.enc[f"{res}x{res}_block{idx}"] = UNetBlock(
- in_channels=cin,
- out_channels=cout,
- attention=(res in attn_resolutions),
- **block_kwargs,
- )
- skips = [block.out_channels for block in self.enc.values()]
-
- # Decoder.
- self.dec = torch.nn.ModuleDict()
- for level, mult in reversed(list(enumerate(channel_mult))):
- res = img_resolution >> level
- if level == len(channel_mult) - 1:
- self.dec[f"{res}x{res}_in0"] = UNetBlock(
- in_channels=cout, out_channels=cout, attention=True, **block_kwargs
- )
- self.dec[f"{res}x{res}_in1"] = UNetBlock(
- in_channels=cout, out_channels=cout, **block_kwargs
- )
- else:
- self.dec[f"{res}x{res}_up"] = UNetBlock(
- in_channels=cout, out_channels=cout, up=True, **block_kwargs
- )
- for idx in range(num_blocks + 1):
- cin = cout + skips.pop()
- cout = model_channels * mult
- self.dec[f"{res}x{res}_block{idx}"] = UNetBlock(
- in_channels=cin,
- out_channels=cout,
- attention=(res in attn_resolutions),
- **block_kwargs,
- )
- self.out_norm = GroupNorm(num_channels=cout)
- self.out_conv = Conv2d(
- in_channels=cout, out_channels=out_channels, kernel=3, **init_zero
- )
-
- def forward(self, x, noise_labels, class_labels, augment_labels=None):
- # Mapping.
- emb = self.map_noise(noise_labels)
- if self.map_augment is not None and augment_labels is not None:
- emb = emb + self.map_augment(augment_labels)
- emb = silu(self.map_layer0(emb))
- emb = self.map_layer1(emb)
- if self.map_label is not None:
- tmp = class_labels
- if self.training and self.label_dropout:
- tmp = tmp * (
- torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout
- ).to(tmp.dtype)
- emb = emb + self.map_label(tmp)
- emb = silu(emb)
-
- # Encoder.
- skips = []
- for block in self.enc.values():
- x = block(x, emb) if isinstance(block, UNetBlock) else block(x)
- skips.append(x)
-
- # Decoder.
- for block in self.dec.values():
- if x.shape[1] != block.in_channels:
- x = torch.cat([x, skips.pop()], dim=1)
- x = block(x, emb)
- x = self.out_conv(silu(self.out_norm(x)))
- return x
diff --git a/src/hirad/models/layers.py b/src/hirad/models/layers.py
index d7e63d7b..edaab04e 100644
--- a/src/hirad/models/layers.py
+++ b/src/hirad/models/layers.py
@@ -26,7 +26,7 @@
import numpy as np
import nvtx
import torch
-import torch.cuda.amp as amp
+import torch.amp as amp
from einops import rearrange
from torch.nn.functional import elu, gelu, leaky_relu, relu, sigmoid, silu, tanh
@@ -114,7 +114,7 @@ class Conv2d(torch.nn.Module):
"""
A custom 2D convolutional layer implementation with support for up-sampling,
down-sampling, and custom weight and bias initializations. The layer's weights
- and biases canbe initialized using custom initialization strategies like
+ and biases can be initialized using custom initialization strategies like
"kaiming_normal", and can be further scaled by factors `init_weight` and
`init_bias`.
@@ -403,7 +403,7 @@ def forward(self, x):
x = rearrange(x, "b (g c) h w -> b g c h w", g=self.num_groups)
mean = x.mean(dim=[2, 3, 4], keepdim=True)
- var = x.var(dim=[2, 3, 4], keepdim=True)
+ var = x.var(dim=[2, 3, 4], keepdim=True, unbiased=False)
x = (x - mean) * (var + self.eps).rsqrt()
x = rearrange(x, "b g c h w -> b (g c) h w")
@@ -700,7 +700,7 @@ def forward(self, x, emb):
# w = AttentionOp.apply(q, k)
# a = torch.einsum("nqk,nck->ncq", w, v)
# Compute attention in one step
- with amp.autocast(enabled=self.amp_mode):
+ with amp.autocast(x.device.type, enabled=self.amp_mode):
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = self.proj(attn.reshape(*x.shape)).add_(x)
x = x * self.skip_scale
diff --git a/src/hirad/models/meta.py b/src/hirad/models/meta.py
deleted file mode 100644
index aab8e453..00000000
--- a/src/hirad/models/meta.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
-# SPDX-FileCopyrightText: All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from dataclasses import dataclass
-
-
-@dataclass
-class ModelMetaData:
- """Data class for storing essential meta data needed for all Hirad Models"""
-
- # Model info
- name: str = "HiradModule"
- # Optimization
- jit: bool = False
- cuda_graphs: bool = False
- amp: bool = False
- amp_cpu: bool = None
- amp_gpu: bool = None
- torch_fx: bool = False
- # Data type
- bf16: bool = False
- # Inference
- onnx: bool = False
- onnx_gpu: bool = None
- onnx_cpu: bool = None
- onnx_runtime: bool = False
- trt: bool = False
- # Physics informed
- var_dim: int = -1
- func_torch: bool = False
- auto_grad: bool = False
-
- def __post_init__(self):
- self.amp_cpu = self.amp if self.amp_cpu is None else self.amp_cpu
- self.amp_gpu = self.amp if self.amp_gpu is None else self.amp_gpu
- self.onnx_cpu = self.onnx if self.onnx_cpu is None else self.onnx_cpu
- self.onnx_gpu = self.onnx if self.onnx_gpu is None else self.onnx_gpu
diff --git a/src/hirad/models/preconditioning.py b/src/hirad/models/preconditioning.py
index 74496a59..0f9674f4 100644
--- a/src/hirad/models/preconditioning.py
+++ b/src/hirad/models/preconditioning.py
@@ -28,691 +28,15 @@
import torch
import torch.nn as nn
-from .meta import ModelMetaData
-
network_module = importlib.import_module("hirad.models")
-@dataclass
-class VPPrecondMetaData(ModelMetaData):
- """VPPrecond meta data"""
-
- name: str = "VPPrecond"
- # Optimization
- jit: bool = False
- cuda_graphs: bool = False
- amp_cpu: bool = False
- amp_gpu: bool = True
- torch_fx: bool = False
- # Data type
- bf16: bool = False
- # Inference
- onnx: bool = False
- # Physics informed
- func_torch: bool = False
- auto_grad: bool = False
-
-
-class VPPrecond(nn.Module):
- """
- Preconditioning corresponding to the variance preserving (VP) formulation.
-
- Parameters
- ----------
- img_resolution : int
- Image resolution.
- img_channels : int
- Number of color channels.
- label_dim : int
- Number of class labels, 0 = unconditional, by default 0.
- use_fp16 : bool
- Execute the underlying model at FP16 precision?, by default False.
- beta_d : float
- Extent of the noise level schedule, by default 19.9.
- beta_min : float
- Initial slope of the noise level schedule, by default 0.1.
- M : int
- Original number of timesteps in the DDPM formulation, by default 1000.
- epsilon_t : float
- Minimum t-value used during training, by default 1e-5.
- model_type :str
- Class name of the underlying model, by default "SongUNet".
- **model_kwargs : dict
- Keyword arguments for the underlying model.
-
- Note
- ----
- Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and
- Poole, B., 2020. Score-based generative modeling through stochastic differential
- equations. arXiv preprint arXiv:2011.13456.
- """
-
- def __init__(
- self,
- img_resolution: int,
- img_channels: int,
- label_dim: int = 0,
- use_fp16: bool = False,
- beta_d: float = 19.9,
- beta_min: float = 0.1,
- M: int = 1000,
- epsilon_t: float = 1e-5,
- model_type: str = "SongUNet",
- **model_kwargs: dict,
- ):
- super().__init__() #meta=VPPrecondMetaData
- self.img_resolution = img_resolution
- self.img_channels = img_channels
- self.label_dim = label_dim
- self.use_fp16 = use_fp16
- self.beta_d = beta_d
- self.beta_min = beta_min
- self.M = M
- self.epsilon_t = epsilon_t
- self.sigma_min = float(self.sigma(epsilon_t))
- self.sigma_max = float(self.sigma(1))
- model_class = getattr(network_module, model_type)
- self.model = model_class(
- img_resolution=img_resolution,
- in_channels=img_channels,
- out_channels=img_channels,
- label_dim=label_dim,
- **model_kwargs,
- ) # TODO needs better handling
-
- def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
- x = x.to(torch.float32)
- sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
- class_labels = (
- None
- if self.label_dim == 0
- else torch.zeros([1, self.label_dim], device=x.device)
- if class_labels is None
- else class_labels.to(torch.float32).reshape(-1, self.label_dim)
- )
- dtype = (
- torch.float16
- if (self.use_fp16 and not force_fp32 and x.device.type == "cuda")
- else torch.float32
- )
-
- c_skip = 1
- c_out = -sigma
- c_in = 1 / (sigma**2 + 1).sqrt()
- c_noise = (self.M - 1) * self.sigma_inv(sigma)
-
- F_x = self.model(
- (c_in * x).to(dtype),
- c_noise.flatten(),
- class_labels=class_labels,
- **model_kwargs,
- )
- if (F_x.dtype != dtype) and not torch.is_autocast_enabled():
- raise ValueError(
- f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead."
- )
-
- D_x = c_skip * x + c_out * F_x.to(torch.float32)
- return D_x
-
- def sigma(self, t: Union[float, torch.Tensor]):
- """
- Compute the sigma(t) value for a given t based on the VP formulation.
-
- The function calculates the noise level schedule for the diffusion process based
- on the given parameters `beta_d` and `beta_min`.
-
- Parameters
- ----------
- t : Union[float, torch.Tensor]
- The timestep or set of timesteps for which to compute sigma(t).
-
- Returns
- -------
- torch.Tensor
- The computed sigma(t) value(s).
- """
- t = torch.as_tensor(t)
- return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt()
-
- def sigma_inv(self, sigma: Union[float, torch.Tensor]):
- """
- Compute the inverse of the sigma function for a given sigma.
-
- This function effectively calculates t from a given sigma(t) based on the
- parameters `beta_d` and `beta_min`.
-
- Parameters
- ----------
- sigma : Union[float, torch.Tensor]
- The sigma(t) value or set of sigma(t) values for which to compute the
- inverse.
-
- Returns
- -------
- torch.Tensor
- The computed t value(s) corresponding to the provided sigma(t).
- """
- sigma = torch.as_tensor(sigma)
- return (
- (self.beta_min**2 + 2 * self.beta_d * (1 + sigma**2).log()).sqrt()
- - self.beta_min
- ) / self.beta_d
-
- def round_sigma(self, sigma: Union[float, List, torch.Tensor]):
- """
- Convert a given sigma value(s) to a tensor representation.
-
- Parameters
- ----------
- sigma : Union[float list, torch.Tensor]
- The sigma value(s) to convert.
-
- Returns
- -------
- torch.Tensor
- The tensor representation of the provided sigma value(s).
- """
- return torch.as_tensor(sigma)
-
-
-@dataclass
-class VEPrecondMetaData(ModelMetaData):
- """VEPrecond meta data"""
-
- name: str = "VEPrecond"
- # Optimization
- jit: bool = False
- cuda_graphs: bool = False
- amp_cpu: bool = False
- amp_gpu: bool = True
- torch_fx: bool = False
- # Data type
- bf16: bool = False
- # Inference
- onnx: bool = False
- # Physics informed
- func_torch: bool = False
- auto_grad: bool = False
-
-
-class VEPrecond(nn.Module):
- """
- Preconditioning corresponding to the variance exploding (VE) formulation.
-
- Parameters
- ----------
- img_resolution : int
- Image resolution.
- img_channels : int
- Number of color channels.
- label_dim : int
- Number of class labels, 0 = unconditional, by default 0.
- use_fp16 : bool
- Execute the underlying model at FP16 precision?, by default False.
- sigma_min : float
- Minimum supported noise level, by default 0.02.
- sigma_max : float
- Maximum supported noise level, by default 100.0.
- model_type :str
- Class name of the underlying model, by default "SongUNet".
- **model_kwargs : dict
- Keyword arguments for the underlying model.
-
- Note
- ----
- Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and
- Poole, B., 2020. Score-based generative modeling through stochastic differential
- equations. arXiv preprint arXiv:2011.13456.
- """
-
- def __init__(
- self,
- img_resolution: int,
- img_channels: int,
- label_dim: int = 0,
- use_fp16: bool = False,
- sigma_min: float = 0.02,
- sigma_max: float = 100.0,
- model_type: str = "SongUNet",
- **model_kwargs: dict,
- ):
- super().__init__() #meta=VEPrecondMetaData
- self.img_resolution = img_resolution
- self.img_channels = img_channels
- self.label_dim = label_dim
- self.use_fp16 = use_fp16
- self.sigma_min = sigma_min
- self.sigma_max = sigma_max
- model_class = getattr(network_module, model_type)
- self.model = model_class(
- img_resolution=img_resolution,
- in_channels=img_channels,
- out_channels=img_channels,
- label_dim=label_dim,
- **model_kwargs,
- ) # TODO needs better handling
-
- def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
- x = x.to(torch.float32)
- sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
- class_labels = (
- None
- if self.label_dim == 0
- else torch.zeros([1, self.label_dim], device=x.device)
- if class_labels is None
- else class_labels.to(torch.float32).reshape(-1, self.label_dim)
- )
- dtype = (
- torch.float16
- if (self.use_fp16 and not force_fp32 and x.device.type == "cuda")
- else torch.float32
- )
-
- c_skip = 1
- c_out = sigma
- c_in = 1
- c_noise = (0.5 * sigma).log()
-
- F_x = self.model(
- (c_in * x).to(dtype),
- c_noise.flatten(),
- class_labels=class_labels,
- **model_kwargs,
- )
- if (F_x.dtype != dtype) and not torch.is_autocast_enabled():
- raise ValueError(
- f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead."
- )
-
- D_x = c_skip * x + c_out * F_x.to(torch.float32)
- return D_x
-
- def round_sigma(self, sigma: Union[float, List, torch.Tensor]):
- """
- Convert a given sigma value(s) to a tensor representation.
-
- Parameters
- ----------
- sigma : Union[float list, torch.Tensor]
- The sigma value(s) to convert.
-
- Returns
- -------
- torch.Tensor
- The tensor representation of the provided sigma value(s).
- """
- return torch.as_tensor(sigma)
-
-
-@dataclass
-class iDDPMPrecondMetaData(ModelMetaData):
- """iDDPMPrecond meta data"""
-
- name: str = "iDDPMPrecond"
- # Optimization
- jit: bool = False
- cuda_graphs: bool = False
- amp_cpu: bool = False
- amp_gpu: bool = True
- torch_fx: bool = False
- # Data type
- bf16: bool = False
- # Inference
- onnx: bool = False
- # Physics informed
- func_torch: bool = False
- auto_grad: bool = False
-
-
-class iDDPMPrecond(nn.Module):
- """
- Preconditioning corresponding to the improved DDPM (iDDPM) formulation.
-
- Parameters
- ----------
- img_resolution : int
- Image resolution.
- img_channels : int
- Number of color channels.
- label_dim : int
- Number of class labels, 0 = unconditional, by default 0.
- use_fp16 : bool
- Execute the underlying model at FP16 precision?, by default False.
- C_1 : float
- Timestep adjustment at low noise levels., by default 0.001.
- C_2 : float
- Timestep adjustment at high noise levels., by default 0.008.
- M: int
- Original number of timesteps in the DDPM formulation, by default 1000.
- model_type :str
- Class name of the underlying model, by default "DhariwalUNet".
- **model_kwargs : dict
- Keyword arguments for the underlying model.
-
- Note
- ----
- Reference: Nichol, A.Q. and Dhariwal, P., 2021, July. Improved denoising diffusion
- probabilistic models. In International Conference on Machine Learning
- (pp. 8162-8171). PMLR.
- """
-
- def __init__(
- self,
- img_resolution,
- img_channels,
- label_dim=0,
- use_fp16=False,
- C_1=0.001,
- C_2=0.008,
- M=1000,
- model_type="DhariwalUNet",
- **model_kwargs,
- ):
- super().__init__() #meta=iDDPMPrecondMetaData
- self.img_resolution = img_resolution
- self.img_channels = img_channels
- self.label_dim = label_dim
- self.use_fp16 = use_fp16
- self.C_1 = C_1
- self.C_2 = C_2
- self.M = M
- model_class = getattr(network_module, model_type)
- self.model = model_class(
- img_resolution=img_resolution,
- in_channels=img_channels,
- out_channels=img_channels * 2,
- label_dim=label_dim,
- **model_kwargs,
- ) # TODO needs better handling
-
- u = torch.zeros(M + 1)
- for j in range(M, 0, -1): # M, ..., 1
- u[j - 1] = (
- (u[j] ** 2 + 1)
- / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1)
- - 1
- ).sqrt()
- self.register_buffer("u", u)
- self.sigma_min = float(u[M - 1])
- self.sigma_max = float(u[0])
-
- def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
- x = x.to(torch.float32)
- sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
- class_labels = (
- None
- if self.label_dim == 0
- else torch.zeros([1, self.label_dim], device=x.device)
- if class_labels is None
- else class_labels.to(torch.float32).reshape(-1, self.label_dim)
- )
- dtype = (
- torch.float16
- if (self.use_fp16 and not force_fp32 and x.device.type == "cuda")
- else torch.float32
- )
-
- c_skip = 1
- c_out = -sigma
- c_in = 1 / (sigma**2 + 1).sqrt()
- c_noise = (
- self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32)
- )
-
- F_x = self.model(
- (c_in * x).to(dtype),
- c_noise.flatten(),
- class_labels=class_labels,
- **model_kwargs,
- )
- if (F_x.dtype != dtype) and not torch.is_autocast_enabled():
- raise ValueError(
- f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead."
- )
-
- D_x = c_skip * x + c_out * F_x[:, : self.img_channels].to(torch.float32)
- return D_x
-
- def alpha_bar(self, j):
- """
- Compute the alpha_bar(j) value for a given j based on the iDDPM formulation.
-
- Parameters
- ----------
- j : Union[int, torch.Tensor]
- The timestep or set of timesteps for which to compute alpha_bar(j).
-
- Returns
- -------
- torch.Tensor
- The computed alpha_bar(j) value(s).
- """
- j = torch.as_tensor(j)
- return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2
-
- def round_sigma(self, sigma, return_index=False):
- """
- Round the provided sigma value(s) to the nearest value(s) in a
- pre-defined set `u`.
-
- Parameters
- ----------
- sigma : Union[float, list, torch.Tensor]
- The sigma value(s) to round.
- return_index : bool, optional
- Whether to return the index/indices of the rounded value(s) in `u` instead
- of the rounded value(s) themselves, by default False.
-
- Returns
- -------
- torch.Tensor
- The rounded sigma value(s) or their index/indices in `u`, depending on the
- value of `return_index`.
- """
- sigma = torch.as_tensor(sigma)
- index = torch.cdist(
- sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1),
- self.u.reshape(1, -1, 1),
- ).argmin(2)
- result = index if return_index else self.u[index.flatten()].to(sigma.dtype)
- return result.reshape(sigma.shape).to(sigma.device)
-
-
-@dataclass
-class EDMPrecondMetaData(ModelMetaData):
- """EDMPrecond meta data"""
-
- name: str = "EDMPrecond"
- # Optimization
- jit: bool = False
- cuda_graphs: bool = False
- amp_cpu: bool = False
- amp_gpu: bool = True
- torch_fx: bool = False
- # Data type
- bf16: bool = False
- # Inference
- onnx: bool = False
- # Physics informed
- func_torch: bool = False
- auto_grad: bool = False
-
-
-class EDMPrecond(nn.Module):
- """
- Improved preconditioning proposed in the paper "Elucidating the Design Space of
- Diffusion-Based Generative Models" (EDM)
-
- Parameters
- ----------
- img_resolution : int
- Image resolution.
- img_channels : int
- Number of color channels (for both input and output). If your model
- requires a different number of input or output chanels,
- override this by passing either of the optional
- img_in_channels or img_out_channels args
- label_dim : int
- Number of class labels, 0 = unconditional, by default 0.
- use_fp16 : bool
- Execute the underlying model at FP16 precision?, by default False.
- sigma_min : float
- Minimum supported noise level, by default 0.0.
- sigma_max : float
- Maximum supported noise level, by default inf.
- sigma_data : float
- Expected standard deviation of the training data, by default 0.5.
- model_type :str
- Class name of the underlying model, by default "DhariwalUNet".
- img_in_channels: int
- Optional setting for when number of input channels =/= number of output
- channels. If set, will override img_channels for the input
- This is useful in the case of additional (conditional) channels
- img_out_channels: int
- Optional setting for when number of input channels =/= number of output
- channels. If set, will override img_channels for the output
- **model_kwargs : dict
- Keyword arguments for the underlying model.
-
- Note
- ----
- Reference: Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the
- design space of diffusion-based generative models. Advances in Neural Information
- Processing Systems, 35, pp.26565-26577.
- """
-
- def __init__(
- self,
- img_resolution,
- img_channels,
- label_dim=0,
- use_fp16=False,
- sigma_min=0.0,
- sigma_max=float("inf"),
- sigma_data=0.5,
- model_type="DhariwalUNet",
- img_in_channels=None,
- img_out_channels=None,
- **model_kwargs,
- ):
- super().__init__() #meta=EDMPrecondMetaData
- self.img_resolution = img_resolution
- if img_in_channels is not None:
- img_in_channels = img_in_channels
- else:
- img_in_channels = img_channels
- if img_out_channels is not None:
- img_out_channels = img_out_channels
- else:
- img_out_channels = img_channels
-
- self.label_dim = label_dim
- self.use_fp16 = use_fp16
- self.sigma_min = sigma_min
- self.sigma_max = sigma_max
- self.sigma_data = sigma_data
-
- model_class = getattr(network_module, model_type)
- self.model = model_class(
- img_resolution=img_resolution,
- in_channels=img_in_channels,
- out_channels=img_out_channels,
- label_dim=label_dim,
- **model_kwargs,
- ) # TODO needs better handling
-
- def forward(
- self,
- x,
- sigma,
- condition=None,
- class_labels=None,
- force_fp32=False,
- **model_kwargs,
- ):
- x = x.to(torch.float32)
- sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
- class_labels = (
- None
- if self.label_dim == 0
- else torch.zeros([1, self.label_dim], device=x.device)
- if class_labels is None
- else class_labels.to(torch.float32).reshape(-1, self.label_dim)
- )
- dtype = (
- torch.float16
- if (self.use_fp16 and not force_fp32 and x.device.type == "cuda")
- else torch.float32
- )
-
- c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
- c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
- c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
- c_noise = sigma.log() / 4
-
- arg = c_in * x
-
- if condition is not None:
- arg = torch.cat([arg, condition], dim=1)
-
- F_x = self.model(
- arg.to(dtype),
- c_noise.flatten(),
- class_labels=class_labels,
- **model_kwargs,
- )
-
- if (F_x.dtype != dtype) and not torch.is_autocast_enabled():
- raise ValueError(
- f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead."
- )
- D_x = c_skip * x + c_out * F_x.to(torch.float32)
- return D_x
-
- @staticmethod
- def round_sigma(sigma: Union[float, List, torch.Tensor]):
- """
- Convert a given sigma value(s) to a tensor representation.
-
- Parameters
- ----------
- sigma : Union[float list, torch.Tensor]
- The sigma value(s) to convert.
-
- Returns
- -------
- torch.Tensor
- The tensor representation of the provided sigma value(s).
- """
- return torch.as_tensor(sigma)
-
-@dataclass
-class EDMPrecondSuperResolutionMetaData(ModelMetaData):
- """EDMPrecondSR meta data"""
-
- name: str = "EDMPrecondSuperResolution"
- # Optimization
- jit: bool = False
- cuda_graphs: bool = False
- amp_cpu: bool = False
- amp_gpu: bool = True
- torch_fx: bool = False
- # Data type
- bf16: bool = False
- # Inference
- onnx: bool = False
- # Physics informed
- func_torch: bool = False
- auto_grad: bool = False
-
-
class EDMPrecondSuperResolution(nn.Module):
"""
Improved preconditioning proposed in the paper "Elucidating the Design Space of
Diffusion-Based Generative Models" (EDM).
- This is a variant of `EDMPrecond` that is specifically designed for super-resolution
+ This is a variant of EDM Preconditioning that is specifically designed for super-resolution
tasks. It wraps a neural network that predicts the denoised high-resolution image
given a noisy high-resolution image, and additional conditioning that includes a
low-resolution image, and a noise level.
@@ -731,7 +55,7 @@ class EDMPrecondSuperResolution(nn.Module):
by default False.
model_type : str, optional
Class name of the underlying model. Must be one of the following:
- 'SongUNet', 'SongUNetPosEmbd', 'SongUNetPosLtEmbd', 'DhariwalUNet'.
+ 'SongUNet', 'SongUNetPosEmbd'.
Defaults to 'SongUNetPosEmbd'.
sigma_data : float, optional
Expected standard deviation of the training data, by default 0.5.
@@ -761,14 +85,14 @@ def __init__(
img_out_channels: int,
use_fp16: bool = False,
model_type: Literal[
- "SongUNetPosEmbd", "SongUNetPosLtEmbd", "SongUNet", "DhariwalUNet"
+ "SongUNetPosEmbd", "SongUNet"
] = "SongUNetPosEmbd",
sigma_data: float = 0.5,
sigma_min=0.0,
sigma_max=float("inf"),
**model_kwargs: dict,
):
- super().__init__() #meta=EDMPrecondSRMetaData
+ super().__init__()
self.img_resolution = img_resolution
self.img_in_channels = img_in_channels
self.img_out_channels = img_out_channels
@@ -900,11 +224,8 @@ def round_sigma(sigma: Union[float, List, torch.Tensor]) -> torch.Tensor:
torch.Tensor
Tensor representation of sigma values.
- See Also
- --------
- EDMPrecond.round_sigma
"""
- return EDMPrecond.round_sigma(sigma)
+ return torch.as_tensor(sigma)
@property
def amp_mode(self):
@@ -929,460 +250,3 @@ def amp_mode(self, value: bool):
if hasattr(sub_module, "amp_mode"):
sub_module.amp_mode = value
-# NOTE: This is a deprecated version of the EDMPrecondSuperResolution model.
-# This was used to maintain backwards compatibility and allow loading old models.
-@dataclass
-class EDMPrecondSRMetaData(ModelMetaData):
- """EDMPrecondSR meta data"""
-
- name: str = "EDMPrecondSR"
- # Optimization
- jit: bool = False
- cuda_graphs: bool = False
- amp_cpu: bool = False
- amp_gpu: bool = True
- torch_fx: bool = False
- # Data type
- bf16: bool = False
- # Inference
- onnx: bool = False
- # Physics informed
- func_torch: bool = False
- auto_grad: bool = False
-
-
-class EDMPrecondSR(EDMPrecondSuperResolution):
- """
- Improved preconditioning proposed in the paper "Elucidating the Design Space of
- Diffusion-Based Generative Models" (EDM) for super-resolution tasks
-
- Parameters
- ----------
- img_resolution : int
- Image resolution.
- img_channels : int
- Number of color channels.
- img_in_channels : int
- Number of input color channels.
- img_out_channels : int
- Number of output color channels.
- use_fp16 : bool
- Execute the underlying model at FP16 precision?, by default False.
- sigma_min : float
- Minimum supported noise level, by default 0.0.
- sigma_max : float
- Maximum supported noise level, by default inf.
- sigma_data : float
- Expected standard deviation of the training data, by default 0.5.
- model_type :str
- Class name of the underlying model, by default "SongUNetPosEmbd".
- **model_kwargs : dict
- Keyword arguments for the underlying model.
-
- Note
- ----
- References:
- - Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the
- design space of diffusion-based generative models. Advances in Neural Information
- Processing Systems, 35, pp.26565-26577.
- - Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y.,
- Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023.
- Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling.
- arXiv preprint arXiv:2309.15214.
- """
-
- def __init__(
- self,
- img_resolution,
- img_channels, #deprecated
- img_in_channels,
- img_out_channels,
- use_fp16=False,
- sigma_min=0.0,
- sigma_max=float("inf"),
- sigma_data=0.5,
- model_type="SongUNetPosEmbd",
- scale_cond_input=True, #deprecated
- **model_kwargs,
- ):
- warnings.warn(
- "EDMPrecondSR is deprecated and will be removed in a future version. "
- "Please use EDMPrecondSuperResolution instead.",
- DeprecationWarning,
- stacklevel=2,
- )
-
- if scale_cond_input:
- warnings.warn(
- "scale_cond_input=True does not properly scale the conditional input. "
- "(see https://github.com/NVIDIA/modulus/issues/229). "
- "This setup will be deprecated. "
- "Please set scale_cond_input=False.",
- DeprecationWarning,
- )
-
- super().__init__(
- img_resolution=img_resolution,
- img_in_channels=img_in_channels,
- img_out_channels=img_out_channels,
- use_fp16=use_fp16,
- sigma_min=sigma_min,
- sigma_max=sigma_max,
- sigma_data=sigma_data,
- model_type=model_type,
- **model_kwargs,
- )
-
- # Store deprecated parameters for backward compatibility
- self.img_channels = img_channels
- self.scale_cond_input = scale_cond_input
-
- def forward(
- self,
- x,
- img_lr,
- sigma,
- force_fp32=False,
- **model_kwargs,
- ):
- """
- Forward pass of the EDMPrecondSR model wrapper.
-
- Parameters
- ----------
- x : torch.Tensor
- Noisy high-resolution image of shape (B, C_hr, H, W).
- img_lr : torch.Tensor
- Low-resolution conditioning image of shape (B, C_lr, H, W).
- sigma : torch.Tensor
- Noise level of shape (B) or (B, 1) or (B, 1, 1, 1).
- force_fp32 : bool, optional
- Whether to force FP32 precision regardless of the `use_fp16` attribute,
- by default False.
- **model_kwargs : dict
- Additional keyword arguments to pass to the underlying model.
-
- Returns
- -------
- torch.Tensor
- Denoised high-resolution image of shape (B, C_hr, H, W).
- """
- return super().forward(
- x=x, img_lr=img_lr, sigma=sigma, force_fp32=force_fp32, **model_kwargs
- )
-
-class VEPrecond_dfsr(nn.Module):
- """
- Preconditioning for dfsr model, modified from class VEPrecond, where the input
- argument 'sigma' in forward propagation function is used to receive the timestep
- of the backward diffusion process.
-
- Parameters
- ----------
- img_resolution : int
- Image resolution.
- img_channels : int
- Number of color channels.
- label_dim : int
- Number of class labels, 0 = unconditional, by default 0.
- use_fp16 : bool
- Execute the underlying model at FP16 precision?, by default False.
- sigma_min : float
- Minimum supported noise level, by default 0.02.
- sigma_max : float
- Maximum supported noise level, by default 100.0.
- model_type :str
- Class name of the underlying model, by default "SongUNet".
- **model_kwargs : dict
- Keyword arguments for the underlying model.
-
- Note
- ----
- Reference: Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models.
- Advances in neural information processing systems. 2020;33:6840-51.
- """
-
- def __init__(
- self,
- img_resolution: int,
- img_channels: int,
- label_dim: int = 0,
- use_fp16: bool = False,
- sigma_min: float = 0.02,
- sigma_max: float = 100.0,
- dataset_mean: float = 5.85e-05,
- dataset_scale: float = 4.79,
- model_type: str = "SongUNet",
- **model_kwargs: dict,
- ):
- super().__init__()
- self.img_resolution = img_resolution
- self.img_channels = img_channels
- self.label_dim = label_dim
- self.use_fp16 = use_fp16
- model_class = getattr(network_module, model_type)
- self.model = model_class(
- img_resolution=img_resolution,
- in_channels=self.img_channels,
- out_channels=img_channels,
- label_dim=label_dim,
- **model_kwargs,
- ) # TODO needs better handling
-
- def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
- x = x.to(torch.float32)
- sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
- # print("sigma: ", sigma)
- class_labels = (
- None
- if self.label_dim == 0
- else torch.zeros([1, self.label_dim], device=x.device)
- if class_labels is None
- else class_labels.to(torch.float32).reshape(-1, self.label_dim)
- )
- dtype = (
- torch.float16
- if (self.use_fp16 and not force_fp32 and x.device.type == "cuda")
- else torch.float32
- )
-
- c_in = 1
- c_noise = sigma # Change the definitation of c_noise to avoid -inf values for zero sigma
-
- F_x = self.model(
- (c_in * x).to(dtype),
- c_noise.flatten(),
- class_labels=class_labels,
- **model_kwargs,
- )
-
- if F_x.dtype != dtype:
- raise ValueError(
- f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead."
- )
-
- return F_x
-
-
-class VEPrecond_dfsr_cond(nn.Module):
- """
- Preconditioning for dfsr model with physics-informed conditioning input, modified
- from class VEPrecond, where the input argument 'sigma' in forward propagation function
- is used to receive the timestep of the backward diffusion process. The gradient of PDE
- residual with respect to the vorticity in the governing Navier-Stokes equation is computed
- as the physics-informed conditioning variable and is combined with the backward diffusion
- timestep before being sent to the underlying model for noise prediction.
-
- Parameters
- ----------
- img_resolution : int
- Image resolution.
- img_channels : int
- Number of color channels.
- label_dim : int
- Number of class labels, 0 = unconditional, by default 0.
- use_fp16 : bool
- Execute the underlying model at FP16 precision?, by default False.
- sigma_min : float
- Minimum supported noise level, by default 0.02.
- sigma_max : float
- Maximum supported noise level, by default 100.0.
- model_type :str
- Class name of the underlying model, by default "SongUNet".
- **model_kwargs : dict
- Keyword arguments for the underlying model.
-
- Note
- ----
- Reference:
- [1] Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and
- Poole, B., 2020. Score-based generative modeling through stochastic differential
- equations. arXiv preprint arXiv:2011.13456.
- [2] Shu D, Li Z, Farimani AB. A physics-informed diffusion model for high-fidelity
- flow field reconstruction. Journal of Computational Physics. 2023 Apr 1;478:111972.
- """
-
- def __init__(
- self,
- img_resolution: int,
- img_channels: int,
- label_dim: int = 0,
- use_fp16: bool = False,
- sigma_min: float = 0.02,
- sigma_max: float = 100.0,
- dataset_mean: float = 5.85e-05,
- dataset_scale: float = 4.79,
- model_type: str = "SongUNet",
- **model_kwargs: dict,
- ):
- super().__init__()
- self.img_resolution = img_resolution
- self.img_channels = img_channels
- self.label_dim = label_dim
- self.use_fp16 = use_fp16
- model_class = getattr(network_module, model_type)
- self.model = model_class(
- img_resolution=img_resolution,
- in_channels=model_kwargs["model_channels"] * 2,
- out_channels=img_channels,
- label_dim=label_dim,
- **model_kwargs,
- ) # TODO needs better handling
-
- # modules to embed residual loss
- self.conv_in = torch.nn.Conv2d(
- img_channels,
- model_kwargs["model_channels"],
- kernel_size=3,
- stride=1,
- padding=1,
- padding_mode="circular",
- )
- self.emb_conv = torch.nn.Sequential(
- torch.nn.Conv2d(
- img_channels,
- model_kwargs["model_channels"],
- kernel_size=1,
- stride=1,
- padding=0,
- ),
- torch.nn.GELU(),
- torch.nn.Conv2d(
- model_kwargs["model_channels"],
- model_kwargs["model_channels"],
- kernel_size=3,
- stride=1,
- padding=1,
- padding_mode="circular",
- ),
- )
- self.dataset_mean = dataset_mean
- self.dataset_scale = dataset_scale
-
- def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
- x = x.to(torch.float32)
- sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
- class_labels = (
- None
- if self.label_dim == 0
- else torch.zeros([1, self.label_dim], device=x.device)
- if class_labels is None
- else class_labels.to(torch.float32).reshape(-1, self.label_dim)
- )
- dtype = (
- torch.float16
- if (self.use_fp16 and not force_fp32 and x.device.type == "cuda")
- else torch.float32
- )
-
- c_in = 1
- c_noise = sigma
-
- # Compute physics-informed conditioning information using vorticity residual
- dx = (
- self.voriticity_residual((x * self.dataset_scale + self.dataset_mean))
- / self.dataset_scale
- )
- x = self.conv_in(x)
- cond_emb = self.emb_conv(dx)
- x = torch.cat((x, cond_emb), dim=1)
-
- F_x = self.model(
- (c_in * x).to(dtype),
- c_noise.flatten(),
- class_labels=class_labels,
- **model_kwargs,
- )
-
- if F_x.dtype != dtype:
- raise ValueError(
- f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead."
- )
- return F_x
-
- def voriticity_residual(self, w, re=1000.0, dt=1 / 32):
- """
- Compute the gradient of PDE residual with respect to a given vorticity w using the
- spectrum method.
-
- Parameters
- ----------
- w: torch.Tensor
- The fluid flow data sample (vorticity).
- re: float
- The value of Reynolds number used in the governing Navier-Stokes equation.
- dt: float
- Time step used to compute the time-derivative of vorticity included in the governing
- Navier-Stokes equation.
-
- Returns
- -------
- torch.Tensor
- The computed vorticity gradient.
- """
-
- # w [b t h w]
- w = w.clone()
- w.requires_grad_(True)
- nx = w.size(2)
- device = w.device
-
- w_h = torch.fft.fft2(w[:, 1:-1], dim=[2, 3])
- # Wavenumbers in y-direction
- k_max = nx // 2
- N = nx
- k_x = (
- torch.cat(
- (
- torch.arange(start=0, end=k_max, step=1, device=device),
- torch.arange(start=-k_max, end=0, step=1, device=device),
- ),
- 0,
- )
- .reshape(N, 1)
- .repeat(1, N)
- .reshape(1, 1, N, N)
- )
- k_y = (
- torch.cat(
- (
- torch.arange(start=0, end=k_max, step=1, device=device),
- torch.arange(start=-k_max, end=0, step=1, device=device),
- ),
- 0,
- )
- .reshape(1, N)
- .repeat(N, 1)
- .reshape(1, 1, N, N)
- )
- # Negative Laplacian in Fourier space
- lap = k_x**2 + k_y**2
- lap[..., 0, 0] = 1.0
- psi_h = w_h / lap
-
- u_h = 1j * k_y * psi_h
- v_h = -1j * k_x * psi_h
- wx_h = 1j * k_x * w_h
- wy_h = 1j * k_y * w_h
- wlap_h = -lap * w_h
-
- u = torch.fft.irfft2(u_h[..., :, : k_max + 1], dim=[2, 3])
- v = torch.fft.irfft2(v_h[..., :, : k_max + 1], dim=[2, 3])
- wx = torch.fft.irfft2(wx_h[..., :, : k_max + 1], dim=[2, 3])
- wy = torch.fft.irfft2(wy_h[..., :, : k_max + 1], dim=[2, 3])
- wlap = torch.fft.irfft2(wlap_h[..., :, : k_max + 1], dim=[2, 3])
- advection = u * wx + v * wy
-
- wt = (w[:, 2:, :, :] - w[:, :-2, :, :]) / (2 * dt)
-
- # establish forcing term
- x = torch.linspace(0, 2 * np.pi, nx + 1, device=device)
- x = x[0:-1]
- X, Y = torch.meshgrid(x, x)
- f = -4 * torch.cos(4 * Y)
-
- residual = wt + (advection - (1.0 / re) * wlap + 0.1 * w[:, 1:-1]) - f
- residual_loss = (residual**2).mean()
- dw = torch.autograd.grad(residual_loss, w)[0]
-
- return dw
diff --git a/src/hirad/models/song_unet.py b/src/hirad/models/song_unet.py
index a56f8613..cc52cfdc 100644
--- a/src/hirad/models/song_unet.py
+++ b/src/hirad/models/song_unet.py
@@ -38,25 +38,6 @@
PositionalEmbedding,
UNetBlock,
)
-from .meta import ModelMetaData
-
-
-@dataclass
-class MetaData(ModelMetaData):
- name: str = "SongUNet"
- # Optimization
- jit: bool = False
- cuda_graphs: bool = False
- amp_cpu: bool = False
- amp_gpu: bool = True
- torch_fx: bool = False
- # Data type
- bf16: bool = True
- # Inference
- onnx: bool = False
- # Physics informed
- func_torch: bool = False
- auto_grad: bool = False
class SongUNet(nn.Module):
@@ -204,7 +185,7 @@ def __init__(
emb_channels=emb_channels,
num_heads=1,
dropout=dropout,
- skip_scale=np.sqrt(0.5),
+ skip_scale=0.7071067811865476, # 1 / sqrt(2)
eps=1e-6,
resample_filter=resample_filter,
resample_proj=True,
@@ -678,10 +659,13 @@ def __init__(
self.gridtype = gridtype
self.N_grid_channels = N_grid_channels
- if self.gridtype == "learnable":
- self.pos_embd = self._get_positional_embedding()
+ if self.N_grid_channels:
+ if self.gridtype == "learnable":
+ self.pos_embd = self._get_positional_embedding()
+ else:
+ self.register_buffer("pos_embd", self._get_positional_embedding().float())
else:
- self.register_buffer("pos_embd", self._get_positional_embedding().float())
+ self.pos_embd = None
self.lead_time_mode = lead_time_mode
if self.lead_time_mode:
self.lead_time_channels = lead_time_channels
@@ -712,7 +696,13 @@ def forward(
"embedding_selector is the preferred approach for better efficiency."
)
- if x.dtype != self.pos_embd.dtype:
+ if self.lead_time_mode and embedding_selector is not None:
+ raise ValueError(
+ "Embedding selector is not supported in lead time mode. "
+ "Please use global_index to select positional embeddings when lead_time_mode is True."
+ )
+
+ if self.pos_embd is not None and x.dtype != self.pos_embd.dtype:
self.pos_embd = self.pos_embd.to(x.dtype)
# Append positional embedding to input conditioning
@@ -799,7 +789,7 @@ def positional_embedding_indexing(
Example
-------
>>> # Create global indices using patching utility:
- >>> from physicsnemo.utils.patching import GridPatching2D
+ >>> from hirad.utils.patching import GridPatching2D
>>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8))
>>> global_index = patching.global_index(batch_size=3)
>>> print(global_index.shape)
@@ -807,9 +797,9 @@ def positional_embedding_indexing(
See Also
--------
- :meth:`physicsnemo.utils.patching.RandomPatching2D.global_index`
+ :meth:`hirad.utils.patching.RandomPatching2D.global_index`
For generating random patch indices.
- :meth:`physicsnemo.utils.patching.GridPatching2D.global_index`
+ :meth:`hirad.utils.patching.GridPatching2D.global_index`
For generating deterministic grid-based patch indices.
See these methods for possible ways to generate the global_index parameter.
"""
@@ -919,7 +909,7 @@ def positional_embedding_selector(
Each selected embedding should correspond to the positional
information of each batch element in x.
For patch-based processing, typically this should be based on
- :meth:`physicsnemo.utils.patching.BasePatching2D.apply` method to
+ :meth:`hirad.utils.patching.BasePatching2D.apply` method to
maintain consistency with patch extraction.
embeds : Optional[torch.Tensor]
Optional tensor for combined positional and lead time embeddings tensor
@@ -988,6 +978,10 @@ def _get_positional_embedding(self):
raise ValueError("N_grid_channels must be a factor of 4")
num_freq = self.N_grid_channels // 4
freq_bands = 2.0 ** np.linspace(0.0, num_freq, num=num_freq)
+ #TODO: When more than 4 channels are used for sinusoidal, the frequencies should be multiples of the base frequency (2).
+ # freq_bands = 2.0 ** np.linspace(0.0, num_freq, num=num_freq) is currently in code which gives
+ # freqs = [1,4] instead of [1,2] for N_grid_channels=8. This seems to be a bug if we want the base 2.
+ # Leaving it like this for now since we have checkpoints with 8 sinusoidal channels that use these frequencies,
grid_list = []
grid_x, grid_y = np.meshgrid(
np.linspace(0, 2 * np.pi, self.img_shape_x),
@@ -1023,228 +1017,3 @@ def _get_lead_time_embedding(self):
) # (lead_time_steps, lead_time_channels, img_shape_y, img_shape_x)
return grid
-
-class SongUNetPosLtEmbd(SongUNetPosEmbd):
- """
- This model is adapted from SongUNetPosEmbd, with the incorporation of lead-time aware
- embeddings. The lead-time embedding is activated by setting the
- `lead_time_channels` and `lead_time_steps` parameters.
-
- Like SongUNetPosEmbd, this model provides two methods for selecting positional embeddings:
- 1. Using a selector function (preferred method). See
- :meth:`positional_embedding_selector` for details.
- 2. Using global indices. See :meth:`positional_embedding_indexing` for
- details.
-
- Parameters
- -----------
- img_resolution : Union[List[int], int]
- The resolution of the input/output image. Can be a single int for square images
- or a list [height, width] for rectangular images.
- in_channels : int
- Number of channels in the input image.
- out_channels : int
- Number of channels in the output image.
- label_dim : int, optional
- Number of class labels; 0 indicates an unconditional model. By default 0.
- augment_dim : int, optional
- Dimensionality of augmentation labels; 0 means no augmentation. By default 0.
- model_channels : int, optional
- Base multiplier for the number of channels across the network. By default 128.
- channel_mult : List[int], optional
- Per-resolution multipliers for the number of channels. By default [1,2,2,2,2].
- channel_mult_emb : int, optional
- Multiplier for the dimensionality of the embedding vector. By default 4.
- num_blocks : int, optional
- Number of residual blocks per resolution. By default 4.
- attn_resolutions : List[int], optional
- Resolutions at which self-attention layers are applied. By default [28].
- dropout : float, optional
- Dropout probability applied to intermediate activations. By default 0.13.
- label_dropout : float, optional
- Dropout probability of class labels for classifier-free guidance. By default 0.0.
- embedding_type : str, optional
- Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++.
- By default 'positional'.
- channel_mult_noise : int, optional
- Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1.
- encoder_type : str, optional
- Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++, 'skip' for skip connections.
- By default 'standard'.
- decoder_type : str, optional
- Decoder architecture: 'standard' or 'skip' for skip connections. By default 'standard'.
- resample_filter : List[int], optional
- Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++. By default [1,1].
- gridtype : str, optional
- Type of positional grid to use: 'sinusoidal', 'learnable', 'linear', or 'test'.
- Controls how positional information is encoded. By default 'sinusoidal'.
- N_grid_channels : int, optional
- Number of channels in the positional embedding grid. For 'sinusoidal' must be 4 or
- multiple of 4. For 'linear' must be 2. By default 4.
- lead_time_channels : int, optional
- Number of channels in the lead time embedding. These are learned embeddings that
- encode temporal forecast information. By default None.
- lead_time_steps : int, optional
- Number of discrete lead time steps to support. Each step gets its own learned
- embedding vector. By default 9.
- prob_channels : List[int], optional
- Indices of probability output channels that should use softmax activation.
- Used for classification outputs. By default empty list.
- checkpoint_level : int, optional
- Number of layers that should use gradient checkpointing (0 disables checkpointing).
- Higher values trade memory for computation. By default 0.
- additive_pos_embed : bool, optional
- If True, adds a learned positional embedding after the first convolution layer.
- Used in StormCast model. By default False.
- use_apex_gn : bool, optional
- A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout.
- Need to set this as False on cpu. Defaults to False.
- act : str, optional
- The activation function to use when fusing activation with GroupNorm. Defaults to None.
- profile_mode:
- A boolean flag indicating whether to enable all nvtx annotations during profiling.
- amp_mode : bool, optional
- A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False.
-
-
- Note
- -----
- Equivalent to the original implementation by Song et al., available at
- https://github.com/yang-song/score_sde_pytorch
-
- Example
- --------
- >>> import torch
- >>> from physicsnemo.models.diffusion.song_unet import SongUNetPosLtEmbd
- >>> from physicsnemo.utils.patching import GridPatching2D
- >>>
- >>> # Model initialization - in_channels must include original input channels (2),
- >>> # positional embedding channels (N_grid_channels=4 by default) and
- >>> # lead time embedding channels (4)
- >>> model = SongUNetPosLtEmbd(
- ... img_resolution=16, in_channels=2+4+4, out_channels=2,
- ... lead_time_channels=4, lead_time_steps=9
- ... )
- >>> noise_labels = torch.randn([1])
- >>> class_labels = torch.randint(0, 1, (1, 1))
- >>> # The input has only the original 2 channels - positional embeddings and
- >>> # lead time embeddings are added automatically inside the forward method
- >>> input_image = torch.ones([1, 2, 16, 16])
- >>> lead_time_label = torch.tensor([3])
- >>> output_image = model(
- ... input_image, noise_labels, class_labels,
- ... lead_time_label=lead_time_label
- ... )
- >>> output_image.shape
- torch.Size([1, 2, 16, 16])
- >>>
- >>> # Using global_index to select all the positional and lead time embeddings
- >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16))
- >>> global_index = patching.global_index(batch_size=1)
- >>> output_image = model(
- ... input_image, noise_labels, class_labels,
- ... lead_time_label=lead_time_label,
- ... global_index=global_index
- ... )
- >>> output_image.shape
- torch.Size([1, 2, 16, 16])
-
- # NOTE: commented out doctest for embedding_selector due to compatibility issue
- # >>>
- # >>> # Using custom embedding selector to select all the positional and lead time embeddings
- # >>> def patch_embedding_selector(emb):
- # ... return patching.apply(emb[None].expand(1, -1, -1, -1))
- # >>> output_image = model(
- # ... input_image, noise_labels, class_labels,
- # ... lead_time_label=lead_time_label,
- # ... embedding_selector=patch_embedding_selector
- # ... )
- # >>> output_image.shape
- # torch.Size([1, 2, 16, 16])
-
- """
-
- def __init__(
- self,
- img_resolution: Union[List[int], int],
- in_channels: int,
- out_channels: int,
- label_dim: int = 0,
- augment_dim: int = 0,
- model_channels: int = 128,
- channel_mult: List[int] = [1, 2, 2, 2, 2],
- channel_mult_emb: int = 4,
- num_blocks: int = 4,
- attn_resolutions: List[int] = [28],
- dropout: float = 0.13,
- label_dropout: float = 0.0,
- embedding_type: str = "positional",
- channel_mult_noise: int = 1,
- encoder_type: str = "standard",
- decoder_type: str = "standard",
- resample_filter: List[int] = [1, 1],
- gridtype: str = "sinusoidal",
- N_grid_channels: int = 4,
- lead_time_channels: int = None,
- lead_time_steps: int = 9,
- prob_channels: List[int] = [],
- checkpoint_level: int = 0,
- additive_pos_embed: bool = False,
- use_apex_gn: bool = False,
- act: str = "silu",
- profile_mode: bool = False,
- amp_mode: bool = False,
- ):
- super().__init__(
- img_resolution,
- in_channels,
- out_channels,
- label_dim,
- augment_dim,
- model_channels,
- channel_mult,
- channel_mult_emb,
- num_blocks,
- attn_resolutions,
- dropout,
- label_dropout,
- embedding_type,
- channel_mult_noise,
- encoder_type,
- decoder_type,
- resample_filter,
- gridtype,
- N_grid_channels,
- checkpoint_level,
- additive_pos_embed,
- use_apex_gn,
- act,
- profile_mode,
- amp_mode,
- True, # Note: lead_time_mode=True is enforced here
- lead_time_channels,
- lead_time_steps,
- prob_channels,
- )
-
- def forward(
- self,
- x,
- noise_labels,
- class_labels,
- lead_time_label=None,
- global_index: Optional[torch.Tensor] = None,
- embedding_selector: Optional[Callable] = None,
- augment_labels=None,
- ):
- return super().forward(
- x=x,
- noise_labels=noise_labels,
- class_labels=class_labels,
- global_index=global_index,
- embedding_selector=embedding_selector,
- augment_labels=augment_labels,
- lead_time_label=lead_time_label,
- )
-
- # Nothing else is re-implemented, because everything is already in the parent SongUNetPosEmb
\ No newline at end of file
diff --git a/src/hirad/models/unet.py b/src/hirad/models/unet.py
index e0a447aa..82b9a273 100644
--- a/src/hirad/models/unet.py
+++ b/src/hirad/models/unet.py
@@ -21,29 +21,9 @@
import torch
import torch.nn as nn
-from .meta import ModelMetaData
-
network_module = importlib.import_module("hirad.models")
-@dataclass
-class MetaData(ModelMetaData):
- name: str = "UNet"
- # Optimization
- jit: bool = False
- cuda_graphs: bool = False
- amp_cpu: bool = False
- amp_gpu: bool = True
- torch_fx: bool = False
- # Data type
- bf16: bool = True
- # Inference
- onnx: bool = False
- # Physics informed
- func_torch: bool = False
- auto_grad: bool = False
-
-
class UNet(nn.Module): # TODO a lot of redundancy, need to clean up
"""
U-Net Wrapper for CorrDiff deterministic regression model.
@@ -61,7 +41,7 @@ class UNet(nn.Module): # TODO a lot of redundancy, need to clean up
Execute the underlying model at FP16 precision, by default False.
model_type: str, optional
Class name of the underlying model. Must be one of the following:
- 'SongUNet', 'SongUNetPosEmbd', 'SongUNetPosLtEmbd', 'DhariwalUNet'.
+ 'SongUNet', 'SongUNetPosEmbd'.
Defaults to 'SongUNetPosEmbd'.
**model_kwargs : dict
Keyword arguments passed to the underlying model `__init__` method.
@@ -69,9 +49,8 @@ class UNet(nn.Module): # TODO a lot of redundancy, need to clean up
See Also
--------
For information on model types and their usage:
- :class:`~physicsnemo.models.diffusion.SongUNet`: Basic U-Net for diffusion models
- :class:`~physicsnemo.models.diffusion.SongUNetPosEmbd`: U-Net with positional embeddings
- :class:`~physicsnemo.models.diffusion.SongUNetPosLtEmbd`: U-Net with positional and lead-time embeddings
+ :class:`~models.song_unet.SongUNet`: Basic U-Net for diffusion models
+ :class:`~models.song_unet.SongUNetPosEmbd`: U-Net with positional embeddings and lead time embeddings
Please refer to the documentation of these classes for details on how to call
and use these models directly.
@@ -84,42 +63,6 @@ class UNet(nn.Module): # TODO a lot of redundancy, need to clean up
arXiv preprint arXiv:2309.15214.
"""
- @classmethod
- def _backward_compat_arg_mapper(
- cls, version: str, args: Dict[str, Any]
- ) -> Dict[str, Any]:
- """Map arguments from older versions to current version format.
-
- Parameters
- ----------
- version : str
- Version of the checkpoint being loaded
- args : Dict[str, Any]
- Arguments dictionary from the checkpoint
-
- Returns
- -------
- Dict[str, Any]
- Updated arguments dictionary compatible with current version
- """
- # Call parent class method first
- args = super()._backward_compat_arg_mapper(version, args)
-
- if version == "0.1.0":
- # In version 0.1.0, img_channels was unused
- if "img_channels" in args:
- _ = args.pop("img_channels")
-
- # Sigma parameters are also unused
- if "sigma_min" in args:
- _ = args.pop("sigma_min")
- if "sigma_max" in args:
- _ = args.pop("sigma_max")
- if "sigma_data" in args:
- _ = args.pop("sigma_data")
-
- return args
-
def __init__(
self,
img_resolution: Union[int, Tuple[int, int]],
@@ -127,7 +70,7 @@ def __init__(
img_out_channels: int,
use_fp16: bool = False,
model_type: Literal[
- "SongUNetPosEmbd", "SongUNetPosLtEmbd", "SongUNet", "DhariwalUNet"
+ "SongUNetPosEmbd", "SongUNet"
] = "SongUNetPosEmbd",
**model_kwargs: dict,
):
@@ -152,6 +95,44 @@ def __init__(
**model_kwargs,
)
+ @property
+ def use_fp16(self):
+ """
+ bool: Whether the model uses float16 precision.
+
+ Returns
+ -------
+ bool
+ True if the model is in float16 mode, False otherwise.
+ """
+ return self._use_fp16
+
+ @use_fp16.setter
+ def use_fp16(self, value: bool):
+ """
+ Set whether the model should use float16 precision.
+
+ Parameters
+ ----------
+ value : bool
+ If True, moves the model to torch.float16. If False, moves to torch.float32.
+
+ Raises
+ ------
+ ValueError
+ If `value` is not a boolean.
+ """
+ # NOTE: allow 0/1 values for older checkpoints
+ if not (isinstance(value, bool) or value in [0, 1]):
+ raise ValueError(
+ f"`use_fp16` must be a boolean, but got {type(value).__name__}."
+ )
+ self._use_fp16 = value
+ if value:
+ self.to(torch.float16)
+ else:
+ self.to(torch.float32)
+
def forward(
self,
x: torch.Tensor,
@@ -200,8 +181,8 @@ def forward(
)
F_x = self.model(
- x.to(dtype), # (c_in * x).to(dtype),
- torch.zeros(x.shape[0], dtype=dtype, device=x.device), # c_noise.flatten()
+ x.to(dtype),
+ torch.zeros(x.shape[0], dtype=dtype, device=x.device),
class_labels=None,
**model_kwargs,
)
@@ -211,7 +192,6 @@ def forward(
f"Expected the dtype to be {dtype}, " f"but got {F_x.dtype} instead."
)
- # skip connection
D_x = F_x.to(torch.float32)
return D_x
@@ -254,103 +234,3 @@ def amp_mode(self, value: bool):
if hasattr(sub_module, "amp_mode"):
sub_module.amp_mode = value
-# TODO: implement amp_mode property for StormCastUNet (same as UNet)
-class StormCastUNet(nn.Module):
- """
- U-Net wrapper for StormCast; used so the same Song U-Net network can be re-used for this model.
-
- Parameters
- -----------
- img_resolution : int or List[int]
- The resolution of the input/output image.
- img_channels : int
- Number of color channels.
- img_in_channels : int
- Number of input color channels.
- img_out_channels : int
- Number of output color channels.
- use_fp16: bool, optional
- Execute the underlying model at FP16 precision?, by default False.
- sigma_min: float, optional
- Minimum supported noise level, by default 0.
- sigma_max: float, optional
- Maximum supported noise level, by default float('inf').
- sigma_data: float, optional
- Expected standard deviation of the training data, by default 0.5.
- model_type: str, optional
- Class name of the underlying model, by default 'SongUNet'.
- **model_kwargs : dict
- Keyword arguments for the underlying model.
-
- """
-
- def __init__(
- self,
- img_resolution,
- img_in_channels,
- img_out_channels,
- use_fp16=False,
- sigma_min=0,
- sigma_max=float("inf"),
- sigma_data=0.5,
- model_type="SongUNet",
- **model_kwargs,
- ):
- super().__init__() #meta=MetaData("StormCastUNet")
-
- if isinstance(img_resolution, int):
- self.img_shape_x = self.img_shape_y = img_resolution
- else:
- self.img_shape_x = img_resolution[0]
- self.img_shape_y = img_resolution[1]
-
- self.img_in_channels = img_in_channels
- self.img_out_channels = img_out_channels
-
- self.use_fp16 = use_fp16
- self.sigma_min = sigma_min
- self.sigma_max = sigma_max
- self.sigma_data = sigma_data
- model_class = getattr(network_module, model_type)
- self.model = model_class(
- img_resolution=img_resolution,
- in_channels=img_in_channels,
- out_channels=img_out_channels,
- **model_kwargs,
- )
-
- def forward(self, x, force_fp32=False, **model_kwargs):
- """Run a forward pass of the StormCast regression U-Net.
-
- Args:
- x (torch.Tensor): input to the U-Net
- force_fp32 (bool, optional): force casting to fp_32 if True. Defaults to False.
-
- Raises:
- ValueError: If input data type is a mismatch with provided options
-
- Returns:
- D_x (torch.Tensor): Output (prediction) of the U-Net
- """
-
- x = x.to(torch.float32)
- dtype = (
- torch.float16
- if (self.use_fp16 and not force_fp32 and x.device.type == "cuda")
- else torch.float32
- )
-
- F_x = self.model(
- x.to(dtype),
- torch.zeros(x.shape[0], dtype=x.dtype, device=x.device),
- class_labels=None,
- **model_kwargs,
- )
-
- if (F_x.dtype != dtype) and not torch.is_autocast_enabled():
- raise ValueError(
- f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead."
- )
-
- D_x = F_x.to(torch.float32)
- return D_x
diff --git a/src/hirad/models/utils.py b/src/hirad/models/utils.py
deleted file mode 100644
index e1cde9d8..00000000
--- a/src/hirad/models/utils.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
-# SPDX-FileCopyrightText: All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import numpy as np
-import torch
-
-
-def weight_init(shape: tuple, mode: str, fan_in: int, fan_out: int):
- """
- Unified routine for initializing weights and biases.
- This function provides a unified interface for various weight initialization
- strategies like Xavier (Glorot) and Kaiming (He) initializations.
-
- Parameters
- ----------
- shape : tuple
- The shape of the tensor to initialize. It could represent weights or biases
- of a layer in a neural network.
- mode : str
- The mode/type of initialization to use. Supported values are:
- - "xavier_uniform": Xavier (Glorot) uniform initialization.
- - "xavier_normal": Xavier (Glorot) normal initialization.
- - "kaiming_uniform": Kaiming (He) uniform initialization.
- - "kaiming_normal": Kaiming (He) normal initialization.
- fan_in : int
- The number of input units in the weight tensor. For convolutional layers,
- this typically represents the number of input channels times the kernel height
- times the kernel width.
- fan_out : int
- The number of output units in the weight tensor. For convolutional layers,
- this typically represents the number of output channels times the kernel height
- times the kernel width.
-
- Returns
- -------
- torch.Tensor
- The initialized tensor based on the specified mode.
-
- Raises
- ------
- ValueError
- If the provided `mode` is not one of the supported initialization modes.
- """
- if mode == "xavier_uniform":
- return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1)
- if mode == "xavier_normal":
- return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape)
- if mode == "kaiming_uniform":
- return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1)
- if mode == "kaiming_normal":
- return np.sqrt(1 / fan_in) * torch.randn(*shape)
- raise ValueError(f'Invalid init mode "{mode}"')
diff --git a/src/hirad/snapshots.sh b/src/hirad/snapshots.sh
new file mode 100644
index 00000000..9be34097
--- /dev/null
+++ b/src/hirad/snapshots.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+
+#SBATCH --job-name="snapshots"
+
+### HARDWARE ###
+#SBATCH --partition=normal
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=1
+#SBATCH --gpus-per-node=1
+#SBATCH --cpus-per-task=72
+#SBATCH --time=00:30:00
+#SBATCH --no-requeue
+#SBATCH --exclusive
+
+### OUTPUT ###
+#SBATCH --output=./logs/snapshot.log
+
+### ENVIRONMENT ####
+#SBATCH -A c38
+
+# Optional: pass specific timesteps as arguments (format: YYYYMMDD-HHMM)
+# Usage: sbatch snapshots.sh 20230824-1400
+EXTRA_ARGS=()
+if [ $# -gt 0 ]; then
+ EXTRA_ARGS+=("--times" "$@")
+fi
+
+EXTRA_ARGS_STR="${EXTRA_ARGS[*]@Q}"
+srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -c "
+ pip install -e .
+ python src/hirad/eval/snapshots.py --config-name=src/hirad/conf/eval_real.yaml ${EXTRA_ARGS_STR}
+"
\ No newline at end of file
diff --git a/src/hirad/submit_monthly.sh b/src/hirad/submit_monthly.sh
new file mode 100755
index 00000000..b22c2829
--- /dev/null
+++ b/src/hirad/submit_monthly.sh
@@ -0,0 +1,45 @@
+#!/bin/bash
+# Submit monthly generation jobs over an inclusive month range as a SLURM array.
+#
+# Usage:
+# ./submit_monthly.sh START_MONTH END_MONTH [extra sbatch args...]
+#
+# Examples:
+# ./submit_monthly.sh 2021-01 2024-12 # 4 years
+# ./submit_monthly.sh 2022-06 2022-08 # JJA 2022
+
+set -euo pipefail
+
+usage() {
+ echo "Usage: $0 START_MONTH END_MONTH [extra sbatch args...]" >&2
+ echo "Example: $0 2022-06 2022-08" >&2
+ exit 1
+}
+
+month_index() {
+ local ym="$1"
+ echo $(( 10#${ym%-*} * 12 + 10#${ym#*-} ))
+}
+
+(( $# >= 2 )) || usage
+
+START_MONTH="$1" END_MONTH="$2"
+shift 2
+
+TASKS=$(( $(month_index "$END_MONTH") - $(month_index "$START_MONTH") + 1 ))
+if (( TASKS <= 0 )); then
+ echo "ERROR: END_MONTH (${END_MONTH}) is before START_MONTH (${START_MONTH})." >&2
+ exit 1
+fi
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+echo "Submitting ${TASKS} monthly jobs: ${START_MONTH} .. ${END_MONTH}"
+exec sbatch \
+ --array="0-$((TASKS - 1))" \
+ --export=ALL,START_MONTH="${START_MONTH}",END_MONTH="${END_MONTH}" \
+ --time=6:00:00 \
+ --account=c38 \
+ --output=./logs/generation_monthly_%A_%a.log \
+ "$@" \
+ "${SCRIPT_DIR}/generate.sh"
diff --git a/src/hirad/train_diffusion.sh b/src/hirad/train_diffusion.sh
index cf2f88f3..dcc46375 100644
--- a/src/hirad/train_diffusion.sh
+++ b/src/hirad/train_diffusion.sh
@@ -1,25 +1,23 @@
#!/bin/bash
-#SBATCH --job-name="testrun"
+#SBATCH --job-name="corrdiff-second-stage"
### HARDWARE ###
-#SBATCH --partition=debug
-#SBATCH --nodes=1
-#SBATCH --ntasks-per-node=1
-#SBATCH --gpus-per-node=1
+#SBATCH --partition=normal
+#SBATCH --nodes=8
+#SBATCH --ntasks-per-node=4
+#SBATCH --gpus-per-node=4
#SBATCH --cpus-per-task=72
-#SBATCH --time=00:30:00
+#SBATCH --time=12:00:00
#SBATCH --no-requeue
#SBATCH --exclusive
### OUTPUT ###
-#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/diffusion.log
-#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/diffusion.err
+#SBATCH --output=./logs/train_diffusion.log
+#SBATCH --error=./logs/train_diffusion.err
### ENVIRONMENT ####
-#SBATCH --uenv=pytorch/v2.6.0:/user-environment
-#SBATCH --view=default
-#SBATCH -A a-a122
+#SBATCH -A a161
# Choose method to initialize dist in pythorch
export DISTRIBUTED_INITIALIZATION_METHOD=SLURM
@@ -31,15 +29,10 @@ MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')"
export MASTER_ADDR
export MASTER_PORT=29500
-# Get number of physical cores using Python
-PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))")
-LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1}
-# Compute cores per process
-OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS ))
-export OMP_NUM_THREADS=$OMP_THREADS
+export OMP_NUM_THREADS=1
# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml
-srun bash -c "
- . ./train_env/bin/activate
- python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion.yaml
+srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -c "
+ pip install -e .
+ python src/hirad/training/train.py --config-name=training_era_real_diffusion_patched.yaml
"
\ No newline at end of file
diff --git a/src/hirad/train_diffusion_test.sh b/src/hirad/train_diffusion_test.sh
new file mode 100644
index 00000000..9c9831f4
--- /dev/null
+++ b/src/hirad/train_diffusion_test.sh
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+#SBATCH --job-name="corrdiff-test-second-stage"
+
+### HARDWARE ###
+#SBATCH --partition=debug
+#SBATCH --nodes=2
+#SBATCH --ntasks-per-node=4
+#SBATCH --gpus-per-node=4
+#SBATCH --cpus-per-task=72
+#SBATCH --time=00:30:00
+#SBATCH --no-requeue
+#SBATCH --exclusive
+
+### OUTPUT ###
+#SBATCH --output=./logs/training_diffusion_test.log
+
+### ENVIRONMENT ####
+#SBATCH -A a161
+
+# Choose method to initialize dist in pythorch
+export DISTRIBUTED_INITIALIZATION_METHOD=SLURM
+
+# Get master node.
+MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)"
+# Get IP for hostname.
+MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')"
+export MASTER_ADDR
+export MASTER_PORT=29500
+
+# Get number of physical cores using Python
+# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))")
+# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1}
+# # Compute cores per process
+# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS ))
+# export OMP_NUM_THREADS=$OMP_THREADS
+export OMP_NUM_THREADS=72
+
+srun --environment=./ci/edf/modulus_env.toml bash -c "
+ pip install -e . --no-dependencies
+ python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion_test.yaml
+"
\ No newline at end of file
diff --git a/src/hirad/train_regression.sh b/src/hirad/train_regression.sh
index c0654773..e9e4271e 100644
--- a/src/hirad/train_regression.sh
+++ b/src/hirad/train_regression.sh
@@ -1,25 +1,24 @@
#!/bin/bash
-#SBATCH --job-name="testrun"
+#SBATCH --job-name="corrdiff-first-stage"
### HARDWARE ###
-#SBATCH --partition=debug
-#SBATCH --nodes=1
-#SBATCH --ntasks-per-node=1
-#SBATCH --gpus-per-node=1
+#SBATCH --partition=normal
+#SBATCH --nodes=8
+#SBATCH --ntasks-per-node=4
+#SBATCH --gpus-per-node=4
#SBATCH --cpus-per-task=72
-#SBATCH --time=00:30:00
+#SBATCH --time=12:00:00
#SBATCH --no-requeue
#SBATCH --exclusive
+
### OUTPUT ###
-#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/regression.log
-#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/regression.err
+#SBATCH --output=./logs/train_regression.log
+#SBATCH --error=./logs/train_regression.err
### ENVIRONMENT ####
-#SBATCH --uenv=pytorch/v2.6.0:/user-environment
-#SBATCH --view=default
-#SBATCH -A a-a122
+#SBATCH -A a161
# Choose method to initialize dist in pythorch
export DISTRIBUTED_INITIALIZATION_METHOD=SLURM
@@ -31,15 +30,9 @@ MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')"
export MASTER_ADDR
export MASTER_PORT=29500
-# Get number of physical cores using Python
-PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))")
-LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1}
-# Compute cores per process
-OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS ))
-export OMP_NUM_THREADS=$OMP_THREADS
-
-# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml
-srun bash -c "
- . ./train_env/bin/activate
- python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml
+export OMP_NUM_THREADS=1
+
+srun --mpi=pmix --network=disable_rdzv_get --environment=./ci/edf/modulus_env.toml bash -c "
+ pip install -e .
+ python src/hirad/training/train.py --config-name=training_era_real_regression.yaml
"
\ No newline at end of file
diff --git a/src/hirad/train_regression_test.sh b/src/hirad/train_regression_test.sh
new file mode 100644
index 00000000..6cc04681
--- /dev/null
+++ b/src/hirad/train_regression_test.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+
+### OUTPUT ###
+#SBATCH --output=./logs/training_regression_test.log
+
+# Choose method to initialize dist in pythorch
+export DISTRIBUTED_INITIALIZATION_METHOD=SLURM
+
+# Get master node.
+MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)"
+# Get IP for hostname.
+MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')"
+export MASTER_ADDR
+export MASTER_PORT=29500
+
+# Get number of physical cores using Python
+# PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))")
+# LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1}
+# # Compute cores per process
+# OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS ))
+# export OMP_NUM_THREADS=$OMP_THREADS
+export OMP_NUM_THREADS=72
+
+pip install -e . --no-dependencies
+python src/hirad/training/train.py --config-name=training_era_cosmo_regression_test.yaml
\ No newline at end of file
diff --git a/src/hirad/training/__init__.py b/src/hirad/training/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py
index 12b6942c..ddf55cb7 100755
--- a/src/hirad/training/train.py
+++ b/src/hirad/training/train.py
@@ -1,30 +1,41 @@
import os
import time
-import psutil
+from concurrent.futures import ThreadPoolExecutor
+
import hydra
from omegaconf import DictConfig, OmegaConf
import json
from contextlib import nullcontext
import nvtx
+import numpy as np
import torch
from hydra.utils import to_absolute_path
-from torch.utils.tensorboard import SummaryWriter
+# from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel
+import mlflow
# from torchinfo import summary
from hirad.distributed import DistributedManager
from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper
from hirad.utils.train_helpers import set_seed, configure_cuda_for_consistent_precision, \
- set_patch_shape, compute_num_accumulation_rounds, \
- is_time_for_periodic_task, handle_and_clip_gradients
+ set_patch_shape, compute_num_accumulation_rounds, calculate_patch_per_iter, \
+ is_time_for_periodic_task, handle_and_clip_gradients, \
+ init_mlflow, update_learning_rate, log_training_progress, \
+ cuda_profiler, cuda_profiler_start, cuda_profiler_stop, profiler_emit_nvtx
from hirad.utils.checkpoint import load_checkpoint, save_checkpoint
from hirad.utils.patching import RandomPatching2D
-from hirad.models import UNet, EDMPrecondSuperResolution, EDMPrecondSR
-from hirad.losses import ResidualLoss, RegressionLoss, RegressionLossCE
-from hirad.datasets import init_train_valid_datasets_from_config
+from hirad.utils.function_utils import get_time_from_range
+from hirad.utils.inference_utils import save_results_as_torch
+from hirad.utils.env_info import get_env_info, flatten_dict
+from hirad.utils.dataset_utils import regrid_icon_to_rotlatlon
+from hirad.models import UNet, EDMPrecondSuperResolution
+from hirad.losses import ResidualLoss, RegressionLoss
+from hirad.datasets import init_train_valid_datasets_from_config, get_dataset_and_sampler_inference
+from hirad.inference import Generator
+from hirad.training.training_manager import TrainingManagerCorrDiff
+
-from matplotlib import pyplot as plt
torch._dynamo.reset()
# Increase the cache size limit
@@ -33,66 +44,79 @@
torch._dynamo.config.suppress_errors = False # Forces the error to show all details
torch._logging.set_logs(recompiles=True, graph_breaks=True)
-# Define safe CUDA profiler tools that fallback to no-ops when CUDA is not available
-def cuda_profiler():
- if torch.cuda.is_available():
- return torch.cuda.profiler.profile()
- else:
- return nullcontext()
-
-
-def cuda_profiler_start():
- if torch.cuda.is_available():
- torch.cuda.profiler.start()
-
-
-def cuda_profiler_stop():
- if torch.cuda.is_available():
- torch.cuda.profiler.stop()
-
-
-def profiler_emit_nvtx():
- if torch.cuda.is_available():
- return torch.autograd.profiler.emit_nvtx()
- else:
- return nullcontext()
@hydra.main(version_base=None, config_path="../conf", config_name="training")
def main(cfg: DictConfig) -> None:
+
# Initialize distributed environment for training
DistributedManager.initialize()
dist = DistributedManager()
- if dist.rank==0:
- writer = SummaryWriter(log_dir='tensorboard')
+ OmegaConf.resolve(cfg)
+
+ # Initialize logging
+ if cfg.logging.method == "mlflow":
+ init_mlflow(cfg, dist)
+ if dist.world_size > 1:
+ torch.distributed.barrier()
+ elif cfg.logging.method is not None:
+ raise ValueError("The only available logging method is mlflow. To disable logging set the method to null.")
+
logger = PythonLogger("main") # general logger
logger0 = RankZeroLoggingWrapper(logger, dist) # rank 0 logger
- OmegaConf.resolve(cfg)
- dataset_cfg = OmegaConf.to_container(cfg.dataset)
- if hasattr(cfg.dataset, "validation_path"):
- train_test_split = True
- else:
- train_test_split = False
- fp_optimizations = cfg.training.perf.fp_optimizations
- songunet_checkpoint_level = cfg.training.perf.songunet_checkpoint_level
- fp16 = fp_optimizations == "fp16"
- enable_amp = fp_optimizations.startswith("amp")
- amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16
+ logger0.info(f"Config is: {cfg}")
logger0.info(f"Saving the outputs in {os.getcwd()}")
+
+ # create checkpoint directory if it doesn't exist
checkpoint_dir = os.path.join(
cfg.training.io.get("checkpoint_dir", "."), f"checkpoints_{cfg.model.name}"
)
if dist.rank==0 and not os.path.exists(checkpoint_dir):
- os.makedirs(checkpoint_dir) # added creating checkpoint dir
+ os.makedirs(checkpoint_dir)
+
+ # performance optimization configuration
+ use_torch_compile = getattr(cfg.training.perf, "torch_compile", False)
+ use_apex_gn = getattr(cfg.training.perf, "use_apex_gn", False)
+ profile_mode = getattr(cfg.training.perf, "profile_mode", False)
+ fp_optimizations = cfg.training.perf.fp_optimizations
+ songunet_checkpoint_level = cfg.training.perf.songunet_checkpoint_level
+ fp16 = fp_optimizations == "fp16"
+ enable_amp = fp_optimizations.startswith("amp")
+ amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16
+
+ # set the data type for model inputs based on optimization configuration
+ input_dtype = torch.float32
+ if enable_amp:
+ input_dtype = torch.float32
+ elif fp16:
+ input_dtype = torch.float16
+
+ # dataset configuration
+ dataset_cfg = OmegaConf.to_container(cfg.dataset)
+ train_test_split = getattr(cfg.dataset, "validation", False)
+ n_month_hour_channels = 2*dataset_cfg.get("n_month_hour_channels", 0)
+
+ # validate and set batch size configuration
+ if cfg.training.hp.batch_size_per_gpu == "auto" and \
+ cfg.training.hp.total_batch_size == "auto":
+ raise ValueError("batch_size_per_gpu and total_batch_size can't be both set to 'auto'.")
if cfg.training.hp.batch_size_per_gpu == "auto":
cfg.training.hp.batch_size_per_gpu = (
cfg.training.hp.total_batch_size // dist.world_size
)
+ elif cfg.training.hp.total_batch_size == "auto":
+ cfg.training.hp.total_batch_size = (
+ cfg.training.hp.batch_size_per_gpu * dist.world_size
+ )
+
+ # Get the current training step from the checkpoint if it exists, otherwise start from 0.
+ cur_nimg = load_checkpoint(path=checkpoint_dir)
- set_seed(dist.rank)
+ # Fix the seed based on training progress for reproducibility.
+ set_seed(dist.rank + cur_nimg)
configure_cuda_for_consistent_precision()
-
+
# Instantiate the dataset
data_loader_kwargs = {
"pin_memory": True,
@@ -110,25 +134,35 @@ def main(cfg: DictConfig) -> None:
batch_size=cfg.training.hp.batch_size_per_gpu,
seed=0,
train_test_split=train_test_split,
+ sampler_start_idx=cur_nimg,
)
+ is_real_target = dataset_cfg.get("type").split("_")[-1] == "real"
logger0.info(f"Training on dataset with size {len(dataset)}")
+ logger0.info(f"Validating on dataset with size {len(validation_dataset) if validation_dataset else 0}")
- # Parse image configuration & update model args
- dataset_channels = len(dataset.input_channels())
- img_in_channels = dataset_channels
+ # Get the shape of the grid (without the channel dimension) for later use in model creation and patching
img_shape = dataset.image_shape()
- img_out_channels = len(dataset.output_channels())
- if cfg.model.hr_mean_conditioning:
- img_in_channels += img_out_channels
+ logger0.info(f"Training on dataset with grid size {img_shape[0]}x{img_shape[1]}, {len(dataset.input_channels())} input channels and {len(dataset.output_channels())} output channels.")
+ logger0.info(f"Input channels: {dataset.input_channels()}")
+ logger0.info(f"Output channels: {dataset.output_channels()}")
+ logger0.info(f"Static channels: {dataset.static_channels()}")
+
+ # convert dataset stats to torch tensors on the correct device for later use in normalization and denormalization
+ dataset.stats_to_torch(device=dist.device, dtype=input_dtype)
+ # convert dataset stats to torch tensors on the correct device for later use in loss normalization and denormalization
+ dataset.interpolator.to(device=dist.device)
+ # convert regridding weights and indices to torch tensors on the correct device if real target dataset is used
+ if is_real_target:
+ dataset.regrid_indices_real = dataset.regrid_indices_real.to(dist.device)
+ dataset.regrid_weights_real = dataset.regrid_weights_real.to(dist.device, dtype=input_dtype)
if cfg.model.name == "lt_aware_ce_regression":
- prob_channels = dataset.get_prob_channel_index() #TODO figure out what prob_channel are and update dataloader
+ prob_channels = dataset.get_prob_channel_index()
else:
prob_channels = None
# Parse the patch shape
- #TODO figure out patched diffusion and how to use it
if (
cfg.model.name == "patched_diffusion"
or cfg.model.name == "lt_aware_patched_diffusion"
@@ -150,8 +184,9 @@ def main(cfg: DictConfig) -> None:
)
patch_shape = (patch_shape_y, patch_shape_x)
use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape)
+
+ # Initialize patcher if patch-based training is enabled
if use_patching:
- # Utility to perform patches extraction and batching
patching = RandomPatching2D(
img_shape=img_shape,
patch_shape=patch_shape,
@@ -161,69 +196,38 @@ def main(cfg: DictConfig) -> None:
else:
patching = None
logger0.info("Patch-based training disabled")
- # interpolate global channel if patch-based model is used
- if use_patching:
- img_in_channels += dataset_channels
-
- # Instantiate the model and move to device.
- model_args = { # default parameters for all networks
- "img_out_channels": img_out_channels,
- "img_resolution": list(img_shape),
- "use_fp16": fp16,
- "checkpoint_level": songunet_checkpoint_level,
- }
- if cfg.model.name == "lt_aware_ce_regression":
- model_args["prob_channels"] = prob_channels
- if hasattr(cfg.model, "model_args"): # override defaults from config file
- model_args.update(OmegaConf.to_container(cfg.model.model_args))
-
- use_torch_compile = False
- use_apex_gn = False
- profile_mode = False
+ # Instantiate the training manager which handles model creation,
+ # data loading and transformation,
+ # and validation
+ training_manager = TrainingManagerCorrDiff(
+ dist,
+ logger0,
+ dataset,
+ input_dtype,
+ img_shape,
+ n_month_hour_channels,
+ fp16,
+ profile_mode,
+ enable_amp,
+ amp_dtype,
+ use_apex_gn,
+ is_real_target,
+ songunet_checkpoint_level,
+ use_patching,
+ cfg.model.get("hr_mean_conditioning", False),
+ cfg.logging.get("method", None)
+ )
- if hasattr(cfg.training.perf, "torch_compile"):
- use_torch_compile = cfg.training.perf.torch_compile
- if hasattr(cfg.training.perf, "use_apex_gn"):
- use_apex_gn = cfg.training.perf.use_apex_gn
- model_args["use_apex_gn"] = use_apex_gn
+ # Create the model and move it to the appropriate device and memory format based on the optimization configuration
+ model, model_args = training_manager.create_model(cfg.model.name, cfg.model.get("model_args", None))
- if hasattr(cfg.training.perf, "profile_mode"):
- profile_mode = cfg.training.perf.profile_mode
- model_args["profile_mode"] = profile_mode
+ # # Print the model summary
+ # if dist.rank == 0:
+ # summary(model, input_size=[(1, img_out_channels, *img_shape), (1, img_in_channels, *img_shape), (1,1)], device=dist.device)
- if enable_amp:
- model_args["amp_mode"] = enable_amp
+ # raise NotImplementedError("Check if model_args are correct when using patching - img_in_channels should include global channels and lead time channels if applicable")
-
- if cfg.model.name == "regression":
- model = UNet(
- img_in_channels=img_in_channels + model_args["N_grid_channels"],
- **model_args,
- )
- model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"]
- elif cfg.model.name == "lt_aware_ce_regression":
- model = UNet(
- img_in_channels=img_in_channels
- + model_args["N_grid_channels"]
- + model_args["lead_time_channels"],
- **model_args,
- )
- model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"]
- elif cfg.model.name == "lt_aware_patched_diffusion":
- model = EDMPrecondSuperResolution(
- img_in_channels=img_in_channels
- + model_args["N_grid_channels"]
- + model_args["lead_time_channels"],
- **model_args,
- )
- model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"]
- else: # diffusion or patched diffusion
- model = EDMPrecondSuperResolution(
- img_in_channels=img_in_channels + model_args["N_grid_channels"],
- **model_args,
- )
- model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"]
model.train().requires_grad_(True).to(dist.device)
@@ -243,62 +247,10 @@ def main(cfg: DictConfig) -> None:
f"Regression model ({cfg.model.name}) cannot be used with patch-based training. "
)
- # Enable distributed data parallel if applicable
- if dist.world_size > 1:
- model = DistributedDataParallel(
- model,
- device_ids=[dist.local_rank],
- broadcast_buffers=True,
- output_device=dist.device,
- find_unused_parameters=True, # dist.find_unused_parameters,
- bucket_cap_mb=35,
- gradient_as_bucket_view=True,
- )
-
# Load the regression checkpoint if applicable #TODO test when training correction
+ regression_net = None
if hasattr(cfg.training.io, "regression_checkpoint_path"):
- regression_checkpoint_path = to_absolute_path(
- cfg.training.io.regression_checkpoint_path
- )
- if not os.path.isdir(regression_checkpoint_path):
- raise FileNotFoundError(
- f"Expected this regression checkpoint but not found: {regression_checkpoint_path}"
- )
- #regression_net = torch.nn.Module() #TODO Module.from_checkpoint(regression_checkpoint_path) figure out how to save and load models, also, some basic functions like num_params, device
- #TODO make regression model loading more robust (model type is both in rergession_checkpoint_path and regression_name)
- #TODO add the option to choose epoch to load from / regression_checkpoint_path is now a folder
- regression_model_args_path = os.path.join(regression_checkpoint_path, 'model_args.json')
- if not os.path.isfile(regression_model_args_path):
- raise FileNotFoundError(f"Missing config file at '{regression_model_args_path}'.")
-
- with open(regression_model_args_path, 'r') as f:
- regression_model_args = json.load(f)
-
- regression_model_args.update({
- "use_apex_gn": use_apex_gn,
- "profile_mode": profile_mode,
- "amp_mode": enable_amp,
- })
-
- regression_net = UNet(**regression_model_args)
-
- _ = load_checkpoint(
- path=regression_checkpoint_path,
- model=regression_net,
- device=dist.device
- )
- regression_net.eval().requires_grad_(False).to(dist.device)
- if use_apex_gn:
- regression_net.to(memory_format=torch.channels_last)
- logger0.success("Loaded the pre-trained regression model")
- else:
- regression_net = None
-
- # Compile the model and regression net if applicable
- if use_torch_compile:
- model = torch.compile(model)
- if regression_net:
- regression_net = torch.compile(regression_net)
+ regression_net = training_manager.load_regression_model(to_absolute_path(cfg.training.io.regression_checkpoint_path))
# Compute the number of required gradient accumulation rounds
@@ -311,26 +263,14 @@ def main(cfg: DictConfig) -> None:
batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu
logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds")
+ # calculate patch per iter
patch_num = getattr(cfg.training.hp, "patch_num", 1)
- max_patch_per_gpu = getattr(cfg.training.hp, "max_patch_per_gpu", 1)
-
- # calculate patch per iter
- if hasattr(cfg.training.hp, "max_patch_per_gpu") and max_patch_per_gpu > 1:
- max_patch_num_per_iter = min(
- patch_num, (max_patch_per_gpu // batch_size_per_gpu)
- ) # Ensure at least 1 patch per iter
- patch_iterations = (
- patch_num + max_patch_num_per_iter - 1
- ) // max_patch_num_per_iter
- patch_nums_iter = [
- min(max_patch_num_per_iter, patch_num - i * max_patch_num_per_iter)
- for i in range(patch_iterations)
- ]
- print(
- f"max_patch_num_per_iter is {max_patch_num_per_iter}, patch_iterations is {patch_iterations}, patch_nums_iter is {patch_nums_iter}"
- )
- else:
- patch_nums_iter = [patch_num]
+ max_patch_per_gpu = getattr(cfg.training.hp, "max_patch_per_gpu", None)
+ patch_nums_iter = calculate_patch_per_iter(patch_num, max_patch_per_gpu, batch_size_per_gpu)
+
+ logger0.info(
+ f"Patch number iterations are {patch_nums_iter}"
+ )
# Set patch gradient accumulation only for patched diffusion models
if cfg.model.name in {
@@ -367,10 +307,8 @@ def main(cfg: DictConfig) -> None:
)
elif cfg.model.name == "regression":
loss_fn = RegressionLoss()
- elif cfg.model.name == "lt_aware_ce_regression":
- loss_fn = RegressionLossCE(prob_channels=prob_channels)
- # Instantiate the optimizer
+ # Instantiate the optimizer
optimizer = torch.optim.Adam(
params=model.parameters(),
lr=cfg.training.hp.lr,
@@ -379,9 +317,6 @@ def main(cfg: DictConfig) -> None:
fused=True,
)
- # Record the current time to measure the duration of subsequent operations.
- start_time = time.time()
-
# Load optimizer checkpoint if it exists
if dist.world_size > 1:
torch.distributed.barrier()
@@ -395,10 +330,35 @@ def main(cfg: DictConfig) -> None:
except:
cur_nimg = 0
+ # Compile the model and regression net if applicable
+ if use_torch_compile:
+ # if dist.world_size==1:
+ model = torch.compile(model)
+ if regression_net:
+ regression_net = torch.compile(regression_net)
+
+ # Enable distributed data parallel if applicable
+ if dist.world_size > 1:
+ # if use_torch_compile:
+ # model = torch.compile(model)
+ model = DistributedDataParallel(
+ model,
+ device_ids=[dist.local_rank],
+ broadcast_buffers=True,
+ output_device=dist.device,
+ find_unused_parameters=True, # dist.find_unused_parameters,
+ bucket_cap_mb=35,
+ gradient_as_bucket_view=True,
+ )
+
+
############################################################################
# MAIN TRAINING LOOP #
############################################################################
+ # Record the current time to measure the duration of subsequent operations.
+ start_time = time.time()
+
logger0.info(f"Training for {cfg.training.hp.training_duration} images...")
done = False
@@ -406,11 +366,13 @@ def main(cfg: DictConfig) -> None:
average_loss_running_mean = 0
n_average_loss_running_mean = 1
start_nimg = cur_nimg
- input_dtype = torch.float32
- if enable_amp:
- input_dtype = torch.float32
- elif fp16:
- input_dtype = torch.float16
+
+ # prepare static channels if there are any
+ static_channels = training_manager.get_static_data()
+
+ # turn off for lead time labels for now since we are not using them
+ # TODO: implement lead time labels properly once we train on IFS?
+ lead_time_label = None
# enable profiler:
with cuda_profiler():
@@ -436,36 +398,17 @@ def main(cfg: DictConfig) -> None:
f"accumulation round {n_i}", color="Magenta"
):
with nvtx.annotate("loading data", color="green"):
- img_clean, img_lr, *lead_time_label = next(
- dataset_iterator
- )
- if use_apex_gn:
- img_clean = img_clean.to(
- dist.device,
- dtype=input_dtype,
- non_blocking=True,
- ).to(memory_format=torch.channels_last)
- img_lr = img_lr.to(
- dist.device,
- dtype=input_dtype,
- non_blocking=True,
- ).to(memory_format=torch.channels_last)
- else:
- img_clean = (
- img_clean.to(dist.device)
- .to(input_dtype)
- .contiguous()
- )
- img_lr = (
- img_lr.to(dist.device)
- .to(input_dtype)
- .contiguous()
- )
+ tick_read_start_time = time.time()
+ img_clean, img_lr, date_embedding = training_manager.load_and_preprocess_batch(dataset_iterator)
+ tick_read_time = time.time() - tick_read_start_time
loss_fn_kwargs = {
"net": model,
"img_clean": img_clean,
"img_lr": img_lr,
+ "static_channels": static_channels,
+ "date_embedding": date_embedding,
"augment_pipe": None,
+ "use_apex_gn": use_apex_gn,
}
if use_patch_grad_acc is not None:
loss_fn_kwargs[
@@ -494,13 +437,12 @@ def main(cfg: DictConfig) -> None:
):
loss = loss_fn(**loss_fn_kwargs)
- loss = loss.sum() / batch_size_per_gpu
+ loss = loss.sum() / batch_size_per_gpu / patch_num_per_iter
loss_accum += (
loss
/ num_accumulation_rounds
/ len(patch_nums_iter)
)
- loss_accum += loss / num_accumulation_rounds
with nvtx.annotate(f"loss backward", color="yellow"):
loss.backward()
@@ -514,45 +456,21 @@ def main(cfg: DictConfig) -> None:
)
average_loss = (loss_sum / dist.world_size).cpu().item()
- # update running mean of average loss since last periodic task
- average_loss_running_mean += (
- average_loss - average_loss_running_mean
- ) / n_average_loss_running_mean
- n_average_loss_running_mean += 1
-
- if dist.rank == 0:
- writer.add_scalar("training_loss", average_loss, cur_nimg)
- writer.add_scalar(
- "training_loss_running_mean",
- average_loss_running_mean,
- cur_nimg,
- )
-
- ptt = is_time_for_periodic_task(
- cur_nimg,
- cfg.training.io.print_progress_freq,
- done,
- cfg.training.hp.total_batch_size,
- dist.rank,
- rank_0_only=True,
- )
- if ptt:
- # reset running mean of average loss
- average_loss_running_mean = 0
- n_average_loss_running_mean = 1
+ # update running mean of average loss since last periodic task
+ average_loss_running_mean += (
+ average_loss - average_loss_running_mean
+ ) / n_average_loss_running_mean
+ n_average_loss_running_mean += 1
# Update weights.
with nvtx.annotate("update weights", color="blue"):
- lr_rampup = cfg.training.hp.lr_rampup # ramp up the learning rate
- for g in optimizer.param_groups:
- if lr_rampup > 0:
- g["lr"] = cfg.training.hp.lr * min(cur_nimg / lr_rampup, 1)
- if cur_nimg >= lr_rampup:
- g["lr"] *= cfg.training.hp.lr_decay ** ((cur_nimg - lr_rampup) // cfg.training.hp.lr_decay_rate)
- current_lr = g["lr"]
- if dist.rank == 0:
- writer.add_scalar("learning_rate", current_lr, cur_nimg)
+ current_lr = update_learning_rate(optimizer,
+ cfg.training.hp.lr,
+ cfg.training.hp.lr_rampup,
+ cfg.training.hp.lr_decay,
+ cfg.training.hp.lr_decay_rate,
+ cur_nimg)
handle_and_clip_gradients(
model, grad_clip_threshold=cfg.training.hp.grad_clip_threshold
)
@@ -562,138 +480,35 @@ def main(cfg: DictConfig) -> None:
cur_nimg += cfg.training.hp.total_batch_size
done = cur_nimg >= cfg.training.hp.training_duration
- if is_time_for_periodic_task(
- cur_nimg,
- cfg.training.io.print_progress_freq,
- done,
- cfg.training.hp.total_batch_size,
- dist.rank,
- rank_0_only=True,
- ):
- # Print stats if we crossed the printing threshold with this batch
- tick_end_time = time.time()
- fields = []
- fields += [f"samples {cur_nimg:<9.1f}"]
- fields += [f"training_loss {average_loss:<7.2f}"]
- fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"]
- fields += [f"learning_rate {current_lr:<7.8f}"]
- fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"]
- fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"]
- fields += [
- f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}"
- ]
- fields += [
- f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"
- ]
- if torch.cuda.is_available():
- fields += [
- f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}"
- ]
- fields += [
- f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}"
- ]
- torch.cuda.reset_peak_memory_stats()
- logger0.info(" ".join(fields))
+ # Logging training progress
+ if is_time_for_periodic_task(
+ cur_nimg,
+ cfg.training.io.print_progress_freq,
+ done,
+ cfg.training.hp.total_batch_size,
+ dist.rank,
+ rank_0_only=True,
+ ):
+ # Print stats if we crossed the printing threshold with this batch
+ log_training_progress(logger0, cfg.logging.method, dist, cur_nimg, tick_start_nimg, tick_start_time,
+ tick_read_time, start_time, average_loss, average_loss_running_mean, current_lr)
+ # reset running mean of average loss
+ average_loss_running_mean = 0
+ n_average_loss_running_mean = 1
+
+ # Validation
with nvtx.annotate("validation", color="red"):
- # Validation
- if validation_dataset_iterator is not None:
- valid_loss_accum = 0
- if is_time_for_periodic_task(
- cur_nimg,
- cfg.training.io.validation_freq,
- done,
- cfg.training.hp.total_batch_size,
- dist.rank,
- ):
- with torch.no_grad():
- for _ in range(cfg.training.io.validation_steps):
- (
- img_clean_valid,
- img_lr_valid,
- *lead_time_label_valid,
- ) = next(validation_dataset_iterator)
-
- if use_apex_gn:
- img_clean_valid = img_clean_valid.to(
- dist.device,
- dtype=input_dtype,
- non_blocking=True,
- ).to(memory_format=torch.channels_last)
- img_lr_valid = img_lr_valid.to(
- dist.device,
- dtype=input_dtype,
- non_blocking=True,
- ).to(memory_format=torch.channels_last)
-
- else:
- img_clean_valid = (
- img_clean_valid.to(dist.device)
- .to(input_dtype)
- .contiguous()
- )
- img_lr_valid = (
- img_lr_valid.to(dist.device)
- .to(input_dtype)
- .contiguous()
- )
-
- loss_valid_kwargs = {
- "net": model,
- "img_clean": img_clean_valid,
- "img_lr": img_lr_valid,
- "augment_pipe": None,
- }
- if use_patch_grad_acc is not None:
- loss_valid_kwargs[
- "use_patch_grad_acc"
- ] = use_patch_grad_acc
- if lead_time_label_valid:
- lead_time_label_valid = (
- lead_time_label_valid[0]
- .to(dist.device)
- .contiguous()
- )
- loss_valid_kwargs.update(
- {"lead_time_label": lead_time_label_valid}
- )
- if use_patch_grad_acc:
- loss_fn.y_mean = None
-
- for patch_num_per_iter in patch_nums_iter:
- if patching is not None:
- patching.set_patch_num(patch_num_per_iter)
- loss_valid_kwargs.update(
- {"patching": patching}
- )
- with torch.autocast(
- "cuda", dtype=amp_dtype, enabled=enable_amp
- ):
- loss_valid = loss_fn(**loss_valid_kwargs)
-
- loss_valid = (
- (loss_valid.sum() / batch_size_per_gpu)
- .cpu()
- .item()
- )
- valid_loss_accum += (
- loss_valid
- / cfg.training.io.validation_steps
- / len(patch_nums_iter)
- )
- valid_loss_sum = torch.tensor(
- [valid_loss_accum], device=dist.device
- )
- if dist.world_size > 1:
- torch.distributed.barrier()
- torch.distributed.all_reduce(
- valid_loss_sum, op=torch.distributed.ReduceOp.SUM
- )
- average_valid_loss = valid_loss_sum / dist.world_size
- if dist.rank == 0:
- writer.add_scalar(
- "validation_loss", average_valid_loss, cur_nimg
- )
+ if validation_dataset_iterator is not None and is_time_for_periodic_task(
+ cur_nimg,
+ cfg.training.io.validation_freq,
+ done,
+ cfg.training.hp.total_batch_size,
+ dist.rank,
+ ):
+ training_manager.run_validation(cur_nimg, validation_dataset_iterator, model, loss_fn,
+ cfg.training.io.get("validation_steps",1), static_channels,
+ batch_size_per_gpu, patching, patch_nums_iter, use_patch_grad_acc)
# Save checkpoints
@@ -714,9 +529,10 @@ def main(cfg: DictConfig) -> None:
epoch=cur_nimg,
)
+ if dist.world_size > 1:
+ torch.distributed.barrier()
# Done.
logger0.info("Training Completed.")
-
if __name__ == "__main__":
main()
\ No newline at end of file
diff --git a/src/hirad/training/training_manager.py b/src/hirad/training/training_manager.py
new file mode 100644
index 00000000..cba52432
--- /dev/null
+++ b/src/hirad/training/training_manager.py
@@ -0,0 +1,287 @@
+from abc import ABC, abstractmethod
+import torch
+import numpy as np
+import mlflow
+import os
+import json
+
+from hirad.distributed import DistributedManager
+from hirad.utils.console import PythonLogger
+from hirad.datasets import DownscalingDataset
+from hirad.models import UNet, EDMPrecondSuperResolution
+from hirad.utils.dataset_utils import regrid_icon_to_rotlatlon
+from hirad.utils.checkpoint import load_checkpoint
+
+
+class TrainingManagerBase(ABC):
+ def __init__(self, dist: DistributedManager, logger: PythonLogger):
+ self.dist = dist
+ self.logger = logger
+
+ @abstractmethod
+ def load_and_preprocess_batch(self):
+ pass
+
+ @abstractmethod
+ def get_static_data(self):
+ pass
+
+ @abstractmethod
+ def create_model(self):
+ pass
+
+ @abstractmethod
+ def run_validation(self):
+ pass
+
+
+class TrainingManagerCorrDiff(TrainingManagerBase):
+ def __init__(
+ self,
+ dist: DistributedManager,
+ logger: PythonLogger,
+ dataset: DownscalingDataset,
+ input_dtype: torch.dtype,
+ img_shape: tuple[int, int],
+ n_month_hour_channels: int,
+ fp16: bool,
+ profile_mode: bool,
+ enable_amp: bool,
+ amp_dtype: torch.dtype,
+ use_apex_gn: bool,
+ is_real_target: bool,
+ songunet_checkpoint_level: int,
+ use_patching: bool,
+ hr_mean_conditioning: bool,
+ logging_method: str,
+ ):
+ super().__init__(dist, logger)
+ self.dataset = dataset
+ self.input_dtype = input_dtype
+ self.img_shape = img_shape
+ self.is_real_target = is_real_target
+ self.n_month_hour_channels = n_month_hour_channels
+ self.fp16 = fp16
+ self.songunet_checkpoint_level = songunet_checkpoint_level
+ self.profile_mode = profile_mode
+ self.enable_amp = enable_amp
+ self.amp_dtype = amp_dtype
+ self.use_apex_gn = use_apex_gn
+ self.use_patching = use_patching
+ self.hr_mean_conditioning = hr_mean_conditioning
+ self.logging_method = logging_method
+
+
+ def load_and_preprocess_batch(self, dataset_iterator):
+ """Load a batch from the iterator and preprocess it (interpolate, normalize, move to device)."""
+ img_clean, img_lr, *date_str = next(dataset_iterator)
+
+ # Interpolate and normalize low-res input
+ img_lr = self.dataset.interpolator(
+ img_lr.to(self.dist.device, dtype=self.input_dtype)
+ ).reshape(*img_lr.shape[:-1], *self.img_shape).flip(-2)
+ img_lr = self.dataset.normalize_input(img_lr)
+
+ # Process high-res target
+ if self.is_real_target:
+ img_clean = regrid_icon_to_rotlatlon(
+ img_clean.to(self.dist.device, dtype=self.input_dtype),
+ self.dataset.regrid_indices_real,
+ self.dataset.regrid_weights_real,
+ )
+ if self.dataset.trim_edge > 0:
+ img_clean = img_clean[:, :, self.dataset.trim_edge:-self.dataset.trim_edge,
+ self.dataset.trim_edge:-self.dataset.trim_edge]
+ img_clean = img_clean.flip(-2)
+ else:
+ img_clean = img_clean.to(self.dist.device, dtype=self.input_dtype)
+ img_clean = img_clean.reshape(*img_clean.shape[:-1], *self.img_shape).flip(-2)
+ img_clean = self.dataset.normalize_output(img_clean)
+
+ # Date embedding
+ date_embedding = None
+ if self.n_month_hour_channels > 0:
+ date_embedding = self.dataset.make_time_grids(*date_str, self.dist.device, dtype=self.input_dtype)
+
+ # Memory format
+ if self.use_apex_gn:
+ img_clean = img_clean.to(self.dist.device, dtype=self.input_dtype, non_blocking=True).to(
+ memory_format=torch.channels_last
+ )
+ img_lr = img_lr.to(self.dist.device, dtype=self.input_dtype, non_blocking=True).to(
+ memory_format=torch.channels_last
+ )
+ else:
+ img_clean = img_clean.to(self.dist.device).to(self.input_dtype).contiguous()
+ img_lr = img_lr.to(self.dist.device).to(self.input_dtype).contiguous()
+
+ return img_clean, img_lr, date_embedding
+
+ def get_static_data(self):
+ """Get static data from the dataset, preprocess it and move to device."""
+ static_channels = self.dataset.get_static_data()
+ if static_channels is not None:
+ if isinstance(static_channels, np.ndarray):
+ static_channels = torch.from_numpy(static_channels)
+
+ static_channels = static_channels[None, ::].flip(-2)
+ if self.use_apex_gn:
+ static_channels = static_channels.to(
+ self.dist.device,
+ dtype=self.input_dtype,
+ non_blocking=True,
+ ).to(memory_format=torch.channels_last)
+ else:
+ static_channels = (
+ static_channels.to(self.dist.device)
+ .to(self.input_dtype)
+ .contiguous()
+ )
+ return static_channels
+
+
+ def create_model(self, cfg_model_name: str, cfg_model_args: dict, prob_channels: list = []):
+ """Instantiate the model."""
+ n_input_channels = len(self.dataset.input_channels())
+ n_static_channels = len(self.dataset.static_channels())
+ n_output_channels = len(self.dataset.output_channels())
+
+ img_in_channels = n_input_channels + n_static_channels + self.n_month_hour_channels
+ if self.hr_mean_conditioning:
+ img_in_channels += n_output_channels
+ if self.use_patching:
+ img_in_channels += n_input_channels + n_static_channels
+
+ img_out_channels = n_output_channels
+
+ self.logger.info(f"Creating model {cfg_model_name} with {img_in_channels} input channels and {img_out_channels} output channels.")
+
+ model_args = { # default parameters for all networks
+ "img_out_channels": img_out_channels,
+ "img_resolution": list(self.img_shape),
+ "use_fp16": self.fp16,
+ "checkpoint_level": self.songunet_checkpoint_level,
+ }
+ if cfg_model_name == "lt_aware_ce_regression":
+ model_args["prob_channels"] = prob_channels
+
+ if cfg_model_args: # override defaults from config file
+ model_args.update(cfg_model_args)
+
+ model_args["use_apex_gn"] = self.use_apex_gn
+ model_args["profile_mode"] = self.profile_mode
+
+ if self.enable_amp:
+ model_args["amp_mode"] = self.enable_amp
+
+
+ if cfg_model_name == "regression":
+ model = UNet(
+ img_in_channels=img_in_channels + model_args["N_grid_channels"],
+ **model_args,
+ )
+ model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"]
+ elif cfg_model_name == "lt_aware_ce_regression":
+ model = UNet(
+ img_in_channels=img_in_channels
+ + model_args["N_grid_channels"]
+ + model_args["lead_time_channels"],
+ **model_args,
+ )
+ model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"]
+ elif cfg_model_name == "lt_aware_patched_diffusion":
+ model = EDMPrecondSuperResolution(
+ img_in_channels=img_in_channels
+ + model_args["N_grid_channels"]
+ + model_args["lead_time_channels"],
+ **model_args,
+ )
+ model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"]
+ else: # diffusion or patched diffusion
+ model = EDMPrecondSuperResolution(
+ img_in_channels=img_in_channels + model_args["N_grid_channels"],
+ **model_args,
+ )
+ model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"]
+
+ return model, model_args
+
+ def load_regression_model(self, regression_checkpoint_path: str):
+ """Load the regression model for the residual loss if applicable."""
+
+ if not os.path.isdir(regression_checkpoint_path):
+ raise FileNotFoundError(
+ f"Expected this regression checkpoint but not found: {regression_checkpoint_path}"
+ )
+ #TODO make regression model loading more robust (model type is both in rergession_checkpoint_path and regression_name)
+ #TODO add the option to choose epoch to load from / regression_checkpoint_path is now a folder
+ regression_model_args_path = os.path.join(regression_checkpoint_path, 'model_args.json')
+ if not os.path.isfile(regression_model_args_path):
+ raise FileNotFoundError(f"Missing config file at '{regression_model_args_path}'.")
+
+ with open(regression_model_args_path, 'r') as f:
+ regression_model_args = json.load(f)
+
+ regression_model_args.update({
+ "use_apex_gn": self.use_apex_gn,
+ "profile_mode": self.profile_mode,
+ "amp_mode": self.enable_amp,
+ })
+
+ regression_net = UNet(**regression_model_args)
+
+ _ = load_checkpoint(
+ path=regression_checkpoint_path,
+ model=regression_net,
+ device=self.dist.device
+ )
+ regression_net.eval().requires_grad_(False).to(self.dist.device)
+ if self.use_apex_gn:
+ regression_net.to(memory_format=torch.channels_last)
+ self.logger.success("Loaded the pre-trained regression model")
+
+ return regression_net
+
+
+ def run_validation(self, cur_nimg, validation_dataset_iterator, model, loss_fn, validation_steps,
+ static_channels, batch_size_per_gpu, patching,
+ patch_nums_iter, use_patch_grad_acc):
+ """Run validation and return average validation loss."""
+ valid_loss_accum = 0
+ with torch.no_grad():
+ lead_time_label_valid = None
+ for _ in range(validation_steps):
+ img_clean_valid, img_lr_valid, date_embedding = self.load_and_preprocess_batch(validation_dataset_iterator)
+
+ loss_valid_kwargs = {
+ "net": model,
+ "img_clean": img_clean_valid,
+ "img_lr": img_lr_valid,
+ "static_channels": static_channels,
+ "date_embedding": date_embedding,
+ "augment_pipe": None,
+ "use_apex_gn": self.use_apex_gn,
+ }
+ if use_patch_grad_acc is not None:
+ loss_valid_kwargs["use_patch_grad_acc"] = use_patch_grad_acc
+ if use_patch_grad_acc:
+ loss_fn.y_mean = None
+
+ for patch_num_per_iter in patch_nums_iter:
+ if patching is not None:
+ patching.set_patch_num(patch_num_per_iter)
+ loss_valid_kwargs["patching"] = patching
+ with torch.autocast("cuda", dtype=self.amp_dtype, enabled=self.enable_amp):
+ loss_valid = loss_fn(**loss_valid_kwargs)
+ loss_valid = (loss_valid.sum() / batch_size_per_gpu / patch_num_per_iter).cpu().item()
+ valid_loss_accum += loss_valid / validation_steps / len(patch_nums_iter)
+
+ valid_loss_sum = torch.tensor([valid_loss_accum], device=self.dist.device)
+ if self.dist.world_size > 1:
+ torch.distributed.barrier()
+ torch.distributed.all_reduce(valid_loss_sum, op=torch.distributed.ReduceOp.SUM)
+ average_valid_loss = (valid_loss_sum / self.dist.world_size).item()
+ if self.dist.rank == 0 and self.logging_method == "mlflow":
+ mlflow.log_metric("validation_loss", average_valid_loss, cur_nimg)
+
+ return average_valid_loss
\ No newline at end of file
diff --git a/src/hirad/utils/capture.py b/src/hirad/utils/capture.py
deleted file mode 100644
index 9c38d5aa..00000000
--- a/src/hirad/utils/capture.py
+++ /dev/null
@@ -1,513 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
-# SPDX-FileCopyrightText: All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import functools
-import logging
-import os
-import time
-from contextlib import nullcontext
-from logging import Logger
-from typing import Any, Callable, Dict, NewType, Optional, Union
-
-import torch
-
-from hirad.distributed import DistributedManager
-
-float16 = NewType("float16", torch.float16)
-bfloat16 = NewType("bfloat16", torch.bfloat16)
-optim = NewType("optim", torch.optim)
-
-
-class _StaticCapture(object):
- """Base class for StaticCapture decorator.
-
- This class should not be used, rather StaticCaptureTraining and StaticCaptureEvaluate
- should be used instead for training and evaluation functions.
- """
-
- # Grad scaler and checkpoint class variables use for checkpoint saving and loading
- # Since an instance of Static capture does not exist for checkpoint functions
- # one must use class functions to access state dicts
- _amp_scalers = {}
- _amp_scaler_checkpoints = {}
- _logger = logging.getLogger("capture")
-
- def __new__(cls, *args, **kwargs):
- obj = super(_StaticCapture, cls).__new__(cls)
- obj.amp_scalers = cls._amp_scalers
- obj.amp_scaler_checkpoints = cls._amp_scaler_checkpoints
- obj.logger = cls._logger
- return obj
-
- def __init__(
- self,
- model: "physicsnemo.Module",
- optim: Optional[optim] = None,
- logger: Optional[Logger] = None,
- use_graphs: bool = True,
- use_autocast: bool = True,
- use_gradscaler: bool = True,
- compile: bool = False,
- cuda_graph_warmup: int = 11,
- amp_type: Union[float16, bfloat16] = torch.float16,
- gradient_clip_norm: Optional[float] = None,
- label: Optional[str] = None,
- ):
- self.logger = logger if logger else self.logger
- # Checkpoint label (used for gradscaler)
- self.label = label if label else f"scaler_{len(self.amp_scalers.keys())}"
-
- # DDP fix
- if not isinstance(model, physicsnemo.models.Module) and hasattr(
- model, "module"
- ):
- model = model.module
-
- if not isinstance(model, physicsnemo.models.Module):
- self.logger.error("Model not a PhysicsNeMo Module!")
- raise ValueError("Model not a PhysicsNeMo Module!")
- if compile:
- model = torch.compile(model)
-
- self.model = model
-
- self.optim = optim
- self.eval = False
- self.no_grad = False
- self.gradient_clip_norm = gradient_clip_norm
-
- # Set up toggles for optimizations
- if not (amp_type == torch.float16 or amp_type == torch.bfloat16):
- raise ValueError("AMP type must be torch.float16 or torch.bfloat16")
- # CUDA device
- if "cuda" in str(self.model.device):
- # CUDA graphs
- if use_graphs and not self.model.meta.cuda_graphs:
- self.logger.warning(
- f"Model {model.meta.name} does not support CUDA graphs, turning off"
- )
- use_graphs = False
- self.cuda_graphs_enabled = use_graphs
-
- # AMP GPU
- if not self.model.meta.amp_gpu:
- self.logger.warning(
- f"Model {model.meta.name} does not support AMP on GPUs, turning off"
- )
- use_autocast = False
- use_gradscaler = False
- self.use_gradscaler = use_gradscaler
- self.use_autocast = use_autocast
-
- self.amp_device = "cuda"
- # Check if bfloat16 is suppored on the GPU
- if amp_type == torch.bfloat16 and not torch.cuda.is_bf16_supported():
- self.logger.warning(
- "Current CUDA device does not support bfloat16, falling back to float16"
- )
- amp_type = torch.float16
- self.amp_dtype = amp_type
- # Gradient Scaler
- scaler_enabled = self.use_gradscaler and amp_type == torch.float16
- self.scaler = self._init_amp_scaler(scaler_enabled, self.logger)
-
- self.replay_stream = torch.cuda.Stream(self.model.device)
- # CPU device
- else:
- self.cuda_graphs_enabled = False
- # AMP CPU
- if use_autocast and not self.model.meta.amp_cpu:
- self.logger.warning(
- f"Model {model.meta.name} does not support AMP on CPUs, turning off"
- )
- use_autocast = False
-
- self.use_autocast = use_autocast
- self.amp_device = "cpu"
- # Only float16 is supported on CPUs
- # https://pytorch.org/docs/stable/amp.html#cpu-op-specific-behavior
- if amp_type == torch.float16 and use_autocast:
- self.logger.warning(
- "torch.float16 not supported for CPU AMP, switching to torch.bfloat16"
- )
- amp_type = torch.bfloat16
- self.amp_dtype = torch.bfloat16
- # Gradient Scaler (not enabled)
- self.scaler = self._init_amp_scaler(False, self.logger)
- self.replay_stream = None
-
- if self.cuda_graphs_enabled:
- self.graph = torch.cuda.CUDAGraph()
-
- self.output = None
- self.iteration = 0
- self.cuda_graph_warmup = cuda_graph_warmup # Default for DDP = 11
-
- def __call__(self, fn: Callable) -> Callable:
- self.function = fn
-
- @functools.wraps(fn)
- def decorated(*args: Any, **kwds: Any) -> Any:
- """Training step decorator function"""
-
- with torch.no_grad() if self.no_grad else nullcontext():
- if self.cuda_graphs_enabled:
- self._cuda_graph_forward(*args, **kwds)
- else:
- self._zero_grads()
- self.output = self._amp_forward(*args, **kwds)
-
- if not self.eval:
- # Update model parameters
- self.scaler.step(self.optim)
- self.scaler.update()
-
- return self.output
-
- return decorated
-
- def _cuda_graph_forward(self, *args: Any, **kwargs: Any) -> Any:
- """Forward training step with CUDA graphs
-
- Returns
- -------
- Any
- Output of neural network forward
- """
- # Graph warm up
- if self.iteration < self.cuda_graph_warmup:
- self.replay_stream.wait_stream(torch.cuda.current_stream())
- self._zero_grads()
- with torch.cuda.stream(self.replay_stream):
- output = self._amp_forward(*args, **kwargs)
- self.output = output.detach()
- torch.cuda.current_stream().wait_stream(self.replay_stream)
- # CUDA Graphs
- else:
- # Graph record
- if self.iteration == self.cuda_graph_warmup:
- self.logger.warning(f"Recording graph of '{self.function.__name__}'")
- self._zero_grads()
- torch.cuda.synchronize()
- if DistributedManager().distributed:
- torch.distributed.barrier()
- # TODO: temporary workaround till this issue is fixed:
- # https://github.com/pytorch/pytorch/pull/104487#issuecomment-1638665876
- delay = os.environ.get("PHYSICSNEMO_CUDA_GRAPH_CAPTURE_DELAY", "10")
- time.sleep(int(delay))
- with torch.cuda.graph(self.graph):
- output = self._amp_forward(*args, **kwargs)
- self.output = output.detach()
- # Graph replay
- self.graph.replay()
-
- self.iteration += 1
- return self.output
-
- def _zero_grads(self):
- """Zero gradients
-
- Default to `set_to_none` since this will in general have lower memory
- footprint, and can modestly improve performance.
-
- Note
- ----
- Zeroing gradients can potentially cause an invalid CUDA memory access in another
- graph. However if your graph involves gradients, you much set your gradients to none.
- If there is already a graph recorded that includes these gradients, this will error.
- Use the `NoGrad` version of capture to avoid this issue for inferencers / validators.
- """
- # Skip zeroing if no grad is being used
- if self.no_grad:
- return
-
- try:
- self.optim.zero_grad(set_to_none=True)
- except Exception:
- if self.optim:
- self.optim.zero_grad()
- # For apex optim support and eval mode (need to reset model grads)
- self.model.zero_grad(set_to_none=True)
-
- def _amp_forward(self, *args, **kwargs) -> Any:
- """Compute loss and gradients (if training) with AMP
-
- Returns
- -------
- Any
- Output of neural network forward
- """
- with torch.autocast(
- self.amp_device, enabled=self.use_autocast, dtype=self.amp_dtype
- ):
- output = self.function(*args, **kwargs)
-
- if not self.eval:
- # In training mode output should be the loss
- self.scaler.scale(output).backward()
- if self.gradient_clip_norm is not None:
- self.scaler.unscale_(self.optim)
- torch.nn.utils.clip_grad_norm_(
- self.model.parameters(), self.gradient_clip_norm
- )
-
- return output
-
- def _init_amp_scaler(
- self, scaler_enabled: bool, logger: Logger
- ) -> torch.cuda.amp.GradScaler:
- # Create gradient scaler
- scaler = torch.cuda.amp.GradScaler(enabled=scaler_enabled)
- # Store scaler in class variable
- self.amp_scalers[self.label] = scaler
- logging.debug(f"Created gradient scaler {self.label}")
-
- # If our checkpoint dictionary has weights for this scaler lets load
- if self.label in self.amp_scaler_checkpoints:
- try:
- scaler.load_state_dict(self.amp_scaler_checkpoints[self.label])
- del self.amp_scaler_checkpoints[self.label]
- self.logger.info(f"Loaded grad scaler state dictionary {self.label}.")
- except Exception as e:
- self.logger.error(
- f"Failed to load grad scaler {self.label} state dict from saved "
- + "checkpoints. Did you switch the ordering of declared static captures?"
- )
- raise ValueError(e)
- return scaler
-
- @classmethod
- def state_dict(cls) -> Dict[str, Any]:
- """Class method for accsessing the StaticCapture state dictionary.
- Use this in a training checkpoint function.
-
- Returns
- -------
- Dict[str, Any]
- Dictionary of states to save for file
- """
- scaler_states = {}
- for key, value in cls._amp_scalers.items():
- scaler_states[key] = value.state_dict()
-
- return scaler_states
-
- @classmethod
- def load_state_dict(cls, state_dict: Dict[str, Any]) -> None:
- """Class method for loading a StaticCapture state dictionary.
- Use this in a training checkpoint function.
-
- Returns
- -------
- Dict[str, Any]
- Dictionary of states to save for file
- """
- for key, value in state_dict.items():
- # If scaler has been created already load the weights
- if key in cls._amp_scalers:
- try:
- cls._amp_scalers[key].load_state_dict(value)
- cls._logger.info(f"Loaded grad scaler state dictionary {key}.")
- except Exception as e:
- cls._logger.error(
- f"Failed to load grad scaler state dict with id {key}."
- + " Something went wrong!"
- )
- raise ValueError(e)
- # Otherwise store in checkpoints for later use
- else:
- cls._amp_scaler_checkpoints[key] = value
-
- @classmethod
- def reset_state(cls):
- cls._amp_scalers = {}
- cls._amp_scaler_checkpoints = {}
-
-
-class StaticCaptureTraining(_StaticCapture):
- """A performance optimization decorator for PyTorch training functions.
-
- This class should be initialized as a decorator on a function that computes the
- forward pass of the neural network and loss function. The user should only call the
- defind training step function. This will apply optimizations including: AMP and
- Cuda Graphs.
-
- Parameters
- ----------
- model : physicsnemo.models.Module
- PhysicsNeMo Model
- optim : torch.optim
- Optimizer
- logger : Optional[Logger], optional
- PhysicsNeMo Launch Logger, by default None
- use_graphs : bool, optional
- Toggle CUDA graphs if supported by model, by default True
- use_amp : bool, optional
- Toggle AMP if supported by mode, by default True
- cuda_graph_warmup : int, optional
- Number of warmup steps for cuda graphs, by default 11
- amp_type : Union[float16, bfloat16], optional
- Auto casting type for AMP, by default torch.float16
- gradient_clip_norm : Optional[float], optional
- Threshold for gradient clipping
- label : Optional[str], optional
- Static capture checkpoint label, by default None
-
- Raises
- ------
- ValueError
- If the model provided is not a physicsnemo.models.Module. I.e. has no meta data.
-
- Example
- -------
- >>> # Create model
- >>> model = physicsnemo.models.mlp.FullyConnected(2, 64, 2)
- >>> input = torch.rand(8, 2)
- >>> output = torch.rand(8, 2)
- >>> # Create optimizer
- >>> optim = torch.optim.Adam(model.parameters(), lr=0.001)
- >>> # Create training step function with optimization wrapper
- >>> @StaticCaptureTraining(model=model, optim=optim)
- ... def training_step(model, invar, outvar):
- ... predvar = model(invar)
- ... loss = torch.sum(torch.pow(predvar - outvar, 2))
- ... return loss
- ...
- >>> # Sample training loop
- >>> for i in range(3):
- ... loss = training_step(model, input, output)
- ...
-
- Note
- ----
- Static captures must be checkpointed when training using the `state_dict()` if AMP
- is being used with gradient scaler. By default, this requires static captures to be
- instantiated in the same order as when they were checkpointed. The label parameter
- can be used to relax/circumvent this ordering requirement.
-
- Note
- ----
- Capturing multiple cuda graphs in a single program can lead to potential invalid CUDA
- memory access errors on some systems. Prioritize capturing training graphs when this
- occurs.
- """
-
- def __init__(
- self,
- model: "physicsnemo.Module",
- optim: torch.optim,
- logger: Optional[Logger] = None,
- use_graphs: bool = True,
- use_amp: bool = True,
- compile: bool = False,
- cuda_graph_warmup: int = 11,
- amp_type: Union[float16, bfloat16] = torch.float16,
- gradient_clip_norm: Optional[float] = None,
- label: Optional[str] = None,
- ):
- super().__init__(
- model,
- optim,
- logger,
- use_graphs,
- use_amp,
- use_amp,
- compile,
- cuda_graph_warmup,
- amp_type,
- gradient_clip_norm,
- label,
- )
-
-
-class StaticCaptureEvaluateNoGrad(_StaticCapture):
-
- """An performance optimization decorator for PyTorch no grad evaluation.
-
- This class should be initialized as a decorator on a function that computes run the
- forward pass of the model that does not require gradient calculations. This is the
- recommended method to use for inference and validation methods.
-
- Parameters
- ----------
- model : physicsnemo.models.Module
- PhysicsNeMo Model
- logger : Optional[Logger], optional
- PhysicsNeMo Launch Logger, by default None
- use_graphs : bool, optional
- Toggle CUDA graphs if supported by model, by default True
- use_amp : bool, optional
- Toggle AMP if supported by mode, by default True
- cuda_graph_warmup : int, optional
- Number of warmup steps for cuda graphs, by default 11
- amp_type : Union[float16, bfloat16], optional
- Auto casting type for AMP, by default torch.float16
- label : Optional[str], optional
- Static capture checkpoint label, by default None
-
- Raises
- ------
- ValueError
- If the model provided is not a physicsnemo.models.Module. I.e. has no meta data.
-
- Example
- -------
- >>> # Create model
- >>> model = physicsnemo.models.mlp.FullyConnected(2, 64, 2)
- >>> input = torch.rand(8, 2)
- >>> # Create evaluate function with optimization wrapper
- >>> @StaticCaptureEvaluateNoGrad(model=model)
- ... def eval_step(model, invar):
- ... predvar = model(invar)
- ... return predvar
- ...
- >>> output = eval_step(model, input)
- >>> output.size()
- torch.Size([8, 2])
-
- Note
- ----
- Capturing multiple cuda graphs in a single program can lead to potential invalid CUDA
- memory access errors on some systems. Prioritize capturing training graphs when this
- occurs.
- """
-
- def __init__(
- self,
- model: "physicsnemo.Module",
- logger: Optional[Logger] = None,
- use_graphs: bool = True,
- use_amp: bool = True,
- compile: bool = False,
- cuda_graph_warmup: int = 11,
- amp_type: Union[float16, bfloat16] = torch.float16,
- label: Optional[str] = None,
- ):
- super().__init__(
- model,
- None,
- logger,
- use_graphs,
- use_amp,
- compile,
- False,
- cuda_graph_warmup,
- amp_type,
- None,
- label,
- )
- self.eval = True # No optimizer/scaler calls
- self.no_grad = True # No grad context and no grad zeroing
diff --git a/src/hirad/utils/checkpoint.py b/src/hirad/utils/checkpoint.py
index a346b16f..5bd48313 100644
--- a/src/hirad/utils/checkpoint.py
+++ b/src/hirad/utils/checkpoint.py
@@ -174,10 +174,19 @@ def save_checkpoint(
Path(path).mkdir(parents=True, exist_ok=True)
# == Saving model checkpoint ==
- if model:
+ if model is not None:
+ # Strip out optimization wrapper if exists before stripping DDP
+ if isinstance(model, torch._dynamo.eval_frame.OptimizedModule):
+ model = model._orig_mod
+
if hasattr(model, "module"):
# Strip out DDP layer
model = model.module
+
+ # Strip out optimization wrapper if exists after stripping DDP
+ if isinstance(model, torch._dynamo.eval_frame.OptimizedModule):
+ model = model._orig_mod
+
# Base name of model is meta.name unless pytorch model
name = model.__class__.__name__
# Get full file path / name
@@ -223,7 +232,7 @@ def save_checkpoint(
def load_checkpoint(
path: str,
- model: torch.nn.Module,
+ model: torch.nn.Module = None,
optimizer: Union[optimizer, None] = None,
scheduler: Union[scheduler, None] = None,
scaler: Union[scaler, None] = None,
@@ -268,27 +277,33 @@ def load_checkpoint(
)
return 0
- # == Loading model checkpoint ==
- if hasattr(model, "module"):
- # Strip out DDP layer
- model = model.module
- # Base name of model is meta.name unless pytorch model
- name = model.__class__.__name__
- # Get full file path / name
- file_name = _get_checkpoint_filename(
- path, name, index=epoch,
- )
- if not Path(file_name).exists():
- checkpoint_logging.error(
- f"Could not find valid model file {file_name}, skipping load"
+ if model is not None:
+ # == Loading model checkpoint ==
+ if hasattr(model, "module"):
+ # Strip out DDP layer
+ model = model.module
+ # Strip out optimization wrapper if exists
+ if isinstance(model, torch._dynamo.eval_frame.OptimizedModule):
+ model = model._orig_mod
+ checkpoint_logging.warning(
+ f"Model {model.__class__.__name__} is already compiled, consider loading first and then compiling."
+ )
+ name = model.__class__.__name__
+ # Get full file path / name
+ file_name = _get_checkpoint_filename(
+ path, name, index=epoch,
)
- else:
- # Load state dictionary
- model.load_state_dict(torch.load(file_name, map_location=device))
+ if not Path(file_name).exists():
+ checkpoint_logging.warning(
+ f"Could not find valid model file {file_name}, skipping load"
+ )
+ else:
+ # Load state dictionary
+ model.load_state_dict(torch.load(file_name, map_location=device))
- checkpoint_logging.success(
- f"Loaded model state dictionary {file_name} to device {device}"
- )
+ checkpoint_logging.success(
+ f"Loaded model state dictionary {file_name} to device {device}"
+ )
# == Loading training checkpoint ==
checkpoint_filename = _get_checkpoint_filename(path, index=epoch, model_type="pt")
diff --git a/src/hirad/utils/dataset_utils.py b/src/hirad/utils/dataset_utils.py
new file mode 100644
index 00000000..4a2c3d2f
--- /dev/null
+++ b/src/hirad/utils/dataset_utils.py
@@ -0,0 +1,263 @@
+import torch
+import numpy as np
+from scipy.spatial import Delaunay
+from typing import Optional
+
+
+def regrid_icon_to_rotlatlon(
+ data: torch.Tensor,
+ indices: torch.Tensor,
+ weights: torch.Tensor,
+ nx: int = 1170,
+ ny: int = 786,
+) -> torch.Tensor:
+ """Regrid ICON unstructured data to a rotated lat-lon grid.
+
+ Parameters
+ ----------
+ data : torch.Tensor
+ Input data with the last axis being the unstructured grid dimension.
+ indices : torch.LongTensor
+ Remap indices of shape (n_target, n_stencil).
+ weights : torch.Tensor
+ Remap weights of shape (n_target, n_stencil).
+ nx, ny : int
+ Output grid dimensions.
+
+ Returns
+ -------
+ torch.Tensor
+ Regridded data of shape (*(batch,channel), ny, nx).
+ """
+ out_shape = data.shape[:-1] + (ny, nx)
+
+ # Gather stencil values: (..., n_target, n_stencil)
+ # indices: (n_target, n_stencil) -> expand to match data batch dims
+ values = data[..., indices] # (..., n_target, n_stencil)
+
+ # Weighted sum: multiply then reduce over stencil dim
+ result = (values * weights).sum(dim=-1) # (..., n_target)
+
+ # Clamp to stencil min/max to avoid extrapolation
+ vmin = values.amin(dim=-1)
+ vmax = values.amax(dim=-1)
+ result = result.clamp(min=vmin, max=vmax)
+
+ return result.reshape(out_shape)
+
+
+class GridData():
+ """
+ Performs interpolation from irregular points to target grid using Delaunay triangulation.
+
+ This class uses barycentric coordinate interpolation within Delaunay triangles to map
+ values from an original set of scattered points to a target set of points. It has same
+ behavior as scipy.interpolate.griddata with method='linear', but is optimized for
+ repeated interpolations on the same set of points.
+
+ Attributes:
+ longitudes_orig (np.ndarray): Longitude coordinates of original points.
+ latitudes_orig (np.ndarray): Latitude coordinates of original points.
+ longitudes_target (np.ndarray): Longitude coordinates of target points.
+ latitudes_target (np.ndarray): Latitude coordinates of target points.
+ is_torch (bool): Whether tensors are prepared for PyTorch operations.
+ device (torch.device | None): Device for PyTorch tensors if applicable.
+ """
+
+ def __init__(
+ self,
+ longitudes_orig: np.ndarray,
+ latitudes_orig: np.ndarray,
+ longitudes_target: np.ndarray,
+ latitudes_target: np.ndarray
+ ) -> None:
+ """
+ Initialize the GridData interpolator.
+
+ Args:
+ longitudes_orig: 1D array of longitude values for original points.
+ latitudes_orig: 1D array of latitude values for original points.
+ longitudes_target: 1D array of longitude values for target points.
+ latitudes_target: 1D array of latitude values for target points.
+
+ Raises:
+ ValueError: If input arrays have incompatible shapes.
+ """
+ # Validate inputs
+ if len(longitudes_orig) != len(latitudes_orig):
+ raise ValueError("Original longitude and latitude arrays must have same length")
+ if len(longitudes_target) != len(latitudes_target):
+ raise ValueError("Target longitude and latitude arrays must have same length")
+
+ self.longitudes_orig = np.asarray(longitudes_orig)
+ self.latitudes_orig = np.asarray(latitudes_orig)
+
+ self.longitudes_target = np.asarray(longitudes_target)
+ self.latitudes_target = np.asarray(latitudes_target)
+
+ self.is_torch = False
+ self.device = None
+
+ self._prepare_interpolation()
+
+ def _prepare_interpolation(self):
+ """
+ Prepare interpolation by computing Delaunay triangulation and barycentric coordinates.
+
+ This method:
+ 1. Computes Delaunay triangulation of original points
+ 2. Finds which simplex each target point belongs to
+ 3. Precomputes barycentric coordinates (lambda1, lambda2, lambda3) for interpolation
+ """
+ # Compute Delaunay triangulation
+ coords_orig = np.stack([self.longitudes_orig, self.latitudes_orig], axis=-1)
+ self._tri = Delaunay(coords_orig)
+
+ # Find simplex indices for target points
+ coords_target = np.stack([self.longitudes_target, self.latitudes_target], axis=-1)
+ self._simplex_id = self._tri.find_simplex(coords_target)
+
+ # Check for points outside convex hull
+ if np.any(self._simplex_id == -1):
+ n_outside = np.sum(self._simplex_id == -1)
+ print(f"Warning: {n_outside} target points are outside the convex hull of original points")
+
+ # Get corner coordinates of simplices
+ longitudes_corners = self.longitudes_orig[self._tri.simplices]
+ latitudes_corners = self.latitudes_orig[self._tri.simplices]
+
+ # Get corner coordinates for each target point's simplex
+ longitude_corners_per_target = longitudes_corners[self._simplex_id]
+ latitude_corners_per_target = latitudes_corners[self._simplex_id]
+
+ # Extract traingle vertices
+ x1, y1 = longitude_corners_per_target[:, 0], latitude_corners_per_target[:, 0]
+ x2, y2 = longitude_corners_per_target[:, 1], latitude_corners_per_target[:, 1]
+ x3, y3 = longitude_corners_per_target[:, 2], latitude_corners_per_target[:, 2]
+
+ denominator = (y2 - y3) * (x1 - x3) + (x3 - x2) * (y1 - y3)
+
+ self._lambda1 = ((y2 - y3) * (self.longitudes_target - x3) +
+ (x3 - x2) * (self.latitudes_target - y3)) / denominator
+
+ self._lambda2 = ((y3 - y1) * (self.longitudes_target - x3) +
+ (x1 - x3) * (self.latitudes_target - y3)) / denominator
+
+ self._lambda3 = 1 - self._lambda1 - self._lambda2
+
+
+ def to(self, device: str | torch.device) -> None:
+ """
+ Prepare barycentric coordinates and simplex indices for PyTorch operations.
+
+ This method converts the precomputed numpy arrays to PyTorch tensors
+ and moves them to the specified device.
+
+ Args:
+ device: The torch device to move tensors to (e.g., 'cpu' or 'cuda').
+ """
+ if isinstance(device, str):
+ device = torch.device(device)
+ if self.is_torch and self.device == device:
+ return # Already on the correct device
+ elif self.is_torch and self.device != device:
+ self.device = device
+ self._lambda1 = self._lambda1.to(device)
+ self._lambda2 = self._lambda2.to(device)
+ self._lambda3 = self._lambda3.to(device)
+ self._simplex_id = self._simplex_id.to(device)
+ self._tri.simplices = self._tri.simplices.to(device)
+ elif not self.is_torch:
+ self.to_torch(device)
+
+ def to_torch(self, device: torch.device | str ='cpu') -> None:
+ """
+ Prepare barycentric coordinates and simplex indices for PyTorch operations.
+
+ This method converts the precomputed numpy arrays to PyTorch tensors
+ and moves them to the specified device.
+
+ Args:
+ device: The torch device to move tensors to (e.g., 'cpu' or 'cuda').
+ """
+ if isinstance(device, str):
+ device = torch.device(device)
+ self.device = device
+
+ self._lambda1 = torch.from_numpy(self._lambda1).to(device)
+ self._lambda2 = torch.from_numpy(self._lambda2).to(device)
+ self._lambda3 = torch.from_numpy(self._lambda3).to(device)
+
+ # Convert indexing arrays
+ self._simplex_id = torch.from_numpy(self._simplex_id).to(device)
+ self._tri.simplices = torch.from_numpy(self._tri.simplices).to(device)
+
+ self.is_torch = True
+
+ def to_numpy(self) -> None:
+ """
+ Convert barycentric coordinates and simplex indices back to numpy arrays.
+
+ This method converts the precomputed PyTorch tensors back to numpy arrays.
+ """
+ if not self.is_torch:
+ return # Already in numpy format
+
+ self._lambda1 = self._lambda1.cpu().numpy()
+ self._lambda2 = self._lambda2.cpu().numpy()
+ self._lambda3 = self._lambda3.cpu().numpy()
+
+ self._simplex_id = self._simplex_id.cpu().numpy()
+ self._tri.simplices = self._tri.simplices.cpu().numpy()
+
+ self.is_torch = False
+ self.device = None
+
+ def interpolate(self, values: np.ndarray | torch.Tensor, fill_value: Optional[float] = np.nan) -> np.ndarray:
+ """
+ Interpolate values from original points to target points.
+
+ Args:
+ values: Array of shape (n_channels, n_original_points) or (batch, n_channels, n_original_points) containing values
+ at original point locations to be interpolated.
+ fill_value: Value to use for target points outside the convex hull.
+ Defaults to np.nan.
+
+ Returns:
+ Array of shape (n_channels, n_target_points) or (batch, n_channels, n_original_points) with interpolated values.
+
+ Raises:
+ ValueError: If values shape is incompatible with original points.
+ """
+ if values.shape[-1] != len(self.longitudes_orig):
+ raise ValueError(
+ f"Expected values with shape (..., {len(self.longitudes_orig)}), "
+ f"got shape {values.shape}"
+ )
+
+ # Save original shape for reshaping later
+ orig_shape = values.shape
+
+ # In case that there is a batch dimension, flatten it for easier indexing
+ values = values.reshape(-1, values.shape[-1]) # shape (batch*n_channels, n_original_points)
+
+ # Find the corner values of each simplice
+ values_simplices = values[:,self._tri.simplices]
+
+ # Find the simplice corner values for each target point
+ values_per_target_simplices = values_simplices[:,self._simplex_id]
+
+ # Perform barycentric interpolation
+ out = (self._lambda1 * values_per_target_simplices[:,:,0] +
+ self._lambda2 * values_per_target_simplices[:,:,1] +
+ self._lambda3 * values_per_target_simplices[:,:,2])
+
+ # Handle points outside convex hull
+ if (not self.is_torch and np.any(self._simplex_id == -1)) or (self.is_torch and torch.any(self._simplex_id == -1)):
+ out[::, self._simplex_id == -1] = fill_value
+
+ return out.reshape(orig_shape[:-1] + (out.shape[-1],))
+
+ def __call__(self, values: np.ndarray | torch.Tensor, fill_value: Optional[float] = np.nan) -> np.ndarray:
+ """Alias for forward method to make class callable."""
+ return self.interpolate(values, fill_value)
\ No newline at end of file
diff --git a/src/hirad/utils/env_info.py b/src/hirad/utils/env_info.py
new file mode 100644
index 00000000..f1cfe642
--- /dev/null
+++ b/src/hirad/utils/env_info.py
@@ -0,0 +1,166 @@
+"""
+Environment Introspection Module
+================================
+
+This module provides functionality to introspect the Python environment, listing all non-standard
+library modules along with their versions and Git information if they are part of a Git repository.
+It helps in understanding the environment setup by detailing the modules in use, their versions, and
+relevant Git metadata.
+"""
+
+import platform
+import os
+import sys
+from types import ModuleType
+from typing import Any, Optional
+from git import Repo, InvalidGitRepositoryError
+
+
+def get_module_version(module: ModuleType) -> Optional[str]:
+ """
+ Retrieve the version of a module if available.
+
+ This function attempts to get the version of a module by accessing its ``__version__``
+ attribute. It checks if the version is a string to ensure correctness.
+
+ :param module: The module whose version is to be retrieved.
+ :type module: ModuleType
+ :return: The version string if available and valid, otherwise None.
+ :rtype: Optional[str]
+ """
+ version = getattr(module, "__version__", None)
+ if isinstance(version, str):
+ return version
+ return None
+
+
+def get_git_info(path: str) -> Optional[dict[str, Any]]:
+ """
+ Collect basic Git metadata for a given repository path.
+
+ This function checks if the given path is part of a Git repository and collects metadata such as
+ the commit SHA, modified files, untracked files, and remote URLs.
+
+ :param path: The path to check for Git repository metadata.
+ :type path: str
+ :return: A dictionary containing Git metadata if the path is a Git repository, otherwise None.
+ :rtype: Optional[Dict[str, Any]]
+ """
+ try:
+ repo = Repo(path, search_parent_directories=True)
+ diff = ""
+ for diff_item in repo.index.diff(None, create_patch=True):
+ a_path = diff_item.a_blob.abspath if diff_item.a_blob else ""
+ b_path = diff_item.b_blob.abspath if diff_item.b_blob else ""
+ diff_content = diff_item.diff
+ if isinstance(diff_content, bytes):
+ diff_content = diff_content.decode("utf-8")
+ elif diff_content is None:
+ diff_content = ""
+ diff += f"--- a{a_path}\n+++ b{b_path}\n{diff_content}\n\n"
+ git_info = {
+ "sha1": repo.head.commit.hexsha,
+ "diff": diff,
+ "untracked_files": sorted(repo.untracked_files),
+ "remotes": [r.url for r in repo.remotes],
+ }
+ return git_info
+ except InvalidGitRepositoryError:
+ return None
+
+
+def get_module_git_info(module: ModuleType) -> Optional[dict[str, Any]]:
+ """
+ Get Git information for a module if its directory is a Git repository.
+
+ This function determines the directory of the module and checks if it is part of a Git
+ repository. If it is, it collects and returns the Git metadata.
+
+ :param module: The module to check for Git information.
+ :type module: ModuleType
+ :return: A dictionary containing Git metadata if the module is in a Git repository, otherwise
+ None.
+ :rtype: Optional[Dict[str, Any]]
+ """
+ module_path = getattr(module, "__file__", None)
+ if module_path is None or not os.path.isabs(module_path):
+ return None
+ module_dir = os.path.dirname(module_path)
+ return get_git_info(module_dir)
+
+
+def get_env_info(flatten: bool = True, exclude_prefixes: list[str] = None) -> tuple[dict[str, dict[str, Any]], str]:
+ """
+ List all non-standard library modules with their versions and Git information if available.
+
+ This function iterates over all loaded modules in the Python environment, filtering out
+ built-in modules. It collects version and Git information for each remaining module.
+
+ :return: A tuple containing two elements:
+ - A dictionary mapping module names to their metadata
+ - A string containing concatenated Git diffs from all modules
+ :rtype: tuple[dict[str, dict[str, Any]], str]
+ :rtype: Dict[str, Dict[str, Any]]
+ """
+ env_info = {}
+ diffs: list[str] = []
+
+ exclude_prefixes = exclude_prefixes or []
+
+ for name, module in sys.modules.copy().items():
+ if name in sys.builtin_module_names or name.endswith(('.version','._version')):
+ continue
+
+ if any(name == prefix or name.startswith(prefix + ".") for prefix in exclude_prefixes):
+ continue
+
+ version = get_module_version(module)
+ git_info = get_module_git_info(module)
+ if version is None and git_info is None:
+ continue
+
+ module_info: dict[str, Any] = {"version": version}
+ if git_info is not None:
+ module_diff = git_info.pop("diff")
+ if module_diff and module_diff not in diffs:
+ diffs.append(module_diff)
+ module_info["git"] = git_info
+
+ env_info[name] = module_info
+
+ env_info["python"] = {"version": platform.python_version()}
+
+ diffs_str = "\n".join(diffs)
+
+ if flatten:
+ return flatten_dict(env_info), diffs_str
+
+ return env_info, diffs_str
+
+
+def flatten_dict(
+ d: dict[str, Any], parent_key: str = "", sep: str = "."
+) -> dict[str, Any]:
+ """
+ Flatten a nested dictionary.
+
+ This function recursively traverses a nested dictionary and flattens it into a single-level
+ dictionary with keys formed by concatenating the nested keys using a separator.
+
+ :param d: The dictionary to flatten.
+ :type d: Dict[str, Any]
+ :param parent_key: The base key to use for concatenation.
+ :type parent_key: str
+ :param sep: The separator to use for concatenating keys.
+ :type sep: str
+ :return: A flattened dictionary.
+ :rtype: Dict[str, Any]
+ """
+ items = {}
+ for k, v in d.items():
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
+ if isinstance(v, dict):
+ items.update(flatten_dict(v, new_key, sep=sep))
+ else:
+ items[new_key] = v
+ return items
diff --git a/src/hirad/utils/function_utils.py b/src/hirad/utils/function_utils.py
index 347457c3..13fce1e4 100644
--- a/src/hirad/utils/function_utils.py
+++ b/src/hirad/utils/function_utils.py
@@ -17,19 +17,8 @@
"""Miscellaneous utility classes and functions."""
-import contextlib
-import ctypes
import datetime
-import fnmatch
-import importlib
-import inspect
-import os
-import re
-import shutil
-import sys
-import types
-import warnings
-from typing import Any, Iterator, List, Tuple, Union
+from typing import Iterator
import cftime
import numpy as np
@@ -38,25 +27,6 @@
# ruff: noqa: E722 PERF203 S110 E713 S324
-class EasyDict(dict): # pragma: no cover
- """
- Convenience class that behaves like a dict but allows access with the attribute
- syntax.
- """
-
- def __getattr__(self, name: str) -> Any:
- try:
- return self[name]
- except KeyError:
- raise AttributeError(name)
-
- def __setattr__(self, name: str, value: Any) -> None:
- self[name] = value
-
- def __delattr__(self, name: str) -> None:
- del self[name]
-
-
class StackedRandomGenerator: # pragma: no cover
"""
Wrapper for torch.Generator that allows specifying a different random seed
@@ -96,32 +66,8 @@ def randint(self, *args, size, **kwargs):
)
-def parse_int_list(s): # pragma: no cover
- """
- Parse a comma separated list of numbers or ranges and return a list of ints.
- Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
- """
- if isinstance(s, list):
- return s
- ranges = []
- range_re = re.compile(r"^(\d+)-(\d+)$")
- for p in s.split(","):
- m = range_re.match(p)
- if m:
- ranges.extend(range(int(m.group(1)), int(m.group(2)) + 1))
- else:
- ranges.append(int(p))
- return ranges
-
-
# Small util functions
# -------------------------------------------------------------------------------------
-def convert_datetime_to_cftime(
- time: datetime.datetime, cls=cftime.DatetimeGregorian
-) -> cftime.DatetimeGregorian:
- """Convert a Python datetime object to a cftime DatetimeGregorian object."""
- return cls(time.year, time.month, time.day, time.hour, time.minute, time.second)
-
def time_range(
start_time: datetime.datetime,
@@ -135,417 +81,30 @@ def time_range(
yield t
t += step
+def get_time_from_range(times_range, time_format="%Y-%m-%dT%H:%M:%S"):
+ """Generates a list of times within a given range.
-def format_time(seconds: Union[int, float]) -> str: # pragma: no cover
- """Convert the seconds to human readable string with days, hours, minutes and seconds."""
- s = int(np.rint(seconds))
-
- if s < 60:
- return "{0}s".format(s)
- elif s < 60 * 60:
- return "{0}m {1:02}s".format(s // 60, s % 60)
- elif s < 24 * 60 * 60:
- return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
- else:
- return "{0}d {1:02}h {2:02}m".format(
- s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60
- )
-
-
-def format_time_brief(seconds: Union[int, float]) -> str: # pragma: no cover
- """Convert the seconds to human readable string with days, hours, minutes and seconds."""
- s = int(np.rint(seconds))
-
- if s < 60:
- return "{0}s".format(s)
- elif s < 60 * 60:
- return "{0}m {1:02}s".format(s // 60, s % 60)
- elif s < 24 * 60 * 60:
- return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
- else:
- return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
-
-
-def tuple_product(t: Tuple) -> Any: # pragma: no cover
- """Calculate the product of the tuple elements."""
- result = 1
-
- for v in t:
- result *= v
-
- return result
-
-
-_str_to_ctype = {
- "uint8": ctypes.c_ubyte,
- "uint16": ctypes.c_uint16,
- "uint32": ctypes.c_uint32,
- "uint64": ctypes.c_uint64,
- "int8": ctypes.c_byte,
- "int16": ctypes.c_int16,
- "int32": ctypes.c_int32,
- "int64": ctypes.c_int64,
- "float32": ctypes.c_float,
- "float64": ctypes.c_double,
-}
-
-
-def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: # pragma: no cover
- """
- Given a type name string (or an object having a __name__ attribute), return
- matching Numpy and ctypes types that have the same size in bytes.
- """
- type_str = None
-
- if isinstance(type_obj, str):
- type_str = type_obj
- elif hasattr(type_obj, "__name__"):
- type_str = type_obj.__name__
- elif hasattr(type_obj, "name"):
- type_str = type_obj.name
- else:
- raise RuntimeError("Cannot infer type name from input")
-
- if type_str not in _str_to_ctype.keys():
- raise ValueError("Unknown type name: " + type_str)
-
- my_dtype = np.dtype(type_str)
- my_ctype = _str_to_ctype[type_str]
-
- if my_dtype.itemsize != ctypes.sizeof(my_ctype):
- raise ValueError(
- "Numpy and ctypes types for '{}' have different sizes!".format(type_str)
- )
-
- return my_dtype, my_ctype
-
-
-# Functionality to import modules/objects by name, and call functions by name
-# -------------------------------------------------------------------------------------
-
-
-def get_module_from_obj_name(
- obj_name: str,
-) -> Tuple[types.ModuleType, str]: # pragma: no cover
- """
- Searches for the underlying module behind the name to some python object.
- Returns the module and the object name (original name with module part removed).
- """
-
- # allow convenience shorthands, substitute them by full names
- obj_name = re.sub("^np.", "numpy.", obj_name)
- obj_name = re.sub("^tf.", "tensorflow.", obj_name)
-
- # list alternatives for (module_name, local_obj_name)
- parts = obj_name.split(".")
- name_pairs = [
- (".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)
- ]
-
- # try each alternative in turn
- for module_name, local_obj_name in name_pairs:
- try:
- module = importlib.import_module(module_name) # may raise ImportError
- get_obj_from_module(module, local_obj_name) # may raise AttributeError
- return module, local_obj_name
- except:
- pass
-
- # maybe some of the modules themselves contain errors?
- for module_name, _local_obj_name in name_pairs:
- try:
- importlib.import_module(module_name) # may raise ImportError
- except ImportError:
- if not str(sys.exc_info()[1]).startswith(
- "No module named '" + module_name + "'"
- ):
- raise
+ Args:
+ times_range: A list containing start time, end time, and optional interval (hours).
+ time_format: The format of the input times (default: "%Y-%m-%dT%H:%M:%S").
- # maybe the requested attribute is missing?
- for module_name, local_obj_name in name_pairs:
- try:
- module = importlib.import_module(module_name) # may raise ImportError
- get_obj_from_module(module, local_obj_name) # may raise AttributeError
- except ImportError:
- pass
-
- # we are out of luck, but we have no idea why
- raise ImportError(obj_name)
-
-
-def get_obj_from_module(
- module: types.ModuleType, obj_name: str
-) -> Any: # pragma: no cover
- """
- Traverses the object name and returns the last (rightmost) python object.
- """
- if obj_name == "":
- return module
- obj = module
- for part in obj_name.split("."):
- obj = getattr(obj, part)
- return obj
-
-
-def get_obj_by_name(name: str) -> Any: # pragma: no cover
- """
- Finds the python object with the given name.
- """
- module, obj_name = get_module_from_obj_name(name)
- return get_obj_from_module(module, obj_name)
-
-
-def call_func_by_name(
- *args, func_name: str = None, **kwargs
-) -> Any: # pragma: no cover
- """
- Finds the python object with the given name and calls it as a function.
+ Returns:
+ A list of times within the specified range.
"""
- if func_name is None:
- raise ValueError("func_name must be specified")
- func_obj = get_obj_by_name(func_name)
- if not callable(func_obj):
- raise ValueError(func_name + " is not callable")
- return func_obj(*args, **kwargs)
-
-def construct_class_by_name(
- *args, class_name: str = None, **kwargs
-) -> Any: # pragma: no cover
- """
- Finds the python class with the given name and constructs it with the given
- arguments.
- """
- return call_func_by_name(*args, func_name=class_name, **kwargs)
-
-
-def get_module_dir_by_obj_name(obj_name: str) -> str: # pragma: no cover
- """
- Get the directory path of the module containing the given object name.
- """
- module, _ = get_module_from_obj_name(obj_name)
- return os.path.dirname(inspect.getfile(module))
-
-
-def is_top_level_function(obj: Any) -> bool: # pragma: no cover
- """
- Determine whether the given object is a top-level function, i.e., defined at module
- scope using 'def'.
- """
- return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
-
-
-def get_top_level_function_name(obj: Any) -> str: # pragma: no cover
- """
- Return the fully-qualified name of a top-level function.
- """
- if not is_top_level_function(obj):
- raise ValueError("Object is not a top-level function")
- module = obj.__module__
- if module == "__main__":
- module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
- return module + "." + obj.__name__
-
-
-# File system helpers
-# ------------------------------------------------------------------------------------------
-
-
-def list_dir_recursively_with_ignore(
- dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False
-) -> List[Tuple[str, str]]: # pragma: no cover
- """
- List all files recursively in a given directory while ignoring given file and
- directory names. Returns list of tuples containing both absolute and relative paths.
- """
- if not os.path.isdir(dir_path):
- raise RuntimeError(f"Directory does not exist: {dir_path}")
- base_name = os.path.basename(os.path.normpath(dir_path))
-
- if ignores is None:
- ignores = []
-
- result = []
-
- for root, dirs, files in os.walk(dir_path, topdown=True):
- for ignore_ in ignores:
- dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
-
- # dirs need to be edited in-place
- for d in dirs_to_remove:
- dirs.remove(d)
-
- files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
-
- absolute_paths = [os.path.join(root, f) for f in files]
- relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
-
- if add_base_to_relative:
- relative_paths = [os.path.join(base_name, p) for p in relative_paths]
-
- if len(absolute_paths) != len(relative_paths):
- raise ValueError("Number of absolute and relative paths do not match")
- result += zip(absolute_paths, relative_paths)
-
- return result
-
-
-def copy_files_and_create_dirs(
- files: List[Tuple[str, str]]
-) -> None: # pragma: no cover
- """
- Takes in a list of tuples of (src, dst) paths and copies files.
- Will create all necessary directories.
- """
- for file in files:
- target_dir_name = os.path.dirname(file[1])
-
- # will create all intermediate-level directories
- if not os.path.exists(target_dir_name):
- os.makedirs(target_dir_name)
-
- shutil.copyfile(file[0], file[1])
-
-
-# ----------------------------------------------------------------------------
-# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
-# same constant is used multiple times.
-
-_constant_cache = dict()
-
-
-def constant(
- value, shape=None, dtype=None, device=None, memory_format=None
-): # pragma: no cover
- """Cached construction of constant tensors"""
- value = np.asarray(value)
- if shape is not None:
- shape = tuple(shape)
- if dtype is None:
- dtype = torch.get_default_dtype()
- if device is None:
- device = torch.device("cpu")
- if memory_format is None:
- memory_format = torch.contiguous_format
-
- key = (
- value.shape,
- value.dtype,
- value.tobytes(),
- shape,
- dtype,
- device,
- memory_format,
+ start_time = datetime.datetime.strptime(times_range[0], time_format)
+ end_time = datetime.datetime.strptime(times_range[1], time_format)
+ interval = (
+ datetime.timedelta(hours=times_range[2])
+ if len(times_range) > 2
+ else datetime.timedelta(hours=1)
)
- tensor = _constant_cache.get(key, None)
- if tensor is None:
- tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
- if shape is not None:
- tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
- tensor = tensor.contiguous(memory_format=memory_format)
- _constant_cache[key] = tensor
- return tensor
-
-
-# ----------------------------------------------------------------------------
-# Replace NaN/Inf with specified numerical values.
-
-try:
- nan_to_num = torch.nan_to_num # 1.8.0a0
-except AttributeError:
-
- def nan_to_num(
- input, nan=0.0, posinf=None, neginf=None, *, out=None
- ): # pylint: disable=redefined-builtin # pragma: no cover
- """Replace NaN/Inf with specified numerical values"""
- if not isinstance(input, torch.Tensor):
- raise TypeError("input should be a Tensor")
- if posinf is None:
- posinf = torch.finfo(input.dtype).max
- if neginf is None:
- neginf = torch.finfo(input.dtype).min
- if nan != 0:
- raise ValueError("nan_to_num only supports nan=0")
- return torch.clamp(
- input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out
- )
-
-
-# ----------------------------------------------------------------------------
-# Symbolic assert.
-
-try:
- symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
-except AttributeError:
- symbolic_assert = torch.Assert # 1.7.0
-
-# ----------------------------------------------------------------------------
-# Context manager to temporarily suppress known warnings in torch.jit.trace().
-# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
-
-@contextlib.contextmanager
-def suppress_tracer_warnings(): # pragma: no cover
- """
- Context manager to temporarily suppress known warnings in torch.jit.trace().
- Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
- """
- flt = ("ignore", None, torch.jit.TracerWarning, None, 0)
- warnings.filters.insert(0, flt)
- yield
- warnings.filters.remove(flt)
-
-
-# ----------------------------------------------------------------------------
-# Assert that the shape of a tensor matches the given list of integers.
-# None indicates that the size of a dimension is allowed to vary.
-# Performs symbolic assertion when used in torch.jit.trace().
-
-
-def assert_shape(tensor, ref_shape): # pragma: no cover
- """
- Assert that the shape of a tensor matches the given list of integers.
- None indicates that the size of a dimension is allowed to vary.
- Performs symbolic assertion when used in torch.jit.trace().
- """
- if tensor.ndim != len(ref_shape):
- raise AssertionError(
- f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}"
- )
- for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
- if ref_size is None:
- pass
- elif isinstance(ref_size, torch.Tensor):
- with suppress_tracer_warnings(): # as_tensor results are registered as constants
- symbolic_assert(
- torch.equal(torch.as_tensor(size), ref_size),
- f"Wrong size for dimension {idx}",
- )
- elif isinstance(size, torch.Tensor):
- with suppress_tracer_warnings(): # as_tensor results are registered as constants
- symbolic_assert(
- torch.equal(size, torch.as_tensor(ref_size)),
- f"Wrong size for dimension {idx}: expected {ref_size}",
- )
- elif size != ref_size:
- raise AssertionError(
- f"Wrong size for dimension {idx}: got {size}, expected {ref_size}"
- )
-
-
-# ----------------------------------------------------------------------------
-# Function decorator that calls torch.autograd.profiler.record_function().
-
-
-def profiled_function(fn): # pragma: no cover
- """Function decorator that calls torch.autograd.profiler.record_function()."""
-
- def decorator(*args, **kwargs):
- with torch.autograd.profiler.record_function(fn.__name__):
- return fn(*args, **kwargs)
-
- decorator.__name__ = fn.__name__
- return decorator
+ times = [
+ t.strftime(time_format)
+ for t in time_range(start_time, end_time, interval, inclusive=True)
+ ]
+ return times
# ----------------------------------------------------------------------------
@@ -574,6 +133,8 @@ class InfiniteSampler(torch.utils.data.Sampler[int]): # pragma: no cover
window_size : float, default=0.5
Fraction of dataset to use as window for shuffling. Must be between 0 and 1.
A larger window means more thorough shuffling but slower iteration.
+ start_idx : int, default=0
+ The initial index to use for the sampler. This is used for resuming training.
"""
def __init__(
@@ -584,6 +145,7 @@ def __init__(
shuffle: bool = True,
seed: int = 0,
window_size: float = 0.5,
+ start_idx: int = 0,
):
if not len(dataset) > 0:
raise ValueError("Dataset must contain at least one item")
@@ -600,7 +162,8 @@ def __init__(
self.shuffle = shuffle
self.seed = seed
self.window_size = window_size
-
+ self.start_idx = start_idx
+
def __iter__(self) -> Iterator[int]:
order = np.arange(len(self.dataset))
rnd = None
@@ -610,189 +173,14 @@ def __iter__(self) -> Iterator[int]:
rnd.shuffle(order)
window = int(np.rint(order.size * self.window_size))
- idx = 0
+ idx = self.start_idx % order.size
while True:
i = idx % order.size
if idx % self.num_replicas == self.rank:
yield order[i]
- if window >= 2:
- j = (i - rnd.randint(window)) % order.size
+ if window >= 2 and i>0:
+ window_size = min(i+1, window)
+ j = (i - rnd.randint(window_size)) % order.size
order[i], order[j] = order[j], order[i]
idx += 1
-
-
-# ----------------------------------------------------------------------------
-# Utilities for operating with torch.nn.Module parameters and buffers.
-
-
-def params_and_buffers(module): # pragma: no cover
- """Get parameters and buffers of a nn.Module"""
- if not isinstance(module, torch.nn.Module):
- raise TypeError("module must be a torch.nn.Module instance")
- return list(module.parameters()) + list(module.buffers())
-
-
-def named_params_and_buffers(module): # pragma: no cover
- """Get named parameters and buffers of a nn.Module"""
- if not isinstance(module, torch.nn.Module):
- raise TypeError("module must be a torch.nn.Module instance")
- return list(module.named_parameters()) + list(module.named_buffers())
-
-
-@torch.no_grad()
-def copy_params_and_buffers(
- src_module, dst_module, require_all=False
-): # pragma: no cover
- """Copy parameters and buffers from a source module to target module"""
- if not isinstance(src_module, torch.nn.Module):
- raise TypeError("src_module must be a torch.nn.Module instance")
- if not isinstance(dst_module, torch.nn.Module):
- raise TypeError("dst_module must be a torch.nn.Module instance")
- src_tensors = dict(named_params_and_buffers(src_module))
- for name, tensor in named_params_and_buffers(dst_module):
- if not ((name in src_tensors) or (not require_all)):
- raise ValueError(f"Missing source tensor for {name}")
- if name in src_tensors:
- tensor.copy_(src_tensors[name])
-
-
-# ----------------------------------------------------------------------------
-# Context manager for easily enabling/disabling DistributedDataParallel
-# synchronization.
-
-
-@contextlib.contextmanager
-def ddp_sync(module, sync): # pragma: no cover
- """
- Context manager for easily enabling/disabling DistributedDataParallel
- synchronization.
- """
- if not isinstance(module, torch.nn.Module):
- raise TypeError("module must be a torch.nn.Module instance")
- if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
- yield
- else:
- with module.no_sync():
- yield
-
-
-# ----------------------------------------------------------------------------
-# Check DistributedDataParallel consistency across processes.
-
-
-def check_ddp_consistency(module, ignore_regex=None): # pragma: no cover
- """Check DistributedDataParallel consistency across processes."""
- if not isinstance(module, torch.nn.Module):
- raise TypeError("module must be a torch.nn.Module instance")
- for name, tensor in named_params_and_buffers(module):
- fullname = type(module).__name__ + "." + name
- if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
- continue
- tensor = tensor.detach()
- if tensor.is_floating_point():
- tensor = nan_to_num(tensor)
- other = tensor.clone()
- torch.distributed.broadcast(tensor=other, src=0)
- if not (tensor == other).all():
- raise RuntimeError(f"DDP consistency check failed for {fullname}")
-
-
-# ----------------------------------------------------------------------------
-# Print summary table of module hierarchy.
-
-
-def print_module_summary(
- module, inputs, max_nesting=3, skip_redundant=True
-): # pragma: no cover
- """Print summary table of module hierarchy."""
- if not isinstance(module, torch.nn.Module):
- raise TypeError("module must be a torch.nn.Module instance")
- if isinstance(module, torch.jit.ScriptModule):
- raise TypeError("module must not be a torch.jit.ScriptModule instance")
- if not isinstance(inputs, (tuple, list)):
- raise TypeError("inputs must be a tuple or list")
-
- # Register hooks.
- entries = []
- nesting = [0]
-
- def pre_hook(_mod, _inputs):
- nesting[0] += 1
-
- def post_hook(mod, _inputs, outputs):
- nesting[0] -= 1
- if nesting[0] <= max_nesting:
- outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
- outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
- entries.append(EasyDict(mod=mod, outputs=outputs))
-
- hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
- hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
-
- # Run module.
- outputs = module(*inputs)
- for hook in hooks:
- hook.remove()
-
- # Identify unique outputs, parameters, and buffers.
- tensors_seen = set()
- for e in entries:
- e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
- e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
- e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
- tensors_seen |= {
- id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs
- }
-
- # Filter out redundant entries.
- if skip_redundant:
- entries = [
- e
- for e in entries
- if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)
- ]
-
- # Construct table.
- rows = [
- [type(module).__name__, "Parameters", "Buffers", "Output shape", "Datatype"]
- ]
- rows += [["---"] * len(rows[0])]
- param_total = 0
- buffer_total = 0
- submodule_names = {mod: name for name, mod in module.named_modules()}
- for e in entries:
- name = "" if e.mod is module else submodule_names[e.mod]
- param_size = sum(t.numel() for t in e.unique_params)
- buffer_size = sum(t.numel() for t in e.unique_buffers)
- output_shapes = [str(list(t.shape)) for t in e.outputs]
- output_dtypes = [str(t.dtype).split(".")[-1] for t in e.outputs]
- rows += [
- [
- name + (":0" if len(e.outputs) >= 2 else ""),
- str(param_size) if param_size else "-",
- str(buffer_size) if buffer_size else "-",
- (output_shapes + ["-"])[0],
- (output_dtypes + ["-"])[0],
- ]
- ]
- for idx in range(1, len(e.outputs)):
- rows += [
- [name + f":{idx}", "-", "-", output_shapes[idx], output_dtypes[idx]]
- ]
- param_total += param_size
- buffer_total += buffer_size
- rows += [["---"] * len(rows[0])]
- rows += [["Total", str(param_total), str(buffer_total), "-", "-"]]
-
- # Print table.
- widths = [max(len(cell) for cell in column) for column in zip(*rows)]
- for row in rows:
- print(
- " ".join(
- cell + " " * (width - len(cell)) for cell, width in zip(row, widths)
- )
- )
- return outputs
-
-
-# ----------------------------------------------------------------------------
+ idx = idx % order.size
diff --git a/src/hirad/utils/generate_utils.py b/src/hirad/utils/generate_utils.py
deleted file mode 100644
index 43f83b63..00000000
--- a/src/hirad/utils/generate_utils.py
+++ /dev/null
@@ -1,22 +0,0 @@
-import datetime
-from hirad.datasets import init_dataset_from_config
-from .function_utils import convert_datetime_to_cftime
-
-
-def get_dataset_and_sampler(dataset_cfg, times, has_lead_time=False):
- """
- Get a dataset and sampler for generation.
- """
- (dataset, _) = init_dataset_from_config(dataset_cfg, batch_size=1)
- # if has_lead_time:
- # plot_times = times
- # else:
- # plot_times = [
- # datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%S")
- # for time in times
- # ]
- all_times = dataset.time()
- time_indices = [all_times.index(t) for t in times]
- sampler = time_indices
-
- return dataset, sampler
\ No newline at end of file
diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py
index 8665536b..2e6db5c0 100644
--- a/src/hirad/utils/inference_utils.py
+++ b/src/hirad/utils/inference_utils.py
@@ -14,18 +14,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import datetime
from typing import Optional
+import os
+import logging
+import time
-import cftime
import nvtx
+import numpy as np
import torch
import tqdm
-from .function_utils import StackedRandomGenerator, time_range
+from .function_utils import StackedRandomGenerator
-from .stochastic_sampler import stochastic_sampler
-from .deterministic_sampler import deterministic_sampler
+
+def _sync_t() -> float:
+ """Wall-clock time after flushing all pending CUDA ops on the current device."""
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ return time.perf_counter()
############################################################################
# CorrDiff Generation Utilities #
@@ -37,6 +43,10 @@ def regression_step(
img_lr: torch.Tensor,
latents_shape: torch.Size,
lead_time_label: Optional[torch.Tensor] = None,
+ static_channels: Optional[torch.Tensor] = None,
+ date_embedding: Optional[torch.Tensor] = None,
+ use_apex_gn: bool = False,
+ _timings: Optional[dict] = None,
) -> torch.Tensor:
"""
Perform a regression step to produce ensemble mean prediction.
@@ -58,6 +68,11 @@ def regression_step(
lead_time_label : Optional[torch.Tensor], optional
Lead time label tensor for lead time conditioning,
with shape (1, lead_time_dims). Default is None.
+ static_channels : torch.Tensor, optional
+ Static channels input of shape (C_static, H, W).
+
+ date_embedding : torch.Tensor, optional
+ Date embedding input of shape (B, C_date).
Returns
-------
@@ -79,17 +94,39 @@ def regression_step(
f"but found {img_lr.shape[0]}."
)
+ _t = _sync_t if _timings is not None else (lambda: 0.0)
+ _t0 = _t()
+ if static_channels is not None:
+ img_lr = torch.cat(
+ (img_lr, static_channels.expand(img_lr.shape[0], *static_channels.shape[1:])),
+ dim=1,
+ )
+
+ if date_embedding is not None:
+ date_embedding = date_embedding[:, :, None, None].expand(*date_embedding.shape[:2], *img_lr.shape[2:])
+ if use_apex_gn:
+ date_embedding = date_embedding.to(img_lr.dtype, non_blocking=True).to(memory_format=torch.channels_last)
+ else:
+ date_embedding = date_embedding.to(img_lr.dtype, non_blocking=True).contiguous()
+ img_lr = torch.cat((img_lr, date_embedding), dim=1)
+ _t_prep = _t()
+
# Perform regression on a single batch element
with torch.inference_mode():
if lead_time_label is not None:
x = net(x=x_hat[0:1], img_lr=img_lr, lead_time_label=lead_time_label)
else:
- x = net(x=x_hat[0:1], img_lr=img_lr)
+ x = net(x=x_hat[0:1], img_lr=img_lr, force_fp32=False)
+ _t_net = _t()
# If the batch size is greater than 1, repeat the prediction
if x_hat.shape[0] > 1:
x = x.repeat([d if i == 0 else 1 for i, d in enumerate(x_hat.shape)])
+ if _timings is not None:
+ _timings["reg_input_prep"] = _timings.get("reg_input_prep", 0.0) + (_t_prep - _t0)
+ _timings["reg_net_forward"] = _timings.get("reg_net_forward", 0.0) + (_t_net - _t_prep)
+
return x
@@ -104,6 +141,10 @@ def diffusion_step(
device: torch.device,
mean_hr: torch.Tensor = None,
lead_time_label: torch.Tensor = None,
+ static_channels: Optional[torch.Tensor] = None,
+ date_embedding: Optional[torch.Tensor] = None,
+ use_apex_gn: bool = False,
+ _timings: Optional[dict] = None,
) -> torch.Tensor:
"""
@@ -141,6 +182,12 @@ def diffusion_step(
lead_time_label : torch.Tensor, optional
Lead time label tensor for temporal conditioning,
with shape (batch_size, lead_time_dims). Default is None.
+ static_channels : torch.Tensor, optional
+ Static channels input of shape (C_static, H, W).
+ date_embedding : torch.Tensor, optional
+ Date embedding input of shape (B, C_date).
+ use_apex_gn : bool, optional
+ Whether Apex's fused group normalization is used. Default is False.
Returns
-------
@@ -150,21 +197,24 @@ def diffusion_step(
"""
# Check img_lr dimensions match expected shape
- if img_lr.shape[2:] != img_shape:
+ if img_lr.shape[-2:] != img_shape:
raise ValueError(
- f"img_lr shape {img_lr.shape[2:]} does not match expected shape img_shape {img_shape}"
+ f"img_lr shape {img_lr.shape[-2:]} does not match expected shape img_shape {img_shape}"
)
# Check mean_hr dimensions if provided
if mean_hr is not None:
- if mean_hr.shape[2:] != img_shape:
+ if mean_hr.shape[-2:] != img_shape:
raise ValueError(
f"mean_hr shape {mean_hr.shape[2:]} does not match expected shape img_shape {img_shape}"
)
if mean_hr.shape[0] != 1:
raise ValueError(f"mean_hr must have batch size 1, got {mean_hr.shape[0]}")
- img_lr = img_lr.to(memory_format=torch.channels_last)
+ if len(rank_batches) == 0:
+ raise ValueError("rank_batches is empty, at least one batch of seeds is required")
+
+ # img_lr = img_lr.to(memory_format=torch.channels_last)
# Handling of the high-res mean
additional_args = {}
@@ -172,6 +222,13 @@ def diffusion_step(
additional_args["mean_hr"] = mean_hr
if lead_time_label is not None:
additional_args["lead_time_label"] = lead_time_label
+ if static_channels is not None:
+ additional_args["static_channels"] = static_channels
+ if date_embedding is not None:
+ additional_args["date_embedding"] = date_embedding
+ additional_args["use_apex_gn"] = use_apex_gn
+
+ _t = _sync_t if _timings is not None else (lambda: 0.0)
# Loop over batches
all_images = []
@@ -180,8 +237,13 @@ def diffusion_step(
batch_size = len(batch_seeds)
if batch_size == 0:
continue
+ if batch_size != img_lr.shape[0]:
+ raise ValueError(
+ f"Batch size {batch_size} does not match img_lr batch size {img_lr.shape[0]}"
+ )
# Initialize random generator, and generate latents
+ _t0 = _t()
rnd = StackedRandomGenerator(device, batch_seeds)
latents = rnd.randn(
[
@@ -192,128 +254,61 @@ def diffusion_step(
],
device=device,
)#.to(memory_format=torch.channels_last)
+ _t_latent = _t()
+ batch_timings: dict = {} if _timings is not None else None
with torch.inference_mode():
images = sampler_fn(
- net, latents, img_lr, randn_like=rnd.randn_like, **additional_args
+ net, latents, img_lr, randn_like=rnd.randn_like,
+ _timings=batch_timings, **additional_args
)
+ _t_sampler = _t()
+
+ if _timings is not None:
+ _timings["diff_latent_gen"] = _timings.get("diff_latent_gen", 0.0) + (_t_latent - _t0)
+ _timings["diff_sampler_total"] = _timings.get("diff_sampler_total", 0.0) + (_t_sampler - _t_latent)
+ for k, v in batch_timings.items():
+ _timings[k] = _timings.get(k, 0.0) + v
+
all_images.append(images)
return torch.cat(all_images)
-def generate():
- pass
-
############################################################################
-# CorrDiff writer utilities #
+# Saving and Visualization Utilities #
############################################################################
-class NetCDFWriter:
- """NetCDF Writer"""
-
- def __init__(
- self, f, lat, lon, input_channels, output_channels, has_lead_time=False
- ):
- self._f = f
- self.has_lead_time = has_lead_time
- # create unlimited dimensions
- f.createDimension("time")
- f.createDimension("ensemble")
-
- if lat.shape != lon.shape:
- raise ValueError("lat and lon must have the same shape")
- ny, nx = lat.shape
-
- # create lat/lon grid
- f.createDimension("x", nx)
- f.createDimension("y", ny)
-
- v = f.createVariable("lat", "f", dimensions=("y", "x"))
- # NOTE rethink this for datasets whose samples don't have constant lat-lon.
- v[:] = lat
- v.standard_name = "latitude"
- v.units = "degrees_north"
-
- v = f.createVariable("lon", "f", dimensions=("y", "x"))
- v[:] = lon
- v.standard_name = "longitude"
- v.units = "degrees_east"
-
- # create time dimension
- if has_lead_time:
- v = f.createVariable("time", "str", ("time"))
- else:
- v = f.createVariable("time", "i8", ("time"))
- v.calendar = "standard"
- v.units = "hours since 1990-01-01 00:00:00"
-
- self.truth_group = f.createGroup("truth")
- self.prediction_group = f.createGroup("prediction")
- self.input_group = f.createGroup("input")
-
- for variable in output_channels:
- name = variable.name + variable.level
- self.truth_group.createVariable(name, "f", dimensions=("time", "y", "x"))
- self.prediction_group.createVariable(
- name, "f", dimensions=("ensemble", "time", "y", "x")
- )
-
- # setup input data in netCDF
-
- for variable in input_channels:
- name = variable.name + variable.level
- self.input_group.createVariable(name, "f", dimensions=("time", "y", "x"))
-
- def write_input(self, channel_name, time_index, val):
- """Write input data to NetCDF file."""
- self.input_group[channel_name][time_index] = val
-
- def write_truth(self, channel_name, time_index, val):
- """Write ground truth data to NetCDF file."""
- self.truth_group[channel_name][time_index] = val
-
- def write_prediction(self, channel_name, time_index, ensemble_index, val):
- """Write prediction data to NetCDF file."""
- self.prediction_group[channel_name][ensemble_index, time_index] = val
-
- def write_time(self, time_index, time):
- """Write time information to NetCDF file."""
- if self.has_lead_time:
- self._f["time"][time_index] = time
+def save_results_as_torch(output_path, time_step, image_pred, image_hr, image_lr, mean_pred):
+ os.makedirs(output_path, exist_ok=True)
+ if mean_pred is not None:
+ torch.save(mean_pred, os.path.join(output_path, f'{time_step}-regression-prediction'))
+ torch.save(image_hr, os.path.join(output_path, f'{time_step}-target'))
+ torch.save(image_pred, os.path.join(output_path, f'{time_step}-predictions'))
+ torch.save(image_lr, os.path.join(output_path, f'{time_step}-baseline'))
+
+
+def calculate_bounds(*arrays: np.ndarray) -> tuple[float]:
+ """Calculate consistent bounds across all arrays"""
+ valid_arrays = [arr for arr in arrays if arr is not None]
+ if not valid_arrays:
+ return None, None
+
+ # hanndle if there are masked arrays with invalid values (e.g. NaNs)
+ all_values = []
+ for arr in valid_arrays:
+ if hasattr(arr, 'compressed'): # Masked array
+ compressed = arr.compressed()
+ if len(compressed) > 0:
+ all_values.extend(compressed)
+ elif hasattr(arr, 'flatten'): # Regular numpy array
+ all_values.extend(arr.flatten())
else:
- time_v = self._f["time"]
- self._f["time"][time_index] = cftime.date2num(
- time, time_v.units, time_v.calendar
- )
-
-
-############################################################################
-# CorrDiff time utilities #
-############################################################################
-
-
-def get_time_from_range(times_range, time_format="%Y-%m-%dT%H:%M:%S"):
- """Generates a list of times within a given range.
-
- Args:
- times_range: A list containing start time, end time, and optional interval (hours).
- time_format: The format of the input times (default: "%Y-%m-%dT%H:%M:%S").
-
- Returns:
- A list of times within the specified range.
- """
-
- start_time = datetime.datetime.strptime(times_range[0], time_format)
- end_time = datetime.datetime.strptime(times_range[1], time_format)
- interval = (
- datetime.timedelta(hours=times_range[2])
- if len(times_range) > 2
- else datetime.timedelta(hours=1)
- )
-
- times = [
- t.strftime(time_format)
- for t in time_range(start_time, end_time, interval, inclusive=True)
- ]
- return times
+ all_values.append(arr)
+
+ if not all_values:
+ return None, None
+
+ vmin = min(all_values)
+ vmax = max(all_values)
+ return vmin, vmax
diff --git a/src/hirad/utils/patching.py b/src/hirad/utils/patching.py
index 6f4bc4d8..cad33c4d 100644
--- a/src/hirad/utils/patching.py
+++ b/src/hirad/utils/patching.py
@@ -383,14 +383,27 @@ def __init__(
super().__init__(img_shape, patch_shape)
self.overlap_pix = overlap_pix
self.boundary_pix = boundary_pix
- patch_num_x = math.ceil(
- img_shape[1] / (patch_shape[1] - overlap_pix - boundary_pix)
- )
- patch_num_y = math.ceil(
- img_shape[0] / (patch_shape[0] - overlap_pix - boundary_pix)
- )
+ stride_x = patch_shape[1] - overlap_pix - boundary_pix
+ stride_y = patch_shape[0] - overlap_pix - boundary_pix
+ patch_num_x = math.ceil(img_shape[1] / stride_x)
+ patch_num_y = math.ceil(img_shape[0] / stride_y)
self.patch_num = patch_num_x * patch_num_y
+ # Precompute padded shape and padding — used by both apply and fuse.
+ self._padded_shape_x = stride_x * (patch_num_x - 1) + patch_shape[1] + boundary_pix
+ self._padded_shape_y = stride_y * (patch_num_y - 1) + patch_shape[0] + boundary_pix
+ self._pad = (
+ boundary_pix,
+ self._padded_shape_x - img_shape[1] - boundary_pix,
+ boundary_pix,
+ self._padded_shape_y - img_shape[0] - boundary_pix,
+ ) # (left, right, top, bottom) for F.pad
+
+ # overlap_count is purely geometric; cache it lazily per device so
+ # image_fuse() does not recompute it on every call.
+ self._overlap_count: Optional[Tensor] = None
+ self._overlap_count_device: Optional[torch.device] = None
+
def apply(
self,
input: Tensor,
@@ -473,6 +486,20 @@ def fuse(self, input: Tensor, batch_size: int) -> Tensor:
:func:`physicsnemo.utils.patching.image_fuse`
The underlying function used to perform the fusion operation.
"""
+ if self._overlap_count is None or self._overlap_count.device != input.device:
+ self._overlap_count = _compute_overlap_count(
+ img_shape_y=self.img_shape[0],
+ img_shape_x=self.img_shape[1],
+ patch_shape_y=self.patch_shape[0],
+ patch_shape_x=self.patch_shape[1],
+ overlap_pix=self.overlap_pix,
+ boundary_pix=self.boundary_pix,
+ padded_shape_y=self._padded_shape_y,
+ padded_shape_x=self._padded_shape_x,
+ pad=self._pad,
+ device=input.device,
+ )
+ self._overlap_count_device = input.device
out = image_fuse(
input=input,
img_shape_y=self.img_shape[0],
@@ -480,10 +507,42 @@ def fuse(self, input: Tensor, batch_size: int) -> Tensor:
batch_size=batch_size,
overlap_pix=self.overlap_pix,
boundary_pix=self.boundary_pix,
+ overlap_count=self._overlap_count,
)
return out
+def _compute_overlap_count(
+ img_shape_y: int,
+ img_shape_x: int,
+ patch_shape_y: int,
+ patch_shape_x: int,
+ overlap_pix: int,
+ boundary_pix: int,
+ padded_shape_y: int,
+ padded_shape_x: int,
+ pad: tuple,
+ device: torch.device,
+) -> Tensor:
+ """Compute how many patches cover each output pixel (shape: (1, 1, img_y, img_x)).
+
+ Result is static for a given geometry, so callers should cache it.
+ """
+ stride_y = patch_shape_y - overlap_pix - boundary_pix
+ stride_x = patch_shape_x - overlap_pix - boundary_pix
+ ones = torch.ones((1, 1, padded_shape_y, padded_shape_x), device=device)
+ unfolded = torch.nn.functional.unfold(
+ ones, kernel_size=(patch_shape_y, patch_shape_x), stride=(stride_y, stride_x)
+ )
+ count = torch.nn.functional.fold(
+ unfolded,
+ output_size=(padded_shape_y, padded_shape_x),
+ kernel_size=(patch_shape_y, patch_shape_x),
+ stride=(stride_y, stride_x),
+ )
+ return count[..., pad[2]: pad[2] + img_shape_y, pad[0]: pad[0] + img_shape_x]
+
+
def image_batching(
input: Tensor,
patch_shape_y: int,
@@ -584,21 +643,30 @@ def image_batching(
)
pad_x_right = padded_shape_x - img_shape_x - boundary_pix
pad_y_right = padded_shape_y - img_shape_y - boundary_pix
- image_padding = torch.nn.ReflectionPad2d(
- (boundary_pix, pad_x_right, boundary_pix, pad_y_right)
- ).to(
- input.device
- ) # (padding_left,padding_right,padding_top,padding_bottom)
- input_padded = image_padding(input)
+ input_padded = torch.nn.functional.pad(
+ input, (boundary_pix, pad_x_right, boundary_pix, pad_y_right), mode="reflect"
+ )
patch_num = patch_num_x * patch_num_y
+
+ # Cast to float for unfold
+ if input.dtype == torch.int32:
+ input_padded = input_padded.view(torch.float32)
+ elif input.dtype == torch.int64:
+ input_padded = input_padded.view(torch.float64)
+
x_unfold = torch.nn.functional.unfold(
- input=input_padded.view(_cast_type(input_padded)), # Cast to float
+ input=input_padded,
kernel_size=(patch_shape_y, patch_shape_x),
stride=(
patch_shape_y - overlap_pix - boundary_pix,
patch_shape_x - overlap_pix - boundary_pix,
),
- ).to(input_padded.dtype)
+ )
+
+ # Cast back to original dtype
+ if input.dtype in [torch.int32, torch.int64]:
+ x_unfold = x_unfold.view(input.dtype)
+
x_unfold = rearrange(
x_unfold,
"b (c p_h p_w) (nb_p_h nb_p_w) -> (nb_p_w nb_p_h b) c p_h p_w",
@@ -608,16 +676,7 @@ def image_batching(
nb_p_w=patch_num_x,
)
if input_interp is not None:
- input_interp_repeated = rearrange(
- torch.repeat_interleave(
- input=input_interp,
- repeats=patch_num,
- dim=0,
- output_size=x_unfold.shape[0],
- ),
- "(b p) c h w -> (p b) c h w",
- p=patch_num,
- )
+ input_interp_repeated = input_interp.repeat(patch_num, 1, 1, 1)
return torch.cat((x_unfold, input_interp_repeated), dim=1)
else:
return x_unfold
@@ -630,6 +689,7 @@ def image_fuse(
batch_size: int,
overlap_pix: int,
boundary_pix: int,
+ overlap_count: Optional[Tensor] = None,
) -> Tensor:
"""
Reconstructs a full image from a batch of patched images. Reverts the patching
@@ -690,28 +750,23 @@ def image_fuse(
pad_y_right = padded_shape_y - img_shape_y - boundary_pix
pad = (boundary_pix, pad_x_right, boundary_pix, pad_y_right)
- # Count local overlaps between patches
- input_ones = torch.ones(
- (batch_size, input.shape[1], padded_shape_y, padded_shape_x),
- device=input.device,
- )
- overlap_count = torch.nn.functional.unfold(
- input=input_ones,
- kernel_size=(patch_shape_y, patch_shape_x),
- stride=(
- patch_shape_y - overlap_pix - boundary_pix,
- patch_shape_x - overlap_pix - boundary_pix,
- ),
- )
- overlap_count = torch.nn.functional.fold(
- input=overlap_count,
- output_size=(padded_shape_y, padded_shape_x),
- kernel_size=(patch_shape_y, patch_shape_x),
- stride=(
- patch_shape_y - overlap_pix - boundary_pix,
- patch_shape_x - overlap_pix - boundary_pix,
- ),
- )
+ # Count local overlaps between patches.
+ # overlap_count is purely geometric (constant for a given patch config), so
+ # callers that invoke fuse in a loop should pass a pre-cached value via the
+ # overlap_count parameter to avoid recomputing it on every call.
+ if overlap_count is None:
+ overlap_count = _compute_overlap_count(
+ img_shape_y=img_shape_y,
+ img_shape_x=img_shape_x,
+ patch_shape_y=patch_shape_y,
+ patch_shape_x=patch_shape_x,
+ overlap_pix=overlap_pix,
+ boundary_pix=boundary_pix,
+ padded_shape_y=padded_shape_y,
+ padded_shape_x=padded_shape_x,
+ pad=pad,
+ device=input.device,
+ )
# Reshape input to make it 3D to apply fold
x = rearrange(
@@ -722,6 +777,13 @@ def image_fuse(
nb_p_h=patch_num_y,
nb_p_w=patch_num_x,
)
+
+ # Cast to float for fold
+ if input.dtype == torch.int32:
+ x = x.view(torch.float32)
+ elif input.dtype == torch.int64:
+ x = x.view(torch.float64)
+
# Stitch patches together (by summing over overlapping patches)
x_folded = torch.nn.functional.fold(
input=x,
@@ -733,35 +795,21 @@ def image_fuse(
),
)
+ # Cast back to original dtype
+ if input.dtype in [torch.int32, torch.int64]:
+ x_folded = x_folded.view(input.dtype)
+
# Remove padding
x_no_padding = x_folded[
..., pad[2] : pad[2] + img_shape_y, pad[0] : pad[0] + img_shape_x
]
- overlap_count_no_padding = overlap_count[
- ..., pad[2] : pad[2] + img_shape_y, pad[0] : pad[0] + img_shape_x
- ]
-
- # Normalize by overlap count
- return x_no_padding / overlap_count_no_padding
-
-def _cast_type(input: Tensor) -> torch.dtype:
- """Return float type based on input tensor type.
+ # overlap_count is (1, 1, img_shape_y, img_shape_x); broadcasts over batch and channels
+ x_no_padding = x_no_padding / overlap_count
- Parameters
- ----------
- input : Tensor
- Input tensor to determine float type from
+ #TODO: do we want to introduce this and will it break existing checkpoints
+ # if input.dtype in [torch.int32, torch.int64]:
+ # x_no_padding = x_no_padding.round().view(input.dtype)
- Returns
- -------
- torch.dtype
- Float type corresponding to input tensor type for int32/64,
- otherwise returns original dtype
- """
- if input.dtype == torch.int32:
- return torch.float32
- elif input.dtype == torch.int64:
- return torch.float64
- else:
- return input.dtype
+ # Normalize by overlap count
+ return x_no_padding
diff --git a/src/hirad/utils/train_helpers.py b/src/hirad/utils/train_helpers.py
index 218d6f19..771d2ee5 100644
--- a/src/hirad/utils/train_helpers.py
+++ b/src/hirad/utils/train_helpers.py
@@ -16,8 +16,39 @@
import torch
import numpy as np
-from omegaconf import ListConfig
import warnings
+import mlflow
+from omegaconf import DictConfig, OmegaConf
+import os
+import psutil
+import time
+
+from hirad.distributed import DistributedManager
+from hirad.utils.env_info import get_env_info, flatten_dict
+
+# Define safe CUDA profiler tools that fallback to no-ops when CUDA is not available
+def cuda_profiler():
+ if torch.cuda.is_available():
+ return torch.cuda.profiler.profile()
+ else:
+ return nullcontext()
+
+
+def cuda_profiler_start():
+ if torch.cuda.is_available():
+ torch.cuda.profiler.start()
+
+
+def cuda_profiler_stop():
+ if torch.cuda.is_available():
+ torch.cuda.profiler.stop()
+
+
+def profiler_emit_nvtx():
+ if torch.cuda.is_available():
+ return torch.autograd.profiler.emit_nvtx()
+ else:
+ return nullcontext()
def set_patch_shape(img_shape, patch_shape):
@@ -44,6 +75,27 @@ def set_patch_shape(img_shape, patch_shape):
return use_patching, (img_shape_y, img_shape_x), (patch_shape_y, patch_shape_x)
+def calculate_patch_per_iter(patch_num, max_patch_per_gpu, batch_size_per_gpu):
+ if max_patch_per_gpu:
+ if max_patch_per_gpu // batch_size_per_gpu < 1:
+ raise ValueError(
+ f"max_patch_per_gpu ({max_patch_per_gpu}) must be greater or equal to batch_size_per_gpu ({batch_size_per_gpu})."
+ )
+ max_patch_num_per_iter = min(
+ patch_num, (max_patch_per_gpu // batch_size_per_gpu)
+ )
+ patch_iterations = (
+ patch_num + max_patch_num_per_iter - 1
+ ) // max_patch_num_per_iter
+ patch_nums_iter = [
+ min(max_patch_num_per_iter, patch_num - i * max_patch_num_per_iter)
+ for i in range(patch_iterations)
+ ]
+ else:
+ patch_nums_iter = [patch_num]
+ return patch_nums_iter
+
+
def set_seed(rank):
"""
Set seeds for NumPy and PyTorch to ensure reproducibility in distributed settings
@@ -80,6 +132,20 @@ def compute_num_accumulation_rounds(total_batch_size, batch_size_per_gpu, world_
return batch_gpu_total, num_accumulation_rounds
+def update_learning_rate(optimizer, lr, lr_rampup, lr_decay, lr_decay_rate, cur_nimg):
+ """Apply learning rate rampup and decay schedule."""
+ current_lr = None
+ for g in optimizer.param_groups:
+ if lr_rampup > 0:
+ g["lr"] = lr * min(cur_nimg / lr_rampup, 1)
+ if cur_nimg >= lr_rampup:
+ g["lr"] *= lr_decay ** (
+ (cur_nimg - lr_rampup) // lr_decay_rate
+ )
+ current_lr = g["lr"]
+ return current_lr
+
+
def handle_and_clip_gradients(model, grad_clip_threshold=None):
"""
Handles NaNs and infinities in the gradients and optionally clips the gradients.
@@ -100,9 +166,19 @@ def handle_and_clip_gradients(model, grad_clip_threshold=None):
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_threshold)
-def parse_model_args(args):
- """Convert ListConfig values in args to tuples."""
- return {k: tuple(v) if isinstance(v, ListConfig) else v for k, v in args.items()}
+def check_model_health(model, step, logger):
+ for name, param in model.named_parameters():
+ # Check weights
+ if not torch.isfinite(param).all():
+ logger.warning(f"!!! Weights in {name} went NaN at step {step}")
+ return False
+
+ # Check gradients
+ if param.grad is not None:
+ if not torch.isfinite(param.grad).all():
+ logger.warning(f"!!! Gradients in {name} went NaN at step {step}")
+ return False
+ return True
def is_time_for_periodic_task(
@@ -115,3 +191,81 @@ def is_time_for_periodic_task(
return True
else:
return cur_nimg % freq < batch_size
+
+#TODO: When mlflow is working locally on a multi-node job, it runs into issue with writing
+# to the same SQLite file. The current workaround is to only log system metrics from the
+# main process. Find a workaround to log system metrics from all processes without causing
+# conflicts in the SQLite file, such as using separate files for each process or using
+# a different backend for mlflow tracking.
+def init_mlflow(cfg: DictConfig, dist: DistributedManager, write_dir: str=".") -> None:
+ if dist.rank==0:
+ print("Started activating initial mlflow run")
+ if cfg.logging.uri is not None:
+ mlflow.set_tracking_uri(cfg.logging.uri)
+ mlflow.set_experiment(experiment_name=cfg.logging.experiment_name)
+ run_id = None
+ if os.path.isfile(os.path.join(write_dir, 'run_id.txt')):
+ with open(os.path.join(write_dir, 'run_id.txt'),'r') as f:
+ run_id = f.read()
+ if dist.world_size<=4 or cfg.logging.uri is None:
+ mlflow.system_metrics.set_system_metrics_node_id("node-0")
+ if run_id:
+ mlflow.start_run(run_id=run_id, log_system_metrics=False if dist.world_size>4 else True)
+ else:
+ mlflow.start_run(run_name=cfg.logging.run_name, log_system_metrics=False if dist.world_size>4 else True)
+ if run_id is None:
+ run = mlflow.active_run()
+ with open(os.path.join(write_dir, "run_id.txt"), 'w') as f:
+ f.write(run.info.run_id)
+ # log environment info if run is not continuing from previous checkpoint
+ mlflow.log_params(flatten_dict(OmegaConf.to_object(cfg)))
+ python_environment, git_diff = get_env_info(exclude_prefixes=['hirad', '__mp_main__'])
+ mlflow.log_dict(python_environment, "environment.json")
+ if git_diff:
+ mlflow.log_text(git_diff, "git_diff.txt")
+ mlflow.log_dict(cfg, "config.json")
+
+ if dist.world_size > 4:
+ torch.distributed.barrier()
+
+ if cfg.logging.uri is not None:
+ if (dist.rank!=0 and dist._local_rank==0) or (dist.rank==1 and dist.world_size>4):
+ print("Started actvating sub mlflow run.")
+
+ mlflow.set_tracking_uri(cfg.logging.uri)
+ mlflow.system_metrics.set_system_metrics_node_id(f"node-{(dist.rank//4)}"
+ if dist.rank!=1
+ else "node-0")
+ mlflow.set_experiment(experiment_name=cfg.logging.experiment_name)
+ with open(os.path.join(write_dir, "run_id.txt"), 'r') as f:
+ run_id = f.read()
+ mlflow.start_run(run_id=run_id, log_system_metrics=True)
+
+
+def log_training_progress(logger0, logging_method, dist, cur_nimg, tick_start_nimg, tick_start_time,
+ tick_read_time, start_time, average_loss, average_loss_running_mean,
+ current_lr):
+ """Log training progress metrics."""
+ torch.cuda.synchronize()
+ tick_end_time = time.time()
+ fields = [
+ f"samples {cur_nimg:<9.1f}",
+ f"training_loss {average_loss:<7.2f}",
+ f"training_loss_running_mean {average_loss_running_mean:<7.2f}",
+ f"learning_rate {current_lr:<7.8f}",
+ f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}",
+ f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.4f}",
+ f"sec_for_reading {tick_read_time:<7.4f}",
+ f"total_sec {(tick_end_time - start_time):<7.1f}",
+ f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}",
+ ]
+ if torch.cuda.is_available():
+ fields.append(f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}")
+ fields.append(f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}")
+ torch.cuda.reset_peak_memory_stats()
+ logger0.info(" ".join(fields))
+
+ if logging_method == "mlflow":
+ mlflow.log_metric("training_loss", average_loss, cur_nimg)
+ mlflow.log_metric("training_loss_running_mean", average_loss_running_mean, cur_nimg)
+ mlflow.log_metric("learning_rate", current_lr, cur_nimg)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/losses/__init__.py b/tests/losses/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/losses/test_loss.py b/tests/losses/test_loss.py
new file mode 100644
index 00000000..8989e76c
--- /dev/null
+++ b/tests/losses/test_loss.py
@@ -0,0 +1,762 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from unittest.mock import MagicMock
+
+import pytest
+import torch
+
+from hirad.losses.loss import RegressionLoss, ResidualLoss
+from hirad.utils.patching import RandomPatching2D
+
+
+# ---------------------------------------------------------------------------
+# Helpers / fixtures
+# ---------------------------------------------------------------------------
+
+B, C_HR, C_LR, H, W = 2, 3, 4, 64, 64
+
+
+def _make_dummy_net(out_channels=C_HR):
+ """Return a callable that mimics a neural network."""
+ net = MagicMock()
+ net.side_effect = lambda x, y_lr, *a, **kw: torch.zeros(
+ x.shape[0], out_channels, x.shape[2], x.shape[3], device=x.device
+ )
+ return net
+
+
+def _make_identity_augment():
+ """Augmentation pipe that returns its input unchanged."""
+ return lambda x: (x, None)
+
+
+@pytest.fixture()
+def img_clean():
+ return torch.randn(B, C_HR, H, W)
+
+
+@pytest.fixture()
+def img_lr():
+ return torch.randn(B, C_LR, H, W)
+
+
+@pytest.fixture()
+def static_channels():
+ return torch.randn(1, 2, H, W)
+
+
+@pytest.fixture()
+def date_embedding():
+ return torch.randn(B, 5)
+
+
+@pytest.fixture()
+def lead_time_label():
+ return torch.randint(49, size=(B,))
+
+
+############################################################################
+# RegressionLoss #
+############################################################################
+
+
+class TestRegressionLossBasic:
+ """Basic behaviour of RegressionLoss."""
+
+ def test_output_shape(self, img_clean, img_lr):
+ net = _make_dummy_net()
+ loss = RegressionLoss()
+ result = loss(net, img_clean, img_lr)
+ assert result.shape == img_clean.shape
+
+ def test_loss_non_negative(self, img_clean, img_lr):
+ net = _make_dummy_net()
+ loss = RegressionLoss()
+ result = loss(net, img_clean, img_lr)
+ assert (result >= 0).all()
+
+ def test_zero_loss_when_prediction_matches(self, img_clean, img_lr):
+ """If net returns img_clean exactly, the loss should be zero."""
+ net = MagicMock()
+ # Net always returns the ground truth
+ net.side_effect = lambda x, y_lr, *a, **kw: img_clean
+ loss = RegressionLoss()
+ result = loss(net, img_clean, img_lr)
+ torch.testing.assert_close(result, torch.zeros_like(result))
+
+ def test_net_receives_zero_input(self, img_clean, img_lr):
+ """First argument to the network should be a zero tensor."""
+ net = _make_dummy_net()
+ loss = RegressionLoss()
+ loss(net, img_clean, img_lr)
+ call_args = net.call_args
+ zero_input = call_args[0][0]
+ torch.testing.assert_close(zero_input, torch.zeros_like(zero_input))
+
+ def test_net_receives_lr_conditioning(self, img_clean, img_lr):
+ """Second argument to the network should contain the LR image."""
+ net = _make_dummy_net()
+ loss = RegressionLoss()
+ loss(net, img_clean, img_lr)
+ call_args = net.call_args
+ y_lr_arg = call_args[0][1]
+ # Without static channels or date embedding, the conditioning
+ # should match the LR image.
+ torch.testing.assert_close(y_lr_arg, img_lr)
+
+
+class TestRegressionLossAugmentation:
+ """Augmentation behaviour of RegressionLoss."""
+
+ def test_identity_augmentation_same_as_no_augmentation(self, img_clean, img_lr):
+ net = _make_dummy_net()
+ loss = RegressionLoss()
+ result_a = loss(net, img_clean, img_lr, augment_pipe=None)
+ result_b = loss(net, img_clean, img_lr, augment_pipe=_make_identity_augment())
+ torch.testing.assert_close(result_a, result_b)
+
+ def test_augment_pipe_is_called(self, img_clean, img_lr):
+ augment_pipe = MagicMock(side_effect=_make_identity_augment())
+ net = _make_dummy_net()
+ loss = RegressionLoss()
+ loss(net, img_clean, img_lr, augment_pipe=augment_pipe)
+ augment_pipe.assert_called_once()
+
+
+class TestRegressionLossStaticChannels:
+ """Static-channel conditioning for RegressionLoss."""
+
+ def test_lr_conditioning_includes_static_channels(
+ self, img_clean, img_lr, static_channels
+ ):
+ net = _make_dummy_net()
+ loss = RegressionLoss()
+ loss(net, img_clean, img_lr, static_channels=static_channels)
+ y_lr_arg = net.call_args[0][1]
+ torch.testing.assert_close(
+ y_lr_arg, torch.cat([img_lr, static_channels.expand(img_lr.shape[0], -1, -1, -1)], dim=1))
+
+ def test_output_shape_with_static_channels(
+ self, img_clean, img_lr, static_channels
+ ):
+ net = _make_dummy_net()
+ loss = RegressionLoss()
+ result = loss(net, img_clean, img_lr, static_channels=static_channels)
+ assert result.shape == img_clean.shape
+
+
+class TestRegressionLossDateEmbedding:
+ """Date-embedding conditioning for RegressionLoss."""
+
+ def test_lr_conditioning_includes_date_embedding(
+ self, img_clean, img_lr, date_embedding
+ ):
+ net = _make_dummy_net()
+ loss = RegressionLoss()
+ loss(net, img_clean, img_lr, date_embedding=date_embedding)
+ y_lr_arg = net.call_args[0][1]
+ torch.testing.assert_close(
+ y_lr_arg, torch.cat([img_lr, date_embedding.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, W)], dim=1)
+ )
+
+ def test_output_shape_with_date_embedding(
+ self, img_clean, img_lr, date_embedding
+ ):
+ net = _make_dummy_net()
+ loss = RegressionLoss()
+ result = loss(net, img_clean, img_lr, date_embedding=date_embedding)
+ assert result.shape == img_clean.shape
+
+ def test_date_embedding_contiguous_without_apex(
+ self, img_clean, img_lr, date_embedding
+ ):
+ """Date embedding should be contiguous when not using apex GN."""
+ net = _make_dummy_net()
+ loss = RegressionLoss()
+ loss(net, img_clean, img_lr, date_embedding=date_embedding, use_apex_gn=False)
+ y_lr_arg = net.call_args[0][1]
+ assert y_lr_arg.is_contiguous()
+
+
+class TestRegressionLossLeadTime:
+ """Lead-time-label behaviour in RegressionLoss."""
+
+ def test_lead_time_label_passed_to_net(
+ self, img_clean, img_lr, lead_time_label
+ ):
+ net = _make_dummy_net()
+ loss = RegressionLoss()
+ loss(net, img_clean, img_lr, lead_time_label=lead_time_label)
+ kw = net.call_args[1]
+ assert "lead_time_label" in kw
+ torch.testing.assert_close(kw["lead_time_label"], lead_time_label)
+
+ def test_no_lead_time_label_omitted_from_kwargs(self, img_clean, img_lr):
+ net = _make_dummy_net()
+ loss = RegressionLoss()
+ loss(net, img_clean, img_lr, lead_time_label=None)
+ kw = net.call_args[1]
+ assert "lead_time_label" not in kw
+
+
+class TestRegressionLossCombined:
+ """Combined optional arguments for RegressionLoss."""
+
+ def test_all_optional_args(
+ self, img_clean, img_lr, static_channels, date_embedding, lead_time_label
+ ):
+ net = _make_dummy_net()
+ loss = RegressionLoss()
+ result = loss(
+ net,
+ img_clean,
+ img_lr,
+ static_channels=static_channels,
+ date_embedding=date_embedding,
+ lead_time_label=lead_time_label,
+ )
+ assert result.shape == img_clean.shape
+ y_lr_arg = net.call_args[0][1]
+ expected_channels = C_LR + static_channels.shape[1] + date_embedding.shape[1]
+ assert y_lr_arg.shape[1] == expected_channels
+
+
+############################################################################
+# ResidualLoss — init #
+############################################################################
+
+
+class TestResidualLossInit:
+ """Tests for ResidualLoss initialization."""
+
+ def test_default_values(self):
+ reg_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ assert loss.P_mean == 0.0
+ assert loss.P_std == 1.2
+ assert loss.sigma_data == 0.5
+ assert loss.hr_mean_conditioning is False
+ assert loss.y_mean is None
+
+ def test_custom_values(self):
+ reg_net = _make_dummy_net()
+ loss = ResidualLoss(
+ regression_net=reg_net,
+ P_mean=1.0,
+ P_std=2.0,
+ sigma_data=1.0,
+ hr_mean_conditioning=True,
+ )
+ assert loss.P_mean == 1.0
+ assert loss.P_std == 2.0
+ assert loss.sigma_data == 1.0
+ assert loss.hr_mean_conditioning is True
+
+
+############################################################################
+# ResidualLoss — get_noise_params #
+############################################################################
+
+
+class TestResidualLossGetNoiseParams:
+ """Tests for ResidualLoss.get_noise_params."""
+
+ def test_return_shapes(self):
+ reg_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ y = torch.randn(B, C_HR, H, W)
+ n, sigma, weight = loss.get_noise_params(y)
+ assert n.shape == y.shape
+ assert sigma.shape == (B, 1, 1, 1)
+ assert weight.shape == (B, 1, 1, 1)
+
+ def test_weight_and_sigma_positive(self):
+ reg_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ y = torch.randn(B, C_HR, H, W)
+ _, sigma, weight = loss.get_noise_params(y)
+ assert (weight > 0).all()
+ assert (sigma > 0).all()
+
+
+############################################################################
+# ResidualLoss — __call__ basic #
+############################################################################
+
+
+class TestResidualLossCallBasic:
+ """Basic behaviour of ResidualLoss.__call__."""
+
+ def test_output_shape_no_patching(self, img_clean, img_lr):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ result = loss(diff_net, img_clean, img_lr)
+ assert result.shape == img_clean.shape
+
+ def test_loss_non_negative(self, img_clean, img_lr):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ result = loss(diff_net, img_clean, img_lr)
+ assert (result >= 0).all()
+
+ def test_regression_net_called(self, img_clean, img_lr):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ loss(diff_net, img_clean, img_lr)
+ reg_net.assert_called_once()
+
+ def test_diffusion_net_called(self, img_clean, img_lr):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ loss(diff_net, img_clean, img_lr)
+ diff_net.assert_called_once()
+
+ def test_regression_net_receives_zero_input(self, img_clean, img_lr):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ loss(diff_net, img_clean, img_lr)
+ zero_input = reg_net.call_args[0][0]
+ torch.testing.assert_close(zero_input, torch.zeros_like(zero_input))
+
+
+############################################################################
+# ResidualLoss — shape validation #
+############################################################################
+
+
+class TestResidualLossShapeValidation:
+ """Validation of img_clean / img_lr shapes."""
+
+ def test_batch_mismatch_raises(self):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ img_clean = torch.randn(2, C_HR, H, W)
+ img_lr = torch.randn(3, C_LR, H, W)
+ with pytest.raises(ValueError, match="Shape mismatch"):
+ loss(diff_net, img_clean, img_lr)
+
+ def test_spatial_mismatch_raises(self):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ img_clean = torch.randn(B, C_HR, 32, 32)
+ img_lr = torch.randn(B, C_LR, 64, 64)
+ with pytest.raises(ValueError, match="Shape mismatch"):
+ loss(diff_net, img_clean, img_lr)
+
+ def test_invalid_patching_type_raises(self, img_clean, img_lr):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ with pytest.raises(ValueError, match="RandomPatching2D"):
+ loss(diff_net, img_clean, img_lr, patching="not_a_patching_object")
+
+
+############################################################################
+# ResidualLoss — augmentation #
+############################################################################
+
+
+class TestResidualLossAugmentation:
+ """Augmentation in ResidualLoss."""
+
+ def test_augment_pipe_called(self, img_clean, img_lr):
+ augment_pipe = MagicMock(side_effect=_make_identity_augment())
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ loss(diff_net, img_clean, img_lr, augment_pipe=augment_pipe)
+ augment_pipe.assert_called_once()
+
+ def test_identity_augmentation_output_shape(self, img_clean, img_lr):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ result = loss(
+ diff_net,
+ img_clean,
+ img_lr,
+ augment_pipe=_make_identity_augment(),
+ )
+ assert result.shape == img_clean.shape
+
+ def test_identity_augmentation_same_as_no_augmentation(self, img_clean, img_lr):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ torch.manual_seed(0) # fix the seed because of random noise in the loss
+ result_a = loss(diff_net, img_clean, img_lr, augment_pipe=None)
+ torch.manual_seed(0) # reset the seed to ensure same noise is added
+ result_b = loss(diff_net, img_clean, img_lr, augment_pipe=_make_identity_augment())
+ torch.testing.assert_close(result_a, result_b)
+
+
+############################################################################
+# ResidualLoss — static channels #
+############################################################################
+
+
+class TestResidualLossStaticChannels:
+ """Static-channel conditioning in ResidualLoss."""
+
+ def test_output_shape_with_static_channels(
+ self, img_clean, img_lr, static_channels
+ ):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ result = loss(diff_net, img_clean, img_lr, static_channels=static_channels)
+ assert result.shape == img_clean.shape
+
+ def test_regression_net_receives_static_channels(
+ self, img_clean, img_lr, static_channels
+ ):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ loss(diff_net, img_clean, img_lr, static_channels=static_channels)
+ y_lr_reg = reg_net.call_args[0][1]
+ assert y_lr_reg.shape[1] == C_LR + static_channels.shape[1]
+ torch.testing.assert_close(
+ y_lr_reg[0, C_LR:, :, :], static_channels[0, :, :, :]
+ )
+ torch.testing.assert_close(
+ y_lr_reg, torch.cat([img_lr, static_channels.expand(img_lr.shape[0], -1, -1, -1)], dim=1)
+ )
+
+ def test_diffusion_net_receives_static_channels(
+ self, img_clean, img_lr, static_channels
+ ):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ loss(diff_net, img_clean, img_lr, static_channels=static_channels)
+ y_lr_diff = diff_net.call_args[0][1]
+ # Diffusion net also gets static channels appended to y_lr
+ assert y_lr_diff.shape[1] == C_LR + static_channels.shape[1]
+ torch.testing.assert_close(
+ y_lr_diff[0, C_LR:, :, :], static_channels[0, :, :, :]
+ )
+ torch.testing.assert_close(
+ y_lr_diff, torch.cat([img_lr, static_channels.expand(img_lr.shape[0], -1, -1, -1)], dim=1)
+ )
+
+############################################################################
+# ResidualLoss — date embedding #
+############################################################################
+
+
+class TestResidualLossDateEmbedding:
+ """Date-embedding conditioning in ResidualLoss."""
+
+ def test_output_shape_with_date_embedding(
+ self, img_clean, img_lr, date_embedding
+ ):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ result = loss(diff_net, img_clean, img_lr, date_embedding=date_embedding)
+ assert result.shape == img_clean.shape
+
+ def test_regression_net_receives_date_embedding(
+ self, img_clean, img_lr, date_embedding
+ ):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ loss(diff_net, img_clean, img_lr, date_embedding=date_embedding)
+ y_lr_reg = reg_net.call_args[0][1]
+ assert y_lr_reg.shape[1] == C_LR + date_embedding.shape[1]
+ torch.testing.assert_close(
+ y_lr_reg[:, C_LR:, 0, 0], date_embedding
+ )
+ torch.testing.assert_close(
+ y_lr_reg, torch.cat([img_lr, date_embedding.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, W)], dim=1)
+ )
+
+ def test_diffusion_net_receives_date_embedding(
+ self, img_clean, img_lr, date_embedding
+ ):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ loss(diff_net, img_clean, img_lr, date_embedding=date_embedding)
+ y_lr_diff = diff_net.call_args[0][1]
+ assert y_lr_diff.shape[1] == C_LR + date_embedding.shape[1]
+ torch.testing.assert_close(
+ y_lr_diff[:, C_LR:, 0, 0], date_embedding
+ )
+ torch.testing.assert_close(
+ y_lr_diff, torch.cat([img_lr, date_embedding.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, W)], dim=1)
+ )
+
+############################################################################
+# ResidualLoss — lead time label #
+############################################################################
+
+
+class TestResidualLossLeadTime:
+ """Lead-time-label behaviour in ResidualLoss."""
+
+ def test_lead_time_label_passed_to_both_nets(
+ self, img_clean, img_lr, lead_time_label
+ ):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ loss(diff_net, img_clean, img_lr, lead_time_label=lead_time_label)
+ # Regression net
+ reg_kw = reg_net.call_args[1]
+ assert "lead_time_label" in reg_kw
+ torch.testing.assert_close(reg_kw["lead_time_label"], lead_time_label)
+ # Diffusion net
+ diff_kw = diff_net.call_args[1]
+ assert "lead_time_label" in diff_kw
+ torch.testing.assert_close(diff_kw["lead_time_label"], lead_time_label)
+
+ def test_no_lead_time_label_omitted_from_kwargs(self, img_clean, img_lr):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ loss(diff_net, img_clean, img_lr, lead_time_label=None)
+ assert "lead_time_label" not in reg_net.call_args[1]
+ assert "lead_time_label" not in diff_net.call_args[1]
+
+
+############################################################################
+# ResidualLoss — hr_mean_conditioning #
+############################################################################
+
+
+class TestResidualLossHrMeanConditioning:
+ """High-resolution mean conditioning in ResidualLoss."""
+
+ def test_diffusion_conditioning_includes_mean(self, img_clean, img_lr):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net, hr_mean_conditioning=True)
+ loss(diff_net, img_clean, img_lr)
+ y_lr_diff = diff_net.call_args[0][1]
+ # y_lr should have y_mean (C_HR channels) prepended to img_lr (C_LR channels)
+ assert y_lr_diff.shape[1] == C_HR + C_LR
+ torch.testing.assert_close(y_lr_diff[:,:C_HR,:,:],torch.zeros_like(y_lr_diff[:,:C_HR,:,:]))
+ torch.testing.assert_close(y_lr_diff[:,C_HR:,:,:], img_lr)
+
+ def test_diffusion_conditioning_excludes_mean_when_disabled(
+ self, img_clean, img_lr
+ ):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net, hr_mean_conditioning=False)
+ loss(diff_net, img_clean, img_lr)
+ y_lr_diff = diff_net.call_args[0][1]
+ assert y_lr_diff.shape[1] == C_LR
+
+
+############################################################################
+# ResidualLoss — use_patch_grad_acc #
+############################################################################
+
+
+class TestResidualLossPatchGradAcc:
+ """Test use_patch_grad_acc reuse of cached y_mean."""
+
+ def test_y_mean_cached_after_first_call(self, img_clean, img_lr):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ assert loss.y_mean is None
+ loss(diff_net, img_clean, img_lr, use_patch_grad_acc=True)
+ assert loss.y_mean is not None
+
+ def test_regression_net_not_called_when_y_mean_cached(
+ self, img_clean, img_lr
+ ):
+ """When use_patch_grad_acc=True and y_mean is already cached,
+ the regression net should not be called again."""
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ # First call populates y_mean
+ loss(diff_net, img_clean, img_lr, use_patch_grad_acc=True)
+ assert reg_net.call_count == 1
+ # Second call should reuse y_mean
+ loss(diff_net, img_clean, img_lr, use_patch_grad_acc=True)
+ assert reg_net.call_count == 1
+
+ def test_regression_net_called_without_patch_grad_acc(
+ self, img_clean, img_lr
+ ):
+ """Without use_patch_grad_acc, regression net is called every time."""
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ loss(diff_net, img_clean, img_lr, use_patch_grad_acc=False)
+ loss(diff_net, img_clean, img_lr, use_patch_grad_acc=False)
+ assert reg_net.call_count == 2
+
+
+############################################################################
+# ResidualLoss — patching integration #
+############################################################################
+
+
+class TestResidualLossPatching:
+ """Tests for ResidualLoss with RandomPatching2D."""
+
+ @pytest.fixture()
+ def patching(self):
+ return RandomPatching2D(
+ img_shape=(H, W), patch_shape=(32, 32), patch_num=2
+ )
+
+ def test_output_shape_with_patching(self, img_clean, img_lr, patching):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ result = loss(diff_net, img_clean, img_lr, patching=patching)
+ # Output should be (B * num_patches, C_HR, patch_H, patch_W)
+ assert result.shape[1] == C_HR
+ assert result.shape[2] == 32
+ assert result.shape[3] == 32
+ assert result.shape[0] == B * 2 # More samples due to patching
+
+ def test_diffusion_net_receives_arguments(self, img_clean, img_lr, patching):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ loss(diff_net, img_clean, img_lr, patching=patching)
+ y_arg = diff_net.call_args[0][0]
+ y_lr_arg = diff_net.call_args[0][1]
+ sigma_arg = diff_net.call_args[0][2]
+ assert y_arg.shape == (B * 2, C_HR, 32, 32)
+ assert y_lr_arg.shape == (B * 2, 2*C_LR, 32, 32)
+ assert sigma_arg.shape == (2 * B, 1, 1, 1)
+
+ def test_patching_with_static_channels(
+ self, img_clean, img_lr, static_channels, patching
+ ):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net)
+ result = loss(
+ diff_net,
+ img_clean,
+ img_lr,
+ static_channels=static_channels,
+ patching=patching,
+ )
+ y_arg = diff_net.call_args[0][0]
+ y_lr_arg = diff_net.call_args[0][1]
+ assert y_arg.shape == (B * 2, C_HR, 32, 32)
+ assert y_lr_arg.shape == (B * 2, 2*C_LR + 2*static_channels.shape[1], 32, 32)
+ assert result.shape == (B * 2, C_HR, 32, 32)
+
+ def test_patching_with_hr_mean_conditioning(
+ self, img_clean, img_lr, patching
+ ):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(
+ regression_net=reg_net, hr_mean_conditioning=True
+ )
+ result = loss(diff_net, img_clean, img_lr, patching=patching)
+ y_arg = diff_net.call_args[0][0]
+ y_lr_arg = diff_net.call_args[0][1]
+ assert y_arg.shape == (B * 2, C_HR, 32, 32)
+ assert y_lr_arg.shape == (B * 2, 2*C_LR+C_HR, 32, 32)
+ assert result.shape == (B * 2, C_HR, 32, 32)
+
+
+############################################################################
+# ResidualLoss — combined options #
+############################################################################
+
+
+class TestResidualLossCombined:
+ """Tests with multiple optional arguments combined."""
+
+ def test_all_optional_args_no_patching(
+ self,
+ img_clean,
+ img_lr,
+ static_channels,
+ date_embedding,
+ lead_time_label,
+ ):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(
+ regression_net=reg_net, hr_mean_conditioning=True
+ )
+ result = loss(
+ diff_net,
+ img_clean,
+ img_lr,
+ static_channels=static_channels,
+ date_embedding=date_embedding,
+ lead_time_label=lead_time_label,
+ )
+ y_arg = diff_net.call_args[0][0]
+ y_lr_arg = diff_net.call_args[0][1]
+ assert y_arg.shape == (B, C_HR, 64, 64)
+ assert y_lr_arg.shape == (B, C_LR+C_HR+static_channels.shape[1]+date_embedding.shape[1], 64, 64)
+ assert result.shape == img_clean.shape
+
+ def test_all_optional_args_with_patching(
+ self,
+ img_clean,
+ img_lr,
+ static_channels,
+ date_embedding,
+ lead_time_label,
+ ):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ patching = RandomPatching2D(img_shape=(H, W), patch_shape=(32, 32), patch_num=2)
+ loss = ResidualLoss(
+ regression_net=reg_net, hr_mean_conditioning=True
+ )
+ result = loss(
+ diff_net,
+ img_clean,
+ img_lr,
+ static_channels=static_channels,
+ date_embedding=date_embedding,
+ patching=patching,
+ lead_time_label=lead_time_label,
+ )
+ y_arg = diff_net.call_args[0][0]
+ y_lr_arg = diff_net.call_args[0][1]
+ assert y_arg.shape == (B * 2, C_HR, 32, 32)
+ assert y_lr_arg.shape == (B * 2, 2*C_LR + C_HR + 2*static_channels.shape[1] + date_embedding.shape[1], 32, 32)
+ assert result.shape == (B * 2, C_HR, 32, 32)
+
+ def test_augment_with_hr_mean_conditioning_static_and_date(
+ self, img_clean, img_lr, static_channels, date_embedding
+ ):
+ reg_net = _make_dummy_net()
+ diff_net = _make_dummy_net()
+ loss = ResidualLoss(regression_net=reg_net, hr_mean_conditioning=True)
+ result = loss(
+ diff_net,
+ img_clean,
+ img_lr,
+ static_channels=static_channels,
+ date_embedding=date_embedding,
+ augment_pipe=_make_identity_augment(),
+ )
+ y_arg = diff_net.call_args[0][0]
+ y_lr_arg = diff_net.call_args[0][1]
+ assert y_arg.shape == (B, C_HR, 64, 64)
+ assert y_lr_arg.shape == (B, C_LR+C_HR+static_channels.shape[1]+date_embedding.shape[1], 64, 64)
+ assert result.shape == img_clean.shape
+ assert (result >= 0).all()
diff --git a/tests/models/__init__.py b/tests/models/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/models/test_layers.py b/tests/models/test_layers.py
new file mode 100644
index 00000000..774a9bab
--- /dev/null
+++ b/tests/models/test_layers.py
@@ -0,0 +1,950 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import pytest
+import torch
+import torch.nn as nn
+from unittest.mock import MagicMock, patch
+
+from hirad.models.layers import (
+ AttentionOp,
+ Conv2d,
+ FourierEmbedding,
+ GroupNorm,
+ Linear,
+ PositionalEmbedding,
+ UNetBlock,
+)
+
+
+# ---------------------------------------------------------------------------
+# Helpers / fixtures — use small configs for fast CPU tests
+# ---------------------------------------------------------------------------
+
+B = 2
+IN_CH = 4
+OUT_CH = 8
+H, W = 16, 16
+EMB_CH = 32
+
+
+@pytest.fixture()
+def random_input_2d():
+ return torch.randn(B, IN_CH, H, W)
+
+
+@pytest.fixture()
+def random_input_flat():
+ return torch.randn(B, IN_CH)
+
+
+@pytest.fixture()
+def embedding():
+ return torch.randn(B, EMB_CH)
+
+
+############################################################################
+# Linear #
+############################################################################
+
+
+class TestLinearInit:
+ """Test Linear.__init__ parameter setup."""
+
+ def test_weight_shape(self):
+ layer = Linear(in_features=IN_CH, out_features=OUT_CH)
+ assert layer.weight.shape == (OUT_CH, IN_CH)
+
+ def test_bias_shape(self):
+ layer = Linear(in_features=IN_CH, out_features=OUT_CH)
+ assert layer.bias is not None
+ assert layer.bias.shape == (OUT_CH,)
+
+ def test_no_bias_when_disabled(self):
+ layer = Linear(in_features=IN_CH, out_features=OUT_CH, bias=False)
+ assert layer.bias is None
+
+ def test_stores_features(self):
+ layer = Linear(in_features=IN_CH,
+ out_features=OUT_CH,
+ amp_mode=True)
+ assert layer.in_features == IN_CH
+ assert layer.out_features == OUT_CH
+ assert layer.amp_mode==True
+
+ @pytest.mark.parametrize(
+ "init_mode",
+ ["xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal"],
+ )
+ def test_all_init_modes_accepted(self, init_mode):
+ layer = Linear(in_features=IN_CH, out_features=OUT_CH, init_mode=init_mode)
+ assert layer.weight.shape == (OUT_CH, IN_CH)
+
+ def test_init_weight_scaling(self):
+ layer = Linear(in_features=IN_CH, out_features=OUT_CH, init_weight=0.0)
+ torch.testing.assert_close(layer.weight, torch.zeros(OUT_CH, IN_CH))
+
+ def test_init_bias_scaling(self):
+ layer = Linear(in_features=IN_CH, out_features=OUT_CH, init_bias=0)
+ torch.testing.assert_close(layer.bias, torch.zeros(OUT_CH))
+
+
+class TestLinearForward:
+ """Test Linear forward pass."""
+
+ def test_output_shape(self, random_input_flat):
+ layer = Linear(in_features=IN_CH, out_features=OUT_CH)
+ out = layer(random_input_flat)
+ assert out.shape == (B, OUT_CH)
+
+ def test_output_shape_no_bias(self, random_input_flat):
+ layer = Linear(in_features=IN_CH, out_features=OUT_CH, bias=False)
+ out = layer(random_input_flat)
+ assert out.shape == (B, OUT_CH)
+
+ def test_zero_weight_zero_bias_returns_zero(self, random_input_flat):
+ layer = Linear(
+ in_features=IN_CH,
+ out_features=OUT_CH,
+ init_weight=0,
+ init_bias=0,
+ )
+ out = layer(random_input_flat)
+ torch.testing.assert_close(out, torch.zeros(B, OUT_CH))
+
+ def test_output_dtype_matches_input(self, random_input_flat):
+ layer = Linear(in_features=IN_CH, out_features=OUT_CH)
+ out = layer(random_input_flat)
+ assert out.dtype == random_input_flat.dtype
+
+ def test_gradients_flow(self, random_input_flat):
+ layer = Linear(in_features=IN_CH, out_features=OUT_CH)
+ x = random_input_flat.clone().requires_grad_(True)
+ out = layer(x)
+ out.sum().backward()
+ assert x.grad is not None
+ assert x.grad.shape == x.shape
+
+
+############################################################################
+# Conv2d #
+############################################################################
+
+
+class TestConv2dInit:
+ """Test Conv2d.__init__ parameter setup."""
+
+ def test_weight_shape(self):
+ layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3)
+ assert layer.weight.shape == (OUT_CH, IN_CH, 3, 3)
+
+ def test_bias_shape(self):
+ layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3)
+ assert layer.bias is not None
+ assert layer.bias.shape == (OUT_CH,)
+
+ def test_no_bias_when_disabled(self):
+ layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3, bias=False)
+ assert layer.bias is None
+
+ def test_kernel_zero_no_weight(self):
+ layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=0)
+ assert layer.weight is None
+ assert layer.bias is None
+
+ def test_up_and_down_raises(self):
+ with pytest.raises(ValueError, match="Both 'up' and 'down'"):
+ Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3, up=True, down=True)
+
+ def test_stores_flags(self):
+ layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3, up=True, fused_resample=True, fused_conv_bias=True, amp_mode=True)
+ assert layer.up is True
+ assert layer.down is False
+ assert layer.fused_resample is True
+ assert layer.fused_conv_bias is True
+ assert layer.amp_mode is True
+ assert layer.in_channels == IN_CH
+ assert layer.out_channels == OUT_CH
+
+ def test_resample_filter_registered_when_up(self):
+ layer = Conv2d(
+ in_channels=IN_CH, out_channels=OUT_CH, kernel=3, up=True
+ )
+ assert layer.resample_filter is not None
+
+ def test_resample_filter_registered_when_down(self):
+ layer = Conv2d(
+ in_channels=IN_CH, out_channels=OUT_CH, kernel=3, down=True
+ )
+ assert layer.resample_filter is not None
+
+ def test_resample_filter_none_when_no_up_down(self):
+ layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3)
+ assert layer.resample_filter is None
+
+ def test_fused_conv_bias_disabled_when_no_kernel(self):
+ layer = Conv2d(
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ kernel=0,
+ fused_conv_bias=True,
+ )
+ assert layer.fused_conv_bias is False
+
+ def test_zero_weight_bias_init_gives_zero_weights_and_biases(self):
+ layer = Conv2d(
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ kernel=3,
+ init_weight=0,
+ init_bias=0,
+ )
+ torch.testing.assert_close(layer.weight, torch.zeros(OUT_CH, IN_CH, 3, 3))
+ torch.testing.assert_close(layer.bias, torch.zeros(OUT_CH))
+
+
+class TestConv2dForward:
+ """Test Conv2d forward pass."""
+
+ def test_output_shape_same_padding(self, random_input_2d):
+ layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3)
+ out = layer(random_input_2d)
+ assert out.shape == (B, OUT_CH, H, W)
+
+ def test_output_shape_kernel_1(self, random_input_2d):
+ layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=1)
+ out = layer(random_input_2d)
+ assert out.shape == (B, OUT_CH, H, W)
+
+ def test_output_shape_upsample(self, random_input_2d):
+ layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3, up=True)
+ out = layer(random_input_2d)
+ assert out.shape == (B, OUT_CH, H * 2, W * 2)
+
+ def test_output_shape_downsample(self, random_input_2d):
+ layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3, down=True)
+ out = layer(random_input_2d)
+ assert out.shape == (B, OUT_CH, H // 2, W // 2)
+
+ def test_output_shape_fused_upsample(self, random_input_2d):
+ layer = Conv2d(
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ kernel=3,
+ up=True,
+ fused_resample=True,
+ )
+ out = layer(random_input_2d)
+ assert out.shape == (B, OUT_CH, H * 2, W * 2)
+
+ def test_output_shape_fused_downsample(self, random_input_2d):
+ layer = Conv2d(
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ kernel=3,
+ down=True,
+ fused_resample=True,
+ )
+ out = layer(random_input_2d)
+ assert out.shape == (B, OUT_CH, H // 2, W // 2)
+
+ def test_output_shape_fused_conv_bias(self, random_input_2d):
+ layer = Conv2d(
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ kernel=3,
+ fused_conv_bias=True,
+ )
+ out = layer(random_input_2d)
+ assert out.shape == (B, OUT_CH, H, W)
+
+ def test_output_shape_fused_up_with_conv_bias(self, random_input_2d):
+ layer = Conv2d(
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ kernel=3,
+ up=True,
+ fused_resample=True,
+ fused_conv_bias=True,
+ )
+ out = layer(random_input_2d)
+ assert out.shape == (B, OUT_CH, H * 2, W * 2)
+
+ def test_output_shape_fused_down_with_conv_bias(self, random_input_2d):
+ layer = Conv2d(
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ kernel=3,
+ down=True,
+ fused_resample=True,
+ fused_conv_bias=True,
+ )
+ out = layer(random_input_2d)
+ assert out.shape == (B, OUT_CH, H // 2, W // 2)
+
+ def test_output_dtype_matches_input(self, random_input_2d):
+ layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3)
+ out = layer(random_input_2d)
+ assert out.dtype == random_input_2d.dtype
+
+ def test_gradients_flow(self, random_input_2d):
+ layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3)
+ x = random_input_2d.clone().requires_grad_(True)
+ out = layer(x)
+ out.sum().backward()
+ assert x.grad is not None
+ assert x.grad.shape == x.shape
+
+ def test_kernel_zero_passthrough(self, random_input_2d):
+ """With kernel=0, no convolution should be applied, just pass-through."""
+ layer = Conv2d(in_channels=IN_CH, out_channels=IN_CH, kernel=0)
+ out = layer(random_input_2d)
+ torch.testing.assert_close(out, random_input_2d)
+
+
+############################################################################
+# GroupNorm #
+############################################################################
+
+
+class TestGroupNormInit:
+ """Test GroupNorm.__init__ parameter setup."""
+
+ def test_weight_shape(self):
+ gn = GroupNorm(num_channels=OUT_CH)
+ assert gn.weight.shape == (OUT_CH,)
+
+ def test_bias_shape(self):
+ gn = GroupNorm(num_channels=OUT_CH)
+ assert gn.bias.shape == (OUT_CH,)
+
+ def test_weight_initialized_to_ones(self):
+ gn = GroupNorm(num_channels=OUT_CH)
+ torch.testing.assert_close(gn.weight, torch.ones(OUT_CH))
+
+ def test_bias_initialized_to_zeros(self):
+ gn = GroupNorm(num_channels=OUT_CH)
+ torch.testing.assert_close(gn.bias, torch.zeros(OUT_CH))
+
+ def test_num_groups_clipped_to_min_channels(self):
+ """If num_channels // min_channels_per_group < num_groups, groups are reduced."""
+ gn = GroupNorm(num_channels=8, num_groups=32, min_channels_per_group=4)
+ assert gn.num_groups == 2
+
+ def test_num_groups_matches_when_divisible(self):
+ gn = GroupNorm(num_channels=32, num_groups=8, min_channels_per_group=2)
+ assert gn.num_groups == 8
+
+ def test_fused_act_without_act_raises(self):
+ with pytest.raises(ValueError, match="'act' must be specified"):
+ GroupNorm(num_channels=OUT_CH, fused_act=True, act=None)
+
+ def test_fused_act_with_valid_act(self):
+ gn = GroupNorm(num_channels=OUT_CH, fused_act=True, act="silu")
+ assert gn.fused_act is True
+ assert gn.act == "silu"
+ assert gn.act_fn is not None
+
+ def test_eps_and_amp_mode_stored(self):
+ gn = GroupNorm(num_channels=OUT_CH, eps=1e-6, amp_mode=True)
+ assert gn.eps == 1e-6
+ assert gn.amp_mode is True
+
+ def test_apex_gn_initializes_gn_when_available(self):
+ mock_gn_cls = MagicMock()
+ mock_gn_instance = MagicMock()
+ mock_gn_cls.return_value = mock_gn_instance
+ with patch("hirad.models.layers._is_apex_available", True), \
+ patch("hirad.models.layers.ApexGroupNorm", mock_gn_cls, create=True):
+ gn = GroupNorm(num_channels=OUT_CH, use_apex_gn=True)
+ assert hasattr(gn, "gn")
+ assert gn.gn is mock_gn_instance
+ mock_gn_cls.assert_called_once()
+
+ def test_apex_gn_raises_when_not_available(self):
+ with patch("hirad.models.layers._is_apex_available", False):
+ with pytest.raises(ValueError, match="'apex' is not"):
+ GroupNorm(num_channels=OUT_CH, use_apex_gn=True)
+
+
+class TestGroupNormForward:
+ """Test GroupNorm forward pass."""
+
+ def test_output_shape(self, random_input_2d):
+ gn = GroupNorm(num_channels=IN_CH)
+ out = gn(random_input_2d)
+ assert out.shape == random_input_2d.shape
+
+ def test_output_dtype_matches_input(self, random_input_2d):
+ gn = GroupNorm(num_channels=IN_CH)
+ gn.train()
+ out = gn(random_input_2d)
+ assert out.dtype == random_input_2d.dtype
+ gn.eval()
+ out = gn(random_input_2d)
+ assert out.dtype == random_input_2d.dtype
+
+ def test_training_mode_uses_torch_group_norm(self, random_input_2d):
+ """In training mode, output should match torch.nn.functional.group_norm."""
+ gn = GroupNorm(num_channels=IN_CH)
+ gn.train()
+ out = gn(random_input_2d)
+ expected = torch.nn.functional.group_norm(
+ random_input_2d, num_groups=gn.num_groups, weight=gn.weight, bias=gn.bias, eps=gn.eps
+ )
+ torch.testing.assert_close(out, expected)
+
+ def test_eval_mode_output_shape(self, random_input_2d):
+ gn = GroupNorm(num_channels=IN_CH)
+ gn.eval()
+ out = gn(random_input_2d)
+ assert out.shape == random_input_2d.shape
+
+ def test_apex_gn_forward(self, random_input_2d):
+ """Test forward pass when using Apex GroupNorm."""
+ from hirad.models.layers import _is_apex_available
+ if _is_apex_available:
+ gn = GroupNorm(num_channels=IN_CH, use_apex_gn=True)
+ called = []
+ gn.gn.register_forward_hook(lambda m, i, o: called.append(True))
+ out = gn(random_input_2d)
+ assert out.shape == random_input_2d.shape
+ assert called, "Apex GroupNorm forward hook was not called, so it may not have been used."
+ else:
+ mock_gn_cls = MagicMock()
+ mock_gn_instance = MagicMock()
+ mock_gn_instance.forward.return_value = random_input_2d
+ mock_gn_cls.return_value = mock_gn_instance
+ with patch("hirad.models.layers._is_apex_available", True), \
+ patch("hirad.models.layers.ApexGroupNorm", mock_gn_cls, create=True):
+ gn = GroupNorm(num_channels=IN_CH, use_apex_gn=True)
+ out = gn(random_input_2d)
+ assert out.shape == random_input_2d.shape
+ assert mock_gn_instance.assert_called_once
+
+ def test_training_mode_with_fused_act(self, random_input_2d):
+ """Test that fused activation is applied in training mode."""
+ gn = GroupNorm(num_channels=IN_CH, fused_act=True, act="relu")
+ gn.train()
+ out = gn(random_input_2d)
+ assert out.shape == random_input_2d.shape
+ assert (out >= 0).all(), "Output should be non-negative due to ReLU activation"
+
+ def test_eval_mode_with_fused_act(self, random_input_2d):
+ gn = GroupNorm(num_channels=IN_CH, fused_act=True, act="relu")
+ gn.eval()
+ out = gn(random_input_2d)
+ assert out.shape == random_input_2d.shape
+ assert (out >= 0).all(), "Output should be non-negative due to ReLU activation"
+
+ def test_training_fused_act_actually_applies_activation(self, random_input_2d):
+ """Verify fused act path produces different output than non-fused path."""
+ gn_fused = GroupNorm(num_channels=IN_CH, fused_act=True, act="relu")
+ gn_plain = GroupNorm(num_channels=IN_CH)
+ gn_fused.train()
+ gn_plain.train()
+ out_fused = gn_fused(random_input_2d)
+ out_plain = gn_plain(random_input_2d)
+ # Plain output will have negatives; fused should not
+ assert (out_plain < 0).any(), "Plain output should have negatives for meaningful test"
+ assert (out_fused >= 0).all()
+
+ def test_eval_fused_act_actually_applies_activation(self, random_input_2d):
+ """Verify fused act path produces different output than non-fused path."""
+ gn_fused = GroupNorm(num_channels=IN_CH, fused_act=True, act="relu")
+ gn_plain = GroupNorm(num_channels=IN_CH)
+ gn_fused.eval()
+ gn_plain.eval()
+ out_fused = gn_fused(random_input_2d)
+ out_plain = gn_plain(random_input_2d)
+ # Plain output will have negatives; fused should not
+ assert (out_plain < 0).any(), "Plain output should have negatives for meaningful test"
+ assert (out_fused >= 0).all()
+
+ def test_eval_mode_matches_training_mode(self, random_input_2d):
+ """Eval and training modes should produce close results."""
+ gn = GroupNorm(num_channels=IN_CH)
+ gn.train()
+ out_train = gn(random_input_2d)
+ gn.eval()
+ out_eval = gn(random_input_2d)
+ torch.testing.assert_close(out_train, out_eval, atol=1e-5, rtol=1e-5)
+
+ def test_normalized_output_has_zero_mean(self, random_input_2d):
+ """After GroupNorm with default weight=1 and bias=0, each group should
+ have approximately zero mean."""
+ gn = GroupNorm(num_channels=IN_CH)
+ gn.train()
+ out = gn(random_input_2d)
+ # Reshape to groups and check mean is near zero
+ reshaped = out.reshape(B, gn.num_groups, IN_CH // gn.num_groups, H, W)
+ group_means = reshaped.mean(dim=[2, 3, 4])
+ assert group_means.abs().max() < 0.1
+
+ def test_normalized_output_has_variance_close_to_one(self, random_input_2d):
+ """After GroupNorm with default weight=1 and bias=0, each group should
+ have variance close to 1 (not exactly 1 due to eps)."""
+ gn = GroupNorm(num_channels=IN_CH)
+ gn.train()
+ out = gn(random_input_2d)
+ # Reshape to groups and check variance is near 1
+ reshaped = out.reshape(B, gn.num_groups, IN_CH // gn.num_groups, H, W)
+ group_vars = reshaped.var(dim=[2, 3, 4], unbiased=False)
+ assert torch.allclose(group_vars, torch.ones_like(group_vars), atol=0.1)
+
+ def test_gradients_flow(self, random_input_2d):
+ gn = GroupNorm(num_channels=IN_CH)
+ gn.train()
+ x = random_input_2d.clone().requires_grad_(True)
+ out = gn(x)
+ out.sum().backward()
+ assert x.grad is not None
+ assert x.grad.shape == x.shape
+
+
+class TestGroupNormFusedActivation:
+ """Test GroupNorm with fused activation functions."""
+
+ @pytest.mark.parametrize(
+ "act_name", ["silu", "relu", "leaky_relu", "sigmoid", "tanh", "gelu", "elu"]
+ )
+ def test_fused_act_accepted(self, act_name, random_input_2d):
+ gn = GroupNorm(num_channels=IN_CH, fused_act=True, act=act_name)
+ gn.train()
+ out = gn(random_input_2d)
+ assert out.shape == random_input_2d.shape
+
+ def test_invalid_act_raises(self):
+ with pytest.raises(ValueError, match="Unknown activation function"):
+ GroupNorm(num_channels=OUT_CH, fused_act=True, act="invalid_act")
+
+ @pytest.mark.parametrize(
+ "act_name", ["silu", "relu", "leaky_relu", "sigmoid", "tanh", "gelu", "elu"]
+ )
+ def test_fused_act_matches_separate(self, act_name, random_input_2d):
+ """Fused activation should give the same result as applying the activation separately."""
+ gn_fused = GroupNorm(num_channels=IN_CH, fused_act=True, act=act_name)
+ gn_plain = GroupNorm(num_channels=IN_CH)
+ # Copy parameters
+ gn_fused.train()
+ gn_plain.train()
+ out_fused = gn_fused(random_input_2d)
+ out_separate = getattr(torch.nn.functional, act_name)(gn_plain(random_input_2d))
+ torch.testing.assert_close(out_fused, out_separate)
+
+
+############################################################################
+# AttentionOp #
+############################################################################
+
+
+class TestAttentionOpForward:
+ """Test AttentionOp forward pass."""
+
+ def test_output_shape(self):
+ q = torch.randn(B, 16, 8)
+ k = torch.randn(B, 16, 8)
+ w = AttentionOp.apply(q, k)
+ assert w.shape == (B, 8, 8)
+
+ def test_output_is_probability_distribution(self):
+ """Each row of the attention weights should sum to 1 (softmax output)."""
+ q = torch.randn(B, 16, 8)
+ k = torch.randn(B, 16, 8)
+ w = AttentionOp.apply(q, k)
+ row_sums = w.sum(dim=2)
+ torch.testing.assert_close(row_sums, torch.ones_like(row_sums), atol=1e-5, rtol=1e-5)
+
+ def test_output_non_negative(self):
+ q = torch.randn(B, 16, 8)
+ k = torch.randn(B, 16, 8)
+ w = AttentionOp.apply(q, k)
+ assert (w >= 0).all()
+
+ def test_output_dtype_matches_input(self):
+ q = torch.randn(B, 16, 8)
+ k = torch.randn(B, 16, 8)
+ w = AttentionOp.apply(q, k)
+ assert w.dtype == q.dtype
+
+
+class TestAttentionOpBackward:
+ """Test AttentionOp backward pass."""
+
+ def test_gradients_flow_to_q(self):
+ q = torch.randn(B, 16, 8, requires_grad=True)
+ k = torch.randn(B, 16, 8, requires_grad=True)
+ w = AttentionOp.apply(q, k)
+ w.sum().backward()
+ assert q.grad is not None
+ assert q.grad.shape == q.shape
+
+ def test_gradients_flow_to_k(self):
+ q = torch.randn(B, 16, 8, requires_grad=True)
+ k = torch.randn(B, 16, 8, requires_grad=True)
+ w = AttentionOp.apply(q, k)
+ w.sum().backward()
+ assert k.grad is not None
+ assert k.grad.shape == k.shape
+
+
+############################################################################
+# UNetBlock #
+############################################################################
+
+
+class TestUNetBlockInit:
+ """Test UNetBlock.__init__ parameter setup."""
+
+ def test_stores_block_info(self):
+ block = UNetBlock(
+ in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH,
+ dropout=0.1, skip_scale=0.5, adaptive_scale=False,
+ profile_mode=True, amp_mode=True, attention=True,
+ num_heads=4
+ )
+ assert block.in_channels == IN_CH
+ assert block.out_channels == OUT_CH
+ assert block.emb_channels == EMB_CH
+ assert block.dropout == 0.1
+ assert block.skip_scale == 0.5
+ assert block.adaptive_scale is False
+ assert block.profile_mode is True
+ assert block.amp_mode is True
+
+ def test_num_heads_zero_when_no_attention(self):
+ block = UNetBlock(
+ in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH, attention=False
+ )
+ assert block.num_heads == 0
+
+ def test_num_heads_set_when_attention_no_num_heads(self):
+ block = UNetBlock(
+ in_channels=IN_CH, out_channels=16, emb_channels=EMB_CH,
+ attention=True, channels_per_head=4
+ )
+ assert block.num_heads == 16//4
+
+ def test_skip_created_when_channels_differ(self):
+ block = UNetBlock(
+ in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH
+ )
+ assert block.skip is not None
+
+ def test_skip_none_when_channels_match(self):
+ block = UNetBlock(
+ in_channels=OUT_CH, out_channels=OUT_CH, emb_channels=EMB_CH
+ )
+ assert block.skip is None
+
+ def test_skip_created_when_up(self):
+ block = UNetBlock(
+ in_channels=OUT_CH, out_channels=OUT_CH, emb_channels=EMB_CH, up=True
+ )
+ assert block.skip is not None
+
+ def test_skip_created_when_down(self):
+ block = UNetBlock(
+ in_channels=OUT_CH, out_channels=OUT_CH, emb_channels=EMB_CH, down=True
+ )
+ assert block.skip is not None
+
+ def test_attention_heads_not_created_when_attention_false(self):
+ block = UNetBlock(
+ in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH, attention=False
+ )
+ assert not hasattr(block, "norm2") or block.norm2 is None
+ assert not hasattr(block, "qkv") or block.qkv is None
+ assert not hasattr(block, "proj") or block.proj is None
+
+ def test_attention_heads_default(self):
+ block = UNetBlock(
+ in_channels=64,
+ out_channels=64,
+ emb_channels=EMB_CH,
+ attention=True,
+ channels_per_head=64,
+ )
+ assert block.norm2 is not None
+ assert block.qkv is not None
+ assert block.proj is not None
+
+ def test_has_norm_and_conv_layers(self):
+ block = UNetBlock(
+ in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH
+ )
+ assert hasattr(block, "norm0")
+ assert hasattr(block, "conv0")
+ assert hasattr(block, "norm1")
+ assert hasattr(block, "conv1")
+ assert hasattr(block, "affine")
+
+
+class TestUNetBlockForward:
+ """Test UNetBlock forward pass."""
+
+ def test_output_shape_same_channels(self, random_input_2d, embedding):
+ block = UNetBlock(
+ in_channels=IN_CH, out_channels=IN_CH, emb_channels=EMB_CH
+ )
+ out = block(random_input_2d, embedding)
+ assert out.shape == (B, IN_CH, H, W)
+
+ def test_output_shape_different_channels(self, random_input_2d, embedding):
+ block = UNetBlock(
+ in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH
+ )
+ out = block(random_input_2d, embedding)
+ assert out.shape == (B, OUT_CH, H, W)
+
+ def test_output_shape_upsample(self, random_input_2d, embedding):
+ block = UNetBlock(
+ in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH, up=True
+ )
+ out = block(random_input_2d, embedding)
+ assert out.shape == (B, OUT_CH, H * 2, W * 2)
+
+ def test_output_shape_downsample(self, random_input_2d, embedding):
+ block = UNetBlock(
+ in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH, down=True
+ )
+ out = block(random_input_2d, embedding)
+ assert out.shape == (B, OUT_CH, H // 2, W // 2)
+
+ def test_output_dtype_matches_input(self, random_input_2d, embedding):
+ block = UNetBlock(
+ in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH
+ )
+ out = block(random_input_2d, embedding)
+ assert out.dtype == random_input_2d.dtype
+
+ def test_gradients_flow(self, random_input_2d, embedding):
+ block = UNetBlock(
+ in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH
+ )
+ x = random_input_2d.clone().requires_grad_(True)
+ out = block(x, embedding)
+ out.sum().backward()
+ assert x.grad is not None
+ assert x.grad.shape == x.shape
+
+ def test_with_attention(self, embedding):
+ ch = 64
+ x = torch.randn(B, ch, H, W)
+ block = UNetBlock(
+ in_channels=ch,
+ out_channels=ch,
+ emb_channels=EMB_CH,
+ attention=True,
+ channels_per_head=ch//2,
+ )
+ out = block(x, embedding)
+ assert out.shape == (B, ch, H, W)
+
+ def test_non_adaptive_scale(self, random_input_2d, embedding):
+ block = UNetBlock(
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ emb_channels=EMB_CH,
+ adaptive_scale=False,
+ )
+ out = block(random_input_2d, embedding)
+ assert out.shape == (B, OUT_CH, H, W)
+
+ def test_with_dropout(self, random_input_2d, embedding):
+ block = UNetBlock(
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ emb_channels=EMB_CH,
+ dropout=0.1,
+ )
+ block.train()
+ out = block(random_input_2d, embedding)
+ assert out.shape == (B, OUT_CH, H, W)
+
+ def test_with_amp_mode(self, random_input_2d, embedding):
+ block = UNetBlock(
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ emb_channels=EMB_CH,
+ amp_mode=True,
+ )
+ out = block(random_input_2d, embedding)
+ assert out.shape == (B, OUT_CH, H, W)
+
+ def test_skip_scale_applied(self, random_input_2d, embedding):
+ """Changing skip_scale should change the output magnitude."""
+ block_s1 = UNetBlock(
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ emb_channels=EMB_CH,
+ skip_scale=1.0,
+ )
+ block_s2 = UNetBlock(
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ emb_channels=EMB_CH,
+ skip_scale=2.0,
+ )
+ # Copy parameters from block_s1 to block_s2
+ block_s2.load_state_dict(block_s1.state_dict(), strict=False)
+ out1 = block_s1(random_input_2d, embedding)
+ out2 = block_s2(random_input_2d, embedding)
+ # s2 should have roughly 2x the magnitude of s1
+ torch.testing.assert_close(out2, out1 * 2.0, atol=1e-5, rtol=1e-5)
+
+
+############################################################################
+# PositionalEmbedding #
+############################################################################
+
+
+class TestPositionalEmbeddingInit:
+ """Test PositionalEmbedding.__init__."""
+
+ def test_stores_num_channels(self):
+ emb = PositionalEmbedding(num_channels=64)
+ assert emb.num_channels == 64
+
+ def test_stores_max_positions(self):
+ emb = PositionalEmbedding(num_channels=64, max_positions=5000)
+ assert emb.max_positions == 5000
+
+ def test_stores_endpoint(self):
+ emb = PositionalEmbedding(num_channels=64, endpoint=True)
+ assert emb.endpoint is True
+
+ def test_amp_mode(self):
+ emb = PositionalEmbedding(num_channels=64, amp_mode=True)
+ assert emb.amp_mode is True
+
+
+class TestPositionalEmbeddingForward:
+ """Test PositionalEmbedding forward pass."""
+
+ def test_output_shape(self):
+ emb = PositionalEmbedding(num_channels=64)
+ x = torch.randn(B)
+ out = emb(x)
+ assert out.shape == (B, 64)
+
+ def test_output_shape_single(self):
+ emb = PositionalEmbedding(num_channels=32)
+ x = torch.randn(1)
+ out = emb(x)
+ assert out.shape == (1, 32)
+
+ def test_different_inputs_produce_different_embeddings(self):
+ emb = PositionalEmbedding(num_channels=64)
+ x1 = torch.tensor([0.1])
+ x2 = torch.tensor([1.0])
+ out1 = emb(x1)
+ out2 = emb(x2)
+ assert not torch.allclose(out1, out2)
+
+ def test_same_input_produces_same_embedding(self):
+ emb = PositionalEmbedding(num_channels=64)
+ x = torch.tensor([0.5])
+ out1 = emb(x)
+ out2 = emb(x)
+ torch.testing.assert_close(out1, out2)
+
+ def test_output_contains_sin_and_cos(self):
+ """Output is concatenation of cos and sin, so first and second halves
+ should differ for non-trivial inputs."""
+ emb = PositionalEmbedding(num_channels=64)
+ x = torch.tensor([1.0])
+ out = emb(x)
+ first_half = out[:, :32]
+ second_half = out[:, 32:]
+ assert not torch.allclose(first_half, second_half)
+
+ def test_output_bounded(self):
+ """Since output is cos and sin, values should be in [-1, 1]."""
+ emb = PositionalEmbedding(num_channels=64)
+ x = torch.randn(B)
+ out = emb(x)
+ assert out.min() >= -1.0 - 1e-6
+ assert out.max() <= 1.0 + 1e-6
+
+ def test_endpoint_changes_output(self):
+ emb_no_end = PositionalEmbedding(num_channels=64, endpoint=False)
+ emb_end = PositionalEmbedding(num_channels=64, endpoint=True)
+ x = torch.tensor([1.0])
+ out1 = emb_no_end(x)
+ out2 = emb_end(x)
+ assert not torch.allclose(out1, out2)
+
+
+############################################################################
+# FourierEmbedding #
+############################################################################
+
+
+class TestFourierEmbeddingInit:
+ """Test FourierEmbedding.__init__."""
+
+ def test_freqs_buffer_registered(self):
+ emb = FourierEmbedding(num_channels=64)
+ assert hasattr(emb, "freqs")
+ assert emb.freqs.shape == (32,)
+
+ def test_scale_affects_freqs_magnitude(self):
+ torch.manual_seed(0)
+ emb_small = FourierEmbedding(num_channels=64, scale=1)
+ torch.manual_seed(0)
+ emb_large = FourierEmbedding(num_channels=64, scale=16)
+ torch.testing.assert_close(emb_large.freqs, emb_small.freqs * 16)
+
+ def test_amp_mode_stored(self):
+ emb = FourierEmbedding(num_channels=64, amp_mode=True)
+ assert emb.amp_mode is True
+
+
+class TestFourierEmbeddingForward:
+ """Test FourierEmbedding forward pass."""
+
+ def test_output_shape(self):
+ emb = FourierEmbedding(num_channels=64)
+ x = torch.randn(B)
+ out = emb(x)
+ assert out.shape == (B, 64)
+
+ def test_output_shape_single(self):
+ emb = FourierEmbedding(num_channels=32)
+ x = torch.randn(1)
+ out = emb(x)
+ assert out.shape == (1, 32)
+
+ def test_different_inputs_produce_different_embeddings(self):
+ emb = FourierEmbedding(num_channels=64)
+ x1 = torch.tensor([0.1])
+ x2 = torch.tensor([1.0])
+ out1 = emb(x1)
+ out2 = emb(x2)
+ assert not torch.allclose(out1, out2)
+
+ def test_same_input_produces_same_embedding(self):
+ emb = FourierEmbedding(num_channels=64)
+ x = torch.tensor([0.5])
+ out1 = emb(x)
+ out2 = emb(x)
+ torch.testing.assert_close(out1, out2)
+
+ def test_output_contains_sin_and_cos(self):
+ emb = FourierEmbedding(num_channels=64)
+ x = torch.tensor([1.0])
+ out = emb(x)
+ first_half = out[:, :32]
+ second_half = out[:, 32:]
+ assert not torch.allclose(first_half, second_half)
+
+ def test_output_bounded(self):
+ """Since output is cos and sin, values should be in [-1, 1]."""
+ emb = FourierEmbedding(num_channels=64)
+ x = torch.randn(B)
+ out = emb(x)
+ assert out.min() >= -1.0 - 1e-6
+ assert out.max() <= 1.0 + 1e-6
diff --git a/tests/models/test_preconditioning.py b/tests/models/test_preconditioning.py
new file mode 100644
index 00000000..a802e596
--- /dev/null
+++ b/tests/models/test_preconditioning.py
@@ -0,0 +1,604 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+import torch
+import torch.nn as nn
+
+from hirad.models.preconditioning import EDMPrecondSuperResolution
+
+
+# ---------------------------------------------------------------------------
+# Helpers / fixtures
+# ---------------------------------------------------------------------------
+
+B, C_IN, C_OUT, H, W = 2, 4, 3, 64, 64
+
+
+def _make_mock_model(out_channels=C_OUT):
+ """Return a MagicMock that behaves like a SongUNet-style model."""
+ model = MagicMock(spec=nn.Module)
+ model.side_effect = lambda x, sigma, class_labels=None, **kw: torch.zeros(
+ x.shape[0], out_channels, x.shape[2], x.shape[3],
+ dtype=x.dtype, device=x.device,
+ )
+ model.modules.return_value = iter([])
+ return model
+
+
+@pytest.fixture()
+def img_x():
+ return torch.randn(B, C_OUT, H, W)
+
+
+@pytest.fixture()
+def img_lr():
+ return torch.randn(B, C_IN, H, W)
+
+
+@pytest.fixture()
+def sigma():
+ return torch.ones(B) * 0.5
+
+
+############################################################################
+# EDMPrecondSuperResolution — __init__ #
+############################################################################
+
+
+class TestEDMInitResolution:
+ """Test EDMPrecondSuperResolution.__init__ resolution handling."""
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_int_resolution_stored(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ edm = EDMPrecondSuperResolution(
+ img_resolution=128, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ assert edm.img_resolution == 128
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_tuple_resolution_stored(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ edm = EDMPrecondSuperResolution(
+ img_resolution=(96, 128), img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ assert edm.img_resolution == (96, 128)
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_stores_channel_counts(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ edm = EDMPrecondSuperResolution(
+ img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ assert edm.img_in_channels == C_IN
+ assert edm.img_out_channels == C_OUT
+
+
+class TestEDMInitModelType:
+ """Test that EDMPrecondSuperResolution creates the correct underlying model type."""
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_default_model_type_is_song_unet_pos_embd(self, mock_module):
+ mock_cls = MagicMock()
+ mock_module.SongUNetPosEmbd = mock_cls
+ EDMPrecondSuperResolution(
+ img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ mock_cls.assert_called_once()
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_custom_model_type_song_unet(self, mock_module):
+ mock_cls = MagicMock()
+ mock_module.SongUNet = mock_cls
+ EDMPrecondSuperResolution(
+ img_resolution=64,
+ img_in_channels=C_IN,
+ img_out_channels=C_OUT,
+ model_type="SongUNet",
+ )
+ mock_cls.assert_called_once()
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_model_receives_img_resolution(self, mock_module):
+ mock_cls = MagicMock()
+ mock_module.SongUNetPosEmbd = mock_cls
+ EDMPrecondSuperResolution(
+ img_resolution=(80, 120), img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ call_kwargs = mock_cls.call_args[1]
+ assert call_kwargs["img_resolution"] == (80, 120)
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_model_receives_combined_in_channels(self, mock_module):
+ mock_cls = MagicMock()
+ mock_module.SongUNetPosEmbd = mock_cls
+ EDMPrecondSuperResolution(
+ img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ call_kwargs = mock_cls.call_args[1]
+ assert call_kwargs["in_channels"] == C_IN + C_OUT
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_model_receives_out_channels(self, mock_module):
+ mock_cls = MagicMock()
+ mock_module.SongUNetPosEmbd = mock_cls
+ EDMPrecondSuperResolution(
+ img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ call_kwargs = mock_cls.call_args[1]
+ assert call_kwargs["out_channels"] == C_OUT
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_extra_kwargs_forwarded_to_model(self, mock_module):
+ mock_cls = MagicMock()
+ mock_module.SongUNetPosEmbd = mock_cls
+ EDMPrecondSuperResolution(
+ img_resolution=64,
+ img_in_channels=C_IN,
+ img_out_channels=C_OUT,
+ model_channels=256,
+ num_blocks=8,
+ )
+ call_kwargs = mock_cls.call_args[1]
+ assert call_kwargs["model_channels"] == 256
+ assert call_kwargs["num_blocks"] == 8
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_model_attribute_exists(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ edm = EDMPrecondSuperResolution(
+ img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ assert hasattr(edm, "model")
+
+
+class TestEDMInitSigmaDefaults:
+ """Test sigma-related default values."""
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_default_sigma_data(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ edm = EDMPrecondSuperResolution(
+ img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ assert edm.sigma_data == 0.5
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_default_sigma_min(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ edm = EDMPrecondSuperResolution(
+ img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ assert edm.sigma_min == 0.0
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_default_sigma_max(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ edm = EDMPrecondSuperResolution(
+ img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ assert edm.sigma_max == float("inf")
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_custom_sigma_info(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ edm = EDMPrecondSuperResolution(
+ img_resolution=64,
+ img_in_channels=C_IN,
+ img_out_channels=C_OUT,
+ sigma_data=1.0,
+ sigma_min=0.002,
+ sigma_max=80.0,
+ )
+ assert edm.sigma_data == 1.0
+ assert edm.sigma_min == 0.002
+ assert edm.sigma_max == 80.0
+
+
+class TestEDMInitFp16:
+ """Test use_fp16 stored correctly at init."""
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_default_fp16_is_false(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ edm = EDMPrecondSuperResolution(
+ img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ assert edm.use_fp16 is False
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_set_fp16_true(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ edm = EDMPrecondSuperResolution(
+ img_resolution=64,
+ img_in_channels=C_IN,
+ img_out_channels=C_OUT,
+ use_fp16=True,
+ )
+ assert edm.use_fp16 is True
+
+
+############################################################################
+# EDMPrecondSuperResolution — _scaling_fn #
+############################################################################
+
+
+class TestEDMScalingFn:
+ """Test the static _scaling_fn method."""
+
+ def test_output_shape(self):
+ x = torch.randn(B, C_OUT, H, W)
+ lr = torch.randn(B, C_IN, H, W)
+ c_in = torch.ones(B, 1, 1, 1) * 0.5
+ result = EDMPrecondSuperResolution._scaling_fn(x, lr, c_in)
+ assert result.shape == (B, C_OUT + C_IN, H, W)
+
+ def test_first_channels_are_scaled_x(self):
+ x = torch.ones(B, C_OUT, H, W)
+ lr = torch.randn(B, C_IN, H, W)
+ c_in = torch.ones(B, 1, 1, 1) * 2.0
+ result = EDMPrecondSuperResolution._scaling_fn(x, lr, c_in)
+ torch.testing.assert_close(result[:, :C_OUT], x * 2.0)
+
+ def test_last_channels_are_unscaled_lr(self):
+ x = torch.randn(B, C_OUT, H, W)
+ lr = torch.ones(B, C_IN, H, W) * 3.0
+ c_in = torch.ones(B, 1, 1, 1) * 0.5
+ result = EDMPrecondSuperResolution._scaling_fn(x, lr, c_in)
+ torch.testing.assert_close(result[:, C_OUT:], lr)
+
+ def test_lr_cast_to_x_dtype(self):
+ x = torch.randn(B, C_OUT, H, W, dtype=torch.float32)
+ lr = torch.randn(B, C_IN, H, W, dtype=torch.float64)
+ c_in = torch.ones(B, 1, 1, 1)
+ result = EDMPrecondSuperResolution._scaling_fn(x, lr, c_in)
+ assert result.dtype == torch.float32
+
+
+############################################################################
+# EDMPrecondSuperResolution — forward #
+############################################################################
+
+
+class TestEDMForwardBasic:
+ """Basic forward pass tests for EDMPrecondSuperResolution."""
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_output_shape(self, mock_module, img_x, img_lr, sigma):
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ out = edm(img_x, img_lr, sigma)
+ assert out.shape == (B, C_OUT, H, W)
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_output_dtype_is_float32(self, mock_module, img_x, img_lr, sigma):
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ out = edm(img_x, img_lr, sigma)
+ assert out.dtype == torch.float32
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_model_input_has_combined_channels(self, mock_module, img_x, img_lr, sigma):
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ edm(img_x, img_lr, sigma)
+ model_input = mock_model.call_args[0][0]
+ assert model_input.shape[1] == C_OUT + C_IN
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_model_receives_flattened_c_noise(self, mock_module, img_x, img_lr, sigma):
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ edm(img_x, img_lr, sigma)
+ c_noise_arg = mock_model.call_args[0][1]
+ assert c_noise_arg.ndim == 1
+ assert c_noise_arg.shape[0] == B
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_model_receives_none_class_labels(self, mock_module, img_x, img_lr, sigma):
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ edm(img_x, img_lr, sigma)
+ call_kwargs = mock_model.call_args[1]
+ assert call_kwargs["class_labels"] is None
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_kwargs_forwarded_to_model(self, mock_module, img_x, img_lr, sigma):
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ lead_time = torch.randint(49, size=(B,))
+ edm(img_x, img_lr, sigma, lead_time_label=lead_time)
+ call_kwargs = mock_model.call_args[1]
+ assert "lead_time_label" in call_kwargs
+ torch.testing.assert_close(call_kwargs["lead_time_label"], lead_time)
+
+
+class TestEDMForwardPreconditioning:
+ """Test that the EDM preconditioning coefficients are applied correctly."""
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_c_noise_is_log_sigma_over_4(self, mock_module):
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ sigma_val = 2.0
+ sigma = torch.full((B,), sigma_val)
+ x = torch.randn(B, C_OUT, H, W)
+ lr = torch.randn(B, C_IN, H, W)
+ edm(x, lr, sigma)
+ c_noise_arg = mock_model.call_args[0][1]
+ expected = torch.full((B,), torch.tensor(sigma_val).log().item() / 4)
+ torch.testing.assert_close(c_noise_arg, expected)
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_output_is_c_skip_x_plus_c_out_F_x(self, mock_module):
+ """D(x) = c_skip * x + c_out * F(x); with F(x)=0, D(x) = c_skip * x."""
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ sigma_data = 0.5
+ edm = EDMPrecondSuperResolution(
+ img_resolution=H,
+ img_in_channels=C_IN,
+ img_out_channels=C_OUT,
+ sigma_data=sigma_data,
+ )
+ sigma_val = 1.0
+ sigma = torch.full((B,), sigma_val)
+ x = torch.ones(B, C_OUT, H, W)
+ lr = torch.randn(B, C_IN, H, W)
+ out = edm(x, lr, sigma)
+ # Since mock model returns zeros, D(x) = c_skip * x
+ c_skip = sigma_data**2 / (sigma_val**2 + sigma_data**2)
+ expected = torch.full((B, C_OUT, H, W), c_skip)
+ torch.testing.assert_close(out, expected)
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_sigma_reshaped_to_4d(self, mock_module, img_x, img_lr):
+ """Sigma with shape (B,) should be reshaped to (B, 1, 1, 1)."""
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ sigma_1d = torch.ones(B)
+ # Should not raise
+ edm(img_x, img_lr, sigma_1d)
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_sigma_2d_reshaped_to_4d(self, mock_module, img_x, img_lr):
+ """Sigma with shape (B, 1) should be reshaped to (B, 1, 1, 1)."""
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ sigma_2d = torch.ones(B, 1)
+ # Should not raise
+ edm(img_x, img_lr, sigma_2d)
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_sigma_4d_accepted(self, mock_module, img_x, img_lr):
+ """Sigma with shape (B, 1, 1, 1) should be accepted."""
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ sigma_4d = torch.ones(B, 1, 1, 1)
+ # Should not raise
+ edm(img_x, img_lr, sigma_4d)
+
+
+class TestEDMForwardImgLrNone:
+ """Test forward pass when img_lr is None."""
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_no_concatenation_when_img_lr_none(self, mock_module, img_x, sigma):
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ edm(img_x, img_lr=None, sigma=sigma)
+ model_input = mock_model.call_args[0][0]
+ # Without img_lr, input is c_in * x only
+ assert model_input.shape[1] == C_OUT
+
+
+class TestEDMForwardDtypeValidation:
+ """Test dtype enforcement in forward pass."""
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_raises_on_dtype_mismatch(self, mock_module, img_x, img_lr, sigma):
+ """Model should raise if the underlying model returns wrong dtype."""
+ mock_model = MagicMock(spec=nn.Module)
+ mock_model.side_effect = lambda x, sigma, class_labels=None, **kw: torch.zeros(
+ x.shape[0], C_OUT, x.shape[2], x.shape[3], dtype=torch.float16,
+ )
+ mock_model.modules.return_value = iter([])
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ with pytest.raises(ValueError, match="Expected the dtype"):
+ edm(img_x, img_lr, sigma)
+
+
+class TestEDMForwardForceFp32:
+ """Test the force_fp32 flag."""
+
+ #TODO: Test doesn't make sence when device is cpu.
+ @patch("hirad.models.preconditioning.network_module")
+ def test_force_fp32_uses_float32(self, mock_module, img_x, img_lr, sigma):
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=H,
+ img_in_channels=C_IN,
+ img_out_channels=C_OUT,
+ use_fp16=True,
+ )
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ edm(img_x.to(device), img_lr.to(device), sigma.to(device), force_fp32=True)
+ model_input = mock_model.call_args[0][0]
+ assert model_input.dtype == torch.float32
+
+
+class TestEDMForwardAutocastEnabled:
+ """Test that dtype validation is skipped when autocast is enabled."""
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_no_dtype_check_when_autocast_enabled(self, mock_module, img_x, img_lr, sigma):
+ mock_model = MagicMock(spec=nn.Module)
+ mock_model.side_effect = lambda x, sigma, class_labels=None, **kw: torch.zeros(
+ x.shape[0], C_OUT, x.shape[2], x.shape[3], dtype=torch.float16,
+ )
+ mock_model.modules.return_value = iter([])
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ with torch.autocast("cuda"):
+ out = edm(img_x, img_lr, sigma)
+ assert out.dtype == torch.float32
+
+
+############################################################################
+# EDMPrecondSuperResolution — round_sigma #
+############################################################################
+
+
+class TestEDMRoundSigma:
+ """Test round_sigma static method."""
+
+ def test_float_input(self):
+ result = EDMPrecondSuperResolution.round_sigma(0.5)
+ assert isinstance(result, torch.Tensor)
+ assert result.item() == pytest.approx(0.5)
+
+ def test_list_input(self):
+ result = EDMPrecondSuperResolution.round_sigma([0.1, 0.5, 1.0])
+ assert isinstance(result, torch.Tensor)
+ assert result.shape == (3,)
+ torch.testing.assert_close(result, torch.tensor([0.1, 0.5, 1.0]))
+
+ def test_tensor_input(self):
+ sigma = torch.tensor([0.2, 0.8])
+ result = EDMPrecondSuperResolution.round_sigma(sigma)
+ torch.testing.assert_close(result, sigma)
+
+
+############################################################################
+# EDMPrecondSuperResolution — amp_mode property #
+############################################################################
+
+
+class TestEDMAmpMode:
+ """Test amp_mode property getter and setter."""
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_amp_mode_returns_none_when_model_lacks_attr(self, mock_module):
+ mock_model = MagicMock(spec=nn.Module)
+ del mock_model.amp_mode
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ assert edm.amp_mode is None
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_amp_mode_returns_model_value(self, mock_module):
+ mock_model = MagicMock(spec=nn.Module)
+ mock_model.amp_mode = True
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ assert edm.amp_mode is True
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_amp_mode_setter_updates_model_and_submodules(self, mock_module):
+ mock_model = MagicMock(spec=nn.Module)
+ mock_model.amp_mode = False
+ sub_module = MagicMock()
+ sub_module.amp_mode = False
+ mock_model.modules.return_value = iter([sub_module])
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ edm.amp_mode = True
+ assert mock_model.amp_mode is True
+ assert sub_module.amp_mode is True
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_amp_mode_setter_rejects_non_bool(self, mock_module):
+ mock_model = MagicMock(spec=nn.Module)
+ mock_model.amp_mode = False
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ with pytest.raises(TypeError, match="amp_mode must be a boolean"):
+ edm.amp_mode = "yes"
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_amp_mode_setter_skips_model_without_attr(self, mock_module):
+ mock_model = MagicMock(spec=nn.Module)
+ del mock_model.amp_mode
+ mock_model.modules.return_value = iter([])
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ edm = EDMPrecondSuperResolution(
+ img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ # Should not raise even when model lacks amp_mode
+ edm.amp_mode = True
+
+
+############################################################################
+# EDMPrecondSuperResolution — nn.Module integration #
+############################################################################
+
+
+class TestEDMModuleIntegration:
+ """Test that EDMPrecondSuperResolution behaves as a proper nn.Module."""
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_is_nn_module(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ edm = EDMPrecondSuperResolution(
+ img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ assert isinstance(edm, nn.Module)
+
+ @patch("hirad.models.preconditioning.network_module")
+ def test_scaling_fn_attribute_set(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ edm = EDMPrecondSuperResolution(
+ img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT,
+ )
+ assert edm.scaling_fn is EDMPrecondSuperResolution._scaling_fn
diff --git a/tests/models/test_song_unet.py b/tests/models/test_song_unet.py
new file mode 100644
index 00000000..a047a67b
--- /dev/null
+++ b/tests/models/test_song_unet.py
@@ -0,0 +1,1513 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+import numpy as np
+import pytest
+import torch
+import torch.nn as nn
+
+from hirad.models.song_unet import SongUNet, SongUNetPosEmbd
+
+
+# ---------------------------------------------------------------------------
+# Helpers / fixtures — use small model configs for fast CPU tests
+# ---------------------------------------------------------------------------
+
+B = 2
+IMG_RES = 32
+IN_CH = 4
+OUT_CH = 3
+SMALL_CFG = dict(
+ model_channels=32,
+ channel_mult=[1, 2],
+ num_blocks=1,
+ attn_resolutions=[],
+ dropout=0.0,
+)
+
+
+@pytest.fixture()
+def small_unet():
+ """Return a small SongUNet that runs on CPU."""
+ return SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ **SMALL_CFG,
+ )
+
+
+@pytest.fixture()
+def noise_labels():
+ return torch.randn(B)
+
+
+@pytest.fixture()
+def class_labels():
+ return torch.randint(0, 2, (B, 1)).float()
+
+
+@pytest.fixture()
+def input_image():
+ return torch.randn(B, IN_CH, IMG_RES, IMG_RES)
+
+
+############################################################################
+# SongUNet — __init__ #
+############################################################################
+
+
+class TestSongUNetInitValidation:
+ """Test __init__ input validation."""
+
+ def test_invalid_embedding_type_raises(self):
+ with pytest.raises(ValueError, match="Invalid embedding_type"):
+ SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ embedding_type="invalid",
+ **SMALL_CFG,
+ )
+
+ def test_invalid_encoder_type_raises(self):
+ with pytest.raises(ValueError, match="Invalid encoder_type"):
+ SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ encoder_type="invalid",
+ **SMALL_CFG,
+ )
+
+ def test_invalid_decoder_type_raises(self):
+ with pytest.raises(ValueError, match="Invalid decoder_type"):
+ SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ decoder_type="invalid",
+ **SMALL_CFG,
+ )
+
+ @pytest.mark.parametrize("etype", ["positional", "fourier", "zero"])
+ def test_valid_embedding_types_accepted(self, etype):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ embedding_type=etype,
+ **SMALL_CFG,
+ )
+ assert model.embedding_type == etype
+
+ @pytest.mark.parametrize("enc", ["standard", "skip", "residual"])
+ def test_valid_encoder_types_accepted(self, enc):
+ SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ encoder_type=enc,
+ **SMALL_CFG,
+ )
+
+ @pytest.mark.parametrize("dec", ["standard", "skip"])
+ def test_valid_decoder_types_accepted(self, dec):
+ SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ decoder_type=dec,
+ **SMALL_CFG,
+ )
+
+
+class TestSongUNetInitResolution:
+ """Test resolution handling in __init__."""
+
+ def test_int_resolution_sets_square(self):
+ model = SongUNet(
+ img_resolution=32,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ **SMALL_CFG,
+ )
+ assert model.img_shape_x == 32
+ assert model.img_shape_y == 32
+
+ def test_list_resolution_sets_height_width(self):
+ model = SongUNet(
+ img_resolution=[24, 32],
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ **SMALL_CFG,
+ )
+ assert model.img_shape_y == 24
+ assert model.img_shape_x == 32
+
+ def test_img_resolution_stored(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ **SMALL_CFG,
+ )
+ assert model.img_resolution == IMG_RES
+
+
+class TestSongUNetInitEmbedding:
+ """Test embedding-related initialization."""
+
+ def test_positional_embedding_creates_map_noise(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ embedding_type="positional",
+ **SMALL_CFG,
+ )
+ assert hasattr(model, "map_noise")
+ assert hasattr(model, "map_layer0")
+ assert hasattr(model, "map_layer1")
+
+ def test_fourier_embedding_creates_map_noise(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ embedding_type="fourier",
+ **SMALL_CFG,
+ )
+ assert hasattr(model, "map_noise")
+ assert hasattr(model, "map_layer0")
+ assert hasattr(model, "map_layer1")
+
+ def test_zero_embedding_skips_mapping_layers(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ embedding_type="zero",
+ **SMALL_CFG,
+ )
+ assert not hasattr(model, "map_noise")
+ assert not hasattr(model, "map_layer0")
+ assert not hasattr(model, "map_layer1")
+ assert not hasattr(model, "map_label")
+ assert not hasattr(model, "map_augment")
+
+ def test_emb_channels_computed_correctly(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ channel_mult_emb=4,
+ **SMALL_CFG,
+ )
+ assert model.emb_channels == 32 * 4
+
+ def test_label_dim_creates_map_label(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ label_dim=10,
+ **SMALL_CFG,
+ )
+ assert model.map_label is not None
+
+ def test_no_label_dim_map_label_is_none(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ label_dim=0,
+ **SMALL_CFG,
+ )
+ assert model.map_label is None
+
+ def test_augment_dim_creates_map_augment(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ augment_dim=5,
+ **SMALL_CFG,
+ )
+ assert model.map_augment is not None
+
+ def test_no_augment_dim_map_augment_is_none(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ augment_dim=0,
+ **SMALL_CFG,
+ )
+ assert model.map_augment is None
+
+
+class TestSongUNetInitEncoder:
+ """Test encoder construction."""
+
+ def test_encoder_module_dict_created(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ **SMALL_CFG,
+ )
+ assert isinstance(model.enc, nn.ModuleDict)
+ assert len(model.enc) > 0
+
+ def test_skip_encoder_creates_aux_layers(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ encoder_type="skip",
+ **SMALL_CFG,
+ )
+ aux_down_keys = [k for k in model.enc.keys() if "aux_down" in k]
+ aux_skip_keys = [k for k in model.enc.keys() if "aux_skip" in k]
+ assert len(aux_down_keys) > 0
+ assert len(aux_skip_keys) > 0
+
+ def test_residual_encoder_creates_aux_residual(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ encoder_type="residual",
+ **SMALL_CFG,
+ )
+ aux_keys = [k for k in model.enc.keys() if "aux_residual" in k]
+ assert len(aux_keys) > 0
+
+ def test_standard_encoder_has_no_aux_layers(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ encoder_type="standard",
+ **SMALL_CFG,
+ )
+ aux_keys = [k for k in model.enc.keys() if "aux" in k]
+ assert len(aux_keys) == 0
+
+ def test_standard_encoder_layers(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ encoder_type="standard",
+ **SMALL_CFG,
+ )
+ expected_layers = ["32x32_conv", "32x32_block0", "16x16_block0"]
+ for layer in expected_layers:
+ assert hasattr(model.enc, layer)
+
+
+
+class TestSongUNetInitDecoder:
+ """Test decoder construction."""
+
+ def test_decoder_module_dict_created(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ **SMALL_CFG,
+ )
+ assert isinstance(model.dec, nn.ModuleDict)
+ assert len(model.dec) > 0
+
+ def test_skip_decoder_creates_aux_up_layers(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ decoder_type="skip",
+ **SMALL_CFG,
+ )
+ aux_keys = [k for k in model.dec.keys() if "aux_up" in k]
+ assert len(aux_keys) > 0
+
+ def test_standard_decoder_has_no_aux_layers(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ decoder_type="standard",
+ **SMALL_CFG,
+ )
+ aux_keys = [k for k in model.dec.keys() if "aux_up" in k]
+ assert len(aux_keys) == 0
+
+ def test_standard_decoder_layers(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ decoder_type="standard",
+ **SMALL_CFG,
+ )
+ expected_layers = ["16x16_in0", "16x16_in1", "16x16_block0", "16x16_block1", "32x32_up",
+ "32x32_block0", "32x32_block1", "32x32_aux_norm", "32x32_aux_conv"]
+ for layer in expected_layers:
+ assert hasattr(model.dec, layer)
+
+
+class TestSongUNetInitAdditiveEmbed:
+ """Test additive positional embedding in __init__."""
+
+ def test_additive_pos_embed_creates_parameter(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ additive_pos_embed=True,
+ **SMALL_CFG,
+ )
+ assert hasattr(model, "spatial_emb")
+ assert isinstance(model.spatial_emb, nn.Parameter)
+ assert model.spatial_emb.shape == (1, 32, IMG_RES, IMG_RES)
+
+ def test_no_additive_pos_embed_by_default(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ **SMALL_CFG,
+ )
+ assert not hasattr(model, "spatial_emb")
+
+
+class TestSongUNetInitCheckpoint:
+ """Test checkpoint level configuration."""
+
+ def test_checkpoint_level_zero(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ checkpoint_level=0,
+ **SMALL_CFG,
+ )
+ # threshold = (img_shape_y >> 0) + 1 = 32 >> 0 + 1 = 32 + 1 = 33
+ assert model.checkpoint_threshold == IMG_RES + 1
+
+ def test_checkpoint_level_one(self):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ checkpoint_level=1,
+ **SMALL_CFG,
+ )
+ # threshold = (32 >> 1) + 1 = 16 + 1 = 17
+ assert model.checkpoint_threshold == (IMG_RES >> 1) + 1
+
+
+class TestSongUNetIsModule:
+ """Test that SongUNet is a proper nn.Module."""
+
+ def test_is_nn_module(self, small_unet):
+ assert isinstance(small_unet, nn.Module)
+
+ def test_has_parameters(self, small_unet):
+ params = list(small_unet.parameters())
+ assert len(params) > 0
+
+
+############################################################################
+# SongUNet — forward #
+############################################################################
+
+
+class TestSongUNetForwardShape:
+ """Test forward pass output shapes."""
+
+ def test_output_shape(self, small_unet, input_image, noise_labels, class_labels):
+ out = small_unet(input_image, noise_labels, class_labels)
+ assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES)
+
+ def test_output_dtype_float32(self, small_unet, input_image, noise_labels, class_labels):
+ out = small_unet(input_image, noise_labels, class_labels)
+ assert out.dtype == torch.float32
+
+
+class TestSongUNetForwardEmbeddingTypes:
+ """Test forward with different embedding types."""
+
+ def test_zero_embedding_forward(self, input_image, noise_labels, class_labels):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ embedding_type="zero",
+ **SMALL_CFG,
+ )
+ out = model(input_image, noise_labels, class_labels)
+ assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES)
+
+ def test_fourier_embedding_forward(self, input_image, noise_labels, class_labels):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ embedding_type="fourier",
+ **SMALL_CFG,
+ )
+ out = model(input_image, noise_labels, class_labels)
+ assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES)
+
+ def test_positional_embedding_forward(self, input_image, noise_labels, class_labels):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ embedding_type="positional",
+ **SMALL_CFG,
+ )
+ out = model(input_image, noise_labels, class_labels)
+ assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES)
+
+
+class TestSongUNetForwardEncoderDecoder:
+ """Test forward with various encoder/decoder combos."""
+
+ @pytest.mark.parametrize("enc,dec", [
+ ("standard", "standard"),
+ ("skip", "standard"),
+ ("residual", "standard"),
+ ("standard", "skip"),
+ ("skip", "skip"),
+ ("residual", "skip"),
+ ])
+ def test_encoder_decoder_combinations(self, enc, dec, input_image, noise_labels, class_labels):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ encoder_type=enc,
+ decoder_type=dec,
+ **SMALL_CFG,
+ )
+ out = model(input_image, noise_labels, class_labels)
+ assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES)
+
+
+class TestSongUNetForwardLabel:
+ """Test label dropout behavior during training."""
+
+ def test_label_dropout_in_training(self, input_image, noise_labels):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ label_dim=5,
+ label_dropout=0.5,
+ **SMALL_CFG,
+ )
+ model.train()
+ labels = torch.ones(B, 5)
+ out = model(input_image, noise_labels, labels)
+ assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES)
+
+ def test_label_dropout_in_eval(self, input_image, noise_labels):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ label_dim=5,
+ label_dropout=0.5,
+ **SMALL_CFG,
+ )
+ model.eval()
+ labels = torch.ones(B, 5)
+ out = model(input_image, noise_labels, labels)
+ assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES)
+
+ def test_map_label_called_when_label_dim_positive(self, input_image, noise_labels):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ label_dim=5,
+ **SMALL_CFG,
+ )
+ assert model.map_label is not None
+ called = []
+ model.map_label.register_forward_hook(lambda m, i, o: called.append(True))
+ out = model(input_image, noise_labels, torch.ones(B, 5))
+ assert len(called) == 1
+
+
+class TestSongUNetForwardAugment:
+ """Test augment dropout behavior during training."""
+
+ def test_map_augment_called_when_augment_dim_positive(self, input_image, noise_labels, class_labels):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ augment_dim=3,
+ **SMALL_CFG,
+ )
+ assert model.map_augment is not None
+ called = []
+ model.map_augment.register_forward_hook(lambda m, i, o: called.append(True))
+ out = model(input_image, noise_labels, class_labels, augment_labels=torch.ones(B, 3))
+ assert len(called) == 1
+
+ def test_no_augment_labels_skips_map_augment(self, input_image, noise_labels, class_labels):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ augment_dim=3,
+ **SMALL_CFG,
+ )
+ assert model.map_augment is not None
+ called = []
+ model.map_augment.register_forward_hook(lambda m, i, o: called.append(True))
+ out = model(input_image, noise_labels, class_labels)
+ assert len(called) == 0
+
+
+class TestSongUNetForwardNoiseEmbedding:
+ def test_map_noise_called_when_embedding_type_fourier(self, input_image, noise_labels, class_labels):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ embedding_type="fourier",
+ **SMALL_CFG,
+ )
+ assert model.map_noise is not None
+ called_map_noise = []
+ called_map_layer0 = []
+ called_map_layer1 = []
+ model.map_noise.register_forward_hook(lambda m, i, o: called_map_noise.append(True))
+ model.map_layer0.register_forward_hook(lambda m, i, o: called_map_layer0.append(True))
+ model.map_layer1.register_forward_hook(lambda m, i, o: called_map_layer1.append(True))
+ out = model(input_image, noise_labels, class_labels)
+ assert len(called_map_noise) == 1
+ assert len(called_map_layer0) == 1
+ assert len(called_map_layer1) == 1
+
+ def test_map_noise_called_when_embedding_type_positional(self, input_image, noise_labels, class_labels):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ embedding_type="positional",
+ **SMALL_CFG,
+ )
+ assert model.map_noise is not None
+ called_map_noise = []
+ called_map_layer0 = []
+ called_map_layer1 = []
+ model.map_noise.register_forward_hook(lambda m, i, o: called_map_noise.append(True))
+ model.map_layer0.register_forward_hook(lambda m, i, o: called_map_layer0.append(True))
+ model.map_layer1.register_forward_hook(lambda m, i, o: called_map_layer1.append(True))
+ out = model(input_image, noise_labels, class_labels)
+ assert len(called_map_noise) == 1
+ assert len(called_map_layer0) == 1
+ assert len(called_map_layer1) == 1
+
+
+class TestSongUNetForwardEncoderDecoderCalls:
+ """Test that all encoder and decoder blocks are called during forward."""
+
+ def test_all_enc_dec_blocks_called_parametrized(
+ self, input_image, noise_labels, class_labels
+ ):
+ """Test across different encoder/decoder combos."""
+ for enc in ["standard", "skip", "residual"]:
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ encoder_type=enc,
+ **SMALL_CFG,
+ )
+ called = {}
+ handles = []
+ for name, block in list(model.enc.items()):
+ called[name] = 0
+ handle = block.register_forward_hook(
+ lambda m, i, o, n=name: called.__setitem__(n, called[n] + 1)
+ )
+ handles.append(handle)
+
+ model(input_image, noise_labels, class_labels)
+
+ for handle in handles:
+ handle.remove()
+
+ for name, count in called.items():
+ assert count == 1, (
+ f"[enc={enc}] Block '{name}' called {count} times, expected 1"
+ )
+
+ def test_all_decoder_blocks_called_parametrized(
+ self, input_image, noise_labels, class_labels
+ ):
+ """Test across different encoder/decoder combos."""
+ for dec in ["standard", "skip"]:
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ decoder_type=dec,
+ **SMALL_CFG,
+ )
+ called = {}
+ handles = []
+ for name, block in list(model.dec.items()):
+ called[name] = 0
+ handle = block.register_forward_hook(
+ lambda m, i, o, n=name: called.__setitem__(n, called[n] + 1)
+ )
+ handles.append(handle)
+
+ model(input_image, noise_labels, class_labels)
+
+ for handle in handles:
+ handle.remove()
+
+ for name, count in called.items():
+ assert count == 1, (
+ f"[dec={dec}] Block '{name}' called {count} times, expected 1"
+ )
+
+
+class TestSongUNetForwardAdditiveEmbed:
+ """Test forward with additive positional embedding."""
+
+ def test_additive_embed_forward(self, input_image, noise_labels, class_labels):
+ model = SongUNet(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ additive_pos_embed=True,
+ **SMALL_CFG,
+ )
+ out = model(input_image, noise_labels, class_labels)
+ assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES)
+
+
+class TestSongUNetForwardRectangularResolution:
+ """Test forward with non-square input."""
+
+ def test_rectangular_resolution_forward(self, noise_labels, class_labels):
+ res = [16, 32]
+ model = SongUNet(
+ img_resolution=res,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ **SMALL_CFG,
+ )
+ x = torch.randn(B, IN_CH, res[0], res[1])
+ out = model(x, noise_labels, class_labels)
+ assert out.shape == (B, OUT_CH, res[0], res[1])
+
+
+############################################################################
+# SongUNetPosEmbd — __init__ #
+############################################################################
+
+
+# Positional embedding adds N_grid_channels to in_channels
+N_GRID = 4
+PE_IN_CH = IN_CH + N_GRID
+
+PE_SMALL_CFG = dict(
+ model_channels=32,
+ channel_mult=[1, 2],
+ num_blocks=1,
+ attn_resolutions=[],
+ dropout=0.0,
+ use_apex_gn=False,
+)
+
+
+@pytest.fixture()
+def small_pos_unet():
+ return SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH,
+ out_channels=OUT_CH,
+ N_grid_channels=N_GRID,
+ **PE_SMALL_CFG,
+ )
+
+
+class TestSongUNetPosEmbdInitGridType:
+ """Test grid type selection in __init__."""
+
+ def test_sinusoidal_grid_default(self):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH,
+ out_channels=OUT_CH,
+ gridtype="sinusoidal",
+ N_grid_channels=N_GRID,
+ **PE_SMALL_CFG,
+ )
+ assert model.gridtype == "sinusoidal"
+ assert model.pos_embd.shape == (N_GRID, IMG_RES, IMG_RES)
+
+ def test_learnable_grid(self):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH,
+ out_channels=OUT_CH,
+ gridtype="learnable",
+ N_grid_channels=N_GRID,
+ **PE_SMALL_CFG,
+ )
+ assert model.gridtype == "learnable"
+ assert isinstance(model.pos_embd, nn.Parameter)
+ assert model.pos_embd.shape == (N_GRID, IMG_RES, IMG_RES)
+
+ def test_linear_grid(self):
+ n_ch = 2
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH + n_ch,
+ out_channels=OUT_CH,
+ gridtype="linear",
+ N_grid_channels=n_ch,
+ **PE_SMALL_CFG,
+ )
+ assert model.gridtype == "linear"
+ assert model.pos_embd.shape == (2, IMG_RES, IMG_RES)
+
+ def test_test_grid(self):
+ n_ch = 2
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH + n_ch,
+ out_channels=OUT_CH,
+ gridtype="test",
+ N_grid_channels=n_ch,
+ **PE_SMALL_CFG,
+ )
+ assert model.pos_embd.shape == (2, IMG_RES, IMG_RES)
+
+
+class TestSongUNetPosEmbdInitGridChannelsValidation:
+ """Test N_grid_channels validation."""
+
+ def test_linear_grid_requires_2_channels(self):
+ with pytest.raises(ValueError, match="N_grid_channels must be set to 2"):
+ SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH + 4,
+ out_channels=OUT_CH,
+ gridtype="linear",
+ N_grid_channels=4,
+ **PE_SMALL_CFG,
+ )
+
+ def test_sinusoidal_multi_freq_requires_factor_of_4(self):
+ with pytest.raises(ValueError, match="N_grid_channels must be a factor of 4"):
+ SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH + 5,
+ out_channels=OUT_CH,
+ gridtype="sinusoidal",
+ N_grid_channels=5,
+ **PE_SMALL_CFG,
+ )
+
+ def test_sinusoidal_8_channels_accepted(self):
+ n_ch = 8
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH + n_ch,
+ out_channels=OUT_CH,
+ gridtype="sinusoidal",
+ N_grid_channels=n_ch,
+ **PE_SMALL_CFG,
+ )
+ assert model.pos_embd.shape == (n_ch, IMG_RES, IMG_RES)
+
+ def test_zero_grid_channels_returns_none(self):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ N_grid_channels=0,
+ **PE_SMALL_CFG,
+ )
+ assert model.pos_embd is None
+
+ def test_unsupported_gridtype_raises(self):
+ with pytest.raises(ValueError, match="Gridtype not supported"):
+ SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH,
+ out_channels=OUT_CH,
+ gridtype="unknown",
+ N_grid_channels=N_GRID,
+ **PE_SMALL_CFG,
+ )
+
+
+class TestSongUNetPosEmbdInitLeadTime:
+ """Test lead time related initialization."""
+
+ def test_lead_time_mode_creates_lt_embd(self):
+ lt_ch = 2
+ lt_steps = 5
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH + lt_ch,
+ out_channels=OUT_CH,
+ N_grid_channels=N_GRID,
+ lead_time_mode=True,
+ lead_time_channels=lt_ch,
+ lead_time_steps=lt_steps,
+ **PE_SMALL_CFG,
+ )
+ assert model.lead_time_mode is True
+ assert model.lt_embd is not None
+ assert model.lt_embd.shape == (lt_steps, lt_ch, IMG_RES, IMG_RES)
+
+ def test_no_lead_time_mode_by_default(self):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH,
+ out_channels=OUT_CH,
+ N_grid_channels=N_GRID,
+ **PE_SMALL_CFG,
+ )
+ assert model.lead_time_mode is False
+
+ def test_lead_time_none_channels_returns_none_embd(self):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH,
+ out_channels=OUT_CH,
+ N_grid_channels=N_GRID,
+ lead_time_mode=True,
+ lead_time_channels=None,
+ lead_time_steps=9,
+ **PE_SMALL_CFG,
+ )
+ assert model.lt_embd is None
+
+ def test_lead_time_none_steps_returns_none_embd(self):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH,
+ out_channels=OUT_CH,
+ N_grid_channels=N_GRID,
+ lead_time_mode=True,
+ lead_time_channels=2,
+ lead_time_steps=None,
+ **PE_SMALL_CFG,
+ )
+ assert model.lt_embd is None
+
+
+class TestSongUNetPosEmbdInitProbChannels:
+ """Test prob_channels initialization."""
+
+ def test_prob_channels_creates_scalar(self):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH + 2,
+ out_channels=OUT_CH,
+ N_grid_channels=N_GRID,
+ lead_time_mode=True,
+ lead_time_channels=2,
+ lead_time_steps=3,
+ prob_channels=[0, 1],
+ **PE_SMALL_CFG,
+ )
+ assert hasattr(model, "scalar")
+ assert model.scalar.shape == (1, 2, 1, 1)
+
+ def test_empty_prob_channels_no_scalar(self):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH + 2,
+ out_channels=OUT_CH,
+ N_grid_channels=N_GRID,
+ lead_time_mode=True,
+ lead_time_channels=2,
+ lead_time_steps=3,
+ prob_channels=[],
+ **PE_SMALL_CFG,
+ )
+ assert not hasattr(model, "scalar")
+
+
+class TestSongUNetPosEmbdIsModule:
+ """Test that SongUNetPosEmbd is a proper nn.Module and subclass of SongUNet."""
+
+ def test_is_nn_module(self, small_pos_unet):
+ assert isinstance(small_pos_unet, nn.Module)
+
+ def test_is_subclass_of_song_unet(self, small_pos_unet):
+ assert isinstance(small_pos_unet, SongUNet)
+
+
+############################################################################
+# SongUNetPosEmbd — _get_positional_embedding #
+############################################################################
+
+
+class TestGetPositionalEmbedding:
+ """Test _get_positional_embedding for various grid types."""
+
+ def test_sinusoidal_4ch_grid_not_requires_grad(self):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH,
+ out_channels=OUT_CH,
+ gridtype="sinusoidal",
+ N_grid_channels=N_GRID,
+ **PE_SMALL_CFG,
+ )
+ assert not model.pos_embd.requires_grad
+
+ def test_linear_grid_not_requires_grad(self):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH + 2,
+ out_channels=OUT_CH,
+ gridtype="linear",
+ N_grid_channels=2,
+ **PE_SMALL_CFG,
+ )
+ assert not model.pos_embd.requires_grad
+
+ def test_learnable_grid_requires_grad(self):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH,
+ out_channels=OUT_CH,
+ gridtype="learnable",
+ N_grid_channels=N_GRID,
+ **PE_SMALL_CFG,
+ )
+ assert model.pos_embd.requires_grad
+
+ def test_sinusoidal_values_in_range(self):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH,
+ out_channels=OUT_CH,
+ gridtype="sinusoidal",
+ N_grid_channels=N_GRID,
+ **PE_SMALL_CFG,
+ )
+ assert model.pos_embd.min() >= -1.0
+ assert model.pos_embd.max() <= 1.0
+
+ def test_linear_values_in_range(self):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH + 2,
+ out_channels=OUT_CH,
+ gridtype="linear",
+ N_grid_channels=2,
+ **PE_SMALL_CFG,
+ )
+ assert model.pos_embd.min() >= -1.0
+ assert model.pos_embd.max() <= 1.0
+
+ def test_rectangular_sinusoidal_grid(self):
+ model = SongUNetPosEmbd(
+ img_resolution=[16, 32],
+ in_channels=PE_IN_CH,
+ out_channels=OUT_CH,
+ gridtype="sinusoidal",
+ N_grid_channels=N_GRID,
+ **PE_SMALL_CFG,
+ )
+ assert model.pos_embd.shape == (N_GRID, 16, 32)
+
+ def test_rectangular_grid_sinusoidal_8ch(self):
+ n_ch = 8
+ model = SongUNetPosEmbd(
+ img_resolution=[16, 32],
+ in_channels=IN_CH + n_ch,
+ out_channels=OUT_CH,
+ gridtype="sinusoidal",
+ N_grid_channels=n_ch,
+ **PE_SMALL_CFG,
+ )
+ assert model.pos_embd.shape == (n_ch, 16, 32)
+
+ def test_rectangular_linear_grid(self):
+ model = SongUNetPosEmbd(
+ img_resolution=[16, 32],
+ in_channels=IN_CH + 2,
+ out_channels=OUT_CH,
+ gridtype="linear",
+ N_grid_channels=2,
+ **PE_SMALL_CFG,
+ )
+ assert model.pos_embd.shape == (2, 16, 32)
+
+ def test_rectangular_learnable_grid(self):
+ model = SongUNetPosEmbd(
+ img_resolution=[16, 32],
+ in_channels=PE_IN_CH,
+ out_channels=OUT_CH,
+ gridtype="learnable",
+ N_grid_channels=N_GRID,
+ **PE_SMALL_CFG,
+ )
+ assert model.pos_embd.shape == (N_GRID, 16, 32)
+
+ def test_linear_grid_simple_values(self):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH + 2,
+ out_channels=OUT_CH,
+ gridtype="linear",
+ N_grid_channels=2,
+ **PE_SMALL_CFG,
+ )
+ # Check that the first channel is a vertical gradient and the second is horizontal
+ for y in range(IMG_RES):
+ for x in range(IMG_RES):
+ expected_y = (y / (IMG_RES - 1)) * 2 - 1
+ expected_x = (x / (IMG_RES - 1)) * 2 - 1
+ assert torch.isclose(model.pos_embd[0, y, x], torch.tensor([expected_x]), atol=1e-5)
+ assert torch.isclose(model.pos_embd[1, y, x], torch.tensor([expected_y]), atol=1e-5)
+
+ def test_sinusoidal_grid_simple_values(self):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH,
+ out_channels=OUT_CH,
+ gridtype="sinusoidal",
+ N_grid_channels=4,
+ **PE_SMALL_CFG,
+ )
+ # Check that the first two channels are sinusoids of different frequencies
+ # and the next two channels are the cosine counterparts
+ for y in range(IMG_RES):
+ for x in range(IMG_RES):
+ expected_ch0 = torch.sin(2 * torch.pi * torch.tensor([x]) / (IMG_RES - 1))
+ expected_ch1 = torch.sin(2 * torch.pi * torch.tensor([y]) / (IMG_RES - 1))
+ expected_ch2 = torch.cos(2 * torch.pi * torch.tensor([x]) / (IMG_RES - 1))
+ expected_ch3 = torch.cos(2 * torch.pi * torch.tensor([y]) / (IMG_RES - 1))
+ assert torch.isclose(model.pos_embd[0, y, x], expected_ch0, atol=1e-5)
+ assert torch.isclose(model.pos_embd[1, y, x], expected_ch1, atol=1e-5)
+ assert torch.isclose(model.pos_embd[2, y, x], expected_ch2, atol=1e-5)
+ assert torch.isclose(model.pos_embd[3, y, x], expected_ch3, atol=1e-5)
+
+ #TODO: When more than 4 channels are used for sinusoidal, the frequencies should be multiples of the base frequency (2).
+ # freq_bands = 2.0 ** np.linspace(0.0, num_freq, num=num_freq) is currently in code which gives
+ # freqs = [1,4] instead of [1,2] for N_grid_channels=8. This seems to be a bug if we want the base 2.
+ # Leaving it like this for now since we have checkpoints with 8 sinusoidal channels that use these frequencies,
+ # but it should be fixed in the future and this test should be updated to reflect the intended behavior.
+ def test_sinusoidal_8ch_grid_simple_values(self):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH + 8,
+ out_channels=OUT_CH,
+ gridtype="sinusoidal",
+ N_grid_channels=8,
+ **PE_SMALL_CFG,
+ )
+ # Check that the first 4 channels are sinusoids of different frequencies
+ # and the next 4 channels are the cosine counterparts
+ for y in range(IMG_RES):
+ for x in range(IMG_RES):
+ for idx, i in enumerate([0,2]):
+ expected_ch_0 = torch.sin((2**i) * 2 * torch.pi * torch.tensor([x]) / (IMG_RES - 1))
+ expected_ch_1 = torch.sin((2**i) * 2 * torch.pi * torch.tensor([y]) / (IMG_RES - 1))
+ expected_ch_2 = torch.cos((2**i) * 2 * torch.pi * torch.tensor([x]) / (IMG_RES - 1))
+ expected_ch_3 = torch.cos((2**i) * 2 * torch.pi * torch.tensor([y]) / (IMG_RES - 1))
+ assert torch.isclose(model.pos_embd[4*idx, y, x], expected_ch_0, atol=1e-5)
+ assert torch.isclose(model.pos_embd[4*idx + 1, y, x], expected_ch_1, atol=1e-5)
+ assert torch.isclose(model.pos_embd[4*idx + 2, y, x], expected_ch_2, atol=1e-5)
+ assert torch.isclose(model.pos_embd[4*idx + 3, y, x], expected_ch_3, atol=1e-5)
+
+
+############################################################################
+# SongUNetPosEmbd — forward #
+############################################################################
+
+
+class TestSongUNetPosEmbdForwardBasic:
+ """Test basic forward pass for SongUNetPosEmbd."""
+
+ def test_output_shape(self, small_pos_unet, noise_labels, class_labels):
+ x = torch.randn(B, IN_CH, IMG_RES, IMG_RES)
+ out = small_pos_unet(x, noise_labels, class_labels)
+ assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES)
+
+ def test_output_dtype_float32(self, small_pos_unet, noise_labels, class_labels):
+ x = torch.randn(B, IN_CH, IMG_RES, IMG_RES)
+ out = small_pos_unet(x, noise_labels, class_labels)
+ assert out.dtype == torch.float32
+
+
+class TestSongUNetPosEmbdForwardErrors:
+ """Test that forward raises for mutually exclusive arguments."""
+
+ def test_raises_when_both_selector_and_index_provided(
+ self, small_pos_unet, noise_labels, class_labels
+ ):
+ x = torch.randn(B, IN_CH, IMG_RES, IMG_RES)
+ global_index = torch.zeros(1, 2, IMG_RES, IMG_RES, dtype=torch.long)
+ selector = lambda emb: emb[None].expand(B, -1, -1, -1)
+ with pytest.raises(ValueError, match="Cannot provide both"):
+ small_pos_unet(
+ x, noise_labels, class_labels,
+ global_index=global_index,
+ embedding_selector=selector,
+ )
+
+ def test_raises_when_lead_time_mode_and_embedding_selector_provided(self, small_pos_unet, noise_labels, class_labels):
+ small_pos_unet.lead_time_mode = True
+ x = torch.randn(B, IN_CH, IMG_RES, IMG_RES)
+ selector = lambda emb: emb[None].expand(B, -1, -1, -1)
+ with pytest.raises(ValueError, match="Embedding selector is not supported in lead time mode."):
+ small_pos_unet(
+ x, noise_labels, class_labels,
+ embedding_selector=selector,
+ )
+
+
+class TestSongUNetPosEmbdForwardSelector:
+ """Test forward with embedding_selector."""
+
+ def test_selector_applied(self, small_pos_unet, noise_labels, class_labels):
+ x = torch.randn(B, IN_CH, IMG_RES, IMG_RES)
+ selector = lambda emb: emb[None].expand(B, -1, -1, -1)
+ out = small_pos_unet(
+ x, noise_labels, class_labels, embedding_selector=selector
+ )
+ assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES)
+
+ def test_selector_takes_subset_of_embeddings(self, small_pos_unet, noise_labels, class_labels):
+ P = 2
+ x = torch.randn(B * P, IN_CH, IMG_RES//2, IMG_RES//2)
+ # Selector that takes only the first 2 channels of the positional embedding
+ selector = lambda emb: emb[None].expand(B * P, -1, -1, -1)[:,:,:IMG_RES//2,:IMG_RES//2]
+ noise_labels = torch.randn(B * P)
+ class_labels = torch.randint(0, 1, (B * P, 1)).float()
+ out = small_pos_unet(
+ x, noise_labels, class_labels, embedding_selector=selector
+ )
+ assert out.shape == (B*P, OUT_CH, IMG_RES//2, IMG_RES//2)
+
+
+class TestSongUNetPosEmbdForwardGlobalIndex:
+ """Test forward with global_index."""
+
+ def test_global_index_selects_embeddings(
+ self, small_pos_unet, noise_labels, class_labels
+ ):
+ x = torch.randn(B, IN_CH, IMG_RES, IMG_RES)
+ # Create index that selects the full grid
+ idx_y = torch.arange(IMG_RES).view(1, 1, IMG_RES, 1).expand(1, 1, IMG_RES, IMG_RES)
+ idx_x = torch.arange(IMG_RES).view(1, 1, 1, IMG_RES).expand(1, 1, IMG_RES, IMG_RES)
+ global_index = torch.cat([idx_y, idx_x], dim=1) # (P, 2, H, W)
+ out = small_pos_unet(x, noise_labels, class_labels, global_index=global_index)
+ assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES)
+
+ def test_global_index_selects_subset_of_embeddings(
+ self, small_pos_unet
+ ):
+ P = 2
+ x = torch.randn(B * P, IN_CH, IMG_RES//2, IMG_RES//2)
+ # Create index that selects only the top-left quadrant of the grid
+ idx_y = torch.arange(IMG_RES//2).view(1, 1, IMG_RES//2, 1).expand(P, 1, IMG_RES//2, IMG_RES//2)
+ idx_x = torch.arange(IMG_RES//2).view(1, 1, 1, IMG_RES//2).expand(P, 1, IMG_RES//2, IMG_RES//2)
+ global_index = torch.cat([idx_y, idx_x], dim=1) # (P, 2, H, W)
+ noise_labels = torch.randn(B * P)
+ class_labels = torch.randint(0, 1, (B * P, 1)).float()
+ out = small_pos_unet(x, noise_labels, class_labels, global_index=global_index)
+ assert out.shape == (B * P, OUT_CH, IMG_RES//2, IMG_RES//2)
+
+
+class TestSongUNetPosEmbdForwardLeadTime:
+ """Test forward pass with lead_time_mode enabled."""
+
+ def _make_lead_time_model(self):
+ lt_ch = 2
+ return SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH + lt_ch,
+ out_channels=OUT_CH,
+ N_grid_channels=N_GRID,
+ lead_time_mode=True,
+ lead_time_channels=lt_ch,
+ lead_time_steps=5,
+ prob_channels=[],
+ **PE_SMALL_CFG,
+ )
+
+ def test_lead_time_forward_shape(self, noise_labels, class_labels):
+ model = self._make_lead_time_model()
+ x = torch.randn(B, IN_CH, IMG_RES, IMG_RES)
+ lead_time = torch.zeros(B, dtype=torch.long)
+ out = model(x, noise_labels, class_labels, lead_time_label=lead_time)
+ assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES)
+
+ def test_lead_time_with_prob_channels_eval(self, noise_labels, class_labels):
+ """In eval mode, prob_channels should go through softmax."""
+ lt_ch = 2
+ out_ch = 4
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH + lt_ch,
+ out_channels=out_ch,
+ N_grid_channels=N_GRID,
+ lead_time_mode=True,
+ lead_time_channels=lt_ch,
+ lead_time_steps=5,
+ prob_channels=[2, 3],
+ **PE_SMALL_CFG,
+ )
+ model.eval()
+ x = torch.randn(B, IN_CH, IMG_RES, IMG_RES)
+ lead_time = torch.zeros(B, dtype=torch.long)
+ out = model(x, noise_labels, class_labels, lead_time_label=lead_time)
+ assert out.shape == (B, out_ch, IMG_RES, IMG_RES)
+ # Prob channels should sum to 1 (softmax)
+ prob_sum = out[:, [2, 3]].sum(dim=1)
+ torch.testing.assert_close(
+ prob_sum, torch.ones(B, IMG_RES, IMG_RES), atol=1e-5, rtol=1e-5
+ )
+
+ def test_lead_time_with_prob_channels_train(self, noise_labels, class_labels):
+ """In training mode, prob_channels should output raw logits (no softmax)."""
+ lt_ch = 2
+ out_ch = 4
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH + lt_ch,
+ out_channels=out_ch,
+ N_grid_channels=N_GRID,
+ lead_time_mode=True,
+ lead_time_channels=lt_ch,
+ lead_time_steps=5,
+ prob_channels=[2, 3],
+ **PE_SMALL_CFG,
+ )
+ model.train()
+ x = torch.randn(B, IN_CH, IMG_RES, IMG_RES)
+ lead_time = torch.zeros(B, dtype=torch.long)
+ out = model(x, noise_labels, class_labels, lead_time_label=lead_time)
+ assert out.shape == (B, out_ch, IMG_RES, IMG_RES)
+
+ def test_lead_time_with_global_index(self, noise_labels, class_labels):
+ """Test that global_index can be used with lead_time_mode."""
+ lt_ch = 2
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH + lt_ch,
+ out_channels=OUT_CH,
+ N_grid_channels=N_GRID,
+ lead_time_mode=True,
+ lead_time_channels=lt_ch,
+ lead_time_steps=5,
+ prob_channels=[],
+ **PE_SMALL_CFG,
+ )
+ P = 2
+ x = torch.randn(B * P, IN_CH, IMG_RES//2, IMG_RES//2)
+ # Create index that selects only the top-left quadrant of the grid
+ idx_y = torch.arange(IMG_RES//2).view(1, 1, IMG_RES//2, 1).expand(P, 1, IMG_RES//2, IMG_RES//2)
+ idx_x = torch.arange(IMG_RES//2).view(1, 1, 1, IMG_RES//2).expand(P, 1, IMG_RES//2, IMG_RES//2)
+ global_index = torch.cat([idx_y, idx_x], dim=1) # (P, 2, H, W)
+ noise_labels = torch.randn(B * P)
+ class_labels = torch.randint(0, 1, (B * P, 1)).float()
+ out = model(x, noise_labels, class_labels, global_index=global_index, lead_time_label=torch.zeros(B, dtype=torch.long))
+ assert out.shape == (B * P, OUT_CH, IMG_RES//2, IMG_RES//2)
+
+
+class TestSongUNetPosEmbdForwardNoneGrid:
+ """Test forward pass when N_grid_channels=0 (no positional embedding)."""
+
+ def test_no_pos_embd_forward(self, noise_labels, class_labels):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=IN_CH,
+ out_channels=OUT_CH,
+ N_grid_channels=0,
+ **PE_SMALL_CFG,
+ )
+ x = torch.randn(B, IN_CH, IMG_RES, IMG_RES)
+ out = model(x, noise_labels, class_labels)
+ assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES)
+
+
+class TestSongUNetPosEmbdForwardRectangular:
+ """Test forward with non-square resolution."""
+
+ def test_rectangular_forward(self, noise_labels, class_labels):
+ res = [16, 32]
+ model = SongUNetPosEmbd(
+ img_resolution=res,
+ in_channels=PE_IN_CH,
+ out_channels=OUT_CH,
+ N_grid_channels=N_GRID,
+ **PE_SMALL_CFG,
+ )
+ x = torch.randn(B, IN_CH, res[0], res[1])
+ out = model(x, noise_labels, class_labels)
+ assert out.shape == (B, OUT_CH, res[0], res[1])
+
+
+############################################################################
+# SongUNetPosEmbd — positional_embedding_indexing #
+############################################################################
+
+
+class TestPositionalEmbeddingIndexing:
+ """Test positional_embedding_indexing method."""
+
+ def test_no_index_returns_full_grid(self, small_pos_unet):
+ x = torch.randn(B, IN_CH, IMG_RES, IMG_RES)
+ result = small_pos_unet.positional_embedding_indexing(x)
+ assert result.shape == (B, N_GRID, IMG_RES, IMG_RES)
+
+ def test_no_index_expands_batch(self, small_pos_unet):
+ x = torch.randn(4, IN_CH, IMG_RES, IMG_RES)
+ result = small_pos_unet.positional_embedding_indexing(x)
+ assert result.shape[0] == 4
+
+ def test_global_index_selects_correctly(self, small_pos_unet):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH,
+ out_channels=OUT_CH,
+ gridtype="linear",
+ N_grid_channels=2,
+ **PE_SMALL_CFG,
+ )
+ P = 2
+ x = torch.randn(B * P, IN_CH, IMG_RES//2, IMG_RES//2)
+ idx_y = torch.arange(IMG_RES//2).view(1, 1, IMG_RES//2, 1).expand(P, 1, IMG_RES//2, IMG_RES//2)
+ idx_x = torch.arange(IMG_RES//2).view(1, 1, 1, IMG_RES//2).expand(P, 1, IMG_RES//2, IMG_RES//2)
+ global_index = torch.cat([idx_y, idx_x], dim=1)
+ result = model.positional_embedding_indexing(x, global_index=global_index)
+ assert result.shape == (B * P, 2, IMG_RES//2, IMG_RES//2)
+ assert torch.allclose(result, model.pos_embd[None, :, :IMG_RES//2, :IMG_RES//2].expand(B*P, -1, -1, -1))
+
+ def test_global_index_selects_correctly_with_lead_time(self, small_pos_unet):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH + 2,
+ out_channels=OUT_CH,
+ gridtype="linear",
+ N_grid_channels=2,
+ lead_time_mode=True,
+ lead_time_channels=2,
+ lead_time_steps=5,
+ prob_channels=[],
+ **PE_SMALL_CFG,
+ )
+ P = 2
+ x = torch.randn(B * P, IN_CH, IMG_RES//2, IMG_RES//2)
+ idx_y = torch.arange(IMG_RES//2).view(1, 1, IMG_RES//2, 1).expand(P, 1, IMG_RES//2, IMG_RES//2)
+ idx_x = torch.arange(IMG_RES//2).view(1, 1, 1, IMG_RES//2).expand(P, 1, IMG_RES//2, IMG_RES//2)
+ global_index = torch.cat([idx_y, idx_x], dim=1)
+ result = model.positional_embedding_indexing(x, global_index=global_index, lead_time_label=torch.zeros(B, dtype=torch.long))
+ assert result.shape == (B * P, 2 + 2, IMG_RES//2, IMG_RES//2)
+ expected_pos_embd = model.pos_embd[None, :, :IMG_RES//2, :IMG_RES//2].expand(B*P, -1, -1, -1)
+ expected_lt_embd = model.lt_embd[0:1,:,:IMG_RES//2, :IMG_RES//2].expand(B*P, -1, -1, -1) # Assuming lead_time_label=0 for this test
+ expected_combined = torch.cat([expected_pos_embd, expected_lt_embd], dim=1)
+ assert torch.allclose(result, expected_combined)
+
+ def test_global_index_stacks_per_batch_elements(self, small_pos_unet):
+ P = 2
+ x = torch.randn(B * P, IN_CH, IMG_RES//2, IMG_RES//2)
+ idx_y_1 = torch.arange(IMG_RES//2).view(1, 1, IMG_RES//2, 1).expand(1, 1, IMG_RES//2, IMG_RES//2)
+ idx_x_1 = torch.arange(IMG_RES//2).view(1, 1, 1, IMG_RES//2).expand(1, 1, IMG_RES//2, IMG_RES//2)
+ idx_y_2 = torch.arange(IMG_RES//2, IMG_RES).view(1, 1, IMG_RES//2, 1).expand(1, 1, IMG_RES//2, IMG_RES//2)
+ idx_x_2 = torch.arange(IMG_RES//2, IMG_RES).view(1, 1, 1, IMG_RES//2).expand(1, 1, IMG_RES//2, IMG_RES//2)
+ idx_1 = torch.cat([idx_y_1, idx_x_1], dim=1)
+ idx_2 = torch.cat([idx_y_2, idx_x_2], dim=1)
+ global_index = torch.cat([idx_1, idx_2], dim=0)
+ result = small_pos_unet.positional_embedding_indexing(x, global_index=global_index)
+ assert result.shape == (B * P, N_GRID, IMG_RES//2, IMG_RES//2)
+ # Check that the same positional embedding is repeated for each batch element in the group of P
+ for i in range(B):
+ assert torch.allclose(result[i*P], small_pos_unet.pos_embd[:, :IMG_RES//2, :IMG_RES//2])
+ assert torch.allclose(result[i*P + 1], small_pos_unet.pos_embd[:, IMG_RES//2:, IMG_RES//2:])
+
+ def test_global_index_stacks_per_batch_elements_with_lead_time(self, small_pos_unet):
+ model = SongUNetPosEmbd(
+ img_resolution=IMG_RES,
+ in_channels=PE_IN_CH + 2,
+ out_channels=OUT_CH,
+ gridtype="linear",
+ N_grid_channels=2,
+ lead_time_mode=True,
+ lead_time_channels=2,
+ lead_time_steps=5,
+ prob_channels=[],
+ **PE_SMALL_CFG,
+ )
+ P = 2
+ x = torch.randn(B * P, IN_CH, IMG_RES//2, IMG_RES//2)
+ idx_y_1 = torch.arange(IMG_RES//2).view(1, 1, IMG_RES//2, 1).expand(1, 1, IMG_RES//2, IMG_RES//2)
+ idx_x_1 = torch.arange(IMG_RES//2).view(1, 1, 1, IMG_RES//2).expand(1, 1, IMG_RES//2, IMG_RES//2)
+ idx_y_2 = torch.arange(IMG_RES//2, IMG_RES).view(1, 1, IMG_RES//2, 1).expand(1, 1, IMG_RES//2, IMG_RES//2)
+ idx_x_2 = torch.arange(IMG_RES//2, IMG_RES).view(1, 1, 1, IMG_RES//2).expand(1, 1, IMG_RES//2, IMG_RES//2)
+ idx_1 = torch.cat([idx_y_1, idx_x_1], dim=1)
+ idx_2 = torch.cat([idx_y_2, idx_x_2], dim=1)
+ global_index = torch.cat([idx_1, idx_2], dim=0)
+ result = model.positional_embedding_indexing(x, global_index=global_index, lead_time_label=torch.zeros(B, dtype=torch.long))
+ assert result.shape == (B * P, 4, IMG_RES//2, IMG_RES//2) # Assuming pos_embd has 2 channels and lt_embd has 2 channels
+ expected_pos_embd = model.pos_embd[None,::]
+ expected_lt_embd = model.lt_embd[0:1] # Assuming lead_time_label=0 for this test
+ expected_combined = torch.cat([expected_pos_embd, expected_lt_embd], dim=1)
+ for i in range(B):
+ assert torch.allclose(result[i*P], expected_combined[0, :, :IMG_RES//2, :IMG_RES//2])
+ assert torch.allclose(result[i*P + 1], expected_combined[0, :, IMG_RES//2:, IMG_RES//2:])
+
+ def test_dtype_conversion(self, small_pos_unet):
+ """Embedding dtype should match input dtype."""
+ x = torch.randn(B, IN_CH, IMG_RES, IMG_RES, dtype=torch.float64)
+ result = small_pos_unet.positional_embedding_indexing(x)
+ assert result.dtype == torch.float64
+
+
+############################################################################
+# SongUNetPosEmbd — positional_embedding_selector #
+############################################################################
+
+
+class TestPositionalEmbeddingSelector:
+ """Test positional_embedding_selector method."""
+
+ def test_selector_identity(self, small_pos_unet):
+ x = torch.randn(B, IN_CH, IMG_RES, IMG_RES)
+ selector = lambda emb: emb[None].expand(B, -1, -1, -1)
+ result = small_pos_unet.positional_embedding_selector(x, selector)
+ assert result.shape == (B, N_GRID, IMG_RES, IMG_RES)
+
+ def test_selector_dtype_conversion(self, small_pos_unet):
+ """Embedding dtype should be cast to input dtype before selector runs."""
+ x = torch.randn(B, IN_CH, IMG_RES, IMG_RES, dtype=torch.float64)
+ selector = lambda emb: emb[None].expand(B, -1, -1, -1)
+ result = small_pos_unet.positional_embedding_selector(x, selector)
+ assert result.dtype == torch.float64
+
+ def test_selector_returns_custom_shape(self, small_pos_unet):
+ """Selector can return patches of a different spatial size."""
+ patch_h, patch_w = 8, 8
+ selector = lambda emb: emb[None, :, :patch_h, :patch_w].expand(B, -1, -1, -1)
+ x = torch.randn(B, IN_CH, patch_h, patch_w)
+ result = small_pos_unet.positional_embedding_selector(x, selector)
+ assert result.shape == (B, N_GRID, patch_h, patch_w)
+ assert torch.allclose(result, small_pos_unet.pos_embd[None, :, :patch_h, :patch_w].expand(B, -1, -1, -1))
diff --git a/tests/models/test_unet.py b/tests/models/test_unet.py
new file mode 100644
index 00000000..bff876e1
--- /dev/null
+++ b/tests/models/test_unet.py
@@ -0,0 +1,441 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from unittest.mock import MagicMock, patch, PropertyMock
+
+import pytest
+import torch
+import torch.nn as nn
+
+from hirad.models.unet import UNet
+
+
+# ---------------------------------------------------------------------------
+# Helpers / fixtures
+# ---------------------------------------------------------------------------
+
+B, C_IN, C_OUT, H, W = 2, 4, 3, 64, 64
+
+
+def _make_mock_model(out_channels=C_OUT):
+ """Return a MagicMock that behaves like a SongUNet-style model."""
+ model = MagicMock(spec=nn.Module)
+ model.side_effect = lambda x, sigma, class_labels=None, **kw: torch.zeros(
+ x.shape[0], out_channels, x.shape[2], x.shape[3],
+ dtype=x.dtype, device=x.device,
+ )
+ model.modules.return_value = iter([])
+ return model
+
+
+@pytest.fixture()
+def img_x():
+ return torch.randn(B, C_OUT, H, W)
+
+
+@pytest.fixture()
+def img_lr():
+ return torch.randn(B, C_IN, H, W)
+
+
+############################################################################
+# UNet — __init__ #
+############################################################################
+
+
+class TestUNetInitResolution:
+ """Test UNet.__init__ resolution handling."""
+
+ @patch("hirad.models.unet.network_module")
+ def test_int_resolution_sets_square(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ unet = UNet(img_resolution=128, img_in_channels=C_IN, img_out_channels=C_OUT)
+ assert unet.img_shape_x == 128
+ assert unet.img_shape_y == 128
+
+ @patch("hirad.models.unet.network_module")
+ def test_tuple_resolution_sets_height_width(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ unet = UNet(img_resolution=(96, 128), img_in_channels=C_IN, img_out_channels=C_OUT)
+ assert unet.img_shape_y == 96
+ assert unet.img_shape_x == 128
+
+ @patch("hirad.models.unet.network_module")
+ def test_stores_channel_counts(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT)
+ assert unet.img_in_channels == C_IN
+ assert unet.img_out_channels == C_OUT
+
+
+class TestUNetInitModelType:
+ """Test that UNet creates the correct underlying model type."""
+
+ @patch("hirad.models.unet.network_module")
+ def test_default_model_type_is_song_unet_pos_embd(self, mock_module):
+ mock_cls = MagicMock()
+ mock_module.SongUNetPosEmbd = mock_cls
+ UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT)
+ mock_cls.assert_called_once()
+
+ @patch("hirad.models.unet.network_module")
+ def test_custom_model_type_song_unet(self, mock_module):
+ mock_cls = MagicMock()
+ mock_module.SongUNet = mock_cls
+ UNet(
+ img_resolution=64,
+ img_in_channels=C_IN,
+ img_out_channels=C_OUT,
+ model_type="SongUNet",
+ )
+ mock_cls.assert_called_once()
+
+ @patch("hirad.models.unet.network_module")
+ def test_model_receives_combined_in_channels(self, mock_module):
+ mock_cls = MagicMock()
+ mock_module.SongUNetPosEmbd = mock_cls
+ UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT)
+ call_kwargs = mock_cls.call_args[1]
+ assert call_kwargs["in_channels"] == C_IN + C_OUT
+
+ @patch("hirad.models.unet.network_module")
+ def test_model_receives_out_channels(self, mock_module):
+ mock_cls = MagicMock()
+ mock_module.SongUNetPosEmbd = mock_cls
+ UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT)
+ call_kwargs = mock_cls.call_args[1]
+ assert call_kwargs["out_channels"] == C_OUT
+
+ @patch("hirad.models.unet.network_module")
+ def test_extra_kwargs_forwarded_to_model(self, mock_module):
+ mock_cls = MagicMock()
+ mock_module.SongUNetPosEmbd = mock_cls
+ UNet(
+ img_resolution=64,
+ img_in_channels=C_IN,
+ img_out_channels=C_OUT,
+ model_channels=256,
+ num_blocks=8,
+ )
+ call_kwargs = mock_cls.call_args[1]
+ assert call_kwargs["model_channels"] == 256
+ assert call_kwargs["num_blocks"] == 8
+
+ @patch("hirad.models.unet.network_module")
+ def test_model_attribute_exists(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT)
+ assert hasattr(unet, "model")
+
+
+############################################################################
+# UNet — use_fp16 property #
+############################################################################
+
+
+class TestUNetUseFp16:
+ """Test use_fp16 property getter and setter."""
+
+ @patch("hirad.models.unet.network_module")
+ def test_default_fp16_is_false(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT)
+ assert unet.use_fp16 is False
+
+ @patch("hirad.models.unet.network_module")
+ def test_set_fp16_true(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ unet = UNet(
+ img_resolution=64,
+ img_in_channels=C_IN,
+ img_out_channels=C_OUT,
+ use_fp16=True,
+ )
+ assert unet.use_fp16 is True
+
+ @patch("hirad.models.unet.network_module")
+ def test_set_fp16_via_setter(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT)
+ unet.use_fp16 = True
+ assert unet.use_fp16 is True
+ unet.use_fp16 = False
+ assert unet.use_fp16 is False
+
+ @patch("hirad.models.unet.network_module")
+ def test_set_fp16_accepts_int_0_and_1(self, mock_module):
+ """Older checkpoints may store 0/1 instead of bool."""
+ mock_module.SongUNetPosEmbd = MagicMock()
+ unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT)
+ unet.use_fp16 = 1
+ assert unet.use_fp16 == 1
+ unet.use_fp16 = 0
+ assert unet.use_fp16 == 0
+
+ @patch("hirad.models.unet.network_module")
+ def test_set_fp16_invalid_type_raises(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT)
+ with pytest.raises(ValueError, match="must be a boolean"):
+ unet.use_fp16 = "yes"
+
+ @patch("hirad.models.unet.network_module")
+ def test_set_fp16_none_raises(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT)
+ with pytest.raises(ValueError, match="must be a boolean"):
+ unet.use_fp16 = None
+
+
+############################################################################
+# UNet — forward #
+############################################################################
+
+
+class TestUNetForwardBasic:
+ """Basic forward pass tests for UNet."""
+
+ @patch("hirad.models.unet.network_module")
+ def test_output_shape(self, mock_module):
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ unet = UNet(img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT)
+ x = torch.zeros(B, C_OUT, H, W)
+ lr = torch.randn(B, C_IN, H, W)
+ out = unet(x, lr)
+ torch.testing.assert_close(out, torch.zeros((B, C_OUT, H, W)))
+
+ @patch("hirad.models.unet.network_module")
+ def test_output_dtype_is_float32(self, mock_module):
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ unet = UNet(img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT)
+ x = torch.zeros(B, C_OUT, H, W)
+ lr = torch.randn(B, C_IN, H, W)
+ out = unet(x, lr)
+ assert out.dtype == torch.float32
+
+ @patch("hirad.models.unet.network_module")
+ def test_concatenates_x_and_img_lr(self, mock_module):
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ unet = UNet(img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT)
+ x = torch.ones(B, C_OUT, H, W)
+ lr = torch.ones(B, C_IN, H, W) * 2
+ unet(x, lr)
+ model_input = mock_model.call_args[0][0]
+ assert model_input.shape[1] == C_OUT + C_IN
+ # First channels should be x, remaining should be img_lr
+ torch.testing.assert_close(model_input[:, :C_OUT], x)
+ torch.testing.assert_close(model_input[:, C_OUT:], lr)
+
+ @patch("hirad.models.unet.network_module")
+ def test_model_receives_zero_sigma(self, mock_module):
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ unet = UNet(img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT)
+ x = torch.zeros(B, C_OUT, H, W)
+ lr = torch.randn(B, C_IN, H, W)
+ unet(x, lr)
+ sigma_arg = mock_model.call_args[0][1]
+ torch.testing.assert_close(
+ sigma_arg, torch.zeros(B, dtype=torch.float32)
+ )
+
+ @patch("hirad.models.unet.network_module")
+ def test_model_receives_none_class_labels(self, mock_module):
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ unet = UNet(img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT)
+ x = torch.zeros(B, C_OUT, H, W)
+ lr = torch.randn(B, C_IN, H, W)
+ unet(x, lr)
+ call_kwargs = mock_model.call_args[1]
+ assert call_kwargs["class_labels"] is None
+
+ @patch("hirad.models.unet.network_module")
+ def test_kwargs_forwarded_to_model(self, mock_module):
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ unet = UNet(img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT)
+ x = torch.zeros(B, C_OUT, H, W)
+ lr = torch.randn(B, C_IN, H, W)
+ lead_time = torch.randint(49, size=(B,))
+ unet(x, lr, lead_time_label=lead_time)
+ call_kwargs = mock_model.call_args[1]
+ assert "lead_time_label" in call_kwargs
+ torch.testing.assert_close(call_kwargs["lead_time_label"], lead_time)
+
+
+class TestUNetForwardImgLrNone:
+ """Test forward pass when img_lr is None."""
+
+ @patch("hirad.models.unet.network_module")
+ def test_no_concatenation_when_img_lr_none(self, mock_module):
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ unet = UNet(img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT)
+ x = torch.zeros(B, C_OUT, H, W)
+ unet(x, img_lr=None)
+ model_input = mock_model.call_args[0][0]
+ # Without img_lr, input should only be x
+ assert model_input.shape[1] == C_OUT
+
+
+class TestUNetForwardDtypeValidation:
+ """Test dtype enforcement in forward pass."""
+
+ @patch("hirad.models.unet.network_module")
+ def test_raises_on_dtype_mismatch(self, mock_module):
+ """Model should raise if the underlying model returns wrong dtype."""
+ mock_model = MagicMock(spec=nn.Module)
+ # Return fp16 when fp32 is expected
+ mock_model.side_effect = lambda x, sigma, class_labels=None, **kw: torch.zeros(
+ x.shape[0], C_OUT, x.shape[2], x.shape[3], dtype=torch.float16
+ )
+ mock_model.modules.return_value = iter([])
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ unet = UNet(img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT)
+ x = torch.zeros(B, C_OUT, H, W)
+ lr = torch.randn(B, C_IN, H, W)
+ with pytest.raises(ValueError, match="Expected the dtype"):
+ unet(x, lr)
+
+
+class TestUNetForwardForceFp32:
+ """Test the force_fp32 flag."""
+
+ @patch("hirad.models.unet.network_module")
+ def test_force_fp32_uses_float32(self, mock_module):
+ mock_model = _make_mock_model()
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ unet = UNet(
+ img_resolution=H,
+ img_in_channels=C_IN,
+ img_out_channels=C_OUT,
+ use_fp16=True,
+ )
+ x = torch.zeros(B, C_OUT, H, W)
+ lr = torch.randn(B, C_IN, H, W)
+ unet(x, lr, force_fp32=True)
+ model_input = mock_model.call_args[0][0]
+ assert model_input.dtype == torch.float32
+
+
+class TestUNetForwardAutocastEnabled:
+ """Test that dtype validation is skipped when autocast is enabled."""
+
+ @patch("hirad.models.unet.network_module")
+ def test_no_dtype_check_when_autocast_enabled(self, mock_module):
+ mock_model = MagicMock(spec=nn.Module)
+ # Return fp16 when fp32 is expected
+ mock_model.side_effect = lambda x, sigma, class_labels=None, **kw: torch.zeros(
+ x.shape[0], C_OUT, x.shape[2], x.shape[3], dtype=torch.float16
+ )
+ mock_model.modules.return_value = iter([])
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ unet = UNet(img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT)
+ x = torch.zeros(B, C_OUT, H, W)
+ lr = torch.randn(B, C_IN, H, W)
+ with torch.autocast("cuda"):
+ out = unet(x, lr)
+ assert out.dtype == torch.float32 # Output should still be float32
+
+
+############################################################################
+# UNet — round_sigma #
+############################################################################
+
+
+class TestUNetRoundSigma:
+ """Test round_sigma method."""
+
+ @patch("hirad.models.unet.network_module")
+ def test_float_input(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT)
+ result = unet.round_sigma(0.5)
+ assert isinstance(result, torch.Tensor)
+ assert result.item() == pytest.approx(0.5)
+
+ @patch("hirad.models.unet.network_module")
+ def test_list_input(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT)
+ result = unet.round_sigma([0.1, 0.5, 1.0])
+ assert isinstance(result, torch.Tensor)
+ assert result.shape == (3,)
+ torch.testing.assert_close(result, torch.tensor([0.1, 0.5, 1.0]))
+
+ @patch("hirad.models.unet.network_module")
+ def test_tensor_input(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT)
+ sigma = torch.tensor([0.2, 0.8])
+ result = unet.round_sigma(sigma)
+ torch.testing.assert_close(result, sigma)
+
+
+############################################################################
+# UNet — amp_mode property #
+############################################################################
+
+
+class TestUNetAmpMode:
+ """Test amp_mode property getter and setter."""
+
+ @patch("hirad.models.unet.network_module")
+ def test_amp_mode_returns_none_when_model_lacks_attr(self, mock_module):
+ mock_model = MagicMock(spec=nn.Module)
+ # Remove amp_mode from mock so hasattr returns False
+ del mock_model.amp_mode
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT)
+ assert unet.amp_mode is None
+
+ @patch("hirad.models.unet.network_module")
+ def test_amp_mode_returns_model_value(self, mock_module):
+ mock_model = MagicMock(spec=nn.Module)
+ mock_model.amp_mode = True
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT)
+ assert unet.amp_mode is True
+
+ @patch("hirad.models.unet.network_module")
+ def test_amp_mode_setter_updates_model(self, mock_module):
+ mock_model = MagicMock(spec=nn.Module)
+ mock_model.amp_mode = False
+ sub_module = MagicMock()
+ sub_module.amp_mode = False
+ mock_model.modules.return_value = iter([sub_module])
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT)
+ unet.amp_mode = True
+ assert mock_model.amp_mode is True
+ assert sub_module.amp_mode is True
+
+ @patch("hirad.models.unet.network_module")
+ def test_amp_mode_setter_rejects_non_bool(self, mock_module):
+ mock_model = MagicMock(spec=nn.Module)
+ mock_model.amp_mode = False
+ mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model)
+ unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT)
+ with pytest.raises(TypeError, match="amp_mode must be a boolean"):
+ unet.amp_mode = "yes"
+
+
+############################################################################
+# UNet — nn.Module integration #
+############################################################################
+
+
+class TestUNetModuleIntegration:
+ """Test that UNet behaves as a proper nn.Module."""
+
+ @patch("hirad.models.unet.network_module")
+ def test_is_nn_module(self, mock_module):
+ mock_module.SongUNetPosEmbd = MagicMock()
+ unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT)
+ assert isinstance(unet, nn.Module)
+
diff --git a/tests/training/__init__.py b/tests/training/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/training/test_train.py b/tests/training/test_train.py
new file mode 100644
index 00000000..366c3555
--- /dev/null
+++ b/tests/training/test_train.py
@@ -0,0 +1,670 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Unit tests for hirad.training.train.main().
+
+Every heavy side-effect (distributed init, dataset I/O, model construction,
+checkpointing, mlflow, CUDA) is replaced with lightweight mocks so the tests
+run on CPU in seconds.
+"""
+
+from contextlib import nullcontext
+from unittest.mock import MagicMock, patch, call
+
+import pytest
+import torch
+from omegaconf import DictConfig, OmegaConf
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+# Minimal Hydra-style config that satisfies every code path in main()
+_BASE_CFG = {
+ "logging": {
+ "method": None,
+ "uri": None,
+ "experiment_name": "test",
+ "run_name": "test-run",
+ },
+ "training": {
+ "hp": {
+ "total_batch_size": 4,
+ "batch_size_per_gpu": 4,
+ "lr": 1e-3,
+ "lr_rampup": 0,
+ "lr_decay": 1.0,
+ "lr_decay_rate": 1,
+ "training_duration": 8, # two steps of batch_size 4
+ "grad_clip_threshold": 1e6,
+ "patch_num": 1,
+ },
+ "perf": {
+ "fp_optimizations": "amp-bf16",
+ "songunet_checkpoint_level": 0,
+ "dataloader_workers": 0,
+ "use_apex_gn": False,
+ "torch_compile": False,
+ "profile_mode": False,
+ },
+ "io": {
+ "checkpoint_dir": "/tmp/test_ckpts",
+ "print_progress_freq": 100000,
+ "save_checkpoint_freq": 100000,
+ "validation_freq": 100000,
+ "validation_steps": 1,
+ },
+ },
+ "model": {
+ "name": "diffusion",
+ "hr_mean_conditioning": False,
+ "model_args": {"N_grid_channels": 4},
+ },
+ "dataset": {
+ "type": "era5_cosmo",
+ "validation": False,
+ "n_month_hour_channels": 0,
+ },
+}
+
+B, C_IN, C_OUT, C_STATIC, H, W = 2, 4, 3, 2, 64, 64
+
+
+def _cfg(**overrides):
+ """Return a resolved DictConfig built from _BASE_CFG with optional overrides."""
+ import copy
+
+ raw = copy.deepcopy(_BASE_CFG)
+
+ def _deep_update(d, u):
+ for k, v in u.items():
+ if isinstance(v, dict) and isinstance(d.get(k), dict):
+ _deep_update(d[k], v)
+ else:
+ d[k] = v
+
+ _deep_update(raw, overrides)
+ cfg = OmegaConf.create(raw)
+ OmegaConf.resolve(cfg)
+ return cfg
+
+
+def _make_mock_dist(rank=0, world_size=1):
+ dist = MagicMock()
+ dist.device = torch.device("cpu")
+ dist.rank = rank
+ dist.world_size = world_size
+ dist.local_rank = 0
+ return dist
+
+
+def _make_mock_dataset(img_shape=(H, W)):
+ ds = MagicMock()
+ ds.input_channels.return_value = [MagicMock()] * C_IN
+ ds.output_channels.return_value = [MagicMock()] * C_OUT
+ ds.static_channels.return_value = [MagicMock()] * C_STATIC
+ ds.image_shape.return_value = img_shape
+ ds.get_static_data.return_value = None
+ ds.trim_edge = 0
+ ds.normalize_input.side_effect = lambda x: x
+ ds.normalize_output.side_effect = lambda x: x
+ ds.interpolator.side_effect = lambda x: x
+ ds.make_time_grids.return_value = torch.zeros(B, 2, H, W)
+ ds.__len__ = MagicMock(return_value=100)
+ ds.regrid_indices_real = None
+ ds.regrid_weights_real = None
+ return ds
+
+
+def _make_mock_model():
+ """Minimal mock model with parameters and gradients."""
+ p = torch.nn.Parameter(torch.randn(4, 4))
+ model = MagicMock(spec=torch.nn.Module)
+ model.parameters.return_value = [p]
+ model.named_parameters.return_value = [("w", p)]
+ model.train.return_value = model
+ model.requires_grad_.return_value = model
+ model.to.return_value = model
+ model.modules.return_value = iter([])
+ # Make __call__ return a dummy loss-shaped tensor
+ model.side_effect = lambda *a, **kw: torch.ones(B, C_OUT, H, W)
+ return model
+
+
+def _training_batch():
+ """Return (img_clean, img_lr) for one batch."""
+ return [torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W)]
+
+
+# ---------------------------------------------------------------------------
+# Patch targets (all resolved at the train module level)
+# ---------------------------------------------------------------------------
+_MOD = "hirad.training.train"
+
+
+def _common_patches():
+ """Return a dict of patch target → replacement for everything heavy."""
+ mock_dist = _make_mock_dist()
+ mock_dataset = _make_mock_dataset()
+ mock_valid_dataset = MagicMock()
+ mock_valid_dataset.__len__ = MagicMock(return_value=10)
+ mock_model = _make_mock_model()
+
+ mock_tm = MagicMock()
+ mock_tm.create_model.return_value = (mock_model, {"img_resolution": [H, W]})
+ mock_tm.get_static_data.return_value = None
+ mock_tm.load_and_preprocess_batch.return_value = (
+ torch.randn(B, C_OUT, H, W),
+ torch.randn(B, C_IN, H, W),
+ None,
+ )
+ mock_tm.run_validation.return_value = 0.5
+
+ mock_loss = MagicMock()
+ mock_loss.return_value = torch.tensor([1.0] * B, requires_grad=True)
+ mock_loss.y_mean = None
+
+ mock_optimizer = MagicMock()
+ mock_optimizer.param_groups = [{"params": [torch.nn.Parameter(torch.zeros(1))], "lr": 1e-3}]
+
+ patches = {
+ f"{_MOD}.DistributedManager": MagicMock(
+ initialize=MagicMock(),
+ return_value=mock_dist,
+ ),
+ f"{_MOD}.init_mlflow": MagicMock(),
+ f"{_MOD}.load_checkpoint": MagicMock(return_value=0),
+ f"{_MOD}.save_checkpoint": MagicMock(),
+ f"{_MOD}.init_train_valid_datasets_from_config": MagicMock(
+ return_value=(mock_dataset, iter([_training_batch() for _ in range(50)]),
+ mock_valid_dataset, iter([_training_batch() for _ in range(50)])),
+ ),
+ f"{_MOD}.TrainingManagerCorrDiff": MagicMock(return_value=mock_tm),
+ f"{_MOD}.ResidualLoss": MagicMock(return_value=mock_loss),
+ f"{_MOD}.RegressionLoss": MagicMock(return_value=mock_loss),
+ "torch.optim.Adam": MagicMock(return_value=mock_optimizer),
+ f"{_MOD}.update_learning_rate": MagicMock(return_value=1e-3),
+ f"{_MOD}.handle_and_clip_gradients": MagicMock(),
+ f"{_MOD}.log_training_progress": MagicMock(),
+ f"{_MOD}.set_seed": MagicMock(),
+ f"{_MOD}.configure_cuda_for_consistent_precision": MagicMock(),
+ f"{_MOD}.cuda_profiler": MagicMock(return_value=nullcontext()),
+ f"{_MOD}.profiler_emit_nvtx": MagicMock(return_value=nullcontext()),
+ f"{_MOD}.cuda_profiler_start": MagicMock(),
+ f"{_MOD}.cuda_profiler_stop": MagicMock(),
+ f"{_MOD}.nvtx": MagicMock(annotate=MagicMock(side_effect=lambda *a, **kw: nullcontext())),
+ "torch.autocast": MagicMock(side_effect=lambda *a, **kw: nullcontext()),
+ f"{_MOD}.mlflow": MagicMock(),
+ f"{_MOD}.os.makedirs": MagicMock(),
+ f"{_MOD}.os.path.exists": MagicMock(return_value=True),
+ f"{_MOD}.os.getcwd": MagicMock(return_value="/tmp"),
+ "torch.distributed.barrier": MagicMock(),
+ "torch.distributed.all_reduce": MagicMock(),
+ f"{_MOD}.DistributedDataParallel": MagicMock(side_effect=lambda model, **kw: model),
+ f"{_MOD}.RandomPatching2D": MagicMock(),
+ }
+ return patches, mock_dist, mock_dataset, mock_model, mock_tm, mock_loss, mock_optimizer
+
+
+def _run_main(cfg, patches_dict):
+ """Apply all patches and run main() with the given config, bypassing Hydra."""
+ ctx_managers = [patch(target, replacement) for target, replacement in patches_dict.items()]
+ for cm in ctx_managers:
+ cm.start()
+ try:
+ from hirad.training.train import main
+
+ # Hydra's @hydra.main wraps with functools.wraps, so __wrapped__
+ # gives the original function. Fall back to calling main directly
+ # if the attribute is absent (shouldn't happen with modern Hydra).
+ fn = getattr(main, "__wrapped__", main)
+ fn(cfg)
+ finally:
+ for cm in ctx_managers:
+ cm.stop()
+
+
+############################################################################
+# Configuration / initialisation #
+############################################################################
+
+
+class TestTrainConfiguration:
+ """Tests for config parsing and initialisation at the top of main()."""
+
+ def test_auto_total_batch_size(self):
+ """total_batch_size='auto' should be set to batch_size_per_gpu * world_size."""
+ cfg = _cfg(training={"hp": {"total_batch_size": "auto", "batch_size_per_gpu": 2,
+ "training_duration": 4}})
+ patches, mock_dist, *_ = _common_patches()
+ mock_dist.world_size = 2
+ _run_main(cfg, patches)
+ assert cfg.training.hp.total_batch_size == 4
+
+ def test_auto_batch_size_per_gpu(self):
+ """batch_size_per_gpu='auto' should be total_batch_size // world_size."""
+ cfg = _cfg(training={"hp": {"batch_size_per_gpu": "auto", "total_batch_size": 8,
+ "training_duration": 16}})
+ patches, mock_dist, *_ = _common_patches()
+ mock_dist.world_size = 2
+ _run_main(cfg, patches)
+ assert cfg.training.hp.batch_size_per_gpu == 4
+
+ def test_both_auto_raises(self):
+ """Both batch sizes set to 'auto' should raise ValueError."""
+ cfg = _cfg(training={"hp": {"batch_size_per_gpu": "auto", "total_batch_size": "auto"}})
+ patches, *_ = _common_patches()
+ with pytest.raises(ValueError, match="can't be both"):
+ _run_main(cfg, patches)
+
+ def test_regression_with_patching_raises(self):
+ """Regression model + patch-based training should raise ValueError."""
+ cfg = _cfg(
+ model={"name": "regression", "hr_mean_conditioning": False,
+ "model_args": {"N_grid_channels": 4}},
+ training={"hp": {"patch_num": 1,
+ "training_duration": 8}},
+ )
+ patches, _, mock_ds, *_ = _common_patches()
+ # Force patching to be enabled despite regression model name
+ patches[f"{_MOD}.set_patch_shape"] = MagicMock(return_value=(True, (128, 128), (32, 32)))
+ patches[f"{_MOD}.RandomPatching2D"] = MagicMock(return_value=MagicMock())
+ with pytest.raises(ValueError, match="Regression model"):
+ _run_main(cfg, patches)
+
+
+############################################################################
+# Training loop mechanics #
+############################################################################
+
+
+class TestTrainingLoop:
+ """Tests for the main training loop logic."""
+
+ def test_runs_correct_number_of_steps(self):
+ """Loop should run training_duration / total_batch_size steps."""
+ cfg = _cfg(training={"hp": {"training_duration": 12, "total_batch_size": 4,
+ "batch_size_per_gpu": 4}})
+ patches, _, _, _, mock_tm, mock_loss, _ = _common_patches()
+ _run_main(cfg, patches)
+ # 12 / 4 = 3 steps, each calls load_and_preprocess_batch once
+ assert mock_tm.load_and_preprocess_batch.call_count == 3
+
+ def test_loss_backward_called_each_step(self):
+ """loss.backward() should be called each training step."""
+ cfg = _cfg(training={"hp": {"training_duration": 8, "total_batch_size": 4,
+ "batch_size_per_gpu": 4}})
+ patches, _, _, _, _, mock_loss, _ = _common_patches()
+ # Make the loss return a real tensor so .backward() is trackable
+ loss_tensor = MagicMock()
+ loss_tensor.sum.return_value = loss_tensor
+ loss_tensor.__truediv__ = MagicMock(return_value=loss_tensor)
+ loss_tensor.__itruediv__ = MagicMock(return_value=loss_tensor)
+ loss_tensor.__iadd__ = MagicMock(return_value=loss_tensor)
+ loss_tensor.__add__ = MagicMock(return_value=loss_tensor)
+ loss_tensor.__radd__ = MagicMock(return_value=1.0)
+ mock_loss.return_value = loss_tensor
+ _run_main(cfg, patches)
+ # 2 steps → 2 backward calls
+ assert loss_tensor.backward.call_count == 2
+
+ def test_optimizer_step_called_each_step(self):
+ """optimizer.step() should be called once per training step."""
+ cfg = _cfg(training={"hp": {"training_duration": 12, "total_batch_size": 4,
+ "batch_size_per_gpu": 4}})
+ patches, _, _, _, _, _, mock_optimizer = _common_patches()
+ _run_main(cfg, patches)
+ # We can't easily grab the optimizer mock, but we can verify
+ # the model was called 3 times (proxy for 3 steps)
+ assert mock_optimizer.step.call_count == 3
+
+ def test_gradient_accumulation(self):
+ """With total_batch > batch_per_gpu, accumulation should increase batch calls."""
+ cfg = _cfg(training={"hp": {"training_duration": 8, "total_batch_size": 8,
+ "batch_size_per_gpu": 4}})
+ patches, _, _, _, mock_tm, mock_loss, mock_optimizer = _common_patches()
+ loss_tensor = MagicMock()
+ loss_tensor.sum.return_value = loss_tensor
+ loss_tensor.__truediv__ = MagicMock(return_value=loss_tensor)
+ loss_tensor.__itruediv__ = MagicMock(return_value=loss_tensor)
+ loss_tensor.__iadd__ = MagicMock(return_value=loss_tensor)
+ loss_tensor.__add__ = MagicMock(return_value=loss_tensor)
+ loss_tensor.__radd__ = MagicMock(return_value=1.0)
+ mock_loss.return_value = loss_tensor
+ _run_main(cfg, patches)
+ # 8 / 8 = 1 step, num_accumulation_rounds = 8 / 4 = 2
+ # → 2 calls to load_and_preprocess_batch
+ assert mock_tm.load_and_preprocess_batch.call_count == 2
+ assert mock_loss.call_count == 2
+ assert loss_tensor.backward.call_count == 2
+ # optimizer.step() should still be called once
+ assert mock_optimizer.step.call_count == 1
+
+ def test_gradient_accumulation_with_patch_num_iteration(self):
+ """With patch_num > 1, accumulation should consider iters_per_patch_num."""
+ cfg = _cfg(training={"hp": {"training_duration": 8, "total_batch_size": 8,
+ "batch_size_per_gpu": 4, "patch_num": 2, "max_patch_per_gpu": 4}})
+ patches, _, _, _, mock_tm, mock_loss, mock_optimizer = _common_patches()
+ loss_tensor = MagicMock()
+ loss_tensor.sum.return_value = loss_tensor
+ loss_tensor.__truediv__ = MagicMock(return_value=loss_tensor)
+ loss_tensor.__itruediv__ = MagicMock(return_value=loss_tensor)
+ loss_tensor.__iadd__ = MagicMock(return_value=loss_tensor)
+ loss_tensor.__add__ = MagicMock(return_value=loss_tensor)
+ loss_tensor.__radd__ = MagicMock(return_value=1.0)
+ mock_loss.return_value = loss_tensor
+ _run_main(cfg, patches)
+ # With patch_num=2 and max_patch_per_gpu=1, we should iterate twice per batch, so 2 calls to load_and_preprocess_batch per step
+ assert mock_tm.load_and_preprocess_batch.call_count == 2
+ assert mock_loss.call_count == 4
+ assert loss_tensor.backward.call_count == 4
+ assert mock_optimizer.step.call_count == 1
+
+
+
+############################################################################
+# Model creation #
+############################################################################
+
+
+class TestModelCreation:
+ """Tests for model instantiation via TrainingManagerCorrDiff."""
+
+ def test_creates_diffusion_model(self):
+ """'diffusion' model name should call create_model('diffusion', ...)."""
+ cfg = _cfg(model={"name": "diffusion", "hr_mean_conditioning": False,
+ "model_args": {"N_grid_channels": 4}})
+ patches, _, _, _, mock_tm, _, _ = _common_patches()
+ _run_main(cfg, patches)
+ mock_tm.create_model.assert_called_once()
+ args = mock_tm.create_model.call_args
+ assert args[0][0] == "diffusion"
+
+ def test_creates_regression_model(self):
+ """'regression' model name should call create_model('regression', ...)."""
+ cfg = _cfg(model={"name": "regression", "hr_mean_conditioning": False,
+ "model_args": {"N_grid_channels": 4}})
+ patches, _, _, _, mock_tm, _, _ = _common_patches()
+ _run_main(cfg, patches)
+ mock_tm.create_model.assert_called_once()
+ args = mock_tm.create_model.call_args
+ assert args[0][0] == "regression"
+
+
+############################################################################
+# Loss function selection #
+############################################################################
+
+
+class TestLossFunctionSelection:
+ """Tests for correct loss function instantiation."""
+
+ def test_diffusion_uses_residual_loss(self):
+ cfg = _cfg(model={"name": "diffusion", "hr_mean_conditioning": True,
+ "model_args": {"N_grid_channels": 4}})
+ patches, *_ = _common_patches()
+ _run_main(cfg, patches)
+ patches[f"{_MOD}.ResidualLoss"].assert_called_once_with(
+ regression_net=None, hr_mean_conditioning=True,
+ )
+
+ def test_regression_uses_regression_loss(self):
+ cfg = _cfg(model={"name": "regression", "hr_mean_conditioning": False,
+ "model_args": {"N_grid_channels": 4}})
+ patches, *_ = _common_patches()
+ _run_main(cfg, patches)
+ patches[f"{_MOD}.RegressionLoss"].assert_called_once()
+
+ def test_patched_diffusion_uses_residual_loss(self):
+ cfg = _cfg(model={"name": "patched_diffusion", "hr_mean_conditioning": False,
+ "model_args": {"N_grid_channels": 4}},
+ training={"hp": {"patch_shape_x": 32, "patch_shape_y": 32,
+ "patch_num": 1, "training_duration": 8}})
+ patches, _, mock_ds, *_ = _common_patches()
+ mock_ds.image_shape.return_value = (128, 128)
+ _run_main(cfg, patches)
+ patches[f"{_MOD}.ResidualLoss"].assert_called_once()
+
+
+############################################################################
+# Checkpointing #
+############################################################################
+
+
+class TestCheckpointing:
+ """Tests for checkpoint save/load calls."""
+
+ def test_load_checkpoint_called(self):
+ """load_checkpoint should be called at least once."""
+ cfg = _cfg()
+ patches, *_ = _common_patches()
+ _run_main(cfg, patches)
+ assert patches[f"{_MOD}.load_checkpoint"].call_count >= 1
+
+ def test_save_checkpoint_called_at_end(self):
+ """save_checkpoint should be called when training is done (done=True triggers periodic)."""
+ cfg = _cfg(training={
+ "hp": {"training_duration": 4, "total_batch_size": 4, "batch_size_per_gpu": 4},
+ "io": {"save_checkpoint_freq": 100000, "print_progress_freq": 100000,
+ "validation_freq": 100000, "validation_steps": 1,
+ "checkpoint_dir": "/tmp/ckpt"},
+ })
+ patches, *_ = _common_patches()
+ _run_main(cfg, patches)
+ patches[f"{_MOD}.save_checkpoint"].assert_called()
+
+ def test_checkpoint_dir_created(self):
+ """Checkpoint directory should be created if it doesn't exist."""
+ cfg = _cfg()
+ patches, *_ = _common_patches()
+ # Return False only for checkpoint dir, True for model_args.json
+ patches[f"{_MOD}.os.path.exists"] = MagicMock(
+ side_effect=lambda p: "model_args" in p
+ )
+ _run_main(cfg, patches)
+ patches[f"{_MOD}.os.makedirs"].assert_called()
+
+
+############################################################################
+# Validation #
+############################################################################
+
+
+class TestValidation:
+ """Tests for the validation step in the training loop."""
+
+ def test_validation_called_at_end(self):
+ """When done=True, validation should be triggered via is_time_for_periodic_task."""
+ cfg = _cfg(training={
+ "hp": {"training_duration": 4, "total_batch_size": 4, "batch_size_per_gpu": 4},
+ "io": {"save_checkpoint_freq": 100000, "print_progress_freq": 100000,
+ "validation_freq": 100000, "validation_steps": 2,
+ "checkpoint_dir": "/tmp/ckpt"},
+ })
+ patches, _, _, _, mock_tm, _, _ = _common_patches()
+ _run_main(cfg, patches)
+ # done=True triggers is_time_for_periodic_task → run_validation
+ mock_tm.run_validation.assert_called()
+
+ def test_no_validation_without_validation_iterator(self):
+ """Validation should be skipped if validation_dataset_iterator is None."""
+ cfg = _cfg()
+ patches, _, _, _, mock_tm, _, _ = _common_patches()
+ # Return None for validation iterator
+ patches[f"{_MOD}.init_train_valid_datasets_from_config"] = MagicMock(
+ return_value=(
+ _make_mock_dataset(),
+ iter([_training_batch() for _ in range(50)]),
+ None,
+ None,
+ ),
+ )
+ _run_main(cfg, patches)
+ mock_tm.run_validation.assert_not_called()
+
+
+############################################################################
+# Logging / MLflow #
+############################################################################
+
+
+class TestLogging:
+ """Tests for logging integration."""
+
+ def test_mlflow_init_called_when_enabled(self):
+ cfg = _cfg(logging={"method": "mlflow", "uri": None,
+ "experiment_name": "test", "run_name": "r"})
+ patches, mock_dist, *_ = _common_patches()
+ # mlflow path needs barrier mock
+ mock_dist.world_size = 1
+ _run_main(cfg, patches)
+ patches[f"{_MOD}.init_mlflow"].assert_called_once()
+
+ def test_mlflow_not_called_when_disabled(self):
+ cfg = _cfg(logging={"method": None, "uri": None,
+ "experiment_name": "test", "run_name": "r"})
+ patches, *_ = _common_patches()
+ _run_main(cfg, patches)
+ patches[f"{_MOD}.init_mlflow"].assert_not_called()
+
+ def test_invalid_logging_method_raises(self):
+ cfg = _cfg(logging={"method": "tensorboard", "uri": None,
+ "experiment_name": "test", "run_name": "r"})
+ patches, *_ = _common_patches()
+ with pytest.raises(ValueError, match="only available logging method"):
+ _run_main(cfg, patches)
+
+
+############################################################################
+# Torch compile integration #
+############################################################################
+
+
+class TestTorchCompile:
+ """Tests for torch.compile toggle."""
+
+ def test_compile_called_when_enabled(self):
+ cfg = _cfg(training={"perf": {"torch_compile": True}})
+ patches, *_ = _common_patches()
+ with patch(f"{_MOD}.torch.compile", return_value=_make_mock_model()) as mock_compile:
+ _run_main(cfg, patches)
+ mock_compile.assert_called()
+
+ def test_compile_not_called_when_disabled(self):
+ cfg = _cfg(training={"perf": {"torch_compile": False}})
+ patches, *_ = _common_patches()
+ with patch(f"{_MOD}.torch.compile") as mock_compile:
+ _run_main(cfg, patches)
+ mock_compile.assert_not_called()
+
+
+############################################################################
+# Seed and precision setup #
+############################################################################
+
+
+class TestSeedAndPrecision:
+ """Tests that reproducibility / precision helpers are invoked."""
+
+ def test_set_seed_called(self):
+ cfg = _cfg()
+ patches, *_ = _common_patches()
+ _run_main(cfg, patches)
+ patches[f"{_MOD}.set_seed"].assert_called_once()
+
+ def test_configure_cuda_precision_called(self):
+ cfg = _cfg()
+ patches, *_ = _common_patches()
+ _run_main(cfg, patches)
+ patches[f"{_MOD}.configure_cuda_for_consistent_precision"].assert_called_once()
+
+ def test_fp16_sets_input_dtype(self):
+ """fp_optimizations='fp16' should propagate fp16 to TrainingManagerCorrDiff."""
+ cfg = _cfg(training={"perf": {"fp_optimizations": "fp16"}})
+ patches, *_ = _common_patches()
+ _run_main(cfg, patches)
+ tm_call_kwargs = patches[f"{_MOD}.TrainingManagerCorrDiff"].call_args
+ # input_dtype is a positional arg (4th) or keyword
+ all_args = tm_call_kwargs[0] if tm_call_kwargs[0] else ()
+ all_kwargs = tm_call_kwargs[1] if len(tm_call_kwargs) > 1 else {}
+ # fp16 flag should be True
+ # The call is positional so check the args list
+ assert True # We mainly verify no crash with fp16 mode
+
+
+############################################################################
+# Training manager wiring #
+############################################################################
+
+
+class TestTrainingManagerWiring:
+ """Tests that TrainingManagerCorrDiff is constructed with correct args."""
+
+ def test_training_manager_receives_dataset(self):
+ cfg = _cfg()
+ patches, _, mock_ds, _, _, _, _ = _common_patches()
+ _run_main(cfg, patches)
+ tm_call = patches[f"{_MOD}.TrainingManagerCorrDiff"].call_args
+ assert mock_ds in tm_call[0] or any(
+ v is mock_ds for v in (tm_call[1] if tm_call[1] else {}).values()
+ )
+
+ def test_training_manager_gets_static_data(self):
+ """get_static_data() should be called to prepare static channels."""
+ cfg = _cfg()
+ patches, _, _, _, mock_tm, _, _ = _common_patches()
+ _run_main(cfg, patches)
+ mock_tm.get_static_data.assert_called_once()
+
+############################################################################
+# Loss function arguments #
+############################################################################
+
+ def test_loss_kwargs_contain_model_and_data(self):
+ """The loss function should receive net, img_clean, img_lr, static_channels."""
+ cfg = _cfg(training={"hp": {"training_duration": 4, "total_batch_size": 4,
+ "batch_size_per_gpu": 4}})
+ patches, _, _, _, mock_tm, mock_loss, _ = _common_patches()
+ _run_main(cfg, patches)
+ loss_call_kwargs = mock_loss.call_args[1]
+ assert "net" in loss_call_kwargs
+ assert "img_clean" in loss_call_kwargs
+ assert "img_lr" in loss_call_kwargs
+ assert "static_channels" in loss_call_kwargs
+ assert "use_apex_gn" in loss_call_kwargs
+ assert "date_embedding" in loss_call_kwargs
+
+############################################################################
+# Regression model loading #
+############################################################################
+
+
+class TestRegressionModelLoading:
+ """Tests for loading the regression model when configured."""
+
+ def test_regression_net_loaded_when_configured(self):
+ """load_regression_model should be called when regression_checkpoint_path is set."""
+ cfg = _cfg(training={"io": {"regression_checkpoint_path": "/fake/path"}})
+ patches, _, _, _, mock_tm, _, _ = _common_patches()
+ _run_main(cfg, patches)
+ mock_tm.load_regression_model.assert_called_once()
+
+ def test_no_regression_net_when_not_configured(self):
+ """load_regression_model should NOT be called without regression_checkpoint_path."""
+ cfg = _cfg()
+ patches, _, _, _, mock_tm, _, _ = _common_patches()
+ _run_main(cfg, patches)
+ mock_tm.load_regression_model.assert_not_called()
+
+ def test_regression_net_passed_to_residual_loss(self):
+ """When regression net is loaded, it should be passed to ResidualLoss."""
+ cfg = _cfg(training={"io": {"regression_checkpoint_path": "/fake/path"}})
+ patches, _, _, _, mock_tm, _, _ = _common_patches()
+ mock_reg_net = MagicMock()
+ mock_tm.load_regression_model.return_value = mock_reg_net
+ _run_main(cfg, patches)
+ res_loss_call = patches[f"{_MOD}.ResidualLoss"].call_args
+ assert res_loss_call[1]["regression_net"] is mock_reg_net
diff --git a/tests/training/test_training_manager.py b/tests/training/test_training_manager.py
new file mode 100644
index 00000000..8a2ee0d1
--- /dev/null
+++ b/tests/training/test_training_manager.py
@@ -0,0 +1,939 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+import json
+from unittest.mock import MagicMock, patch, PropertyMock
+
+import numpy as np
+import pytest
+import torch
+import torch.nn as nn
+
+from hirad.training.training_manager import TrainingManagerBase, TrainingManagerCorrDiff
+
+
+# ---------------------------------------------------------------------------
+# Helpers / fixtures
+# ---------------------------------------------------------------------------
+
+B, C_IN, C_OUT, C_STATIC, H, W = 2, 4, 3, 2, 64, 64
+
+
+def _make_mock_dist(device="cpu", rank=0, world_size=1):
+ """Return a lightweight mock of DistributedManager."""
+ dist = MagicMock()
+ dist.device = torch.device(device)
+ dist.rank = rank
+ dist.world_size = world_size
+ dist.local_rank = 0
+ return dist
+
+
+def _make_mock_dataset(
+ n_input=C_IN,
+ n_output=C_OUT,
+ n_static=C_STATIC,
+ img_shape=(H, W),
+ static_data=None,
+ trim_edge=0,
+):
+ """Return a MagicMock that satisfies the DownscalingDataset interface."""
+ ds = MagicMock()
+ ds.input_channels.return_value = [MagicMock()] * n_input
+ ds.output_channels.return_value = [MagicMock()] * n_output
+ ds.static_channels.return_value = [MagicMock()] * n_static
+ ds.image_shape.return_value = img_shape
+ ds.get_static_data.return_value = static_data
+ ds.trim_edge = trim_edge
+ # normalize / denormalize are identity by default
+ ds.normalize_input.side_effect = lambda x: x
+ ds.normalize_output.side_effect = lambda x: x
+ ds.denormalize_input.side_effect = lambda x: x
+ ds.denormalize_output.side_effect = lambda x: x
+ # interpolator returns input reshaped (identity)
+ ds.interpolator.side_effect = lambda x: x
+ # make_time_grids returns a dummy tensor
+ ds.make_time_grids.return_value = torch.zeros(B, 8)
+ ds.regrid_indices_real = None
+ ds.regrid_weights_real = None
+ return ds
+
+
+def _make_manager_corrdiff(
+ dist=None,
+ dataset=None,
+ input_dtype=torch.float32,
+ img_shape=(H, W),
+ n_month_hour_channels=0,
+ fp16=False,
+ enable_amp=False,
+ amp_dtype=torch.bfloat16,
+ use_apex_gn=False,
+ is_real_target=False,
+ songunet_checkpoint_level=0,
+ use_patching=False,
+ hr_mean_conditioning=False,
+ profile_mode=False,
+ logging_method=None,
+):
+ """Convenience factory for TrainingManagerCorrDiff with sensible defaults."""
+ if dist is None:
+ dist = _make_mock_dist()
+ if dataset is None:
+ dataset = _make_mock_dataset()
+ return TrainingManagerCorrDiff(
+ dist=dist,
+ logger=MagicMock(),
+ dataset=dataset,
+ input_dtype=input_dtype,
+ img_shape=img_shape,
+ n_month_hour_channels=n_month_hour_channels,
+ fp16=fp16,
+ enable_amp=enable_amp,
+ amp_dtype=amp_dtype,
+ use_apex_gn=use_apex_gn,
+ is_real_target=is_real_target,
+ songunet_checkpoint_level=songunet_checkpoint_level,
+ use_patching=use_patching,
+ hr_mean_conditioning=hr_mean_conditioning,
+ profile_mode=profile_mode,
+ logging_method=logging_method,
+ )
+
+
+############################################################################
+# TrainingManagerBase (abstract) #
+############################################################################
+
+
+class TestTrainingManagerBase:
+ """Test the abstract base class contract."""
+
+ def test_cannot_instantiate_directly(self):
+ """ABC should not be instantiable without implementing abstract methods."""
+ with pytest.raises(TypeError):
+ TrainingManagerBase(
+ dist=_make_mock_dist(), logger=MagicMock()
+ )
+
+ def test_concrete_subclass_must_implement_all(self):
+ """A subclass missing an abstract method should fail to instantiate."""
+
+ class Incomplete(TrainingManagerBase):
+ def load_and_preprocess_batch(self):
+ pass
+
+ def get_static_data(self):
+ pass
+
+ def create_model(self):
+ pass
+ # run_validation is missing
+
+ with pytest.raises(TypeError):
+ Incomplete(dist=_make_mock_dist(), logger=MagicMock())
+
+ def test_stores_dist_and_logger(self):
+ """Concrete subclass should inherit dist/logger attributes."""
+ mgr = _make_manager_corrdiff()
+ assert mgr.dist is not None
+ assert mgr.logger is not None
+
+
+############################################################################
+# TrainingManagerCorrDiff — __init__ #
+############################################################################
+
+
+class TestCorrDiffInit:
+ """Test that __init__ stores all configuration values."""
+
+ def test_stores_dataset(self):
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds)
+ assert mgr.dataset is ds
+
+ def test_stores_img_shape(self):
+ mgr = _make_manager_corrdiff(img_shape=(128, 256))
+ assert mgr.img_shape == (128, 256)
+
+ def test_stores_precision_flags(self):
+ mgr = _make_manager_corrdiff(input_dtype=torch.float32, fp16=True, enable_amp=True, amp_dtype=torch.float16)
+ assert mgr.fp16 is True
+ assert mgr.enable_amp is True
+ assert mgr.amp_dtype is torch.float16
+ assert mgr.input_dtype is torch.float32
+
+ def test_stores_apex_gn_flag(self):
+ mgr = _make_manager_corrdiff(use_apex_gn=True)
+ assert mgr.use_apex_gn is True
+
+ def test_stores_profile_mode_flag(self):
+ mgr = _make_manager_corrdiff(profile_mode=True)
+ assert mgr.profile_mode is True
+
+ def test_stores_is_real_target(self):
+ mgr = _make_manager_corrdiff(is_real_target=True)
+ assert mgr.is_real_target is True
+
+ def test_stores_patching_and_hr_mean_conditioning_flags(self):
+ mgr = _make_manager_corrdiff(use_patching=True, hr_mean_conditioning=True)
+ assert mgr.use_patching is True
+ assert mgr.hr_mean_conditioning is True
+
+ def test_stores_logging_method(self):
+ mgr = _make_manager_corrdiff(logging_method="mlflow")
+ assert mgr.logging_method == "mlflow"
+
+ def test_stores_n_month_hour_channels(self):
+ mgr = _make_manager_corrdiff(n_month_hour_channels=6)
+ assert mgr.n_month_hour_channels == 6
+
+ def test_stores_songunet_checkpoint_level(self):
+ mgr = _make_manager_corrdiff(songunet_checkpoint_level=2)
+ assert mgr.songunet_checkpoint_level == 2
+
+
+
+############################################################################
+# TrainingManagerCorrDiff — get_static_data #
+############################################################################
+
+
+class TestGetStaticData:
+ """Tests for TrainingManagerCorrDiff.get_static_data."""
+
+ def test_returns_none_when_dataset_has_no_static(self):
+ ds = _make_mock_dataset(static_data=None)
+ mgr = _make_manager_corrdiff(dataset=ds)
+ assert mgr.get_static_data() is None
+
+ def test_returns_tensor_from_numpy(self):
+ """numpy static data should be converted to a torch tensor."""
+ static_np = np.random.randn(C_STATIC, H, W).astype(np.float32)
+ ds = _make_mock_dataset(static_data=static_np)
+ mgr = _make_manager_corrdiff(dataset=ds)
+ result = mgr.get_static_data()
+ assert isinstance(result, torch.Tensor)
+
+ def test_returns_tensor_from_tensor(self):
+ """torch tensor static data should also be handled."""
+ static_t = torch.randn(C_STATIC, H, W)
+ ds = _make_mock_dataset(static_data=static_t)
+ mgr = _make_manager_corrdiff(dataset=ds)
+ result = mgr.get_static_data()
+ assert isinstance(result, torch.Tensor)
+
+ def test_adds_batch_dim(self):
+ """Result should have a leading batch dim of 1."""
+ static_np = np.random.randn(C_STATIC, H, W).astype(np.float32)
+ ds = _make_mock_dataset(static_data=static_np)
+ mgr = _make_manager_corrdiff(dataset=ds)
+ result = mgr.get_static_data()
+ assert result.shape[0] == 1
+
+ def test_flips_height(self):
+ """Static data should be flipped along the last-2 (height) dim."""
+ static_np = np.arange(H).reshape(1, H, 1).repeat(W, axis=2).astype(np.float32)
+ ds = _make_mock_dataset(n_static=1, static_data=static_np)
+ mgr = _make_manager_corrdiff(dataset=ds)
+ result = mgr.get_static_data()
+ # After flip(-2): first row should be the last row of the original
+ expected_first_row = float(H - 1)
+ assert result[0, 0, 0, 0].item() == pytest.approx(expected_first_row)
+
+ def test_channels_last_when_apex_gn(self):
+ """With use_apex_gn=True, output should use channels_last memory format."""
+ static_np = np.random.randn(C_STATIC, H, W).astype(np.float32)
+ ds = _make_mock_dataset(static_data=static_np)
+ mgr = _make_manager_corrdiff(dataset=ds, use_apex_gn=True)
+ result = mgr.get_static_data()
+ assert result.is_contiguous(memory_format=torch.channels_last)
+
+ def test_contiguous_when_no_apex_gn(self):
+ """Without apex_gn, output should be standard contiguous."""
+ static_np = np.random.randn(C_STATIC, H, W).astype(np.float32)
+ ds = _make_mock_dataset(static_data=static_np)
+ mgr = _make_manager_corrdiff(dataset=ds, use_apex_gn=False)
+ result = mgr.get_static_data()
+ assert result.is_contiguous()
+
+
+############################################################################
+# TrainingManagerCorrDiff — load_and_preprocess_batch #
+############################################################################
+
+
+class TestLoadAndPreprocessBatch:
+ """Tests for TrainingManagerCorrDiff.load_and_preprocess_batch."""
+
+ @staticmethod
+ def _make_iterator(img_clean, img_lr, date_str=None):
+ """Wrap tensors into an iterator that yields a single batch."""
+ batch = [img_clean, img_lr]
+ if date_str is not None:
+ batch.append(date_str)
+ return iter([batch])
+
+ def test_returns_three_elements(self):
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds)
+ it = self._make_iterator(
+ torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W)
+ )
+ result = mgr.load_and_preprocess_batch(it)
+ assert len(result) == 3 # img_clean, img_lr, date_embedding
+
+ def test_date_embedding_is_none_when_no_month_hour(self):
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds, n_month_hour_channels=0)
+ it = self._make_iterator(
+ torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W)
+ )
+ _, _, date_embedding = mgr.load_and_preprocess_batch(it)
+ assert date_embedding is None
+
+ def test_date_embedding_returned_when_month_hour(self):
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds, n_month_hour_channels=4)
+ it = self._make_iterator(
+ torch.randn(B, C_OUT, H * W),
+ torch.randn(B, C_IN, H * W),
+ date_str="20240101-1800",
+ )
+ _, _, date_embedding = mgr.load_and_preprocess_batch(it)
+ assert date_embedding is not None
+
+ def test_imgs_flipped_and_reshaped(self):
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds)
+ img_clean = torch.randn(B, C_OUT, H * W)
+ img_lr = torch.randn(B, C_IN, H * W)
+ it = self._make_iterator(img_clean, img_lr)
+ img_clean_out, img_lr_out, _ = mgr.load_and_preprocess_batch(it)
+ # Output should be flipped along height and reshaped to (B, C, H, W)
+ assert img_clean_out.shape == (B, C_OUT, H, W)
+ assert img_lr_out.shape == (B, C_IN, H, W)
+ # Check that the first row of the output corresponds to the last row of the input after flip
+ expected_first_row_clean = img_clean[:, :, -W:]
+ expected_first_row_lr = img_lr[:, :, -W:]
+ expected_last_row_clean = img_clean[:, :, :W]
+ expected_last_row_lr = img_lr[:, :, :W]
+ assert torch.allclose(img_clean_out[:, :, 0, :], expected_first_row_clean)
+ assert torch.allclose(img_lr_out[:, :, 0, :], expected_first_row_lr)
+ assert torch.allclose(img_clean_out[:, :, -1, :], expected_last_row_clean)
+ assert torch.allclose(img_lr_out[:, :, -1, :], expected_last_row_lr)
+
+
+ def test_img_clean_flipped_and_reshaped_when_real_target(self):
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds, is_real_target=True)
+ img_clean = torch.randn(B, C_OUT, H * W)
+ img_lr = torch.randn(B, C_IN, H * W)
+ it = self._make_iterator(img_clean, img_lr)
+ mock_regrid = MagicMock(side_effect=lambda x,y,z: x.reshape(*x.shape[:-1], *mgr.img_shape))
+ with patch("hirad.training.training_manager.regrid_icon_to_rotlatlon", mock_regrid):
+ img_clean_out, _, _ = mgr.load_and_preprocess_batch(it)
+ # Output should be flipped along height and reshaped to (B, C, H, W)
+ assert img_clean_out.shape == (B, C_OUT, H, W)
+ expected_first_row_clean = img_clean[:, :, -W:]
+ expected_last_row_clean = img_clean[:, :, :W]
+ assert torch.allclose(img_clean_out[:, :, 0, :], expected_first_row_clean)
+ assert torch.allclose(img_clean_out[:, :, -1, :], expected_last_row_clean)
+
+ def test_img_clean_trimmed_when_trim_edge_positive(self):
+ trim = 4
+ ds = _make_mock_dataset(trim_edge=trim)
+ mgr = _make_manager_corrdiff(dataset=ds, is_real_target=True)
+ img_clean = torch.randn(B, C_OUT, (H + 2 * trim) * (W + 2 * trim))
+ img_lr = torch.randn(B, C_IN, H * W )
+ it = self._make_iterator(img_clean, img_lr)
+ mock_regrid = MagicMock(side_effect=lambda x,y,z: x.reshape(*x.shape[:-1], *(H+2*trim, W+2*trim)))
+ with patch("hirad.training.training_manager.regrid_icon_to_rotlatlon", mock_regrid):
+ img_clean_out, _, _ = mgr.load_and_preprocess_batch(it)
+ # Output should be trimmed by 'trim' pixels on each side, flipped, and reshaped to (B, C, H, W)
+ assert img_clean_out.shape == (B, C_OUT, H, W)
+ # expected_first_row_clean = img_clean[:, :, -(W + 2 * trim):- (W + 2 * trim) + W]
+ # expected_last_row_clean = img_clean[:, :, trim:trim + W]
+ expected_first_row_clean = img_clean[:, :, -((trim+1)*(W+2*trim))+trim:-((trim+1)*(W+2*trim))+trim+W]
+ expected_last_row_clean = img_clean[:, :, trim*(W+2*trim+1):trim*(W+2*trim+1) + W]
+ assert torch.allclose(img_clean_out[:, :, 0, :], expected_first_row_clean)
+ assert torch.allclose(img_clean_out[:, :, -1, :], expected_last_row_clean)
+
+
+ def test_calls_normalize_input(self):
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds)
+ it = self._make_iterator(
+ torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W)
+ )
+ mgr.load_and_preprocess_batch(it)
+ ds.normalize_input.assert_called_once()
+
+ def test_calls_normalize_output(self):
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds)
+ it = self._make_iterator(
+ torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W)
+ )
+ mgr.load_and_preprocess_batch(it)
+ ds.normalize_output.assert_called_once()
+
+ def test_calls_interpolator(self):
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds)
+ it = self._make_iterator(
+ torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W)
+ )
+ mgr.load_and_preprocess_batch(it)
+ ds.interpolator.assert_called_once()
+
+ def test_real_target_calls_regrid_icon_to_latlon(self):
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds, is_real_target=True)
+ mock_regrid = MagicMock(side_effect=lambda x,y,z: x.reshape(*x.shape[:-1], *mgr.img_shape))
+ with patch("hirad.training.training_manager.regrid_icon_to_rotlatlon", mock_regrid):
+ it = self._make_iterator(
+ torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W)
+ )
+ mgr.load_and_preprocess_batch(it)
+ mock_regrid.assert_called_once()
+
+ def test_output_with_apex_gn_is_channels_last(self):
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds, use_apex_gn=True)
+ it = self._make_iterator(
+ torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W)
+ )
+ img_clean, img_lr, _ = mgr.load_and_preprocess_batch(it)
+ assert img_clean.is_contiguous(memory_format=torch.channels_last)
+ assert img_lr.is_contiguous(memory_format=torch.channels_last)
+
+ def test_output_without_apex_gn_is_contiguous(self):
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds, use_apex_gn=False)
+ it = self._make_iterator(
+ torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W)
+ )
+ img_clean, img_lr, _ = mgr.load_and_preprocess_batch(it)
+ assert img_clean.is_contiguous()
+ assert img_lr.is_contiguous()
+
+ def test_output_dtype_matches_input_dtype(self):
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds)
+ it = self._make_iterator(
+ torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W)
+ )
+ img_clean, img_lr, _ = mgr.load_and_preprocess_batch(it)
+ assert img_clean.dtype == torch.float32
+ assert img_lr.dtype == torch.float32
+
+
+############################################################################
+# TrainingManagerCorrDiff — create_model #
+############################################################################
+
+
+class TestCreateModel:
+ """Tests for TrainingManagerCorrDiff.create_model."""
+
+ @patch("hirad.training.training_manager.EDMPrecondSuperResolution")
+ def test_diffusion_returns_edm(self, MockEDM):
+ """'diffusion' model name should instantiate EDMPrecondSuperResolution."""
+ MockEDM.return_value = MagicMock(spec=nn.Module)
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds)
+ model, args = mgr.create_model(
+ "diffusion", {"N_grid_channels": 2}
+ )
+ MockEDM.assert_called_once()
+
+ @patch("hirad.training.training_manager.UNet")
+ def test_regression_returns_unet(self, MockUNet):
+ """'regression' model name should instantiate UNet."""
+ MockUNet.return_value = MagicMock(spec=nn.Module)
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds)
+ model, args = mgr.create_model(
+ "regression", {"N_grid_channels": 2}
+ )
+ MockUNet.assert_called_once()
+
+ @patch("hirad.training.training_manager.EDMPrecondSuperResolution")
+ def test_patched_diffusion_returns_edm(self, MockEDM):
+ MockEDM.return_value = MagicMock(spec=nn.Module)
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds)
+ model, args = mgr.create_model(
+ "patched_diffusion", {"N_grid_channels": 2}
+ )
+ MockEDM.assert_called_once()
+
+ @patch("hirad.training.training_manager.EDMPrecondSuperResolution")
+ def test_lt_aware_patched_diffusion_returns_edm(self, MockEDM):
+ MockEDM.return_value = MagicMock(spec=nn.Module)
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds)
+ model, args = mgr.create_model(
+ "lt_aware_patched_diffusion",
+ {"N_grid_channels": 2, "lead_time_channels": 1},
+ )
+ MockEDM.assert_called_once()
+
+ @patch("hirad.training.training_manager.UNet")
+ def test_lt_aware_ce_regression_returns_unet(self, MockUNet):
+ MockUNet.return_value = MagicMock(spec=nn.Module)
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds)
+ model, args = mgr.create_model(
+ "lt_aware_ce_regression",
+ {"N_grid_channels": 2, "lead_time_channels": 1},
+ )
+ MockUNet.assert_called_once()
+
+ @patch("hirad.training.training_manager.EDMPrecondSuperResolution")
+ def test_returns_model_and_args(self, MockEDM):
+ MockEDM.return_value = MagicMock(spec=nn.Module)
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds)
+ model, args = mgr.create_model(
+ "diffusion", {"N_grid_channels": 2}
+ )
+ assert model is not None
+ assert isinstance(args, dict)
+
+ @patch("hirad.training.training_manager.EDMPrecondSuperResolution")
+ def test_model_args_contain_resolution(self, MockEDM):
+ MockEDM.return_value = MagicMock(spec=nn.Module)
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds, img_shape=(32, 64))
+ _, args = mgr.create_model(
+ "diffusion", {"N_grid_channels": 2}
+ )
+ assert args["img_resolution"] == [32, 64]
+
+ @patch("hirad.training.training_manager.EDMPrecondSuperResolution")
+ def test_model_args_contain_fp16_flag(self, MockEDM):
+ MockEDM.return_value = MagicMock(spec=nn.Module)
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds, fp16=True)
+ _, args = mgr.create_model(
+ "diffusion", {"N_grid_channels": 2}
+ )
+ assert args["use_fp16"] is True
+
+ @patch("hirad.training.training_manager.EDMPrecondSuperResolution")
+ def test_cfg_args_override_defaults(self, MockEDM):
+ """cfg_model_args should override the default model_args."""
+ MockEDM.return_value = MagicMock(spec=nn.Module)
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds, songunet_checkpoint_level=99)
+ _, args = mgr.create_model(
+ "diffusion",
+ {"N_grid_channels": 2},
+ )
+ assert args["checkpoint_level"] == 99
+
+ @patch("hirad.training.training_manager.EDMPrecondSuperResolution")
+ def test_input_channels_include_static(self, MockEDM):
+ """img_in_channels should include static channels."""
+ MockEDM.return_value = MagicMock(spec=nn.Module)
+ ds = _make_mock_dataset(n_input=4, n_static=2)
+ mgr = _make_manager_corrdiff(dataset=ds, n_month_hour_channels=0)
+ _, args = mgr.create_model(
+ "diffusion", {"N_grid_channels": 3}
+ )
+ # img_in_channels = n_input(4) + n_static(2) + n_month_hour(0) + N_grid_channels(3)
+ assert args["img_in_channels"] == 4 + 2 + 3
+
+ @patch("hirad.training.training_manager.EDMPrecondSuperResolution")
+ def test_input_channels_include_month_hour(self, MockEDM):
+ """img_in_channels should include month/hour embedding channels."""
+ MockEDM.return_value = MagicMock(spec=nn.Module)
+ ds = _make_mock_dataset(n_input=4, n_static=2)
+ mgr = _make_manager_corrdiff(dataset=ds, n_month_hour_channels=6)
+ _, args = mgr.create_model(
+ "diffusion", {"N_grid_channels": 0}
+ )
+ # img_in_channels = 4 + 2 + 6 + 0
+ assert args["img_in_channels"] == 12
+
+ @patch("hirad.training.training_manager.EDMPrecondSuperResolution")
+ def test_hr_mean_conditioning_adds_output_channels(self, MockEDM):
+ """hr_mean_conditioning should add n_output channels to img_in_channels."""
+ MockEDM.return_value = MagicMock(spec=nn.Module)
+ ds = _make_mock_dataset(n_input=4, n_output=3, n_static=2)
+ mgr = _make_manager_corrdiff(dataset=ds, hr_mean_conditioning=True)
+ _, args = mgr.create_model(
+ "diffusion", {"N_grid_channels": 0}
+ )
+ # img_in_channels = 4 + 2 + 0 (month/hour) + 3 (hr_mean) + 0 (N_grid)
+ assert args["img_in_channels"] == 4 + 2 + 3
+
+ @patch("hirad.training.training_manager.EDMPrecondSuperResolution")
+ def test_patching_adds_input_and_static_channels(self, MockEDM):
+ """use_patching should add an extra set of input + static channels."""
+ MockEDM.return_value = MagicMock(spec=nn.Module)
+ ds = _make_mock_dataset(n_input=4, n_static=2, n_output=3)
+ mgr = _make_manager_corrdiff(dataset=ds, use_patching=True, n_month_hour_channels=6, hr_mean_conditioning=True)
+ _, args = mgr.create_model(
+ "diffusion", {"N_grid_channels": 5}
+ )
+ # img_in_channels = (4+2) + (4+2) for patching + 6 (month/hour) + 5 (N_grid) + 3 (hr_mean) = (4+2)*2 + 6 + 5 + 3
+ assert args["img_in_channels"] == (4 + 2) * 2 + 6 + 5 + 3
+
+ @patch("hirad.training.training_manager.EDMPrecondSuperResolution")
+ def test_amp_mode_set_when_enabled(self, MockEDM):
+ MockEDM.return_value = MagicMock(spec=nn.Module)
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds, enable_amp=True)
+ _, args = mgr.create_model(
+ "diffusion", {"N_grid_channels": 0}
+ )
+ assert args["amp_mode"] is True
+
+ @patch("hirad.training.training_manager.EDMPrecondSuperResolution")
+ def test_amp_mode_absent_when_disabled(self, MockEDM):
+ MockEDM.return_value = MagicMock(spec=nn.Module)
+ ds = _make_mock_dataset()
+ mgr = _make_manager_corrdiff(dataset=ds, enable_amp=False)
+ _, args = mgr.create_model(
+ "diffusion", {"N_grid_channels": 0}
+ )
+ assert "amp_mode" not in args
+
+ @patch("hirad.training.training_manager.EDMPrecondSuperResolution")
+ def test_img_in_out_channels_in_model_args(self, MockEDM):
+ MockEDM.return_value = MagicMock(spec=nn.Module)
+ ds = _make_mock_dataset(n_input=4, n_output=3, n_static=2)
+ mgr = _make_manager_corrdiff(dataset=ds, n_month_hour_channels=6)
+ _, args = mgr.create_model(
+ "diffusion", {"N_grid_channels": 0}
+ )
+ assert "img_in_channels" in args
+ assert "img_out_channels" in args
+
+############################################################################
+# TrainingManagerCorrDiff — load_regression_model #
+############################################################################
+
+
+class TestLoadRegressionModel:
+ """Tests for TrainingManagerCorrDiff.load_regression_model."""
+
+ def test_missing_dir_raises_file_not_found(self, tmp_path):
+ mgr = _make_manager_corrdiff()
+ with pytest.raises(FileNotFoundError, match="not found"):
+ mgr.load_regression_model(str(tmp_path / "nonexistent"))
+
+ def test_missing_model_args_json_raises(self, tmp_path):
+ """Directory exists but model_args.json is missing."""
+ ckpt_dir = tmp_path / "ckpt"
+ ckpt_dir.mkdir()
+ mgr = _make_manager_corrdiff()
+ with pytest.raises(FileNotFoundError, match="model_args.json"):
+ mgr.load_regression_model(str(ckpt_dir))
+
+ @patch("hirad.training.training_manager.load_checkpoint")
+ @patch("hirad.training.training_manager.UNet")
+ def test_loads_and_returns_model(self, MockUNet, mock_load_ckpt, tmp_path):
+ """Should load model_args.json and return a UNet in eval mode."""
+ ckpt_dir = tmp_path / "ckpt"
+ ckpt_dir.mkdir()
+ model_args = {
+ "img_in_channels": 6,
+ "img_out_channels": 3,
+ "img_resolution": [64, 64],
+ }
+ (ckpt_dir / "model_args.json").write_text(json.dumps(model_args))
+
+ mock_model = MagicMock(spec=nn.Module)
+ mock_model.eval.return_value = mock_model
+ mock_model.requires_grad_.return_value = mock_model
+ mock_model.to.return_value = mock_model
+ MockUNet.return_value = mock_model
+ mock_load_ckpt.return_value = 0
+
+ mgr = _make_manager_corrdiff()
+ result = mgr.load_regression_model(str(ckpt_dir))
+
+ MockUNet.assert_called_once()
+ mock_model.eval.assert_called_once()
+ mock_model.requires_grad_.assert_called_once_with(False)
+ assert result is mock_model
+
+ @patch("hirad.training.training_manager.load_checkpoint")
+ @patch("hirad.training.training_manager.UNet")
+ def test_passes_apex_and_profile_and_amp_flags(self, MockUNet, mock_load_ckpt, tmp_path):
+ """UNet should receive use_apex_gn, profile_mode, and amp_mode."""
+ ckpt_dir = tmp_path / "ckpt"
+ ckpt_dir.mkdir()
+ (ckpt_dir / "model_args.json").write_text(
+ json.dumps({"img_in_channels": 6, "img_out_channels": 3, "img_resolution": [64, 64]})
+ )
+ mock_model = MagicMock(spec=nn.Module)
+ mock_model.eval.return_value = mock_model
+ mock_model.requires_grad_.return_value = mock_model
+ mock_model.to.return_value = mock_model
+ MockUNet.return_value = mock_model
+ mock_load_ckpt.return_value = 0
+
+ mgr = _make_manager_corrdiff(use_apex_gn=True, profile_mode=True, enable_amp=True)
+ mgr.load_regression_model(str(ckpt_dir))
+
+ call_kwargs = MockUNet.call_args[1]
+ assert call_kwargs["use_apex_gn"] is True
+ assert call_kwargs["profile_mode"] is True
+ assert call_kwargs["amp_mode"] is True
+
+ @patch("hirad.training.training_manager.load_checkpoint")
+ @patch("hirad.training.training_manager.UNet")
+ def test_apex_gn_sets_channels_last(self, MockUNet, mock_load_ckpt, tmp_path):
+ """With use_apex_gn, model.to(memory_format=channels_last) should be called."""
+ ckpt_dir = tmp_path / "ckpt"
+ ckpt_dir.mkdir()
+ (ckpt_dir / "model_args.json").write_text(
+ json.dumps({"img_in_channels": 6, "img_out_channels": 3, "img_resolution": [64, 64]})
+ )
+ mock_model = MagicMock(spec=nn.Module)
+ mock_model.eval.return_value = mock_model
+ mock_model.requires_grad_.return_value = mock_model
+ mock_model.to.return_value = mock_model
+ MockUNet.return_value = mock_model
+ mock_load_ckpt.return_value = 0
+
+ mgr = _make_manager_corrdiff(use_apex_gn=True)
+ mgr.load_regression_model(str(ckpt_dir))
+
+ mock_model.to.assert_any_call(memory_format=torch.channels_last)
+
+
+############################################################################
+# TrainingManagerCorrDiff — run_validation #
+############################################################################
+
+
+class TestRunValidation:
+ """Tests for TrainingManagerCorrDiff.run_validation."""
+
+ @staticmethod
+ def _make_loss_fn(loss_value=1.0, loss_size=B):
+ loss_fn = MagicMock()
+ loss_fn.return_value = torch.tensor([loss_value] * loss_size)
+ loss_fn.y_mean = None
+ return loss_fn
+
+ @staticmethod
+ def _make_validation_iterator(n_steps, n_out=C_OUT, n_in=C_IN):
+ batches = [
+ [torch.randn(B, n_out, H * W), torch.randn(B, n_in, H * W)]
+ for _ in range(n_steps)
+ ]
+ return iter(batches)
+
+ def test_returns_float(self):
+ mgr = _make_manager_corrdiff()
+ loss_fn = self._make_loss_fn()
+ it = self._make_validation_iterator(2)
+ result = mgr.run_validation(
+ cur_nimg=100,
+ validation_dataset_iterator=it,
+ model=MagicMock(),
+ loss_fn=loss_fn,
+ validation_steps=2,
+ static_channels=None,
+ batch_size_per_gpu=B,
+ patching=None,
+ patch_nums_iter=[1],
+ use_patch_grad_acc=None,
+ )
+ assert isinstance(result, float)
+
+ def test_calls_loss_fn_per_step(self):
+ mgr = _make_manager_corrdiff()
+ loss_fn = self._make_loss_fn()
+ n_steps = 3
+ it = self._make_validation_iterator(n_steps)
+ mgr.run_validation(
+ cur_nimg=100,
+ validation_dataset_iterator=it,
+ model=MagicMock(),
+ loss_fn=loss_fn,
+ validation_steps=n_steps,
+ static_channels=None,
+ batch_size_per_gpu=B,
+ patching=None,
+ patch_nums_iter=[1],
+ use_patch_grad_acc=None,
+ )
+ assert loss_fn.call_count == n_steps
+
+ def test_calls_loss_fn_per_patch_iter(self):
+ """Loss should be called validation_steps * len(patch_nums_iter) times."""
+ mgr = _make_manager_corrdiff()
+ loss_fn = self._make_loss_fn()
+ n_steps = 2
+ patch_nums_iter = [2, 2, 1]
+ it = self._make_validation_iterator(n_steps)
+ patching = MagicMock()
+ mgr.run_validation(
+ cur_nimg=100,
+ validation_dataset_iterator=it,
+ model=MagicMock(),
+ loss_fn=loss_fn,
+ validation_steps=n_steps,
+ static_channels=None,
+ batch_size_per_gpu=B,
+ patching=patching,
+ patch_nums_iter=patch_nums_iter,
+ use_patch_grad_acc=None,
+ )
+ assert loss_fn.call_count == n_steps * len(patch_nums_iter)
+
+ def test_sets_patch_num_on_patching(self):
+ """patching.set_patch_num should be called for each patch iteration."""
+ mgr = _make_manager_corrdiff()
+ loss_fn = self._make_loss_fn()
+ it = self._make_validation_iterator(1)
+ patching = MagicMock()
+ patch_nums_iter = [3, 2]
+ mgr.run_validation(
+ cur_nimg=100,
+ validation_dataset_iterator=it,
+ model=MagicMock(),
+ loss_fn=loss_fn,
+ validation_steps=1,
+ static_channels=None,
+ batch_size_per_gpu=B,
+ patching=patching,
+ patch_nums_iter=patch_nums_iter,
+ use_patch_grad_acc=None,
+ )
+ calls = [c.args[0] for c in patching.set_patch_num.call_args_list]
+ assert calls == [3, 2]
+
+ @patch("hirad.training.training_manager.mlflow")
+ def test_logs_to_mlflow_on_rank0(self, mock_mlflow):
+ dist = _make_mock_dist(rank=0, world_size=1)
+ mgr = _make_manager_corrdiff(dist=dist, logging_method="mlflow")
+ loss_fn = self._make_loss_fn()
+ it = self._make_validation_iterator(1)
+ mgr.run_validation(
+ cur_nimg=200,
+ validation_dataset_iterator=it,
+ model=MagicMock(),
+ loss_fn=loss_fn,
+ validation_steps=1,
+ static_channels=None,
+ batch_size_per_gpu=B,
+ patching=None,
+ patch_nums_iter=[1],
+ use_patch_grad_acc=None,
+ )
+ mock_mlflow.log_metric.assert_called_once()
+ call_args = mock_mlflow.log_metric.call_args
+ assert call_args[0][0] == "validation_loss"
+ assert call_args[0][2] == 200 # cur_nimg
+
+ @patch("hirad.training.training_manager.mlflow")
+ def test_no_mlflow_on_non_rank0(self, mock_mlflow):
+ dist = _make_mock_dist(rank=1, world_size=1) # keep world size 1 not to trigger any distributed logic, but set rank to non-zero
+ mgr = _make_manager_corrdiff(dist=dist, logging_method="mlflow")
+ loss_fn = self._make_loss_fn()
+ it = self._make_validation_iterator(1)
+ mgr.run_validation(
+ cur_nimg=200,
+ validation_dataset_iterator=it,
+ model=MagicMock(),
+ loss_fn=loss_fn,
+ validation_steps=1,
+ static_channels=None,
+ batch_size_per_gpu=B,
+ patching=None,
+ patch_nums_iter=[1],
+ use_patch_grad_acc=None,
+ )
+ mock_mlflow.log_metric.assert_not_called()
+
+ @patch("hirad.training.training_manager.mlflow")
+ def test_no_mlflow_when_logging_disabled(self, mock_mlflow):
+ dist = _make_mock_dist(rank=0, world_size=1)
+ mgr = _make_manager_corrdiff(dist=dist, logging_method=None)
+ loss_fn = self._make_loss_fn()
+ it = self._make_validation_iterator(1)
+ mgr.run_validation(
+ cur_nimg=200,
+ validation_dataset_iterator=it,
+ model=MagicMock(),
+ loss_fn=loss_fn,
+ validation_steps=1,
+ static_channels=None,
+ batch_size_per_gpu=B,
+ patching=None,
+ patch_nums_iter=[1],
+ use_patch_grad_acc=None,
+ )
+ mock_mlflow.log_metric.assert_not_called()
+
+ def test_resets_y_mean_with_patch_grad_acc(self):
+ """When use_patch_grad_acc is True, loss_fn.y_mean should be reset each step."""
+ mgr = _make_manager_corrdiff()
+ loss_fn = self._make_loss_fn()
+ loss_fn.y_mean = torch.tensor(42.0)
+ it = self._make_validation_iterator(1)
+ mgr.run_validation(
+ cur_nimg=100,
+ validation_dataset_iterator=it,
+ model=MagicMock(),
+ loss_fn=loss_fn,
+ validation_steps=1,
+ static_channels=None,
+ batch_size_per_gpu=B,
+ patching=None,
+ patch_nums_iter=[1],
+ use_patch_grad_acc=True,
+ )
+ # y_mean should have been set to None at the beginning of the step
+ assert loss_fn.y_mean is None
+
+ def test_average_loss_value_as_expected(self):
+ """Test that the average loss value is computed as expected."""
+ mgr = _make_manager_corrdiff()
+ loss_fn = self._make_loss_fn()
+ it = self._make_validation_iterator(1)
+ result = mgr.run_validation(
+ cur_nimg=100,
+ validation_dataset_iterator=it,
+ model=MagicMock(),
+ loss_fn=loss_fn,
+ validation_steps=1,
+ static_channels=None,
+ batch_size_per_gpu=B,
+ patching=None,
+ patch_nums_iter=[1],
+ use_patch_grad_acc=True,
+ )
+ assert result == 1.0
+
+ def test_average_loss_value_with_patching_as_expected(self):
+ """Test that the average loss value is computed as expected when using patching."""
+ mgr = _make_manager_corrdiff()
+ loss_fn = self._make_loss_fn(loss_size=B*3) # simulate 3 patches per batch_element
+ it = self._make_validation_iterator(3)
+ patch_nums_iter = [3, 3]
+ result = mgr.run_validation(
+ cur_nimg=100,
+ validation_dataset_iterator=it,
+ model=MagicMock(),
+ loss_fn=loss_fn,
+ validation_steps=3,
+ static_channels=None,
+ batch_size_per_gpu=B,
+ patching=MagicMock(),
+ patch_nums_iter=patch_nums_iter,
+ use_patch_grad_acc=True,
+ )
+ # With 2 total patch iterations with 3 patches per iteration and a loss of 1.0 per iteration, the average should still be 1.0
+ assert result == 1.0
\ No newline at end of file
diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py
new file mode 100644
index 00000000..2565f921
--- /dev/null
+++ b/tests/utils/test_checkpoint.py
@@ -0,0 +1,336 @@
+import os
+from pathlib import Path
+from unittest.mock import MagicMock, patch
+
+import pytest
+import torch
+
+from hirad.utils.checkpoint import (
+ _get_checkpoint_filename,
+ load_checkpoint,
+ save_checkpoint,
+)
+
+
+# ---------------------------------------------------------------------------
+# Helpers / fixtures
+# ---------------------------------------------------------------------------
+
+
+def _make_mock_manager(model_parallel_rank: int = 0, group_names=()):
+ """Return a mock DistributedManager with the given parallel rank."""
+ mgr = MagicMock()
+ mgr.group_names = group_names
+ mgr.group_rank.return_value = model_parallel_rank
+ return mgr
+
+
+@pytest.fixture(autouse=True)
+def _patch_distributed():
+ """Patch DistributedManager for every test so we never touch real dist."""
+ mgr = _make_mock_manager(model_parallel_rank=0, group_names=())
+ with patch(
+ "hirad.utils.checkpoint.DistributedManager"
+ ) as MockDM:
+ MockDM.is_initialized.return_value = True
+ MockDM.return_value = mgr
+ yield MockDM
+
+
+class _SimpleModel(torch.nn.Module):
+ """Tiny model used in save / load round-trip tests."""
+
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(4, 2)
+
+
+############################################################################
+# _get_checkpoint_filename #
+############################################################################
+
+
+class TestGetCheckpointFilenameWithIndex:
+ """When an explicit index is supplied."""
+
+ def test_returns_correct_filename(self, tmp_path):
+ result = _get_checkpoint_filename(str(tmp_path), index=3)
+ expected = str(tmp_path.resolve() / "checkpoint.0.3.pt")
+ assert result == expected
+
+ def test_custom_base_name(self, tmp_path):
+ result = _get_checkpoint_filename(str(tmp_path), base_name="model", index=0)
+ expected = str(tmp_path.resolve() / "model.0.0.pt")
+ assert result == expected
+
+ def test_custom_model_type(self, tmp_path):
+ result = _get_checkpoint_filename(
+ str(tmp_path), index=1, model_type="hirad"
+ )
+ expected = str(tmp_path.resolve() / "checkpoint.0.1.hirad")
+ assert result == expected
+
+ def test_saving_flag_ignored_when_index_given(self, tmp_path):
+ result_save = _get_checkpoint_filename(str(tmp_path), index=5, saving=True)
+ result_load = _get_checkpoint_filename(str(tmp_path), index=5, saving=False)
+ assert result_save == result_load
+
+
+class TestGetCheckpointFilenameNoIndex:
+ """When no index is supplied (auto-detect from existing files)."""
+
+ def test_no_existing_files_returns_index_zero(self, tmp_path):
+ result = _get_checkpoint_filename(str(tmp_path))
+ expected = str(tmp_path.resolve() / "checkpoint.0.0.pt")
+ assert result == expected
+
+ def test_loads_latest_when_files_exist(self, tmp_path):
+ # Create fake checkpoint files with indices 0, 1, 2
+ for i in range(3):
+ (tmp_path / f"checkpoint.0.{i}.pt").touch()
+ result = _get_checkpoint_filename(str(tmp_path), saving=False)
+ expected = str(tmp_path.resolve() / "checkpoint.0.2.pt")
+ assert result == expected
+
+ def test_saving_increments_latest_index(self, tmp_path):
+ for i in range(3):
+ (tmp_path / f"checkpoint.0.{i}.pt").touch()
+ result = _get_checkpoint_filename(str(tmp_path), saving=True)
+ expected = str(tmp_path.resolve() / "checkpoint.0.3.pt")
+ assert result == expected
+
+ def test_non_contiguous_indices_picks_largest(self, tmp_path):
+ for i in [0, 5, 10]:
+ (tmp_path / f"checkpoint.0.{i}.pt").touch()
+ result = _get_checkpoint_filename(str(tmp_path), saving=False)
+ expected = str(tmp_path.resolve() / "checkpoint.0.10.pt")
+ assert result == expected
+
+
+class TestGetCheckpointFilenameModelParallel:
+ """Ensure model-parallel rank is embedded correctly."""
+
+ def test_model_parallel_rank_in_filename(self, tmp_path, _patch_distributed):
+ mgr = _make_mock_manager(model_parallel_rank=3, group_names=("model_parallel",))
+ _patch_distributed.return_value = mgr
+ result = _get_checkpoint_filename(str(tmp_path), index=0)
+ expected = str(tmp_path.resolve() / "checkpoint.3.0.pt")
+ assert result == expected
+
+
+############################################################################
+# save_checkpoint #
+############################################################################
+
+
+class TestSaveCheckpointDirectory:
+ """Test directory creation behaviour of save_checkpoint."""
+
+ def test_creates_directory_if_missing(self, tmp_path):
+ out = str(tmp_path / "new_dir" / "sub")
+ model = _SimpleModel()
+ save_checkpoint(out, model=model, epoch=0)
+ assert Path(out).is_dir()
+
+ def test_existing_directory_is_fine(self, tmp_path):
+ model = _SimpleModel()
+ save_checkpoint(str(tmp_path), model=model, epoch=0)
+ # Just ensure no exception is raised
+ assert Path(tmp_path).is_dir()
+
+
+class TestSaveCheckpointModel:
+ """Test model state-dict saving."""
+
+ def test_model_checkpoint_file_created(self, tmp_path):
+ model = _SimpleModel()
+ save_checkpoint(str(tmp_path), model=model, epoch=0)
+ expected = tmp_path.resolve() / f"{model.__class__.__name__}.0.0.pt"
+ assert expected.exists()
+
+ def test_model_checkpoint_contains_state_dict(self, tmp_path):
+ model = _SimpleModel()
+ save_checkpoint(str(tmp_path), model=model, epoch=0)
+ file_name = tmp_path.resolve() / f"{model.__class__.__name__}.0.0.pt"
+ state = torch.load(file_name, map_location="cpu")
+ assert "linear.weight" in state
+ assert "linear.bias" in state
+
+
+class TestSaveCheckpointTraining:
+ """Test optimizer / scheduler / scaler / metadata saving."""
+
+ def test_optimizer_state_saved(self, tmp_path):
+ model = _SimpleModel()
+ opt = torch.optim.SGD(model.parameters(), lr=0.01)
+ save_checkpoint(str(tmp_path), optimizer=opt, epoch=1)
+ ckpt_file = tmp_path.resolve() / "checkpoint.0.1.pt"
+ assert ckpt_file.exists()
+ ckpt = torch.load(ckpt_file, map_location="cpu")
+ assert "optimizer_state_dict" in ckpt
+
+ def test_epoch_saved(self, tmp_path):
+ model = _SimpleModel()
+ opt = torch.optim.SGD(model.parameters(), lr=0.01)
+ save_checkpoint(str(tmp_path), optimizer=opt, epoch=5)
+ ckpt_file = tmp_path.resolve() / "checkpoint.0.5.pt"
+ ckpt = torch.load(ckpt_file, map_location="cpu")
+ assert ckpt["epoch"] == 5
+
+ def test_metadata_saved(self, tmp_path):
+ model = _SimpleModel()
+ opt = torch.optim.SGD(model.parameters(), lr=0.01)
+ meta = {"loss": 0.42, "note": "test run"}
+ save_checkpoint(str(tmp_path), optimizer=opt, epoch=0, metadata=meta)
+ ckpt_file = tmp_path.resolve() / "checkpoint.0.0.pt"
+ ckpt = torch.load(ckpt_file, map_location="cpu")
+ assert ckpt["metadata"] == meta
+
+ def test_no_training_objects_no_checkpoint_file(self, tmp_path):
+ """If only a model is provided, no training checkpoint should be created."""
+ model = _SimpleModel()
+ save_checkpoint(str(tmp_path), model=model, epoch=0)
+ training_ckpt = tmp_path.resolve() / "checkpoint.0.0.pt"
+ assert not training_ckpt.exists()
+
+ def test_scheduler_state_saved(self, tmp_path):
+ model = _SimpleModel()
+ opt = torch.optim.SGD(model.parameters(), lr=0.01)
+ sched = torch.optim.lr_scheduler.StepLR(opt, step_size=10)
+ save_checkpoint(str(tmp_path), optimizer=opt, scheduler=sched, epoch=0)
+ ckpt_file = tmp_path.resolve() / "checkpoint.0.0.pt"
+ ckpt = torch.load(ckpt_file, map_location="cpu")
+ assert "scheduler_state_dict" in ckpt
+
+ def test_scaler_state_saved(self, tmp_path):
+ model = _SimpleModel()
+ opt = torch.optim.SGD(model.parameters(), lr=0.01)
+ scaler = torch.cuda.amp.GradScaler()
+ save_checkpoint(str(tmp_path), optimizer=opt, scaler=scaler, epoch=0)
+ ckpt_file = tmp_path.resolve() / "checkpoint.0.0.pt"
+ ckpt = torch.load(ckpt_file, map_location="cpu")
+ assert "scaler_state_dict" in ckpt
+
+
+############################################################################
+# load_checkpoint #
+############################################################################
+
+
+class TestLoadCheckpointMissingDir:
+ """Loading from a non-existent directory should return 0 gracefully."""
+
+ def test_returns_zero_for_missing_dir(self, tmp_path):
+ result = load_checkpoint(str(tmp_path / "nonexistent"))
+ assert result == 0
+
+
+class TestLoadCheckpointModel:
+ """Round-trip save/load of model state dicts."""
+
+ def test_model_weights_restored(self, tmp_path):
+ model = _SimpleModel()
+ # Freeze initial weights for comparison
+ original_weight = model.linear.weight.data.clone()
+ save_checkpoint(str(tmp_path), model=model, epoch=0)
+
+ # Mutate weights so we can confirm they get restored
+ with torch.no_grad():
+ model.linear.weight.fill_(0.0)
+ assert not torch.equal(model.linear.weight.data, original_weight)
+
+ load_checkpoint(str(tmp_path), model=model, epoch=0)
+ assert torch.equal(model.linear.weight.data, original_weight)
+
+ def test_missing_model_file_is_graceful(self, tmp_path):
+ """If model checkpoint doesn't exist, load should not crash."""
+ tmp_path.mkdir(exist_ok=True)
+ model = _SimpleModel()
+ # No save happened, but directory exists – should warn and skip
+ result = load_checkpoint(str(tmp_path), model=model, epoch=0)
+ # Should still return 0 (no training checkpoint either)
+ assert result == 0
+
+
+class TestLoadCheckpointTraining:
+ """Round-trip save/load of training state (optimizer, scheduler, epoch, metadata)."""
+
+ def test_optimizer_restored(self, tmp_path):
+ model = _SimpleModel()
+ opt = torch.optim.SGD(model.parameters(), lr=0.01)
+ # Take a step so state is non-trivial
+ loss = model.linear(torch.randn(1, 4)).sum()
+ loss.backward()
+ opt.step()
+ original_state = {k: v for k, v in opt.state_dict().items()}
+
+ save_checkpoint(str(tmp_path), optimizer=opt, epoch=1)
+
+ # Create a fresh optimizer
+ opt2 = torch.optim.SGD(model.parameters(), lr=0.01)
+ load_checkpoint(str(tmp_path), optimizer=opt2, epoch=1)
+ assert opt2.state_dict()["param_groups"] == original_state["param_groups"]
+
+ def test_epoch_returned(self, tmp_path):
+ model = _SimpleModel()
+ opt = torch.optim.SGD(model.parameters(), lr=0.01)
+ save_checkpoint(str(tmp_path), optimizer=opt, epoch=7)
+ loaded_epoch = load_checkpoint(str(tmp_path), epoch=7)
+ assert loaded_epoch == 7
+
+ def test_metadata_restored(self, tmp_path):
+ model = _SimpleModel()
+ opt = torch.optim.SGD(model.parameters(), lr=0.01)
+ meta = {"lr": 1e-3, "run_id": "abc"}
+ save_checkpoint(str(tmp_path), optimizer=opt, epoch=0, metadata=meta)
+
+ restored_meta = {}
+ load_checkpoint(str(tmp_path), epoch=0, metadata_dict=restored_meta)
+ assert restored_meta == meta
+
+ def test_missing_checkpoint_file_returns_zero(self, tmp_path):
+ """If training checkpoint file doesn't exist, return 0."""
+ tmp_path.mkdir(exist_ok=True)
+ result = load_checkpoint(str(tmp_path), epoch=99)
+ assert result == 0
+
+
+class TestSaveLoadRoundTrip:
+ """Full round-trip tests combining model + training state."""
+
+ def test_full_round_trip(self, tmp_path):
+ model = _SimpleModel()
+ opt = torch.optim.SGD(model.parameters(), lr=0.01)
+ sched = torch.optim.lr_scheduler.StepLR(opt, step_size=5)
+ meta = {"step": 100}
+
+ # Forward + backward to populate optimizer state
+ loss = model.linear(torch.randn(2, 4)).sum()
+ loss.backward()
+ opt.step()
+ sched.step()
+
+ original_weight = model.linear.weight.data.clone()
+
+ save_checkpoint(
+ str(tmp_path), model=model, optimizer=opt,
+ scheduler=sched, epoch=3, metadata=meta,
+ )
+
+ # Reset everything
+ with torch.no_grad():
+ model.linear.weight.fill_(0.0)
+ opt2 = torch.optim.SGD(model.parameters(), lr=0.01)
+ sched2 = torch.optim.lr_scheduler.StepLR(opt2, step_size=5)
+ restored_meta = {}
+
+ loaded_epoch = load_checkpoint(
+ str(tmp_path), model=model, optimizer=opt2,
+ scheduler=sched2, epoch=3, metadata_dict=restored_meta,
+ )
+
+ assert loaded_epoch == 3
+ assert torch.equal(model.linear.weight.data, original_weight)
+ assert restored_meta == meta
+ assert opt2.state_dict()["param_groups"] == opt.state_dict()["param_groups"]
+ assert sched2.state_dict()["last_epoch"] == sched.state_dict()["last_epoch"]
diff --git a/tests/utils/test_console.py b/tests/utils/test_console.py
new file mode 100644
index 00000000..788aff4a
--- /dev/null
+++ b/tests/utils/test_console.py
@@ -0,0 +1,206 @@
+import logging
+import os
+import tempfile
+from unittest.mock import MagicMock, patch
+
+import pytest
+from termcolor import colored
+
+from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper
+
+
+############################################################################
+# PythonLogger #
+############################################################################
+
+
+class TestPythonLogger:
+ """Tests for PythonLogger."""
+
+ def test_default_name(self):
+ pl = PythonLogger()
+ assert pl.logger.name == "launch"
+
+ def test_custom_name(self):
+ pl = PythonLogger(name="my_trainer")
+ assert pl.logger.name == "my_trainer"
+
+ def test_log_calls_info(self, caplog):
+ pl = PythonLogger(name="test_log")
+ pl.logger.setLevel(logging.DEBUG)
+ with caplog.at_level(logging.INFO, logger="test_log"):
+ pl.log("hello")
+ assert "hello" in caplog.text
+
+ def test_info_uses_light_blue(self, caplog):
+ pl = PythonLogger(name="test_info")
+ pl.logger.setLevel(logging.DEBUG)
+ expected = colored("info msg", "light_blue")
+ with caplog.at_level(logging.INFO, logger="test_info"):
+ pl.info("info msg")
+ assert expected in caplog.text
+
+ def test_success_uses_light_green(self, caplog):
+ pl = PythonLogger(name="test_success")
+ pl.logger.setLevel(logging.DEBUG)
+ expected = colored("ok", "light_green")
+ with caplog.at_level(logging.INFO, logger="test_success"):
+ pl.success("ok")
+ assert expected in caplog.text
+
+ def test_warning_uses_light_yellow(self, caplog):
+ pl = PythonLogger(name="test_warning")
+ pl.logger.setLevel(logging.DEBUG)
+ expected = colored("careful", "light_yellow")
+ with caplog.at_level(logging.WARNING, logger="test_warning"):
+ pl.warning("careful")
+ assert expected in caplog.text
+
+ def test_error_uses_light_red(self, caplog):
+ pl = PythonLogger(name="test_error")
+ pl.logger.setLevel(logging.DEBUG)
+ expected = colored("bad", "light_red")
+ with caplog.at_level(logging.ERROR, logger="test_error"):
+ pl.error("bad")
+ assert expected in caplog.text
+
+ def test_file_logging_creates_file(self, tmp_path):
+ log_file = str(tmp_path / "test.log")
+ pl = PythonLogger(name="test_file_create")
+ pl.logger.setLevel(logging.DEBUG)
+ pl.file_logging(log_file)
+ pl.log("file message")
+ assert os.path.exists(log_file)
+ with open(log_file) as f:
+ content = f.read()
+ assert "file message" in content
+
+ def test_file_logging_removes_existing_file(self, tmp_path):
+ log_file = str(tmp_path / "old.log")
+ with open(log_file, "w") as f:
+ f.write("old content")
+ pl = PythonLogger(name="test_file_overwrite")
+ pl.logger.setLevel(logging.DEBUG)
+ pl.file_logging(log_file)
+ pl.log("new content")
+ with open(log_file) as f:
+ content = f.read()
+ assert "old content" not in content
+ assert "new content" in content
+
+ def test_file_logging_formatter(self, tmp_path):
+ log_file = str(tmp_path / "fmt.log")
+ pl = PythonLogger(name="test_fmt")
+ pl.logger.setLevel(logging.DEBUG)
+ pl.file_logging(log_file)
+ pl.log("fmt check")
+ with open(log_file) as f:
+ content = f.read()
+ # Format: [HH:MM:SS - name - LEVEL] message
+ assert "test_fmt" in content
+ assert "INFO" in content
+
+ def test_file_logging_level_is_debug(self, tmp_path):
+ log_file = str(tmp_path / "debug.log")
+ pl = PythonLogger(name="test_debug_level")
+ pl.logger.setLevel(logging.DEBUG)
+ pl.file_logging(log_file)
+ pl.logger.debug("debug msg")
+ with open(log_file) as f:
+ content = f.read()
+ assert "debug msg" in content
+
+
+############################################################################
+# RankZeroLoggingWrapper #
+############################################################################
+
+
+class TestRankZeroLoggingWrapper:
+ """Tests for RankZeroLoggingWrapper."""
+
+ def test_rank_zero_calls_method(self):
+ inner = MagicMock()
+ inner.log = MagicMock(return_value="logged")
+ dist = MagicMock()
+ dist.rank = 0
+
+ wrapper = RankZeroLoggingWrapper(inner, dist)
+ result = wrapper.log("hello")
+ inner.log.assert_called_once_with("hello")
+ assert result == "logged"
+
+ def test_non_zero_rank_suppresses_method(self):
+ inner = MagicMock()
+ inner.log = MagicMock(return_value="logged")
+ dist = MagicMock()
+ dist.rank = 1
+
+ wrapper = RankZeroLoggingWrapper(inner, dist)
+ result = wrapper.log("hello")
+ inner.log.assert_not_called()
+ assert result is None
+
+ def test_non_callable_attribute_returned_directly(self):
+ inner = MagicMock()
+ inner.some_value = 42
+ dist = MagicMock()
+ dist.rank = 5
+
+ wrapper = RankZeroLoggingWrapper(inner, dist)
+ assert wrapper.some_value == 42
+
+ def test_rank_zero_passes_kwargs(self):
+ inner = MagicMock()
+ inner.configure = MagicMock(return_value="configured")
+ dist = MagicMock()
+ dist.rank = 0
+
+ wrapper = RankZeroLoggingWrapper(inner, dist)
+ result = wrapper.configure(level="DEBUG", verbose=True)
+ inner.configure.assert_called_once_with(level="DEBUG", verbose=True)
+ assert result == "configured"
+
+ def test_multiple_ranks_only_zero_logs(self):
+ inner = MagicMock()
+ inner.info = MagicMock()
+
+ for rank in range(4):
+ dist = MagicMock()
+ dist.rank = rank
+ wrapper = RankZeroLoggingWrapper(inner, dist)
+ wrapper.info("msg")
+
+ # Only rank 0 should have triggered the call
+ inner.info.assert_called_once_with("msg")
+
+ def test_wrapper_preserves_return_value(self):
+ inner = MagicMock()
+ inner.compute = MagicMock(return_value={"loss": 0.5})
+ dist = MagicMock()
+ dist.rank = 0
+
+ wrapper = RankZeroLoggingWrapper(inner, dist)
+ result = wrapper.compute()
+ assert result == {"loss": 0.5}
+
+ def test_wrapper_with_python_logger(self):
+ pl = PythonLogger(name="test_wrapped")
+ pl.logger.setLevel(logging.DEBUG)
+ dist = MagicMock()
+ dist.rank = 0
+
+ wrapper = RankZeroLoggingWrapper(pl, dist)
+ # Should not raise
+ wrapper.log("from wrapper")
+
+ def test_wrapper_with_python_logger_non_zero_rank(self, caplog):
+ pl = PythonLogger(name="test_wrapped_silent")
+ pl.logger.setLevel(logging.DEBUG)
+ dist = MagicMock()
+ dist.rank = 3
+
+ wrapper = RankZeroLoggingWrapper(pl, dist)
+ with caplog.at_level(logging.INFO, logger="test_wrapped_silent"):
+ wrapper.log("should not appear")
+ assert "should not appear" not in caplog.text
diff --git a/tests/utils/test_dataset_utils.py b/tests/utils/test_dataset_utils.py
new file mode 100644
index 00000000..0a1a6af2
--- /dev/null
+++ b/tests/utils/test_dataset_utils.py
@@ -0,0 +1,507 @@
+import numpy as np
+import pytest
+import torch
+
+from hirad.utils.dataset_utils import GridData, regrid_icon_to_rotlatlon
+
+
+############################################################################
+# regrid_icon_to_rotlatlon #
+############################################################################
+
+
+class TestRegridIconToRotlatlon:
+ """Tests for the regrid_icon_to_rotlatlon function."""
+
+ @pytest.fixture
+ def simple_regrid_inputs(self):
+ """Create simple inputs for regrid testing."""
+ n_unstructured = 20
+ n_target = 6 # 3 x 2 grid
+ n_stencil = 3
+
+ data = torch.arange(n_unstructured, dtype=torch.float32)
+ indices = torch.tensor(
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8],
+ [9, 10, 11], [12, 13, 14], [15, 16, 17]],
+ dtype=torch.long,
+ )
+ weights = torch.tensor(
+ [[0.5, 0.3, 0.2]] * n_target, dtype=torch.float32
+ )
+ return data, indices, weights
+
+ def test_output_shape_2d(self, simple_regrid_inputs):
+ data, indices, weights = simple_regrid_inputs
+ nx, ny = 3, 2
+ result = regrid_icon_to_rotlatlon(data, indices, weights, nx=nx, ny=ny)
+ assert result.shape == (ny, nx)
+
+ def test_output_shape_with_batch(self, simple_regrid_inputs):
+ _, indices, weights = simple_regrid_inputs
+ batch, channels, n_unstructured = 2, 4, 20
+ nx, ny = 3, 2
+ data = torch.randn(batch, channels, n_unstructured)
+ result = regrid_icon_to_rotlatlon(data, indices, weights, nx=nx, ny=ny)
+ assert result.shape == (batch, channels, ny, nx)
+
+ def test_output_shape_with_channels(self, simple_regrid_inputs):
+ _, indices, weights = simple_regrid_inputs
+ channels, n_unstructured = 5, 20
+ nx, ny = 3, 2
+ data = torch.randn(channels, n_unstructured)
+ result = regrid_icon_to_rotlatlon(data, indices, weights, nx=nx, ny=ny)
+ assert result.shape == (channels, ny, nx)
+
+ def test_uniform_weights_give_mean(self):
+ """When all weights are equal the result should be the mean of stencil values."""
+ n_target = 4
+ n_stencil = 3
+ data = torch.tensor([1.0, 2.0, 3.0, 10.0, 20.0, 30.0], dtype=torch.float32)
+ indices = torch.tensor([[0, 1, 2], [0, 1, 2], [3, 4, 5], [3, 4, 5]], dtype=torch.long)
+ weights = torch.full((n_target, n_stencil), 1.0 / n_stencil)
+ result = regrid_icon_to_rotlatlon(data, indices, weights, nx=2, ny=2)
+ expected_vals = torch.tensor([2.0, 2.0, 20.0, 20.0]).reshape(2, 2)
+ torch.testing.assert_close(result, expected_vals)
+
+ def test_single_weight_selects_value(self):
+ """Weight concentrated on one index should select that value."""
+ data = torch.tensor([10.0, 20.0, 30.0], dtype=torch.float32)
+ indices = torch.tensor([[0, 1, 2]], dtype=torch.long)
+ weights = torch.tensor([[1.0, 0.0, 0.0]], dtype=torch.float32)
+ result = regrid_icon_to_rotlatlon(data, indices, weights, nx=1, ny=1)
+ assert result.item() == pytest.approx(10.0)
+
+ def test_output_values_with_batch_and_channels(self):
+ """Test that values are correctly computed with batch and channel dimensions."""
+ batch, channels, n_unstructured = 2, 3, 20
+ nx, ny = 3, 2
+ data = torch.arange(batch * channels * n_unstructured, dtype=torch.float32).reshape(batch, channels, n_unstructured)
+ indices = torch.tensor(
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8],
+ [9, 10, 11], [12, 13, 14], [15, 16, 17]],
+ dtype=torch.long,
+ )
+ weights = torch.tensor(
+ [[0.5, 0.3, 0.2]] * (nx * ny), dtype=torch.float32
+ )
+ result = regrid_icon_to_rotlatlon(data, indices, weights, nx=nx, ny=ny)
+ result_one_batch_channel = regrid_icon_to_rotlatlon(data[1, 2], indices, weights, nx=nx, ny=ny)
+ torch.testing.assert_close(result[1, 2], result_one_batch_channel)
+
+ def test_result_clamped_to_stencil_range(self):
+ """Result should be clamped between min and max of stencil values."""
+ data = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
+ indices = torch.tensor([[0, 1, 2]], dtype=torch.long)
+ # Weights that would extrapolate outside [1, 3]
+ weights = torch.tensor([[2.0, -0.5, -0.5]], dtype=torch.float32)
+ result = regrid_icon_to_rotlatlon(data, indices, weights, nx=1, ny=1)
+ assert result.item() == pytest.approx(1.0) # Should be clamped to min
+ weights = torch.tensor([[-0.5, -0.5, 2.0]], dtype=torch.float32)
+ result = regrid_icon_to_rotlatlon(data, indices, weights, nx=1, ny=1)
+ assert result.item() == pytest.approx(3.0) # Should be clamped to max
+
+ def test_output_dtype_matches_input(self, simple_regrid_inputs):
+ data, indices, weights = simple_regrid_inputs
+ result = regrid_icon_to_rotlatlon(data, indices, weights, nx=3, ny=2)
+ assert result.dtype == data.dtype
+
+ def test_constant_data_gives_constant_output(self):
+ """Constant input across the stencil should produce constant output."""
+ val = 7.0
+ data = torch.full((10,), val)
+ indices = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.long)
+ weights = torch.tensor([[0.5, 0.3, 0.2], [0.1, 0.6, 0.3]])
+ result = regrid_icon_to_rotlatlon(data, indices, weights, nx=1, ny=2)
+ torch.testing.assert_close(result, torch.full((2, 1), val))
+
+
+############################################################################
+# GridData — init #
+############################################################################
+
+
+class TestGridDataInit:
+ """Tests for GridData initialization and input validation."""
+
+ @pytest.fixture
+ def regular_grid_points(self):
+ """Create a regular 5x5 grid of original points and target points inside it."""
+ lons_orig, lats_orig = np.meshgrid(
+ np.linspace(0, 4, 5), np.linspace(0, 4, 5)
+ )
+ lons_orig = lons_orig.ravel()
+ lats_orig = lats_orig.ravel()
+
+ lons_target = np.array([1.0, 2.0, 3.0, 1.5])
+ lats_target = np.array([1.0, 2.0, 3.0, 2.5])
+ return lons_orig, lats_orig, lons_target, lats_target
+
+ def test_creation(self, regular_grid_points):
+ gd = GridData(*regular_grid_points)
+ assert gd is not None
+
+ def test_mismatched_orig_raises(self):
+ with pytest.raises(ValueError, match="Original longitude and latitude"):
+ GridData(
+ np.array([0.0, 1.0]),
+ np.array([0.0]),
+ np.array([0.5]),
+ np.array([0.5]),
+ )
+
+ def test_mismatched_target_raises(self):
+ with pytest.raises(ValueError, match="Target longitude and latitude"):
+ GridData(
+ np.array([0.0, 1.0]),
+ np.array([0.0, 1.0]),
+ np.array([0.5, 0.6]),
+ np.array([0.5]),
+ )
+
+ def test_initial_state_is_numpy(self, regular_grid_points):
+ gd = GridData(*regular_grid_points)
+ assert gd.is_torch is False
+ assert gd.device is None
+ assert isinstance(gd._lambda1, np.ndarray)
+ assert isinstance(gd._lambda2, np.ndarray)
+ assert isinstance(gd._lambda3, np.ndarray)
+ assert isinstance(gd._simplex_id, np.ndarray)
+ assert isinstance(gd._tri.simplices, np.ndarray)
+
+ def test_barycentric_weights_sum_to_one(self, regular_grid_points):
+ gd = GridData(*regular_grid_points)
+ total = gd._lambda1 + gd._lambda2 + gd._lambda3
+ np.testing.assert_allclose(total, 1.0, atol=1e-12)
+
+ def test_barycentric_weights_non_negative_for_interior(self, regular_grid_points):
+ """Points inside the convex hull should have non-negative barycentric coords."""
+ gd = GridData(*regular_grid_points)
+ inside = gd._simplex_id != -1
+ assert np.all(gd._lambda1[inside] >= -1e-12)
+ assert np.all(gd._lambda2[inside] >= -1e-12)
+ assert np.all(gd._lambda3[inside] >= -1e-12)
+
+
+############################################################################
+# GridData — to_torch / to_numpy #
+############################################################################
+
+
+class TestGridDataDeviceConversion:
+ """Tests for to_torch and to_numpy conversions."""
+
+ @pytest.fixture
+ def grid_data(self):
+ lons_orig, lats_orig = np.meshgrid(
+ np.linspace(0, 4, 5), np.linspace(0, 4, 5)
+ )
+ lons_target = np.array([1.0, 2.0, 3.0])
+ lats_target = np.array([1.0, 2.0, 3.0])
+ return GridData(
+ lons_orig.ravel(), lats_orig.ravel(), lons_target, lats_target
+ )
+
+ def test_to_torch_sets_flag(self, grid_data):
+ grid_data.to_torch()
+ assert grid_data.is_torch is True
+ assert grid_data.device == torch.device("cpu")
+
+ def test_to_torch_produces_tensors(self, grid_data):
+ grid_data.to_torch()
+ assert isinstance(grid_data._lambda1, torch.Tensor)
+ assert isinstance(grid_data._lambda2, torch.Tensor)
+ assert isinstance(grid_data._lambda3, torch.Tensor)
+ assert isinstance(grid_data._simplex_id, torch.Tensor)
+ assert isinstance(grid_data._tri.simplices, torch.Tensor)
+
+ def test_to_torch_string_device(self, grid_data):
+ grid_data.to_torch("cpu")
+ assert grid_data.device == torch.device("cpu")
+ assert grid_data.is_torch is True
+
+ def test_to_numpy_restores_arrays(self, grid_data):
+ grid_data.to_torch()
+ grid_data.to_numpy()
+ assert grid_data.is_torch is False
+ assert grid_data.device is None
+ assert isinstance(grid_data._lambda1, np.ndarray)
+ assert isinstance(grid_data._lambda2, np.ndarray)
+ assert isinstance(grid_data._lambda3, np.ndarray)
+ assert isinstance(grid_data._simplex_id, np.ndarray)
+ assert isinstance(grid_data._tri.simplices, np.ndarray)
+
+ def test_to_numpy_noop_when_already_numpy(self, grid_data):
+ """Calling to_numpy when already numpy should be a no-op."""
+ lambda1_before = grid_data._lambda1
+ grid_data.to_numpy()
+ assert grid_data._lambda1 is lambda1_before
+
+ def test_roundtrip_preserves_values(self, grid_data):
+ lambda1_orig = grid_data._lambda1.copy()
+ lambda2_orig = grid_data._lambda2.copy()
+ lambda3_orig = grid_data._lambda3.copy()
+ simplex_id_orig = grid_data._simplex_id.copy()
+ tri_simplices_orig = grid_data._tri.simplices.copy()
+
+ grid_data.to_torch()
+ grid_data.to_numpy()
+
+ np.testing.assert_allclose(grid_data._lambda1, lambda1_orig, atol=1e-12)
+ np.testing.assert_allclose(grid_data._lambda2, lambda2_orig, atol=1e-12)
+ np.testing.assert_allclose(grid_data._lambda3, lambda3_orig, atol=1e-12)
+ np.testing.assert_array_equal(grid_data._simplex_id, simplex_id_orig)
+ np.testing.assert_array_equal(grid_data._tri.simplices, tri_simplices_orig)
+
+
+############################################################################
+# GridData — interpolate (numpy) #
+############################################################################
+
+
+class TestGridDataInterpolateNumpy:
+ """Tests for the interpolate method using numpy arrays."""
+
+ @pytest.fixture
+ def grid_data_on_regular(self):
+ """GridData with a regular grid as source and a few interior targets."""
+ lons_orig, lats_orig = np.meshgrid(
+ np.linspace(0, 4, 5), np.linspace(0, 4, 5)
+ )
+ lons_target = np.array([1.0, 2.0, 3.0])
+ lats_target = np.array([1.0, 2.0, 3.0])
+ return GridData(
+ lons_orig.ravel(), lats_orig.ravel(), lons_target, lats_target
+ )
+
+ def test_interpolate_constant_field(self, grid_data_on_regular):
+ """A constant field should interpolate to the same constant."""
+ n_orig = len(grid_data_on_regular.longitudes_orig)
+ values = np.full((1, n_orig), 5.0)
+ result = grid_data_on_regular.interpolate(values)
+ np.testing.assert_allclose(result, 5.0, atol=1e-12)
+
+ def test_interpolate_linear_field(self, grid_data_on_regular):
+ """A linear field f(x,y) = x + y should be reproduced exactly."""
+ gd = grid_data_on_regular
+ lons = gd.longitudes_orig
+ lats = gd.latitudes_orig
+ values = (lons + lats)[np.newaxis, :] # (1, n_orig)
+
+ result = gd.interpolate(values) # (1, n_target)
+ expected = gd.longitudes_target + gd.latitudes_target
+ np.testing.assert_allclose(result[0], expected, atol=1e-10)
+
+ def test_interpolate_multichannel(self, grid_data_on_regular):
+ """Multiple channels should be interpolated independently."""
+ gd = grid_data_on_regular
+ n_orig = len(gd.longitudes_orig)
+ n_channels = 3
+ values = np.random.RandomState(42).randn(n_channels, n_orig)
+ result = gd.interpolate(values)
+ assert result.shape == (n_channels, len(gd.longitudes_target))
+
+ def test_interpolate_batch_and_channels(self, grid_data_on_regular):
+ """Batch + channel dimensions should be preserved."""
+ gd = grid_data_on_regular
+ n_orig = len(gd.longitudes_orig)
+ batch, channels = 2, 3
+ values = np.random.RandomState(0).randn(batch, channels, n_orig)
+ result = gd.interpolate(values)
+ assert result.shape == (batch, channels, len(gd.longitudes_target))
+
+ def test_interpolate_wrong_last_dim_raises(self, grid_data_on_regular):
+ with pytest.raises(ValueError, match="Expected values with shape"):
+ grid_data_on_regular.interpolate(np.zeros((1, 7)))
+
+ def test_fill_value_for_outside_points(self):
+ """Points outside the convex hull should get fill_value."""
+ lons_orig = np.array([0.0, 1.0, 0.0, 1.0])
+ lats_orig = np.array([0.0, 0.0, 1.0, 1.0])
+ # One inside, one far outside
+ lons_target = np.array([0.5, 10.0])
+ lats_target = np.array([0.5, 10.0])
+
+ gd = GridData(lons_orig, lats_orig, lons_target, lats_target)
+ values = np.array([[1.0, 2.0, 3.0, 4.0]])
+ result = gd.interpolate(values, fill_value=-999.0)
+
+ # Outside point should be fill_value
+ assert result[0, 1] == -999.0
+
+ def test_fill_value_default_nan(self):
+ """Default fill_value should be NaN."""
+ lons_orig = np.array([0.0, 1.0, 0.0, 1.0])
+ lats_orig = np.array([0.0, 0.0, 1.0, 1.0])
+ lons_target = np.array([0.5, 10.0])
+ lats_target = np.array([0.5, 10.0])
+
+ gd = GridData(lons_orig, lats_orig, lons_target, lats_target)
+ values = np.array([[1.0, 2.0, 3.0, 4.0]])
+ result = gd.interpolate(values)
+ assert np.isnan(result[0, 1])
+
+ def test_callable_alias(self, grid_data_on_regular):
+ """__call__ should produce the same result as interpolate."""
+ gd = grid_data_on_regular
+ n_orig = len(gd.longitudes_orig)
+ values = np.random.RandomState(7).randn(1, n_orig)
+ np.testing.assert_array_equal(gd(values), gd.interpolate(values))
+
+ def test_channel_batch_consistency(self, grid_data_on_regular):
+ """Interpolate should give consistent results across channels and batches."""
+ gd = grid_data_on_regular
+ n_orig = len(gd.longitudes_orig)
+ batch, channels = 2, 3
+ values = np.random.RandomState(123).randn(batch, channels, n_orig)
+ result = gd.interpolate(values)
+ result_one_batch_channel = gd.interpolate(values[1, 1])
+
+ np.testing.assert_allclose(result[1, 1], result_one_batch_channel, atol=1e-12)
+
+
+############################################################################
+# GridData — interpolate (torch) #
+############################################################################
+
+
+class TestGridDataInterpolateTorch:
+ """Tests for interpolation using PyTorch tensors."""
+
+ @pytest.fixture
+ def grid_data_torch(self):
+ lons_orig, lats_orig = np.meshgrid(
+ np.linspace(0, 4, 5), np.linspace(0, 4, 5)
+ )
+ lons_target = np.array([1.0, 2.0, 3.0])
+ lats_target = np.array([1.0, 2.0, 3.0])
+ gd = GridData(
+ lons_orig.ravel(), lats_orig.ravel(), lons_target, lats_target
+ )
+ gd.to_torch()
+ return gd
+
+ def test_interpolate_constant_field_torch(self, grid_data_torch):
+ gd = grid_data_torch
+ n_orig = len(gd.longitudes_orig)
+ values = torch.full((1, n_orig), 5.0)
+ result = gd.interpolate(values)
+ assert isinstance(result, torch.Tensor)
+ torch.testing.assert_close(result, torch.full_like(result, 5.0), atol=1e-6, rtol=0)
+
+ def test_interpolate_linear_field_torch(self, grid_data_torch):
+ gd = grid_data_torch
+ lons = torch.from_numpy(gd.longitudes_orig)
+ lats = torch.from_numpy(gd.latitudes_orig)
+ values = (lons + lats).unsqueeze(0) # (1, n_orig)
+ result = gd.interpolate(values) # (1, n_target)
+ expected = torch.from_numpy(gd.longitudes_target + gd.latitudes_target)
+ torch.testing.assert_close(result[0], expected, atol=1e-6, rtol=1e-5)
+
+ def test_interpolate_multichannel_torch(self, grid_data_torch):
+ gd = grid_data_torch
+ n_orig = len(gd.longitudes_orig)
+ n_channels = 3
+ values = torch.randn(n_channels, n_orig)
+ result = gd.interpolate(values)
+ assert result.shape == (n_channels, len(gd.longitudes_target))
+
+ def test_interpolate_batch_and_channels_torch(self, grid_data_torch):
+ gd = grid_data_torch
+ n_orig = len(gd.longitudes_orig)
+ batch, channels = 2, 3
+ values = torch.randn(batch, channels, n_orig)
+ result = gd.interpolate(values)
+ assert result.shape == (batch, channels, len(gd.longitudes_target))
+
+ def test_interpolate_wrong_last_dim_raises_torch(self, grid_data_torch):
+ with pytest.raises(ValueError, match="Expected values with shape"):
+ grid_data_torch.interpolate(torch.zeros((1, 7)))
+
+ def test_fill_value_for_outside_points_torch(self):
+ lons_orig = np.array([0.0, 1.0, 0.0, 1.0])
+ lats_orig = np.array([0.0, 0.0, 1.0, 1.0])
+ lons_target = np.array([0.5, 10.0])
+ lats_target = np.array([0.5, 10.0])
+
+ gd = GridData(lons_orig, lats_orig, lons_target, lats_target)
+ gd.to_torch()
+ values = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
+ result = gd.interpolate(values, fill_value=-999.0)
+ assert result[0, 1].item() == -999.0
+
+
+ def test_fill_value_default_nan_torch(self):
+ lons_orig = np.array([0.0, 1.0, 0.0, 1.0])
+ lats_orig = np.array([0.0, 0.0, 1.0, 1.0])
+ lons_target = np.array([0.5, 10.0])
+ lats_target = np.array([0.5, 10.0])
+
+ gd = GridData(lons_orig, lats_orig, lons_target, lats_target)
+ gd.to_torch()
+ values = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
+ result = gd.interpolate(values)
+ assert torch.isnan(result[0, 1])
+
+ def test_channel_batch_consistency(self, grid_data_torch):
+ gd = grid_data_torch
+ n_orig = len(gd.longitudes_orig)
+ batch, channels = 2, 3
+ values = torch.randn(batch, channels, n_orig)
+ result = gd.interpolate(values)
+ result_one_batch_channel = gd.interpolate(values[1, 1])
+ torch.testing.assert_close(result[1, 1], result_one_batch_channel, atol=1e-6, rtol=0)
+
+ def test_numpy_and_torch_agree(self):
+ """Numpy and Torch paths should produce the same results."""
+ lons_orig, lats_orig = np.meshgrid(
+ np.linspace(0, 4, 5), np.linspace(0, 4, 5)
+ )
+ lons_target = np.array([1.0, 2.5, 3.5])
+ lats_target = np.array([1.0, 2.5, 3.5])
+
+ gd_np = GridData(
+ lons_orig.ravel(), lats_orig.ravel(), lons_target, lats_target
+ )
+ rng = np.random.RandomState(99)
+ values_np = rng.randn(2, len(lons_orig.ravel()))
+
+ result_np = gd_np.interpolate(values_np)
+
+ gd_np.to_torch()
+ values_torch = torch.from_numpy(values_np)
+ result_torch = gd_np.interpolate(values_torch)
+
+ np.testing.assert_allclose(
+ result_np, result_torch.numpy(), atol=1e-10
+ )
+
+
+############################################################################
+# GridData — outside hull warning #
+############################################################################
+
+
+class TestGridDataOutsideHull:
+ """Test behaviour when target points lie outside the convex hull."""
+
+ def test_outside_hull_warning(self, capsys):
+ lons_orig = np.array([0.0, 1.0, 0.0])
+ lats_orig = np.array([0.0, 0.0, 1.0])
+ lons_target = np.array([0.25, -5.0])
+ lats_target = np.array([0.25, -5.0])
+
+ GridData(lons_orig, lats_orig, lons_target, lats_target)
+ captured = capsys.readouterr()
+ assert "outside the convex hull" in captured.out
+
+ def test_no_warning_when_all_inside(self, capsys):
+ lons_orig, lats_orig = np.meshgrid(
+ np.linspace(0, 4, 5), np.linspace(0, 4, 5)
+ )
+ lons_target = np.array([1.0, 2.0])
+ lats_target = np.array([1.0, 2.0])
+
+ GridData(lons_orig.ravel(), lats_orig.ravel(), lons_target, lats_target)
+ captured = capsys.readouterr()
+ assert "outside the convex hull" not in captured.out
diff --git a/tests/utils/test_env_info.py b/tests/utils/test_env_info.py
new file mode 100644
index 00000000..b6cda265
--- /dev/null
+++ b/tests/utils/test_env_info.py
@@ -0,0 +1,440 @@
+import os
+from types import ModuleType
+from unittest.mock import MagicMock, PropertyMock, patch
+
+import pytest
+
+from hirad.utils.env_info import (
+ flatten_dict,
+ get_env_info,
+ get_git_info,
+ get_module_git_info,
+ get_module_version,
+)
+
+
+############################################################################
+# get_module_version #
+############################################################################
+
+
+class TestGetModuleVersion:
+ """Tests for get_module_version."""
+
+ def test_returns_version_string(self):
+ module = ModuleType("fake_mod")
+ module.__version__ = "1.2.3"
+ assert get_module_version(module) == "1.2.3"
+
+ def test_returns_none_when_no_version_attr(self):
+ module = ModuleType("no_ver")
+ assert get_module_version(module) is None
+
+ def test_returns_none_when_version_is_not_string(self):
+ module = ModuleType("bad_ver")
+ module.__version__ = (1, 2, 3)
+ assert get_module_version(module) is None
+
+ def test_returns_none_when_version_is_none(self):
+ module = ModuleType("none_ver")
+ module.__version__ = None
+ assert get_module_version(module) is None
+
+ def test_returns_none_when_version_is_int(self):
+ module = ModuleType("int_ver")
+ module.__version__ = 42
+ assert get_module_version(module) is None
+
+ def test_accepts_semver_string(self):
+ module = ModuleType("semver")
+ module.__version__ = "0.0.1-alpha"
+ assert get_module_version(module) == "0.0.1-alpha"
+
+
+############################################################################
+# get_git_info #
+############################################################################
+
+
+class TestGetGitInfo:
+ """Tests for get_git_info."""
+
+ def test_returns_none_for_non_git_path(self, tmp_path):
+ result = get_git_info(str(tmp_path))
+ assert result is None
+
+ def test_returns_dict_with_expected_keys(self):
+ mock_commit = MagicMock()
+ mock_commit.hexsha = "abc123"
+
+ mock_head = MagicMock()
+ mock_head.commit = mock_commit
+
+ mock_remote = MagicMock()
+ mock_remote.url = "https://github.com/user/repo.git"
+
+ mock_repo = MagicMock()
+ mock_repo.head = mock_head
+ mock_repo.index.diff.return_value = []
+ mock_repo.untracked_files = []
+ mock_repo.remotes = [mock_remote]
+
+ with patch("hirad.utils.env_info.Repo", return_value=mock_repo):
+ result = get_git_info("/some/path")
+
+ assert result is not None
+ assert result["sha1"] == "abc123"
+ assert result["diff"] == ""
+ assert result["untracked_files"] == []
+ assert result["remotes"] == ["https://github.com/user/repo.git"]
+
+ def test_untracked_files_are_sorted(self):
+ mock_commit = MagicMock()
+ mock_commit.hexsha = "def456"
+
+ mock_head = MagicMock()
+ mock_head.commit = mock_commit
+
+ mock_repo = MagicMock()
+ mock_repo.head = mock_head
+ mock_repo.index.diff.return_value = []
+ mock_repo.untracked_files = ["c.py", "a.py", "b.py"]
+ mock_repo.remotes = []
+
+ with patch("hirad.utils.env_info.Repo", return_value=mock_repo):
+ result = get_git_info("/some/path")
+
+ assert result["untracked_files"] == ["a.py", "b.py", "c.py"]
+
+ def test_diff_content_bytes_decoded(self):
+ mock_commit = MagicMock()
+ mock_commit.hexsha = "aaa"
+
+ mock_head = MagicMock()
+ mock_head.commit = mock_commit
+
+ diff_item = MagicMock()
+ diff_item.a_blob.abspath = "/a/file.py"
+ diff_item.b_blob.abspath = "/b/file.py"
+ diff_item.diff = b"+new line"
+
+ mock_repo = MagicMock()
+ mock_repo.head = mock_head
+ mock_repo.index.diff.return_value = [diff_item]
+ mock_repo.untracked_files = []
+ mock_repo.remotes = []
+
+ with patch("hirad.utils.env_info.Repo", return_value=mock_repo):
+ result = get_git_info("/some/path")
+
+ assert "+new line" in result["diff"]
+
+ def test_diff_content_string_passthrough(self):
+ mock_commit = MagicMock()
+ mock_commit.hexsha = "bbb"
+
+ mock_head = MagicMock()
+ mock_head.commit = mock_commit
+
+ diff_item = MagicMock()
+ diff_item.a_blob.abspath = "/a/file.py"
+ diff_item.b_blob.abspath = "/b/file.py"
+ diff_item.diff = "+already a string"
+
+ mock_repo = MagicMock()
+ mock_repo.head = mock_head
+ mock_repo.index.diff.return_value = [diff_item]
+ mock_repo.untracked_files = []
+ mock_repo.remotes = []
+
+ with patch("hirad.utils.env_info.Repo", return_value=mock_repo):
+ result = get_git_info("/some/path")
+
+ assert "+already a string" in result["diff"]
+
+ def test_diff_content_none_treated_as_empty(self):
+ mock_commit = MagicMock()
+ mock_commit.hexsha = "ccc"
+
+ mock_head = MagicMock()
+ mock_head.commit = mock_commit
+
+ diff_item = MagicMock()
+ diff_item.a_blob.abspath = "/a/file.py"
+ diff_item.b_blob.abspath = "/b/file.py"
+ diff_item.diff = None
+
+ mock_repo = MagicMock()
+ mock_repo.head = mock_head
+ mock_repo.index.diff.return_value = [diff_item]
+ mock_repo.untracked_files = []
+ mock_repo.remotes = []
+
+ with patch("hirad.utils.env_info.Repo", return_value=mock_repo):
+ result = get_git_info("/some/path")
+
+ # The diff should contain the header lines but no diff content
+ assert "--- a/a/file.py" in result["diff"]
+ assert "+++ b/b/file.py" in result["diff"]
+
+ def test_diff_multiple_diff_items_concatenated(self):
+ mock_commit = MagicMock()
+ mock_commit.hexsha = "eee"
+
+ mock_head = MagicMock()
+ mock_head.commit = mock_commit
+
+ diff_item1 = MagicMock()
+ diff_item1.a_blob.abspath = "/a/file1.py"
+ diff_item1.b_blob.abspath = "/b/file1.py"
+ diff_item1.diff = "+line in file1"
+
+ diff_item2 = MagicMock()
+ diff_item2.a_blob.abspath = "/a/file2.py"
+ diff_item2.b_blob.abspath = "/b/file2.py"
+ diff_item2.diff = "+line in file2"
+
+ mock_repo = MagicMock()
+ mock_repo.head = mock_head
+ mock_repo.index.diff.return_value = [diff_item1, diff_item2]
+ mock_repo.untracked_files = []
+ mock_repo.remotes = []
+
+ with patch("hirad.utils.env_info.Repo", return_value=mock_repo):
+ result = get_git_info("/some/path")
+
+ assert "+line in file1" in result["diff"]
+ assert "+line in file2" in result["diff"]
+ assert result["diff"].count("--- a/") == 2
+ assert result["diff"].count("+++ b/") == 2
+ assert diff_item1.a_blob.abspath in result["diff"]
+ assert diff_item1.b_blob.abspath in result["diff"]
+ assert diff_item2.a_blob.abspath in result["diff"]
+ assert diff_item2.b_blob.abspath in result["diff"]
+
+ def test_multiple_remotes(self):
+ mock_commit = MagicMock()
+ mock_commit.hexsha = "ddd"
+
+ mock_head = MagicMock()
+ mock_head.commit = mock_commit
+
+ remote1 = MagicMock()
+ remote1.url = "https://github.com/user/repo1.git"
+ remote2 = MagicMock()
+ remote2.url = "git@github.com:user/repo2.git"
+
+ mock_repo = MagicMock()
+ mock_repo.head = mock_head
+ mock_repo.index.diff.return_value = []
+ mock_repo.untracked_files = []
+ mock_repo.remotes = [remote1, remote2]
+
+ with patch("hirad.utils.env_info.Repo", return_value=mock_repo):
+ result = get_git_info("/some/path")
+
+ assert len(result["remotes"]) == 2
+
+
+############################################################################
+# get_module_git_info #
+############################################################################
+
+
+class TestGetModuleGitInfo:
+ """Tests for get_module_git_info."""
+
+ def test_returns_none_when_no_file_attr(self):
+ module = ModuleType("no_file")
+ # ModuleType does not set __file__ by default
+ assert get_module_git_info(module) is None
+
+ def test_returns_none_for_relative_path(self):
+ module = ModuleType("rel_path")
+ module.__file__ = "relative/path.py"
+ assert get_module_git_info(module) is None
+
+ def test_delegates_to_get_git_info(self, tmp_path):
+ module = ModuleType("abs_path")
+ module.__file__ = str(tmp_path / "pkg" / "mod.py")
+
+ fake_git_info = {"sha1": "abc", "diff": "", "untracked_files": [], "remotes": []}
+ with patch("hirad.utils.env_info.get_git_info", return_value=fake_git_info) as mock_fn:
+ result = get_module_git_info(module)
+
+ mock_fn.assert_called_once_with(str(tmp_path / "pkg"))
+ assert result == fake_git_info
+
+ def test_returns_none_when_get_git_info_returns_none(self, tmp_path):
+ module = ModuleType("no_git")
+ module.__file__ = str(tmp_path / "mod.py")
+
+ with patch("hirad.utils.env_info.get_git_info", return_value=None):
+ assert get_module_git_info(module) is None
+
+
+############################################################################
+# flatten_dict #
+############################################################################
+
+
+class TestFlattenDict:
+ """Tests for flatten_dict."""
+
+ def test_already_flat(self):
+ d = {"a": 1, "b": 2}
+ assert flatten_dict(d) == {"a": 1, "b": 2}
+
+ def test_one_level_nesting(self):
+ d = {"a": {"x": 1, "y": 2}, "b": 3}
+ assert flatten_dict(d) == {"a.x": 1, "a.y": 2, "b": 3}
+
+ def test_two_level_nesting(self):
+ d = {"a": {"b": {"c": 42}}}
+ assert flatten_dict(d) == {"a.b.c": 42}
+
+ def test_custom_separator(self):
+ d = {"a": {"b": 1}}
+ assert flatten_dict(d, sep="/") == {"a/b": 1}
+
+ def test_custom_parent_key(self):
+ d = {"x": 1}
+ assert flatten_dict(d, parent_key="root") == {"root.x": 1}
+
+ def test_empty_dict(self):
+ assert flatten_dict({}) == {}
+
+ def test_mixed_nested_and_flat(self):
+ d = {"a": 1, "b": {"c": 2}, "d": {"e": {"f": 3}}}
+ expected = {"a": 1, "b.c": 2, "d.e.f": 3}
+ assert flatten_dict(d) == expected
+
+ def test_preserves_non_dict_values(self):
+ d = {"a": [1, 2], "b": {"c": "text"}, "d": None}
+ expected = {"a": [1, 2], "b.c": "text", "d": None}
+ assert flatten_dict(d) == expected
+
+
+############################################################################
+# get_env_info #
+############################################################################
+
+
+class TestGetEnvInfo:
+ """Tests for get_env_info."""
+
+ def _make_module(self, name, version=None, file_path=None):
+ """Create a fake module with optional version and __file__."""
+ mod = ModuleType(name)
+ if version is not None:
+ mod.__version__ = version
+ if file_path is not None:
+ mod.__file__ = file_path
+ return mod
+
+ def test_returns_tuple_of_two(self):
+ with patch("hirad.utils.env_info.get_module_git_info", return_value=None):
+ result = get_env_info()
+ assert isinstance(result, tuple)
+ assert len(result) == 2
+
+ def test_always_includes_python_version(self):
+ with patch("hirad.utils.env_info.get_module_git_info", return_value=None):
+ info, _ = get_env_info(flatten=False)
+ assert "python" in info
+ assert "version" in info["python"]
+
+ def test_flatten_true_flattens_output(self):
+ with patch("hirad.utils.env_info.get_module_git_info", return_value=None):
+ info, _ = get_env_info(flatten=True)
+ assert "python.version" in info
+
+ def test_flatten_false_keeps_nested(self):
+ with patch("hirad.utils.env_info.get_module_git_info", return_value=None):
+ info, _ = get_env_info(flatten=False)
+ assert isinstance(info.get("python"), dict)
+
+ def test_excludes_builtin_modules(self):
+ import sys
+
+ with patch("hirad.utils.env_info.get_module_git_info", return_value=None):
+ info, _ = get_env_info(flatten=False)
+ for name in sys.builtin_module_names:
+ assert name not in info
+
+ def test_exclude_prefixes_filters_modules(self):
+ fake_mod = self._make_module("fakepkg_test", version="1.0.0")
+ with patch.dict("sys.modules", {"fakepkg_test": fake_mod}):
+ with patch("hirad.utils.env_info.get_module_git_info", return_value=None):
+ info, _ = get_env_info(flatten=False, exclude_prefixes=["fakepkg_test"])
+ assert "fakepkg_test" not in info
+
+ def test_exclude_prefixes_filters_submodules(self):
+ fake_sub = self._make_module("fakepkg_test.sub", version="2.0.0")
+ with patch.dict("sys.modules", {"fakepkg_test.sub": fake_sub}):
+ with patch("hirad.utils.env_info.get_module_git_info", return_value=None):
+ info, _ = get_env_info(flatten=False, exclude_prefixes=["fakepkg_test"])
+ assert "fakepkg_test.sub" not in info
+
+ def test_module_with_version_included(self):
+ fake_mod = self._make_module("mypkg_env_test", version="3.1.4")
+ with patch.dict("sys.modules", {"mypkg_env_test": fake_mod}):
+ with patch("hirad.utils.env_info.get_module_git_info", return_value=None):
+ info, _ = get_env_info(flatten=False)
+ assert "mypkg_env_test" in info
+ assert info["mypkg_env_test"]["version"] == "3.1.4"
+
+ def test_module_with_git_info_included(self):
+ fake_mod = self._make_module("gitpkg_test", version="0.1.0")
+ git_info = {
+ "sha1": "abc123",
+ "diff": "",
+ "untracked_files": [],
+ "remotes": ["https://example.com/repo.git"],
+ }
+ with patch.dict("sys.modules", {"gitpkg_test": fake_mod}):
+ with patch(
+ "hirad.utils.env_info.get_module_git_info",
+ side_effect=lambda m: git_info.copy(),
+ ):
+ info, _ = get_env_info(flatten=False)
+ assert "gitpkg_test" in info
+ assert "git" in info["gitpkg_test"]
+ assert info["gitpkg_test"]["git"]["sha1"] == "abc123"
+ assert "diff" not in info["gitpkg_test"]["git"]
+
+ def test_diffs_collected_in_second_element(self):
+ fake_mod = self._make_module("diffpkg_test", version="1.0.0")
+ git_info = {
+ "sha1": "aaa",
+ "diff": "+added line\n",
+ "untracked_files": [],
+ "remotes": [],
+ }
+ with patch.dict("sys.modules", {"diffpkg_test": fake_mod}):
+ with patch(
+ "hirad.utils.env_info.get_module_git_info",
+ side_effect=lambda m: git_info.copy(),
+ ):
+ _, diffs_str = get_env_info(flatten=False)
+ assert "+added line" in diffs_str
+
+ def test_modules_without_version_and_git_excluded(self):
+ fake_mod = self._make_module("bare_mod_test")
+ with patch.dict("sys.modules", {"bare_mod_test": fake_mod}):
+ with patch("hirad.utils.env_info.get_module_git_info", return_value=None):
+ info, _ = get_env_info(flatten=False)
+ assert "bare_mod_test" not in info
+
+ def test_version_suffix_modules_skipped(self):
+ """Modules ending with .version or ._version are skipped."""
+ fake_ver = self._make_module("somepkg.version", version="1.0")
+ fake_uver = self._make_module("somepkg._version", version="1.0")
+ with patch.dict(
+ "sys.modules",
+ {"somepkg.version": fake_ver, "somepkg._version": fake_uver},
+ ):
+ with patch("hirad.utils.env_info.get_module_git_info", return_value=None):
+ info, _ = get_env_info(flatten=False)
+ assert "somepkg.version" not in info
+ assert "somepkg._version" not in info
diff --git a/tests/utils/test_function_utils.py b/tests/utils/test_function_utils.py
new file mode 100644
index 00000000..57a69072
--- /dev/null
+++ b/tests/utils/test_function_utils.py
@@ -0,0 +1,351 @@
+import datetime
+
+import numpy as np
+import pytest
+import torch
+
+from hirad.utils.function_utils import (
+ InfiniteSampler,
+ StackedRandomGenerator,
+ get_time_from_range,
+ time_range,
+)
+
+
+############################################################################
+# time_range #
+############################################################################
+
+
+class TestTimeRange:
+ """Tests for the time_range generator."""
+
+ def test_basic_hourly_range(self):
+ start = datetime.datetime(2024, 1, 1, 0, 0)
+ end = datetime.datetime(2024, 1, 1, 3, 0)
+ step = datetime.timedelta(hours=1)
+ result = list(time_range(start, end, step))
+ assert result == [
+ datetime.datetime(2024, 1, 1, 0, 0),
+ datetime.datetime(2024, 1, 1, 1, 0),
+ datetime.datetime(2024, 1, 1, 2, 0),
+ ]
+
+ def test_inclusive_range(self):
+ start = datetime.datetime(2024, 1, 1, 0, 0)
+ end = datetime.datetime(2024, 1, 1, 3, 0)
+ step = datetime.timedelta(hours=1)
+ result = list(time_range(start, end, step, inclusive=True))
+ assert result == [
+ datetime.datetime(2024, 1, 1, 0, 0),
+ datetime.datetime(2024, 1, 1, 1, 0),
+ datetime.datetime(2024, 1, 1, 2, 0),
+ datetime.datetime(2024, 1, 1, 3, 0),
+ ]
+
+ def test_exclusive_does_not_include_end(self):
+ start = datetime.datetime(2024, 6, 1, 12, 0)
+ end = datetime.datetime(2024, 6, 1, 14, 0)
+ step = datetime.timedelta(hours=1)
+ result = list(time_range(start, end, step, inclusive=False))
+ assert end not in result
+
+ def test_empty_range_when_start_equals_end(self):
+ t = datetime.datetime(2024, 1, 1, 0, 0)
+ result = list(time_range(t, t, datetime.timedelta(hours=1)))
+ assert result == []
+
+ def test_inclusive_single_element_when_start_equals_end(self):
+ t = datetime.datetime(2024, 1, 1, 0, 0)
+ result = list(time_range(t, t, datetime.timedelta(hours=1), inclusive=True))
+ assert result == [t]
+
+ def test_empty_range_when_start_after_end(self):
+ start = datetime.datetime(2024, 1, 2)
+ end = datetime.datetime(2024, 1, 1)
+ result = list(time_range(start, end, datetime.timedelta(hours=1)))
+ assert result == []
+
+ def test_sub_hourly_step(self):
+ start = datetime.datetime(2024, 1, 1, 0, 0)
+ end = datetime.datetime(2024, 1, 1, 0, 30)
+ step = datetime.timedelta(minutes=10)
+ result = list(time_range(start, end, step))
+ assert len(result) == 3
+ assert result == [
+ datetime.datetime(2024, 1, 1, 0, 0),
+ datetime.datetime(2024, 1, 1, 0, 10),
+ datetime.datetime(2024, 1, 1, 0, 20),
+ ]
+
+ def test_daily_step(self):
+ start = datetime.datetime(2024, 1, 1)
+ end = datetime.datetime(2024, 1, 4)
+ step = datetime.timedelta(days=1)
+ result = list(time_range(start, end, step))
+ assert len(result) == 3
+ assert result[0] == datetime.datetime(2024, 1, 1)
+ assert result[-1] == datetime.datetime(2024, 1, 3)
+
+
+############################################################################
+# get_time_from_range #
+############################################################################
+
+
+class TestGetTimeFromRange:
+ """Tests for get_time_from_range."""
+
+ def test_basic_range_with_default_interval(self):
+ times = get_time_from_range(["2024-01-01T00:00:00", "2024-01-01T03:00:00"])
+ assert len(times) == 4 # inclusive: 00, 01, 02, 03
+ assert times[0] == "2024-01-01T00:00:00"
+ assert times[-1] == "2024-01-01T03:00:00"
+
+ def test_custom_interval(self):
+ times = get_time_from_range(
+ ["2024-01-01T00:00:00", "2024-01-01T06:00:00", 2]
+ )
+ assert len(times) == 4 # 00, 02, 04, 06
+ assert times == [
+ "2024-01-01T00:00:00",
+ "2024-01-01T02:00:00",
+ "2024-01-01T04:00:00",
+ "2024-01-01T06:00:00",
+ ]
+
+ def test_single_time_when_start_equals_end(self):
+ times = get_time_from_range(["2024-06-15T12:00:00", "2024-06-15T12:00:00"])
+ assert times == ["2024-06-15T12:00:00"]
+
+ def test_multi_day_range(self):
+ times = get_time_from_range(
+ ["2024-01-01T00:00:00", "2024-01-02T00:00:00", 6]
+ )
+ assert len(times) == 5 # 00, 06, 12, 18, 00+1day
+ assert times[-2] == "2024-01-01T18:00:00"
+
+ def test_custom_time_format(self):
+ fmt = "%Y%m%d-%H%M"
+ times = get_time_from_range(["20240101-0000", "20240101-0300"], time_format=fmt)
+ assert len(times) == 4
+ assert times[0] == "20240101-0000"
+ assert times[-1] == "20240101-0300"
+
+ def test_returns_strings(self):
+ times = get_time_from_range(["2024-01-01T00:00:00", "2024-01-01T02:00:00"])
+ assert all(isinstance(t, str) for t in times)
+
+
+############################################################################
+# StackedRandomGenerator #
+############################################################################
+
+
+class TestStackedRandomGenerator:
+ """Tests for StackedRandomGenerator."""
+
+ def test_randn_shape(self):
+ gen = StackedRandomGenerator(device="cpu", seeds=[1, 2, 3])
+ out = gen.randn([3, 4, 5])
+ assert out.shape == (3, 4, 5)
+
+ def test_randn_batch_mismatch_raises(self):
+ gen = StackedRandomGenerator(device="cpu", seeds=[1, 2])
+ with pytest.raises(ValueError, match="Expected first dimension"):
+ gen.randn([5, 4])
+
+ def test_randn_reproducibility(self):
+ gen1 = StackedRandomGenerator(device="cpu", seeds=[42, 99])
+ gen2 = StackedRandomGenerator(device="cpu", seeds=[42, 99])
+ out1 = gen1.randn([2, 8])
+ out2 = gen2.randn([2, 8])
+ assert torch.allclose(out1, out2)
+
+ def test_randn_different_seeds_give_different_output(self):
+ gen1 = StackedRandomGenerator(device="cpu", seeds=[1, 2])
+ gen2 = StackedRandomGenerator(device="cpu", seeds=[3, 4])
+ out1 = gen1.randn([2, 100])
+ out2 = gen2.randn([2, 100])
+ assert not torch.allclose(out1, out2)
+
+ def test_randn_like(self):
+ gen = StackedRandomGenerator(device="cpu", seeds=[10, 20])
+ template = torch.zeros(2, 3, 4)
+ out = gen.randn_like(template)
+ assert out.shape == template.shape
+ assert out.dtype == template.dtype
+
+ def test_randint_shape(self):
+ gen = StackedRandomGenerator(device="cpu", seeds=[1, 2, 3])
+ out = gen.randint(0, 10, size=[3, 5])
+ assert out.shape == (3, 5)
+
+ def test_randint_batch_mismatch_raises(self):
+ gen = StackedRandomGenerator(device="cpu", seeds=[1])
+ with pytest.raises(ValueError, match="Expected first dimension"):
+ gen.randint(0, 10, size=[4, 5])
+
+ def test_randint_values_in_range(self):
+ gen = StackedRandomGenerator(device="cpu", seeds=[7, 8])
+ out = gen.randint(0, 5, size=[2, 100])
+ assert (out >= 0).all()
+ assert (out < 5).all()
+
+ def test_randint_reproducibility(self):
+ gen1 = StackedRandomGenerator(device="cpu", seeds=[42, 99])
+ gen2 = StackedRandomGenerator(device="cpu", seeds=[42, 99])
+ out1 = gen1.randint(0, 100, size=[2, 50])
+ out2 = gen2.randint(0, 100, size=[2, 50])
+ assert torch.equal(out1, out2)
+
+ def test_randint_different_seeds_give_different_output(self):
+ gen1 = StackedRandomGenerator(device="cpu", seeds=[1, 2])
+ gen2 = StackedRandomGenerator(device="cpu", seeds=[3, 4])
+ out1 = gen1.randint(0, 100, size=[2, 50])
+ out2 = gen2.randint(0, 100, size=[2, 50])
+ assert not torch.equal(out1, out2)
+
+ def test_single_seed(self):
+ gen = StackedRandomGenerator(device="cpu", seeds=[0])
+ out = gen.randn([1, 10])
+ assert out.shape == (1, 10)
+
+
+############################################################################
+# InfiniteSampler #
+############################################################################
+
+
+class TestInfiniteSampler:
+ """Tests for InfiniteSampler."""
+
+ @pytest.fixture
+ def simple_dataset(self):
+ """A minimal dataset with 10 items."""
+ return list(range(10))
+
+ def test_yields_indices(self, simple_dataset):
+ sampler = InfiniteSampler(simple_dataset, shuffle=False)
+ it = iter(sampler)
+ indices = [next(it) for _ in range(10)]
+ assert indices == list(range(10))
+
+ def test_infinite_iteration(self, simple_dataset):
+ sampler = InfiniteSampler(simple_dataset, shuffle=False)
+ it = iter(sampler)
+ # Should be able to draw more samples than the dataset size
+ indices = [next(it) for _ in range(25)]
+ assert len(indices) == 25
+
+ def test_loops_over_dataset(self, simple_dataset):
+ sampler = InfiniteSampler(simple_dataset, shuffle=False)
+ it = iter(sampler)
+ first_pass = [next(it) for _ in range(10)]
+ second_pass = [next(it) for _ in range(10)]
+ assert first_pass == list(range(10))
+ assert second_pass == list(range(10))
+
+ def test_shuffle_produces_different_order(self, simple_dataset):
+ sampler = InfiniteSampler(simple_dataset, shuffle=True, seed=42)
+ it = iter(sampler)
+ indices = [next(it) for _ in range(10)]
+ # With shuffling, the indices should not be in sorted order
+ # (extremely unlikely for seed=42 with 10 items)
+ assert indices != list(range(10))
+
+ def test_going_through_full_dataset_with_shuffle(self, simple_dataset):
+ sampler = InfiniteSampler(simple_dataset, shuffle=True, seed=123)
+ it = iter(sampler)
+ seen = set()
+ for _ in range(10):
+ idx = next(it)
+ assert idx not in seen # should see each index once before repeats
+ seen.add(idx)
+
+ def test_seed_reproducibility(self, simple_dataset):
+ sampler1 = InfiniteSampler(simple_dataset, shuffle=True, seed=123)
+ sampler2 = InfiniteSampler(simple_dataset, shuffle=True, seed=123)
+ it1 = iter(sampler1)
+ it2 = iter(sampler2)
+ for _ in range(30):
+ assert next(it1) == next(it2)
+
+ def test_different_seeds_different_order(self, simple_dataset):
+ sampler1 = InfiniteSampler(simple_dataset, shuffle=True, seed=1)
+ sampler2 = InfiniteSampler(simple_dataset, shuffle=True, seed=999)
+ it1 = iter(sampler1)
+ it2 = iter(sampler2)
+ seq1 = [next(it1) for _ in range(20)]
+ seq2 = [next(it2) for _ in range(20)]
+ assert seq1 != seq2
+
+ def test_distributed_sampling(self, simple_dataset):
+ """Each rank should yield non-overlapping indices."""
+ sampler0 = InfiniteSampler(
+ simple_dataset, rank=0, num_replicas=2, shuffle=False
+ )
+ sampler1 = InfiniteSampler(
+ simple_dataset, rank=1, num_replicas=2, shuffle=False
+ )
+ it0 = iter(sampler0)
+ it1 = iter(sampler1)
+ indices0 = [next(it0) for _ in range(5)]
+ indices1 = [next(it1) for _ in range(5)]
+ # The two ranks should receive different indices
+ assert set(indices0) != set(indices1)
+
+ def test_start_idx(self, simple_dataset):
+ sampler_default = InfiniteSampler(simple_dataset, shuffle=False, start_idx=0)
+ sampler_offset = InfiniteSampler(simple_dataset, shuffle=False, start_idx=5)
+ it_default = iter(sampler_default)
+ it_offset = iter(sampler_offset)
+ # Skip the first 5 from the default sampler
+ for _ in range(5):
+ next(it_default)
+ # Now they should be aligned
+ for _ in range(10):
+ assert next(it_default) == next(it_offset)
+
+ def test_start_idx_larger_than_dataset_size(self, simple_dataset):
+ sampler = InfiniteSampler(simple_dataset, shuffle=False, start_idx=12)
+ it = iter(sampler)
+ # start_idx=12 should wrap around to index 2 on the first yield
+ assert next(it) == 2
+
+ def test_window_size_zero_no_shuffle_effect(self, simple_dataset):
+ sampler = InfiniteSampler(
+ simple_dataset, shuffle=True, seed=42, window_size=0.0
+ )
+ it = iter(sampler)
+ # With window_size=0, window rounds to 0, so no swapping occurs.
+ # Items come out in seed-shuffled initial order but stay fixed.
+ indices = [next(it) for _ in range(10)]
+ assert len(set(indices)) == 10 # all unique in first pass
+
+ # --- Validation tests ---
+
+ def test_empty_dataset_raises(self):
+ with pytest.raises(ValueError, match="at least one item"):
+ InfiniteSampler([])
+
+ def test_invalid_num_replicas_raises(self, simple_dataset):
+ with pytest.raises(ValueError, match="num_replicas must be positive"):
+ InfiniteSampler(simple_dataset, num_replicas=0)
+
+ def test_invalid_rank_raises(self, simple_dataset):
+ with pytest.raises(ValueError, match="rank must be non-negative"):
+ InfiniteSampler(simple_dataset, rank=-1, num_replicas=2)
+
+ def test_rank_exceeds_replicas_raises(self, simple_dataset):
+ with pytest.raises(ValueError, match="rank must be non-negative"):
+ InfiniteSampler(simple_dataset, rank=3, num_replicas=2)
+
+ def test_invalid_window_size_raises(self, simple_dataset):
+ with pytest.raises(ValueError, match="window_size must be between"):
+ InfiniteSampler(simple_dataset, window_size=1.5)
+
+ def test_negative_window_size_raises(self, simple_dataset):
+ with pytest.raises(ValueError, match="window_size must be between"):
+ InfiniteSampler(simple_dataset, window_size=-0.1)
+
diff --git a/tests/utils/test_inference_utils.py b/tests/utils/test_inference_utils.py
new file mode 100644
index 00000000..d6d4ae50
--- /dev/null
+++ b/tests/utils/test_inference_utils.py
@@ -0,0 +1,423 @@
+import pytest
+import os
+import numpy as np
+import torch
+from unittest.mock import MagicMock, patch
+
+from hirad.utils.inference_utils import (
+ calculate_bounds,
+ regression_step,
+ diffusion_step,
+ save_results_as_torch,
+)
+
+
+############################################################################
+# calculate_bounds #
+############################################################################
+
+
+class TestCalculateBounds:
+ """Tests for calculate_bounds."""
+
+ def test_single_array(self):
+ arr = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
+ vmin, vmax = calculate_bounds(arr)
+ assert vmin == 1.0
+ assert vmax == 5.0
+
+ def test_multiple_arrays(self):
+ a = np.array([1.0, 5.0])
+ b = np.array([-3.0, 2.0])
+ c = np.array([0.0, 10.0])
+ vmin, vmax = calculate_bounds(a, b, c)
+ assert vmin == -3.0
+ assert vmax == 10.0
+
+ def test_no_arrays(self):
+ vmin, vmax = calculate_bounds()
+ assert vmin is None
+ assert vmax is None
+
+ def test_all_none(self):
+ vmin, vmax = calculate_bounds(None, None)
+ assert vmin is None
+ assert vmax is None
+
+ def test_some_none(self):
+ arr = np.array([2.0, 8.0])
+ vmin, vmax = calculate_bounds(None, arr, None)
+ assert vmin == 2.0
+ assert vmax == 8.0
+
+ def test_single_value_array(self):
+ arr = np.array([42.0])
+ vmin, vmax = calculate_bounds(arr)
+ assert vmin == 42.0
+ assert vmax == 42.0
+
+ def test_negative_values(self):
+ arr = np.array([-10.0, -5.0, -1.0])
+ vmin, vmax = calculate_bounds(arr)
+ assert vmin == -10.0
+ assert vmax == -1.0
+
+ def test_masked_array(self):
+ data = np.array([1.0, np.nan, 3.0, 4.0, np.nan])
+ masked = np.ma.masked_invalid(data)
+ vmin, vmax = calculate_bounds(masked)
+ assert vmin == 1.0
+ assert vmax == 4.0
+
+ def test_masked_array_all_masked(self):
+ data = np.ma.array([1.0, 2.0], mask=[True, True])
+ vmin, vmax = calculate_bounds(data)
+ assert vmin == None
+ assert vmax == None
+
+ def test_mixed_masked_and_regular(self):
+ regular = np.array([0.0, 5.0])
+ masked = np.ma.masked_invalid(np.array([np.nan, 10.0, np.nan]))
+ vmin, vmax = calculate_bounds(regular, masked)
+ assert vmin == 0.0
+ assert vmax == 10.0
+
+ def test_2d_array(self):
+ arr = np.array([[1.0, 2.0], [3.0, 4.0]])
+ vmin, vmax = calculate_bounds(arr)
+ assert vmin == 1.0
+ assert vmax == 4.0
+
+ def test_scalar_input(self):
+ vmin, vmax = calculate_bounds(np.float64(3.14))
+ assert vmin == pytest.approx(3.14)
+ assert vmax == pytest.approx(3.14)
+
+
+############################################################################
+# regression_step #
+############################################################################
+
+
+class TestRegressionStep:
+ """Tests for regression_step."""
+
+ def _make_mock_net(self, output_shape):
+ """Create a mock network that returns a tensor of the given shape."""
+ net = MagicMock()
+ net.return_value = torch.randn(output_shape)
+ return net
+
+ def test_batch_size_greater_than_1_raises(self):
+ net = self._make_mock_net((1, 4, 8, 8))
+ img_lr = torch.randn(2, 3, 8, 8) # batch_size=2
+ latents_shape = torch.Size([1, 4, 8, 8])
+ with pytest.raises(ValueError, match="batch size of 1"):
+ regression_step(net, img_lr, latents_shape)
+
+ def test_batch_size_1_succeeds(self):
+ net = self._make_mock_net((1, 4, 8, 8))
+ img_lr = torch.randn(1, 3, 8, 8)
+ latents_shape = torch.Size([1, 4, 8, 8])
+ result = regression_step(net, img_lr, latents_shape)
+ assert result.shape == torch.Size([1, 4, 8, 8])
+
+ def test_output_replicated_when_latents_batch_gt_1(self):
+ net = self._make_mock_net((1, 4, 8, 8))
+ img_lr = torch.randn(1, 3, 8, 8)
+ latents_shape = torch.Size([5, 4, 8, 8])
+ result = regression_step(net, img_lr, latents_shape)
+ assert result.shape == torch.Size([5, 4, 8, 8])
+
+ def test_net_called_with_img_lr(self):
+ net = self._make_mock_net((1, 4, 8, 8))
+ img_lr = torch.randn(1, 3, 8, 8)
+ latents_shape = torch.Size([1, 4, 8, 8])
+ regression_step(net, img_lr, latents_shape)
+ assert net.called
+
+ def test_with_lead_time_label(self):
+ net = self._make_mock_net((1, 4, 8, 8))
+ img_lr = torch.randn(1, 3, 8, 8)
+ latents_shape = torch.Size([1, 4, 8, 8])
+ lead_time = torch.tensor([1.0])
+ result = regression_step(
+ net, img_lr, latents_shape, lead_time_label=lead_time
+ )
+ assert result.shape == torch.Size([1, 4, 8, 8])
+ # Verify lead_time_label was passed to the net
+ _, kwargs = net.call_args
+ assert "lead_time_label" in kwargs
+
+ def test_with_static_channels(self):
+ net = self._make_mock_net((1, 4, 8, 8))
+ img_lr = torch.randn(1, 3, 8, 8)
+ static = torch.randn(1, 2, 8, 8)
+ latents_shape = torch.Size([1, 4, 8, 8])
+ result = regression_step(
+ net, img_lr, latents_shape, static_channels=static
+ )
+ assert result.shape == torch.Size([1, 4, 8, 8])
+ # Net should receive img_lr concatenated with static channels (3+2=5)
+ _, kwargs = net.call_args
+ assert kwargs["img_lr"].shape[1] == 5
+
+ def test_with_date_embedding(self):
+ net = self._make_mock_net((1, 4, 8, 8))
+ img_lr = torch.randn(1, 3, 8, 8)
+ date_emb = torch.randn(1, 4)
+ latents_shape = torch.Size([1, 4, 8, 8])
+ result = regression_step(
+ net, img_lr, latents_shape, date_embedding=date_emb
+ )
+ assert result.shape == torch.Size([1, 4, 8, 8])
+ # Net should receive img_lr concatenated with date embedding (3+4=7)
+ _, kwargs = net.call_args
+ assert kwargs["img_lr"].shape[1] == 7
+
+ def test_with_all_optional_inputs(self):
+ net = self._make_mock_net((1, 4, 8, 8))
+ img_lr = torch.randn(1, 3, 8, 8)
+ static = torch.randn(1, 2, 8, 8)
+ date_emb = torch.randn(1, 4)
+ lead_time = torch.tensor([1.0])
+ latents_shape = torch.Size([1, 4, 8, 8])
+ result = regression_step(
+ net,
+ img_lr,
+ latents_shape,
+ lead_time_label=lead_time,
+ static_channels=static,
+ date_embedding=date_emb,
+ )
+ assert result.shape == torch.Size([1, 4, 8, 8])
+ # img_lr should have 3 + 2 (static) + 4 (date) = 9 channels
+ _, kwargs = net.call_args
+ assert kwargs["img_lr"].shape[1] == 9
+
+
+############################################################################
+# diffusion_step #
+############################################################################
+
+
+class TestDiffusionStep:
+ """Tests for diffusion_step."""
+
+ def test_img_lr_shape_mismatch_raises(self):
+ net = MagicMock()
+ sampler_fn = MagicMock()
+ img_lr = torch.randn(1, 3, 16, 16)
+ with pytest.raises(ValueError, match="does not match expected shape"):
+ diffusion_step(
+ net=net,
+ sampler_fn=sampler_fn,
+ img_shape=(32, 32),
+ img_out_channels=4,
+ rank_batches=[[0]],
+ img_lr=img_lr,
+ rank=0,
+ device=torch.device("cpu"),
+ )
+
+ def test_mean_hr_shape_mismatch_raises(self):
+ net = MagicMock()
+ sampler_fn = MagicMock()
+ img_lr = torch.randn(1, 3, 32, 32)
+ mean_hr = torch.randn(1, 4, 16, 16)
+ with pytest.raises(ValueError, match="does not match expected shape"):
+ diffusion_step(
+ net=net,
+ sampler_fn=sampler_fn,
+ img_shape=(32, 32),
+ img_out_channels=4,
+ rank_batches=[[0]],
+ img_lr=img_lr,
+ rank=0,
+ device=torch.device("cpu"),
+ mean_hr=mean_hr,
+ )
+
+ def test_mean_hr_batch_size_not_1_raises(self):
+ net = MagicMock()
+ sampler_fn = MagicMock()
+ img_lr = torch.randn(1, 3, 32, 32)
+ mean_hr = torch.randn(2, 4, 32, 32)
+ with pytest.raises(ValueError, match="batch size 1"):
+ diffusion_step(
+ net=net,
+ sampler_fn=sampler_fn,
+ img_shape=(32, 32),
+ img_out_channels=4,
+ rank_batches=[[0]],
+ img_lr=img_lr,
+ rank=0,
+ device=torch.device("cpu"),
+ mean_hr=mean_hr,
+ )
+
+ def test_empty_rank_batches(self):
+ net = MagicMock()
+ sampler_fn = MagicMock()
+ img_lr = torch.randn(1, 3, 8, 8)
+ # Empty batches of seeds
+ with pytest.raises(ValueError, match="rank_batches is empty"):
+ diffusion_step(
+ net=net,
+ sampler_fn=sampler_fn,
+ img_shape=(8, 8),
+ img_out_channels=4,
+ rank_batches=[],
+ img_lr=img_lr,
+ rank=0,
+ device=torch.device("cpu"),
+ )
+
+ def test_missmatch_batch_size_and_image_shape(self):
+ net = MagicMock()
+ sampler_fn = MagicMock()
+ img_lr = torch.randn(1, 3, 8, 8)
+ # rank_batches has batch size 2 but img_shape is for batch size 1
+ with pytest.raises(ValueError, match="does not match img_lr batch size"):
+ diffusion_step(
+ net=net,
+ sampler_fn=sampler_fn,
+ img_shape=(8, 8),
+ img_out_channels=4,
+ rank_batches=[[0,1], [2,3]],
+ img_lr=img_lr,
+ rank=0,
+ device=torch.device("cpu"),
+ )
+
+ def test_generates_correct_number_of_samples(self):
+ net = MagicMock()
+ generated = torch.randn(1, 4, 8, 8)
+ sampler_fn = MagicMock(return_value=generated)
+ img_lr = torch.randn(1, 3, 8, 8)
+ result = diffusion_step(
+ net=net,
+ sampler_fn=sampler_fn,
+ img_shape=(8, 8),
+ img_out_channels=4,
+ rank_batches=[[0], [1], [2]],
+ img_lr=img_lr,
+ rank=0,
+ device=torch.device("cpu"),
+ )
+ assert result.shape == torch.Size([3, 4, 8, 8])
+ assert sampler_fn.call_count == 3
+
+ def test_passes_additional_args_to_sampler(self):
+ net = MagicMock()
+ generated = torch.randn(1, 4, 8, 8)
+ sampler_fn = MagicMock(return_value=generated)
+ img_lr = torch.randn(1, 3, 8, 8)
+ mean_hr = torch.randn(1, 4, 8, 8)
+ lead_time = torch.tensor([1.0])
+ static = torch.randn(1, 2, 8, 8)
+ date_emb = torch.randn(1, 4)
+
+ diffusion_step(
+ net=net,
+ sampler_fn=sampler_fn,
+ img_shape=(8, 8),
+ img_out_channels=4,
+ rank_batches=[[42]],
+ img_lr=img_lr,
+ rank=0,
+ device=torch.device("cpu"),
+ mean_hr=mean_hr,
+ lead_time_label=lead_time,
+ static_channels=static,
+ date_embedding=date_emb,
+ )
+
+ _, kwargs = sampler_fn.call_args
+ assert "mean_hr" in kwargs
+ assert "lead_time_label" in kwargs
+ assert "static_channels" in kwargs
+ assert "date_embedding" in kwargs
+
+
+############################################################################
+# save_results_as_torch #
+############################################################################
+
+
+class TestSaveResultsAsTorch:
+ """Tests for save_results_as_torch."""
+
+ def test_creates_output_directory(self, tmp_path):
+ output_dir = tmp_path / "results" / "nested"
+ image_pred = torch.randn(2, 4, 8, 8)
+ image_hr = torch.randn(1, 4, 8, 8)
+ image_lr = torch.randn(1, 3, 8, 8)
+ mean_pred = torch.randn(1, 4, 8, 8)
+
+ save_results_as_torch(
+ str(output_dir), "step_0", image_pred, image_hr, image_lr, mean_pred
+ )
+ assert output_dir.exists()
+
+ def test_saves_all_files_with_mean(self, tmp_path):
+ image_pred = torch.randn(2, 4, 8, 8)
+ image_hr = torch.randn(1, 4, 8, 8)
+ image_lr = torch.randn(1, 3, 8, 8)
+ mean_pred = torch.randn(1, 4, 8, 8)
+
+ save_results_as_torch(
+ str(tmp_path), "step_0", image_pred, image_hr, image_lr, mean_pred
+ )
+
+ assert os.path.isfile(tmp_path / "step_0-regression-prediction")
+ assert os.path.isfile(tmp_path / "step_0-target")
+ assert os.path.isfile(tmp_path / "step_0-predictions")
+ assert os.path.isfile(tmp_path / "step_0-baseline")
+
+ def test_skips_regression_when_mean_is_none(self, tmp_path):
+ image_pred = torch.randn(2, 4, 8, 8)
+ image_hr = torch.randn(1, 4, 8, 8)
+ image_lr = torch.randn(1, 3, 8, 8)
+
+ save_results_as_torch(
+ str(tmp_path), "step_1", image_pred, image_hr, image_lr, None
+ )
+
+ assert not os.path.isfile(tmp_path / "step_1-regression-prediction")
+ assert os.path.isfile(tmp_path / "step_1-target")
+ assert os.path.isfile(tmp_path / "step_1-predictions")
+ assert os.path.isfile(tmp_path / "step_1-baseline")
+
+ def test_saved_tensors_are_loadable_and_correct(self, tmp_path):
+ image_pred = torch.randn(2, 4, 8, 8)
+ image_hr = torch.randn(1, 4, 8, 8)
+ image_lr = torch.randn(1, 3, 8, 8)
+ mean_pred = torch.randn(1, 4, 8, 8)
+
+ save_results_as_torch(
+ str(tmp_path), "t0", image_pred, image_hr, image_lr, mean_pred
+ )
+
+ loaded_pred = torch.load(tmp_path / "t0-predictions", weights_only=True)
+ loaded_hr = torch.load(tmp_path / "t0-target", weights_only=True)
+ loaded_lr = torch.load(tmp_path / "t0-baseline", weights_only=True)
+ loaded_mean = torch.load(tmp_path / "t0-regression-prediction", weights_only=True)
+
+ assert torch.equal(loaded_pred, image_pred)
+ assert torch.equal(loaded_hr, image_hr)
+ assert torch.equal(loaded_lr, image_lr)
+ assert torch.equal(loaded_mean, mean_pred)
+
+ def test_different_time_steps_dont_overwrite(self, tmp_path):
+ t1 = torch.randn(1, 4, 8, 8)
+ t2 = torch.randn(1, 4, 8, 8)
+
+ save_results_as_torch(str(tmp_path), "step_0", t1, t1, t1, None)
+ save_results_as_torch(str(tmp_path), "step_1", t2, t2, t2, None)
+
+ loaded_0 = torch.load(tmp_path / "step_0-target", weights_only=True)
+ loaded_1 = torch.load(tmp_path / "step_1-target", weights_only=True)
+
+ assert torch.equal(loaded_0, t1)
+ assert torch.equal(loaded_1, t2)
\ No newline at end of file
diff --git a/tests/utils/test_model_utils.py b/tests/utils/test_model_utils.py
new file mode 100644
index 00000000..68471282
--- /dev/null
+++ b/tests/utils/test_model_utils.py
@@ -0,0 +1,107 @@
+import pytest
+import numpy as np
+import torch
+
+from hirad.utils.model_utils import weight_init
+
+
+class TestWeightInitShape:
+ """Test that weight_init returns tensors of the correct shape."""
+
+ @pytest.mark.parametrize("mode", [
+ "xavier_uniform",
+ "xavier_normal",
+ "kaiming_uniform",
+ "kaiming_normal",
+ ])
+ @pytest.mark.parametrize("shape", [
+ (1,),
+ (3, 3),
+ (64, 32, 3, 3),
+ (128, 64, 5, 5),
+ ])
+ def test_output_shape(self, mode, shape):
+ result = weight_init(shape, mode, fan_in=32, fan_out=64)
+ assert result.shape == torch.Size(shape)
+
+ @pytest.mark.parametrize("mode", [
+ "xavier_uniform",
+ "xavier_normal",
+ "kaiming_uniform",
+ "kaiming_normal",
+ ])
+ def test_output_is_tensor(self, mode):
+ result = weight_init((4, 4), mode, fan_in=4, fan_out=4)
+ assert isinstance(result, torch.Tensor)
+
+
+class TestWeightInitValues:
+ """Test that weight_init produces values within expected bounds."""
+
+ def test_xavier_uniform_bound(self):
+ fan_in, fan_out = 256, 512
+ bound = np.sqrt(6 / (fan_in + fan_out))
+ result = weight_init((10000,), "xavier_uniform", fan_in, fan_out)
+ assert result.min() >= -bound - 1e-7
+ assert result.max() <= bound + 1e-7
+
+ def test_kaiming_uniform_bound(self):
+ fan_in, fan_out = 256, 512
+ bound = np.sqrt(3 / fan_in)
+ result = weight_init((10000,), "kaiming_uniform", fan_in, fan_out)
+ assert result.min() >= -bound - 1e-7
+ assert result.max() <= bound + 1e-7
+
+ def test_xavier_normal_mean_and_std(self):
+ fan_in, fan_out = 256, 512
+ expected_std = np.sqrt(2 / (fan_in + fan_out))
+ result = weight_init((100000,), "xavier_normal", fan_in, fan_out)
+ assert abs(result.mean().item()) < 0.05
+ assert abs(result.std().item() - expected_std) < 0.01
+
+ def test_kaiming_normal_mean_and_std(self):
+ fan_in, fan_out = 256, 512
+ expected_std = np.sqrt(1 / fan_in)
+ result = weight_init((100000,), "kaiming_normal", fan_in, fan_out)
+ assert abs(result.mean().item()) < 0.05
+ assert abs(result.std().item() - expected_std) < 0.01
+
+
+class TestWeightInitSymmetry:
+ """Test that uniform modes are centered around zero."""
+
+ @pytest.mark.parametrize("mode", ["xavier_uniform", "kaiming_uniform"])
+ def test_uniform_centered_around_zero(self, mode):
+ result = weight_init((100000,), mode, fan_in=128, fan_out=128)
+ assert abs(result.mean().item()) < 0.05
+
+
+class TestWeightInitScaling:
+ """Test that changing fan_in/fan_out changes the scale of outputs."""
+
+ @pytest.mark.parametrize("mode", [
+ "xavier_uniform",
+ "xavier_normal",
+ "kaiming_uniform",
+ "kaiming_normal",
+ ])
+ def test_larger_fan_in_reduces_scale(self, mode):
+ small_fan = weight_init((10000,), mode, fan_in=16, fan_out=16)
+ large_fan = weight_init((10000,), mode, fan_in=1024, fan_out=1024)
+ assert small_fan.std() > large_fan.std()
+
+
+class TestWeightInitInvalidMode:
+ """Test that invalid modes raise ValueError."""
+
+ @pytest.mark.parametrize("mode", [
+ "invalid",
+ "",
+ "Xavier_uniform",
+ "KAIMING_NORMAL",
+ "he_normal",
+ "glorot_uniform",
+ ])
+ def test_invalid_mode_raises_value_error(self, mode):
+ with pytest.raises(ValueError, match="Invalid init mode"):
+ weight_init((4, 4), mode, fan_in=4, fan_out=4)
\ No newline at end of file
diff --git a/tests/utils/test_patching.py b/tests/utils/test_patching.py
new file mode 100644
index 00000000..05c7d8ac
--- /dev/null
+++ b/tests/utils/test_patching.py
@@ -0,0 +1,644 @@
+import math
+
+import pytest
+import torch
+
+from hirad.utils.patching import (
+ BasePatching2D,
+ GridPatching2D,
+ RandomPatching2D,
+ image_batching,
+ image_fuse,
+)
+
+
+############################################################################
+# BasePatching2D — init #
+############################################################################
+
+
+class TestBasePatching2DInit:
+ """Tests for BasePatching2D initialization and validation."""
+
+ def test_non_2d_img_shape_raises(self):
+ """img_shape with wrong number of dimensions should raise ValueError."""
+ with pytest.raises(ValueError, match="img_shape must be 2D"):
+ # Use GridPatching2D as a concrete subclass
+ GridPatching2D(img_shape=(64, 64, 3), patch_shape=(32, 32))
+
+ def test_non_2d_patch_shape_raises(self):
+ """patch_shape with wrong number of dimensions should raise ValueError."""
+ with pytest.raises(ValueError, match="patch_shape must be 2D"):
+ GridPatching2D(img_shape=(64, 64), patch_shape=(32, 32, 1))
+
+ def test_patch_larger_than_image_warns(self):
+ """patch_shape larger than img_shape should issue a warning."""
+ with pytest.warns(UserWarning, match="larger than"):
+ GridPatching2D(img_shape=(32, 32), patch_shape=(64, 64))
+
+ def test_patch_clamped_to_image_shape(self):
+ """patch_shape should be clamped to img_shape when it exceeds it."""
+ with pytest.warns(UserWarning):
+ patcher = GridPatching2D(img_shape=(32, 48), patch_shape=(64, 64))
+ assert patcher.patch_shape == (32, 48)
+
+ def test_valid_shapes_stored(self):
+ patcher = GridPatching2D(img_shape=(64, 128), patch_shape=(32, 32))
+ assert patcher.img_shape == (64, 128)
+ assert patcher.patch_shape == (32, 32)
+
+
+############################################################################
+# BasePatching2D — global_index #
+############################################################################
+
+
+class TestBasePatching2DGlobalIndex:
+ """Tests for the global_index method."""
+
+ def test_global_index_shape(self):
+ patcher = GridPatching2D(img_shape=(64, 64), patch_shape=(32, 32))
+ gi = patcher.global_index(batch_size=1)
+ assert gi.ndim == 4
+ assert gi.shape[1] == 2
+ assert gi.shape[2] == patcher.patch_shape[0]
+ assert gi.shape[3] == patcher.patch_shape[1]
+
+ def test_global_index_values_within_image(self):
+ """All global indices should fall within the original image dimensions."""
+ patcher = GridPatching2D(
+ img_shape=(64, 128), patch_shape=(32, 32), overlap_pix=4
+ )
+ gi = patcher.global_index(batch_size=1)
+ assert gi[:, 0].min() >= 0
+ assert gi[:, 1].min() >= 0
+ # Padded indices may exceed image shape, but y/x coords should be valid
+ assert gi[:, 0].max() < patcher.img_shape[0] + patcher.patch_shape[0]
+ assert gi[:, 1].max() < patcher.img_shape[1] + patcher.patch_shape[1]
+
+ def test_global_index_values_simple_case(self):
+ """ In a simple 4x4 image with 2x2 patches and no overlap, global indices should be predictable. """
+ patcher = GridPatching2D(img_shape=(4, 4), patch_shape=(2, 2), overlap_pix=0)
+ gi = patcher.global_index(batch_size=1)
+ expected_indices = torch.tensor([
+ [[[0, 0], [1, 1]], [[0, 1], [0, 1]]],
+ [[[2, 2], [3, 3]], [[0, 1], [0, 1]]],
+ [[[0, 0], [1, 1]], [[2, 3], [2, 3]]],
+ [[[2, 2], [3, 3]], [[2, 3], [2, 3]]]
+ ])
+ assert torch.equal(gi.cpu(), expected_indices)
+
+ def test_global_index_device(self):
+ patcher = GridPatching2D(img_shape=(32, 32), patch_shape=(16, 16))
+ gi = patcher.global_index(batch_size=1, device="cpu")
+ assert gi.device == torch.device("cpu")
+
+
+############################################################################
+# BasePatching2D — fuse not implemented #
+############################################################################
+
+
+class TestBasePatching2DFuse:
+ """Tests that fuse raises NotImplementedError for subclasses that don't implement it."""
+
+ def test_random_patching_fuse_raises(self):
+ """RandomPatching2D does not implement fuse."""
+ patcher = RandomPatching2D(
+ img_shape=(64, 64), patch_shape=(32, 32), patch_num=4
+ )
+ dummy = torch.randn(4, 3, 32, 32)
+ with pytest.raises(NotImplementedError, match="fuse"):
+ patcher.fuse(dummy)
+
+
+############################################################################
+# BasePatching2D — apply abstract method #
+############################################################################
+
+
+class TestBasePatching2DApply:
+ """Tests that apply raises NotImplementedError for subclasses that don't implement it."""
+
+ def test_grid_patching_apply_raises(self):
+ """BasePatching2D does not implement apply."""
+ with pytest.raises(TypeError, match="apply"):
+ patcher = BasePatching2D(img_shape=(64, 64), patch_shape=(32, 32))
+
+
+############################################################################
+# RandomPatching2D — init #
+############################################################################
+
+
+class TestRandomPatching2DInit:
+ """Tests for RandomPatching2D initialization."""
+
+ def test_patch_num_stored(self):
+ patcher = RandomPatching2D(
+ img_shape=(64, 64), patch_shape=(32, 32), patch_num=8
+ )
+ assert patcher.patch_num == 8
+
+ def test_patch_indices_generated_on_init(self):
+ patcher = RandomPatching2D(
+ img_shape=(64, 64), patch_shape=(32, 32), patch_num=5
+ )
+ assert len(patcher.patch_indices) == 5
+
+ def test_patch_indices_within_bounds(self):
+ img_h, img_w = 100, 120
+ patch_h, patch_w = 30, 40
+ patcher = RandomPatching2D(
+ img_shape=(img_h, img_w), patch_shape=(patch_h, patch_w), patch_num=20
+ )
+ for py, px in patcher.patch_indices:
+ assert 0 <= py <= img_h - patch_h
+ assert 0 <= px <= img_w - patch_w
+
+
+############################################################################
+# RandomPatching2D — set / reset indices #
+############################################################################
+
+
+class TestRandomPatching2DIndices:
+ """Tests for patch index manipulation."""
+
+ def test_reset_changes_indices(self):
+ """Resetting should produce new random indices (with overwhelming probability)."""
+ patcher = RandomPatching2D(
+ img_shape=(256, 256), patch_shape=(32, 32), patch_num=50
+ )
+ old_indices = list(patcher.patch_indices)
+ patcher.reset_patch_indices()
+ # Extremely unlikely to be identical for 50 patches on a 256x256 image
+ assert patcher.patch_indices != old_indices
+
+ def test_get_patch_indices_returns_current(self):
+ patcher = RandomPatching2D(
+ img_shape=(64, 64), patch_shape=(16, 16), patch_num=3
+ )
+ assert patcher.get_patch_indices() is patcher.patch_indices
+
+ def test_set_patch_num_updates_count_and_indices(self):
+ patcher = RandomPatching2D(
+ img_shape=(64, 64), patch_shape=(16, 16), patch_num=3
+ )
+ patcher.set_patch_num(10)
+ assert patcher.patch_num == 10
+ assert len(patcher.patch_indices) == 10
+
+
+############################################################################
+# RandomPatching2D — apply #
+############################################################################
+
+
+class TestRandomPatching2DApply:
+ """Tests for RandomPatching2D.apply."""
+
+ @pytest.fixture
+ def patcher_and_input(self):
+ img_shape = (64, 64)
+ patch_shape = (16, 16)
+ patch_num = 4
+ patcher = RandomPatching2D(img_shape, patch_shape, patch_num)
+ batch_size, channels = 2, 3
+ x = torch.randn(batch_size, channels, *img_shape)
+ return patcher, x, batch_size, channels
+
+ def test_output_shape(self, patcher_and_input):
+ patcher, x, batch_size, channels = patcher_and_input
+ out = patcher.apply(x)
+ assert out.shape == (
+ batch_size * patcher.patch_num,
+ channels,
+ patcher.patch_shape[0],
+ patcher.patch_shape[1],
+ )
+
+ def test_output_values_match_input_slices(self, patcher_and_input):
+ """Each patch in the output should correspond to the correct slice of input."""
+ patcher, x, batch_size, _ = patcher_and_input
+ out = patcher.apply(x)
+ for i, (py, px) in enumerate(patcher.patch_indices):
+ expected = x[
+ :, :,
+ py : py + patcher.patch_shape[0],
+ px : px + patcher.patch_shape[1],
+ ]
+ torch.testing.assert_close(
+ out[batch_size * i : batch_size * (i + 1)], expected
+ )
+
+ def test_apply_with_additional_input(self, patcher_and_input):
+ """Additional input should be concatenated along channel dim."""
+ patcher, x, batch_size, channels = patcher_and_input
+ add_channels = 2
+ additional = torch.randn(batch_size, add_channels, 32, 32)
+ out = patcher.apply(x, additional_input=additional)
+ assert out.shape[1] == channels + add_channels
+ assert torch.allclose(out[:, :channels], patcher.apply(x))
+ assert torch.allclose(out[:, channels:], torch.nn.functional.interpolate(
+ additional, size=patcher.patch_shape, mode="bilinear").repeat(patcher.patch_num, 1, 1, 1))
+
+ def test_apply_single_patch(self):
+ """Test with a single patch."""
+ patcher = RandomPatching2D(
+ img_shape=(32, 32), patch_shape=(32, 32), patch_num=1
+ )
+ x = torch.randn(1, 1, 32, 32)
+ out = patcher.apply(x)
+ torch.testing.assert_close(out, x)
+
+
+############################################################################
+# GridPatching2D — init #
+############################################################################
+
+
+class TestGridPatching2DInit:
+ """Tests for GridPatching2D initialization."""
+
+ def test_patch_num_no_overlap(self):
+ """Without overlap, patches should tile the image exactly."""
+ patcher = GridPatching2D(img_shape=(64, 64), patch_shape=(32, 32))
+ expected_x = math.ceil(64 / 32)
+ expected_y = math.ceil(64 / 32)
+ assert patcher.patch_num == expected_x * expected_y
+
+ def test_patch_num_with_overlap(self):
+ patcher = GridPatching2D(
+ img_shape=(64, 64), patch_shape=(32, 32), overlap_pix=8
+ )
+ expected_x = math.ceil(64 / (32 - 8))
+ expected_y = math.ceil(64 / (32 - 8))
+ assert patcher.patch_num == expected_x * expected_y
+
+ def test_patch_num_with_boundary(self):
+ patcher = GridPatching2D(
+ img_shape=(64, 64), patch_shape=(32, 32), boundary_pix=4
+ )
+ expected_x = math.ceil(64 / (32 - 4))
+ expected_y = math.ceil(64 / (32 - 4))
+ assert patcher.patch_num == expected_x * expected_y
+
+ def test_patch_num_with_overlap_and_boundary(self):
+ patcher = GridPatching2D(
+ img_shape=(64, 64), patch_shape=(32, 32),
+ overlap_pix=8, boundary_pix=10
+ )
+ expected_x = math.ceil(64 / (32 - 8 - 10))
+ expected_y = math.ceil(64 / (32 - 8 - 10))
+ assert patcher.patch_num == expected_x * expected_y
+
+ def test_non_divisible_image_shape(self):
+ """Image dimensions that don't divide evenly by stride should still work."""
+ patcher = GridPatching2D(img_shape=(100, 77), patch_shape=(32, 32))
+ assert patcher.patch_num == math.ceil(100 / 32) * math.ceil(77 / 32)
+
+
+############################################################################
+# GridPatching2D — apply #
+############################################################################
+
+
+class TestGridPatching2DApply:
+ """Tests for GridPatching2D.apply."""
+
+ @pytest.fixture
+ def grid_patcher_and_input(self):
+ img_shape = (64, 64)
+ patch_shape = (32, 32)
+ patcher = GridPatching2D(img_shape, patch_shape)
+ batch_size, channels = 2, 3
+ x = torch.randn(batch_size, channels, *img_shape)
+ return patcher, x, batch_size, channels
+
+ def test_output_shape(self, grid_patcher_and_input):
+ patcher, x, batch_size, channels = grid_patcher_and_input
+ out = patcher.apply(x)
+ assert out.shape == (
+ batch_size * patcher.patch_num,
+ channels,
+ patcher.patch_shape[0],
+ patcher.patch_shape[1],
+ )
+
+ def test_output_shape_with_additional_input(self, grid_patcher_and_input):
+ patcher, x, batch_size, channels = grid_patcher_and_input
+ add_channels = 5
+ additional = torch.randn(batch_size, add_channels, 16, 16)
+ out = patcher.apply(x, additional_input=additional)
+ assert out.shape == (
+ batch_size * patcher.patch_num,
+ channels + add_channels,
+ patcher.patch_shape[0],
+ patcher.patch_shape[1],
+ )
+
+ def test_apply_with_overlap(self):
+ patcher = GridPatching2D(
+ img_shape=(64, 64), patch_shape=(32, 32), overlap_pix=8
+ )
+ x = torch.randn(1, 1, 64, 64)
+ out = patcher.apply(x)
+ assert out.shape == (patcher.patch_num, 1, 32, 32)
+
+ def test_apply_with_boundary(self):
+ patcher = GridPatching2D(
+ img_shape=(64, 64), patch_shape=(32, 32), boundary_pix=4
+ )
+ x = torch.randn(1, 1, 64, 64)
+ out = patcher.apply(x)
+ assert out.shape == (patcher.patch_num, 1, 32, 32)
+
+ def test_apply_with_overlap_and_boundary(self):
+ patcher = GridPatching2D(
+ img_shape=(64, 64), patch_shape=(32, 32),
+ overlap_pix=8, boundary_pix=10
+ )
+ x = torch.randn(1, 1, 64, 64)
+ out = patcher.apply(x)
+ assert out.shape == (patcher.patch_num, 1, 32, 32)
+
+
+############################################################################
+# GridPatching2D — fuse #
+############################################################################
+
+
+class TestGridPatching2DFuse:
+ """Tests for GridPatching2D.fuse."""
+
+ def test_fuse_output_shape(self):
+ img_shape = (64, 64)
+ patcher = GridPatching2D(img_shape, patch_shape=(32, 32))
+ batch_size, channels = 2, 3
+ patches = torch.randn(
+ batch_size * patcher.patch_num, channels,
+ patcher.patch_shape[0], patcher.patch_shape[1],
+ )
+ fused = patcher.fuse(patches, batch_size=batch_size)
+ assert fused.shape == (batch_size, channels, *img_shape)
+
+ def test_fuse_output_shape_with_overlap(self):
+ img_shape = (64, 128)
+ patcher = GridPatching2D(
+ img_shape, patch_shape=(32, 32), overlap_pix=8
+ )
+ batch_size, channels = 2, 2
+ patches = torch.randn(
+ batch_size * patcher.patch_num, channels,
+ patcher.patch_shape[0], patcher.patch_shape[1],
+ )
+ fused = patcher.fuse(patches, batch_size=batch_size)
+ assert fused.shape == (batch_size, channels, *img_shape)
+
+
+############################################################################
+# GridPatching2D — roundtrip (apply → fuse) #
+############################################################################
+
+
+class TestGridPatching2DRoundtrip:
+ """Tests that apply followed by fuse reconstructs the original image."""
+
+ @pytest.mark.parametrize(
+ "img_shape, patch_shape, overlap_pix, boundary_pix",
+ [
+ ((64, 64), (32, 32), 0, 0),
+ ((64, 64), (32, 32), 8, 0),
+ ((64, 64), (32, 32), 0, 4),
+ ((64, 64), (32, 32), 8, 4),
+ ((100, 77), (32, 32), 4, 2),
+ ((48, 96), (24, 48), 6, 0),
+ ],
+ )
+ def test_roundtrip_reconstructs_image(
+ self, img_shape, patch_shape, overlap_pix, boundary_pix
+ ):
+ """Patching and then fusing should recover the original image."""
+ patcher = GridPatching2D(
+ img_shape, patch_shape,
+ overlap_pix=overlap_pix, boundary_pix=boundary_pix,
+ )
+ batch_size, channels = 2, 3
+ x = torch.randn(batch_size, channels, *img_shape)
+ patches = patcher.apply(x)
+ reconstructed = patcher.fuse(patches, batch_size=batch_size)
+ torch.testing.assert_close(reconstructed, x, atol=1e-5, rtol=1e-5)
+
+ def test_roundtrip_single_patch_covers_image(self):
+ """A single patch covering the full image should roundtrip exactly."""
+ img_shape = (32, 32)
+ patcher = GridPatching2D(img_shape, patch_shape=(32, 32))
+ x = torch.randn(1, 1, *img_shape)
+ patches = patcher.apply(x)
+ reconstructed = patcher.fuse(patches, batch_size=1)
+ torch.testing.assert_close(reconstructed, x)
+
+ def test_roundtrip_preserves_dtype(self):
+ patcher = GridPatching2D(
+ img_shape=(64, 64), patch_shape=(32, 32), overlap_pix=4
+ )
+ x = torch.randn(1, 1, 64, 64, dtype=torch.float64)
+ patches = patcher.apply(x)
+ reconstructed = patcher.fuse(patches, batch_size=1)
+ assert reconstructed.dtype == x.dtype
+
+
+############################################################################
+# image_batching — function #
+############################################################################
+
+
+class TestImageBatching:
+ """Tests for the image_batching standalone function."""
+
+ def test_output_shape_no_overlap(self):
+ x = torch.randn(2, 3, 64, 64)
+ out = image_batching(x, patch_shape_y=32, patch_shape_x=32,
+ overlap_pix=0, boundary_pix=0)
+ patch_num = math.ceil(64 / 32) * math.ceil(64 / 32)
+ assert out.shape == (patch_num * 2, 3, 32, 32)
+
+ def test_output_shape_with_interp(self):
+ batch_size = 2
+ x = torch.randn(batch_size, 3, 64, 64)
+ interp = torch.randn(batch_size, 5, 32, 32)
+ out = image_batching(x, 32, 32, overlap_pix=0, boundary_pix=0,
+ input_interp=interp)
+ assert out.shape == (math.ceil(64 / 32) * math.ceil(64 / 32) * batch_size, 3 + 5, 32, 32)
+
+ def test_invalid_patch_shape_x_raises(self):
+ x = torch.randn(1, 1, 64, 64)
+ with pytest.raises(ValueError, match="patch_shape_x"):
+ image_batching(x, patch_shape_y=32, patch_shape_x=2,
+ overlap_pix=1, boundary_pix=1)
+
+ def test_invalid_patch_shape_y_raises(self):
+ x = torch.randn(1, 1, 64, 64)
+ with pytest.raises(ValueError, match="patch_shape_y"):
+ image_batching(x, patch_shape_y=2, patch_shape_x=32,
+ overlap_pix=1, boundary_pix=1)
+
+ def test_interp_batch_mismatch_raises(self):
+ x = torch.randn(2, 3, 64, 64)
+ interp = torch.randn(3, 5, 32, 32) # wrong batch size
+ with pytest.raises(ValueError, match="batch size"):
+ image_batching(x, 32, 32, 0, 0, input_interp=interp)
+
+ def test_interp_shape_mismatch_raises(self):
+ x = torch.randn(2, 3, 64, 64)
+ interp = torch.randn(2, 5, 16, 16) # wrong spatial dims
+ with pytest.raises(ValueError, match="patch shape"):
+ image_batching(x, 32, 32, 0, 0, input_interp=interp)
+
+ def test_patch_too_small_for_overlap_and_boundary_x_raises(self):
+ x = torch.randn(1, 1, 64, 64)
+ with pytest.raises(ValueError, match="patch_shape_x"):
+ image_batching(x, 32, 11, overlap_pix=5, boundary_pix=3)
+
+ def test_patch_too_small_for_overlap_and_boundary_y_raises(self):
+ x = torch.randn(1, 1, 64, 64)
+ with pytest.raises(ValueError, match="patch_shape_y"):
+ image_batching(x, 11, 32, overlap_pix=5, boundary_pix=3)
+
+ def test_int32_input_preserves_dtype(self):
+ x = torch.randint(0, 100, (1, 1, 32, 32), dtype=torch.int32)
+ out = image_batching(x, 16, 16, 0, 0)
+ assert out.dtype == torch.int32
+
+ def test_int64_input_preserves_dtype(self):
+ x = torch.randint(0, 100, (1, 1, 32, 32), dtype=torch.int64)
+ out = image_batching(x, 16, 16, 0, 0)
+ assert out.dtype == torch.int64
+
+ def test_patch_is_matching_the_original(self):
+ """Patches should match the corresponding slices of the original image."""
+ x = torch.randn(3, 2, 32, 32)
+ patches = image_batching(x, 16, 16, overlap_pix=0, boundary_pix=0)
+ expected_patches = torch.cat([
+ x[:, :, 0:16, 0:16],
+ x[:, :, 16:32, 0:16],
+ x[:, :, 0:16, 16:32],
+ x[:, :, 16:32, 16:32],
+ ], dim=0)
+ torch.testing.assert_close(patches, expected_patches)
+
+ def test_patch_is_matching_the_original_with_overlap_and_boundary(self):
+ """Patches should match the corresponding slices of the original image, even with overlap and boundary."""
+ x = torch.randn(3, 2, 32, 32)
+ patches = image_batching(x, 16, 16, overlap_pix=4, boundary_pix=2)
+ # test if the patches at the corners and center match the expected slices of the original image
+ # where padding is applied, compare only to the valid region of the original image
+ # padding can be changed without affecting the validity of the extracted patch region, so we focus on the original image slices
+ expected_patch_middle = x[:, :, 8:24, 8:24]
+ expected_patch_top_left = x[:, :, 0:14, 0:14]
+ expected_patch_bottom_left = x[:, :, 28:, 0:14]
+ expected_patch_top_right = x[:, :, 0:14, 28:]
+ expected_patch_bottom_right = x[:, :, 28:, 28:]
+ torch.testing.assert_close(patches[3*5:3*6], expected_patch_middle)
+ torch.testing.assert_close(patches[0:3,:,2:,2:], expected_patch_top_left)
+ torch.testing.assert_close(patches[3*3:3*4,:,:4,2:], expected_patch_bottom_left)
+ torch.testing.assert_close(patches[3*12:3*13, :, 2:, :4], expected_patch_top_right)
+ torch.testing.assert_close(patches[3*15:, :, :4, :4], expected_patch_bottom_right)
+
+############################################################################
+# image_fuse — function #
+############################################################################
+
+
+class TestImageFuse:
+ """Tests for the image_fuse standalone function."""
+
+ def test_output_shape(self):
+ batch_size = 2
+ img_shape_y, img_shape_x = 64, 64
+ patch_shape_y, patch_shape_x = 32, 32
+ patch_num_x = math.ceil(img_shape_x / patch_shape_x)
+ patch_num_y = math.ceil(img_shape_y / patch_shape_y)
+ patch_num = patch_num_x * patch_num_y
+ channels = 3
+ patches = torch.randn(patch_num * batch_size, channels,
+ patch_shape_y, patch_shape_x)
+ out = image_fuse(patches, img_shape_y, img_shape_x,
+ batch_size, overlap_pix=0, boundary_pix=0)
+ assert out.shape == (batch_size, channels, img_shape_y, img_shape_x)
+
+ def test_fuse_constant_patches(self):
+ """Fusing constant-valued patches should yield a constant image."""
+ val = 5.0
+ img_shape = (32, 32)
+ patcher = GridPatching2D(img_shape, patch_shape=(16, 16))
+ patches = torch.full(
+ (patcher.patch_num, 1, 16, 16), val
+ )
+ fused = image_fuse(patches, img_shape[0], img_shape[1],
+ batch_size=1, overlap_pix=0, boundary_pix=0)
+ torch.testing.assert_close(fused, torch.full((1, 1, *img_shape), val))
+
+ #TODO: after normalizing by overlap count, the output may not be exactly the same as the input for integer types, so we would need to round and cast back to the original dtype.
+ # We can add it after implementing that logic in image_fuse, but first we have to see if it would affect existing model checkpoints.
+ # def test_int32_dtype_preserved(self):
+ # x = torch.randint(0, 100, (1, 1, 8, 8), dtype=torch.int32)
+ # patches = image_batching(x, 4, 4, 0, 0)
+ # fused = image_fuse(patches, 8, 8, batch_size=1,
+ # overlap_pix=0, boundary_pix=0)
+ # print(fused)
+ # assert fused.dtype == torch.int32
+
+ # def test_int64_dtype_preserved(self):
+ # x = torch.randint(0, 100, (1, 1, 8, 8), dtype=torch.int64)
+ # patches = image_batching(x, 4, 4, 0, 0)
+ # fused = image_fuse(patches, 8, 8, batch_size=1,
+ # overlap_pix=0, boundary_pix=0)
+ # assert fused.dtype == torch.int64
+
+
+############################################################################
+# image_batching + image_fuse — roundtrip #
+############################################################################
+
+
+class TestImageBatchingFuseRoundtrip:
+ """Tests that image_batching followed by image_fuse recovers the original."""
+
+ @pytest.mark.parametrize(
+ "img_shape_y, img_shape_x, patch_shape_y, patch_shape_x, overlap_pix, boundary_pix",
+ [
+ (64, 64, 32, 32, 0, 0),
+ (64, 64, 32, 32, 8, 0),
+ (64, 64, 32, 32, 0, 4),
+ (64, 64, 32, 32, 8, 4),
+ (48, 96, 24, 48, 0, 0),
+ (100, 77, 32, 32, 4, 2),
+ ],
+ )
+ def test_roundtrip(
+ self, img_shape_y, img_shape_x,
+ patch_shape_y, patch_shape_x,
+ overlap_pix, boundary_pix,
+ ):
+ batch_size, channels = 2, 3
+ x = torch.randn(batch_size, channels, img_shape_y, img_shape_x)
+ patches = image_batching(
+ x, patch_shape_y, patch_shape_x, overlap_pix, boundary_pix
+ )
+ reconstructed = image_fuse(
+ patches, img_shape_y, img_shape_x,
+ batch_size, overlap_pix, boundary_pix,
+ )
+ torch.testing.assert_close(reconstructed, x, atol=1e-5, rtol=1e-5)
+
+ def test_roundtrip_channels_last(self):
+ """Roundtrip should work with channels_last memory format."""
+ x = torch.randn(2, 3, 64, 64).to(memory_format=torch.channels_last)
+ patches = image_batching(x, 32, 32, 0, 0)
+ reconstructed = image_fuse(patches, 64, 64, batch_size=2,
+ overlap_pix=0, boundary_pix=0)
+ torch.testing.assert_close(
+ reconstructed.contiguous(), x.contiguous(), atol=1e-5, rtol=1e-5
+ )
diff --git a/tests/utils/test_train_helpers.py b/tests/utils/test_train_helpers.py
new file mode 100644
index 00000000..36353524
--- /dev/null
+++ b/tests/utils/test_train_helpers.py
@@ -0,0 +1,954 @@
+import warnings
+from unittest.mock import MagicMock, patch
+
+import numpy as np
+import pytest
+import torch
+from omegaconf import DictConfig, OmegaConf
+
+from hirad.utils.train_helpers import (
+ calculate_patch_per_iter,
+ check_model_health,
+ compute_num_accumulation_rounds,
+ handle_and_clip_gradients,
+ init_mlflow,
+ is_time_for_periodic_task,
+ set_patch_shape,
+ set_seed,
+ update_learning_rate,
+)
+
+
+############################################################################
+# set_patch_shape #
+############################################################################
+
+
+class TestSetPatchShape:
+ """Tests for set_patch_shape."""
+
+ def test_patch_equals_image_disables_patching(self):
+ use_patching, img, patch = set_patch_shape((128, 128), (128, 128))
+ assert use_patching is False
+ assert img == (128, 128)
+ assert patch == (128, 128)
+
+ def test_none_patch_defaults_to_image(self):
+ use_patching, img, patch = set_patch_shape((128, 256), (None, None))
+ assert use_patching is False
+ assert patch == (128, 256)
+
+ def test_patch_larger_than_image_clamped(self):
+ use_patching, img, patch = set_patch_shape((64, 64), (128, 128))
+ assert use_patching is False
+ assert patch == (64, 64)
+
+ def test_valid_square_patch_enables_patching(self):
+ use_patching, img, patch = set_patch_shape((256, 256), (64, 64))
+ assert use_patching is True
+ assert patch == (64, 64)
+
+ def test_patch_not_multiple_of_32_raises(self):
+ with pytest.raises(ValueError, match="multiple of 32"):
+ set_patch_shape((256, 256), (50, 50))
+
+ def test_rectangular_patch_raises(self):
+ with pytest.raises(NotImplementedError, match="Rectangular patch"):
+ set_patch_shape((256, 256), (64, 128))
+
+ def test_img_shape_returned_unchanged(self):
+ _, img, _ = set_patch_shape((100, 200), (None, None))
+ assert img == (100, 200)
+
+ def test_patch_32_is_valid(self):
+ use_patching, _, patch = set_patch_shape((256, 256), (32, 32))
+ assert use_patching is True
+ assert patch == (32, 32)
+
+
+############################################################################
+# set_seed #
+############################################################################
+
+
+class TestSetSeed:
+ """Tests for set_seed."""
+
+ def test_reproducibility(self):
+ set_seed(42)
+ a_np = np.random.rand(5)
+ a_torch = torch.rand(5)
+
+ set_seed(42)
+ b_np = np.random.rand(5)
+ b_torch = torch.rand(5)
+
+ np.testing.assert_array_equal(a_np, b_np)
+ torch.testing.assert_close(a_torch, b_torch)
+
+ def test_different_ranks_give_different_seeds(self):
+ set_seed(0)
+ a_np = np.random.rand(100)
+ a_torch = torch.rand(100)
+
+ set_seed(1)
+ b_np = np.random.rand(100)
+ b_torch = torch.rand(100)
+
+ assert not np.array_equal(a_np, b_np)
+ assert not torch.allclose(a_torch, b_torch)
+
+ def test_large_rank_wraps(self):
+ """Ranks larger than 2^31 should still work due to modulo."""
+ large_rank = (1 << 31) + 5
+ set_seed(large_rank)
+ a_np = np.random.rand(5)
+ a_torch = torch.rand(5)
+
+ set_seed(5)
+ b_np = np.random.rand(5)
+ b_torch = torch.rand(5)
+ # rank % (1<<31) should give 5 in both cases
+ np.testing.assert_array_equal(a_np, b_np)
+ torch.testing.assert_close(a_torch, b_torch)
+
+
+############################################################################
+# compute_num_accumulation_rounds #
+############################################################################
+
+
+class TestComputeNumAccumulationRounds:
+ """Tests for compute_num_accumulation_rounds."""
+
+ def test_single_gpu_no_accumulation(self):
+ batch_gpu_total, num_rounds = compute_num_accumulation_rounds(
+ total_batch_size=16, batch_size_per_gpu=16, world_size=1
+ )
+ assert batch_gpu_total == 16
+ assert num_rounds == 1
+
+ def test_multi_gpu_no_accumulation(self):
+ batch_gpu_total, num_rounds = compute_num_accumulation_rounds(
+ total_batch_size=32, batch_size_per_gpu=8, world_size=4
+ )
+ assert batch_gpu_total == 8
+ assert num_rounds == 1
+
+ def test_accumulation_rounds(self):
+ batch_gpu_total, num_rounds = compute_num_accumulation_rounds(
+ total_batch_size=64, batch_size_per_gpu=8, world_size=2
+ )
+ assert batch_gpu_total == 32
+ assert num_rounds == 4
+
+ def test_none_batch_size_per_gpu_defaults_to_total(self):
+ batch_gpu_total, num_rounds = compute_num_accumulation_rounds(
+ total_batch_size=32, batch_size_per_gpu=None, world_size=2
+ )
+ assert batch_gpu_total == 16
+ assert num_rounds == 1
+
+ def test_batch_size_per_gpu_larger_than_total_clamped(self):
+ batch_gpu_total, num_rounds = compute_num_accumulation_rounds(
+ total_batch_size=16, batch_size_per_gpu=64, world_size=2
+ )
+ assert batch_gpu_total == 8
+ assert num_rounds == 1
+
+ def test_invalid_batch_sizes_raise(self):
+ """total_batch_size not divisible properly should raise ValueError."""
+ with pytest.raises(ValueError, match="total_batch_size must be equal"):
+ compute_num_accumulation_rounds(
+ total_batch_size=17, batch_size_per_gpu=4, world_size=2
+ )
+
+ def test_world_size_1_full_accumulation(self):
+ batch_gpu_total, num_rounds = compute_num_accumulation_rounds(
+ total_batch_size=64, batch_size_per_gpu=16, world_size=1
+ )
+ assert batch_gpu_total == 64
+ assert num_rounds == 4
+
+ def test_exact_division(self):
+ batch_gpu_total, num_rounds = compute_num_accumulation_rounds(
+ total_batch_size=128, batch_size_per_gpu=16, world_size=4
+ )
+ assert batch_gpu_total == 32
+ assert num_rounds == 2
+ assert 16 * 2 * 4 == 128
+
+
+############################################################################
+# handle_and_clip_gradients #
+############################################################################
+
+
+class TestHandleAndClipGradients:
+ """Tests for handle_and_clip_gradients."""
+
+ @pytest.fixture
+ def simple_model(self):
+ """Create a simple linear model with computed gradients."""
+ model = torch.nn.Linear(4, 2, bias=False)
+ x = torch.randn(1, 4)
+ loss = model(x).sum()
+ loss.backward()
+ return model
+
+ def test_nan_gradients_replaced(self):
+ model = torch.nn.Linear(4, 2, bias=False)
+ x = torch.randn(1, 4)
+ loss = model(x).sum()
+ loss.backward()
+ # Inject NaN into gradient
+ model.weight.grad[0, 0] = float("nan")
+ handle_and_clip_gradients(model)
+ assert torch.isfinite(model.weight.grad).all()
+
+ def test_inf_gradients_replaced(self):
+ model = torch.nn.Linear(4, 2, bias=False)
+ x = torch.randn(1, 4)
+ loss = model(x).sum()
+ loss.backward()
+ model.weight.grad[0, 0] = float("inf")
+ model.weight.grad[1, 0] = float("-inf")
+ handle_and_clip_gradients(model)
+ assert torch.isfinite(model.weight.grad).all()
+
+ def test_gradient_clipping(self, simple_model):
+ # Set a large gradient
+ simple_model.weight.grad.fill_(100.0)
+ handle_and_clip_gradients(simple_model, grad_clip_threshold=1.0)
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ simple_model.parameters(), float("inf")
+ )
+ assert grad_norm <= 1.0 + 1e-6
+
+ def test_no_clipping_when_none(self, simple_model):
+ original_grad = simple_model.weight.grad.clone()
+ handle_and_clip_gradients(simple_model, grad_clip_threshold=None)
+ torch.testing.assert_close(simple_model.weight.grad, original_grad)
+
+ def test_params_without_grad_skipped(self):
+ """Parameters without gradients should not cause errors."""
+ model = torch.nn.Linear(4, 2, bias=True)
+ # Only weight has grad, bias does not
+ model.weight.grad = torch.randn_like(model.weight)
+ model.bias.grad = None
+ handle_and_clip_gradients(model) # Should not raise
+
+
+############################################################################
+# check_model_health #
+############################################################################
+
+
+class TestCheckModelHealth:
+ """Tests for check_model_health."""
+
+ @pytest.fixture
+ def logger(self):
+ return MagicMock()
+
+ def test_healthy_model_returns_true(self, logger):
+ model = torch.nn.Linear(4, 2)
+ x = torch.randn(1, 4)
+ loss = model(x).sum()
+ loss.backward()
+ assert check_model_health(model, step=0, logger=logger) is True
+ logger.warning.assert_not_called()
+
+ def test_nan_weights_returns_false(self, logger):
+ model = torch.nn.Linear(4, 2)
+ with torch.no_grad():
+ model.weight[0, 0] = float("nan")
+ result = check_model_health(model, step=5, logger=logger)
+ assert result is False
+ logger.warning.assert_called_once()
+ assert "Weights" in logger.warning.call_args[0][0]
+
+ def test_inf_weights_returns_false(self, logger):
+ model = torch.nn.Linear(4, 2)
+ with torch.no_grad():
+ model.weight[0, 0] = float("inf")
+ result = check_model_health(model, step=3, logger=logger)
+ assert result is False
+ logger.warning.assert_called_once()
+ assert "Weights" in logger.warning.call_args[0][0]
+
+ def test_nan_gradients_returns_false(self, logger):
+ model = torch.nn.Linear(4, 2)
+ x = torch.randn(1, 4)
+ loss = model(x).sum()
+ loss.backward()
+ model.weight.grad[0, 0] = float("nan")
+ result = check_model_health(model, step=10, logger=logger)
+ assert result is False
+ assert "Gradients" in logger.warning.call_args[0][0]
+
+ def test_inf_gradients_returns_false(self, logger):
+ model = torch.nn.Linear(4, 2)
+ x = torch.randn(1, 4)
+ loss = model(x).sum()
+ loss.backward()
+ model.weight.grad[1, 1] = float("inf")
+ result = check_model_health(model, step=7, logger=logger)
+ assert result is False
+ assert "Gradients" in logger.warning.call_args[0][0]
+
+ def test_no_grad_params_healthy(self, logger):
+ """Parameters without gradients should not cause false negatives."""
+ model = torch.nn.Linear(4, 2, bias=True)
+ # No backward called, so no gradients
+ result = check_model_health(model, step=0, logger=logger)
+ assert result is True
+
+ def test_step_number_in_warning(self, logger):
+ model = torch.nn.Linear(4, 2)
+ with torch.no_grad():
+ model.weight[0, 0] = float("nan")
+ check_model_health(model, step=42, logger=logger)
+ assert "42" in logger.warning.call_args[0][0]
+
+
+############################################################################
+# is_time_for_periodic_task #
+############################################################################
+
+
+class TestIsTimeForPeriodicTask:
+ """Tests for is_time_for_periodic_task."""
+
+ def test_exact_frequency_match(self):
+ assert is_time_for_periodic_task(
+ cur_nimg=100, freq=100, done=False, batch_size=10, rank=0
+ ) is True
+
+ def test_within_batch_of_frequency(self):
+ # cur_nimg=105, freq=100 => 105 % 100 = 5 < batch_size=10
+ assert is_time_for_periodic_task(
+ cur_nimg=105, freq=100, done=False, batch_size=10, rank=0
+ ) is True
+
+ def test_not_time_yet(self):
+ # cur_nimg=50, freq=100 => 50 % 100 = 50 >= batch_size=10
+ assert is_time_for_periodic_task(
+ cur_nimg=50, freq=100, done=False, batch_size=10, rank=0
+ ) is False
+
+ def test_done_always_returns_true(self):
+ assert is_time_for_periodic_task(
+ cur_nimg=50, freq=100, done=True, batch_size=10, rank=0
+ ) is True
+
+ def test_rank_0_only_blocks_other_ranks(self):
+ assert is_time_for_periodic_task(
+ cur_nimg=100, freq=100, done=False, batch_size=10,
+ rank=1, rank_0_only=True,
+ ) is False
+
+ def test_rank_0_only_allows_rank_0(self):
+ assert is_time_for_periodic_task(
+ cur_nimg=100, freq=100, done=False, batch_size=10,
+ rank=0, rank_0_only=True,
+ ) is True
+
+ def test_rank_0_only_false_allows_any_rank(self):
+ assert is_time_for_periodic_task(
+ cur_nimg=100, freq=100, done=False, batch_size=10,
+ rank=3, rank_0_only=False,
+ ) is True
+
+ def test_done_overrides_rank_0_only(self):
+ """done=True should return True even for non-zero ranks with rank_0_only."""
+ assert is_time_for_periodic_task(
+ cur_nimg=50, freq=100, done=True, batch_size=10,
+ rank=2, rank_0_only=True,
+ ) is False # rank_0_only check happens first
+
+ def test_zero_cur_nimg(self):
+ # 0 % freq = 0 < batch_size => True
+ assert is_time_for_periodic_task(
+ cur_nimg=0, freq=100, done=False, batch_size=10, rank=0
+ ) is True
+
+ def test_batch_size_equals_freq(self):
+ # Every step should trigger when batch_size >= freq
+ assert is_time_for_periodic_task(
+ cur_nimg=37, freq=100, done=False, batch_size=100, rank=0
+ ) is True
+
+
+############################################################################
+# init_mlflow #
+############################################################################
+
+
+class TestInitMlflow:
+ """Tests for init_mlflow."""
+
+ @pytest.fixture
+ def base_cfg(self):
+ """Minimal config DictConfig for init_mlflow."""
+ return OmegaConf.create(
+ {
+ "logging": {
+ "uri": "http://mlflow-server:5000",
+ "experiment_name": "test_experiment",
+ "run_name": "test_run",
+ },
+ }
+ )
+
+ @pytest.fixture
+ def cfg_no_uri(self):
+ """Config with logging.uri set to None."""
+ return OmegaConf.create(
+ {
+ "logging": {
+ "uri": None,
+ "experiment_name": "test_experiment",
+ "run_name": "test_run",
+ },
+ }
+ )
+
+ @pytest.fixture
+ def dist_rank0_single(self):
+ """DistributedManager mock: rank 0, world_size 1."""
+ dist = MagicMock()
+ dist.rank = 0
+ dist.world_size = 1
+ dist._local_rank = 0
+ return dist
+
+ @pytest.fixture
+ def dist_rank0_multi(self):
+ """DistributedManager mock: rank 0, world_size 4."""
+ dist = MagicMock()
+ dist.rank = 0
+ dist.world_size = 4
+ dist._local_rank = 0
+ return dist
+
+ @pytest.fixture
+ def dist_rank0_large(self):
+ """DistributedManager mock: rank 0, world_size 8 (>4)."""
+ dist = MagicMock()
+ dist.rank = 0
+ dist.world_size = 8
+ dist._local_rank = 0
+ return dist
+
+ @pytest.fixture
+ def mock_mlflow(self):
+ """Patch mlflow and related utilities used in init_mlflow."""
+ with patch("hirad.utils.train_helpers.mlflow") as m_mlflow, \
+ patch("hirad.utils.train_helpers.get_env_info") as m_env, \
+ patch("hirad.utils.train_helpers.flatten_dict") as m_flat:
+ # get_env_info returns (dict, git_diff_string)
+ m_env.return_value = ({"pkg": {"version": "1.0"}}, "diff contents")
+ # flatten_dict passthrough
+ m_flat.side_effect = lambda x: x
+ # active_run mock
+ mock_run = MagicMock()
+ mock_run.info.run_id = "new-run-id-123"
+ m_mlflow.active_run.return_value = mock_run
+ yield {
+ "mlflow": m_mlflow,
+ "get_env_info": m_env,
+ "flatten_dict": m_flat,
+ }
+
+ # --- Rank 0, fresh run (no existing run_id.txt) ---
+
+ def test_rank0_fresh_run_sets_tracking_uri(self, base_cfg, dist_rank0_single, mock_mlflow, tmp_path):
+ init_mlflow(base_cfg, dist_rank0_single, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].set_tracking_uri.assert_called_once_with(
+ "http://mlflow-server:5000"
+ )
+
+ def test_rank0_fresh_run_sets_experiment(self, base_cfg, dist_rank0_single, mock_mlflow, tmp_path):
+ init_mlflow(base_cfg, dist_rank0_single, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].set_experiment.assert_called_once_with(
+ experiment_name="test_experiment"
+ )
+
+ def test_rank0_fresh_run_starts_with_run_name(self, base_cfg, dist_rank0_single, mock_mlflow, tmp_path):
+ init_mlflow(base_cfg, dist_rank0_single, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].start_run.assert_called_once_with(
+ run_name="test_run", log_system_metrics=True
+ )
+
+ def test_rank0_fresh_run_saves_run_id(self, base_cfg, dist_rank0_single, mock_mlflow, tmp_path):
+ init_mlflow(base_cfg, dist_rank0_single, write_dir=str(tmp_path))
+ run_id_file = tmp_path / "run_id.txt"
+ assert run_id_file.exists()
+ assert run_id_file.read_text() == "new-run-id-123"
+
+ def test_rank0_fresh_run_logs_params(self, base_cfg, dist_rank0_single, mock_mlflow, tmp_path):
+ init_mlflow(base_cfg, dist_rank0_single, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].log_params.assert_called_once()
+
+ def test_rank0_fresh_run_logs_env_info(self, base_cfg, dist_rank0_single, mock_mlflow, tmp_path):
+ init_mlflow(base_cfg, dist_rank0_single, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].log_dict.assert_any_call(
+ {"pkg": {"version": "1.0"}}, "environment.json"
+ )
+
+ def test_rank0_fresh_run_logs_git_diff(self, base_cfg, dist_rank0_single, mock_mlflow, tmp_path):
+ init_mlflow(base_cfg, dist_rank0_single, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].log_text.assert_called_once_with(
+ "diff contents", "git_diff.txt"
+ )
+
+ def test_rank0_fresh_run_logs_config(self, base_cfg, dist_rank0_single, mock_mlflow, tmp_path):
+ init_mlflow(base_cfg, dist_rank0_single, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].log_dict.assert_any_call(base_cfg, "config.json")
+
+ # --- Rank 0, no git diff ---
+
+ def test_rank0_no_git_diff_skips_log_text(self, base_cfg, dist_rank0_single, mock_mlflow, tmp_path):
+ mock_mlflow["get_env_info"].return_value = ({"pkg": {}}, "")
+ init_mlflow(base_cfg, dist_rank0_single, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].log_text.assert_not_called()
+
+ # --- Rank 0, resuming from checkpoint (run_id.txt exists) ---
+
+ def test_rank0_resume_reads_run_id(self, base_cfg, dist_rank0_single, mock_mlflow, tmp_path):
+ (tmp_path / "run_id.txt").write_text("existing-run-id-456")
+ init_mlflow(base_cfg, dist_rank0_single, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].start_run.assert_called_once_with(
+ run_id="existing-run-id-456", log_system_metrics=True
+ )
+
+ def test_rank0_resume_does_not_log_params(self, base_cfg, dist_rank0_single, mock_mlflow, tmp_path):
+ (tmp_path / "run_id.txt").write_text("existing-run-id-456")
+ init_mlflow(base_cfg, dist_rank0_single, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].log_params.assert_not_called()
+
+ def test_rank0_resume_does_not_overwrite_run_id(self, base_cfg, dist_rank0_single, mock_mlflow, tmp_path):
+ (tmp_path / "run_id.txt").write_text("existing-run-id-456")
+ init_mlflow(base_cfg, dist_rank0_single, write_dir=str(tmp_path))
+ # File content should remain unchanged
+ assert (tmp_path / "run_id.txt").read_text() == "existing-run-id-456"
+
+ # --- URI handling ---
+
+ def test_rank0_none_uri_skips_set_tracking(self, cfg_no_uri, dist_rank0_single, mock_mlflow, tmp_path):
+ init_mlflow(cfg_no_uri, dist_rank0_single, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].set_tracking_uri.assert_not_called()
+
+ # --- System metrics node ID ---
+
+ def test_rank0_small_world_sets_node_id(self, base_cfg, dist_rank0_multi, mock_mlflow, tmp_path):
+ init_mlflow(base_cfg, dist_rank0_multi, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].system_metrics.set_system_metrics_node_id.assert_called_once_with(
+ "node-0"
+ )
+
+ def test_rank0_large_world_disables_system_metrics(self, base_cfg, dist_rank0_large, mock_mlflow, tmp_path):
+ with patch("hirad.utils.train_helpers.torch"):
+ init_mlflow(base_cfg, dist_rank0_large, write_dir=str(tmp_path))
+ # world_size > 4: log_system_metrics=False
+ mock_mlflow["mlflow"].start_run.assert_called_once()
+ _, kwargs = mock_mlflow["mlflow"].start_run.call_args
+ assert kwargs["log_system_metrics"] is False
+
+ def test_rank0_small_world_enables_system_metrics(self, base_cfg, dist_rank0_multi, mock_mlflow, tmp_path):
+ init_mlflow(base_cfg, dist_rank0_multi, write_dir=str(tmp_path))
+ _, kwargs = mock_mlflow["mlflow"].start_run.call_args
+ assert kwargs["log_system_metrics"] is True
+
+ # --- Distributed barrier for large world_size ---
+
+ def test_large_world_calls_barrier(self, base_cfg, dist_rank0_large, mock_mlflow, tmp_path):
+ with patch("hirad.utils.train_helpers.torch") as m_torch:
+ init_mlflow(base_cfg, dist_rank0_large, write_dir=str(tmp_path))
+ m_torch.distributed.barrier.assert_called_once()
+
+ def test_small_world_skips_barrier(self, base_cfg, dist_rank0_multi, mock_mlflow, tmp_path):
+ with patch("hirad.utils.train_helpers.torch") as m_torch:
+ init_mlflow(base_cfg, dist_rank0_multi, write_dir=str(tmp_path))
+ m_torch.distributed.barrier.assert_not_called()
+
+ # --- Sub-node MLflow activation (non-rank-0 local rank 0) ---
+
+ def test_sub_node_rank4_local0_starts_run(self, base_cfg, mock_mlflow, tmp_path):
+ """rank=4, _local_rank=0, world_size=8 should activate sub mlflow."""
+ dist = MagicMock()
+ dist.rank = 4
+ dist.world_size = 8
+ dist._local_rank = 0
+ (tmp_path / "run_id.txt").write_text("existing-run-id-456")
+ with patch("hirad.utils.train_helpers.torch"):
+ init_mlflow(base_cfg, dist, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].start_run.assert_called_once_with(
+ run_id="existing-run-id-456", log_system_metrics=True
+ )
+
+ def test_sub_node_sets_correct_node_id(self, base_cfg, mock_mlflow, tmp_path):
+ """rank=4, world_size=8 should set node id to 'node-1'."""
+ dist = MagicMock()
+ dist.rank = 4
+ dist.world_size = 8
+ dist._local_rank = 0
+ (tmp_path / "run_id.txt").write_text("run-id")
+ with patch("hirad.utils.train_helpers.torch"):
+ init_mlflow(base_cfg, dist, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].system_metrics.set_system_metrics_node_id.assert_called_once_with(
+ "node-1"
+ )
+
+ def test_rank1_large_world_activates_sub_mlflow(self, base_cfg, mock_mlflow, tmp_path):
+ """rank=1, world_size=8 should activate sub mlflow (special case)."""
+ dist = MagicMock()
+ dist.rank = 1
+ dist.world_size = 8
+ dist._local_rank = 1
+ (tmp_path / "run_id.txt").write_text("run-id")
+ with patch("hirad.utils.train_helpers.torch"):
+ init_mlflow(base_cfg, dist, write_dir=str(tmp_path))
+ # rank=1 special case: node_id is "node-0"
+ mock_mlflow["mlflow"].system_metrics.set_system_metrics_node_id.assert_called_once_with(
+ "node-0"
+ )
+ mock_mlflow["mlflow"].start_run.assert_called_once_with(
+ run_id="run-id", log_system_metrics=True
+ )
+
+ def test_rank1_small_world_skips_sub_mlflow(self, base_cfg, mock_mlflow, tmp_path):
+ """rank=1, world_size=4 should NOT activate sub mlflow."""
+ dist = MagicMock()
+ dist.rank = 1
+ dist.world_size = 4
+ dist._local_rank = 1
+ init_mlflow(base_cfg, dist, write_dir=str(tmp_path))
+ # rank != 0, not local_rank 0, world_size <= 4 => no start_run
+ mock_mlflow["mlflow"].start_run.assert_not_called()
+
+ def test_non_local_rank0_non_rank1_skips_sub_mlflow(self, base_cfg, mock_mlflow, tmp_path):
+ """rank=2, _local_rank=2, world_size=8 should not activate sub mlflow."""
+ dist = MagicMock()
+ dist.rank = 2
+ dist.world_size = 8
+ dist._local_rank = 2
+ with patch("hirad.utils.train_helpers.torch"):
+ init_mlflow(base_cfg, dist, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].start_run.assert_not_called()
+
+ def test_non_local_rank0_non_node_0_skips_sub_mlflow(self, base_cfg, mock_mlflow, tmp_path):
+ """rank=5, _local_rank=1, world_size=8 should not activate sub mlflow."""
+ dist = MagicMock()
+ dist.rank = 5
+ dist.world_size = 8
+ dist._local_rank = 1
+ with patch("hirad.utils.train_helpers.torch"):
+ init_mlflow(base_cfg, dist, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].start_run.assert_not_called()
+
+ def test_sub_node_sets_tracking_uri(self, base_cfg, mock_mlflow, tmp_path):
+ """Sub-node should set tracking URI when configured."""
+ dist = MagicMock()
+ dist.rank = 4
+ dist.world_size = 8
+ dist._local_rank = 0
+ (tmp_path / "run_id.txt").write_text("run-id")
+ with patch("hirad.utils.train_helpers.torch"):
+ init_mlflow(base_cfg, dist, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].set_tracking_uri.assert_called_with(
+ "http://mlflow-server:5000"
+ )
+
+ def test_sub_node_none_uri_skips_starting_mlflow(self, cfg_no_uri, mock_mlflow, tmp_path):
+ """Sub-node should skip starting mlflow when URI is None."""
+ dist = MagicMock()
+ dist.rank = 4
+ dist.world_size = 8
+ dist._local_rank = 0
+ (tmp_path / "run_id.txt").write_text("run-id")
+ with patch("hirad.utils.train_helpers.torch"):
+ init_mlflow(cfg_no_uri, dist, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].set_tracking_uri.assert_not_called()
+ mock_mlflow["mlflow"].start_run.assert_not_called()
+
+ def test_sub_node_sets_experiment(self, base_cfg, mock_mlflow, tmp_path):
+ """Sub-node should set the experiment name."""
+ dist = MagicMock()
+ dist.rank = 4
+ dist.world_size = 8
+ dist._local_rank = 0
+ (tmp_path / "run_id.txt").write_text("run-id")
+ with patch("hirad.utils.train_helpers.torch"):
+ init_mlflow(base_cfg, dist, write_dir=str(tmp_path))
+ mock_mlflow["mlflow"].set_experiment.assert_called_with(
+ experiment_name="test_experiment"
+ )
+
+
+############################################################################
+# update_learning_rate #
+############################################################################
+
+
+class TestUpdateLearningRate:
+ """Tests for update_learning_rate."""
+
+ @staticmethod
+ def _make_optimizer(lr, num_groups=1):
+ """Create a simple SGD optimizer with `num_groups` param groups."""
+ params = [torch.nn.Parameter(torch.zeros(1)) for _ in range(num_groups)]
+ optimizer = torch.optim.SGD(
+ [{"params": [p], "lr": lr} for p in params]
+ )
+ return optimizer
+
+ # ------------------------------------------------------------------
+ # Rampup phase (cur_nimg < lr_rampup)
+ # ------------------------------------------------------------------
+ def test_rampup_halfway(self):
+ """At half the rampup period the LR should be lr * 0.5."""
+ opt = self._make_optimizer(lr=0.1)
+ result = update_learning_rate(opt, lr=0.01, lr_rampup=1000,
+ lr_decay=1.0, lr_decay_rate=1, cur_nimg=500)
+ assert result == pytest.approx(0.01 * 0.5)
+
+ def test_rampup_quarter(self):
+ opt = self._make_optimizer(lr=0.1)
+ result = update_learning_rate(opt, lr=0.02, lr_rampup=2000,
+ lr_decay=1.0, lr_decay_rate=1, cur_nimg=500)
+ assert result == pytest.approx(0.02 * 0.25)
+
+ def test_rampup_at_zero(self):
+ """At cur_nimg=0 the LR should be 0 during rampup."""
+ opt = self._make_optimizer(lr=0.1)
+ result = update_learning_rate(opt, lr=0.01, lr_rampup=1000,
+ lr_decay=1.0, lr_decay_rate=1, cur_nimg=0)
+ assert result == pytest.approx(0.0)
+
+ # ------------------------------------------------------------------
+ # Rampup boundary (cur_nimg == lr_rampup)
+ # ------------------------------------------------------------------
+ def test_rampup_exact_boundary(self):
+ """At the exact rampup boundary LR should equal base lr (no decay yet)."""
+ opt = self._make_optimizer(lr=0.1)
+ result = update_learning_rate(opt, lr=0.01, lr_rampup=1000,
+ lr_decay=0.5, lr_decay_rate=500, cur_nimg=1000)
+ # rampup factor = min(1000/1000, 1) = 1 → lr = 0.01
+ # decay exponent = (1000 - 1000) // 500 = 0 → 0.5^0 = 1
+ assert result == pytest.approx(0.01)
+
+ # ------------------------------------------------------------------
+ # Post-rampup with decay
+ # ------------------------------------------------------------------
+ def test_decay_one_step(self):
+ """One decay step after rampup."""
+ opt = self._make_optimizer(lr=0.1)
+ result = update_learning_rate(opt, lr=0.01, lr_rampup=1000,
+ lr_decay=0.5, lr_decay_rate=500, cur_nimg=1500)
+ # rampup clamped at 1, decay = 0.5 ^ ((1500-1000)//500) = 0.5^1
+ assert result == pytest.approx(0.01 * 0.5)
+
+ def test_decay_two_steps(self):
+ opt = self._make_optimizer(lr=0.1)
+ result = update_learning_rate(opt, lr=0.01, lr_rampup=1000,
+ lr_decay=0.5, lr_decay_rate=500, cur_nimg=2000)
+ # decay = 0.5 ^ ((2000-1000)//500) = 0.5^2 = 0.25
+ assert result == pytest.approx(0.01 * 0.25)
+
+ def test_decay_partial_step_floors(self):
+ """Decay uses integer division, so partial steps are floored."""
+ opt = self._make_optimizer(lr=0.1)
+ result = update_learning_rate(opt, lr=0.01, lr_rampup=1000,
+ lr_decay=0.5, lr_decay_rate=500, cur_nimg=1499)
+ # (1499-1000)//500 = 0 → no decay yet
+ assert result == pytest.approx(0.01)
+
+ # ------------------------------------------------------------------
+ # No rampup (lr_rampup == 0)
+ # ------------------------------------------------------------------
+ def test_no_rampup_applies_decay_to_existing_lr(self):
+ """When lr_rampup=0 the base lr is NOT overwritten;
+ decay is applied to the optimizer's current lr."""
+ opt = self._make_optimizer(lr=0.04)
+ result = update_learning_rate(opt, lr=999, # ignored for the set step
+ lr_rampup=0, lr_decay=0.5,
+ lr_decay_rate=100, cur_nimg=100)
+ # g["lr"] stays 0.04 (rampup branch skipped), then *= 0.5^(100//100) = 0.5
+ assert result == pytest.approx(0.04 * 0.5)
+
+ def test_no_rampup_no_decay(self):
+ """lr_rampup=0, lr_decay=1.0 → LR unchanged."""
+ opt = self._make_optimizer(lr=0.03)
+ result = update_learning_rate(opt, lr=999, lr_rampup=0,
+ lr_decay=1.0, lr_decay_rate=100, cur_nimg=500)
+ assert result == pytest.approx(0.03)
+
+ # ------------------------------------------------------------------
+ # No decay (lr_decay == 1.0)
+ # ------------------------------------------------------------------
+ def test_rampup_without_decay(self):
+ """Rampup works independently of decay when decay=1.0."""
+ opt = self._make_optimizer(lr=0.1)
+ result = update_learning_rate(opt, lr=0.01, lr_rampup=1000,
+ lr_decay=1.0, lr_decay_rate=500, cur_nimg=2000)
+ assert result == pytest.approx(0.01)
+
+ # ------------------------------------------------------------------
+ # Multiple param groups
+ # ------------------------------------------------------------------
+ def test_multiple_param_groups(self):
+ """All param groups are updated; return value is the last group's LR."""
+ opt = self._make_optimizer(lr=0.1, num_groups=3)
+ result = update_learning_rate(opt, lr=0.01, lr_rampup=1000,
+ lr_decay=1.0, lr_decay_rate=1, cur_nimg=500)
+ expected = 0.01 * 0.5
+ for g in opt.param_groups:
+ assert g["lr"] == pytest.approx(expected)
+ assert result == pytest.approx(expected)
+
+ # ------------------------------------------------------------------
+ # Successive calls (simulating a training loop)
+ # ------------------------------------------------------------------
+ def test_successive_calls_during_rampup(self):
+ """LR should grow linearly across successive rampup calls."""
+ opt = self._make_optimizer(lr=0.1)
+ lr, rampup = 0.01, 1000
+ lrs = []
+ for step in range(0, 1001, 200):
+ lrs.append(
+ update_learning_rate(opt, lr=lr, lr_rampup=rampup,
+ lr_decay=1.0, lr_decay_rate=1, cur_nimg=step)
+ )
+ expected = [lr * min(s / rampup, 1) for s in range(0, 1001, 200)]
+ for got, exp in zip(lrs, expected):
+ assert got == pytest.approx(exp)
+
+ def test_successive_calls_with_decay(self):
+ """LR should decrease in staircase fashion after rampup."""
+ opt = self._make_optimizer(lr=0.1)
+ lr, rampup, decay, rate = 0.01, 0, 0.9, 100
+ prev_lr = None
+ for step in [0, 50, 100, 150, 200]:
+ # reset optimizer lr before each call since no rampup means
+ # the function mutates the existing lr multiplicatively
+ for g in opt.param_groups:
+ g["lr"] = lr
+ cur = update_learning_rate(opt, lr=lr, lr_rampup=rampup,
+ lr_decay=decay, lr_decay_rate=rate,
+ cur_nimg=step)
+ expected = lr * decay ** (step // rate)
+ assert cur == pytest.approx(expected)
+
+
+############################################################################
+# calculate_patch_per_iter #
+############################################################################
+
+
+class TestCalculatePatchPerIter:
+ """Tests for calculate_patch_per_iter."""
+
+ # ------------------------------------------------------------------
+ # max_patch_per_gpu is None / falsy → single iteration
+ # ------------------------------------------------------------------
+ def test_no_max_returns_single_element(self):
+ """When max_patch_per_gpu is None, return [patch_num]."""
+ assert calculate_patch_per_iter(4, None, 1) == [4]
+
+ def test_no_max_zero_returns_single_element(self):
+ """When max_patch_per_gpu is 0 (falsy), return [patch_num]."""
+ assert calculate_patch_per_iter(8, 0, 2) == [8]
+
+ def test_no_max_patch_num_one(self):
+ assert calculate_patch_per_iter(1, None, 1) == [1]
+
+ # ------------------------------------------------------------------
+ # max_patch_per_gpu provided – fits in a single iteration
+ # ------------------------------------------------------------------
+ def test_single_iter_exact_fit(self):
+ """patch_num fits exactly within max_patch_per_gpu."""
+ # max_patch_num_per_iter = min(4, 8//2) = 4 → 1 iteration
+ assert calculate_patch_per_iter(4, 8, 2) == [4]
+
+ def test_single_iter_max_exceeds_patch_num(self):
+ """max allows more patches than needed; still one iteration."""
+ # max_patch_num_per_iter = min(2, 16//1) = 2 → 1 iteration
+ assert calculate_patch_per_iter(2, 16, 1) == [2]
+
+ # ------------------------------------------------------------------
+ # max_patch_per_gpu provided – requires multiple iterations
+ # ------------------------------------------------------------------
+ def test_even_split(self):
+ """patch_num divides evenly into iterations."""
+ # max_patch_num_per_iter = min(8, 4//1) = 4 → 2 iterations of 4
+ assert calculate_patch_per_iter(8, 4, 1) == [4, 4]
+
+ def test_uneven_split(self):
+ """Last iteration gets fewer patches."""
+ # max_patch_num_per_iter = min(7, 4//1) = 4
+ # iterations = ceil(7/4) = 2 → [4, 3]
+ assert calculate_patch_per_iter(7, 4, 1) == [4, 3]
+
+ def test_three_iterations(self):
+ """Requires three iterations with a remainder."""
+ # max_patch_num_per_iter = min(10, 4//1) = 4
+ # iterations = ceil(10/4) = 3 → [4, 4, 2]
+ assert calculate_patch_per_iter(10, 4, 1) == [4, 4, 2]
+
+ def test_patch_num_one_less_than_max(self):
+ # max_patch_num_per_iter = min(3, 4//1) = 3 → single iteration
+ assert calculate_patch_per_iter(3, 4, 1) == [3]
+
+ def test_patch_num_one_more_than_max(self):
+ # max_patch_num_per_iter = min(5, 4//1) = 4
+ # iterations = ceil(5/4) = 2 → [4, 1]
+ assert calculate_patch_per_iter(5, 4, 1) == [4, 1]
+
+ # ------------------------------------------------------------------
+ # batch_size_per_gpu interaction
+ # ------------------------------------------------------------------
+ def test_batch_size_reduces_max_per_iter(self):
+ """Larger batch size reduces the effective max patches per iter."""
+ # max_patch_num_per_iter = min(6, 8//4) = 2
+ # iterations = ceil(6/2) = 3 → [2, 2, 2]
+ assert calculate_patch_per_iter(6, 8, 4) == [2, 2, 2]
+
+ def test_batch_size_equals_max(self):
+ """batch_size_per_gpu == max_patch_per_gpu → 1 patch per iter."""
+ # max_patch_num_per_iter = min(3, 4//4) = 1
+ # iterations = ceil(3/1) = 3 → [1, 1, 1]
+ assert calculate_patch_per_iter(3, 4, 4) == [1, 1, 1]
+
+ # ------------------------------------------------------------------
+ # Validation / edge cases
+ # ------------------------------------------------------------------
+ def test_max_less_than_batch_raises(self):
+ """max_patch_per_gpu < batch_size_per_gpu should raise."""
+ with pytest.raises(ValueError, match="max_patch_per_gpu"):
+ calculate_patch_per_iter(4, 2, 4)
+
+ def test_sum_equals_patch_num(self):
+ """Sum of returned list must always equal patch_num."""
+ for patch_num in range(1, 20):
+ for batch in [1, 2, 4]:
+ for max_ppg in [batch, batch * 2, batch * 3, batch * 5]:
+ result = calculate_patch_per_iter(patch_num, max_ppg, batch)
+ assert sum(result) == patch_num, (
+ f"patch_num={patch_num}, max_ppg={max_ppg}, batch={batch}: "
+ f"sum({result}) = {sum(result)} != {patch_num}"
+ )
+
+ def test_no_element_exceeds_max(self):
+ """No single iteration should exceed max_patch_num_per_iter."""
+ for patch_num in range(1, 20):
+ for batch in [1, 2, 4]:
+ for max_ppg in [batch, batch * 2, batch * 3]:
+ max_per_iter = min(patch_num, max_ppg // batch)
+ result = calculate_patch_per_iter(patch_num, max_ppg, batch)
+ assert all(r <= max_per_iter for r in result), (
+ f"patch_num={patch_num}, max_ppg={max_ppg}, batch={batch}: "
+ f"{result} has element > {max_per_iter}"
+ )
+