From a430e8c560fa0e9d267ed07a3dc0020589ccd818 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Thu, 10 Apr 2025 14:52:53 +0200 Subject: [PATCH 01/66] add basic example for regridding using scipy --- read_datasets.py | 159 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 read_datasets.py diff --git a/read_datasets.py b/read_datasets.py new file mode 100644 index 0000000..4ae11e4 --- /dev/null +++ b/read_datasets.py @@ -0,0 +1,159 @@ +from anemoi.datasets import open_dataset +import matplotlib.pyplot as plt +import cartopy.crs as ccrs +import numpy as np +from scipy.interpolate import griddata + +COSMO_PATH = '/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr' # path in Balfrin +ERA_PATH = '/scratch/mch/apennino/data/aifs-ea-an-oper-0001-mars-n320-1979-2022-6h-v6.zarr' # path in Balfrin + +# trim edge removes boundary +cosmo = open_dataset(COSMO_PATH, select="2t", trim_edge=20) # open cosmo, select only 2m-temperature +start_date = cosmo.metadata()['start_date'] # get start and end date of cosmo +end_date = cosmo.metadata()['end_date'] +era = open_dataset(ERA_PATH, select="2t", start=start_date, end=end_date) # load era5 2m-temperature in the time-range of cosmo + + +# get indeces of era5 data that is in the bounding rectangle of cosmo data - this is just for plotting +min_lat_cosmo = min(cosmo.latitudes) +max_lat_cosmo = max(cosmo.latitudes) +min_lon_cosmo = min(cosmo.longitudes) +max_lon_cosmo = max(cosmo.longitudes) +box_lat = np.logical_and(era.latitudes>=min_lat_cosmo,era.latitudes<=max_lat_cosmo) +box_lon = np.logical_and(era.longitudes>=min_lon_cosmo,era.longitudes<=max_lon_cosmo) +indeces = np.where(box_lon*box_lat) + + +#### Approach 1 ######################################################### +#### Scipy Interpolate ################################################## + +grid = np.column_stack((era.longitudes, era.latitudes)) # stack lon-lat columns of era5 points +values = np.array(era[0,0,0,:]) # get era grid 2m-temperature values on the first avaialble date-time + +interp_grid = np.column_stack((cosmo.longitudes, cosmo.latitudes)) # stack lon-lat column of cosmo points + +values_int = griddata(grid,values,interp_grid,method='linear') # interpolate era5 to cosmo grid using scipy griddata linear + + +################ plotting ################################################ + +# plot era original +fig = plt.figure() +fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) +p = ax.scatter(x=era.longitudes[indeces], y=era.latitudes[indeces], c=era[0, 0, 0, :][indeces]) +ax.coastlines() +ax.gridlines(draw_labels=True) +plt.colorbar(p, label="K", orientation="horizontal") +plt.savefig("temperature-2m-era.jpg") + +# plot cosmo original +fig = plt.figure() +fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) +p = ax.scatter(x=cosmo.longitudes, y=cosmo.latitudes, c=cosmo[0, 0, 0, :]) +ax.coastlines() +ax.gridlines(draw_labels=True) +plt.colorbar(p, label="K", orientation="horizontal") +plt.savefig("temperature-2m-cosmo.jpg") + +#plot inerpolated era5 +fig = plt.figure() +fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) +p = ax.scatter(x=cosmo.longitudes, y=cosmo.latitudes, c=values_int) +ax.coastlines() +ax.gridlines(draw_labels=True) +plt.colorbar(p, label="K", orientation="horizontal") +plt.savefig("temperature-2m-era-downscaled.jpg") + + + + + + + + + + + + + +# cosmo = xr.open_zarr(COSMO_PATH) +# print(cosmo.attrs['data_request']) +# era = xr.open_zarr(ERA_PATH) +# print(cosmo[0,0,0,0]) +# print(era) + + + + +# min_lat_cosmo = min(cosmo.latitudes) +# max_lat_cosmo = max(cosmo.latitudes) +# min_lon_cosmo = min(cosmo.longitudes) +# max_lon_cosmo = max(cosmo.longitudes) +# box_lat = np.logical_and(era.latitudes>=min_lat_cosmo,era.latitudes<=max_lat_cosmo) +# box_lon = np.logical_and(era.longitudes>=min_lon_cosmo,era.longitudes<=max_lon_cosmo) +# indeces = np.where(box_lon*box_lat) + +# ds = xr.tutorial.open_dataset( +# "air_temperature" +# ) # use xr.tutorial.load_dataset() for xarray Date: Thu, 10 Apr 2025 14:56:53 +0200 Subject: [PATCH 02/66] rename and delete comments --- read_datasets.py => interpolate_basic.py | 96 +----------------------- 1 file changed, 1 insertion(+), 95 deletions(-) rename read_datasets.py => interpolate_basic.py (50%) diff --git a/read_datasets.py b/interpolate_basic.py similarity index 50% rename from read_datasets.py rename to interpolate_basic.py index 4ae11e4..d17ab5a 100644 --- a/read_datasets.py +++ b/interpolate_basic.py @@ -62,98 +62,4 @@ ax.coastlines() ax.gridlines(draw_labels=True) plt.colorbar(p, label="K", orientation="horizontal") -plt.savefig("temperature-2m-era-downscaled.jpg") - - - - - - - - - - - - - -# cosmo = xr.open_zarr(COSMO_PATH) -# print(cosmo.attrs['data_request']) -# era = xr.open_zarr(ERA_PATH) -# print(cosmo[0,0,0,0]) -# print(era) - - - - -# min_lat_cosmo = min(cosmo.latitudes) -# max_lat_cosmo = max(cosmo.latitudes) -# min_lon_cosmo = min(cosmo.longitudes) -# max_lon_cosmo = max(cosmo.longitudes) -# box_lat = np.logical_and(era.latitudes>=min_lat_cosmo,era.latitudes<=max_lat_cosmo) -# box_lon = np.logical_and(era.longitudes>=min_lon_cosmo,era.longitudes<=max_lon_cosmo) -# indeces = np.where(box_lon*box_lat) - -# ds = xr.tutorial.open_dataset( -# "air_temperature" -# ) # use xr.tutorial.load_dataset() for xarray Date: Thu, 10 Apr 2025 15:07:11 +0200 Subject: [PATCH 03/66] start a branch for transfering corrdiff from nvidia-modulus --- interpolate_basic.py | 65 -------------------------------------------- 1 file changed, 65 deletions(-) delete mode 100644 interpolate_basic.py diff --git a/interpolate_basic.py b/interpolate_basic.py deleted file mode 100644 index d17ab5a..0000000 --- a/interpolate_basic.py +++ /dev/null @@ -1,65 +0,0 @@ -from anemoi.datasets import open_dataset -import matplotlib.pyplot as plt -import cartopy.crs as ccrs -import numpy as np -from scipy.interpolate import griddata - -COSMO_PATH = '/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr' # path in Balfrin -ERA_PATH = '/scratch/mch/apennino/data/aifs-ea-an-oper-0001-mars-n320-1979-2022-6h-v6.zarr' # path in Balfrin - -# trim edge removes boundary -cosmo = open_dataset(COSMO_PATH, select="2t", trim_edge=20) # open cosmo, select only 2m-temperature -start_date = cosmo.metadata()['start_date'] # get start and end date of cosmo -end_date = cosmo.metadata()['end_date'] -era = open_dataset(ERA_PATH, select="2t", start=start_date, end=end_date) # load era5 2m-temperature in the time-range of cosmo - - -# get indeces of era5 data that is in the bounding rectangle of cosmo data - this is just for plotting -min_lat_cosmo = min(cosmo.latitudes) -max_lat_cosmo = max(cosmo.latitudes) -min_lon_cosmo = min(cosmo.longitudes) -max_lon_cosmo = max(cosmo.longitudes) -box_lat = np.logical_and(era.latitudes>=min_lat_cosmo,era.latitudes<=max_lat_cosmo) -box_lon = np.logical_and(era.longitudes>=min_lon_cosmo,era.longitudes<=max_lon_cosmo) -indeces = np.where(box_lon*box_lat) - - -#### Approach 1 ######################################################### -#### Scipy Interpolate ################################################## - -grid = np.column_stack((era.longitudes, era.latitudes)) # stack lon-lat columns of era5 points -values = np.array(era[0,0,0,:]) # get era grid 2m-temperature values on the first avaialble date-time - -interp_grid = np.column_stack((cosmo.longitudes, cosmo.latitudes)) # stack lon-lat column of cosmo points - -values_int = griddata(grid,values,interp_grid,method='linear') # interpolate era5 to cosmo grid using scipy griddata linear - - -################ plotting ################################################ - -# plot era original -fig = plt.figure() -fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) -p = ax.scatter(x=era.longitudes[indeces], y=era.latitudes[indeces], c=era[0, 0, 0, :][indeces]) -ax.coastlines() -ax.gridlines(draw_labels=True) -plt.colorbar(p, label="K", orientation="horizontal") -plt.savefig("temperature-2m-era.jpg") - -# plot cosmo original -fig = plt.figure() -fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) -p = ax.scatter(x=cosmo.longitudes, y=cosmo.latitudes, c=cosmo[0, 0, 0, :]) -ax.coastlines() -ax.gridlines(draw_labels=True) -plt.colorbar(p, label="K", orientation="horizontal") -plt.savefig("temperature-2m-cosmo.jpg") - -#plot inerpolated era5 -fig = plt.figure() -fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) -p = ax.scatter(x=cosmo.longitudes, y=cosmo.latitudes, c=values_int) -ax.coastlines() -ax.gridlines(draw_labels=True) -plt.colorbar(p, label="K", orientation="horizontal") -plt.savefig("temperature-2m-era-downscaled.jpg") \ No newline at end of file From 99d55dffe2893523a9a0c795a68604c20e81ef3e Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 11 Apr 2025 13:43:37 +0200 Subject: [PATCH 04/66] add model and loss scripts --- src/losses/loss.py | 914 +++++++++++++++++++++ src/models/layers.py | 567 ++++++++++++++ src/models/preconditioning copy.py | 1176 ++++++++++++++++++++++++++++ src/models/preconditioning.py | 1176 ++++++++++++++++++++++++++++ src/models/song_unet.py | 906 +++++++++++++++++++++ src/models/unet.py | 267 +++++++ src/models/utils.py | 66 ++ 7 files changed, 5072 insertions(+) create mode 100644 src/losses/loss.py create mode 100644 src/models/layers.py create mode 100644 src/models/preconditioning copy.py create mode 100644 src/models/preconditioning.py create mode 100644 src/models/song_unet.py create mode 100644 src/models/unet.py create mode 100644 src/models/utils.py diff --git a/src/losses/loss.py b/src/losses/loss.py new file mode 100644 index 0000000..18dde13 --- /dev/null +++ b/src/losses/loss.py @@ -0,0 +1,914 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Loss functions used in the paper +"Elucidating the Design Space of Diffusion-Based Generative Models".""" + +import random +from typing import Callable, Optional, Union + +import numpy as np +import torch + + +class VPLoss: + """ + Loss function corresponding to the variance preserving (VP) formulation. + + Parameters + ---------- + beta_d: float, optional + Coefficient for the diffusion process, by default 19.9. + beta_min: float, optional + Minimum bound, by defaults 0.1. + epsilon_t: float, optional + Small positive value, by default 1e-5. + + Note: + ----- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + + """ + + def __init__( + self, beta_d: float = 19.9, beta_min: float = 0.1, epsilon_t: float = 1e-5 + ): + self.beta_d = beta_d + self.beta_min = beta_min + self.epsilon_t = epsilon_t + + def __call__( + self, + net: torch.nn.Module, + images: torch.Tensor, + labels: torch.Tensor, + augment_pipe: Optional[Callable] = None, + ): + """ + Calculate and return the loss corresponding to the variance preserving (VP) + formulation. + + The method adds random noise to the input images and calculates the loss as the + square difference between the network's predictions and the input images. + The noise level is determined by 'sigma', which is computed as a function of + 'epsilon_t' and random values. The calculated loss is weighted based on the + inverse of 'sigma^2'. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + images: torch.Tensor + Input images to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) + sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1)) + weight = 1 / sigma**2 + y, augment_labels = ( + augment_pipe(images) if augment_pipe is not None else (images, None) + ) + n = torch.randn_like(y) * sigma + D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + return loss + + def sigma( + self, t: Union[float, torch.Tensor] + ): # NOTE: also exists in preconditioning + """ + Compute the sigma(t) value for a given t based on the VP formulation. + + The function calculates the noise level schedule for the diffusion process based + on the given parameters `beta_d` and `beta_min`. + + Parameters + ---------- + t : Union[float, torch.Tensor] + The timestep or set of timesteps for which to compute sigma(t). + + Returns + ------- + torch.Tensor + The computed sigma(t) value(s). + """ + t = torch.as_tensor(t) + return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt() + + +class VELoss: + """ + Loss function corresponding to the variance exploding (VE) formulation. + + Parameters + ---------- + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + + Note: + ----- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + """ + + def __init__(self, sigma_min: float = 0.02, sigma_max: float = 100.0): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + + def __call__(self, net, images, labels, augment_pipe=None): + """ + Calculate and return the loss corresponding to the variance exploding (VE) + formulation. + + The method adds random noise to the input images and calculates the loss as the + square difference between the network's predictions and the input images. + The noise level is determined by 'sigma', which is computed as a function of + 'sigma_min' and 'sigma_max' and random values. The calculated loss is weighted + based on the inverse of 'sigma^2'. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + images: torch.Tensor + Input images to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) + sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform) + weight = 1 / sigma**2 + y, augment_labels = ( + augment_pipe(images) if augment_pipe is not None else (images, None) + ) + n = torch.randn_like(y) * sigma + D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + return loss + + +class EDMLoss: + """ + Loss function proposed in the EDM paper. + + Parameters + ---------- + P_mean: float, optional + Mean value for `sigma` computation, by default -1.2. + P_std: float, optional: + Standard deviation for `sigma` computation, by default 1.2. + sigma_data: float, optional + Standard deviation for data, by default 0.5. + + Note + ---- + Reference: Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + """ + + def __init__( + self, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5 + ): + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + + def __call__(self, net, images, condition=None, labels=None, augment_pipe=None): + """ + Calculate and return the loss corresponding to the EDM formulation. + + The method adds random noise to the input images and calculates the loss as the + square difference between the network's predictions and the input images. + The noise level is determined by 'sigma', which is computed as a function of + 'P_mean' and 'P_std' random values. The calculated loss is weighted as a + function of 'sigma' and 'sigma_data'. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + images: torch.Tensor + Input images to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + y, augment_labels = ( + augment_pipe(images) if augment_pipe is not None else (images, None) + ) + n = torch.randn_like(y) * sigma + if condition is not None: + D_yn = net( + y + n, + sigma, + condition=condition, + class_labels=labels, + augment_labels=augment_labels, + ) + else: + D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + return loss + + +class EDMLossSR: + """ + Variation of the loss function proposed in the EDM paper for Super-Resolution. + + Parameters + ---------- + P_mean: float, optional + Mean value for `sigma` computation, by default -1.2. + P_std: float, optional: + Standard deviation for `sigma` computation, by default 1.2. + sigma_data: float, optional + Standard deviation for data, by default 0.5. + + Note + ---- + Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5 + ): + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + + def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): + """ + Calculate and return the loss corresponding to the EDM formulation. + + The method adds random noise to the input images and calculates the loss as the + square difference between the network's predictions and the input images. + The noise level is determined by 'sigma', which is computed as a function of + 'P_mean' and 'P_std' random values. The calculated loss is weighted as a + function of 'sigma' and 'sigma_data'. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + images: torch.Tensor + Input images to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + # augment for conditional generaiton + img_tot = torch.cat((img_clean, img_lr), dim=1) + y_tot, augment_labels = ( + augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) + ) + y = y_tot[:, : img_clean.shape[1], :, :] + y_lr = y_tot[:, img_clean.shape[1] :, :, :] + + n = torch.randn_like(y) * sigma + D_yn = net(y + n, y_lr, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + return loss + + +class RegressionLoss: + """ + Regression loss function for the U-Net for deterministic predictions. + + Parameters + ---------- + P_mean: float, optional + Mean value for `sigma` computation, by default -1.2. + P_std: float, optional: + Standard deviation for `sigma` computation, by default 1.2. + sigma_data: float, optional + Standard deviation for data, by default 0.5. + + Note + ---- + Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5 + ): + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + + def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): + """ + Calculate and return the loss for the U-Net for deterministic predictions. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + img_clean: torch.Tensor + Input images (high resolution) to the neural network. + + img_lr: torch.Tensor + Input images (low resolution) to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = ( + 1.0 # (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 + ) + + img_tot = torch.cat((img_clean, img_lr), dim=1) + y_tot, augment_labels = ( + augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) + ) + y = y_tot[:, : img_clean.shape[1], :, :] + y_lr = y_tot[:, img_clean.shape[1] :, :, :] + + input = torch.zeros_like(y, device=img_clean.device) + D_yn = net(input, y_lr, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + + return loss + + +class ResLoss: + """ + Mixture loss function for denoising score matching. + + Parameters + ---------- + P_mean: float, optional + Mean value for `sigma` computation, by default -1.2. + P_std: float, optional: + Standard deviation for `sigma` computation, by default 1.2. + sigma_data: float, optional + Standard deviation for data, by default 0.5. + + Note + ---- + Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, + regression_net, + img_shape_x, + img_shape_y, + patch_shape_x, + patch_shape_y, + patch_num, + P_mean: float = 0.0, + P_std: float = 1.2, + sigma_data: float = 0.5, + hr_mean_conditioning: bool = False, + ): + self.unet = regression_net + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + self.img_shape_x = img_shape_x + self.img_shape_y = img_shape_y + self.patch_shape_x = patch_shape_x + self.patch_shape_y = patch_shape_y + self.patch_num = patch_num + self.hr_mean_conditioning = hr_mean_conditioning + + def __call__( + self, + net, + img_clean, + img_lr, + labels=None, + lead_time_label=None, + augment_pipe=None, + ): + """ + Calculate and return the loss for denoising score matching. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + img_clean: torch.Tensor + Input images (high resolution) to the neural network. + + img_lr: torch.Tensor + Input images (low resolution) to the neural network. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + + rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + # augment for conditional generaiton + img_tot = torch.cat((img_clean, img_lr), dim=1) + y_tot, augment_labels = ( + augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) + ) + y = y_tot[:, : img_clean.shape[1], :, :] + y_lr = y_tot[:, img_clean.shape[1] :, :, :] + y_lr_res = y_lr + + # global index + b = y.shape[0] + Nx = torch.arange(self.img_shape_x).int() + Ny = torch.arange(self.img_shape_y).int() + grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0)[ + None, + ].expand(b, -1, -1, -1) + + # form residual + if lead_time_label is not None: + y_mean = self.unet( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + sigma, + labels, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + y_mean = self.unet( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + sigma, + labels, + augment_labels=augment_labels, + ) + + y = y - y_mean + + if self.hr_mean_conditioning: + y_lr = torch.cat((y_mean, y_lr), dim=1).contiguous() + global_index = None + # patchified training + # conditioning: cat(y_mean, y_lr, input_interp, pos_embd), 4+12+100+4 + if ( + self.img_shape_x != self.patch_shape_x + or self.img_shape_y != self.patch_shape_y + ): + c_in = y_lr.shape[1] + c_out = y.shape[1] + rnd_normal = torch.randn( + [img_clean.shape[0] * self.patch_num, 1, 1, 1], device=img_clean.device + ) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma**2 + self.sigma_data**2) / ( + sigma * self.sigma_data + ) ** 2 + + # global interpolation + input_interp = torch.nn.functional.interpolate( + img_lr, + (self.patch_shape_y, self.patch_shape_x), + mode="bilinear", + ) + + # patch generation from a single sample (not from random samples due to memory consumption of regression) + y_new = torch.zeros( + b * self.patch_num, + c_out, + self.patch_shape_y, + self.patch_shape_x, + device=img_clean.device, + ) + y_lr_new = torch.zeros( + b * self.patch_num, + c_in + input_interp.shape[1], + self.patch_shape_y, + self.patch_shape_x, + device=img_clean.device, + ) + global_index = torch.zeros( + b * self.patch_num, + 2, + self.patch_shape_y, + self.patch_shape_x, + dtype=torch.int, + device=img_clean.device, + ) + for i in range(self.patch_num): + rnd_x = random.randint(0, self.img_shape_x - self.patch_shape_x) + rnd_y = random.randint(0, self.img_shape_y - self.patch_shape_y) + y_new[b * i : b * (i + 1),] = y[ + :, + :, + rnd_y : rnd_y + self.patch_shape_y, + rnd_x : rnd_x + self.patch_shape_x, + ] + global_index[b * i : b * (i + 1),] = grid[ + :, + :, + rnd_y : rnd_y + self.patch_shape_y, + rnd_x : rnd_x + self.patch_shape_x, + ] + y_lr_new[b * i : b * (i + 1),] = torch.cat( + ( + y_lr[ + :, + :, + rnd_y : rnd_y + self.patch_shape_y, + rnd_x : rnd_x + self.patch_shape_x, + ], + input_interp, + ), + 1, + ) + y = y_new + y_lr = y_lr_new + latent = y + torch.randn_like(y) * sigma + + if lead_time_label is not None: + D_yn = net( + latent, + y_lr, + sigma, + labels, + global_index=global_index, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + D_yn = net( + latent, + y_lr, + sigma, + labels, + global_index=global_index, + augment_labels=augment_labels, + ) + loss = weight * ((D_yn - y) ** 2) + + return loss + + +class VELoss_dfsr: + """ + Loss function for dfsr model, modified from class VELoss. + + Parameters + ---------- + beta_start : float + Noise level at the initial step of the forward diffusion process, by default 0.0001. + beta_end : float + Noise level at the Final step of the forward diffusion process, by default 0.02. + num_diffusion_timesteps : int + Total number of forward/backward diffusion steps, by default 1000. + + + Note: + ----- + Reference: Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models. + Advances in neural information processing systems. 2020;33:6840-51. + """ + + def __init__( + self, + beta_start: float = 0.0001, + beta_end: float = 0.02, + num_diffusion_timesteps: int = 1000, + ): + # scheduler for diffusion: + self.beta_schedule = "linear" + self.beta_start = beta_start + self.beta_end = beta_end + self.num_diffusion_timesteps = num_diffusion_timesteps + betas = self.get_beta_schedule( + beta_schedule=self.beta_schedule, + beta_start=self.beta_start, + beta_end=self.beta_end, + num_diffusion_timesteps=self.num_diffusion_timesteps, + ) + self.betas = torch.from_numpy(betas).float() + self.num_timesteps = betas.shape[0] + + def get_beta_schedule( + self, beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps + ): + """ + Compute the variance scheduling parameters {beta(0), ..., beta(t), ..., beta(T)} + based on the VP formulation. + + beta_schedule: str + Method to construct the sequence of beta(t)'s. + beta_start: float + Noise level at the initial step of the forward diffusion process, e.g., beta(0) + beta_end: float + Noise level at the final step of the forward diffusion process, e.g., beta(T) + num_diffusion_timesteps: int + Total number of forward/backward diffusion steps + """ + + def sigmoid(x): + return 1 / (np.exp(-x) + 1) + + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start**0.5, + beta_end**0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + elif beta_schedule == "sigmoid": + betas = np.linspace(-6, 6, num_diffusion_timesteps) + betas = sigmoid(betas) * (beta_end - beta_start) + beta_start + else: + raise NotImplementedError(beta_schedule) + if betas.shape != (num_diffusion_timesteps,): + raise ValueError( + f"Expected betas to have shape ({num_diffusion_timesteps},), " + f"but got {betas.shape}" + ) + return betas + + def __call__(self, net, images, labels, augment_pipe=None): + """ + Calculate and return the loss corresponding to the variance preserving + formulation. + + The method adds random noise to the input images and calculates the loss as the + square difference between the network's predictions and the noise samples added + to the t-th step of the diffusion process. + The noise level is determined by 'beta_t' based on the given parameters 'beta_start', + 'beta_end' and the current diffusion timestep t. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + images: torch.Tensor + Input fluid flow data samples to the neural network. + + labels: torch.Tensor + Ground truth labels for the input fluid flow data samples. Not required for dfsr. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + t = torch.randint( + low=0, high=self.num_timesteps, size=(images.size(0) // 2 + 1,) + ).to(images.device) + t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[: images.size(0)] + e = torch.randn_like(images) + b = self.betas.to(images.device) + a = (1 - b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1) + x = images * a.sqrt() + e * (1.0 - a).sqrt() + + output = net(x, t, labels) + loss = (e - output).square() + + return loss + + +class RegressionLossCE: + """ + A regression loss function for the GEFS-HRRR model with probability channels, adapted + from RegressionLoss. In this version, probability channels are evaluated using + CrossEntropyLoss instead of MSELoss. + + Parameters + ---------- + P_mean: float, optional + Mean value for `sigma` computation, by default -1.2. + P_std: float, optional: + Standard deviation for `sigma` computation, by default 1.2. + sigma_data: float, optional + Standard deviation for data, by default 0.5. + prob_channels: list, optional + A index list of output probability channels. + + Note + ---- + Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, + P_mean: float = -1.2, + P_std: float = 1.2, + sigma_data: float = 0.5, + prob_channels: list = [4, 5, 6, 7, 8], + ): + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + self.entropy = torch.nn.CrossEntropyLoss(reduction="none") + self.prob_channels = prob_channels + + def __call__( + self, + net, + img_clean, + img_lr, + lead_time_label=None, + labels=None, + augment_pipe=None, + ): + """ + Calculate and return the loss for the U-Net for deterministic predictions. + + Parameters: + ---------- + net: torch.nn.Module + The neural network model that will make predictions. + + img_clean: torch.Tensor + Input images (high resolution) to the neural network. + + img_lr: torch.Tensor + Input images (low resolution) to the neural network. + + lead_time_label: torch.Tensor + Lead time labels for input batches. + + labels: torch.Tensor + Ground truth labels for the input images. + + augment_pipe: callable, optional + An optional data augmentation function that takes images as input and + returns augmented images. If not provided, no data augmentation is applied. + + Returns: + ------- + torch.Tensor + A tensor representing the loss calculated based on the network's + predictions. + """ + all_channels = list(range(img_clean.shape[1])) # [0, 1, 2, ..., 10] + scalar_channels = [ + item for item in all_channels if item not in self.prob_channels + ] + rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = ( + 1.0 # (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 + ) + + img_tot = torch.cat((img_clean, img_lr), dim=1) + y_tot, augment_labels = ( + augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) + ) + y = y_tot[:, : img_clean.shape[1], :, :] + y_lr = y_tot[:, img_clean.shape[1] :, :, :] + + input = torch.zeros_like(y, device=img_clean.device) + + if lead_time_label is not None: + D_yn = net( + input, + y_lr, + sigma, + labels, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + D_yn = net( + input, + y_lr, + sigma, + labels, + augment_labels=augment_labels, + ) + loss1 = weight * ((D_yn[:, scalar_channels] - y[:, scalar_channels]) ** 2) + loss2 = ( + weight + * self.entropy(D_yn[:, self.prob_channels], y[:, self.prob_channels])[ + :, None + ] + ) + loss = torch.cat((loss1, loss2), dim=1) + return loss diff --git a/src/models/layers.py b/src/models/layers.py new file mode 100644 index 0000000..1fb3b17 --- /dev/null +++ b/src/models/layers.py @@ -0,0 +1,567 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model architecture layers used in the paper "Elucidating the Design Space of +Diffusion-Based Generative Models". +""" + +from typing import Any, Dict, List + +import numpy as np +import torch +from einops import rearrange +from torch.nn.functional import silu + +from physicsnemo.models.diffusion import weight_init + + +class Linear(torch.nn.Module): + """ + A fully connected (dense) layer implementation. The layer's weights and biases can + be initialized using custom initialization strategies like "kaiming_normal", + and can be further scaled by factors `init_weight` and `init_bias`. + + Parameters + ---------- + in_features : int + Size of each input sample. + out_features : int + Size of each output sample. + bias : bool, optional + The biases of the layer. If set to `None`, the layer will not learn an additive + bias. By default True. + init_mode : str, optional (default="kaiming_normal") + The mode/type of initialization to use for weights and biases. Supported modes + are: + - "xavier_uniform": Xavier (Glorot) uniform initialization. + - "xavier_normal": Xavier (Glorot) normal initialization. + - "kaiming_uniform": Kaiming (He) uniform initialization. + - "kaiming_normal": Kaiming (He) normal initialization. + By default "kaiming_normal". + init_weight : float, optional + A scaling factor to multiply with the initialized weights. By default 1. + init_bias : float, optional + A scaling factor to multiply with the initialized biases. By default 0. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + init_mode: str = "kaiming_normal", + init_weight: int = 1, + init_bias: int = 0, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) + self.weight = torch.nn.Parameter( + weight_init([out_features, in_features], **init_kwargs) * init_weight + ) + self.bias = ( + torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) + if bias + else None + ) + + def forward(self, x): + x = x @ self.weight.to(x.dtype).t() + if self.bias is not None: + x = x.add_(self.bias.to(x.dtype)) + return x + + +class Conv2d(torch.nn.Module): + """ + A custom 2D convolutional layer implementation with support for up-sampling, + down-sampling, and custom weight and bias initializations. The layer's weights + and biases canbe initialized using custom initialization strategies like + "kaiming_normal", and can be further scaled by factors `init_weight` and + `init_bias`. + + Parameters + ---------- + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels produced by the convolution. + kernel : int + Size of the convolving kernel. + bias : bool, optional + The biases of the layer. If set to `None`, the layer will not learn an + additive bias. By default True. + up : bool, optional + Whether to perform up-sampling. By default False. + down : bool, optional + Whether to perform down-sampling. By default False. + resample_filter : List[int], optional + Filter to be used for resampling. By default [1, 1]. + fused_resample : bool, optional + If True, performs fused up-sampling and convolution or fused down-sampling + and convolution. By default False. + init_mode : str, optional (default="kaiming_normal") + init_mode : str, optional (default="kaiming_normal") + The mode/type of initialization to use for weights and biases. Supported modes + are: + - "xavier_uniform": Xavier (Glorot) uniform initialization. + - "xavier_normal": Xavier (Glorot) normal initialization. + - "kaiming_uniform": Kaiming (He) uniform initialization. + - "kaiming_normal": Kaiming (He) normal initialization. + By default "kaiming_normal". + init_weight : float, optional + A scaling factor to multiply with the initialized weights. By default 1.0. + init_bias : float, optional + A scaling factor to multiply with the initialized biases. By default 0.0. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel: int, + bias: bool = True, + up: bool = False, + down: bool = False, + resample_filter: List[int] = [1, 1], + fused_resample: bool = False, + init_mode: str = "kaiming_normal", + init_weight: float = 1.0, + init_bias: float = 0.0, + ): + if up and down: + raise ValueError("Both 'up' and 'down' cannot be true at the same time.") + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.up = up + self.down = down + self.fused_resample = fused_resample + init_kwargs = dict( + mode=init_mode, + fan_in=in_channels * kernel * kernel, + fan_out=out_channels * kernel * kernel, + ) + self.weight = ( + torch.nn.Parameter( + weight_init([out_channels, in_channels, kernel, kernel], **init_kwargs) + * init_weight + ) + if kernel + else None + ) + self.bias = ( + torch.nn.Parameter(weight_init([out_channels], **init_kwargs) * init_bias) + if kernel and bias + else None + ) + f = torch.as_tensor(resample_filter, dtype=torch.float32) + f = f.ger(f).unsqueeze(0).unsqueeze(1) / f.sum().square() + self.register_buffer("resample_filter", f if up or down else None) + + def forward(self, x): + w = self.weight.to(x.dtype) if self.weight is not None else None + b = self.bias.to(x.dtype) if self.bias is not None else None + f = ( + self.resample_filter.to(x.dtype) + if self.resample_filter is not None + else None + ) + w_pad = w.shape[-1] // 2 if w is not None else 0 + f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0 + + if self.fused_resample and self.up and w is not None: + x = torch.nn.functional.conv_transpose2d( + x, + f.mul(4).tile([self.in_channels, 1, 1, 1]), + groups=self.in_channels, + stride=2, + padding=max(f_pad - w_pad, 0), + ) + x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0)) + elif self.fused_resample and self.down and w is not None: + x = torch.nn.functional.conv2d(x, w, padding=w_pad + f_pad) + x = torch.nn.functional.conv2d( + x, + f.tile([self.out_channels, 1, 1, 1]), + groups=self.out_channels, + stride=2, + ) + else: + if self.up: + x = torch.nn.functional.conv_transpose2d( + x, + f.mul(4).tile([self.in_channels, 1, 1, 1]), + groups=self.in_channels, + stride=2, + padding=f_pad, + ) + if self.down: + x = torch.nn.functional.conv2d( + x, + f.tile([self.in_channels, 1, 1, 1]), + groups=self.in_channels, + stride=2, + padding=f_pad, + ) + if w is not None: + x = torch.nn.functional.conv2d(x, w, padding=w_pad) + if b is not None: + x = x.add_(b.reshape(1, -1, 1, 1)) + return x + + +class GroupNorm(torch.nn.Module): + """ + A custom Group Normalization layer implementation. + + Group Normalization (GN) divides the channels of the input tensor into groups and + normalizes the features within each group independently. It does not require the + batch size as in Batch Normalization, making itsuitable for batch sizes of any size + or even for batch-free scenarios. + + Parameters + ---------- + num_channels : int + Number of channels in the input tensor. + num_groups : int, optional + Desired number of groups to divide the input channels, by default 32. + This might be adjusted based on the `min_channels_per_group`. + min_channels_per_group : int, optional + Minimum channels required per group. This ensures that no group has fewer + channels than this number. By default 4. + eps : float, optional + A small number added to the variance to prevent division by zero, by default + 1e-5. + + Notes + ----- + If `num_channels` is not divisible by `num_groups`, the actual number of groups + might be adjusted to satisfy the `min_channels_per_group` condition. + """ + + def __init__( + self, + num_channels: int, + num_groups: int = 32, + min_channels_per_group: int = 4, + eps: float = 1e-5, + ): + super().__init__() + self.num_groups = min(num_groups, num_channels // min_channels_per_group) + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(num_channels)) + self.bias = torch.nn.Parameter(torch.zeros(num_channels)) + + def forward(self, x): + if self.training: + # Use default torch implementation of GroupNorm for training + # This does not support channels last memory format + x = torch.nn.functional.group_norm( + x, + num_groups=self.num_groups, + weight=self.weight.to(x.dtype), + bias=self.bias.to(x.dtype), + eps=self.eps, + ) + else: + # Use custom GroupNorm implementation that supports channels last + # memory layout for inference + dtype = x.dtype + x = x.float() + x = rearrange(x, "b (g c) h w -> b g c h w", g=self.num_groups) + + mean = x.mean(dim=[2, 3, 4], keepdim=True) + var = x.var(dim=[2, 3, 4], keepdim=True) + + x = (x - mean) * (var + self.eps).rsqrt() + x = rearrange(x, "b g c h w -> b (g c) h w") + + weight = rearrange(self.weight, "c -> 1 c 1 1") + bias = rearrange(self.bias, "c -> 1 c 1 1") + x = x * weight + bias + + x = x.type(dtype) + return x + + +class AttentionOp(torch.autograd.Function): + """ + Attention weight computation, i.e., softmax(Q^T * K). + Performs all computation using FP32, but uses the original datatype for + inputs/outputs/gradients to conserve memory. + """ + + @staticmethod + def forward(ctx, q, k): + w = ( + torch.einsum( + "ncq,nck->nqk", + q.to(torch.float32), + (k / torch.sqrt(torch.tensor(k.shape[1]))).to(torch.float32), + ) + .softmax(dim=2) + .to(q.dtype) + ) + ctx.save_for_backward(q, k, w) + return w + + @staticmethod + def backward(ctx, dw): + q, k, w = ctx.saved_tensors + db = torch._softmax_backward_data( + grad_output=dw.to(torch.float32), + output=w.to(torch.float32), + dim=2, + input_dtype=torch.float32, + ) + dq = torch.einsum("nck,nqk->ncq", k.to(torch.float32), db).to( + q.dtype + ) / np.sqrt(k.shape[1]) + dk = torch.einsum("ncq,nqk->nck", q.to(torch.float32), db).to( + k.dtype + ) / np.sqrt(k.shape[1]) + return dq, dk + + +class UNetBlock(torch.nn.Module): + """ + Unified U-Net block with optional up/downsampling and self-attention. Represents + the union of all features employed by the DDPM++, NCSN++, and ADM architectures. + + Parameters: + ----------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + emb_channels : int + Number of embedding channels. + up : bool, optional + If True, applies upsampling in the forward pass. By default False. + down : bool, optional + If True, applies downsampling in the forward pass. By default False. + attention : bool, optional + If True, enables the self-attention mechanism in the block. By default False. + num_heads : int, optional + Number of attention heads. If None, defaults to `out_channels // 64`. + channels_per_head : int, optional + Number of channels per attention head. By default 64. + dropout : float, optional + Dropout probability. By default 0.0. + skip_scale : float, optional + Scale factor applied to skip connections. By default 1.0. + eps : float, optional + Epsilon value used for normalization layers. By default 1e-5. + resample_filter : List[int], optional + Filter for resampling layers. By default [1, 1]. + resample_proj : bool, optional + If True, resampling projection is enabled. By default False. + adaptive_scale : bool, optional + If True, uses adaptive scaling in the forward pass. By default True. + init : dict, optional + Initialization parameters for convolutional and linear layers. + init_zero : dict, optional + Initialization parameters with zero weights for certain layers. By default + {'init_weight': 0}. + init_attn : dict, optional + Initialization parameters specific to attention mechanism layers. + Defaults to 'init' if not provided. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + emb_channels: int, + up: bool = False, + down: bool = False, + attention: bool = False, + num_heads: int = None, + channels_per_head: int = 64, + dropout: float = 0.0, + skip_scale: float = 1.0, + eps: float = 1e-5, + resample_filter: List[int] = [1, 1], + resample_proj: bool = False, + adaptive_scale: bool = True, + init: Dict[str, Any] = dict(), + init_zero: Dict[str, Any] = dict(init_weight=0), + init_attn: Any = None, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.emb_channels = emb_channels + self.num_heads = ( + 0 + if not attention + else num_heads + if num_heads is not None + else out_channels // channels_per_head + ) + self.dropout = dropout + self.skip_scale = skip_scale + self.adaptive_scale = adaptive_scale + + self.norm0 = GroupNorm(num_channels=in_channels, eps=eps) + self.conv0 = Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel=3, + up=up, + down=down, + resample_filter=resample_filter, + **init, + ) + self.affine = Linear( + in_features=emb_channels, + out_features=out_channels * (2 if adaptive_scale else 1), + **init, + ) + self.norm1 = GroupNorm(num_channels=out_channels, eps=eps) + self.conv1 = Conv2d( + in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero + ) + + self.skip = None + if out_channels != in_channels or up or down: + kernel = 1 if resample_proj or out_channels != in_channels else 0 + self.skip = Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel=kernel, + up=up, + down=down, + resample_filter=resample_filter, + **init, + ) + + if self.num_heads: + self.norm2 = GroupNorm(num_channels=out_channels, eps=eps) + self.qkv = Conv2d( + in_channels=out_channels, + out_channels=out_channels * 3, + kernel=1, + **(init_attn if init_attn is not None else init), + ) + self.proj = Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel=1, + **init_zero, + ) + + def forward(self, x, emb): + torch.cuda.nvtx.range_push("UNetBlock") + orig = x + x = self.conv0(silu(self.norm0(x))) + + params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype) + if self.adaptive_scale: + scale, shift = params.chunk(chunks=2, dim=1) + x = silu(torch.addcmul(shift, self.norm1(x), scale + 1)) + else: + x = silu(self.norm1(x.add_(params))) + + x = self.conv1( + torch.nn.functional.dropout(x, p=self.dropout, training=self.training) + ) + x = x.add_(self.skip(orig) if self.skip is not None else orig) + x = x * self.skip_scale + + if self.num_heads: + q, k, v = ( + self.qkv(self.norm2(x)) + .reshape( + x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1 + ) + .unbind(2) + ) + w = AttentionOp.apply(q, k) + a = torch.einsum("nqk,nck->ncq", w, v) + x = self.proj(a.reshape(*x.shape)).add_(x) + x = x * self.skip_scale + torch.cuda.nvtx.range_pop() + return x + + +class PositionalEmbedding(torch.nn.Module): + """ + A module for generating positional embeddings based on timesteps. + This embedding technique is employed in the DDPM++ and ADM architectures. + + Parameters: + ----------- + num_channels : int + Number of channels for the embedding. + max_positions : int, optional + Maximum number of positions for the embeddings, by default 10000. + endpoint : bool, optional + If True, the embedding considers the endpoint. By default False. + + """ + + def __init__( + self, num_channels: int, max_positions: int = 10000, endpoint: bool = False + ): + super().__init__() + self.num_channels = num_channels + self.max_positions = max_positions + self.endpoint = endpoint + + def forward(self, x): + freqs = torch.arange( + start=0, end=self.num_channels // 2, dtype=torch.float32, device=x.device + ) + freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) + freqs = (1 / self.max_positions) ** freqs + x = x.ger(freqs.to(x.dtype)) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x + + +class FourierEmbedding(torch.nn.Module): + """ + Generates Fourier embeddings for timesteps, primarily used in the NCSN++ + architecture. + + This class generates embeddings by first multiplying input tensor `x` and + internally stored random frequencies, and then concatenating the cosine and sine of + the resultant. + + Parameters: + ----------- + num_channels : int + The number of channels in the embedding. The final embedding size will be + 2 * num_channels because of concatenation of cosine and sine results. + scale : int, optional + A scale factor applied to the random frequencies, controlling their range + and thereby the frequency of oscillations in the embedding space. By default 16. + """ + + def __init__(self, num_channels: int, scale: int = 16): + super().__init__() + self.register_buffer("freqs", torch.randn(num_channels // 2) * scale) + + def forward(self, x): + x = x.ger((2 * np.pi * self.freqs).to(x.dtype)) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x diff --git a/src/models/preconditioning copy.py b/src/models/preconditioning copy.py new file mode 100644 index 0000000..52a1660 --- /dev/null +++ b/src/models/preconditioning copy.py @@ -0,0 +1,1176 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Preconditioning schemes used in the paper"Elucidating the Design Space of +Diffusion-Based Generative Models". +""" + +import importlib +import warnings +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import nvtx +import torch + +from physicsnemo.models.diffusion import ( + DhariwalUNet, # noqa: F401 for globals + SongUNet, # noqa: F401 for globals +) +from physicsnemo.models.meta import ModelMetaData +from physicsnemo.models.module import Module + +network_module = importlib.import_module("physicsnemo.models.diffusion") + + +@dataclass +class VPPrecondMetaData(ModelMetaData): + """VPPrecond meta data""" + + name: str = "VPPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class VPPrecond(Module): + """ + Preconditioning corresponding to the variance preserving (VP) formulation. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + beta_d : float + Extent of the noise level schedule, by default 19.9. + beta_min : float + Initial slope of the noise level schedule, by default 0.1. + M : int + Original number of timesteps in the DDPM formulation, by default 1000. + epsilon_t : float + Minimum t-value used during training, by default 1e-5. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + beta_d: float = 19.9, + beta_min: float = 0.1, + M: int = 1000, + epsilon_t: float = 1e-5, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__(meta=VPPrecondMetaData) + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.beta_d = beta_d + self.beta_min = beta_min + self.M = M + self.epsilon_t = epsilon_t + self.sigma_min = float(self.sigma(epsilon_t)) + self.sigma_max = float(self.sigma(1)) + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = 1 + c_out = -sigma + c_in = 1 / (sigma**2 + 1).sqrt() + c_noise = (self.M - 1) * self.sigma_inv(sigma) + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def sigma(self, t: Union[float, torch.Tensor]): + """ + Compute the sigma(t) value for a given t based on the VP formulation. + + The function calculates the noise level schedule for the diffusion process based + on the given parameters `beta_d` and `beta_min`. + + Parameters + ---------- + t : Union[float, torch.Tensor] + The timestep or set of timesteps for which to compute sigma(t). + + Returns + ------- + torch.Tensor + The computed sigma(t) value(s). + """ + t = torch.as_tensor(t) + return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt() + + def sigma_inv(self, sigma: Union[float, torch.Tensor]): + """ + Compute the inverse of the sigma function for a given sigma. + + This function effectively calculates t from a given sigma(t) based on the + parameters `beta_d` and `beta_min`. + + Parameters + ---------- + sigma : Union[float, torch.Tensor] + The sigma(t) value or set of sigma(t) values for which to compute the + inverse. + + Returns + ------- + torch.Tensor + The computed t value(s) corresponding to the provided sigma(t). + """ + sigma = torch.as_tensor(sigma) + return ( + (self.beta_min**2 + 2 * self.beta_d * (1 + sigma**2).log()).sqrt() + - self.beta_min + ) / self.beta_d + + def round_sigma(self, sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +@dataclass +class VEPrecondMetaData(ModelMetaData): + """VEPrecond meta data""" + + name: str = "VEPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class VEPrecond(Module): + """ + Preconditioning corresponding to the variance exploding (VE) formulation. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + sigma_min: float = 0.02, + sigma_max: float = 100.0, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__(meta=VEPrecondMetaData) + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = 1 + c_out = sigma + c_in = 1 + c_noise = (0.5 * sigma).log() + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def round_sigma(self, sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +@dataclass +class iDDPMPrecondMetaData(ModelMetaData): + """iDDPMPrecond meta data""" + + name: str = "iDDPMPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class iDDPMPrecond(Module): + """ + Preconditioning corresponding to the improved DDPM (iDDPM) formulation. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + C_1 : float + Timestep adjustment at low noise levels., by default 0.001. + C_2 : float + Timestep adjustment at high noise levels., by default 0.008. + M: int + Original number of timesteps in the DDPM formulation, by default 1000. + model_type :str + Class name of the underlying model, by default "DhariwalUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Nichol, A.Q. and Dhariwal, P., 2021, July. Improved denoising diffusion + probabilistic models. In International Conference on Machine Learning + (pp. 8162-8171). PMLR. + """ + + def __init__( + self, + img_resolution, + img_channels, + label_dim=0, + use_fp16=False, + C_1=0.001, + C_2=0.008, + M=1000, + model_type="DhariwalUNet", + **model_kwargs, + ): + super().__init__(meta=iDDPMPrecondMetaData) + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.C_1 = C_1 + self.C_2 = C_2 + self.M = M + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels * 2, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + u = torch.zeros(M + 1) + for j in range(M, 0, -1): # M, ..., 1 + u[j - 1] = ( + (u[j] ** 2 + 1) + / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1) + - 1 + ).sqrt() + self.register_buffer("u", u) + self.sigma_min = float(u[M - 1]) + self.sigma_max = float(u[0]) + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = 1 + c_out = -sigma + c_in = 1 / (sigma**2 + 1).sqrt() + c_noise = ( + self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32) + ) + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x[:, : self.img_channels].to(torch.float32) + return D_x + + def alpha_bar(self, j): + """ + Compute the alpha_bar(j) value for a given j based on the iDDPM formulation. + + Parameters + ---------- + j : Union[int, torch.Tensor] + The timestep or set of timesteps for which to compute alpha_bar(j). + + Returns + ------- + torch.Tensor + The computed alpha_bar(j) value(s). + """ + j = torch.as_tensor(j) + return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2 + + def round_sigma(self, sigma, return_index=False): + """ + Round the provided sigma value(s) to the nearest value(s) in a + pre-defined set `u`. + + Parameters + ---------- + sigma : Union[float, list, torch.Tensor] + The sigma value(s) to round. + return_index : bool, optional + Whether to return the index/indices of the rounded value(s) in `u` instead + of the rounded value(s) themselves, by default False. + + Returns + ------- + torch.Tensor + The rounded sigma value(s) or their index/indices in `u`, depending on the + value of `return_index`. + """ + sigma = torch.as_tensor(sigma) + index = torch.cdist( + sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), + self.u.reshape(1, -1, 1), + ).argmin(2) + result = index if return_index else self.u[index.flatten()].to(sigma.dtype) + return result.reshape(sigma.shape).to(sigma.device) + + +@dataclass +class EDMPrecondMetaData(ModelMetaData): + """EDMPrecond meta data""" + + name: str = "EDMPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class EDMPrecond(Module): + """ + Improved preconditioning proposed in the paper "Elucidating the Design Space of + Diffusion-Based Generative Models" (EDM) + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels (for both input and output). If your model + requires a different number of input or output chanels, + override this by passing either of the optional + img_in_channels or img_out_channels args + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.0. + sigma_max : float + Maximum supported noise level, by default inf. + sigma_data : float + Expected standard deviation of the training data, by default 0.5. + model_type :str + Class name of the underlying model, by default "DhariwalUNet". + img_in_channels: int + Optional setting for when number of input channels =/= number of output + channels. If set, will override img_channels for the input + This is useful in the case of additional (conditional) channels + img_out_channels: int + Optional setting for when number of input channels =/= number of output + channels. If set, will override img_channels for the output + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + """ + + def __init__( + self, + img_resolution, + img_channels, + label_dim=0, + use_fp16=False, + sigma_min=0.0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="DhariwalUNet", + img_in_channels=None, + img_out_channels=None, + **model_kwargs, + ): + super().__init__(meta=EDMPrecondMetaData) + self.img_resolution = img_resolution + if img_in_channels is not None: + img_in_channels = img_in_channels + else: + img_in_channels = img_channels + if img_out_channels is not None: + img_out_channels = img_out_channels + else: + img_out_channels = img_channels + + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_in_channels, + out_channels=img_out_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward( + self, + x, + sigma, + condition=None, + class_labels=None, + force_fp32=False, + **model_kwargs, + ): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 + + arg = c_in * x + + if condition is not None: + arg = torch.cat([arg, condition], dim=1) + + F_x = self.model( + arg.to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + @staticmethod + def round_sigma(sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +@dataclass +class EDMPrecondSRMetaData(ModelMetaData): + """EDMPrecondSR meta data""" + + name: str = "EDMPrecondSR" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class EDMPrecondSR(Module): + """ + Improved preconditioning proposed in the paper "Elucidating the Design Space of + Diffusion-Based Generative Models" (EDM) for super-resolution tasks + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + img_in_channels : int + Number of input color channels. + img_out_channels : int + Number of output color channels. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.0. + sigma_max : float + Maximum supported noise level, by default inf. + sigma_data : float + Expected standard deviation of the training data, by default 0.5. + model_type :str + Class name of the underlying model, by default "SongUNetPosEmbd". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + References: + - Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + - Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, + img_resolution, + img_channels, + img_in_channels, + img_out_channels, + use_fp16=False, + sigma_min=0.0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="SongUNetPosEmbd", + scale_cond_input=True, + **model_kwargs, + ): + super().__init__(meta=EDMPrecondSRMetaData) + self.img_resolution = img_resolution + self.img_channels = img_channels # TODO: this is not used, remove it + self.img_in_channels = img_in_channels + self.img_out_channels = img_out_channels + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + self.scale_cond_input = scale_cond_input + + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_in_channels + img_out_channels, + out_channels=img_out_channels, + **model_kwargs, + ) # TODO needs better handling + self.scaling_fn = self._get_scaling_fn() + + def _get_scaling_fn(self): + if self.scale_cond_input: + warnings.warn( + "scale_cond_input=True does not properly scale the conditional input. " + "(see https://github.com/NVIDIA/modulus/issues/229). " + "This setup will be deprecated. " + "Please set scale_cond_input=False.", + DeprecationWarning, + ) + return self._legacy_scaling_fn + else: + return self._scaling_fn + + @staticmethod + def _scaling_fn(x, img_lr, c_in): + return torch.cat([c_in * x, img_lr.to(x.dtype)], dim=1) + + @staticmethod + def _legacy_scaling_fn(x, img_lr, c_in): + return c_in * torch.cat([x, img_lr.to(x.dtype)], dim=1) + + @nvtx.annotate(message="EDMPrecondSR", color="orange") + def forward( + self, + x, + img_lr, + sigma, + force_fp32=False, + **model_kwargs, + ): + # Concatenate input channels + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 + + if img_lr is None: + arg = c_in * x + else: + arg = self.scaling_fn(x, img_lr, c_in) + arg = arg.to(dtype) + + F_x = self.model( + arg, + c_noise.flatten(), + class_labels=None, + **model_kwargs, + ) + + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + @staticmethod + def round_sigma(sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + See EDMPrecond.round_sigma + """ + return EDMPrecond.round_sigma(sigma) + + +class VEPrecond_dfsr(torch.nn.Module): + """ + Preconditioning for dfsr model, modified from class VEPrecond, where the input + argument 'sigma' in forward propagation function is used to receive the timestep + of the backward diffusion process. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models. + Advances in neural information processing systems. 2020;33:6840-51. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + sigma_min: float = 0.02, + sigma_max: float = 100.0, + dataset_mean: float = 5.85e-05, + dataset_scale: float = 4.79, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.model = globals()[model_type]( + img_resolution=img_resolution, + in_channels=self.img_channels, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + # print("sigma: ", sigma) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_in = 1 + c_noise = sigma # Change the definitation of c_noise to avoid -inf values for zero sigma + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if F_x.dtype != dtype: + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + return F_x + + +class VEPrecond_dfsr_cond(torch.nn.Module): + """ + Preconditioning for dfsr model with physics-informed conditioning input, modified + from class VEPrecond, where the input argument 'sigma' in forward propagation function + is used to receive the timestep of the backward diffusion process. The gradient of PDE + residual with respect to the vorticity in the governing Navier-Stokes equation is computed + as the physics-informed conditioning variable and is combined with the backward diffusion + timestep before being sent to the underlying model for noise prediction. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: + [1] Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + [2] Shu D, Li Z, Farimani AB. A physics-informed diffusion model for high-fidelity + flow field reconstruction. Journal of Computational Physics. 2023 Apr 1;478:111972. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + sigma_min: float = 0.02, + sigma_max: float = 100.0, + dataset_mean: float = 5.85e-05, + dataset_scale: float = 4.79, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.model = globals()[model_type]( + img_resolution=img_resolution, + in_channels=model_kwargs["model_channels"] * 2, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + # modules to embed residual loss + self.conv_in = torch.nn.Conv2d( + img_channels, + model_kwargs["model_channels"], + kernel_size=3, + stride=1, + padding=1, + padding_mode="circular", + ) + self.emb_conv = torch.nn.Sequential( + torch.nn.Conv2d( + img_channels, + model_kwargs["model_channels"], + kernel_size=1, + stride=1, + padding=0, + ), + torch.nn.GELU(), + torch.nn.Conv2d( + model_kwargs["model_channels"], + model_kwargs["model_channels"], + kernel_size=3, + stride=1, + padding=1, + padding_mode="circular", + ), + ) + self.dataset_mean = dataset_mean + self.dataset_scale = dataset_scale + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_in = 1 + c_noise = sigma + + # Compute physics-informed conditioning information using vorticity residual + dx = ( + self.voriticity_residual((x * self.dataset_scale + self.dataset_mean)) + / self.dataset_scale + ) + x = self.conv_in(x) + cond_emb = self.emb_conv(dx) + x = torch.cat((x, cond_emb), dim=1) + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if F_x.dtype != dtype: + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + return F_x + + def voriticity_residual(self, w, re=1000.0, dt=1 / 32): + """ + Compute the gradient of PDE residual with respect to a given vorticity w using the + spectrum method. + + Parameters + ---------- + w: torch.Tensor + The fluid flow data sample (vorticity). + re: float + The value of Reynolds number used in the governing Navier-Stokes equation. + dt: float + Time step used to compute the time-derivative of vorticity included in the governing + Navier-Stokes equation. + + Returns + ------- + torch.Tensor + The computed vorticity gradient. + """ + + # w [b t h w] + w = w.clone() + w.requires_grad_(True) + nx = w.size(2) + device = w.device + + w_h = torch.fft.fft2(w[:, 1:-1], dim=[2, 3]) + # Wavenumbers in y-direction + k_max = nx // 2 + N = nx + k_x = ( + torch.cat( + ( + torch.arange(start=0, end=k_max, step=1, device=device), + torch.arange(start=-k_max, end=0, step=1, device=device), + ), + 0, + ) + .reshape(N, 1) + .repeat(1, N) + .reshape(1, 1, N, N) + ) + k_y = ( + torch.cat( + ( + torch.arange(start=0, end=k_max, step=1, device=device), + torch.arange(start=-k_max, end=0, step=1, device=device), + ), + 0, + ) + .reshape(1, N) + .repeat(N, 1) + .reshape(1, 1, N, N) + ) + # Negative Laplacian in Fourier space + lap = k_x**2 + k_y**2 + lap[..., 0, 0] = 1.0 + psi_h = w_h / lap + + u_h = 1j * k_y * psi_h + v_h = -1j * k_x * psi_h + wx_h = 1j * k_x * w_h + wy_h = 1j * k_y * w_h + wlap_h = -lap * w_h + + u = torch.fft.irfft2(u_h[..., :, : k_max + 1], dim=[2, 3]) + v = torch.fft.irfft2(v_h[..., :, : k_max + 1], dim=[2, 3]) + wx = torch.fft.irfft2(wx_h[..., :, : k_max + 1], dim=[2, 3]) + wy = torch.fft.irfft2(wy_h[..., :, : k_max + 1], dim=[2, 3]) + wlap = torch.fft.irfft2(wlap_h[..., :, : k_max + 1], dim=[2, 3]) + advection = u * wx + v * wy + + wt = (w[:, 2:, :, :] - w[:, :-2, :, :]) / (2 * dt) + + # establish forcing term + x = torch.linspace(0, 2 * np.pi, nx + 1, device=device) + x = x[0:-1] + X, Y = torch.meshgrid(x, x) + f = -4 * torch.cos(4 * Y) + + residual = wt + (advection - (1.0 / re) * wlap + 0.1 * w[:, 1:-1]) - f + residual_loss = (residual**2).mean() + dw = torch.autograd.grad(residual_loss, w)[0] + + return dw diff --git a/src/models/preconditioning.py b/src/models/preconditioning.py new file mode 100644 index 0000000..52a1660 --- /dev/null +++ b/src/models/preconditioning.py @@ -0,0 +1,1176 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Preconditioning schemes used in the paper"Elucidating the Design Space of +Diffusion-Based Generative Models". +""" + +import importlib +import warnings +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import nvtx +import torch + +from physicsnemo.models.diffusion import ( + DhariwalUNet, # noqa: F401 for globals + SongUNet, # noqa: F401 for globals +) +from physicsnemo.models.meta import ModelMetaData +from physicsnemo.models.module import Module + +network_module = importlib.import_module("physicsnemo.models.diffusion") + + +@dataclass +class VPPrecondMetaData(ModelMetaData): + """VPPrecond meta data""" + + name: str = "VPPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class VPPrecond(Module): + """ + Preconditioning corresponding to the variance preserving (VP) formulation. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + beta_d : float + Extent of the noise level schedule, by default 19.9. + beta_min : float + Initial slope of the noise level schedule, by default 0.1. + M : int + Original number of timesteps in the DDPM formulation, by default 1000. + epsilon_t : float + Minimum t-value used during training, by default 1e-5. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + beta_d: float = 19.9, + beta_min: float = 0.1, + M: int = 1000, + epsilon_t: float = 1e-5, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__(meta=VPPrecondMetaData) + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.beta_d = beta_d + self.beta_min = beta_min + self.M = M + self.epsilon_t = epsilon_t + self.sigma_min = float(self.sigma(epsilon_t)) + self.sigma_max = float(self.sigma(1)) + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = 1 + c_out = -sigma + c_in = 1 / (sigma**2 + 1).sqrt() + c_noise = (self.M - 1) * self.sigma_inv(sigma) + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def sigma(self, t: Union[float, torch.Tensor]): + """ + Compute the sigma(t) value for a given t based on the VP formulation. + + The function calculates the noise level schedule for the diffusion process based + on the given parameters `beta_d` and `beta_min`. + + Parameters + ---------- + t : Union[float, torch.Tensor] + The timestep or set of timesteps for which to compute sigma(t). + + Returns + ------- + torch.Tensor + The computed sigma(t) value(s). + """ + t = torch.as_tensor(t) + return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt() + + def sigma_inv(self, sigma: Union[float, torch.Tensor]): + """ + Compute the inverse of the sigma function for a given sigma. + + This function effectively calculates t from a given sigma(t) based on the + parameters `beta_d` and `beta_min`. + + Parameters + ---------- + sigma : Union[float, torch.Tensor] + The sigma(t) value or set of sigma(t) values for which to compute the + inverse. + + Returns + ------- + torch.Tensor + The computed t value(s) corresponding to the provided sigma(t). + """ + sigma = torch.as_tensor(sigma) + return ( + (self.beta_min**2 + 2 * self.beta_d * (1 + sigma**2).log()).sqrt() + - self.beta_min + ) / self.beta_d + + def round_sigma(self, sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +@dataclass +class VEPrecondMetaData(ModelMetaData): + """VEPrecond meta data""" + + name: str = "VEPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class VEPrecond(Module): + """ + Preconditioning corresponding to the variance exploding (VE) formulation. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + sigma_min: float = 0.02, + sigma_max: float = 100.0, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__(meta=VEPrecondMetaData) + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = 1 + c_out = sigma + c_in = 1 + c_noise = (0.5 * sigma).log() + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def round_sigma(self, sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +@dataclass +class iDDPMPrecondMetaData(ModelMetaData): + """iDDPMPrecond meta data""" + + name: str = "iDDPMPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class iDDPMPrecond(Module): + """ + Preconditioning corresponding to the improved DDPM (iDDPM) formulation. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + C_1 : float + Timestep adjustment at low noise levels., by default 0.001. + C_2 : float + Timestep adjustment at high noise levels., by default 0.008. + M: int + Original number of timesteps in the DDPM formulation, by default 1000. + model_type :str + Class name of the underlying model, by default "DhariwalUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Nichol, A.Q. and Dhariwal, P., 2021, July. Improved denoising diffusion + probabilistic models. In International Conference on Machine Learning + (pp. 8162-8171). PMLR. + """ + + def __init__( + self, + img_resolution, + img_channels, + label_dim=0, + use_fp16=False, + C_1=0.001, + C_2=0.008, + M=1000, + model_type="DhariwalUNet", + **model_kwargs, + ): + super().__init__(meta=iDDPMPrecondMetaData) + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.C_1 = C_1 + self.C_2 = C_2 + self.M = M + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_channels, + out_channels=img_channels * 2, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + u = torch.zeros(M + 1) + for j in range(M, 0, -1): # M, ..., 1 + u[j - 1] = ( + (u[j] ** 2 + 1) + / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1) + - 1 + ).sqrt() + self.register_buffer("u", u) + self.sigma_min = float(u[M - 1]) + self.sigma_max = float(u[0]) + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = 1 + c_out = -sigma + c_in = 1 / (sigma**2 + 1).sqrt() + c_noise = ( + self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32) + ) + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x[:, : self.img_channels].to(torch.float32) + return D_x + + def alpha_bar(self, j): + """ + Compute the alpha_bar(j) value for a given j based on the iDDPM formulation. + + Parameters + ---------- + j : Union[int, torch.Tensor] + The timestep or set of timesteps for which to compute alpha_bar(j). + + Returns + ------- + torch.Tensor + The computed alpha_bar(j) value(s). + """ + j = torch.as_tensor(j) + return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2 + + def round_sigma(self, sigma, return_index=False): + """ + Round the provided sigma value(s) to the nearest value(s) in a + pre-defined set `u`. + + Parameters + ---------- + sigma : Union[float, list, torch.Tensor] + The sigma value(s) to round. + return_index : bool, optional + Whether to return the index/indices of the rounded value(s) in `u` instead + of the rounded value(s) themselves, by default False. + + Returns + ------- + torch.Tensor + The rounded sigma value(s) or their index/indices in `u`, depending on the + value of `return_index`. + """ + sigma = torch.as_tensor(sigma) + index = torch.cdist( + sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), + self.u.reshape(1, -1, 1), + ).argmin(2) + result = index if return_index else self.u[index.flatten()].to(sigma.dtype) + return result.reshape(sigma.shape).to(sigma.device) + + +@dataclass +class EDMPrecondMetaData(ModelMetaData): + """EDMPrecond meta data""" + + name: str = "EDMPrecond" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class EDMPrecond(Module): + """ + Improved preconditioning proposed in the paper "Elucidating the Design Space of + Diffusion-Based Generative Models" (EDM) + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels (for both input and output). If your model + requires a different number of input or output chanels, + override this by passing either of the optional + img_in_channels or img_out_channels args + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.0. + sigma_max : float + Maximum supported noise level, by default inf. + sigma_data : float + Expected standard deviation of the training data, by default 0.5. + model_type :str + Class name of the underlying model, by default "DhariwalUNet". + img_in_channels: int + Optional setting for when number of input channels =/= number of output + channels. If set, will override img_channels for the input + This is useful in the case of additional (conditional) channels + img_out_channels: int + Optional setting for when number of input channels =/= number of output + channels. If set, will override img_channels for the output + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + """ + + def __init__( + self, + img_resolution, + img_channels, + label_dim=0, + use_fp16=False, + sigma_min=0.0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="DhariwalUNet", + img_in_channels=None, + img_out_channels=None, + **model_kwargs, + ): + super().__init__(meta=EDMPrecondMetaData) + self.img_resolution = img_resolution + if img_in_channels is not None: + img_in_channels = img_in_channels + else: + img_in_channels = img_channels + if img_out_channels is not None: + img_out_channels = img_out_channels + else: + img_out_channels = img_channels + + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_in_channels, + out_channels=img_out_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward( + self, + x, + sigma, + condition=None, + class_labels=None, + force_fp32=False, + **model_kwargs, + ): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 + + arg = c_in * x + + if condition is not None: + arg = torch.cat([arg, condition], dim=1) + + F_x = self.model( + arg.to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + @staticmethod + def round_sigma(sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +@dataclass +class EDMPrecondSRMetaData(ModelMetaData): + """EDMPrecondSR meta data""" + + name: str = "EDMPrecondSR" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class EDMPrecondSR(Module): + """ + Improved preconditioning proposed in the paper "Elucidating the Design Space of + Diffusion-Based Generative Models" (EDM) for super-resolution tasks + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + img_in_channels : int + Number of input color channels. + img_out_channels : int + Number of output color channels. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.0. + sigma_max : float + Maximum supported noise level, by default inf. + sigma_data : float + Expected standard deviation of the training data, by default 0.5. + model_type :str + Class name of the underlying model, by default "SongUNetPosEmbd". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + References: + - Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + - Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, + img_resolution, + img_channels, + img_in_channels, + img_out_channels, + use_fp16=False, + sigma_min=0.0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="SongUNetPosEmbd", + scale_cond_input=True, + **model_kwargs, + ): + super().__init__(meta=EDMPrecondSRMetaData) + self.img_resolution = img_resolution + self.img_channels = img_channels # TODO: this is not used, remove it + self.img_in_channels = img_in_channels + self.img_out_channels = img_out_channels + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + self.scale_cond_input = scale_cond_input + + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_in_channels + img_out_channels, + out_channels=img_out_channels, + **model_kwargs, + ) # TODO needs better handling + self.scaling_fn = self._get_scaling_fn() + + def _get_scaling_fn(self): + if self.scale_cond_input: + warnings.warn( + "scale_cond_input=True does not properly scale the conditional input. " + "(see https://github.com/NVIDIA/modulus/issues/229). " + "This setup will be deprecated. " + "Please set scale_cond_input=False.", + DeprecationWarning, + ) + return self._legacy_scaling_fn + else: + return self._scaling_fn + + @staticmethod + def _scaling_fn(x, img_lr, c_in): + return torch.cat([c_in * x, img_lr.to(x.dtype)], dim=1) + + @staticmethod + def _legacy_scaling_fn(x, img_lr, c_in): + return c_in * torch.cat([x, img_lr.to(x.dtype)], dim=1) + + @nvtx.annotate(message="EDMPrecondSR", color="orange") + def forward( + self, + x, + img_lr, + sigma, + force_fp32=False, + **model_kwargs, + ): + # Concatenate input channels + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 + + if img_lr is None: + arg = c_in * x + else: + arg = self.scaling_fn(x, img_lr, c_in) + arg = arg.to(dtype) + + F_x = self.model( + arg, + c_noise.flatten(), + class_labels=None, + **model_kwargs, + ) + + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + @staticmethod + def round_sigma(sigma: Union[float, List, torch.Tensor]): + """ + Convert a given sigma value(s) to a tensor representation. + See EDMPrecond.round_sigma + """ + return EDMPrecond.round_sigma(sigma) + + +class VEPrecond_dfsr(torch.nn.Module): + """ + Preconditioning for dfsr model, modified from class VEPrecond, where the input + argument 'sigma' in forward propagation function is used to receive the timestep + of the backward diffusion process. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models. + Advances in neural information processing systems. 2020;33:6840-51. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + sigma_min: float = 0.02, + sigma_max: float = 100.0, + dataset_mean: float = 5.85e-05, + dataset_scale: float = 4.79, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.model = globals()[model_type]( + img_resolution=img_resolution, + in_channels=self.img_channels, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + # print("sigma: ", sigma) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_in = 1 + c_noise = sigma # Change the definitation of c_noise to avoid -inf values for zero sigma + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if F_x.dtype != dtype: + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + return F_x + + +class VEPrecond_dfsr_cond(torch.nn.Module): + """ + Preconditioning for dfsr model with physics-informed conditioning input, modified + from class VEPrecond, where the input argument 'sigma' in forward propagation function + is used to receive the timestep of the backward diffusion process. The gradient of PDE + residual with respect to the vorticity in the governing Navier-Stokes equation is computed + as the physics-informed conditioning variable and is combined with the backward diffusion + timestep before being sent to the underlying model for noise prediction. + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + label_dim : int + Number of class labels, 0 = unconditional, by default 0. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.02. + sigma_max : float + Maximum supported noise level, by default 100.0. + model_type :str + Class name of the underlying model, by default "SongUNet". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + Reference: + [1] Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + [2] Shu D, Li Z, Farimani AB. A physics-informed diffusion model for high-fidelity + flow field reconstruction. Journal of Computational Physics. 2023 Apr 1;478:111972. + """ + + def __init__( + self, + img_resolution: int, + img_channels: int, + label_dim: int = 0, + use_fp16: bool = False, + sigma_min: float = 0.02, + sigma_max: float = 100.0, + dataset_mean: float = 5.85e-05, + dataset_scale: float = 4.79, + model_type: str = "SongUNet", + **model_kwargs: dict, + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.model = globals()[model_type]( + img_resolution=img_resolution, + in_channels=model_kwargs["model_channels"] * 2, + out_channels=img_channels, + label_dim=label_dim, + **model_kwargs, + ) # TODO needs better handling + + # modules to embed residual loss + self.conv_in = torch.nn.Conv2d( + img_channels, + model_kwargs["model_channels"], + kernel_size=3, + stride=1, + padding=1, + padding_mode="circular", + ) + self.emb_conv = torch.nn.Sequential( + torch.nn.Conv2d( + img_channels, + model_kwargs["model_channels"], + kernel_size=1, + stride=1, + padding=0, + ), + torch.nn.GELU(), + torch.nn.Conv2d( + model_kwargs["model_channels"], + model_kwargs["model_channels"], + kernel_size=3, + stride=1, + padding=1, + padding_mode="circular", + ), + ) + self.dataset_mean = dataset_mean + self.dataset_scale = dataset_scale + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if class_labels is None + else class_labels.to(torch.float32).reshape(-1, self.label_dim) + ) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + c_in = 1 + c_noise = sigma + + # Compute physics-informed conditioning information using vorticity residual + dx = ( + self.voriticity_residual((x * self.dataset_scale + self.dataset_mean)) + / self.dataset_scale + ) + x = self.conv_in(x) + cond_emb = self.emb_conv(dx) + x = torch.cat((x, cond_emb), dim=1) + + F_x = self.model( + (c_in * x).to(dtype), + c_noise.flatten(), + class_labels=class_labels, + **model_kwargs, + ) + + if F_x.dtype != dtype: + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + return F_x + + def voriticity_residual(self, w, re=1000.0, dt=1 / 32): + """ + Compute the gradient of PDE residual with respect to a given vorticity w using the + spectrum method. + + Parameters + ---------- + w: torch.Tensor + The fluid flow data sample (vorticity). + re: float + The value of Reynolds number used in the governing Navier-Stokes equation. + dt: float + Time step used to compute the time-derivative of vorticity included in the governing + Navier-Stokes equation. + + Returns + ------- + torch.Tensor + The computed vorticity gradient. + """ + + # w [b t h w] + w = w.clone() + w.requires_grad_(True) + nx = w.size(2) + device = w.device + + w_h = torch.fft.fft2(w[:, 1:-1], dim=[2, 3]) + # Wavenumbers in y-direction + k_max = nx // 2 + N = nx + k_x = ( + torch.cat( + ( + torch.arange(start=0, end=k_max, step=1, device=device), + torch.arange(start=-k_max, end=0, step=1, device=device), + ), + 0, + ) + .reshape(N, 1) + .repeat(1, N) + .reshape(1, 1, N, N) + ) + k_y = ( + torch.cat( + ( + torch.arange(start=0, end=k_max, step=1, device=device), + torch.arange(start=-k_max, end=0, step=1, device=device), + ), + 0, + ) + .reshape(1, N) + .repeat(N, 1) + .reshape(1, 1, N, N) + ) + # Negative Laplacian in Fourier space + lap = k_x**2 + k_y**2 + lap[..., 0, 0] = 1.0 + psi_h = w_h / lap + + u_h = 1j * k_y * psi_h + v_h = -1j * k_x * psi_h + wx_h = 1j * k_x * w_h + wy_h = 1j * k_y * w_h + wlap_h = -lap * w_h + + u = torch.fft.irfft2(u_h[..., :, : k_max + 1], dim=[2, 3]) + v = torch.fft.irfft2(v_h[..., :, : k_max + 1], dim=[2, 3]) + wx = torch.fft.irfft2(wx_h[..., :, : k_max + 1], dim=[2, 3]) + wy = torch.fft.irfft2(wy_h[..., :, : k_max + 1], dim=[2, 3]) + wlap = torch.fft.irfft2(wlap_h[..., :, : k_max + 1], dim=[2, 3]) + advection = u * wx + v * wy + + wt = (w[:, 2:, :, :] - w[:, :-2, :, :]) / (2 * dt) + + # establish forcing term + x = torch.linspace(0, 2 * np.pi, nx + 1, device=device) + x = x[0:-1] + X, Y = torch.meshgrid(x, x) + f = -4 * torch.cos(4 * Y) + + residual = wt + (advection - (1.0 / re) * wlap + 0.1 * w[:, 1:-1]) - f + residual_loss = (residual**2).mean() + dw = torch.autograd.grad(residual_loss, w)[0] + + return dw diff --git a/src/models/song_unet.py b/src/models/song_unet.py new file mode 100644 index 0000000..d38484b --- /dev/null +++ b/src/models/song_unet.py @@ -0,0 +1,906 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model architectures used in the paper "Elucidating the Design Space of +Diffusion-Based Generative Models". +""" + +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import nvtx +import torch +from torch.nn.functional import silu +from torch.utils.checkpoint import checkpoint + +from physicsnemo.models.diffusion import ( + Conv2d, + FourierEmbedding, + GroupNorm, + Linear, + PositionalEmbedding, + UNetBlock, +) +from physicsnemo.models.meta import ModelMetaData +from physicsnemo.models.module import Module + + +@dataclass +class MetaData(ModelMetaData): + name: str = "SongUNet" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = True + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class SongUNet(Module): + """ + Reimplementation of the DDPM++ and NCSN++ architectures, U-Net variants with + optional self-attention, embeddings, and encoder-decoder components. + + This model supports conditional and unconditional setups, as well as several + options for various internal architectural choices such as encoder and decoder + type, embedding type, etc., making it flexible and adaptable to different tasks + and configurations. + + Parameters + ----------- + img_resolution : Union[List[int], int] + The resolution of the input/output image, 1 value represents a square image. + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels in the output image. + label_dim : int, optional + Number of class labels; 0 indicates an unconditional model. By default 0. + augment_dim : int, optional + Dimensionality of augmentation labels; 0 means no augmentation. By default 0. + model_channels : int, optional + Base multiplier for the number of channels across the network, by default 128. + channel_mult : List[int], optional + Per-resolution multipliers for the number of channels. By default [1,2,2,2]. + channel_mult_emb : int, optional + Multiplier for the dimensionality of the embedding vector. By default 4. + num_blocks : int, optional + Number of residual blocks per resolution. By default 4. + attn_resolutions : List[int], optional + Resolutions at which self-attention layers are applied. By default [16]. + dropout : float, optional + Dropout probability applied to intermediate activations. By default 0.10. + label_dropout : float, optional + Dropout probability of class labels for classifier-free guidance. By default 0.0. + embedding_type : str, optional + Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++, 'zero' for none + By default 'positional'. + channel_mult_noise : int, optional + Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. + encoder_type : str, optional + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. By default + 'standard'. + decoder_type : str, optional + Decoder architecture: 'standard' for both DDPM++ and NCSN++. By default + 'standard'. + resample_filter : List[int], optional (default=[1,1]) + Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. + checkpoint_level : int, optional (default=0) + How many layers should use gradient checkpointing, 0 is None + additive_pos_embed: bool = False, + Set to True to add a learned position embedding after the first conv (used in StormCast) + + + Reference + ---------- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + + Note + ----- + Equivalent to the original implementation by Song et al., available at + https://github.com/yang-song/score_sde_pytorch + + Example + -------- + >>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> noise_labels = torch.randn([1]) + >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> input_image = torch.ones([1, 2, 16, 16]) + >>> output_image = model(input_image, noise_labels, class_labels) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + """ + + def __init__( + self, + img_resolution: Union[List[int], int], + in_channels: int, + out_channels: int, + label_dim: int = 0, + augment_dim: int = 0, + model_channels: int = 128, + channel_mult: List[int] = [1, 2, 2, 2], + channel_mult_emb: int = 4, + num_blocks: int = 4, + attn_resolutions: List[int] = [16], + dropout: float = 0.10, + label_dropout: float = 0.0, + embedding_type: str = "positional", + channel_mult_noise: int = 1, + encoder_type: str = "standard", + decoder_type: str = "standard", + resample_filter: List[int] = [1, 1], + checkpoint_level: int = 0, + additive_pos_embed: bool = False, + ): + valid_embedding_types = ["fourier", "positional", "zero"] + if embedding_type not in valid_embedding_types: + raise ValueError( + f"Invalid embedding_type: {embedding_type}. Must be one of {valid_embedding_types}." + ) + + valid_encoder_types = ["standard", "skip", "residual"] + if encoder_type not in valid_encoder_types: + raise ValueError( + f"Invalid encoder_type: {encoder_type}. Must be one of {valid_encoder_types}." + ) + + valid_decoder_types = ["standard", "skip"] + if decoder_type not in valid_decoder_types: + raise ValueError( + f"Invalid decoder_type: {decoder_type}. Must be one of {valid_decoder_types}." + ) + + super().__init__(meta=MetaData()) + self.label_dropout = label_dropout + self.embedding_type = embedding_type + emb_channels = model_channels * channel_mult_emb + self.emb_channels = emb_channels + noise_channels = model_channels * channel_mult_noise + init = dict(init_mode="xavier_uniform") + init_zero = dict(init_mode="xavier_uniform", init_weight=1e-5) + init_attn = dict(init_mode="xavier_uniform", init_weight=np.sqrt(0.2)) + block_kwargs = dict( + emb_channels=emb_channels, + num_heads=1, + dropout=dropout, + skip_scale=np.sqrt(0.5), + eps=1e-6, + resample_filter=resample_filter, + resample_proj=True, + adaptive_scale=False, + init=init, + init_zero=init_zero, + init_attn=init_attn, + ) + + # for compatibility with older versions that took only 1 dimension + self.img_resolution = img_resolution + if isinstance(img_resolution, int): + self.img_shape_y = self.img_shape_x = img_resolution + else: + self.img_shape_y = img_resolution[0] + self.img_shape_x = img_resolution[1] + + # set the threshold for checkpointing based on image resolution + self.checkpoint_threshold = (self.img_shape_y >> checkpoint_level) + 1 + + # Optional additive learned positition embed after the first conv + self.additive_pos_embed = additive_pos_embed + if self.additive_pos_embed: + self.spatial_emb = torch.nn.Parameter( + torch.randn(1, model_channels, self.img_shape_y, self.img_shape_x) + ) + torch.nn.init.trunc_normal_(self.spatial_emb, std=0.02) + + # Mapping. + if self.embedding_type != "zero": + self.map_noise = ( + PositionalEmbedding(num_channels=noise_channels, endpoint=True) + if embedding_type == "positional" + else FourierEmbedding(num_channels=noise_channels) + ) + self.map_label = ( + Linear(in_features=label_dim, out_features=noise_channels, **init) + if label_dim + else None + ) + self.map_augment = ( + Linear( + in_features=augment_dim, + out_features=noise_channels, + bias=False, + **init, + ) + if augment_dim + else None + ) + self.map_layer0 = Linear( + in_features=noise_channels, out_features=emb_channels, **init + ) + self.map_layer1 = Linear( + in_features=emb_channels, out_features=emb_channels, **init + ) + + # Encoder. + self.enc = torch.nn.ModuleDict() + cout = in_channels + caux = in_channels + for level, mult in enumerate(channel_mult): + res = self.img_shape_y >> level + if level == 0: + cin = cout + cout = model_channels + self.enc[f"{res}x{res}_conv"] = Conv2d( + in_channels=cin, out_channels=cout, kernel=3, **init + ) + else: + self.enc[f"{res}x{res}_down"] = UNetBlock( + in_channels=cout, out_channels=cout, down=True, **block_kwargs + ) + if encoder_type == "skip": + self.enc[f"{res}x{res}_aux_down"] = Conv2d( + in_channels=caux, + out_channels=caux, + kernel=0, + down=True, + resample_filter=resample_filter, + ) + self.enc[f"{res}x{res}_aux_skip"] = Conv2d( + in_channels=caux, out_channels=cout, kernel=1, **init + ) + if encoder_type == "residual": + self.enc[f"{res}x{res}_aux_residual"] = Conv2d( + in_channels=caux, + out_channels=cout, + kernel=3, + down=True, + resample_filter=resample_filter, + fused_resample=True, + **init, + ) + caux = cout + for idx in range(num_blocks): + cin = cout + cout = model_channels * mult + attn = res in attn_resolutions + self.enc[f"{res}x{res}_block{idx}"] = UNetBlock( + in_channels=cin, out_channels=cout, attention=attn, **block_kwargs + ) + skips = [ + block.out_channels for name, block in self.enc.items() if "aux" not in name + ] + + # Decoder. + self.dec = torch.nn.ModuleDict() + for level, mult in reversed(list(enumerate(channel_mult))): + res = self.img_shape_y >> level + if level == len(channel_mult) - 1: + self.dec[f"{res}x{res}_in0"] = UNetBlock( + in_channels=cout, out_channels=cout, attention=True, **block_kwargs + ) + self.dec[f"{res}x{res}_in1"] = UNetBlock( + in_channels=cout, out_channels=cout, **block_kwargs + ) + else: + self.dec[f"{res}x{res}_up"] = UNetBlock( + in_channels=cout, out_channels=cout, up=True, **block_kwargs + ) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = model_channels * mult + attn = idx == num_blocks and res in attn_resolutions + self.dec[f"{res}x{res}_block{idx}"] = UNetBlock( + in_channels=cin, out_channels=cout, attention=attn, **block_kwargs + ) + if decoder_type == "skip" or level == 0: + if decoder_type == "skip" and level < len(channel_mult) - 1: + self.dec[f"{res}x{res}_aux_up"] = Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel=0, + up=True, + resample_filter=resample_filter, + ) + self.dec[f"{res}x{res}_aux_norm"] = GroupNorm( + num_channels=cout, eps=1e-6 + ) + self.dec[f"{res}x{res}_aux_conv"] = Conv2d( + in_channels=cout, out_channels=out_channels, kernel=3, **init_zero + ) + + @nvtx.annotate(message="SongUNet", color="blue") + def forward(self, x, noise_labels, class_labels, augment_labels=None): + if self.embedding_type != "zero": + # Mapping. + emb = self.map_noise(noise_labels) + emb = ( + emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) + ) # swap sin/cos + if self.map_label is not None: + tmp = class_labels + if self.training and self.label_dropout: + tmp = tmp * ( + torch.rand([x.shape[0], 1], device=x.device) + >= self.label_dropout + ).to(tmp.dtype) + emb = emb + self.map_label(tmp * np.sqrt(self.map_label.in_features)) + if self.map_augment is not None and augment_labels is not None: + emb = emb + self.map_augment(augment_labels) + emb = silu(self.map_layer0(emb)) + emb = silu(self.map_layer1(emb)) + else: + emb = torch.zeros( + (noise_labels.shape[0], self.emb_channels), device=x.device + ) + + # Encoder. + skips = [] + aux = x + for name, block in self.enc.items(): + with nvtx.annotate(f"SongUNet encoder: {name}", color="blue"): + if "aux_down" in name: + aux = block(aux) + elif "aux_skip" in name: + x = skips[-1] = x + block(aux) + elif "aux_residual" in name: + x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) + elif "_conv" in name: + x = block(x) + if self.additive_pos_embed: + x = x + self.spatial_emb.to(dtype=x.dtype) + skips.append(x) + else: + # For UNetBlocks check if we should use gradient checkpointing + if isinstance(block, UNetBlock): + if x.shape[-1] > self.checkpoint_threshold: + x = checkpoint(block, x, emb, use_reentrant=False) + else: + x = block(x, emb) + else: + x = block(x) + skips.append(x) + + # Decoder. + aux = None + tmp = None + for name, block in self.dec.items(): + with nvtx.annotate(f"SongUNet decoder: {name}", color="blue"): + if "aux_up" in name: + aux = block(aux) + elif "aux_norm" in name: + tmp = block(x) + elif "aux_conv" in name: + tmp = block(silu(tmp)) + aux = tmp if aux is None else tmp + aux + else: + if x.shape[1] != block.in_channels: + x = torch.cat([x, skips.pop()], dim=1) + # check for checkpointing on decoder blocks and up sampling blocks + if ( + x.shape[-1] > self.checkpoint_threshold and "_block" in name + ) or ( + x.shape[-1] > (self.checkpoint_threshold / 2) and "_up" in name + ): + x = checkpoint(block, x, emb, use_reentrant=False) + else: + x = block(x, emb) + return aux + + +class SongUNetPosEmbd(SongUNet): + """ + Reimplementation of the DDPM++ and NCSN++ architectures, U-Net variants with + optional self-attention,embeddings, and encoder-decoder components. + + This model supports conditional and unconditional setups, as well as several + options for various internal architectural choices such as encoder and decoder + type, embedding type, etc., making it flexible and adaptable to different tasks + and configurations. + + Parameters + ----------- + img_resolution : Union[List[int], int] + The resolution of the input/output image, 1 value represents a square image. + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels in the output image. + label_dim : int, optional + Number of class labels; 0 indicates an unconditional model. By default 0. + augment_dim : int, optional + Dimensionality of augmentation labels; 0 means no augmentation. By default 0. + model_channels : int, optional + Base multiplier for the number of channels across the network, by default 128. + channel_mult : List[int], optional + Per-resolution multipliers for the number of channels. By default [1,2,2,2]. + channel_mult_emb : int, optional + Multiplier for the dimensionality of the embedding vector. By default 4. + num_blocks : int, optional + Number of residual blocks per resolution. By default 4. + attn_resolutions : List[int], optional + Resolutions at which self-attention layers are applied. By default [16]. + dropout : float, optional + Dropout probability applied to intermediate activations. By default 0.13. + label_dropout : float, optional + Dropout probability of class labels for classifier-free guidance. By default 0.0. + embedding_type : str, optional + Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. + By default 'positional'. + channel_mult_noise : int, optional + Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. + encoder_type : str, optional + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. By default + 'standard'. + decoder_type : str, optional + Decoder architecture: 'standard' for both DDPM++ and NCSN++. By default + 'standard'. + resample_filter : List[int], optional (default=[1,1]) + Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. + + + Reference + ---------- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + + Note + ----- + Equivalent to the original implementation by Song et al., available at + https://github.com/yang-song/score_sde_pytorch + + Example + -------- + >>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> noise_labels = torch.randn([1]) + >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> input_image = torch.ones([1, 2, 16, 16]) + >>> output_image = model(input_image, noise_labels, class_labels) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + """ + + def __init__( + self, + img_resolution: Union[List[int], int], + in_channels: int, + out_channels: int, + label_dim: int = 0, + augment_dim: int = 0, + model_channels: int = 128, + channel_mult: List[int] = [1, 2, 2, 2, 2], + channel_mult_emb: int = 4, + num_blocks: int = 4, + attn_resolutions: List[int] = [28], + dropout: float = 0.13, + label_dropout: float = 0.0, + embedding_type: str = "positional", + channel_mult_noise: int = 1, + encoder_type: str = "standard", + decoder_type: str = "standard", + resample_filter: List[int] = [1, 1], + gridtype: str = "sinusoidal", + N_grid_channels: int = 4, + checkpoint_level: int = 0, + ): + super().__init__( + img_resolution, + in_channels, + out_channels, + label_dim, + augment_dim, + model_channels, + channel_mult, + channel_mult_emb, + num_blocks, + attn_resolutions, + dropout, + label_dropout, + embedding_type, + channel_mult_noise, + encoder_type, + decoder_type, + resample_filter, + checkpoint_level, + ) + + self.gridtype = gridtype + self.N_grid_channels = N_grid_channels + self.pos_embd = self._get_positional_embedding() + + @nvtx.annotate(message="SongUNet", color="blue") + def forward( + self, x, noise_labels, class_labels, global_index=None, augment_labels=None + ): + # append positional embedding to input conditioning + if self.pos_embd is not None: + selected_pos_embd = self.positional_embedding_indexing(x, global_index) + x = torch.cat((x, selected_pos_embd), dim=1) + + return super().forward(x, noise_labels, class_labels, augment_labels) + + def positional_embedding_indexing(self, x, global_index): + if global_index is None: + selected_pos_embd = ( + self.pos_embd.to(x.dtype) + .to(x.device)[None] + .expand((x.shape[0], -1, -1, -1)) + ) + else: + B = global_index.shape[0] + X = global_index.shape[2] + Y = global_index.shape[3] + global_index = torch.reshape( + torch.permute(global_index, (1, 0, 2, 3)), (2, -1) + ) # (B, 2, X, Y) to (2, B*X*Y) + selected_pos_embd = self.pos_embd.to(x.device)[ + :, global_index[0], global_index[1] + ] # (N_pe, B*X*Y) + selected_pos_embd = ( + torch.permute( + torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], B, X, Y)), + (1, 0, 2, 3), + ) + .to(x.device) + .to(x.dtype) + ) # (B, N_pe, X, Y) + return selected_pos_embd + + def _get_positional_embedding(self): + if self.N_grid_channels == 0: + return None + elif self.gridtype == "learnable": + grid = torch.nn.Parameter( + torch.randn(self.N_grid_channels, self.img_shape_y, self.img_shape_x) + ) + elif self.gridtype == "linear": + if self.N_grid_channels != 2: + raise ValueError("N_grid_channels must be set to 2 for gridtype linear") + x = np.meshgrid(np.linspace(-1, 1, self.img_shape_y)) + y = np.meshgrid(np.linspace(-1, 1, self.img_shape_x)) + grid_x, grid_y = np.meshgrid(y, x) + grid = torch.from_numpy(np.stack((grid_x, grid_y), axis=0)) + grid.requires_grad = False + elif self.gridtype == "sinusoidal" and self.N_grid_channels == 4: + # print('sinusuidal grid added ......') + x1 = np.meshgrid(np.sin(np.linspace(0, 2 * np.pi, self.img_shape_y))) + x2 = np.meshgrid(np.cos(np.linspace(0, 2 * np.pi, self.img_shape_y))) + y1 = np.meshgrid(np.sin(np.linspace(0, 2 * np.pi, self.img_shape_x))) + y2 = np.meshgrid(np.cos(np.linspace(0, 2 * np.pi, self.img_shape_x))) + grid_x1, grid_y1 = np.meshgrid(y1, x1) + grid_x2, grid_y2 = np.meshgrid(y2, x2) + grid = torch.squeeze( + torch.from_numpy( + np.expand_dims( + np.stack((grid_x1, grid_y1, grid_x2, grid_y2), axis=0), axis=0 + ) + ) + ) + grid.requires_grad = False + elif self.gridtype == "sinusoidal" and self.N_grid_channels != 4: + if self.N_grid_channels % 4 != 0: + raise ValueError("N_grid_channels must be a factor of 4") + num_freq = self.N_grid_channels // 4 + freq_bands = 2.0 ** np.linspace(0.0, num_freq, num=num_freq) + grid_list = [] + grid_x, grid_y = np.meshgrid( + np.linspace(0, 2 * np.pi, self.img_shape_x), + np.linspace(0, 2 * np.pi, self.img_shape_y), + ) + for freq in freq_bands: + for p_fn in [np.sin, np.cos]: + grid_list.append(p_fn(grid_x * freq)) + grid_list.append(p_fn(grid_y * freq)) + grid = torch.from_numpy(np.stack(grid_list, axis=0)) + grid.requires_grad = False + elif self.gridtype == "test" and self.N_grid_channels == 2: + idx_x = torch.arange(self.img_shape_y) + idx_y = torch.arange(self.img_shape_x) + mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y) + grid = torch.stack((mesh_x, mesh_y), dim=0) + else: + raise ValueError("Gridtype not supported.") + return grid + + +class SongUNetPosLtEmbd(SongUNet): + """ + This model is adapated from SongUNetPosEmbd, with the incoporatation of lead-time aware + embedding for the GEFS-HRRR model. The lead-time embedding is activated by setting the + lead_time_channels and lead_time_steps parameters. + + Parameters + ----------- + img_resolution : Union[List[int], int] + The resolution of the input/output image, 1 value represents a square image. + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels in the output image. + label_dim : int, optional + Number of class labels; 0 indicates an unconditional model. By default 0. + augment_dim : int, optional + Dimensionality of augmentation labels; 0 means no augmentation. By default 0. + model_channels : int, optional + Base multiplier for the number of channels across the network, by default 128. + channel_mult : List[int], optional + Per-resolution multipliers for the number of channels. By default [1,2,2,2]. + channel_mult_emb : int, optional + Multiplier for the dimensionality of the embedding vector. By default 4. + num_blocks : int, optional + Number of residual blocks per resolution. By default 4. + attn_resolutions : List[int], optional + Resolutions at which self-attention layers are applied. By default [16]. + dropout : float, optional + Dropout probability applied to intermediate activations. By default 0.13. + label_dropout : float, optional + Dropout probability of class labels for classifier-free guidance. By default 0.0. + embedding_type : str, optional + Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. + By default 'positional'. + channel_mult_noise : int, optional + Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. + encoder_type : str, optional + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. By default + 'standard'. + decoder_type : str, optional + Decoder architecture: 'standard' for both DDPM++ and NCSN++. By default + 'standard'. + resample_filter : List[int], optional (default=[1,1]) + Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. + lead_time_channels: int, optional + Length of lead time embedding vector + lead_time_steps: int, optional + Total number of lead times + + + Reference + ---------- + Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Poole, B., 2020. Score-based generative modeling through stochastic differential + equations. arXiv preprint arXiv:2011.13456. + + Note + ----- + Equivalent to the original implementation by Song et al., available at + https://github.com/yang-song/score_sde_pytorch + + Example + -------- + >>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> noise_labels = torch.randn([1]) + >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> input_image = torch.ones([1, 2, 16, 16]) + >>> output_image = model(input_image, noise_labels, class_labels) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + """ + + def __init__( + self, + img_resolution: Union[List[int], int], + in_channels: int, + out_channels: int, + label_dim: int = 0, + augment_dim: int = 0, + model_channels: int = 128, + channel_mult: List[int] = [1, 2, 2, 2, 2], + channel_mult_emb: int = 4, + num_blocks: int = 4, + attn_resolutions: List[int] = [28], + dropout: float = 0.13, + label_dropout: float = 0.0, + embedding_type: str = "positional", + channel_mult_noise: int = 1, + encoder_type: str = "standard", + decoder_type: str = "standard", + resample_filter: List[int] = [1, 1], + gridtype: str = "sinusoidal", + N_grid_channels: int = 4, + lead_time_channels: int = None, + lead_time_steps: int = 9, + prob_channels: List[int] = [], + checkpoint_level: int = 0, + ): + super().__init__( + img_resolution, + in_channels, + out_channels, + label_dim, + augment_dim, + model_channels, + channel_mult, + channel_mult_emb, + num_blocks, + attn_resolutions, + dropout, + label_dropout, + embedding_type, + channel_mult_noise, + encoder_type, + decoder_type, + resample_filter, + checkpoint_level, + ) + + self.gridtype = gridtype + self.N_grid_channels = N_grid_channels + self.pos_embd = self._get_positional_embedding() + self.lead_time_channels = lead_time_channels + self.lead_time_steps = lead_time_steps + self.lt_embd = self._get_lead_time_embedding() + self.prob_channels = prob_channels + if self.prob_channels: + self.scalar = torch.nn.Parameter( + torch.ones((1, len(self.prob_channels), 1, 1)) + ) + + @nvtx.annotate(message="SongUNet", color="blue") + def forward( + self, + x, + noise_labels, + class_labels, + lead_time_label=None, + global_index=None, + augment_labels=None, + ): + # append positional embedding to input conditioning + embeds = [] + if self.pos_embd is not None: + embeds.append(self.pos_embd.to(x.device)) + if self.lt_embd is not None: + embeds.append( + torch.reshape( + self.lt_embd[lead_time_label.int()], + (self.lead_time_channels, self.img_shape_y, self.img_shape_x), + ).to(x.device) + ) + if len(embeds) > 0: + embeds = torch.cat(embeds, dim=0) + selected_pos_embd = self.positional_embedding_indexing( + x, embeds, global_index + ) + x = torch.cat((x, selected_pos_embd), dim=1) + out = super().forward(x, noise_labels, class_labels, augment_labels) + # if training mode, let crossEntropyLoss do softmax. The model outputs logits. + # if eval mode, the model outputs probability + all_channels = list(range(out.shape[1])) # [0, 1, 2, ..., 10] + scalar_channels = [ + item for item in all_channels if item not in self.prob_channels + ] + if self.prob_channels and (not self.training): + out_final = torch.cat( + ( + out[:, scalar_channels], + (out[:, self.prob_channels] * self.scalar).softmax(dim=1), + ), + dim=1, + ) + elif self.prob_channels and self.training: + out_final = torch.cat( + (out[:, scalar_channels], (out[:, self.prob_channels] * self.scalar)), + dim=1, + ) + else: + out_final = out + return out_final + + def positional_embedding_indexing(self, x, pos_embd, global_index): + if global_index is None: + selected_pos_embd = ( + pos_embd.to(x.dtype).to(x.device)[None].expand((x.shape[0], -1, -1, -1)) + ) + else: + B = global_index.shape[0] + X = global_index.shape[2] + Y = global_index.shape[3] + global_index = torch.reshape( + torch.permute(global_index, (1, 0, 2, 3)), (2, -1) + ) # (B, 2, X, Y) to (2, B*X*Y) + selected_pos_embd = pos_embd.to(x.device)[ + :, global_index[0], global_index[1] + ] # (N_pe, B*X*Y) + selected_pos_embd = ( + torch.permute( + torch.reshape(selected_pos_embd, (pos_embd.shape[0], B, X, Y)), + (1, 0, 2, 3), + ) + .to(x.device) + .to(x.dtype) + ) # (B, N_pe, X, Y) + return selected_pos_embd + + def _get_positional_embedding(self): + if self.N_grid_channels == 0: + return None + elif self.gridtype == "learnable": + grid = torch.nn.Parameter( + torch.randn(self.N_grid_channels, self.img_shape_y, self.img_shape_x) + ) + elif self.gridtype == "linear": + if self.N_grid_channels != 2: + raise ValueError("N_grid_channels must be set to 2 for gridtype linear") + x = np.meshgrid(np.linspace(-1, 1, self.img_shape_y)) + y = np.meshgrid(np.linspace(-1, 1, self.img_shape_x)) + grid_x, grid_y = np.meshgrid(y, x) + grid = torch.from_numpy(np.stack((grid_x, grid_y), axis=0)) + grid.requires_grad = False + elif self.gridtype == "sinusoidal" and self.N_grid_channels == 4: + # print('sinusuidal grid added ......') + x1 = np.meshgrid(np.sin(np.linspace(0, 2 * np.pi, self.img_shape_y))) + x2 = np.meshgrid(np.cos(np.linspace(0, 2 * np.pi, self.img_shape_y))) + y1 = np.meshgrid(np.sin(np.linspace(0, 2 * np.pi, self.img_shape_x))) + y2 = np.meshgrid(np.cos(np.linspace(0, 2 * np.pi, self.img_shape_x))) + grid_x1, grid_y1 = np.meshgrid(y1, x1) + grid_x2, grid_y2 = np.meshgrid(y2, x2) + grid = torch.squeeze( + torch.from_numpy( + np.expand_dims( + np.stack((grid_x1, grid_y1, grid_x2, grid_y2), axis=0), axis=0 + ) + ) + ) + grid.requires_grad = False + elif self.gridtype == "sinusoidal" and self.N_grid_channels != 4: + if self.N_grid_channels % 4 != 0: + raise ValueError("N_grid_channels must be a factor of 4") + num_freq = self.N_grid_channels // 4 + freq_bands = 2.0 ** np.linspace(0.0, num_freq, num=num_freq) + grid_list = [] + grid_x, grid_y = np.meshgrid( + np.linspace(0, 2 * np.pi, self.img_shape_x), + np.linspace(0, 2 * np.pi, self.img_shape_y), + ) + for freq in freq_bands: + for p_fn in [np.sin, np.cos]: + grid_list.append(p_fn(grid_x * freq)) + grid_list.append(p_fn(grid_y * freq)) + grid = torch.from_numpy(np.stack(grid_list, axis=0)) + grid.requires_grad = False + elif self.gridtype == "test" and self.N_grid_channels == 2: + idx_x = torch.arange(self.img_shape_y) + idx_y = torch.arange(self.img_shape_x) + mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y) + grid = torch.stack((mesh_x, mesh_y), dim=0) + else: + raise ValueError("Gridtype not supported.") + return grid + + def _get_lead_time_embedding(self): + if (self.lead_time_steps is None) or (self.lead_time_channels is None): + return None + grid = torch.nn.Parameter( + torch.randn( + self.lead_time_steps, + self.lead_time_channels, + self.img_shape_y, + self.img_shape_x, + ) + ) + return grid diff --git a/src/models/unet.py b/src/models/unet.py new file mode 100644 index 0000000..7270606 --- /dev/null +++ b/src/models/unet.py @@ -0,0 +1,267 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +from dataclasses import dataclass + +import torch + +from physicsnemo.models.meta import ModelMetaData +from physicsnemo.models.module import Module + +network_module = importlib.import_module("physicsnemo.models.diffusion") + + +@dataclass +class MetaData(ModelMetaData): + name: str = "UNet" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = True + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class UNet(Module): # TODO a lot of redundancy, need to clean up + """ + U-Net Wrapper for CorrDiff. + + Parameters + ----------- + img_resolution : int + The resolution of the input/output image. + img_channels : int + Number of color channels. + img_in_channels : int + Number of input color channels. + img_out_channels : int + Number of output color channels. + use_fp16: bool, optional + Execute the underlying model at FP16 precision?, by default False. + sigma_min: float, optional + Minimum supported noise level, by default 0. + sigma_max: float, optional + Maximum supported noise level, by default float('inf'). + sigma_data: float, optional + Expected standard deviation of the training data, by default 0.5. + model_type: str, optional + Class name of the underlying model, by default 'DhariwalUNet'. + **model_kwargs : dict + Keyword arguments for the underlying model. + + + References + ---------- + Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, + img_resolution, + img_channels, + img_in_channels, + img_out_channels, + use_fp16=False, + sigma_min=0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="SongUNetPosEmbd", + **model_kwargs, + ): + super().__init__(meta=MetaData) + + self.img_channels = img_channels + + # for compatibility with older versions that took only 1 dimension + if isinstance(img_resolution, int): + self.img_shape_x = self.img_shape_y = img_resolution + else: + self.img_shape_x = img_resolution[0] + self.img_shape_y = img_resolution[1] + + self.img_in_channels = img_in_channels + self.img_out_channels = img_out_channels + + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_in_channels + img_out_channels, + out_channels=img_out_channels, + **model_kwargs, + ) + + def forward(self, x, img_lr, sigma, force_fp32=False, **model_kwargs): + # SR: concatenate input channels + if img_lr is not None: + x = torch.cat((x, img_lr), dim=1) + + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + F_x = self.model( + x.to(dtype), # (c_in * x).to(dtype), + torch.zeros( + sigma.numel(), dtype=sigma.dtype, device=sigma.device + ), # c_noise.flatten() + class_labels=None, + **model_kwargs, + ) + + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + # skip connection - for SR there's size mismatch bwtween input and output + D_x = F_x.to(torch.float32) + return D_x + + def round_sigma(self, sigma): + """ + Convert a given sigma value(s) to a tensor representation. + + Parameters + ---------- + sigma : Union[float list, torch.Tensor] + The sigma value(s) to convert. + + Returns + ------- + torch.Tensor + The tensor representation of the provided sigma value(s). + """ + return torch.as_tensor(sigma) + + +class StormCastUNet(Module): + """ + U-Net wrapper for StormCast; used so the same Song U-Net network can be re-used for this model. + + Parameters + ----------- + img_resolution : int or List[int] + The resolution of the input/output image. + img_channels : int + Number of color channels. + img_in_channels : int + Number of input color channels. + img_out_channels : int + Number of output color channels. + use_fp16: bool, optional + Execute the underlying model at FP16 precision?, by default False. + sigma_min: float, optional + Minimum supported noise level, by default 0. + sigma_max: float, optional + Maximum supported noise level, by default float('inf'). + sigma_data: float, optional + Expected standard deviation of the training data, by default 0.5. + model_type: str, optional + Class name of the underlying model, by default 'DhariwalUNet'. + **model_kwargs : dict + Keyword arguments for the underlying model. + + """ + + def __init__( + self, + img_resolution, + img_in_channels, + img_out_channels, + use_fp16=False, + sigma_min=0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="SongUNet", + **model_kwargs, + ): + super().__init__(meta=MetaData("StormCastUNet")) + + if isinstance(img_resolution, int): + self.img_shape_x = self.img_shape_y = img_resolution + else: + self.img_shape_x = img_resolution[0] + self.img_shape_y = img_resolution[1] + + self.img_in_channels = img_in_channels + self.img_out_channels = img_out_channels + + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + model_class = getattr(network_module, model_type) + self.model = model_class( + img_resolution=img_resolution, + in_channels=img_in_channels, + out_channels=img_out_channels, + **model_kwargs, + ) + + def forward(self, x, force_fp32=False, **model_kwargs): + """Run a forward pass of the StormCast regression U-Net. + + Args: + x (torch.Tensor): input to the U-Net + force_fp32 (bool, optional): force casting to fp_32 if True. Defaults to False. + + Raises: + ValueError: If input data type is a mismatch with provided options + + Returns: + D_x (torch.Tensor): Output (prediction) of the U-Net + """ + + x = x.to(torch.float32) + dtype = ( + torch.float16 + if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") + else torch.float32 + ) + + F_x = self.model( + x.to(dtype), + torch.zeros(x.shape[0], dtype=x.dtype, device=x.device), + class_labels=None, + **model_kwargs, + ) + + if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): + raise ValueError( + f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + ) + + D_x = F_x.to(torch.float32) + return D_x diff --git a/src/models/utils.py b/src/models/utils.py new file mode 100644 index 0000000..e1cde9d --- /dev/null +++ b/src/models/utils.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import torch + + +def weight_init(shape: tuple, mode: str, fan_in: int, fan_out: int): + """ + Unified routine for initializing weights and biases. + This function provides a unified interface for various weight initialization + strategies like Xavier (Glorot) and Kaiming (He) initializations. + + Parameters + ---------- + shape : tuple + The shape of the tensor to initialize. It could represent weights or biases + of a layer in a neural network. + mode : str + The mode/type of initialization to use. Supported values are: + - "xavier_uniform": Xavier (Glorot) uniform initialization. + - "xavier_normal": Xavier (Glorot) normal initialization. + - "kaiming_uniform": Kaiming (He) uniform initialization. + - "kaiming_normal": Kaiming (He) normal initialization. + fan_in : int + The number of input units in the weight tensor. For convolutional layers, + this typically represents the number of input channels times the kernel height + times the kernel width. + fan_out : int + The number of output units in the weight tensor. For convolutional layers, + this typically represents the number of output channels times the kernel height + times the kernel width. + + Returns + ------- + torch.Tensor + The initialized tensor based on the specified mode. + + Raises + ------ + ValueError + If the provided `mode` is not one of the supported initialization modes. + """ + if mode == "xavier_uniform": + return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) + if mode == "xavier_normal": + return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) + if mode == "kaiming_uniform": + return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) + if mode == "kaiming_normal": + return np.sqrt(1 / fan_in) * torch.randn(*shape) + raise ValueError(f'Invalid init mode "{mode}"') From 1f404b38d32d2e6f3a187696bf5c942d14fa1bc6 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 11 Apr 2025 17:29:49 +0200 Subject: [PATCH 05/66] add utils and change imports --- src/distributed/__init__.py | 1 + src/distributed/config.py | 247 ++++++ src/distributed/manager.py | 775 ++++++++++++++++++ src/models/__init__.py | 3 + src/models/layers.py | 2 +- src/models/preconditioning copy.py | 1176 ---------------------------- src/models/preconditioning.py | 18 +- src/models/song_unet.py | 6 +- src/models/unet.py | 8 +- src/utils/capture.py | 513 ++++++++++++ src/utils/checkpoint.py | 398 ++++++++++ src/utils/console.py | 88 +++ src/utils/deterministic_sampler.py | 231 ++++++ src/utils/function_utils.py | 775 ++++++++++++++++++ src/utils/inference_utils.py | 253 ++++++ src/utils/model_utils.py | 66 ++ src/utils/stochastic_sampler.py | 533 +++++++++++++ src/utils/train_helpers.py | 107 +++ 18 files changed, 4007 insertions(+), 1193 deletions(-) create mode 100644 src/distributed/__init__.py create mode 100644 src/distributed/config.py create mode 100644 src/distributed/manager.py create mode 100644 src/models/__init__.py delete mode 100644 src/models/preconditioning copy.py create mode 100644 src/utils/capture.py create mode 100644 src/utils/checkpoint.py create mode 100644 src/utils/console.py create mode 100644 src/utils/deterministic_sampler.py create mode 100644 src/utils/function_utils.py create mode 100644 src/utils/inference_utils.py create mode 100644 src/utils/model_utils.py create mode 100644 src/utils/stochastic_sampler.py create mode 100644 src/utils/train_helpers.py diff --git a/src/distributed/__init__.py b/src/distributed/__init__.py new file mode 100644 index 0000000..0da01f3 --- /dev/null +++ b/src/distributed/__init__.py @@ -0,0 +1 @@ +from .manager import DistributedManager \ No newline at end of file diff --git a/src/distributed/config.py b/src/distributed/config.py new file mode 100644 index 0000000..c5414b4 --- /dev/null +++ b/src/distributed/config.py @@ -0,0 +1,247 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Union + +from treelib import Tree + + +class ProcessGroupNode: + """ + Class to store the attributes of a distributed process group + + Attributes + ---------- + name : str + Name of the process group + size : Optional[int] + Optional, number of processes in the process group + """ + + def __init__( + self, + name: str, + size: Optional[int] = None, + ): + """ + Constructor for the ProcessGroupNode class + + Parameters + ---------- + name : str + Name of the process group + size : Optional[int] + Optional, size of the process group + """ + self.name = name + self.size = size + + def __str__(self): + """ + String representation of the process group node + + Returns + ------- + str + String representation of the process group node + """ + return "ProcessGroupNode(" f"name={self.name}, " f"size={self.size}, " + + def __repr__(self): + """ + String representation of the process group node + + Returns + ------- + str + String representation of the process group node + """ + return self.__str__() + + +class ProcessGroupConfig: + """ + Class to define the configuration of a model's parallel process group structure as a + tree. Each node of the tree is of type `ProcessGroupNode`. + + Once the process group config structure (i.e, the tree structure) is set, it is + sufficient to set only the sizes for each leaf process group. Then, the size of + every parent group can be automatically computed as the product reduction of the + sub-tree of that parent group node. + + Examples + -------- + >>> from physicsnemo.distributed import ProcessGroupNode, ProcessGroupConfig + >>> + >>> # Create world group that contains all processes that are part of this job + >>> world = ProcessGroupNode("world") + >>> + >>> # Create the process group config with the highest level process group + >>> config = ProcessGroupConfig(world) + >>> + >>> # Create model and data parallel sub-groups + >>> # Sub-groups of a single node are guaranteed to be orthogonal by construction + >>> # Nodes can be added with either the name of the node or the node itself + >>> config.add_node(ProcessGroupNode("model_parallel"), parent=world) + >>> config.add_node(ProcessGroupNode("data_parallel"), parent="world") + >>> + >>> # Create spatial and channel parallel sub-groups + >>> config.add_node(ProcessGroupNode("spatial_parallel"), parent="model_parallel") + >>> config.add_node(ProcessGroupNode("channel_parallel"), parent="model_parallel") + >>> + >>> config.leaf_groups() + ['data_parallel', 'spatial_parallel', 'channel_parallel'] + >>> + >>> # Set leaf group sizes + >>> # Note: product of all leaf-node sizes should be the world size + >>> group_sizes = {"channel_parallel": 3, "spatial_parallel": 2, "data_parallel": 4} + >>> config.set_leaf_group_sizes(group_sizes) # Update all parent group sizes too + >>> config.get_node("model_parallel").size + 6 + """ + + def __init__(self, node: ProcessGroupNode): + """ + Constructor to the ProcessGroupConfig class + + Parameters + ---------- + node : ProcessGroupNode + Root node of the tree, typically would be 'world' + Note, it is generally recommended to set the child groups for 'world' + to 'model_parallel' and 'data_parallel' to aid with distributed + data parallel training unless there is a specific reason to choose a + different structure + """ + self.root = node + self.root_id = node.name + self.tree = Tree() + self.tree.create_node(node.name, node.name, data=node) + + def add_node(self, node: ProcessGroupNode, parent=Union[str, ProcessGroupNode]): + """ + Add a node to the process group config + + Parameters + ---------- + node : ProcessGroupNode + The new node to be added to the config + parent : Union[str, ProcessGroupNode] + Parent node of the node to be added. Should already be in the config. + If str, it is the name of the parent node. Otherwise, the parent + ProcessGroupNode itself. + """ + if isinstance(parent, ProcessGroupNode): + parent = parent.name + self.tree.create_node(node.name, node.name, data=node, parent=parent) + + def get_node(self, name: str) -> ProcessGroupNode: + """ + Method to get the node given the name of the node + + Parameters + ---------- + name : str + Name of the node to retrieve + + Returns + ------- + ProcessGroupNode + Node with the given name from the config + """ + return self.tree.get_node(name).data + + def update_parent_sizes(self, verbose: bool = False) -> int: + """ + Method to update parent node sizes after setting the sizes for each leaf node + + Parameters + ---------- + verbose : bool + If True, print a message each time a parent node size was updated + + Returns + ------- + int + Size of the root node + """ + return _tree_product_reduction(self.tree, self.root_id, verbose=verbose) + + def leaf_groups(self) -> List[str]: + """ + Get a list of all leaf group names + + Returns + ------- + List[str] + List of all leaf node names + """ + return [n.identifier for n in self.tree.leaves()] + + def set_leaf_group_sizes( + self, group_sizes: Dict[str, int], update_parent_sizes: bool = True + ): + """ + Set process group sizes for all leaf groups + + Parameters + ---------- + group_sizes : Dict[str, int] + Dictionary with a mapping of each leaf group name to its size + update_parent_sizes : bool + Update all parent group sizes based on the leaf group if True + If False, only set the leaf group sizes. + """ + for id, size in group_sizes.items(): + if not self.tree.contains(id): + raise AssertionError( + f"Process group {id} is not in this process group config" + ) + node = self.tree.get_node(id) + if not node.is_leaf(): + raise AssertionError(f"Process group {id} is not a leaf group") + node.data.size = size + + if update_parent_sizes: + self.update_parent_sizes() + + +def _tree_product_reduction(tree, node_id, verbose=False): + """ + Function to traverse a tree and compute the product reduction of + the sub-tree for each node starting from `node_id` + """ + children = tree.children(node_id) + node = tree.get_node(node_id) + if not children: + if node.data.size is None: + raise AssertionError("Leaf nodes should have a valid size set") + return node.data.size + + product = 1 + + for child in children: + product *= _tree_product_reduction(tree, child.identifier) + + if node.data.size != product: + if verbose: + print( + "Updating size of node " + f"{node.data.name} from {node.data.size} to {product}" + ) + node.data.size = product + + return product diff --git a/src/distributed/manager.py b/src/distributed/manager.py new file mode 100644 index 0000000..facb466 --- /dev/null +++ b/src/distributed/manager.py @@ -0,0 +1,775 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import os +import queue +import warnings +from typing import Optional, Tuple +from warnings import warn + +import numpy as np +import torch +import torch.distributed as dist + +from src.distributed.config import ProcessGroupConfig, ProcessGroupNode + +warnings.simplefilter("default", DeprecationWarning) + + +class UndefinedGroupError(Exception): + """Exception for querying an undefined process group using the PhysicsNeMo DistributedManager""" + + def __init__(self, name: str): + """ + + Parameters + ---------- + name : str + Name of the process group being queried. + + """ + message = ( + f"Cannot query process group '{name}' before it is explicitly created." + ) + super().__init__(message) + + +class UninitializedDistributedManagerWarning(Warning): + """Warning to indicate usage of an uninitialized DistributedManager""" + + def __init__(self): + message = ( + "A DistributedManager object is being instantiated before " + + "this singleton class has been initialized. Instantiating a manager before " + + "initialization can lead to unexpected results where processes fail " + + "to communicate. Initialize the distributed manager via " + + "DistributedManager.initialize() before instantiating." + ) + super().__init__(message) + + +class DistributedManager(object): + """Distributed Manager for setting up distributed training environment. + + This is a singleton that creates a persistance class instance for storing parallel + environment information through out the life time of the program. This should be + used to help set up Distributed Data Parallel and parallel datapipes. + + Note + ---- + One should call `DistributedManager.initialize()` prior to constructing a manager + object + + Example + ------- + >>> DistributedManager.initialize() + >>> manager = DistributedManager() + >>> manager.rank + 0 + >>> manager.world_size + 1 + """ + + _shared_state = {} + + def __new__(cls): + obj = super(DistributedManager, cls).__new__(cls) + obj.__dict__ = cls._shared_state + + # Set the defaults + if not hasattr(obj, "_rank"): + obj._rank = 0 + if not hasattr(obj, "_world_size"): + obj._world_size = 1 + if not hasattr(obj, "_local_rank"): + obj._local_rank = 0 + if not hasattr(obj, "_distributed"): + obj._distributed = False + if not hasattr(obj, "_device"): + obj._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + if not hasattr(obj, "_cuda"): + obj._cuda = torch.cuda.is_available() + if not hasattr(obj, "_broadcast_buffers"): + obj._broadcast_buffers = False + if not hasattr(obj, "_find_unused_parameters"): + obj._find_unused_parameters = False + if not hasattr(obj, "_initialization_method"): + obj._initialization_method = "None" + if not hasattr(obj, "_groups"): + obj._groups = {} + if not hasattr(obj, "_group_ranks"): + obj._group_ranks = {} + if not hasattr(obj, "_group_names"): + obj._group_names = {} + if not hasattr(obj, "_is_initialized"): + obj._is_initialized = False + if not hasattr(obj, "_global_mesh"): + obj._global_mesh = None # Lazy initialized right when it's first needed + if not hasattr(obj, "_mesh_dims"): + obj._mesh_dims = {} # Dictionary mapping axis names to sizes + + return obj + + def __init__(self): + if not self._is_initialized: + raise UninitializedDistributedManagerWarning() + super().__init__() + + @property + def rank(self): + """Process rank""" + return self._rank + + @property + def local_rank(self): + """Process rank on local machine""" + return self._local_rank + + @property + def world_size(self): + """Number of processes in distributed environment""" + return self._world_size + + @property + def device(self): + """Process device""" + return self._device + + @property + def distributed(self): + """Distributed environment""" + return self._distributed + + @property + def cuda(self): + """If cuda is available""" + return self._cuda + + @property + def mesh_dims(self): + """Mesh Dimensions as dictionary (axis name : size)""" + return self._mesh_dims + + @property + def group_names(self): + """ + Returns a list of all named process groups created + """ + return self._groups.keys() + + @property + def global_mesh(self): + """ + Returns the global mesh. If it's not initialized, it will be created when this is called. + """ + if self._global_mesh is None: + # Fully flat mesh (1D) by default: + self.initialize_mesh(mesh_shape=(-1,), mesh_dim_names=("world",)) + + return self._global_mesh + + def mesh_names(self): + """ + Return mesh axis names + """ + return self._mesh_dims.keys() + + def mesh_sizes(self): + """ + Return mesh axis sizes + """ + return self._mesh_dims.values() + + def group(self, name=None): + """ + Returns a process group with the given name + If name is None, group is also None indicating the default process group + If named group does not exist, UndefinedGroupError exception is raised + """ + if name in self._groups.keys(): + return self._groups[name] + elif name is None: + return None + else: + raise UndefinedGroupError(name) + + def mesh(self, name=None): + """ + Return a device_mesh with the given name. + Does not initialize. If the mesh is not created + already, will raise and error + + Parameters + ---------- + name : str, optional + Name of desired mesh, by default None + """ + + if name in self._global_mesh.axis_names: + return self._global_mesh[name] + elif name is None: + return self._global_mesh + else: + raise UndefinedGroupError(f"Mesh axis {name} not defined") + + def group_size(self, name=None): + """ + Returns the size of named process group + """ + if name is None: + return self._world_size + group = self.group(name) + return dist.get_world_size(group=group) + + def group_rank(self, name=None): + """ + Returns the rank in named process group + """ + if name is None: + return self._rank + group = self.group(name) + return dist.get_rank(group=group) + + def group_name(self, group=None): + """ + Returns the name of process group + """ + if group is None: + return None + return self._group_names[group] + + @property + def broadcast_buffers(self): + """broadcast_buffers in PyTorch DDP""" + return self._broadcast_buffers + + @broadcast_buffers.setter + def broadcast_buffers(self, broadcast: bool): + """Setter for broadcast_buffers""" + self._broadcast_buffers = broadcast + + @property + def find_unused_parameters(self): + """find_unused_parameters in PyTorch DDP""" + return self._find_unused_parameters + + @find_unused_parameters.setter + def find_unused_parameters(self, find_params: bool): + """Setter for find_unused_parameters""" + if find_params: + warn( + "Setting `find_unused_parameters` in DDP to true, " + "use only if necessary." + ) + self._find_unused_parameters = find_params + + def __str__(self): + output = ( + f"Initialized process {self.rank} of {self.world_size} using " + f"method '{self._initialization_method}'. Device set to {str(self.device)}" + ) + return output + + @classmethod + def is_initialized(cls) -> bool: + """If manager singleton has been initialized""" + return cls._shared_state.get("_is_initialized", False) + + @staticmethod + def get_available_backend(): + """Get communication backend""" + if torch.cuda.is_available() and torch.distributed.is_nccl_available(): + return "nccl" + else: + return "gloo" + + @staticmethod + def initialize_env(): + """Setup method using generic initialization""" + rank = int(os.environ.get("RANK")) + world_size = int(os.environ.get("WORLD_SIZE")) + if "LOCAL_RANK" in os.environ: + local_rank = os.environ.get("LOCAL_RANK") + if local_rank is not None: + local_rank = int(local_rank) + else: + local_rank = rank % torch.cuda.device_count() + + else: + local_rank = rank % torch.cuda.device_count() + + # Read env variables + addr = os.environ.get("MASTER_ADDR") + port = os.environ.get("MASTER_PORT") + + DistributedManager.setup( + rank=rank, + world_size=world_size, + local_rank=local_rank, + addr=addr, + port=port, + backend=DistributedManager.get_available_backend(), + ) + + @staticmethod + def initialize_open_mpi(addr, port): + """Setup method using OpenMPI initialization""" + rank = int(os.environ.get("OMPI_COMM_WORLD_RANK")) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE")) + local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")) + + DistributedManager.setup( + rank=rank, + world_size=world_size, + local_rank=local_rank, + addr=addr, + port=port, + backend=DistributedManager.get_available_backend(), + method="openmpi", + ) + + @staticmethod + def initialize_slurm(port): + """Setup method using SLURM initialization""" + rank = int(os.environ.get("SLURM_PROCID")) + world_size = int(os.environ.get("SLURM_NPROCS")) + local_rank = int(os.environ.get("SLURM_LOCALID")) + addr = os.environ.get("SLURM_LAUNCH_NODE_IPADDR") + + DistributedManager.setup( + rank=rank, + world_size=world_size, + local_rank=local_rank, + addr=addr, + port=port, + backend=DistributedManager.get_available_backend(), + method="slurm", + ) + + @staticmethod + def initialize(): + """ + Initialize distributed manager + + Current supported initialization methods are: + `ENV`: PyTorch environment variable initialization + https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization + `SLURM`: Initialization on SLURM systems. + Uses `SLURM_PROCID`, `SLURM_NPROCS`, `SLURM_LOCALID` and + `SLURM_LAUNCH_NODE_IPADDR` environment variables. + `OPENMPI`: Initialization for OpenMPI launchers. + Uses `OMPI_COMM_WORLD_RANK`, `OMPI_COMM_WORLD_SIZE` and + `OMPI_COMM_WORLD_LOCAL_RANK` environment variables. + + Initialization by default is done using the first valid method in the order + listed above. Initialization method can also be explicitly controlled using the + `PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD` environment variable and setting it + to one of the options above. + """ + if DistributedManager.is_initialized(): + warn("Distributed manager is already intialized") + return + + addr = os.getenv("MASTER_ADDR", "localhost") + port = os.getenv("MASTER_PORT", "12355") + # https://pytorch.org/docs/master/notes/cuda.html#id5 + # was changed in version 2.2 + if torch.__version__ < (2, 2): + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + else: + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0" + initialization_method = os.getenv( + "PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD" + ) + if initialization_method is None: + try: + DistributedManager.initialize_env() + except TypeError: + if "SLURM_PROCID" in os.environ: + DistributedManager.initialize_slurm(port) + elif "OMPI_COMM_WORLD_RANK" in os.environ: + DistributedManager.initialize_open_mpi(addr, port) + else: + warn( + "Could not initialize using ENV, SLURM or OPENMPI methods. Assuming this is a single process job" + ) + DistributedManager._shared_state["_is_initialized"] = True + elif initialization_method == "ENV": + DistributedManager.initialize_env() + elif initialization_method == "SLURM": + DistributedManager.initialize_slurm(port) + elif initialization_method == "OPENMPI": + DistributedManager.initialize_open_mpi(addr, port) + else: + raise RuntimeError( + "Unknown initialization method " + f"{initialization_method}. " + "Supported values for " + "PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD are " + "ENV, SLURM and OPENMPI" + ) + + # Set per rank numpy random seed for data sampling + np.random.seed(seed=DistributedManager().rank) + + def initialize_mesh( + self, mesh_shape: Tuple[int, ...], mesh_dim_names: Tuple[str, ...] + ) -> dist.DeviceMesh: + """ + Initialize a global device mesh over the entire distributed job. + + Creates a multi-dimensional mesh of processes that can be used for distributed + operations. The mesh shape must multiply to equal the total world size, with + one dimension optionally being flexible (-1). + + Parameters + ---------- + mesh_shape : Tuple[int, ...] + Tuple of ints describing the size of each mesh dimension. Product must equal + world_size. One dimension can be -1 to be automatically calculated. + + mesh_dim_names : Tuple[str, ...] + Names for each mesh dimension. Must match length of mesh_shape. + + Returns + ------- + torch.distributed.DeviceMesh + The initialized device mesh + + Raises + ------ + RuntimeError + If mesh dimensions are invalid or don't match world size + AssertionError + If distributed environment is not available + """ + + manager = DistributedManager() + if not manager.distributed: + raise AssertionError( + "torch.distributed is unavailable. " + "Check pytorch build to ensure the distributed package is available. " + "If building PyTorch from source, set `USE_DISTRIBUTED=1` " + "to enable the distributed package" + ) + + # Assert basic properties: + if len(mesh_shape) == 0: + raise RuntimeError( + "Device Mesh requires at least one mesh dimension in `mesh_shape`" + ) + if len(mesh_shape) != len(mesh_dim_names): + raise RuntimeError( + "mesh_shape and mesh_dim_names must have the same length, but found " + f"{len(mesh_shape)} and {len(mesh_dim_names)} respectively." + ) + if len(set(mesh_dim_names)) != len(mesh_dim_names): + raise RuntimeError("Mesh dimension names must be unique") + + # Check against the total mesh shape vs. world size: + total_mesh_shape = np.prod(mesh_shape) + + # Allow one shape to be -1 + if -1 in mesh_shape: + residual_shape = int(self.world_size / (-1 * total_mesh_shape)) + + # Replace -1 with the computed size: + mesh_shape = [residual_shape if m == -1 else m for m in mesh_shape] + # Recompute total shape: + total_mesh_shape = np.prod(mesh_shape) + + if total_mesh_shape != self.world_size: + raise RuntimeError( + "Device Mesh num elements must equal world size of " + f"{total_mesh_shape} but was configured by user with " + f"global size of {self.world_size}." + ) + + # Actually create the mesh: + self._global_mesh = dist.init_device_mesh( + "cuda" if self.cuda else "cpu", + mesh_shape, + mesh_dim_names=mesh_dim_names, + ) + + # Finally, upon success, cache the mesh dimensions: + self._mesh_dims = {key: val for key, val in zip(mesh_dim_names, mesh_shape)} + + return self._global_mesh + + @staticmethod + def setup( + rank=0, + world_size=1, + local_rank=None, + addr="localhost", + port="12355", + backend="nccl", + method="env", + ): + """Set up PyTorch distributed process group and update manager attributes""" + os.environ["MASTER_ADDR"] = addr + os.environ["MASTER_PORT"] = str(port) + + DistributedManager._shared_state["_is_initialized"] = True + manager = DistributedManager() + + manager._distributed = torch.distributed.is_available() + if manager._distributed: + # Update rank and world_size if using distributed + manager._rank = rank + manager._world_size = world_size + if local_rank is None: + manager._local_rank = rank % torch.cuda.device_count() + else: + manager._local_rank = local_rank + + manager._device = torch.device( + f"cuda:{manager.local_rank}" if torch.cuda.is_available() else "cpu" + ) + + if manager._distributed: + # Setup distributed process group + try: + dist.init_process_group( + backend, + rank=manager.rank, + world_size=manager.world_size, + device_id=manager.device, + ) + except TypeError: + # device_id only introduced in PyTorch 2.3 + dist.init_process_group( + backend, + rank=manager.rank, + world_size=manager.world_size, + ) + + if torch.cuda.is_available(): + # Set device for this process and empty cache to optimize memory usage + torch.cuda.set_device(manager.device) + torch.cuda.device(manager.device) + torch.cuda.empty_cache() + + manager._initialization_method = method + + @staticmethod + def create_process_subgroup( + name: str, size: int, group_name: Optional[str] = None, verbose: bool = False + ): # pragma: no cover + """ + Create a process subgroup of a parent process group. This must be a collective + call by all processes participating in this application. + + Parameters + ---------- + name : str + Name of the process subgroup to be created. + + size : int + Size of the process subgroup to be created. This must be an integer factor of + the parent group's size. + + group_name : Optional[str] + Name of the parent process group, optional. If None, the default process group + will be used. Default None. + + verbose : bool + Print out ranks of each created process group, default False. + + """ + manager = DistributedManager() + if not manager.distributed: + raise AssertionError( + "torch.distributed is unavailable. " + "Check pytorch build to ensure the distributed package is available. " + "If building PyTorch from source, set `USE_DISTRIBUTED=1` " + "to enable the distributed package" + ) + + if name in manager._groups: + raise AssertionError(f"Group with name {name} already exists") + + # Get parent group's params + group = manager._groups[group_name] if group_name else None + group_size = dist.get_world_size(group=group) + num_groups = manager.world_size // group_size + + # Get number of sub-groups per parent group + if group_size % size != 0: + raise AssertionError( + f"Cannot divide group size {group_size} evenly into subgroups of" + f" size {size}" + ) + num_subgroups = group_size // size + + # Create all the sub-groups + # Note: all ranks in the job need to create all sub-groups in + # the same order even if a rank is not part of a sub-group + manager._group_ranks[name] = [] + for g in range(num_groups): + for i in range(num_subgroups): + # Get global ranks that are part of this sub-group + start = i * size + end = start + size + if group_name: + ranks = manager._group_ranks[group_name][g][start:end] + else: + ranks = list(range(start, end)) + # Create sub-group and keep track of ranks + tmp_group = dist.new_group(ranks=ranks) + manager._group_ranks[name].append(ranks) + if manager.rank in ranks: + # Set group in manager only if this rank is part of the group + manager._groups[name] = tmp_group + manager._group_names[tmp_group] = name + + if verbose and manager.rank == 0: + print(f"Process group '{name}':") + for grp in manager._group_ranks[name]: + print(" ", grp) + + @staticmethod + def create_orthogonal_process_group( + orthogonal_group_name: str, group_name: str, verbose: bool = False + ): # pragma: no cover + """ + Create a process group that is orthogonal to the specified process group. + + Parameters + ---------- + orthogonal_group_name : str + Name of the orthogonal process group to be created. + + group_name : str + Name of the existing process group. + + verbose : bool + Print out ranks of each created process group, default False. + + """ + manager = DistributedManager() + if not manager.distributed: + raise AssertionError( + "torch.distributed is unavailable. " + "Check pytorch build to ensure the distributed package is available. " + "If building PyTorch from source, set `USE_DISTRIBUTED=1` " + "to enable the distributed package" + ) + + if group_name not in manager._groups: + raise ValueError(f"Group with name {group_name} does not exist") + if orthogonal_group_name in manager._groups: + raise ValueError(f"Group with name {orthogonal_group_name} already exists") + + group_ranks = manager._group_ranks[group_name] + orthogonal_ranks = [list(i) for i in zip(*group_ranks)] + + for ranks in orthogonal_ranks: + tmp_group = dist.new_group(ranks=ranks) + if manager.rank in ranks: + # Set group in manager only if this rank is part of the group + manager._groups[orthogonal_group_name] = tmp_group + manager._group_names[tmp_group] = orthogonal_group_name + + manager._group_ranks[orthogonal_group_name] = orthogonal_ranks + + if verbose and manager.rank == 0: + print(f"Process group '{orthogonal_group_name}':") + for grp in manager._group_ranks[orthogonal_group_name]: + print(" ", grp) + + @staticmethod + def create_group_from_node( + node: ProcessGroupNode, + parent: Optional[str] = None, + verbose: bool = False, + ): # pragma: no cover + if node.size is None: + raise AssertionError( + "Cannot create groups from a ProcessGroupNode that is not fully" + " populated. Ensure that config.set_leaf_group_sizes is called first" + " with `update_parent_sizes = True`" + ) + + DistributedManager.create_process_subgroup( + node.name, node.size, group_name=parent, verbose=verbose + ) + # Create orthogonal process group + orthogonal_group = f"__orthogonal_to_{node.name}" + DistributedManager.create_orthogonal_process_group( + orthogonal_group, node.name, verbose=verbose + ) + return orthogonal_group + + @staticmethod + def create_groups_from_config( + config: ProcessGroupConfig, verbose: bool = False + ): # pragma: no cover + + warnings.warn( + "DistributedManager.create_groups_from_config is no longer the most simple " + "way to organize process groups. Please switch to DeviceMesh, " + "and DistributedManager.initialize_mesh", + category=DeprecationWarning, + stacklevel=2, + ) + + # Traverse process group tree in breadth first order + # to create nested process groups + q = queue.Queue() + q.put(config.root_id) + DistributedManager.create_group_from_node(config.root) + + while not q.empty(): + node_id = q.get() + if verbose: + print(f"Node ID: {node_id}") + + children = config.tree.children(node_id) + if verbose: + print(f" Children: {children}") + + parent_group = node_id + for child in children: + # Create child group and replace parent group by orthogonal group so + # that each child forms an independent block of processes + parent_group = DistributedManager.create_group_from_node( + child.data, + parent=parent_group, + ) + + # Add child ids to the queue + q.put(child.identifier) + + @atexit.register + @staticmethod + def cleanup(): + """Clean up distributed group and singleton""" + # Destroying group.WORLD is enough for all process groups to get destroyed + if ( + "_is_initialized" in DistributedManager._shared_state + and DistributedManager._shared_state["_is_initialized"] + and "_distributed" in DistributedManager._shared_state + and DistributedManager._shared_state["_distributed"] + ): + if torch.cuda.is_available(): + dist.barrier(device_ids=[DistributedManager().local_rank]) + else: + dist.barrier() + dist.destroy_process_group() + DistributedManager._shared_state = {} diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..6b790ae --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1,3 @@ +from .unet import UNet +from .song_unet import SongUNet, SongUNetPosEmbd, SongUNetPosLtEmbd +from .layers import Linear, Conv2d, GroupNorm, AttentionOp, UNetBlock, PositionalEmbedding, FourierEmbedding \ No newline at end of file diff --git a/src/models/layers.py b/src/models/layers.py index 1fb3b17..d5a1ab2 100644 --- a/src/models/layers.py +++ b/src/models/layers.py @@ -26,7 +26,7 @@ from einops import rearrange from torch.nn.functional import silu -from physicsnemo.models.diffusion import weight_init +from src.utils.model_utils import weight_init class Linear(torch.nn.Module): diff --git a/src/models/preconditioning copy.py b/src/models/preconditioning copy.py deleted file mode 100644 index 52a1660..0000000 --- a/src/models/preconditioning copy.py +++ /dev/null @@ -1,1176 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Preconditioning schemes used in the paper"Elucidating the Design Space of -Diffusion-Based Generative Models". -""" - -import importlib -import warnings -from dataclasses import dataclass -from typing import List, Union - -import numpy as np -import nvtx -import torch - -from physicsnemo.models.diffusion import ( - DhariwalUNet, # noqa: F401 for globals - SongUNet, # noqa: F401 for globals -) -from physicsnemo.models.meta import ModelMetaData -from physicsnemo.models.module import Module - -network_module = importlib.import_module("physicsnemo.models.diffusion") - - -@dataclass -class VPPrecondMetaData(ModelMetaData): - """VPPrecond meta data""" - - name: str = "VPPrecond" - # Optimization - jit: bool = False - cuda_graphs: bool = False - amp_cpu: bool = False - amp_gpu: bool = True - torch_fx: bool = False - # Data type - bf16: bool = False - # Inference - onnx: bool = False - # Physics informed - func_torch: bool = False - auto_grad: bool = False - - -class VPPrecond(Module): - """ - Preconditioning corresponding to the variance preserving (VP) formulation. - - Parameters - ---------- - img_resolution : int - Image resolution. - img_channels : int - Number of color channels. - label_dim : int - Number of class labels, 0 = unconditional, by default 0. - use_fp16 : bool - Execute the underlying model at FP16 precision?, by default False. - beta_d : float - Extent of the noise level schedule, by default 19.9. - beta_min : float - Initial slope of the noise level schedule, by default 0.1. - M : int - Original number of timesteps in the DDPM formulation, by default 1000. - epsilon_t : float - Minimum t-value used during training, by default 1e-5. - model_type :str - Class name of the underlying model, by default "SongUNet". - **model_kwargs : dict - Keyword arguments for the underlying model. - - Note - ---- - Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and - Poole, B., 2020. Score-based generative modeling through stochastic differential - equations. arXiv preprint arXiv:2011.13456. - """ - - def __init__( - self, - img_resolution: int, - img_channels: int, - label_dim: int = 0, - use_fp16: bool = False, - beta_d: float = 19.9, - beta_min: float = 0.1, - M: int = 1000, - epsilon_t: float = 1e-5, - model_type: str = "SongUNet", - **model_kwargs: dict, - ): - super().__init__(meta=VPPrecondMetaData) - self.img_resolution = img_resolution - self.img_channels = img_channels - self.label_dim = label_dim - self.use_fp16 = use_fp16 - self.beta_d = beta_d - self.beta_min = beta_min - self.M = M - self.epsilon_t = epsilon_t - self.sigma_min = float(self.sigma(epsilon_t)) - self.sigma_max = float(self.sigma(1)) - model_class = getattr(network_module, model_type) - self.model = model_class( - img_resolution=img_resolution, - in_channels=img_channels, - out_channels=img_channels, - label_dim=label_dim, - **model_kwargs, - ) # TODO needs better handling - - def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) - class_labels = ( - None - if self.label_dim == 0 - else torch.zeros([1, self.label_dim], device=x.device) - if class_labels is None - else class_labels.to(torch.float32).reshape(-1, self.label_dim) - ) - dtype = ( - torch.float16 - if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") - else torch.float32 - ) - - c_skip = 1 - c_out = -sigma - c_in = 1 / (sigma**2 + 1).sqrt() - c_noise = (self.M - 1) * self.sigma_inv(sigma) - - F_x = self.model( - (c_in * x).to(dtype), - c_noise.flatten(), - class_labels=class_labels, - **model_kwargs, - ) - if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): - raise ValueError( - f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." - ) - - D_x = c_skip * x + c_out * F_x.to(torch.float32) - return D_x - - def sigma(self, t: Union[float, torch.Tensor]): - """ - Compute the sigma(t) value for a given t based on the VP formulation. - - The function calculates the noise level schedule for the diffusion process based - on the given parameters `beta_d` and `beta_min`. - - Parameters - ---------- - t : Union[float, torch.Tensor] - The timestep or set of timesteps for which to compute sigma(t). - - Returns - ------- - torch.Tensor - The computed sigma(t) value(s). - """ - t = torch.as_tensor(t) - return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt() - - def sigma_inv(self, sigma: Union[float, torch.Tensor]): - """ - Compute the inverse of the sigma function for a given sigma. - - This function effectively calculates t from a given sigma(t) based on the - parameters `beta_d` and `beta_min`. - - Parameters - ---------- - sigma : Union[float, torch.Tensor] - The sigma(t) value or set of sigma(t) values for which to compute the - inverse. - - Returns - ------- - torch.Tensor - The computed t value(s) corresponding to the provided sigma(t). - """ - sigma = torch.as_tensor(sigma) - return ( - (self.beta_min**2 + 2 * self.beta_d * (1 + sigma**2).log()).sqrt() - - self.beta_min - ) / self.beta_d - - def round_sigma(self, sigma: Union[float, List, torch.Tensor]): - """ - Convert a given sigma value(s) to a tensor representation. - - Parameters - ---------- - sigma : Union[float list, torch.Tensor] - The sigma value(s) to convert. - - Returns - ------- - torch.Tensor - The tensor representation of the provided sigma value(s). - """ - return torch.as_tensor(sigma) - - -@dataclass -class VEPrecondMetaData(ModelMetaData): - """VEPrecond meta data""" - - name: str = "VEPrecond" - # Optimization - jit: bool = False - cuda_graphs: bool = False - amp_cpu: bool = False - amp_gpu: bool = True - torch_fx: bool = False - # Data type - bf16: bool = False - # Inference - onnx: bool = False - # Physics informed - func_torch: bool = False - auto_grad: bool = False - - -class VEPrecond(Module): - """ - Preconditioning corresponding to the variance exploding (VE) formulation. - - Parameters - ---------- - img_resolution : int - Image resolution. - img_channels : int - Number of color channels. - label_dim : int - Number of class labels, 0 = unconditional, by default 0. - use_fp16 : bool - Execute the underlying model at FP16 precision?, by default False. - sigma_min : float - Minimum supported noise level, by default 0.02. - sigma_max : float - Maximum supported noise level, by default 100.0. - model_type :str - Class name of the underlying model, by default "SongUNet". - **model_kwargs : dict - Keyword arguments for the underlying model. - - Note - ---- - Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and - Poole, B., 2020. Score-based generative modeling through stochastic differential - equations. arXiv preprint arXiv:2011.13456. - """ - - def __init__( - self, - img_resolution: int, - img_channels: int, - label_dim: int = 0, - use_fp16: bool = False, - sigma_min: float = 0.02, - sigma_max: float = 100.0, - model_type: str = "SongUNet", - **model_kwargs: dict, - ): - super().__init__(meta=VEPrecondMetaData) - self.img_resolution = img_resolution - self.img_channels = img_channels - self.label_dim = label_dim - self.use_fp16 = use_fp16 - self.sigma_min = sigma_min - self.sigma_max = sigma_max - model_class = getattr(network_module, model_type) - self.model = model_class( - img_resolution=img_resolution, - in_channels=img_channels, - out_channels=img_channels, - label_dim=label_dim, - **model_kwargs, - ) # TODO needs better handling - - def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) - class_labels = ( - None - if self.label_dim == 0 - else torch.zeros([1, self.label_dim], device=x.device) - if class_labels is None - else class_labels.to(torch.float32).reshape(-1, self.label_dim) - ) - dtype = ( - torch.float16 - if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") - else torch.float32 - ) - - c_skip = 1 - c_out = sigma - c_in = 1 - c_noise = (0.5 * sigma).log() - - F_x = self.model( - (c_in * x).to(dtype), - c_noise.flatten(), - class_labels=class_labels, - **model_kwargs, - ) - if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): - raise ValueError( - f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." - ) - - D_x = c_skip * x + c_out * F_x.to(torch.float32) - return D_x - - def round_sigma(self, sigma: Union[float, List, torch.Tensor]): - """ - Convert a given sigma value(s) to a tensor representation. - - Parameters - ---------- - sigma : Union[float list, torch.Tensor] - The sigma value(s) to convert. - - Returns - ------- - torch.Tensor - The tensor representation of the provided sigma value(s). - """ - return torch.as_tensor(sigma) - - -@dataclass -class iDDPMPrecondMetaData(ModelMetaData): - """iDDPMPrecond meta data""" - - name: str = "iDDPMPrecond" - # Optimization - jit: bool = False - cuda_graphs: bool = False - amp_cpu: bool = False - amp_gpu: bool = True - torch_fx: bool = False - # Data type - bf16: bool = False - # Inference - onnx: bool = False - # Physics informed - func_torch: bool = False - auto_grad: bool = False - - -class iDDPMPrecond(Module): - """ - Preconditioning corresponding to the improved DDPM (iDDPM) formulation. - - Parameters - ---------- - img_resolution : int - Image resolution. - img_channels : int - Number of color channels. - label_dim : int - Number of class labels, 0 = unconditional, by default 0. - use_fp16 : bool - Execute the underlying model at FP16 precision?, by default False. - C_1 : float - Timestep adjustment at low noise levels., by default 0.001. - C_2 : float - Timestep adjustment at high noise levels., by default 0.008. - M: int - Original number of timesteps in the DDPM formulation, by default 1000. - model_type :str - Class name of the underlying model, by default "DhariwalUNet". - **model_kwargs : dict - Keyword arguments for the underlying model. - - Note - ---- - Reference: Nichol, A.Q. and Dhariwal, P., 2021, July. Improved denoising diffusion - probabilistic models. In International Conference on Machine Learning - (pp. 8162-8171). PMLR. - """ - - def __init__( - self, - img_resolution, - img_channels, - label_dim=0, - use_fp16=False, - C_1=0.001, - C_2=0.008, - M=1000, - model_type="DhariwalUNet", - **model_kwargs, - ): - super().__init__(meta=iDDPMPrecondMetaData) - self.img_resolution = img_resolution - self.img_channels = img_channels - self.label_dim = label_dim - self.use_fp16 = use_fp16 - self.C_1 = C_1 - self.C_2 = C_2 - self.M = M - model_class = getattr(network_module, model_type) - self.model = model_class( - img_resolution=img_resolution, - in_channels=img_channels, - out_channels=img_channels * 2, - label_dim=label_dim, - **model_kwargs, - ) # TODO needs better handling - - u = torch.zeros(M + 1) - for j in range(M, 0, -1): # M, ..., 1 - u[j - 1] = ( - (u[j] ** 2 + 1) - / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1) - - 1 - ).sqrt() - self.register_buffer("u", u) - self.sigma_min = float(u[M - 1]) - self.sigma_max = float(u[0]) - - def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) - class_labels = ( - None - if self.label_dim == 0 - else torch.zeros([1, self.label_dim], device=x.device) - if class_labels is None - else class_labels.to(torch.float32).reshape(-1, self.label_dim) - ) - dtype = ( - torch.float16 - if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") - else torch.float32 - ) - - c_skip = 1 - c_out = -sigma - c_in = 1 / (sigma**2 + 1).sqrt() - c_noise = ( - self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32) - ) - - F_x = self.model( - (c_in * x).to(dtype), - c_noise.flatten(), - class_labels=class_labels, - **model_kwargs, - ) - if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): - raise ValueError( - f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." - ) - - D_x = c_skip * x + c_out * F_x[:, : self.img_channels].to(torch.float32) - return D_x - - def alpha_bar(self, j): - """ - Compute the alpha_bar(j) value for a given j based on the iDDPM formulation. - - Parameters - ---------- - j : Union[int, torch.Tensor] - The timestep or set of timesteps for which to compute alpha_bar(j). - - Returns - ------- - torch.Tensor - The computed alpha_bar(j) value(s). - """ - j = torch.as_tensor(j) - return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2 - - def round_sigma(self, sigma, return_index=False): - """ - Round the provided sigma value(s) to the nearest value(s) in a - pre-defined set `u`. - - Parameters - ---------- - sigma : Union[float, list, torch.Tensor] - The sigma value(s) to round. - return_index : bool, optional - Whether to return the index/indices of the rounded value(s) in `u` instead - of the rounded value(s) themselves, by default False. - - Returns - ------- - torch.Tensor - The rounded sigma value(s) or their index/indices in `u`, depending on the - value of `return_index`. - """ - sigma = torch.as_tensor(sigma) - index = torch.cdist( - sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), - self.u.reshape(1, -1, 1), - ).argmin(2) - result = index if return_index else self.u[index.flatten()].to(sigma.dtype) - return result.reshape(sigma.shape).to(sigma.device) - - -@dataclass -class EDMPrecondMetaData(ModelMetaData): - """EDMPrecond meta data""" - - name: str = "EDMPrecond" - # Optimization - jit: bool = False - cuda_graphs: bool = False - amp_cpu: bool = False - amp_gpu: bool = True - torch_fx: bool = False - # Data type - bf16: bool = False - # Inference - onnx: bool = False - # Physics informed - func_torch: bool = False - auto_grad: bool = False - - -class EDMPrecond(Module): - """ - Improved preconditioning proposed in the paper "Elucidating the Design Space of - Diffusion-Based Generative Models" (EDM) - - Parameters - ---------- - img_resolution : int - Image resolution. - img_channels : int - Number of color channels (for both input and output). If your model - requires a different number of input or output chanels, - override this by passing either of the optional - img_in_channels or img_out_channels args - label_dim : int - Number of class labels, 0 = unconditional, by default 0. - use_fp16 : bool - Execute the underlying model at FP16 precision?, by default False. - sigma_min : float - Minimum supported noise level, by default 0.0. - sigma_max : float - Maximum supported noise level, by default inf. - sigma_data : float - Expected standard deviation of the training data, by default 0.5. - model_type :str - Class name of the underlying model, by default "DhariwalUNet". - img_in_channels: int - Optional setting for when number of input channels =/= number of output - channels. If set, will override img_channels for the input - This is useful in the case of additional (conditional) channels - img_out_channels: int - Optional setting for when number of input channels =/= number of output - channels. If set, will override img_channels for the output - **model_kwargs : dict - Keyword arguments for the underlying model. - - Note - ---- - Reference: Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the - design space of diffusion-based generative models. Advances in Neural Information - Processing Systems, 35, pp.26565-26577. - """ - - def __init__( - self, - img_resolution, - img_channels, - label_dim=0, - use_fp16=False, - sigma_min=0.0, - sigma_max=float("inf"), - sigma_data=0.5, - model_type="DhariwalUNet", - img_in_channels=None, - img_out_channels=None, - **model_kwargs, - ): - super().__init__(meta=EDMPrecondMetaData) - self.img_resolution = img_resolution - if img_in_channels is not None: - img_in_channels = img_in_channels - else: - img_in_channels = img_channels - if img_out_channels is not None: - img_out_channels = img_out_channels - else: - img_out_channels = img_channels - - self.label_dim = label_dim - self.use_fp16 = use_fp16 - self.sigma_min = sigma_min - self.sigma_max = sigma_max - self.sigma_data = sigma_data - - model_class = getattr(network_module, model_type) - self.model = model_class( - img_resolution=img_resolution, - in_channels=img_in_channels, - out_channels=img_out_channels, - label_dim=label_dim, - **model_kwargs, - ) # TODO needs better handling - - def forward( - self, - x, - sigma, - condition=None, - class_labels=None, - force_fp32=False, - **model_kwargs, - ): - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) - class_labels = ( - None - if self.label_dim == 0 - else torch.zeros([1, self.label_dim], device=x.device) - if class_labels is None - else class_labels.to(torch.float32).reshape(-1, self.label_dim) - ) - dtype = ( - torch.float16 - if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") - else torch.float32 - ) - - c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) - c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() - c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() - c_noise = sigma.log() / 4 - - arg = c_in * x - - if condition is not None: - arg = torch.cat([arg, condition], dim=1) - - F_x = self.model( - arg.to(dtype), - c_noise.flatten(), - class_labels=class_labels, - **model_kwargs, - ) - - if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): - raise ValueError( - f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." - ) - D_x = c_skip * x + c_out * F_x.to(torch.float32) - return D_x - - @staticmethod - def round_sigma(sigma: Union[float, List, torch.Tensor]): - """ - Convert a given sigma value(s) to a tensor representation. - - Parameters - ---------- - sigma : Union[float list, torch.Tensor] - The sigma value(s) to convert. - - Returns - ------- - torch.Tensor - The tensor representation of the provided sigma value(s). - """ - return torch.as_tensor(sigma) - - -@dataclass -class EDMPrecondSRMetaData(ModelMetaData): - """EDMPrecondSR meta data""" - - name: str = "EDMPrecondSR" - # Optimization - jit: bool = False - cuda_graphs: bool = False - amp_cpu: bool = False - amp_gpu: bool = True - torch_fx: bool = False - # Data type - bf16: bool = False - # Inference - onnx: bool = False - # Physics informed - func_torch: bool = False - auto_grad: bool = False - - -class EDMPrecondSR(Module): - """ - Improved preconditioning proposed in the paper "Elucidating the Design Space of - Diffusion-Based Generative Models" (EDM) for super-resolution tasks - - Parameters - ---------- - img_resolution : int - Image resolution. - img_channels : int - Number of color channels. - img_in_channels : int - Number of input color channels. - img_out_channels : int - Number of output color channels. - use_fp16 : bool - Execute the underlying model at FP16 precision?, by default False. - sigma_min : float - Minimum supported noise level, by default 0.0. - sigma_max : float - Maximum supported noise level, by default inf. - sigma_data : float - Expected standard deviation of the training data, by default 0.5. - model_type :str - Class name of the underlying model, by default "SongUNetPosEmbd". - **model_kwargs : dict - Keyword arguments for the underlying model. - - Note - ---- - References: - - Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the - design space of diffusion-based generative models. Advances in Neural Information - Processing Systems, 35, pp.26565-26577. - - Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., - Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. - Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. - arXiv preprint arXiv:2309.15214. - """ - - def __init__( - self, - img_resolution, - img_channels, - img_in_channels, - img_out_channels, - use_fp16=False, - sigma_min=0.0, - sigma_max=float("inf"), - sigma_data=0.5, - model_type="SongUNetPosEmbd", - scale_cond_input=True, - **model_kwargs, - ): - super().__init__(meta=EDMPrecondSRMetaData) - self.img_resolution = img_resolution - self.img_channels = img_channels # TODO: this is not used, remove it - self.img_in_channels = img_in_channels - self.img_out_channels = img_out_channels - self.use_fp16 = use_fp16 - self.sigma_min = sigma_min - self.sigma_max = sigma_max - self.sigma_data = sigma_data - self.scale_cond_input = scale_cond_input - - model_class = getattr(network_module, model_type) - self.model = model_class( - img_resolution=img_resolution, - in_channels=img_in_channels + img_out_channels, - out_channels=img_out_channels, - **model_kwargs, - ) # TODO needs better handling - self.scaling_fn = self._get_scaling_fn() - - def _get_scaling_fn(self): - if self.scale_cond_input: - warnings.warn( - "scale_cond_input=True does not properly scale the conditional input. " - "(see https://github.com/NVIDIA/modulus/issues/229). " - "This setup will be deprecated. " - "Please set scale_cond_input=False.", - DeprecationWarning, - ) - return self._legacy_scaling_fn - else: - return self._scaling_fn - - @staticmethod - def _scaling_fn(x, img_lr, c_in): - return torch.cat([c_in * x, img_lr.to(x.dtype)], dim=1) - - @staticmethod - def _legacy_scaling_fn(x, img_lr, c_in): - return c_in * torch.cat([x, img_lr.to(x.dtype)], dim=1) - - @nvtx.annotate(message="EDMPrecondSR", color="orange") - def forward( - self, - x, - img_lr, - sigma, - force_fp32=False, - **model_kwargs, - ): - # Concatenate input channels - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) - dtype = ( - torch.float16 - if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") - else torch.float32 - ) - - c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) - c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() - c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() - c_noise = sigma.log() / 4 - - if img_lr is None: - arg = c_in * x - else: - arg = self.scaling_fn(x, img_lr, c_in) - arg = arg.to(dtype) - - F_x = self.model( - arg, - c_noise.flatten(), - class_labels=None, - **model_kwargs, - ) - - if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): - raise ValueError( - f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." - ) - - D_x = c_skip * x + c_out * F_x.to(torch.float32) - return D_x - - @staticmethod - def round_sigma(sigma: Union[float, List, torch.Tensor]): - """ - Convert a given sigma value(s) to a tensor representation. - See EDMPrecond.round_sigma - """ - return EDMPrecond.round_sigma(sigma) - - -class VEPrecond_dfsr(torch.nn.Module): - """ - Preconditioning for dfsr model, modified from class VEPrecond, where the input - argument 'sigma' in forward propagation function is used to receive the timestep - of the backward diffusion process. - - Parameters - ---------- - img_resolution : int - Image resolution. - img_channels : int - Number of color channels. - label_dim : int - Number of class labels, 0 = unconditional, by default 0. - use_fp16 : bool - Execute the underlying model at FP16 precision?, by default False. - sigma_min : float - Minimum supported noise level, by default 0.02. - sigma_max : float - Maximum supported noise level, by default 100.0. - model_type :str - Class name of the underlying model, by default "SongUNet". - **model_kwargs : dict - Keyword arguments for the underlying model. - - Note - ---- - Reference: Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models. - Advances in neural information processing systems. 2020;33:6840-51. - """ - - def __init__( - self, - img_resolution: int, - img_channels: int, - label_dim: int = 0, - use_fp16: bool = False, - sigma_min: float = 0.02, - sigma_max: float = 100.0, - dataset_mean: float = 5.85e-05, - dataset_scale: float = 4.79, - model_type: str = "SongUNet", - **model_kwargs: dict, - ): - super().__init__() - self.img_resolution = img_resolution - self.img_channels = img_channels - self.label_dim = label_dim - self.use_fp16 = use_fp16 - self.model = globals()[model_type]( - img_resolution=img_resolution, - in_channels=self.img_channels, - out_channels=img_channels, - label_dim=label_dim, - **model_kwargs, - ) # TODO needs better handling - - def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) - # print("sigma: ", sigma) - class_labels = ( - None - if self.label_dim == 0 - else torch.zeros([1, self.label_dim], device=x.device) - if class_labels is None - else class_labels.to(torch.float32).reshape(-1, self.label_dim) - ) - dtype = ( - torch.float16 - if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") - else torch.float32 - ) - - c_in = 1 - c_noise = sigma # Change the definitation of c_noise to avoid -inf values for zero sigma - - F_x = self.model( - (c_in * x).to(dtype), - c_noise.flatten(), - class_labels=class_labels, - **model_kwargs, - ) - - if F_x.dtype != dtype: - raise ValueError( - f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." - ) - - return F_x - - -class VEPrecond_dfsr_cond(torch.nn.Module): - """ - Preconditioning for dfsr model with physics-informed conditioning input, modified - from class VEPrecond, where the input argument 'sigma' in forward propagation function - is used to receive the timestep of the backward diffusion process. The gradient of PDE - residual with respect to the vorticity in the governing Navier-Stokes equation is computed - as the physics-informed conditioning variable and is combined with the backward diffusion - timestep before being sent to the underlying model for noise prediction. - - Parameters - ---------- - img_resolution : int - Image resolution. - img_channels : int - Number of color channels. - label_dim : int - Number of class labels, 0 = unconditional, by default 0. - use_fp16 : bool - Execute the underlying model at FP16 precision?, by default False. - sigma_min : float - Minimum supported noise level, by default 0.02. - sigma_max : float - Maximum supported noise level, by default 100.0. - model_type :str - Class name of the underlying model, by default "SongUNet". - **model_kwargs : dict - Keyword arguments for the underlying model. - - Note - ---- - Reference: - [1] Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and - Poole, B., 2020. Score-based generative modeling through stochastic differential - equations. arXiv preprint arXiv:2011.13456. - [2] Shu D, Li Z, Farimani AB. A physics-informed diffusion model for high-fidelity - flow field reconstruction. Journal of Computational Physics. 2023 Apr 1;478:111972. - """ - - def __init__( - self, - img_resolution: int, - img_channels: int, - label_dim: int = 0, - use_fp16: bool = False, - sigma_min: float = 0.02, - sigma_max: float = 100.0, - dataset_mean: float = 5.85e-05, - dataset_scale: float = 4.79, - model_type: str = "SongUNet", - **model_kwargs: dict, - ): - super().__init__() - self.img_resolution = img_resolution - self.img_channels = img_channels - self.label_dim = label_dim - self.use_fp16 = use_fp16 - self.model = globals()[model_type]( - img_resolution=img_resolution, - in_channels=model_kwargs["model_channels"] * 2, - out_channels=img_channels, - label_dim=label_dim, - **model_kwargs, - ) # TODO needs better handling - - # modules to embed residual loss - self.conv_in = torch.nn.Conv2d( - img_channels, - model_kwargs["model_channels"], - kernel_size=3, - stride=1, - padding=1, - padding_mode="circular", - ) - self.emb_conv = torch.nn.Sequential( - torch.nn.Conv2d( - img_channels, - model_kwargs["model_channels"], - kernel_size=1, - stride=1, - padding=0, - ), - torch.nn.GELU(), - torch.nn.Conv2d( - model_kwargs["model_channels"], - model_kwargs["model_channels"], - kernel_size=3, - stride=1, - padding=1, - padding_mode="circular", - ), - ) - self.dataset_mean = dataset_mean - self.dataset_scale = dataset_scale - - def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) - class_labels = ( - None - if self.label_dim == 0 - else torch.zeros([1, self.label_dim], device=x.device) - if class_labels is None - else class_labels.to(torch.float32).reshape(-1, self.label_dim) - ) - dtype = ( - torch.float16 - if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") - else torch.float32 - ) - - c_in = 1 - c_noise = sigma - - # Compute physics-informed conditioning information using vorticity residual - dx = ( - self.voriticity_residual((x * self.dataset_scale + self.dataset_mean)) - / self.dataset_scale - ) - x = self.conv_in(x) - cond_emb = self.emb_conv(dx) - x = torch.cat((x, cond_emb), dim=1) - - F_x = self.model( - (c_in * x).to(dtype), - c_noise.flatten(), - class_labels=class_labels, - **model_kwargs, - ) - - if F_x.dtype != dtype: - raise ValueError( - f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." - ) - return F_x - - def voriticity_residual(self, w, re=1000.0, dt=1 / 32): - """ - Compute the gradient of PDE residual with respect to a given vorticity w using the - spectrum method. - - Parameters - ---------- - w: torch.Tensor - The fluid flow data sample (vorticity). - re: float - The value of Reynolds number used in the governing Navier-Stokes equation. - dt: float - Time step used to compute the time-derivative of vorticity included in the governing - Navier-Stokes equation. - - Returns - ------- - torch.Tensor - The computed vorticity gradient. - """ - - # w [b t h w] - w = w.clone() - w.requires_grad_(True) - nx = w.size(2) - device = w.device - - w_h = torch.fft.fft2(w[:, 1:-1], dim=[2, 3]) - # Wavenumbers in y-direction - k_max = nx // 2 - N = nx - k_x = ( - torch.cat( - ( - torch.arange(start=0, end=k_max, step=1, device=device), - torch.arange(start=-k_max, end=0, step=1, device=device), - ), - 0, - ) - .reshape(N, 1) - .repeat(1, N) - .reshape(1, 1, N, N) - ) - k_y = ( - torch.cat( - ( - torch.arange(start=0, end=k_max, step=1, device=device), - torch.arange(start=-k_max, end=0, step=1, device=device), - ), - 0, - ) - .reshape(1, N) - .repeat(N, 1) - .reshape(1, 1, N, N) - ) - # Negative Laplacian in Fourier space - lap = k_x**2 + k_y**2 - lap[..., 0, 0] = 1.0 - psi_h = w_h / lap - - u_h = 1j * k_y * psi_h - v_h = -1j * k_x * psi_h - wx_h = 1j * k_x * w_h - wy_h = 1j * k_y * w_h - wlap_h = -lap * w_h - - u = torch.fft.irfft2(u_h[..., :, : k_max + 1], dim=[2, 3]) - v = torch.fft.irfft2(v_h[..., :, : k_max + 1], dim=[2, 3]) - wx = torch.fft.irfft2(wx_h[..., :, : k_max + 1], dim=[2, 3]) - wy = torch.fft.irfft2(wy_h[..., :, : k_max + 1], dim=[2, 3]) - wlap = torch.fft.irfft2(wlap_h[..., :, : k_max + 1], dim=[2, 3]) - advection = u * wx + v * wy - - wt = (w[:, 2:, :, :] - w[:, :-2, :, :]) / (2 * dt) - - # establish forcing term - x = torch.linspace(0, 2 * np.pi, nx + 1, device=device) - x = x[0:-1] - X, Y = torch.meshgrid(x, x) - f = -4 * torch.cos(4 * Y) - - residual = wt + (advection - (1.0 / re) * wlap + 0.1 * w[:, 1:-1]) - f - residual_loss = (residual**2).mean() - dw = torch.autograd.grad(residual_loss, w)[0] - - return dw diff --git a/src/models/preconditioning.py b/src/models/preconditioning.py index 52a1660..7b621e2 100644 --- a/src/models/preconditioning.py +++ b/src/models/preconditioning.py @@ -27,13 +27,13 @@ import numpy as np import nvtx import torch +import torch.nn as nn -from physicsnemo.models.diffusion import ( +from src.models import ( DhariwalUNet, # noqa: F401 for globals SongUNet, # noqa: F401 for globals ) from physicsnemo.models.meta import ModelMetaData -from physicsnemo.models.module import Module network_module = importlib.import_module("physicsnemo.models.diffusion") @@ -58,7 +58,7 @@ class VPPrecondMetaData(ModelMetaData): auto_grad: bool = False -class VPPrecond(Module): +class VPPrecond(nn.Module): """ Preconditioning corresponding to the variance preserving (VP) formulation. @@ -241,7 +241,7 @@ class VEPrecondMetaData(ModelMetaData): auto_grad: bool = False -class VEPrecond(Module): +class VEPrecond(nn.Module): """ Preconditioning corresponding to the variance exploding (VE) formulation. @@ -370,7 +370,7 @@ class iDDPMPrecondMetaData(ModelMetaData): auto_grad: bool = False -class iDDPMPrecond(Module): +class iDDPMPrecond(nn.Module): """ Preconditioning corresponding to the improved DDPM (iDDPM) formulation. @@ -544,7 +544,7 @@ class EDMPrecondMetaData(ModelMetaData): auto_grad: bool = False -class EDMPrecond(Module): +class EDMPrecond(nn.Module): """ Improved preconditioning proposed in the paper "Elucidating the Design Space of Diffusion-Based Generative Models" (EDM) @@ -713,7 +713,7 @@ class EDMPrecondSRMetaData(ModelMetaData): auto_grad: bool = False -class EDMPrecondSR(Module): +class EDMPrecondSR(nn.Module): """ Improved preconditioning proposed in the paper "Elucidating the Design Space of Diffusion-Based Generative Models" (EDM) for super-resolution tasks @@ -861,7 +861,7 @@ def round_sigma(sigma: Union[float, List, torch.Tensor]): return EDMPrecond.round_sigma(sigma) -class VEPrecond_dfsr(torch.nn.Module): +class VEPrecond_dfsr(nn.Module): """ Preconditioning for dfsr model, modified from class VEPrecond, where the input argument 'sigma' in forward propagation function is used to receive the timestep @@ -953,7 +953,7 @@ def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs) return F_x -class VEPrecond_dfsr_cond(torch.nn.Module): +class VEPrecond_dfsr_cond(nn.Module): """ Preconditioning for dfsr model with physics-informed conditioning input, modified from class VEPrecond, where the input argument 'sigma' in forward propagation function diff --git a/src/models/song_unet.py b/src/models/song_unet.py index d38484b..68adbda 100644 --- a/src/models/song_unet.py +++ b/src/models/song_unet.py @@ -27,8 +27,9 @@ import torch from torch.nn.functional import silu from torch.utils.checkpoint import checkpoint +import torch.nn as nn -from physicsnemo.models.diffusion import ( +from src.models import ( Conv2d, FourierEmbedding, GroupNorm, @@ -37,7 +38,6 @@ UNetBlock, ) from physicsnemo.models.meta import ModelMetaData -from physicsnemo.models.module import Module @dataclass @@ -58,7 +58,7 @@ class MetaData(ModelMetaData): auto_grad: bool = False -class SongUNet(Module): +class SongUNet(nn.Module): """ Reimplementation of the DDPM++ and NCSN++ architectures, U-Net variants with optional self-attention, embeddings, and encoder-decoder components. diff --git a/src/models/unet.py b/src/models/unet.py index 7270606..db8e4f8 100644 --- a/src/models/unet.py +++ b/src/models/unet.py @@ -18,11 +18,11 @@ from dataclasses import dataclass import torch +import torch.nn as nn from physicsnemo.models.meta import ModelMetaData -from physicsnemo.models.module import Module -network_module = importlib.import_module("physicsnemo.models.diffusion") +network_module = importlib.import_module("src.models") @dataclass @@ -43,7 +43,7 @@ class MetaData(ModelMetaData): auto_grad: bool = False -class UNet(Module): # TODO a lot of redundancy, need to clean up +class UNet(nn.Module): # TODO a lot of redundancy, need to clean up """ U-Net Wrapper for CorrDiff. @@ -166,7 +166,7 @@ def round_sigma(self, sigma): return torch.as_tensor(sigma) -class StormCastUNet(Module): +class StormCastUNet(nn.Module): """ U-Net wrapper for StormCast; used so the same Song U-Net network can be re-used for this model. diff --git a/src/utils/capture.py b/src/utils/capture.py new file mode 100644 index 0000000..50057f9 --- /dev/null +++ b/src/utils/capture.py @@ -0,0 +1,513 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import logging +import os +import time +from contextlib import nullcontext +from logging import Logger +from typing import Any, Callable, Dict, NewType, Optional, Union + +import torch + +from src.distributed import DistributedManager + +float16 = NewType("float16", torch.float16) +bfloat16 = NewType("bfloat16", torch.bfloat16) +optim = NewType("optim", torch.optim) + + +class _StaticCapture(object): + """Base class for StaticCapture decorator. + + This class should not be used, rather StaticCaptureTraining and StaticCaptureEvaluate + should be used instead for training and evaluation functions. + """ + + # Grad scaler and checkpoint class variables use for checkpoint saving and loading + # Since an instance of Static capture does not exist for checkpoint functions + # one must use class functions to access state dicts + _amp_scalers = {} + _amp_scaler_checkpoints = {} + _logger = logging.getLogger("capture") + + def __new__(cls, *args, **kwargs): + obj = super(_StaticCapture, cls).__new__(cls) + obj.amp_scalers = cls._amp_scalers + obj.amp_scaler_checkpoints = cls._amp_scaler_checkpoints + obj.logger = cls._logger + return obj + + def __init__( + self, + model: "physicsnemo.Module", + optim: Optional[optim] = None, + logger: Optional[Logger] = None, + use_graphs: bool = True, + use_autocast: bool = True, + use_gradscaler: bool = True, + compile: bool = False, + cuda_graph_warmup: int = 11, + amp_type: Union[float16, bfloat16] = torch.float16, + gradient_clip_norm: Optional[float] = None, + label: Optional[str] = None, + ): + self.logger = logger if logger else self.logger + # Checkpoint label (used for gradscaler) + self.label = label if label else f"scaler_{len(self.amp_scalers.keys())}" + + # DDP fix + if not isinstance(model, physicsnemo.models.Module) and hasattr( + model, "module" + ): + model = model.module + + if not isinstance(model, physicsnemo.models.Module): + self.logger.error("Model not a PhysicsNeMo Module!") + raise ValueError("Model not a PhysicsNeMo Module!") + if compile: + model = torch.compile(model) + + self.model = model + + self.optim = optim + self.eval = False + self.no_grad = False + self.gradient_clip_norm = gradient_clip_norm + + # Set up toggles for optimizations + if not (amp_type == torch.float16 or amp_type == torch.bfloat16): + raise ValueError("AMP type must be torch.float16 or torch.bfloat16") + # CUDA device + if "cuda" in str(self.model.device): + # CUDA graphs + if use_graphs and not self.model.meta.cuda_graphs: + self.logger.warning( + f"Model {model.meta.name} does not support CUDA graphs, turning off" + ) + use_graphs = False + self.cuda_graphs_enabled = use_graphs + + # AMP GPU + if not self.model.meta.amp_gpu: + self.logger.warning( + f"Model {model.meta.name} does not support AMP on GPUs, turning off" + ) + use_autocast = False + use_gradscaler = False + self.use_gradscaler = use_gradscaler + self.use_autocast = use_autocast + + self.amp_device = "cuda" + # Check if bfloat16 is suppored on the GPU + if amp_type == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + self.logger.warning( + "Current CUDA device does not support bfloat16, falling back to float16" + ) + amp_type = torch.float16 + self.amp_dtype = amp_type + # Gradient Scaler + scaler_enabled = self.use_gradscaler and amp_type == torch.float16 + self.scaler = self._init_amp_scaler(scaler_enabled, self.logger) + + self.replay_stream = torch.cuda.Stream(self.model.device) + # CPU device + else: + self.cuda_graphs_enabled = False + # AMP CPU + if use_autocast and not self.model.meta.amp_cpu: + self.logger.warning( + f"Model {model.meta.name} does not support AMP on CPUs, turning off" + ) + use_autocast = False + + self.use_autocast = use_autocast + self.amp_device = "cpu" + # Only float16 is supported on CPUs + # https://pytorch.org/docs/stable/amp.html#cpu-op-specific-behavior + if amp_type == torch.float16 and use_autocast: + self.logger.warning( + "torch.float16 not supported for CPU AMP, switching to torch.bfloat16" + ) + amp_type = torch.bfloat16 + self.amp_dtype = torch.bfloat16 + # Gradient Scaler (not enabled) + self.scaler = self._init_amp_scaler(False, self.logger) + self.replay_stream = None + + if self.cuda_graphs_enabled: + self.graph = torch.cuda.CUDAGraph() + + self.output = None + self.iteration = 0 + self.cuda_graph_warmup = cuda_graph_warmup # Default for DDP = 11 + + def __call__(self, fn: Callable) -> Callable: + self.function = fn + + @functools.wraps(fn) + def decorated(*args: Any, **kwds: Any) -> Any: + """Training step decorator function""" + + with torch.no_grad() if self.no_grad else nullcontext(): + if self.cuda_graphs_enabled: + self._cuda_graph_forward(*args, **kwds) + else: + self._zero_grads() + self.output = self._amp_forward(*args, **kwds) + + if not self.eval: + # Update model parameters + self.scaler.step(self.optim) + self.scaler.update() + + return self.output + + return decorated + + def _cuda_graph_forward(self, *args: Any, **kwargs: Any) -> Any: + """Forward training step with CUDA graphs + + Returns + ------- + Any + Output of neural network forward + """ + # Graph warm up + if self.iteration < self.cuda_graph_warmup: + self.replay_stream.wait_stream(torch.cuda.current_stream()) + self._zero_grads() + with torch.cuda.stream(self.replay_stream): + output = self._amp_forward(*args, **kwargs) + self.output = output.detach() + torch.cuda.current_stream().wait_stream(self.replay_stream) + # CUDA Graphs + else: + # Graph record + if self.iteration == self.cuda_graph_warmup: + self.logger.warning(f"Recording graph of '{self.function.__name__}'") + self._zero_grads() + torch.cuda.synchronize() + if DistributedManager().distributed: + torch.distributed.barrier() + # TODO: temporary workaround till this issue is fixed: + # https://github.com/pytorch/pytorch/pull/104487#issuecomment-1638665876 + delay = os.environ.get("PHYSICSNEMO_CUDA_GRAPH_CAPTURE_DELAY", "10") + time.sleep(int(delay)) + with torch.cuda.graph(self.graph): + output = self._amp_forward(*args, **kwargs) + self.output = output.detach() + # Graph replay + self.graph.replay() + + self.iteration += 1 + return self.output + + def _zero_grads(self): + """Zero gradients + + Default to `set_to_none` since this will in general have lower memory + footprint, and can modestly improve performance. + + Note + ---- + Zeroing gradients can potentially cause an invalid CUDA memory access in another + graph. However if your graph involves gradients, you much set your gradients to none. + If there is already a graph recorded that includes these gradients, this will error. + Use the `NoGrad` version of capture to avoid this issue for inferencers / validators. + """ + # Skip zeroing if no grad is being used + if self.no_grad: + return + + try: + self.optim.zero_grad(set_to_none=True) + except Exception: + if self.optim: + self.optim.zero_grad() + # For apex optim support and eval mode (need to reset model grads) + self.model.zero_grad(set_to_none=True) + + def _amp_forward(self, *args, **kwargs) -> Any: + """Compute loss and gradients (if training) with AMP + + Returns + ------- + Any + Output of neural network forward + """ + with torch.autocast( + self.amp_device, enabled=self.use_autocast, dtype=self.amp_dtype + ): + output = self.function(*args, **kwargs) + + if not self.eval: + # In training mode output should be the loss + self.scaler.scale(output).backward() + if self.gradient_clip_norm is not None: + self.scaler.unscale_(self.optim) + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.gradient_clip_norm + ) + + return output + + def _init_amp_scaler( + self, scaler_enabled: bool, logger: Logger + ) -> torch.cuda.amp.GradScaler: + # Create gradient scaler + scaler = torch.cuda.amp.GradScaler(enabled=scaler_enabled) + # Store scaler in class variable + self.amp_scalers[self.label] = scaler + logging.debug(f"Created gradient scaler {self.label}") + + # If our checkpoint dictionary has weights for this scaler lets load + if self.label in self.amp_scaler_checkpoints: + try: + scaler.load_state_dict(self.amp_scaler_checkpoints[self.label]) + del self.amp_scaler_checkpoints[self.label] + self.logger.info(f"Loaded grad scaler state dictionary {self.label}.") + except Exception as e: + self.logger.error( + f"Failed to load grad scaler {self.label} state dict from saved " + + "checkpoints. Did you switch the ordering of declared static captures?" + ) + raise ValueError(e) + return scaler + + @classmethod + def state_dict(cls) -> Dict[str, Any]: + """Class method for accsessing the StaticCapture state dictionary. + Use this in a training checkpoint function. + + Returns + ------- + Dict[str, Any] + Dictionary of states to save for file + """ + scaler_states = {} + for key, value in cls._amp_scalers.items(): + scaler_states[key] = value.state_dict() + + return scaler_states + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any]) -> None: + """Class method for loading a StaticCapture state dictionary. + Use this in a training checkpoint function. + + Returns + ------- + Dict[str, Any] + Dictionary of states to save for file + """ + for key, value in state_dict.items(): + # If scaler has been created already load the weights + if key in cls._amp_scalers: + try: + cls._amp_scalers[key].load_state_dict(value) + cls._logger.info(f"Loaded grad scaler state dictionary {key}.") + except Exception as e: + cls._logger.error( + f"Failed to load grad scaler state dict with id {key}." + + " Something went wrong!" + ) + raise ValueError(e) + # Otherwise store in checkpoints for later use + else: + cls._amp_scaler_checkpoints[key] = value + + @classmethod + def reset_state(cls): + cls._amp_scalers = {} + cls._amp_scaler_checkpoints = {} + + +class StaticCaptureTraining(_StaticCapture): + """A performance optimization decorator for PyTorch training functions. + + This class should be initialized as a decorator on a function that computes the + forward pass of the neural network and loss function. The user should only call the + defind training step function. This will apply optimizations including: AMP and + Cuda Graphs. + + Parameters + ---------- + model : physicsnemo.models.Module + PhysicsNeMo Model + optim : torch.optim + Optimizer + logger : Optional[Logger], optional + PhysicsNeMo Launch Logger, by default None + use_graphs : bool, optional + Toggle CUDA graphs if supported by model, by default True + use_amp : bool, optional + Toggle AMP if supported by mode, by default True + cuda_graph_warmup : int, optional + Number of warmup steps for cuda graphs, by default 11 + amp_type : Union[float16, bfloat16], optional + Auto casting type for AMP, by default torch.float16 + gradient_clip_norm : Optional[float], optional + Threshold for gradient clipping + label : Optional[str], optional + Static capture checkpoint label, by default None + + Raises + ------ + ValueError + If the model provided is not a physicsnemo.models.Module. I.e. has no meta data. + + Example + ------- + >>> # Create model + >>> model = physicsnemo.models.mlp.FullyConnected(2, 64, 2) + >>> input = torch.rand(8, 2) + >>> output = torch.rand(8, 2) + >>> # Create optimizer + >>> optim = torch.optim.Adam(model.parameters(), lr=0.001) + >>> # Create training step function with optimization wrapper + >>> @StaticCaptureTraining(model=model, optim=optim) + ... def training_step(model, invar, outvar): + ... predvar = model(invar) + ... loss = torch.sum(torch.pow(predvar - outvar, 2)) + ... return loss + ... + >>> # Sample training loop + >>> for i in range(3): + ... loss = training_step(model, input, output) + ... + + Note + ---- + Static captures must be checkpointed when training using the `state_dict()` if AMP + is being used with gradient scaler. By default, this requires static captures to be + instantiated in the same order as when they were checkpointed. The label parameter + can be used to relax/circumvent this ordering requirement. + + Note + ---- + Capturing multiple cuda graphs in a single program can lead to potential invalid CUDA + memory access errors on some systems. Prioritize capturing training graphs when this + occurs. + """ + + def __init__( + self, + model: "physicsnemo.Module", + optim: torch.optim, + logger: Optional[Logger] = None, + use_graphs: bool = True, + use_amp: bool = True, + compile: bool = False, + cuda_graph_warmup: int = 11, + amp_type: Union[float16, bfloat16] = torch.float16, + gradient_clip_norm: Optional[float] = None, + label: Optional[str] = None, + ): + super().__init__( + model, + optim, + logger, + use_graphs, + use_amp, + use_amp, + compile, + cuda_graph_warmup, + amp_type, + gradient_clip_norm, + label, + ) + + +class StaticCaptureEvaluateNoGrad(_StaticCapture): + + """An performance optimization decorator for PyTorch no grad evaluation. + + This class should be initialized as a decorator on a function that computes run the + forward pass of the model that does not require gradient calculations. This is the + recommended method to use for inference and validation methods. + + Parameters + ---------- + model : physicsnemo.models.Module + PhysicsNeMo Model + logger : Optional[Logger], optional + PhysicsNeMo Launch Logger, by default None + use_graphs : bool, optional + Toggle CUDA graphs if supported by model, by default True + use_amp : bool, optional + Toggle AMP if supported by mode, by default True + cuda_graph_warmup : int, optional + Number of warmup steps for cuda graphs, by default 11 + amp_type : Union[float16, bfloat16], optional + Auto casting type for AMP, by default torch.float16 + label : Optional[str], optional + Static capture checkpoint label, by default None + + Raises + ------ + ValueError + If the model provided is not a physicsnemo.models.Module. I.e. has no meta data. + + Example + ------- + >>> # Create model + >>> model = physicsnemo.models.mlp.FullyConnected(2, 64, 2) + >>> input = torch.rand(8, 2) + >>> # Create evaluate function with optimization wrapper + >>> @StaticCaptureEvaluateNoGrad(model=model) + ... def eval_step(model, invar): + ... predvar = model(invar) + ... return predvar + ... + >>> output = eval_step(model, input) + >>> output.size() + torch.Size([8, 2]) + + Note + ---- + Capturing multiple cuda graphs in a single program can lead to potential invalid CUDA + memory access errors on some systems. Prioritize capturing training graphs when this + occurs. + """ + + def __init__( + self, + model: "physicsnemo.Module", + logger: Optional[Logger] = None, + use_graphs: bool = True, + use_amp: bool = True, + compile: bool = False, + cuda_graph_warmup: int = 11, + amp_type: Union[float16, bfloat16] = torch.float16, + label: Optional[str] = None, + ): + super().__init__( + model, + None, + logger, + use_graphs, + use_amp, + compile, + False, + cuda_graph_warmup, + amp_type, + None, + label, + ) + self.eval = True # No optimizer/scaler calls + self.no_grad = True # No grad context and no grad zeroing diff --git a/src/utils/checkpoint.py b/src/utils/checkpoint.py new file mode 100644 index 0000000..8ec70fa --- /dev/null +++ b/src/utils/checkpoint.py @@ -0,0 +1,398 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import re +from pathlib import Path +from typing import Any, Dict, List, NewType, Optional, Union + +import torch +from torch.cuda.amp import GradScaler +from torch.optim.lr_scheduler import _LRScheduler + +from src.distributed import DistributedManager +from src.utils.console import PythonLogger +from src.utils.capture import _StaticCapture + +optimizer = NewType("optimizer", torch.optim) +scheduler = NewType("scheduler", _LRScheduler) +scaler = NewType("scaler", GradScaler) + +checkpoint_logging = PythonLogger("checkpoint") + + +def _get_checkpoint_filename( + path: str, + base_name: str = "checkpoint", + index: Union[int, None] = None, + saving: bool = False, + model_type: str = "mdlus", +) -> str: + """Gets the file name /path of checkpoint + + This function has three different ways of providing a checkout filename: + - If supplied an index this will return the checkpoint name using that index. + - If index is None and saving is false, this will get the checkpoint with the + largest index (latest save). + - If index is None and saving is true, it will return the next valid index file name + which is calculated by indexing the largest checkpoint index found by one. + + Parameters + ---------- + path : str + Path to checkpoints + base_name: str, optional + Base file name, by default checkpoint + index : Union[int, None], optional + Checkpoint index, by default None + saving : bool, optional + Get filename for saving a new checkpoint, by default False + model_type : str + Model type, by default "mdlus" for PhysicsNeMo models and "pt" for PyTorch models + + + Returns + ------- + str + Checkpoint file name + """ + # Get model parallel rank so all processes in the first model parallel group + # can save their checkpoint. In the case without model parallelism, + # model_parallel_rank should be the same as the process rank itself and + # only rank 0 saves + if not DistributedManager.is_initialized(): + checkpoint_logging.warning( + "`DistributedManager` not initialized already. Initializing now, but this might lead to unexpected errors" + ) + DistributedManager.initialize() + manager = DistributedManager() + model_parallel_rank = ( + manager.group_rank("model_parallel") + if "model_parallel" in manager.group_names + else 0 + ) + + # Input file name + checkpoint_filename = str( + Path(path).resolve() / f"{base_name}.{model_parallel_rank}" + ) + + # File extension for PhysicsNeMo models or PyTorch models + file_extension = ".mdlus" if model_type == "mdlus" else ".pt" + + # If epoch is provided load that file + if index is not None: + checkpoint_filename = checkpoint_filename + f".{index}" + checkpoint_filename += file_extension + # Otherwise try loading the latest epoch or rolling checkpoint + else: + file_names = [ + Path(fname).name + for fname in glob.glob( + checkpoint_filename + "*" + file_extension, recursive=False + ) + ] + + if len(file_names) > 0: + # If checkpoint from a null index save exists load that + # This is the most likely line to error since it will fail with + # invalid checkpoint names + file_idx = [ + int( + re.sub( + f"^{base_name}.{model_parallel_rank}.|" + file_extension, + "", + fname, + ) + ) + for fname in file_names + ] + file_idx.sort() + # If we are saving index by 1 to get the next free file name + if saving: + checkpoint_filename = checkpoint_filename + f".{file_idx[-1]+1}" + else: + checkpoint_filename = checkpoint_filename + f".{file_idx[-1]}" + checkpoint_filename += file_extension + else: + checkpoint_filename += ".0" + file_extension + + return checkpoint_filename + + +def _unique_model_names( + models: List[torch.nn.Module], +) -> Dict[str, torch.nn.Module]: + """Util to clean model names and index if repeat names, will also strip DDP wrappers + if they exist. + + Parameters + ---------- + model : List[torch.nn.Module] + List of models to generate names for + + Returns + ------- + Dict[str, torch.nn.Module] + Dictionary of model names and respective modules + """ + # Loop through provided models and set up base names + model_dict = {} + for model0 in models: + if hasattr(model0, "module"): + # Strip out DDP layer + model0 = model0.module + # Base name of model is meta.name unless pytorch model + base_name = model0.__class__.__name__ + if isinstance(model0, physicsnemo.models.Module): + base_name = model0.meta.name + # If we have multiple models of the same name, introduce another index + if base_name in model_dict: + model_dict[base_name].append(model0) + else: + model_dict[base_name] = [model0] + + # Set up unique model names if needed + output_dict = {} + for key, model in model_dict.items(): + if len(model) > 1: + for i, model0 in enumerate(model): + output_dict[key + str(i)] = model0 + else: + output_dict[key] = model[0] + + return output_dict + + +def save_checkpoint( + path: str, + models: Union[torch.nn.Module, List[torch.nn.Module], None] = None, + optimizer: Union[optimizer, None] = None, + scheduler: Union[scheduler, None] = None, + scaler: Union[scaler, None] = None, + epoch: Union[int, None] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> None: + """Training checkpoint saving utility + + This will save a training checkpoint in the provided path following the file naming + convention "checkpoint.{model parallel id}.{epoch/index}.mdlus". The load checkpoint + method in PhysicsNeMo core can then be used to read this file. + + Parameters + ---------- + path : str + Path to save the training checkpoint + models : Union[torch.nn.Module, List[torch.nn.Module], None], optional + A single or list of PyTorch models, by default None + optimizer : Union[optimizer, None], optional + Optimizer, by default None + scheduler : Union[scheduler, None], optional + Learning rate scheduler, by default None + scaler : Union[scaler, None], optional + AMP grad scaler. Will attempt to save on in static capture if none provided, by + default None + epoch : Union[int, None], optional + Epoch checkpoint to load. If none this will save the checkpoint in the next + valid index, by default None + metadata : Optional[Dict[str, Any]], optional + Additional metadata to save, by default None + """ + # Create checkpoint directory if it does not exist + if not Path(path).is_dir(): + checkpoint_logging.warning( + f"Output directory {path} does not exist, will " "attempt to create" + ) + Path(path).mkdir(parents=True, exist_ok=True) + + # == Saving model checkpoint == + if models: + if not isinstance(models, list): + models = [models] + models = _unique_model_names(models) + for name, model in models.items(): + # Get model type + model_type = ( + "mdlus" if isinstance(model, physicsnemo.models.Module) else "pt" + ) + + # Get full file path / name + file_name = _get_checkpoint_filename( + path, name, index=epoch, saving=True, model_type=model_type + ) + + # Save state dictionary + if isinstance(model, physicsnemo.models.Module): + model.save(file_name) + else: + torch.save(model.state_dict(), file_name) + checkpoint_logging.success(f"Saved model state dictionary: {file_name}") + + # == Saving training checkpoint == + checkpoint_dict = {} + # Optimizer state dict + if optimizer: + checkpoint_dict["optimizer_state_dict"] = optimizer.state_dict() + + # Scheduler state dict + if scheduler: + checkpoint_dict["scheduler_state_dict"] = scheduler.state_dict() + + # Scheduler state dict + if scaler: + checkpoint_dict["scaler_state_dict"] = scaler.state_dict() + # Static capture is being used, save its grad scaler + if _StaticCapture._amp_scalers: + checkpoint_dict["static_capture_state_dict"] = _StaticCapture.state_dict() + + # Output file name + output_filename = _get_checkpoint_filename( + path, index=epoch, saving=True, model_type="pt" + ) + if epoch: + checkpoint_dict["epoch"] = epoch + if metadata: + checkpoint_dict["metadata"] = metadata + + # Save checkpoint to memory + if bool(checkpoint_dict): + torch.save( + checkpoint_dict, + output_filename, + ) + checkpoint_logging.success(f"Saved training checkpoint: {output_filename}") + + +def load_checkpoint( + path: str, + models: Union[torch.nn.Module, List[torch.nn.Module], None] = None, + optimizer: Union[optimizer, None] = None, + scheduler: Union[scheduler, None] = None, + scaler: Union[scaler, None] = None, + epoch: Union[int, None] = None, + metadata_dict: Optional[Dict[str, Any]] = {}, + device: Union[str, torch.device] = "cpu", +) -> int: + """Checkpoint loading utility + + This loader is designed to be used with the save checkpoint utility in PhysicsNeMo + Launch. Given a path, this method will try to find a checkpoint and load state + dictionaries into the provided training objects. + + Parameters + ---------- + path : str + Path to training checkpoint + models : Union[torch.nn.Module, List[torch.nn.Module], None], optional + A single or list of PyTorch models, by default None + optimizer : Union[optimizer, None], optional + Optimizer, by default None + scheduler : Union[scheduler, None], optional + Learning rate scheduler, by default None + scaler : Union[scaler, None], optional + AMP grad scaler, by default None + epoch : Union[int, None], optional + Epoch checkpoint to load. If none is provided this will attempt to load the + checkpoint with the largest index, by default None + metadata_dict: Optional[Dict[str, Any]], optional + Dictionary to store metadata from the checkpoint, by default None + device : Union[str, torch.device], optional + Target device, by default "cpu" + + Returns + ------- + int + Loaded epoch + """ + # Check if checkpoint directory exists + if not Path(path).is_dir(): + checkpoint_logging.warning( + f"Provided checkpoint directory {path} does not exist, skipping load" + ) + return 0 + + # == Loading model checkpoint == + if models: + if not isinstance(models, list): + models = [models] + models = _unique_model_names(models) + for name, model in models.items(): + # Get model type + model_type = ( + "mdlus" if isinstance(model, physicsnemo.models.Module) else "pt" + ) + + # Get full file path / name + file_name = _get_checkpoint_filename( + path, name, index=epoch, model_type=model_type + ) + if not Path(file_name).exists(): + checkpoint_logging.error( + f"Could not find valid model file {file_name}, skipping load" + ) + continue + # Load state dictionary + if isinstance(model, physicsnemo.models.Module): + model.load(file_name) + else: + model.load_state_dict(torch.load(file_name, map_location=device)) + + checkpoint_logging.success( + f"Loaded model state dictionary {file_name} to device {device}" + ) + + # == Loading training checkpoint == + checkpoint_filename = _get_checkpoint_filename(path, index=epoch, model_type="pt") + if not Path(checkpoint_filename).is_file(): + checkpoint_logging.warning( + "Could not find valid checkpoint file, skipping load" + ) + return 0 + + checkpoint_dict = torch.load(checkpoint_filename, map_location=device) + checkpoint_logging.success( + f"Loaded checkpoint file {checkpoint_filename} to device {device}" + ) + + # Optimizer state dict + if optimizer and "optimizer_state_dict" in checkpoint_dict: + optimizer.load_state_dict(checkpoint_dict["optimizer_state_dict"]) + checkpoint_logging.success("Loaded optimizer state dictionary") + + # Scheduler state dict + if scheduler and "scheduler_state_dict" in checkpoint_dict: + scheduler.load_state_dict(checkpoint_dict["scheduler_state_dict"]) + checkpoint_logging.success("Loaded scheduler state dictionary") + + # Scaler state dict + if scaler and "scaler_state_dict" in checkpoint_dict: + scaler.load_state_dict(checkpoint_dict["scaler_state_dict"]) + checkpoint_logging.success("Loaded grad scaler state dictionary") + + if "static_capture_state_dict" in checkpoint_dict: + _StaticCapture.load_state_dict(checkpoint_dict["static_capture_state_dict"]) + checkpoint_logging.success("Loaded static capture state dictionary") + + epoch = 0 + if "epoch" in checkpoint_dict: + epoch = checkpoint_dict["epoch"] + + # Update metadata if exists and the dictionary object is provided + metadata = checkpoint_dict.get("metadata", {}) + for key, value in metadata.items(): + metadata_dict[key] = value + + return epoch diff --git a/src/utils/console.py b/src/utils/console.py new file mode 100644 index 0000000..4231576 --- /dev/null +++ b/src/utils/console.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +from termcolor import colored + + +class PythonLogger: + """Simple console logger for DL training + This is a WIP + """ + + def __init__(self, name: str = "launch"): + self.logger = logging.getLogger(name) + + def file_logging(self, file_name: str = "launch.log"): + """Log to file""" + if os.path.exists(file_name): + try: + os.remove(file_name) + except FileNotFoundError: + # ignore if already removed (can happen with multiple processes) + pass + formatter = logging.Formatter( + "[%(asctime)s - %(name)s - %(levelname)s] %(message)s", + datefmt="%H:%M:%S", + ) + filehandler = logging.FileHandler(file_name) + filehandler.setFormatter(formatter) + filehandler.setLevel(logging.DEBUG) + self.logger.addHandler(filehandler) + + def log(self, message: str): + """Log message""" + self.logger.info(message) + + def info(self, message: str): + """Log info""" + self.logger.info(colored(message, "light_blue")) + + def success(self, message: str): + """Log success""" + self.logger.info(colored(message, "light_green")) + + def warning(self, message: str): + """Log warning""" + self.logger.warning(colored(message, "light_yellow")) + + def error(self, message: str): + """Log error""" + self.logger.error(colored(message, "light_red")) + + +class RankZeroLoggingWrapper: + """Wrapper class to only log from rank 0 process in distributed training.""" + + def __init__(self, obj, dist): + self.obj = obj + self.dist = dist + + def __getattr__(self, name): + attr = getattr(self.obj, name) + if callable(attr): + + def wrapper(*args, **kwargs): + if self.dist.rank == 0: + return attr(*args, **kwargs) + else: + return None + + return wrapper + else: + return attr diff --git a/src/utils/deterministic_sampler.py b/src/utils/deterministic_sampler.py new file mode 100644 index 0000000..4b2f32b --- /dev/null +++ b/src/utils/deterministic_sampler.py @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import nvtx +import torch + +from src.models import EDMPrecond + +# ruff: noqa: E731 + + +@nvtx.annotate(message="deterministic_sampler", color="red") +def deterministic_sampler( + net, + latents, + img_lr, + img_shape=None, + class_labels=None, + randn_like=torch.randn_like, + num_steps=18, + sigma_min=None, + sigma_max=None, + rho=7, + solver="heun", + discretization="edm", + schedule="linear", + scaling="none", + epsilon_s=1e-3, + C_1=0.001, + C_2=0.008, + M=1000, + alpha=1, + S_churn=0, + S_min=0, + S_max=float("inf"), + S_noise=1, +): + """ + Generalized sampler, representing the superset of all sampling methods discussed + in the paper "Elucidating the Design Space of Diffusion-Based Generative Models" + """ + + # conditioning + x_lr = img_lr + + if solver not in ["euler", "heun"]: + raise ValueError(f"Unknown solver {solver}") + if discretization not in ["vp", "ve", "iddpm", "edm"]: + raise ValueError(f"Unknown discretization {discretization}") + if schedule not in ["vp", "ve", "linear"]: + raise ValueError(f"Unknown schedule {schedule}") + if scaling not in ["vp", "none"]: + raise ValueError(f"Unknown scaling {scaling}") + + # Helper functions for VP & VE noise level schedules. + vp_sigma = ( + lambda beta_d, beta_min: lambda t: ( + np.e ** (0.5 * beta_d * (t**2) + beta_min * t) - 1 + ) + ** 0.5 + ) + vp_sigma_deriv = ( + lambda beta_d, beta_min: lambda t: 0.5 + * (beta_min + beta_d * t) + * (sigma(t) + 1 / sigma(t)) + ) + vp_sigma_inv = ( + lambda beta_d, beta_min: lambda sigma: ( + (beta_min**2 + 2 * beta_d * (sigma**2 + 1).log()).sqrt() - beta_min + ) + / beta_d + ) + ve_sigma = lambda t: t.sqrt() + ve_sigma_deriv = lambda t: 0.5 / t.sqrt() + ve_sigma_inv = lambda sigma: sigma**2 + + # Select default noise level range based on the specified time step discretization. + if sigma_min is None: + vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s) + sigma_min = {"vp": vp_def, "ve": 0.02, "iddpm": 0.002, "edm": 0.002}[ + discretization + ] + if sigma_max is None: + vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1) + sigma_max = {"vp": vp_def, "ve": 100, "iddpm": 81, "edm": 80}[discretization] + + # Adjust noise levels based on what's supported by the network. + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + + # Compute corresponding betas for VP. + vp_beta_d = ( + 2 + * (np.log(sigma_min**2 + 1) / epsilon_s - np.log(sigma_max**2 + 1)) + / (epsilon_s - 1) + ) + vp_beta_min = np.log(sigma_max**2 + 1) - 0.5 * vp_beta_d + + # Define time steps in terms of noise level. + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + if discretization == "vp": + orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) + sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) + elif discretization == "ve": + orig_t_steps = (sigma_max**2) * ( + (sigma_min**2 / sigma_max**2) ** (step_indices / (num_steps - 1)) + ) + sigma_steps = ve_sigma(orig_t_steps) + elif discretization == "iddpm": + u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device) + alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 + for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1 + u[j - 1] = ( + (u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1 + ).sqrt() + u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] + sigma_steps = u_filtered[ + ((len(u_filtered) - 1) / (num_steps - 1) * step_indices) + .round() + .to(torch.int64) + ] + else: + sigma_steps = ( + sigma_max ** (1 / rho) + + step_indices + / (num_steps - 1) + * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) + ) ** rho + + # Define noise level schedule. + if schedule == "vp": + sigma = vp_sigma(vp_beta_d, vp_beta_min) + sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) + sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) + elif schedule == "ve": + sigma = ve_sigma + sigma_deriv = ve_sigma_deriv + sigma_inv = ve_sigma_inv + else: + sigma = lambda t: t + sigma_deriv = lambda t: 1 + sigma_inv = lambda sigma: sigma + + # Define scaling schedule. + if scaling == "vp": + s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() + s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) + else: + s = lambda t: 1 + s_deriv = lambda t: 0 + + # Compute final time steps based on the corresponding noise levels. + t_steps = sigma_inv(net.round_sigma(sigma_steps)) + t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + t_next = t_steps[0] + x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next)) + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = ( + min(S_churn / num_steps, np.sqrt(2) - 1) + if S_min <= sigma(t_cur) <= S_max + else 0 + ) + t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) + x_hat = s(t_hat) / s(t_cur) * x_cur + ( + sigma(t_hat) ** 2 - sigma(t_cur) ** 2 + ).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur) + + # Euler step. + h = t_next - t_hat + if isinstance(net, EDMPrecond): + # Conditioning info is passed as keyword arg + denoised = net( + x_hat / s(t_hat), + sigma(t_hat), + condition=x_lr, + class_labels=class_labels, + ).to(torch.float64) + else: + denoised = net(x_hat / s(t_hat), x_lr, sigma(t_hat), class_labels).to( + torch.float64 + ) + d_cur = ( + sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat) + ) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised + x_prime = x_hat + alpha * h * d_cur + t_prime = t_hat + alpha * h + + # Apply 2nd order correction. + if solver == "euler" or i == num_steps - 1: + x_next = x_hat + h * d_cur + else: + if isinstance(net, EDMPrecond): + # Conditioning info is passed as keyword arg + denoised = net( + x_prime / s(t_prime), + sigma(t_prime), + condition=x_lr, + class_labels=class_labels, + ).to(torch.float64) + else: + denoised = net( + x_prime / s(t_prime), x_lr, sigma(t_prime), class_labels + ).to(torch.float64) + d_prime = ( + sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime) + ) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised + x_next = x_hat + h * ( + (1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime + ) + + return x_next diff --git a/src/utils/function_utils.py b/src/utils/function_utils.py new file mode 100644 index 0000000..dcbb127 --- /dev/null +++ b/src/utils/function_utils.py @@ -0,0 +1,775 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Miscellaneous utility classes and functions.""" + +import contextlib +import ctypes +import datetime +import fnmatch +import importlib +import inspect +import os +import re +import shutil +import sys +import types +import warnings +from typing import Any, List, Tuple, Union + +import cftime +import numpy as np +import torch + +# ruff: noqa: E722 PERF203 S110 E713 S324 + + +class EasyDict(dict): # pragma: no cover + """ + Convenience class that behaves like a dict but allows access with the attribute + syntax. + """ + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +class StackedRandomGenerator: # pragma: no cover + """ + Wrapper for torch.Generator that allows specifying a different random seed + for each sample in a minibatch. + """ + + def __init__(self, device, seeds): + super().__init__() + self.generators = [ + torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds + ] + + def randn(self, size, **kwargs): + if size[0] != len(self.generators): + raise ValueError( + f"Expected first dimension of size {len(self.generators)}, got {size[0]}" + ) + return torch.stack( + [torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators] + ) + + def randn_like(self, input): + return self.randn( + input.shape, dtype=input.dtype, layout=input.layout, device=input.device + ) + + def randint(self, *args, size, **kwargs): + if size[0] != len(self.generators): + raise ValueError( + f"Expected first dimension of size {len(self.generators)}, got {size[0]}" + ) + return torch.stack( + [ + torch.randint(*args, size=size[1:], generator=gen, **kwargs) + for gen in self.generators + ] + ) + + +def parse_int_list(s): # pragma: no cover + """ + Parse a comma separated list of numbers or ranges and return a list of ints. + Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] + """ + if isinstance(s, list): + return s + ranges = [] + range_re = re.compile(r"^(\d+)-(\d+)$") + for p in s.split(","): + m = range_re.match(p) + if m: + ranges.extend(range(int(m.group(1)), int(m.group(2)) + 1)) + else: + ranges.append(int(p)) + return ranges + + +# Small util functions +# ------------------------------------------------------------------------------------- +def convert_datetime_to_cftime( + time: datetime.datetime, cls=cftime.DatetimeGregorian +) -> cftime.DatetimeGregorian: + """Convert a Python datetime object to a cftime DatetimeGregorian object.""" + return cls(time.year, time.month, time.day, time.hour, time.minute, time.second) + + +def time_range( + start_time: datetime.datetime, + end_time: datetime.datetime, + step: datetime.timedelta, + inclusive: bool = False, +): + """Like the Python `range` iterator, but with datetimes.""" + t = start_time + while (t <= end_time) if inclusive else (t < end_time): + yield t + t += step + + +def format_time(seconds: Union[int, float]) -> str: # pragma: no cover + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) + else: + return "{0}d {1:02}h {2:02}m".format( + s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60 + ) + + +def format_time_brief(seconds: Union[int, float]) -> str: # pragma: no cover + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) + else: + return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) + + +def tuple_product(t: Tuple) -> Any: # pragma: no cover + """Calculate the product of the tuple elements.""" + result = 1 + + for v in t: + result *= v + + return result + + +_str_to_ctype = { + "uint8": ctypes.c_ubyte, + "uint16": ctypes.c_uint16, + "uint32": ctypes.c_uint32, + "uint64": ctypes.c_uint64, + "int8": ctypes.c_byte, + "int16": ctypes.c_int16, + "int32": ctypes.c_int32, + "int64": ctypes.c_int64, + "float32": ctypes.c_float, + "float64": ctypes.c_double, +} + + +def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: # pragma: no cover + """ + Given a type name string (or an object having a __name__ attribute), return + matching Numpy and ctypes types that have the same size in bytes. + """ + type_str = None + + if isinstance(type_obj, str): + type_str = type_obj + elif hasattr(type_obj, "__name__"): + type_str = type_obj.__name__ + elif hasattr(type_obj, "name"): + type_str = type_obj.name + else: + raise RuntimeError("Cannot infer type name from input") + + if type_str not in _str_to_ctype.keys(): + raise ValueError("Unknown type name: " + type_str) + + my_dtype = np.dtype(type_str) + my_ctype = _str_to_ctype[type_str] + + if my_dtype.itemsize != ctypes.sizeof(my_ctype): + raise ValueError( + "Numpy and ctypes types for '{}' have different sizes!".format(type_str) + ) + + return my_dtype, my_ctype + + +# Functionality to import modules/objects by name, and call functions by name +# ------------------------------------------------------------------------------------- + + +def get_module_from_obj_name( + obj_name: str, +) -> Tuple[types.ModuleType, str]: # pragma: no cover + """ + Searches for the underlying module behind the name to some python object. + Returns the module and the object name (original name with module part removed). + """ + + # allow convenience shorthands, substitute them by full names + obj_name = re.sub("^np.", "numpy.", obj_name) + obj_name = re.sub("^tf.", "tensorflow.", obj_name) + + # list alternatives for (module_name, local_obj_name) + parts = obj_name.split(".") + name_pairs = [ + (".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1) + ] + + # try each alternative in turn + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + return module, local_obj_name + except: + pass + + # maybe some of the modules themselves contain errors? + for module_name, _local_obj_name in name_pairs: + try: + importlib.import_module(module_name) # may raise ImportError + except ImportError: + if not str(sys.exc_info()[1]).startswith( + "No module named '" + module_name + "'" + ): + raise + + # maybe the requested attribute is missing? + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + except ImportError: + pass + + # we are out of luck, but we have no idea why + raise ImportError(obj_name) + + +def get_obj_from_module( + module: types.ModuleType, obj_name: str +) -> Any: # pragma: no cover + """ + Traverses the object name and returns the last (rightmost) python object. + """ + if obj_name == "": + return module + obj = module + for part in obj_name.split("."): + obj = getattr(obj, part) + return obj + + +def get_obj_by_name(name: str) -> Any: # pragma: no cover + """ + Finds the python object with the given name. + """ + module, obj_name = get_module_from_obj_name(name) + return get_obj_from_module(module, obj_name) + + +def call_func_by_name( + *args, func_name: str = None, **kwargs +) -> Any: # pragma: no cover + """ + Finds the python object with the given name and calls it as a function. + """ + if func_name is None: + raise ValueError("func_name must be specified") + func_obj = get_obj_by_name(func_name) + if not callable(func_obj): + raise ValueError(func_name + " is not callable") + return func_obj(*args, **kwargs) + + +def construct_class_by_name( + *args, class_name: str = None, **kwargs +) -> Any: # pragma: no cover + """ + Finds the python class with the given name and constructs it with the given + arguments. + """ + return call_func_by_name(*args, func_name=class_name, **kwargs) + + +def get_module_dir_by_obj_name(obj_name: str) -> str: # pragma: no cover + """ + Get the directory path of the module containing the given object name. + """ + module, _ = get_module_from_obj_name(obj_name) + return os.path.dirname(inspect.getfile(module)) + + +def is_top_level_function(obj: Any) -> bool: # pragma: no cover + """ + Determine whether the given object is a top-level function, i.e., defined at module + scope using 'def'. + """ + return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ + + +def get_top_level_function_name(obj: Any) -> str: # pragma: no cover + """ + Return the fully-qualified name of a top-level function. + """ + if not is_top_level_function(obj): + raise ValueError("Object is not a top-level function") + module = obj.__module__ + if module == "__main__": + module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] + return module + "." + obj.__name__ + + +# File system helpers +# ------------------------------------------------------------------------------------------ + + +def list_dir_recursively_with_ignore( + dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False +) -> List[Tuple[str, str]]: # pragma: no cover + """ + List all files recursively in a given directory while ignoring given file and + directory names. Returns list of tuples containing both absolute and relative paths. + """ + if not os.path.isdir(dir_path): + raise RuntimeError(f"Directory does not exist: {dir_path}") + base_name = os.path.basename(os.path.normpath(dir_path)) + + if ignores is None: + ignores = [] + + result = [] + + for root, dirs, files in os.walk(dir_path, topdown=True): + for ignore_ in ignores: + dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] + + # dirs need to be edited in-place + for d in dirs_to_remove: + dirs.remove(d) + + files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] + + absolute_paths = [os.path.join(root, f) for f in files] + relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] + + if add_base_to_relative: + relative_paths = [os.path.join(base_name, p) for p in relative_paths] + + if len(absolute_paths) != len(relative_paths): + raise ValueError("Number of absolute and relative paths do not match") + result += zip(absolute_paths, relative_paths) + + return result + + +def copy_files_and_create_dirs( + files: List[Tuple[str, str]] +) -> None: # pragma: no cover + """ + Takes in a list of tuples of (src, dst) paths and copies files. + Will create all necessary directories. + """ + for file in files: + target_dir_name = os.path.dirname(file[1]) + + # will create all intermediate-level directories + if not os.path.exists(target_dir_name): + os.makedirs(target_dir_name) + + shutil.copyfile(file[0], file[1]) + + +# ---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + + +def constant( + value, shape=None, dtype=None, device=None, memory_format=None +): # pragma: no cover + """Cached construction of constant tensors""" + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device("cpu") + if memory_format is None: + memory_format = torch.contiguous_format + + key = ( + value.shape, + value.dtype, + value.tobytes(), + shape, + dtype, + device, + memory_format, + ) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + + +# ---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + + def nan_to_num( + input, nan=0.0, posinf=None, neginf=None, *, out=None + ): # pylint: disable=redefined-builtin # pragma: no cover + """Replace NaN/Inf with specified numerical values""" + if not isinstance(input, torch.Tensor): + raise TypeError("input should be a Tensor") + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + if nan != 0: + raise ValueError("nan_to_num only supports nan=0") + return torch.clamp( + input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out + ) + + +# ---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +# ---------------------------------------------------------------------------- +# Context manager to temporarily suppress known warnings in torch.jit.trace(). +# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 + + +@contextlib.contextmanager +def suppress_tracer_warnings(): # pragma: no cover + """ + Context manager to temporarily suppress known warnings in torch.jit.trace(). + Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 + """ + flt = ("ignore", None, torch.jit.TracerWarning, None, 0) + warnings.filters.insert(0, flt) + yield + warnings.filters.remove(flt) + + +# ---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + + +def assert_shape(tensor, ref_shape): # pragma: no cover + """ + Assert that the shape of a tensor matches the given list of integers. + None indicates that the size of a dimension is allowed to vary. + Performs symbolic assertion when used in torch.jit.trace(). + """ + if tensor.ndim != len(ref_shape): + raise AssertionError( + f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}" + ) + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert( + torch.equal(torch.as_tensor(size), ref_size), + f"Wrong size for dimension {idx}", + ) + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert( + torch.equal(size, torch.as_tensor(ref_size)), + f"Wrong size for dimension {idx}: expected {ref_size}", + ) + elif size != ref_size: + raise AssertionError( + f"Wrong size for dimension {idx}: got {size}, expected {ref_size}" + ) + + +# ---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + + +def profiled_function(fn): # pragma: no cover + """Function decorator that calls torch.autograd.profiler.record_function().""" + + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + + decorator.__name__ = fn.__name__ + return decorator + + +# ---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + + +class InfiniteSampler(torch.utils.data.Sampler): # pragma: no cover + """ + Sampler for torch.utils.data.DataLoader that loops over the dataset + indefinitely, shuffling items as it goes. + """ + + def __init__( + self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5 + ): + if not len(dataset) > 0: + raise ValueError("Dataset must contain at least one item") + if not num_replicas > 0: + raise ValueError("num_replicas must be positive") + if not 0 <= rank < num_replicas: + raise ValueError("rank must be non-negative and less than num_replicas") + if not 0 <= window_size <= 1: + raise ValueError("window_size must be between 0 and 1") + super().__init__() + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + + +# ---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + + +def params_and_buffers(module): # pragma: no cover + """Get parameters and buffers of a nn.Module""" + if not isinstance(module, torch.nn.Module): + raise TypeError("module must be a torch.nn.Module instance") + return list(module.parameters()) + list(module.buffers()) + + +def named_params_and_buffers(module): # pragma: no cover + """Get named parameters and buffers of a nn.Module""" + if not isinstance(module, torch.nn.Module): + raise TypeError("module must be a torch.nn.Module instance") + return list(module.named_parameters()) + list(module.named_buffers()) + + +@torch.no_grad() +def copy_params_and_buffers( + src_module, dst_module, require_all=False +): # pragma: no cover + """Copy parameters and buffers from a source module to target module""" + if not isinstance(src_module, torch.nn.Module): + raise TypeError("src_module must be a torch.nn.Module instance") + if not isinstance(dst_module, torch.nn.Module): + raise TypeError("dst_module must be a torch.nn.Module instance") + src_tensors = dict(named_params_and_buffers(src_module)) + for name, tensor in named_params_and_buffers(dst_module): + if not ((name in src_tensors) or (not require_all)): + raise ValueError(f"Missing source tensor for {name}") + if name in src_tensors: + tensor.copy_(src_tensors[name]) + + +# ---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + + +@contextlib.contextmanager +def ddp_sync(module, sync): # pragma: no cover + """ + Context manager for easily enabling/disabling DistributedDataParallel + synchronization. + """ + if not isinstance(module, torch.nn.Module): + raise TypeError("module must be a torch.nn.Module instance") + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + + +# ---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + + +def check_ddp_consistency(module, ignore_regex=None): # pragma: no cover + """Check DistributedDataParallel consistency across processes.""" + if not isinstance(module, torch.nn.Module): + raise TypeError("module must be a torch.nn.Module instance") + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + "." + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + if tensor.is_floating_point(): + tensor = nan_to_num(tensor) + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + if not (tensor == other).all(): + raise RuntimeError(f"DDP consistency check failed for {fullname}") + + +# ---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + + +def print_module_summary( + module, inputs, max_nesting=3, skip_redundant=True +): # pragma: no cover + """Print summary table of module hierarchy.""" + if not isinstance(module, torch.nn.Module): + raise TypeError("module must be a torch.nn.Module instance") + if isinstance(module, torch.jit.ScriptModule): + raise TypeError("module must not be a torch.jit.ScriptModule instance") + if not isinstance(inputs, (tuple, list)): + raise TypeError("inputs must be a tuple or list") + + # Register hooks. + entries = [] + nesting = [0] + + def pre_hook(_mod, _inputs): + nesting[0] += 1 + + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(EasyDict(mod=mod, outputs=outputs)) + + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= { + id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs + } + + # Filter out redundant entries. + if skip_redundant: + entries = [ + e + for e in entries + if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs) + ] + + # Construct table. + rows = [ + [type(module).__name__, "Parameters", "Buffers", "Output shape", "Datatype"] + ] + rows += [["---"] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = "" if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(t.shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split(".")[-1] for t in e.outputs] + rows += [ + [ + name + (":0" if len(e.outputs) >= 2 else ""), + str(param_size) if param_size else "-", + str(buffer_size) if buffer_size else "-", + (output_shapes + ["-"])[0], + (output_dtypes + ["-"])[0], + ] + ] + for idx in range(1, len(e.outputs)): + rows += [ + [name + f":{idx}", "-", "-", output_shapes[idx], output_dtypes[idx]] + ] + param_total += param_size + buffer_total += buffer_size + rows += [["---"] * len(rows[0])] + rows += [["Total", str(param_total), str(buffer_total), "-", "-"]] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + for row in rows: + print( + " ".join( + cell + " " * (width - len(cell)) for cell, width in zip(row, widths) + ) + ) + return outputs + + +# ---------------------------------------------------------------------------- diff --git a/src/utils/inference_utils.py b/src/utils/inference_utils.py new file mode 100644 index 0000000..842bdd3 --- /dev/null +++ b/src/utils/inference_utils.py @@ -0,0 +1,253 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +import cftime +import nvtx +import torch +import tqdm + +from src.utils.function_utils import StackedRandomGenerator, time_range + +############################################################################ +# CorrDiff Generation Utilities # +############################################################################ + + +def regression_step( + net: torch.nn.Module, + img_lr: torch.Tensor, + latents_shape: torch.Size, + lead_time_label: torch.Tensor = None, +) -> torch.Tensor: + """ + Given a low-res input, performs a regression step to produce ensemble mean. + This function performs the regression on a single instance and then replicates + the results across the batch dimension. + + Args: + net (torch.nn.Module): U-Net model for regression. + img_lr (torch.Tensor): Low-resolution input. + latents_shape (torch.Size): Shape of the latent representation. Typically + (batch_size, out_channels, image_shape_x, image_shape_y). + + + Returns: + torch.Tensor: Predicted output at the next time step. + """ + # Create a tensor of zeros with the given shape and move it to the appropriate device + x_hat = torch.zeros(latents_shape, dtype=torch.float64, device=net.device) + t_hat = torch.tensor(1.0, dtype=torch.float64, device=net.device) + + # Perform regression on a single batch element + with torch.inference_mode(): + if lead_time_label is not None: + x = net(x_hat[0:1], img_lr, t_hat, lead_time_label=lead_time_label) + else: + x = net(x_hat[0:1], img_lr, t_hat) + + # If the batch size is greater than 1, repeat the prediction + if x_hat.shape[0] > 1: + x = x.repeat([d if i == 0 else 1 for i, d in enumerate(x_hat.shape)]) + + return x + + +def diffusion_step( # TODO generalize the module and add defaults + net: torch.nn.Module, + sampler_fn: callable, + seed_batch_size: int, + img_shape: tuple, + img_out_channels: int, + rank_batches: list, + img_lr: torch.Tensor, + rank: int, + device: torch.device, + hr_mean: torch.Tensor = None, + lead_time_label: torch.Tensor = None, +) -> torch.Tensor: + + """ + Generate images using diffusion techniques as described in the relevant paper. + + Args: + net (torch.nn.Module): The diffusion model network. + sampler_fn (callable): Function used to sample images from the diffusion model. + seed_batch_size (int): Number of seeds per batch. + img_shape (tuple): Shape of the images, (height, width). + img_out_channels (int): Number of output channels for the image. + rank_batches (list): List of batches of seeds to process. + img_lr (torch.Tensor): Low-resolution input image. + rank (int): Rank of the current process for distributed processing. + device (torch.device): Device to perform computations. + mean_hr (torch.Tensor, optional): High-resolution mean tensor, to be used as an additional input. By default None. + + Returns: + torch.Tensor: Generated images concatenated across batches. + """ + + img_lr = img_lr.to(memory_format=torch.channels_last) + + # Handling of the high-res mean + additional_args = {} + if hr_mean is not None: + additional_args["mean_hr"] = hr_mean + if lead_time_label is not None: + additional_args["lead_time_label"] = lead_time_label + additional_args["img_shape"] = img_shape + + # Loop over batches + all_images = [] + for batch_seeds in tqdm.tqdm(rank_batches, unit="batch", disable=(rank != 0)): + with nvtx.annotate(f"generate {len(all_images)}", color="rapids"): + batch_size = len(batch_seeds) + if batch_size == 0: + continue + + # Initialize random generator, and generate latents + rnd = StackedRandomGenerator(device, batch_seeds) + latents = rnd.randn( + [ + seed_batch_size, + img_out_channels, + img_shape[0], + img_shape[1], + ], + device=device, + ).to(memory_format=torch.channels_last) + + with torch.inference_mode(): + images = sampler_fn( + net, latents, img_lr, randn_like=rnd.randn_like, **additional_args + ) + all_images.append(images) + return torch.cat(all_images) + + +############################################################################ +# CorrDiff writer utilities # +############################################################################ + + +class NetCDFWriter: + """NetCDF Writer""" + + def __init__( + self, f, lat, lon, input_channels, output_channels, has_lead_time=False + ): + self._f = f + self.has_lead_time = has_lead_time + # create unlimited dimensions + f.createDimension("time") + f.createDimension("ensemble") + + if lat.shape != lon.shape: + raise ValueError("lat and lon must have the same shape") + ny, nx = lat.shape + + # create lat/lon grid + f.createDimension("x", nx) + f.createDimension("y", ny) + + v = f.createVariable("lat", "f", dimensions=("y", "x")) + # NOTE rethink this for datasets whose samples don't have constant lat-lon. + v[:] = lat + v.standard_name = "latitude" + v.units = "degrees_north" + + v = f.createVariable("lon", "f", dimensions=("y", "x")) + v[:] = lon + v.standard_name = "longitude" + v.units = "degrees_east" + + # create time dimension + if has_lead_time: + v = f.createVariable("time", "str", ("time")) + else: + v = f.createVariable("time", "i8", ("time")) + v.calendar = "standard" + v.units = "hours since 1990-01-01 00:00:00" + + self.truth_group = f.createGroup("truth") + self.prediction_group = f.createGroup("prediction") + self.input_group = f.createGroup("input") + + for variable in output_channels: + name = variable.name + variable.level + self.truth_group.createVariable(name, "f", dimensions=("time", "y", "x")) + self.prediction_group.createVariable( + name, "f", dimensions=("ensemble", "time", "y", "x") + ) + + # setup input data in netCDF + + for variable in input_channels: + name = variable.name + variable.level + self.input_group.createVariable(name, "f", dimensions=("time", "y", "x")) + + def write_input(self, channel_name, time_index, val): + """Write input data to NetCDF file.""" + self.input_group[channel_name][time_index] = val + + def write_truth(self, channel_name, time_index, val): + """Write ground truth data to NetCDF file.""" + self.truth_group[channel_name][time_index] = val + + def write_prediction(self, channel_name, time_index, ensemble_index, val): + """Write prediction data to NetCDF file.""" + self.prediction_group[channel_name][ensemble_index, time_index] = val + + def write_time(self, time_index, time): + """Write time information to NetCDF file.""" + if self.has_lead_time: + self._f["time"][time_index] = time + else: + time_v = self._f["time"] + self._f["time"][time_index] = cftime.date2num( + time, time_v.units, time_v.calendar + ) + + +############################################################################ +# CorrDiff time utilities # +############################################################################ + + +def get_time_from_range(times_range, time_format="%Y-%m-%dT%H:%M:%S"): + """Generates a list of times within a given range. + + Args: + times_range: A list containing start time, end time, and optional interval (hours). + time_format: The format of the input times (default: "%Y-%m-%dT%H:%M:%S"). + + Returns: + A list of times within the specified range. + """ + + start_time = datetime.datetime.strptime(times_range[0], time_format) + end_time = datetime.datetime.strptime(times_range[1], time_format) + interval = ( + datetime.timedelta(hours=times_range[2]) + if len(times_range) > 2 + else datetime.timedelta(hours=1) + ) + + times = [ + t.strftime(time_format) + for t in time_range(start_time, end_time, interval, inclusive=True) + ] + return times diff --git a/src/utils/model_utils.py b/src/utils/model_utils.py new file mode 100644 index 0000000..e1cde9d --- /dev/null +++ b/src/utils/model_utils.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import torch + + +def weight_init(shape: tuple, mode: str, fan_in: int, fan_out: int): + """ + Unified routine for initializing weights and biases. + This function provides a unified interface for various weight initialization + strategies like Xavier (Glorot) and Kaiming (He) initializations. + + Parameters + ---------- + shape : tuple + The shape of the tensor to initialize. It could represent weights or biases + of a layer in a neural network. + mode : str + The mode/type of initialization to use. Supported values are: + - "xavier_uniform": Xavier (Glorot) uniform initialization. + - "xavier_normal": Xavier (Glorot) normal initialization. + - "kaiming_uniform": Kaiming (He) uniform initialization. + - "kaiming_normal": Kaiming (He) normal initialization. + fan_in : int + The number of input units in the weight tensor. For convolutional layers, + this typically represents the number of input channels times the kernel height + times the kernel width. + fan_out : int + The number of output units in the weight tensor. For convolutional layers, + this typically represents the number of output channels times the kernel height + times the kernel width. + + Returns + ------- + torch.Tensor + The initialized tensor based on the specified mode. + + Raises + ------ + ValueError + If the provided `mode` is not one of the supported initialization modes. + """ + if mode == "xavier_uniform": + return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) + if mode == "xavier_normal": + return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) + if mode == "kaiming_uniform": + return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) + if mode == "kaiming_normal": + return np.sqrt(1 / fan_in) * torch.randn(*shape) + raise ValueError(f'Invalid init mode "{mode}"') diff --git a/src/utils/stochastic_sampler.py b/src/utils/stochastic_sampler.py new file mode 100644 index 0000000..ddcf9cc --- /dev/null +++ b/src/utils/stochastic_sampler.py @@ -0,0 +1,533 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Any, Callable, Optional + +import torch +from torch import Tensor + + +def image_batching( + input: Tensor, + img_shape_y: int, + img_shape_x: int, + patch_shape_y: int, + patch_shape_x: int, + batch_size: int, + overlap_pix: int, + boundary_pix: int, + input_interp: Optional[Tensor] = None, +) -> Tensor: + """ + Splits a full image into a batch of patched images. + + This function takes a full image and splits it into patches, adding padding where necessary. + It can also concatenate additional interpolated data to each patch if provided. + + Parameters + ---------- + input : Tensor + The input tensor representing the full image with shape (batch_size, channels, img_shape_x, img_shape_y). + img_shape_x : int + The width (x-dimension) of the original full image. + img_shape_y : int + The height (y-dimension) of the original full image. + patch_shape_x : int + The width (x-dimension) of each image patch. + patch_shape_y : int + The height (y-dimension) of each image patch. + batch_size : int + The original batch size before patching. + overlap_pix : int + The number of overlapping pixels between adjacent patches. + boundary_pix : int + The number of pixels to crop as a boundary from each patch. + input_interp : Optional[Tensor], optional + Optional additional data to concatenate to each patch with shape (batch_size, interp_channels, patch_shape_x, patch_shape_y). + By default None. + + Returns + ------- + Tensor + A tensor containing the image patches, with shape (total_patches * batch_size, channels [+ interp_channels], patch_shape_x, patch_shape_y). + """ + patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)) + patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)) + padded_shape_x = ( + (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1) + + patch_shape_x + + boundary_pix + ) + padded_shape_y = ( + (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1) + + patch_shape_y + + boundary_pix + ) + pad_x_right = padded_shape_x - img_shape_x - boundary_pix + pad_y_right = padded_shape_y - img_shape_y - boundary_pix + input_padded = torch.zeros( + input.shape[0], input.shape[1], padded_shape_y, padded_shape_x + ).to(input.device) + image_padding = torch.nn.ReflectionPad2d( + (boundary_pix, pad_x_right, boundary_pix, pad_y_right) + ).to( + input.device + ) # (padding_left,padding_right,padding_top,padding_bottom) + input_padded = image_padding(input) + patch_num = patch_num_x * patch_num_y + if input_interp is not None: + output = torch.zeros( + patch_num * batch_size, + input.shape[1] + input_interp.shape[1], + patch_shape_y, + patch_shape_x, + ).to(input.device) + else: + output = torch.zeros( + patch_num * batch_size, input.shape[1], patch_shape_y, patch_shape_x + ).to(input.device) + for x_index in range(patch_num_x): + for y_index in range(patch_num_y): + x_start = x_index * (patch_shape_x - overlap_pix - boundary_pix) + y_start = y_index * (patch_shape_y - overlap_pix - boundary_pix) + if input_interp is not None: + output[ + (x_index * patch_num_y + y_index) + * batch_size : (x_index * patch_num_y + y_index + 1) + * batch_size, + ] = torch.cat( + ( + input_padded[ + :, + :, + y_start : y_start + patch_shape_y, + x_start : x_start + patch_shape_x, + ], + input_interp, + ), + dim=1, + ) + else: + output[ + (x_index * patch_num_y + y_index) + * batch_size : (x_index * patch_num_y + y_index + 1) + * batch_size, + ] = input_padded[ + :, + :, + y_start : y_start + patch_shape_y, + x_start : x_start + patch_shape_x, + ] + return output + + +def image_fuse( + input: Tensor, + img_shape_y: int, + img_shape_x: int, + patch_shape_y: int, + patch_shape_x: int, + batch_size: int, + overlap_pix: int, + boundary_pix: int, +) -> Tensor: + """ + Reconstructs a full image from a batch of patched images. + + This function takes a batch of image patches and reconstructs the full image + by stitching the patches together. The function accounts for overlapping and + boundary pixels, ensuring that overlapping areas are averaged. + + Parameters + ---------- + input : Tensor + The input tensor containing the image patches with shape (total_patches * batch_size, channels, patch_shape_x, patch_shape_y). + img_shape_x : int + The width (x-dimension) of the original full image. + img_shape_y : int + The height (y-dimension) of the original full image. + patch_shape_x : int + The width (x-dimension) of each image patch. + patch_shape_y : int + The height (y-dimension) of each image patch. + batch_size : int + The original batch size before patching. + overlap_pix : int + The number of overlapping pixels between adjacent patches. + boundary_pix : int + The number of pixels to crop as a boundary from each patch. + + Returns + ------- + Tensor + The reconstructed full image tensor with shape (batch_size, channels, img_shape_x, img_shape_y). + + """ + patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)) + patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)) + padded_shape_x = ( + (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1) + + patch_shape_x + + boundary_pix + ) + padded_shape_y = ( + (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1) + + patch_shape_y + + boundary_pix + ) + pad_x_right = padded_shape_x - img_shape_x - boundary_pix + pad_y_right = padded_shape_y - img_shape_y - boundary_pix + residual_x = patch_shape_x - pad_x_right # residual pixels in the last patch + residual_y = patch_shape_y - pad_y_right # residual pixels in the last patch + output = torch.zeros( + batch_size, input.shape[1], img_shape_y, img_shape_x, device=input.device + ) + one_map = torch.ones(1, 1, input.shape[2], input.shape[3], device=input.device) + count_map = torch.zeros( + 1, 1, img_shape_y, img_shape_x, device=input.device + ) # to count the overlapping times + for x_index in range(patch_num_x): + for y_index in range(patch_num_y): + x_start = x_index * (patch_shape_x - overlap_pix - boundary_pix) + y_start = y_index * (patch_shape_y - overlap_pix - boundary_pix) + if (x_index == patch_num_x - 1) and (y_index != patch_num_y - 1): + output[ + :, :, y_start : y_start + patch_shape_y - 2 * boundary_pix, x_start: + ] += input[ + (x_index * patch_num_y + y_index) + * batch_size : (x_index * patch_num_y + y_index + 1) + * batch_size, + :, + boundary_pix : patch_shape_y - boundary_pix, + boundary_pix : residual_x + boundary_pix, + ] + count_map[ + :, :, y_start : y_start + patch_shape_y - 2 * boundary_pix, x_start: + ] += one_map[ + :, + :, + boundary_pix : patch_shape_y - boundary_pix, + boundary_pix : residual_x + boundary_pix, + ] + elif (y_index == patch_num_y - 1) and ((x_index != patch_num_x - 1)): + output[ + :, :, y_start:, x_start : x_start + patch_shape_x - 2 * boundary_pix + ] += input[ + (x_index * patch_num_y + y_index) + * batch_size : (x_index * patch_num_y + y_index + 1) + * batch_size, + :, + boundary_pix : residual_y + boundary_pix, + boundary_pix : patch_shape_x - boundary_pix, + ] + count_map[ + :, :, y_start:, x_start : x_start + patch_shape_x - 2 * boundary_pix + ] += one_map[ + :, + :, + boundary_pix : residual_y + boundary_pix, + boundary_pix : patch_shape_x - boundary_pix, + ] + elif x_index == patch_num_x - 1 and y_index == patch_num_y - 1: + output[:, :, y_start:, x_start:] += input[ + (x_index * patch_num_y + y_index) + * batch_size : (x_index * patch_num_y + y_index + 1) + * batch_size, + :, + boundary_pix : residual_y + boundary_pix, + boundary_pix : residual_x + boundary_pix, + ] + count_map[:, :, y_start:, x_start:] += one_map[ + :, + :, + boundary_pix : residual_y + boundary_pix, + boundary_pix : residual_x + boundary_pix, + ] + else: + output[ + :, + :, + y_start : y_start + patch_shape_y - 2 * boundary_pix, + x_start : x_start + patch_shape_x - 2 * boundary_pix, + ] += input[ + (x_index * patch_num_y + y_index) + * batch_size : (x_index * patch_num_y + y_index + 1) + * batch_size, + :, + boundary_pix : patch_shape_y - boundary_pix, + boundary_pix : patch_shape_x - boundary_pix, + ] + count_map[ + :, + :, + y_start : y_start + patch_shape_y - 2 * boundary_pix, + x_start : x_start + patch_shape_x - 2 * boundary_pix, + ] += one_map[ + :, + :, + boundary_pix : patch_shape_y - boundary_pix, + boundary_pix : patch_shape_x - boundary_pix, + ] + return output / count_map + + +def stochastic_sampler( + net: Any, + latents: Tensor, + img_lr: Tensor, + class_labels: Optional[Tensor] = None, + randn_like: Callable[[Tensor], Tensor] = torch.randn_like, + img_shape: int = 448, + patch_shape: int = 448, + overlap_pix: int = 4, + boundary_pix: int = 2, + mean_hr: Optional[Tensor] = None, + lead_time_label: Optional[Tensor] = None, + num_steps: int = 18, + sigma_min: float = 0.002, + sigma_max: float = 800, + rho: float = 7, + S_churn: float = 0, + S_min: float = 0, + S_max: float = float("inf"), + S_noise: float = 1, +) -> Tensor: + """ + Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution and patch-based diffusion. + + Parameters + ---------- + net : Any + The neural network model that generates denoised images from noisy inputs. + latents : Tensor + The latent variables (e.g., noise) used as the initial input for the sampler. + img_lr : Tensor + Low-resolution input image for conditioning the super-resolution process. + class_labels : Optional[Tensor], optional + Class labels for conditional generation, if required by the model. By default None. + randn_like : Callable[[Tensor], Tensor] + Function to generate random noise with the same shape as the input tensor. + By default torch.randn_like. + img_shape : int + The height and width of the full image (assumed to be square). By default 448. + patch_shape : int + The height and width of each patch (assumed to be square). By default 448. + overlap_pix : int + Number of overlapping pixels between adjacent patches. By default 4. + boundary_pix : int + Number of pixels to be cropped as a boundary from each patch. By default 2. + mean_hr : Optional[Tensor], optional + Optional tensor containing mean high-resolution images for conditioning. By default None. + num_steps : int + Number of time steps for the sampler. By default 18. + sigma_min : float + Minimum noise level. By default 0.002. + sigma_max : float + Maximum noise level. By default 800. + rho : float + Exponent used in the time step discretization. By default 7. + S_churn : float + Churn parameter controlling the level of noise added in each step. By default 0. + S_min : float + Minimum time step for applying churn. By default 0. + S_max : float + Maximum time step for applying churn. By default float("inf"). + S_noise : float + Noise scaling factor applied during the churn step. By default 1. + + Returns + ------- + Tensor + The final denoised image produced by the sampler. + """ + + # Adjust noise levels based on what's supported by the network. + "Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution." + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + if isinstance(img_shape, tuple): + img_shape_y, img_shape_x = img_shape + else: + img_shape_x = img_shape_y = img_shape + if patch_shape > img_shape_x or patch_shape > img_shape_y: + patch_shape = min(img_shape_x, img_shape_y) + + # Time step discretization. + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + t_steps = ( + sigma_max ** (1 / rho) + + step_indices + / (num_steps - 1) + * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) + ) ** rho + t_steps = torch.cat( + [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] + ) # t_N = 0 + + b = latents.shape[0] + Nx = torch.arange(img_shape_x) + Ny = torch.arange(img_shape_y) + grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0)[ + None, + ].expand(b, -1, -1, -1) + + # conditioning = [mean_hr, img_lr, global_lr, pos_embd] + batch_size = img_lr.shape[0] + x_lr = img_lr + if mean_hr is not None: + x_lr = torch.cat((mean_hr.expand(x_lr.shape[0], -1, -1, -1), x_lr), dim=1) + global_index = None + + # input and position padding + patching + if patch_shape != img_shape_x or patch_shape != img_shape_y: + input_interp = torch.nn.functional.interpolate( + img_lr, (patch_shape, patch_shape), mode="bilinear" + ) + x_lr = image_batching( + x_lr, + img_shape_y, + img_shape_x, + patch_shape, + patch_shape, + batch_size, + overlap_pix, + boundary_pix, + input_interp, + ) + global_index = image_batching( + grid.float(), + img_shape_y, + img_shape_x, + patch_shape, + patch_shape, + batch_size, + overlap_pix, + boundary_pix, + ).int() + + # Main sampling loop. + x_next = latents.to(torch.float64) * t_steps[0] + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + # Increase noise temporarily. + gamma = S_churn / num_steps if S_min <= t_cur <= S_max else 0 + t_hat = net.round_sigma(t_cur + gamma * t_cur) + + x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) + + # Euler step. Perform patching operation on score tensor if patch-based generation is used + # denoised = net(x_hat, t_hat, class_labels,lead_time_label=lead_time_label).to(torch.float64) #x_lr + + if patch_shape != img_shape_x or patch_shape != img_shape_y: + x_hat_batch = image_batching( + x_hat, + img_shape_y, + img_shape_x, + patch_shape, + patch_shape, + batch_size, + overlap_pix, + boundary_pix, + ) + else: + x_hat_batch = x_hat + x_hat_batch = x_hat_batch.to(latents.device) + x_lr = x_lr.to(latents.device) + if global_index is not None: + global_index = global_index.to(latents.device) + + if lead_time_label is not None: + denoised = net( + x_hat_batch, + x_lr, + t_hat, + class_labels, + lead_time_label=lead_time_label, + global_index=global_index, + ).to(torch.float64) + else: + denoised = net( + x_hat_batch, + x_lr, + t_hat, + class_labels, + global_index=global_index, + ).to(torch.float64) + if patch_shape != img_shape_x or patch_shape != img_shape_y: + + denoised = image_fuse( + denoised, + img_shape_y, + img_shape_x, + patch_shape, + patch_shape, + batch_size, + overlap_pix, + boundary_pix, + ) + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + # Apply 2nd order correction. + if i < num_steps - 1: + if patch_shape != img_shape_x or patch_shape != img_shape_y: + x_next_batch = image_batching( + x_next, + img_shape_y, + img_shape_x, + patch_shape, + patch_shape, + batch_size, + overlap_pix, + boundary_pix, + ) + else: + x_next_batch = x_next + # ask about this fix + x_next_batch = x_next_batch.to(latents.device) + if lead_time_label is not None: + denoised = net( + x_next_batch, + x_lr, + t_next, + class_labels, + lead_time_label=lead_time_label, + global_index=global_index, + ).to(torch.float64) + else: + denoised = net( + x_next_batch, + x_lr, + t_next, + class_labels, + global_index=global_index, + ).to(torch.float64) + if patch_shape != img_shape_x or patch_shape != img_shape_y: + denoised = image_fuse( + denoised, + img_shape_y, + img_shape_x, + patch_shape, + patch_shape, + batch_size, + overlap_pix, + boundary_pix, + ) + d_prime = (x_next - denoised) / t_next + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + return x_next diff --git a/src/utils/train_helpers.py b/src/utils/train_helpers.py new file mode 100644 index 0000000..d4529ac --- /dev/null +++ b/src/utils/train_helpers.py @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np +from omegaconf import ListConfig + + +def set_patch_shape(img_shape, patch_shape): + img_shape_y, img_shape_x = img_shape + patch_shape_y, patch_shape_x = patch_shape + if (patch_shape_x is None) or (patch_shape_x > img_shape_x): + patch_shape_x = img_shape_x + if (patch_shape_y is None) or (patch_shape_y > img_shape_y): + patch_shape_y = img_shape_y + if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: + if patch_shape_x != patch_shape_y: + raise NotImplementedError("Rectangular patch not supported yet") + if patch_shape_x % 32 != 0 or patch_shape_y % 32 != 0: + raise ValueError("Patch shape needs to be a multiple of 32") + return (img_shape_y, img_shape_x), (patch_shape_y, patch_shape_x) + + +def set_seed(rank): + """ + Set seeds for NumPy and PyTorch to ensure reproducibility in distributed settings + """ + np.random.seed(rank % (1 << 31)) + torch.manual_seed(np.random.randint(1 << 31)) + + +def configure_cuda_for_consistent_precision(): + """ + Configures CUDA and cuDNN settings to ensure consistent precision by + disabling TensorFloat-32 (TF32) and reduced precision settings. + """ + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + + +def compute_num_accumulation_rounds(total_batch_size, batch_size_per_gpu, world_size): + """ + Calculate the total batch size per GPU in a distributed setting, log the batch size per GPU, ensure it's within valid limits, + determine the number of accumulation rounds, and validate that the global batch size matches the expected value. + """ + batch_gpu_total = total_batch_size // world_size + batch_size_per_gpu = batch_size_per_gpu + if batch_size_per_gpu is None or batch_size_per_gpu > batch_gpu_total: + batch_size_per_gpu = batch_gpu_total + num_accumulation_rounds = batch_gpu_total // batch_size_per_gpu + if total_batch_size != batch_size_per_gpu * num_accumulation_rounds * world_size: + raise ValueError( + "total_batch_size must be equal to batch_size_per_gpu * num_accumulation_rounds * world_size" + ) + return batch_gpu_total, num_accumulation_rounds + + +def handle_and_clip_gradients(model, grad_clip_threshold=None): + """ + Handles NaNs and infinities in the gradients and optionally clips the gradients. + + Parameters: + - model (torch.nn.Module): The model whose gradients need to be processed. + - grad_clip_threshold (float, optional): The threshold for gradient clipping. If None, no clipping is performed. + """ + # Replace NaNs and infinities in gradients + for param in model.parameters(): + if param.grad is not None: + torch.nan_to_num( + param.grad, nan=0.0, posinf=1e5, neginf=-1e5, out=param.grad + ) + + # Clip gradients if a threshold is provided + if grad_clip_threshold is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_threshold) + + +def parse_model_args(args): + """Convert ListConfig values in args to tuples.""" + return {k: tuple(v) if isinstance(v, ListConfig) else v for k, v in args.items()} + + +def is_time_for_periodic_task( + cur_nimg, freq, done, batch_size, rank, rank_0_only=False +): + """Should we perform a task that is done every `freq` samples?""" + if rank_0_only and rank != 0: + return False + elif done: # Run periodic tasks also at the end of training + return True + else: + return cur_nimg % freq < batch_size From 3b172ccc9093ceaa3b4a18a76e26d770a6301253 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Mon, 14 Apr 2025 17:14:37 +0200 Subject: [PATCH 06/66] restructure file system and add pyproject.toml --- README.md | 0 pyproject.toml | 24 +++++++++++ src/hirad/conf/train_regression.yaml | 0 src/{ => hirad}/distributed/__init__.py | 0 src/{ => hirad}/distributed/config.py | 0 src/{ => hirad}/distributed/manager.py | 0 src/hirad/losses/__init__.py | 0 src/{ => hirad}/losses/loss.py | 0 src/{ => hirad}/models/__init__.py | 0 .../__pycache__/__init__.cpython-312.pyc | Bin 0 -> 185 bytes .../models/__pycache__/dummy.cpython-312.pyc | Bin 0 -> 451 bytes .../models/__pycache__/unet.cpython-312.pyc | Bin 0 -> 8966 bytes src/{ => hirad}/models/layers.py | 0 src/{ => hirad}/models/preconditioning.py | 0 src/{ => hirad}/models/song_unet.py | 0 src/{ => hirad}/models/unet.py | 0 src/{ => hirad}/models/utils.py | 0 src/hirad/training/train.py | 0 src/hirad/utils/__init__.py | 0 src/{ => hirad}/utils/capture.py | 0 src/{ => hirad}/utils/checkpoint.py | 0 src/{ => hirad}/utils/console.py | 0 .../utils/deterministic_sampler.py | 0 src/{ => hirad}/utils/function_utils.py | 0 src/{ => hirad}/utils/inference_utils.py | 0 src/{ => hirad}/utils/model_utils.py | 0 src/{ => hirad}/utils/stochastic_sampler.py | 0 src/{ => hirad}/utils/train_helpers.py | 0 src/hirad_gen.egg-info/PKG-INFO | 38 ++++++++++++++++++ src/hirad_gen.egg-info/SOURCES.txt | 28 +++++++++++++ src/hirad_gen.egg-info/dependency_links.txt | 1 + src/hirad_gen.egg-info/top_level.txt | 7 ++++ 32 files changed, 98 insertions(+) create mode 100644 README.md create mode 100644 pyproject.toml create mode 100644 src/hirad/conf/train_regression.yaml rename src/{ => hirad}/distributed/__init__.py (100%) rename src/{ => hirad}/distributed/config.py (100%) rename src/{ => hirad}/distributed/manager.py (100%) create mode 100644 src/hirad/losses/__init__.py rename src/{ => hirad}/losses/loss.py (100%) rename src/{ => hirad}/models/__init__.py (100%) create mode 100644 src/hirad/models/__pycache__/__init__.cpython-312.pyc create mode 100644 src/hirad/models/__pycache__/dummy.cpython-312.pyc create mode 100644 src/hirad/models/__pycache__/unet.cpython-312.pyc rename src/{ => hirad}/models/layers.py (100%) rename src/{ => hirad}/models/preconditioning.py (100%) rename src/{ => hirad}/models/song_unet.py (100%) rename src/{ => hirad}/models/unet.py (100%) rename src/{ => hirad}/models/utils.py (100%) create mode 100644 src/hirad/training/train.py create mode 100644 src/hirad/utils/__init__.py rename src/{ => hirad}/utils/capture.py (100%) rename src/{ => hirad}/utils/checkpoint.py (100%) rename src/{ => hirad}/utils/console.py (100%) rename src/{ => hirad}/utils/deterministic_sampler.py (100%) rename src/{ => hirad}/utils/function_utils.py (100%) rename src/{ => hirad}/utils/inference_utils.py (100%) rename src/{ => hirad}/utils/model_utils.py (100%) rename src/{ => hirad}/utils/stochastic_sampler.py (100%) rename src/{ => hirad}/utils/train_helpers.py (100%) create mode 100644 src/hirad_gen.egg-info/PKG-INFO create mode 100644 src/hirad_gen.egg-info/SOURCES.txt create mode 100644 src/hirad_gen.egg-info/dependency_links.txt create mode 100644 src/hirad_gen.egg-info/top_level.txt diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b2fa56c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,24 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "hirad-gen" +version = "0.1.0" +description = "High resolution atmospheric downscaling using generative machine learning" +authors = [ + { name="Petar Stamenkovic", email="petar.stamenkovic@meteoswiss.ch" } +] +readme = "README.md" +requires-python = ">=3.12" +license = {file = "LICENSE"} + +dependencies = [ + "torch>=2.6.0" +] + +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] \ No newline at end of file diff --git a/src/hirad/conf/train_regression.yaml b/src/hirad/conf/train_regression.yaml new file mode 100644 index 0000000..e69de29 diff --git a/src/distributed/__init__.py b/src/hirad/distributed/__init__.py similarity index 100% rename from src/distributed/__init__.py rename to src/hirad/distributed/__init__.py diff --git a/src/distributed/config.py b/src/hirad/distributed/config.py similarity index 100% rename from src/distributed/config.py rename to src/hirad/distributed/config.py diff --git a/src/distributed/manager.py b/src/hirad/distributed/manager.py similarity index 100% rename from src/distributed/manager.py rename to src/hirad/distributed/manager.py diff --git a/src/hirad/losses/__init__.py b/src/hirad/losses/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/losses/loss.py b/src/hirad/losses/loss.py similarity index 100% rename from src/losses/loss.py rename to src/hirad/losses/loss.py diff --git a/src/models/__init__.py b/src/hirad/models/__init__.py similarity index 100% rename from src/models/__init__.py rename to src/hirad/models/__init__.py diff --git a/src/hirad/models/__pycache__/__init__.cpython-312.pyc b/src/hirad/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70d5748263686e7ee01a00c9c9ffd61e4261990b GIT binary patch literal 185 zcmX@j%ge<81eP(8=gQRm475F^(AnLQ) zQDjD9>o~pA)a3YQB1#<>S}gTyk;VCqxJWL~XOnrH3tdDayH#g=oRqaZdZNZE447`7==k|mj}EK82PsyZ3&%#yR( zo!Q)(l}KR)vxSQ;+q#mP6r@`haex4oDiu*4g2+$KQ(yWbg(|!mH~@j92#mZ?pxh>Y zNzS?RvrBqy2Tp*#_zF6E&;2>~@7#0GcZUBJk4G3t8zaA+{guElzd*)K0VlI@7c$F? z%*bq!$>OiC$d-IrA8VKSvwm8~mG~@A%lySaDVPnijE_;6Ledr3_?5Yvtj%S^iqFR! zWn}(MMh+;!n|{0erqehZp|wG%4Jmw~t0BY9xjnKm+A<6+BVW)mzQCrVAEAKhG;4)r z$&_+M$uNM4PwBE!oKj3_5=t*c43f)~u*k@{1LChUOJ?Jrfmw!CW?7lZ`p&=}W&H}L z@Urh&sFD3ySQBuZ!-XivJ6xD@0m=ol5z2*}x+vwsPF;+05vMLrxv0Y>C>L|MHp<13 zTS%natdP4Fi#s!|m=>o=CFO|aFQ}#!%~fPcoF!6u&ai?~sVwHomD6@^wp_77rcQEm zVt(G@X7VEkEKb+7c`K5yXgQHqSz)PS>L@{;PA9F9D57*xv?8JiyH+V;J|>Fasz^n< zhJ;{ZmftXm)h3FPrs<|+s=8)~A_-#^H>2xCf?7$UNMcChND@fekf279b|5#H4WP^? zgToa=A;xgoFkyV{{P2tFsqx97qlz{R2RLjW;BZA#%uIQKpkhQ*i@+!Autvv;O zAA=ri?I;CuU2EHW7mr`%R=M7K*UrUbHS<2Vlk~U)Z_*_mN-=G`2;^^?v?$h;iL^ED>V%E33oSzeCFp*PuVKnBgv25Zbbn+?6fq@yGV;-vrJ7$zb>Ix_^P zFMI>ex~vc(uM=THCuCC1=Q9!7;G{%g?@ZW6$__OW<<-)xNEAaaR`9rlF+tT#T7kk( z&nZGvrJff|C{eX?#T?cvCj1C$Nt#tMo~Ya$oITvCErNvUN@)h>2NH665hm%jYxGn# z(JkE4xzU%7{FxG9Mqez#nRcsLfUb%8^2mWP0d;jy(95_bQqkLv!}Ch6Vk&eKDw?d2 z;)1Hp3RDdR$rO&9962DAiIP(d==l7gFtZ@YN?xiIp-w6q&ORCHY)KMJsx~I%i@IdC z7->q?)KaA+7?pBaC#E6`nywm(P*g4`FtaAP`|KTbWJ>d&BU76H17*@#4fLy8zCYb8 z9u4Ye%Rf9{R&qF1xU8Bi5gFEgL8Vpz&mA?#BoZtPwhSe0S;o`$NOUtrbD^w^fgsXi zZzj;B3uty+v0Kk#ljQ!%IfZ+{f8d`rsS-7af^r5Z- zR>w5xwZi#t9SE)%RgP0iULlH>Q?}kQS~~@MA!+KMFqIh;4naHpqH10m6sCcl(B~9w zP3~7Iq)g|iCtz)r-EUYVpZlX zu&*XqK|9@IZ#td_%T#?;Ms#0t#>D4`rYdWC@^9j7Ad$8zJI_LU*DGQmk0aBhZL6}f z56%0g7Od=d%Dk3UKeWI%kftgJTw`1wIB(=?061Ki-7*LqzJ)dy0uCQVn#?{g_lTRT zhTWN8j{2BtROYKJjLltw232@X-tku>-)H{;SLFF;3rspNeS@_)cm^bT3?JcM-SAt9 z&DP(FHW^_n)@)BnH}puiTRd1Kg`h_O7ukfz9rw+Op&ZwGTM;mS8XAF|g69qX2v!?& zQdtq_o0$cx4QrYQO=3awz*fjL%U0MkoL*J}x5X|+i(&<5VTC66APH~@;0~X6ycLV5 z_dOlAI72Ds$zw2Sl17q2GK^#&k`W}MNcIDfc$2WdANfKCzgWuLf!Ao-3@i$5B0D04%E2&p`M2U zCf0RT`+n|zbZ>q1g}bpAYLoY4y({K_?;2j)HFRCNw`(6QsW&rgyE53jw=V3%-@dhO zp~n4vyxy}1eiH7_j>mv&`>?C;9qoEywd-Il`m>a<&hTB4`p$H{yZ`Rqv3uQP^`7*D zAe60#n6{k{!jP+RZ$+rjlXG1hSOtU1Pp|75T&-oegn!v%I#-fm3eI)IhHe7I_2qps zd&}pkG+Sj0%`%;)jf-FKuBj`9Vcd&<87{(??_Kb&YeTZ^ms$LG(=jub;gT1;i`-~O zo~&|48s_d@S!H+R#?KrNbrEeT3nto;`6f;c$mfXt|Dl}F=rIam$0loB}q4D8Fk zrupIR*^=Cu7`+VxDF>^8e*q}jS}xRt|CM)F|5q%u`tn*9s-hbnnI?N+UVtq9vReA+ zcaSU&Vb=|x#mnfvSOIi@zy!u@+fbBdz%Tx!$8(d=p`a}VQx|3w;gbmXE#Ohl>Zb5X z7`$x5R3teASdX_snLj6)me16!Aa&68k6J-+&e5#e#_JGF$l~bOc8ie%xk$kDqu&ZJ z5x^rl%KJyjL1<#j^B)1z?w=Ir4ZKIL^%JP|gZb(NJT3WBj9Ng4TY!c}KbVL`W-Z!& zH`@1MGPNAO7G4=y$*(+ledK;}s5ZV9i(i>sn*2^`x$j!vYHU}HtwXEOwa`1M@Akdj zx0>w#AUS+5Iego9d;Ipd?(AMoK3n7015A5rIsT9F?-g#p{NvYt_}Y)p{qWpR@;`n0 zrH6hd-f?B>Z>N6Q-h~p^gJ5{qV@z95O}iypLdLR0I@x%i@*cMrjFgqSuk75p_?0gW1T+P4U!NikSjx8NqN&PhT z$SwYZf$@6-((E`kfO$S-6{gdBgAPOsq2?79_TH zT*)uxzjJ2!+_iJ7iAQVxdRs@$F}cl#6nz~0zcB~|{G%^0x9$Z1^<|~w&^Y3r(}b>F zPzbng!YsIjnm|ulxPaO)V5Ct$CBalQ16)G%@)QHSV(QpC7!&1e!I@aJ1!`iEEqP2h zqk*$|mg=!kR1I)wsg7q((^1ZOGs8Zo;i7zcMmc-VO*1&9m=&VgE;#mchWBJOkKh%$ z1GHS83oyz=zo5zpZSV?s8)JL6C?V}5N8t=0ZeobGBq+fPL{0-S@X!S3GDL|EFhZn~ zE+y%OXewF3=AA&mDd=UAsVi?Ry|GeRZF~G~=yCey-e8Zw-9)2Luo*!EZUseA)^iY( z;pvW2cZ0l!)n~Cf?)WD;rC1cj8!Y)clp3flFWBx0y0J8_fbNiiDXTH==@p7_7(oygZ0j?dgo&wryl1c8`Eqb-}%4KCHN=o2nb5a!GABgMs@=&)0c>V zgzX3jjeP{s=vwht;1Ej_V6<@KBnz~ko}?WGx!4p1@fvH@KnHsL)L< zzY42-6;}BwtkM)qbMBiRG9VbuB?yJBkQ*+@H9PBQb|~<83M|@|I6>5j6S$TL;`le6 zR}x@{lVceZWzYmM0uRNzu*8cIcv1jedk#5`5xDVyMhjNvJ{32>xI|;5KS{*E-09BT z#C@Wv8ggfT8NRYqBXR(tR@E=_5FH3R(E$v&(*)6h;53qvq4=Xd-X_MwjE80P5Hg;)ch zT_-hz>@LsnZ$jJu2|bK|2ZCt3yMeafePi{JXDHCvksJPjM@SS0+emk zx0uysjzZZ)V=On!(1L zuktOWa1M}d%&yfJPG^w-WYJX61b3AKxEgFN7a{H@Aj|&|8W5-kz^4jJZ-Ek>s%V0Q zAzy@PLKy?>j#uJcC-7Ki!D*pEV(M-+dEVm?MIWP1kb5QPg@I;wJB{8Del7wifw-3n z*wg{Wc2K1-u=(=~^re;HqZuYfra?x&Ebbo_)Vy$-RFsS`X@6ufu*&(+;t4+`!H1JA zvuohoG2vL=nHokA1w_vR_Yy2b`XI5QOL04{;JfVGk?XsE(0`-<&hGaI-W#|dn|Q$U zfW`xe%Y&G6;+7PbC;kXr4st#lmy2(0hp&fR?E|BkSG_)Z&qVc&GZ3acgjU|3XVtSBlpjXfy_K|^Fi zRQuz9M$;@G{NG-lTQJm|p(!QZ`GbfI{Y-#X+KSuKoG7~ee52!I1*Yg<0Z`0n`wr+Q z*fYV9ofUc(0_5;_0?!c*SbV08IUwtPmSs2gFl_tJnccr&Vn1h6|HF)|wGUkF`JjE^ zUi-lH-qrR4iwQVBU-Y3b!S>cJZ!qw);fI2CEO@wQo@Jk?jb1spbnqd36=$D#c)>rz TvYl5CEg!pf>|+M=bSA$CZ1ZH& literal 0 HcmV?d00001 diff --git a/src/models/layers.py b/src/hirad/models/layers.py similarity index 100% rename from src/models/layers.py rename to src/hirad/models/layers.py diff --git a/src/models/preconditioning.py b/src/hirad/models/preconditioning.py similarity index 100% rename from src/models/preconditioning.py rename to src/hirad/models/preconditioning.py diff --git a/src/models/song_unet.py b/src/hirad/models/song_unet.py similarity index 100% rename from src/models/song_unet.py rename to src/hirad/models/song_unet.py diff --git a/src/models/unet.py b/src/hirad/models/unet.py similarity index 100% rename from src/models/unet.py rename to src/hirad/models/unet.py diff --git a/src/models/utils.py b/src/hirad/models/utils.py similarity index 100% rename from src/models/utils.py rename to src/hirad/models/utils.py diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py new file mode 100644 index 0000000..e69de29 diff --git a/src/hirad/utils/__init__.py b/src/hirad/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/capture.py b/src/hirad/utils/capture.py similarity index 100% rename from src/utils/capture.py rename to src/hirad/utils/capture.py diff --git a/src/utils/checkpoint.py b/src/hirad/utils/checkpoint.py similarity index 100% rename from src/utils/checkpoint.py rename to src/hirad/utils/checkpoint.py diff --git a/src/utils/console.py b/src/hirad/utils/console.py similarity index 100% rename from src/utils/console.py rename to src/hirad/utils/console.py diff --git a/src/utils/deterministic_sampler.py b/src/hirad/utils/deterministic_sampler.py similarity index 100% rename from src/utils/deterministic_sampler.py rename to src/hirad/utils/deterministic_sampler.py diff --git a/src/utils/function_utils.py b/src/hirad/utils/function_utils.py similarity index 100% rename from src/utils/function_utils.py rename to src/hirad/utils/function_utils.py diff --git a/src/utils/inference_utils.py b/src/hirad/utils/inference_utils.py similarity index 100% rename from src/utils/inference_utils.py rename to src/hirad/utils/inference_utils.py diff --git a/src/utils/model_utils.py b/src/hirad/utils/model_utils.py similarity index 100% rename from src/utils/model_utils.py rename to src/hirad/utils/model_utils.py diff --git a/src/utils/stochastic_sampler.py b/src/hirad/utils/stochastic_sampler.py similarity index 100% rename from src/utils/stochastic_sampler.py rename to src/hirad/utils/stochastic_sampler.py diff --git a/src/utils/train_helpers.py b/src/hirad/utils/train_helpers.py similarity index 100% rename from src/utils/train_helpers.py rename to src/hirad/utils/train_helpers.py diff --git a/src/hirad_gen.egg-info/PKG-INFO b/src/hirad_gen.egg-info/PKG-INFO new file mode 100644 index 0000000..9c4c18f --- /dev/null +++ b/src/hirad_gen.egg-info/PKG-INFO @@ -0,0 +1,38 @@ +Metadata-Version: 2.4 +Name: hirad-gen +Version: 0.1.0 +Summary: High resolution atmospheric downscaling using generative machine learning +Author-email: Petar Stamenkovic +License: BSD 3-Clause License + + Copyright (c) 2025, MeteoSwiss + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Requires-Python: >=3.12 +Description-Content-Type: text/markdown +License-File: LICENSE +Dynamic: license-file diff --git a/src/hirad_gen.egg-info/SOURCES.txt b/src/hirad_gen.egg-info/SOURCES.txt new file mode 100644 index 0000000..1645d32 --- /dev/null +++ b/src/hirad_gen.egg-info/SOURCES.txt @@ -0,0 +1,28 @@ +LICENSE +README.md +pyproject.toml +src/distributed/__init__.py +src/distributed/config.py +src/distributed/manager.py +src/hirad_gen.egg-info/PKG-INFO +src/hirad_gen.egg-info/SOURCES.txt +src/hirad_gen.egg-info/dependency_links.txt +src/hirad_gen.egg-info/top_level.txt +src/losses/__init__.py +src/losses/loss.py +src/models/__init__.py +src/models/layers.py +src/models/preconditioning.py +src/models/song_unet.py +src/models/unet.py +src/models/utils.py +src/utils/__init__.py +src/utils/capture.py +src/utils/checkpoint.py +src/utils/console.py +src/utils/deterministic_sampler.py +src/utils/function_utils.py +src/utils/inference_utils.py +src/utils/model_utils.py +src/utils/stochastic_sampler.py +src/utils/train_helpers.py \ No newline at end of file diff --git a/src/hirad_gen.egg-info/dependency_links.txt b/src/hirad_gen.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/hirad_gen.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/src/hirad_gen.egg-info/top_level.txt b/src/hirad_gen.egg-info/top_level.txt new file mode 100644 index 0000000..b778691 --- /dev/null +++ b/src/hirad_gen.egg-info/top_level.txt @@ -0,0 +1,7 @@ +distributed +evaluation +losses +metrics +models +training +utils From 05e3a08a75e41c5c6a5c63422c1516dab6cdb091 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic <56728083+PetarStam@users.noreply.github.com> Date: Mon, 14 Apr 2025 17:15:40 +0200 Subject: [PATCH 07/66] Delete src/hirad_gen.egg-info directory --- src/hirad_gen.egg-info/PKG-INFO | 38 --------------------- src/hirad_gen.egg-info/SOURCES.txt | 28 --------------- src/hirad_gen.egg-info/dependency_links.txt | 1 - src/hirad_gen.egg-info/top_level.txt | 7 ---- 4 files changed, 74 deletions(-) delete mode 100644 src/hirad_gen.egg-info/PKG-INFO delete mode 100644 src/hirad_gen.egg-info/SOURCES.txt delete mode 100644 src/hirad_gen.egg-info/dependency_links.txt delete mode 100644 src/hirad_gen.egg-info/top_level.txt diff --git a/src/hirad_gen.egg-info/PKG-INFO b/src/hirad_gen.egg-info/PKG-INFO deleted file mode 100644 index 9c4c18f..0000000 --- a/src/hirad_gen.egg-info/PKG-INFO +++ /dev/null @@ -1,38 +0,0 @@ -Metadata-Version: 2.4 -Name: hirad-gen -Version: 0.1.0 -Summary: High resolution atmospheric downscaling using generative machine learning -Author-email: Petar Stamenkovic -License: BSD 3-Clause License - - Copyright (c) 2025, MeteoSwiss - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - - 3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -Requires-Python: >=3.12 -Description-Content-Type: text/markdown -License-File: LICENSE -Dynamic: license-file diff --git a/src/hirad_gen.egg-info/SOURCES.txt b/src/hirad_gen.egg-info/SOURCES.txt deleted file mode 100644 index 1645d32..0000000 --- a/src/hirad_gen.egg-info/SOURCES.txt +++ /dev/null @@ -1,28 +0,0 @@ -LICENSE -README.md -pyproject.toml -src/distributed/__init__.py -src/distributed/config.py -src/distributed/manager.py -src/hirad_gen.egg-info/PKG-INFO -src/hirad_gen.egg-info/SOURCES.txt -src/hirad_gen.egg-info/dependency_links.txt -src/hirad_gen.egg-info/top_level.txt -src/losses/__init__.py -src/losses/loss.py -src/models/__init__.py -src/models/layers.py -src/models/preconditioning.py -src/models/song_unet.py -src/models/unet.py -src/models/utils.py -src/utils/__init__.py -src/utils/capture.py -src/utils/checkpoint.py -src/utils/console.py -src/utils/deterministic_sampler.py -src/utils/function_utils.py -src/utils/inference_utils.py -src/utils/model_utils.py -src/utils/stochastic_sampler.py -src/utils/train_helpers.py \ No newline at end of file diff --git a/src/hirad_gen.egg-info/dependency_links.txt b/src/hirad_gen.egg-info/dependency_links.txt deleted file mode 100644 index 8b13789..0000000 --- a/src/hirad_gen.egg-info/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/hirad_gen.egg-info/top_level.txt b/src/hirad_gen.egg-info/top_level.txt deleted file mode 100644 index b778691..0000000 --- a/src/hirad_gen.egg-info/top_level.txt +++ /dev/null @@ -1,7 +0,0 @@ -distributed -evaluation -losses -metrics -models -training -utils From dbf101ef0c50e1dff50ba140b85124a6d407f965 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Mon, 14 Apr 2025 17:17:57 +0200 Subject: [PATCH 08/66] add gitignore --- .gitignore | 171 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c514c5d --- /dev/null +++ b/.gitignore @@ -0,0 +1,171 @@ +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json \ No newline at end of file From 5de83bc5f7bb25f8bff91ffda69b2eedd97738fd Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 16 Apr 2025 11:57:45 +0200 Subject: [PATCH 09/66] add train script with missing parts --- .../{train_regression.yaml => training.yaml} | 0 src/hirad/distributed/manager.py | 6 +- src/hirad/losses/__init__.py | 1 + src/hirad/models/__init__.py | 1 + src/hirad/training/train.py | 462 ++++++++++++++++++ src/hirad/utils/capture.py | 2 +- 6 files changed, 468 insertions(+), 4 deletions(-) rename src/hirad/conf/{train_regression.yaml => training.yaml} (100%) diff --git a/src/hirad/conf/train_regression.yaml b/src/hirad/conf/training.yaml similarity index 100% rename from src/hirad/conf/train_regression.yaml rename to src/hirad/conf/training.yaml diff --git a/src/hirad/distributed/manager.py b/src/hirad/distributed/manager.py index facb466..e80ce13 100644 --- a/src/hirad/distributed/manager.py +++ b/src/hirad/distributed/manager.py @@ -25,7 +25,7 @@ import torch import torch.distributed as dist -from src.distributed.config import ProcessGroupConfig, ProcessGroupNode +from hirad.distributed.config import ProcessGroupConfig, ProcessGroupNode warnings.simplefilter("default", DeprecationWarning) @@ -393,7 +393,7 @@ def initialize(): else: os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0" initialization_method = os.getenv( - "PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD" + "DISTRIBUTED_INITIALIZATION_METHOD" ) if initialization_method is None: try: @@ -419,7 +419,7 @@ def initialize(): "Unknown initialization method " f"{initialization_method}. " "Supported values for " - "PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD are " + "DISTRIBUTED_INITIALIZATION_METHOD are " "ENV, SLURM and OPENMPI" ) diff --git a/src/hirad/losses/__init__.py b/src/hirad/losses/__init__.py index e69de29..185527b 100644 --- a/src/hirad/losses/__init__.py +++ b/src/hirad/losses/__init__.py @@ -0,0 +1 @@ +from .loss import ResLoss, RegressionLoss, RegressionLossCE \ No newline at end of file diff --git a/src/hirad/models/__init__.py b/src/hirad/models/__init__.py index 6b790ae..3b494c6 100644 --- a/src/hirad/models/__init__.py +++ b/src/hirad/models/__init__.py @@ -1,3 +1,4 @@ from .unet import UNet from .song_unet import SongUNet, SongUNetPosEmbd, SongUNetPosLtEmbd +from .preconditioning import EDMPrecondSR from .layers import Linear, Conv2d, GroupNorm, AttentionOp, UNetBlock, PositionalEmbedding, FourierEmbedding \ No newline at end of file diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index e69de29..a47910b 100644 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -0,0 +1,462 @@ +import os +import time + +import psutil +import hydra +from omegaconf import DictConfig, OmegaConf +import torch +from hydra.utils import to_absolute_path +from torch.utils.tensorboard import SummaryWriter +from torch.nn.parallel import DistributedDataParallel + +from hirad.distributed import DistributedManager +from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper +from hirad.utils.train_helpers import set_seed, configure_cuda_for_consistent_precision, \ + set_patch_shape, compute_num_accumulation_rounds, \ + is_time_for_periodic_task, handle_and_clip_gradients +from hirad.models import UNet, EDMPrecondSR +from hirad.losses import ResLoss, RegressionLoss, RegressionLossCE + +@hydra.main(version_base=None, config_path="conf", config_name="training") +def main(cfg: DictConfig) -> None: + + # Initialize distributed environment for training + DistributedManager.initialize() + dist = DistributedManager() + + if dist.rank==0: + writer = SummaryWriter(log_dir='tensorboard') + logger = PythonLogger("main") # general logger + logger0 = RankZeroLoggingWrapper(logger, dist) # rank 0 logger + + OmegaConf.resolve(cfg) + dataset_cfg = OmegaConf.to_container(cfg.dataset) + if hasattr(cfg, "validation"): + train_test_split = True + validation_dataset_cfg = OmegaConf.to_container(cfg.validation) + else: + train_test_split = False + validation_dataset_cfg = None + fp_optimizations = cfg.training.perf.fp_optimizations + songunet_checkpoint_level = cfg.training.perf.songunet_checkpoint_level + fp16 = fp_optimizations == "fp16" + enable_amp = fp_optimizations.startswith("amp") + amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16 + logger.info(f"Saving the outputs in {os.getcwd()}") + checkpoint_dir = os.path.join( + cfg.training.io.get("checkpoint_dir", "."), f"checkpoints_{cfg.model.name}" + ) + if cfg.training.hp.batch_size_per_gpu == "auto": + cfg.training.hp.batch_size_per_gpu = ( + cfg.training.hp.total_batch_size // dist.world_size + ) + + set_seed(dist.rank) + configure_cuda_for_consistent_precision() + + ### Write our own dataloader ### + ( + dataset, + dataset_iterator, + validation_dataset, + validation_dataset_iterator + ) = None, None, None, None + + dataset_channels = None #len(dataset.input_channels()) + img_in_channels = None #dataset_channels + img_shape = None #dataset.image_shape() + img_out_channels = None #len(dataset.output_channels()) + + prob_channels = None + + # Parse the patch shape + if ( + cfg.model.name == "patched_diffusion" + or cfg.model.name == "lt_aware_patched_diffusion" + ): + patch_shape_x = cfg.training.hp.patch_shape_x + patch_shape_y = cfg.training.hp.patch_shape_y + else: + patch_shape_x = None + patch_shape_y = None + patch_shape = (patch_shape_y, patch_shape_x) + img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) + if patch_shape != img_shape: + logger0.info("Patch-based training enabled") + else: + logger0.info("Patch-based training disabled") + # interpolate global channel if patch-based model is used + if img_shape[1] != patch_shape[1]: + img_in_channels += dataset_channels + + + # Instantiate the model and move to device. + if cfg.model.name not in ( + "regression", + "lt_aware_ce_regression", + "diffusion", + "patched_diffusion", + "lt_aware_patched_diffusion", + ): + raise ValueError("Invalid model") + model_args = { # default parameters for all networks + "img_out_channels": img_out_channels, + "img_resolution": list(img_shape), + "use_fp16": fp16, + } + standard_model_cfgs = { # default parameters for different network types + "regression": { + "img_channels": 4, + "N_grid_channels": 4, + "embedding_type": "zero", + "checkpoint_level": songunet_checkpoint_level, + }, + "lt_aware_ce_regression": { + "img_channels": 4, + "N_grid_channels": 4, + "embedding_type": "zero", + "lead_time_channels": 4, + "lead_time_steps": 9, + "prob_channels": prob_channels, + "checkpoint_level": songunet_checkpoint_level, + "model_type": "SongUNetPosLtEmbd", + }, + "diffusion": { + "img_channels": img_out_channels, + "gridtype": "sinusoidal", + "N_grid_channels": 4, + "checkpoint_level": songunet_checkpoint_level, + }, + "patched_diffusion": { + "img_channels": img_out_channels, + "gridtype": "learnable", + "N_grid_channels": 100, + "checkpoint_level": songunet_checkpoint_level, + }, + "lt_aware_patched_diffusion": { + "img_channels": img_out_channels, + "gridtype": "learnable", + "N_grid_channels": 100, + "lead_time_channels": 20, + "lead_time_steps": 9, + "checkpoint_level": songunet_checkpoint_level, + "model_type": "SongUNetPosLtEmbd", + }, + } + + + model_args.update(standard_model_cfgs[cfg.model.name]) + if cfg.model.name in ( + "diffusion", + "patched_diffusion", + "lt_aware_patched_diffusion", + ): + model_args["scale_cond_input"] = cfg.model.scale_cond_input + if hasattr(cfg.model, "model_args"): # override defaults from config file + model_args.update(OmegaConf.to_container(cfg.model.model_args)) + if cfg.model.name == "regression": + model = UNet( + img_in_channels=img_in_channels + model_args["N_grid_channels"], + **model_args, + ) + elif cfg.model.name == "lt_aware_ce_regression": + model = UNet( + img_in_channels=img_in_channels + + model_args["N_grid_channels"] + + model_args["lead_time_channels"], + **model_args, + ) + elif cfg.model.name == "lt_aware_patched_diffusion": + model = EDMPrecondSR( + img_in_channels=img_in_channels + + model_args["N_grid_channels"] + + model_args["lead_time_channels"], + **model_args, + ) + else: # diffusion or patched diffusion + model = EDMPrecondSR( + img_in_channels=img_in_channels + model_args["N_grid_channels"], + **model_args, + ) + + model.train().requires_grad_(True).to(dist.device) + + # Enable distributed data parallel if applicable + if dist.world_size > 1: + model = DistributedDataParallel( + model, + device_ids=[dist.local_rank], + broadcast_buffers=True, + output_device=dist.device, + find_unused_parameters=dist.find_unused_parameters, + ) + + # Load the regression checkpoint if applicable + if hasattr(cfg.training.io, "regression_checkpoint_path"): + regression_checkpoint_path = to_absolute_path( + cfg.training.io.regression_checkpoint_path + ) + if not os.path.exists(regression_checkpoint_path): + raise FileNotFoundError( + f"Expected this regression checkpoint but not found: {regression_checkpoint_path}" + ) + regression_net = torch.nn.Module() #Module.from_checkpoint(regression_checkpoint_path) figure out how to save and load models, also, some basic functions like num_params, device + regression_net.eval().requires_grad_(False).to(dist.device) + logger0.success("Loaded the pre-trained regression model") + + # Instantiate the loss function + patch_num = getattr(cfg.training.hp, "patch_num", 1) + if cfg.model.name in ( + "diffusion", + "patched_diffusion", + "lt_aware_patched_diffusion", + ): + loss_fn = ResLoss( + regression_net=regression_net, + img_shape_x=img_shape[1], + img_shape_y=img_shape[0], + patch_shape_x=patch_shape[1], + patch_shape_y=patch_shape[0], + patch_num=patch_num, + hr_mean_conditioning=cfg.model.hr_mean_conditioning, + ) + elif cfg.model.name == "regression": + loss_fn = RegressionLoss() + elif cfg.model.name == "lt_aware_ce_regression": + loss_fn = RegressionLossCE(prob_channels=prob_channels) + + # Instantiate the optimizer + optimizer = torch.optim.Adam( + params=model.parameters(), lr=cfg.training.hp.lr, betas=[0.9, 0.999], eps=1e-8 + ) + + # Record the current time to measure the duration of subsequent operations. + start_time = time.time() + + # Compute the number of required gradient accumulation rounds + # It is automatically used if batch_size_per_gpu * dist.world_size < total_batch_size + batch_gpu_total, num_accumulation_rounds = compute_num_accumulation_rounds( + cfg.training.hp.total_batch_size, + cfg.training.hp.batch_size_per_gpu, + dist.world_size, + ) + batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu + logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds") + + + ## Resume training from previous checkpoints if exists + if dist.world_size > 1: + torch.distributed.barrier() + try: + cur_nimg = 0 + # Fix loading and saving checkpoint + #load_checkpoint( + # path=checkpoint_dir, + # models=model, + # optimizer=optimizer, + # device=dist.device, + # ) + except: + cur_nimg = 0 + + ############################################################################ + # MAIN TRAINING LOOP # + ############################################################################ + + logger0.info(f"Training for {cfg.training.hp.training_duration} images...") + done = False + + # init variables to monitor running mean of average loss since last periodic + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 + + + while not done: + tick_start_nimg = cur_nimg + tick_start_time = time.time() + # Compute & accumulate gradients + optimizer.zero_grad(set_to_none=True) + loss_accum = 0 + for _ in range(num_accumulation_rounds): + img_clean, img_lr, labels, *lead_time_label = next(dataset_iterator) # what are labels and lead_time_label + img_clean = img_clean.to(dist.device).to(torch.float32).contiguous() + img_lr = img_lr.to(dist.device).to(torch.float32).contiguous() + labels = labels.to(dist.device).contiguous() + loss_fn_kwargs = { + "net": model, + "img_clean": img_clean, + "img_lr": img_lr, + "labels": labels, + "augment_pipe": None, + } + if lead_time_label: + lead_time_label = lead_time_label[0].to(dist.device).contiguous() + loss_fn_kwargs.update({"lead_time_label": lead_time_label}) + else: + lead_time_label = None + with torch.autocast("cuda", dtype=amp_dtype, enabled=enable_amp): + loss = loss_fn(**loss_fn_kwargs) + loss = loss.sum() / batch_size_per_gpu + loss_accum += loss / num_accumulation_rounds + loss.backward() + + + loss_sum = torch.tensor([loss_accum], device=dist.device) + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce(loss_sum, op=torch.distributed.ReduceOp.SUM) + average_loss = (loss_sum / dist.world_size).cpu().item() + + # update running mean of average loss since last periodic task + average_loss_running_mean += ( + average_loss - average_loss_running_mean + ) / n_average_loss_running_mean + n_average_loss_running_mean += 1 + + if dist.rank == 0: + writer.add_scalar("training_loss", average_loss, cur_nimg) + writer.add_scalar( + "training_loss_running_mean", average_loss_running_mean, cur_nimg + ) + + ptt = is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ) + if ptt: + # reset running mean of average loss + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 + + # Update weights. + lr_rampup = cfg.training.hp.lr_rampup # ramp up the learning rate + for g in optimizer.param_groups: + if lr_rampup > 0: + g["lr"] = cfg.training.hp.lr * min(cur_nimg / lr_rampup, 1) + if cur_nimg >= lr_rampup: + g["lr"] *= cfg.training.hp.lr_decay ** ((cur_nimg - lr_rampup) // 5e6) + current_lr = g["lr"] + if dist.rank == 0: + writer.add_scalar("learning_rate", current_lr, cur_nimg) + handle_and_clip_gradients( + model, grad_clip_threshold=cfg.training.hp.grad_clip_threshold + ) + optimizer.step() + + cur_nimg += cfg.training.hp.total_batch_size + done = cur_nimg >= cfg.training.hp.training_duration + + # Validation + if validation_dataset_iterator is not None: + valid_loss_accum = 0 + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.validation_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + ): + with torch.no_grad(): + for _ in range(cfg.training.io.validation_steps): + img_clean_valid, img_lr_valid, labels_valid = next( + validation_dataset_iterator + ) + + img_clean_valid = ( + img_clean_valid.to(dist.device) + .to(torch.float32) + .contiguous() + ) + img_lr_valid = ( + img_lr_valid.to(dist.device).to(torch.float32).contiguous() + ) + labels_valid = labels_valid.to(dist.device).contiguous() + loss_valid = loss_fn( + net=model, + img_clean=img_clean_valid, + img_lr=img_lr_valid, + labels=labels_valid, + augment_pipe=None, + ) + loss_valid = ( + (loss_valid.sum() / batch_size_per_gpu).cpu().item() + ) + valid_loss_accum += ( + loss_valid / cfg.training.io.validation_steps + ) + valid_loss_sum = torch.tensor( + [valid_loss_accum], device=dist.device + ) + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce( + valid_loss_sum, op=torch.distributed.ReduceOp.SUM + ) + average_valid_loss = valid_loss_sum / dist.world_size + if dist.rank == 0: + writer.add_scalar( + "validation_loss", average_valid_loss, cur_nimg + ) + + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + # Print stats if we crossed the printing threshold with this batch + tick_end_time = time.time() + fields = [] + fields += [f"samples {cur_nimg:<9.1f}"] + fields += [f"training_loss {average_loss:<7.2f}"] + fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] + fields += [f"learning_rate {current_lr:<7.8f}"] + fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] + fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] + fields += [ + f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" + ] + fields += [ + f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" + ] + logger0.info(" ".join(fields)) + torch.cuda.reset_peak_memory_stats() + + # Save checkpoints + if dist.world_size > 1: + torch.distributed.barrier() + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.save_checkpoint_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + # figure out how to do save and load checkpoint + #save_checkpoint( + # path=checkpoint_dir, + # models=model, + # optimizer=optimizer, + # epoch=cur_nimg, + # ) + pass + + # Done. + logger0.info("Training Completed.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/hirad/utils/capture.py b/src/hirad/utils/capture.py index 50057f9..9c38d5a 100644 --- a/src/hirad/utils/capture.py +++ b/src/hirad/utils/capture.py @@ -24,7 +24,7 @@ import torch -from src.distributed import DistributedManager +from hirad.distributed import DistributedManager float16 = NewType("float16", torch.float16) bfloat16 = NewType("bfloat16", torch.bfloat16) From 134ed835f7cfde077dddb5d8eeee2313cf994611 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 16 Apr 2025 14:33:48 +0200 Subject: [PATCH 10/66] add abstract dataset class --- src/hirad/datasets/base.py | 85 ++++++++++++++++++++++++++ src/hirad/datasets/dataset.py | 111 ++++++++++++++++++++++++++++++++++ 2 files changed, 196 insertions(+) create mode 100644 src/hirad/datasets/base.py create mode 100644 src/hirad/datasets/dataset.py diff --git a/src/hirad/datasets/base.py b/src/hirad/datasets/base.py new file mode 100644 index 0000000..22b00d2 --- /dev/null +++ b/src/hirad/datasets/base.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Tuple + +import numpy as np +import torch + + +@dataclass +class ChannelMetadata: + """Metadata describing a data channel.""" + + name: str + level: str = "" + auxiliary: bool = False + + +class DownscalingDataset(torch.utils.data.Dataset, ABC): + """An abstract class that defines the interface for downscaling datasets.""" + + @abstractmethod + def longitude(self) -> np.ndarray: + """Get longitude values from the dataset.""" + pass + + @abstractmethod + def latitude(self) -> np.ndarray: + """Get latitude values from the dataset.""" + pass + + @abstractmethod + def input_channels(self) -> List[ChannelMetadata]: + """Metadata for the input channels. A list of ChannelMetadata, one for each channel""" + pass + + @abstractmethod + def output_channels(self) -> List[ChannelMetadata]: + """Metadata for the output channels. A list of ChannelMetadata, one for each channel""" + pass + + @abstractmethod + def time(self) -> List: + """Get time values from the dataset.""" + pass + + @abstractmethod + def image_shape(self) -> Tuple[int, int]: + """Get the (height, width) of the data (same for input and output).""" + pass + + def normalize_input(self, x: np.ndarray) -> np.ndarray: + """Convert input from physical units to normalized data.""" + return x + + def denormalize_input(self, x: np.ndarray) -> np.ndarray: + """Convert input from normalized data to physical units.""" + return x + + def normalize_output(self, x: np.ndarray) -> np.ndarray: + """Convert output from physical units to normalized data.""" + return x + + def denormalize_output(self, x: np.ndarray) -> np.ndarray: + """Convert output from normalized data to physical units.""" + return x + + def info(self) -> dict: + """Get information about the dataset.""" + return {} diff --git a/src/hirad/datasets/dataset.py b/src/hirad/datasets/dataset.py new file mode 100644 index 0000000..2a26630 --- /dev/null +++ b/src/hirad/datasets/dataset.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable, Tuple, Union +import copy +import torch + +from hirad.utils.function_utils import InfiniteSampler +from hirad.distributed import DistributedManager + +from . import base, cwb, hrrrmini, gefs_hrrr + + +# this maps all known dataset types to the corresponding init function +known_datasets = { + "cwb": cwb.get_zarr_dataset, + "hrrr_mini": hrrrmini.HRRRMiniDataset, + "gefs_hrrr": gefs_hrrr.HrrrForecastGEFSDataset, +} + + +def init_train_valid_datasets_from_config( + dataset_cfg: dict, + dataloader_cfg: Union[dict, None] = None, + batch_size: int = 1, + seed: int = 0, + validation_dataset_cfg: Union[dict, None] = None, + train_test_split: bool = True, +) -> Tuple[ + base.DownscalingDataset, + Iterable, + Union[base.DownscalingDataset, None], + Union[Iterable, None], +]: + """ + A wrapper function for managing the train-test split for the CWB dataset. + + Parameters: + - dataset_cfg (dict): Configuration for the dataset. + - dataloader_cfg (dict, optional): Configuration for the dataloader. Defaults to None. + - batch_size (int): The number of samples in each batch of data. Defaults to 1. + - seed (int): The random seed for dataset shuffling. Defaults to 0. + - train_test_split (bool): A flag to determine whether to create a validation dataset. Defaults to True. + + Returns: + - Tuple[base.DownscalingDataset, Iterable, Optional[base.DownscalingDataset], Optional[Iterable]]: A tuple containing the training dataset and iterator, and optionally the validation dataset and iterator if train_test_split is True. + """ + + config = copy.deepcopy(dataset_cfg) + (dataset, dataset_iter) = init_dataset_from_config( + config, dataloader_cfg, batch_size=batch_size, seed=seed + ) + if train_test_split: + valid_dataset_cfg = copy.deepcopy(config) + if validation_dataset_cfg: + valid_dataset_cfg.update(validation_dataset_cfg) + (valid_dataset, valid_dataset_iter) = init_dataset_from_config( + valid_dataset_cfg, dataloader_cfg, batch_size=batch_size, seed=seed + ) + else: + valid_dataset = valid_dataset_iter = None + + return dataset, dataset_iter, valid_dataset, valid_dataset_iter + + +def init_dataset_from_config( + dataset_cfg: dict, + dataloader_cfg: Union[dict, None] = None, + batch_size: int = 1, + seed: int = 0, +) -> Tuple[base.DownscalingDataset, Iterable]: + dataset_cfg = copy.deepcopy(dataset_cfg) + dataset_type = dataset_cfg.pop("type", "cwb") + if "train_test_split" in dataset_cfg: + # handled by init_train_valid_datasets_from_config + del dataset_cfg["train_test_split"] + dataset_init_func = known_datasets[dataset_type] + + dataset_obj = dataset_init_func(**dataset_cfg) + if dataloader_cfg is None: + dataloader_cfg = {} + + dist = DistributedManager() + dataset_sampler = InfiniteSampler( + dataset=dataset_obj, rank=dist.rank, num_replicas=dist.world_size, seed=seed + ) + + dataset_iterator = iter( + torch.utils.data.DataLoader( + dataset=dataset_obj, + sampler=dataset_sampler, + batch_size=batch_size, + worker_init_fn=None, + **dataloader_cfg, + ) + ) + + return (dataset_obj, dataset_iterator) From de4e7c55374b45d637754f5bf8be4f6726e762c7 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 23 Apr 2025 16:09:28 +0200 Subject: [PATCH 11/66] adapt checkpoint saving and loading --- src/hirad/utils/checkpoint.py | 51 ++++++++++------------------------- 1 file changed, 14 insertions(+), 37 deletions(-) diff --git a/src/hirad/utils/checkpoint.py b/src/hirad/utils/checkpoint.py index 8ec70fa..5ce194a 100644 --- a/src/hirad/utils/checkpoint.py +++ b/src/hirad/utils/checkpoint.py @@ -23,11 +23,11 @@ from torch.cuda.amp import GradScaler from torch.optim.lr_scheduler import _LRScheduler -from src.distributed import DistributedManager -from src.utils.console import PythonLogger -from src.utils.capture import _StaticCapture +from hirad.distributed import DistributedManager +from hirad.utils.console import PythonLogger +from hirad.utils.capture import _StaticCapture -optimizer = NewType("optimizer", torch.optim) +optimizer = NewType("optimizer", torch.optim.Optimizer) scheduler = NewType("scheduler", _LRScheduler) scaler = NewType("scaler", GradScaler) @@ -39,7 +39,7 @@ def _get_checkpoint_filename( base_name: str = "checkpoint", index: Union[int, None] = None, saving: bool = False, - model_type: str = "mdlus", + model_type: str = "pt", ) -> str: """Gets the file name /path of checkpoint @@ -91,7 +91,7 @@ def _get_checkpoint_filename( ) # File extension for PhysicsNeMo models or PyTorch models - file_extension = ".mdlus" if model_type == "mdlus" else ".pt" + file_extension = "."+model_type # If epoch is provided load that file if index is not None: @@ -157,8 +157,6 @@ def _unique_model_names( model0 = model0.module # Base name of model is meta.name unless pytorch model base_name = model0.__class__.__name__ - if isinstance(model0, physicsnemo.models.Module): - base_name = model0.meta.name # If we have multiple models of the same name, introduce another index if base_name in model_dict: model_dict[base_name].append(model0) @@ -189,8 +187,8 @@ def save_checkpoint( """Training checkpoint saving utility This will save a training checkpoint in the provided path following the file naming - convention "checkpoint.{model parallel id}.{epoch/index}.mdlus". The load checkpoint - method in PhysicsNeMo core can then be used to read this file. + convention "checkpoint.{model parallel id}.{epoch/index}.pt". The load checkpoint + method can then be used to read this file. Parameters ---------- @@ -224,21 +222,13 @@ def save_checkpoint( models = [models] models = _unique_model_names(models) for name, model in models.items(): - # Get model type - model_type = ( - "mdlus" if isinstance(model, physicsnemo.models.Module) else "pt" - ) - # Get full file path / name file_name = _get_checkpoint_filename( - path, name, index=epoch, saving=True, model_type=model_type + path, name, index=epoch, saving=True, model_type="pt" ) # Save state dictionary - if isinstance(model, physicsnemo.models.Module): - model.save(file_name) - else: - torch.save(model.state_dict(), file_name) + torch.save(model.state_dict(), file_name) checkpoint_logging.success(f"Saved model state dictionary: {file_name}") # == Saving training checkpoint == @@ -251,12 +241,9 @@ def save_checkpoint( if scheduler: checkpoint_dict["scheduler_state_dict"] = scheduler.state_dict() - # Scheduler state dict + # Scaler state dict if scaler: checkpoint_dict["scaler_state_dict"] = scaler.state_dict() - # Static capture is being used, save its grad scaler - if _StaticCapture._amp_scalers: - checkpoint_dict["static_capture_state_dict"] = _StaticCapture.state_dict() # Output file name output_filename = _get_checkpoint_filename( @@ -288,8 +275,7 @@ def load_checkpoint( ) -> int: """Checkpoint loading utility - This loader is designed to be used with the save checkpoint utility in PhysicsNeMo - Launch. Given a path, this method will try to find a checkpoint and load state + This loader is designed to be used with the save checkpoint utility. Given a path, this method will try to find a checkpoint and load state dictionaries into the provided training objects. Parameters @@ -331,9 +317,7 @@ def load_checkpoint( models = _unique_model_names(models) for name, model in models.items(): # Get model type - model_type = ( - "mdlus" if isinstance(model, physicsnemo.models.Module) else "pt" - ) + model_type = "pt" # Get full file path / name file_name = _get_checkpoint_filename( @@ -345,10 +329,7 @@ def load_checkpoint( ) continue # Load state dictionary - if isinstance(model, physicsnemo.models.Module): - model.load(file_name) - else: - model.load_state_dict(torch.load(file_name, map_location=device)) + model.load_state_dict(torch.load(file_name, map_location=device)) checkpoint_logging.success( f"Loaded model state dictionary {file_name} to device {device}" @@ -382,10 +363,6 @@ def load_checkpoint( scaler.load_state_dict(checkpoint_dict["scaler_state_dict"]) checkpoint_logging.success("Loaded grad scaler state dictionary") - if "static_capture_state_dict" in checkpoint_dict: - _StaticCapture.load_state_dict(checkpoint_dict["static_capture_state_dict"]) - checkpoint_logging.success("Loaded static capture state dictionary") - epoch = 0 if "epoch" in checkpoint_dict: epoch = checkpoint_dict["epoch"] From cff9da486d60d77764e2996632ba12d4642fc47f Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 25 Apr 2025 16:51:17 +0200 Subject: [PATCH 12/66] adapt checkpoint loading and saving --- src/hirad/models/__init__.py | 5 +- src/hirad/training/train.py | 105 ++++++++++++++++++++++--------- src/hirad/utils/checkpoint.py | 115 +++++++++++----------------------- 3 files changed, 115 insertions(+), 110 deletions(-) diff --git a/src/hirad/models/__init__.py b/src/hirad/models/__init__.py index 3b494c6..f17e5ce 100644 --- a/src/hirad/models/__init__.py +++ b/src/hirad/models/__init__.py @@ -1,4 +1,5 @@ from .unet import UNet from .song_unet import SongUNet, SongUNetPosEmbd, SongUNetPosLtEmbd -from .preconditioning import EDMPrecondSR -from .layers import Linear, Conv2d, GroupNorm, AttentionOp, UNetBlock, PositionalEmbedding, FourierEmbedding \ No newline at end of file +from .preconditioning import EDMPrecondSR, EDMPrecond +from .layers import Linear, Conv2d, GroupNorm, AttentionOp, UNetBlock, PositionalEmbedding, FourierEmbedding +from .meta import ModelMetaData \ No newline at end of file diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index a47910b..d6fe563 100644 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -4,6 +4,7 @@ import psutil import hydra from omegaconf import DictConfig, OmegaConf +import json import torch from hydra.utils import to_absolute_path from torch.utils.tensorboard import SummaryWriter @@ -14,8 +15,10 @@ from hirad.utils.train_helpers import set_seed, configure_cuda_for_consistent_precision, \ set_patch_shape, compute_num_accumulation_rounds, \ is_time_for_periodic_task, handle_and_clip_gradients +from hirad.utils.checkpoint import load_checkpoint, save_checkpoint from hirad.models import UNet, EDMPrecondSR from hirad.losses import ResLoss, RegressionLoss, RegressionLossCE +from hirad.datasets import init_train_valid_datasets_from_config @hydra.main(version_base=None, config_path="conf", config_name="training") def main(cfg: DictConfig) -> None: @@ -54,20 +57,38 @@ def main(cfg: DictConfig) -> None: set_seed(dist.rank) configure_cuda_for_consistent_precision() - ### Write our own dataloader ### + # Instantiate the dataset + data_loader_kwargs = { + "pin_memory": True, + "num_workers": cfg.training.perf.dataloader_workers, + "prefetch_factor": 2, + } ( - dataset, - dataset_iterator, - validation_dataset, - validation_dataset_iterator - ) = None, None, None, None + dataset, + dataset_iterator, + validation_dataset, + validation_dataset_iterator, + ) = init_train_valid_datasets_from_config( + dataset_cfg, + data_loader_kwargs, + batch_size=cfg.training.hp.batch_size_per_gpu, + seed=0, + validation_dataset_cfg=validation_dataset_cfg, + train_test_split=train_test_split, + ) - dataset_channels = None #len(dataset.input_channels()) - img_in_channels = None #dataset_channels - img_shape = None #dataset.image_shape() - img_out_channels = None #len(dataset.output_channels()) + # Parse image configuration & update model args + dataset_channels = len(dataset.input_channels()) + img_in_channels = dataset_channels + img_shape = dataset.image_shape() + img_out_channels = len(dataset.output_channels()) + if cfg.model.hr_mean_conditioning: + img_in_channels += img_out_channels - prob_channels = None + if cfg.model.name == "lt_aware_ce_regression": + prob_channels = dataset.get_prob_channel_index() + else: + prob_channels = None # Parse the patch shape if ( @@ -181,6 +202,10 @@ def main(cfg: DictConfig) -> None: model.train().requires_grad_(True).to(dist.device) + if not os.path.exists(os.path.join(checkpoint_dir, 'model_args.json')): + with open(os.path.join(checkpoint_dir, 'model_args.json'), 'w') as f: + json.dump(model_args, f) + # Enable distributed data parallel if applicable if dist.world_size > 1: model = DistributedDataParallel( @@ -196,11 +221,38 @@ def main(cfg: DictConfig) -> None: regression_checkpoint_path = to_absolute_path( cfg.training.io.regression_checkpoint_path ) - if not os.path.exists(regression_checkpoint_path): + if not os.path.isdir(regression_checkpoint_path): raise FileNotFoundError( f"Expected this regression checkpoint but not found: {regression_checkpoint_path}" ) - regression_net = torch.nn.Module() #Module.from_checkpoint(regression_checkpoint_path) figure out how to save and load models, also, some basic functions like num_params, device + #regression_net = torch.nn.Module() #TODO Module.from_checkpoint(regression_checkpoint_path) figure out how to save and load models, also, some basic functions like num_params, device + #TODO make regression model loading more robust (model type is both in rergession_checkpoint_path and regression_name) + #TODO add the option to choose epoch to load from / regression_checkpoint_path is now a folder + regression_model_args_path = os.path.join(regression_checkpoint_path, 'model_args.json') + if not os.path.isfile(regression_model_args_path): + raise FileNotFoundError(f"Missing config file at '{regression_model_args_path}'.") + + with open(regression_model_args_path, 'r') as f: + regression_model_args = json.load(f) + + if cfg.model.name == "lt_aware_patched_diffusion": + regression_net = UNet( + img_in_channels=img_in_channels + + model_args["N_grid_channels"] + + model_args["lead_time_channels"], + **regression_model_args, + ) + else: + regression_net = UNet( + img_in_channels=img_in_channels + model_args["N_grid_channels"], + **regression_model_args, + ) + + _ = load_checkpoint( + path=regression_checkpoint_path, + model=regression_net, + device=dist.device + ) regression_net.eval().requires_grad_(False).to(dist.device) logger0.success("Loaded the pre-trained regression model") @@ -248,14 +300,12 @@ def main(cfg: DictConfig) -> None: if dist.world_size > 1: torch.distributed.barrier() try: - cur_nimg = 0 - # Fix loading and saving checkpoint - #load_checkpoint( - # path=checkpoint_dir, - # models=model, - # optimizer=optimizer, - # device=dist.device, - # ) + cur_nimg = load_checkpoint( + path=checkpoint_dir, + model=model, + optimizer=optimizer, + device=dist.device, + ) except: cur_nimg = 0 @@ -445,13 +495,12 @@ def main(cfg: DictConfig) -> None: dist.rank, rank_0_only=True, ): - # figure out how to do save and load checkpoint - #save_checkpoint( - # path=checkpoint_dir, - # models=model, - # optimizer=optimizer, - # epoch=cur_nimg, - # ) + save_checkpoint( + path=checkpoint_dir, + model=model, + optimizer=optimizer, + epoch=cur_nimg, + ) pass # Done. diff --git a/src/hirad/utils/checkpoint.py b/src/hirad/utils/checkpoint.py index 5ce194a..03b423d 100644 --- a/src/hirad/utils/checkpoint.py +++ b/src/hirad/utils/checkpoint.py @@ -17,7 +17,7 @@ import glob import re from pathlib import Path -from typing import Any, Dict, List, NewType, Optional, Union +from typing import Any, Dict, List, NewType, Optional, Union, Tuple import torch from torch.cuda.amp import GradScaler @@ -25,7 +25,6 @@ from hirad.distributed import DistributedManager from hirad.utils.console import PythonLogger -from hirad.utils.capture import _StaticCapture optimizer = NewType("optimizer", torch.optim.Optimizer) scheduler = NewType("scheduler", _LRScheduler) @@ -133,51 +132,9 @@ def _get_checkpoint_filename( return checkpoint_filename -def _unique_model_names( - models: List[torch.nn.Module], -) -> Dict[str, torch.nn.Module]: - """Util to clean model names and index if repeat names, will also strip DDP wrappers - if they exist. - - Parameters - ---------- - model : List[torch.nn.Module] - List of models to generate names for - - Returns - ------- - Dict[str, torch.nn.Module] - Dictionary of model names and respective modules - """ - # Loop through provided models and set up base names - model_dict = {} - for model0 in models: - if hasattr(model0, "module"): - # Strip out DDP layer - model0 = model0.module - # Base name of model is meta.name unless pytorch model - base_name = model0.__class__.__name__ - # If we have multiple models of the same name, introduce another index - if base_name in model_dict: - model_dict[base_name].append(model0) - else: - model_dict[base_name] = [model0] - - # Set up unique model names if needed - output_dict = {} - for key, model in model_dict.items(): - if len(model) > 1: - for i, model0 in enumerate(model): - output_dict[key + str(i)] = model0 - else: - output_dict[key] = model[0] - - return output_dict - - def save_checkpoint( path: str, - models: Union[torch.nn.Module, List[torch.nn.Module], None] = None, + model: Union[torch.nn.Module, None] = None, optimizer: Union[optimizer, None] = None, scheduler: Union[scheduler, None] = None, scaler: Union[scaler, None] = None, @@ -217,19 +174,20 @@ def save_checkpoint( Path(path).mkdir(parents=True, exist_ok=True) # == Saving model checkpoint == - if models: - if not isinstance(models, list): - models = [models] - models = _unique_model_names(models) - for name, model in models.items(): - # Get full file path / name - file_name = _get_checkpoint_filename( - path, name, index=epoch, saving=True, model_type="pt" - ) + if model: + if hasattr(model, "module"): + # Strip out DDP layer + model = model.module + # Base name of model is meta.name unless pytorch model + name = model.__class__.__name__ + # Get full file path / name + file_name = _get_checkpoint_filename( + path, name, index=epoch, saving=True, model_type="pt" + ) - # Save state dictionary - torch.save(model.state_dict(), file_name) - checkpoint_logging.success(f"Saved model state dictionary: {file_name}") + # Save state dictionary + torch.save(model.state_dict(), file_name) + checkpoint_logging.success(f"Saved model state dictionary: {file_name}") # == Saving training checkpoint == checkpoint_dict = {} @@ -265,7 +223,7 @@ def save_checkpoint( def load_checkpoint( path: str, - models: Union[torch.nn.Module, List[torch.nn.Module], None] = None, + model: torch.nn.Module, optimizer: Union[optimizer, None] = None, scheduler: Union[scheduler, None] = None, scaler: Union[scaler, None] = None, @@ -311,29 +269,26 @@ def load_checkpoint( return 0 # == Loading model checkpoint == - if models: - if not isinstance(models, list): - models = [models] - models = _unique_model_names(models) - for name, model in models.items(): - # Get model type - model_type = "pt" - - # Get full file path / name - file_name = _get_checkpoint_filename( - path, name, index=epoch, model_type=model_type - ) - if not Path(file_name).exists(): - checkpoint_logging.error( - f"Could not find valid model file {file_name}, skipping load" - ) - continue - # Load state dictionary - model.load_state_dict(torch.load(file_name, map_location=device)) + if hasattr(model, "module"): + # Strip out DDP layer + model = model.module + # Base name of model is meta.name unless pytorch model + name = model.__class__.__name__ + # Get full file path / name + file_name = _get_checkpoint_filename( + path, name, index=epoch, + ) + if not Path(file_name).exists(): + checkpoint_logging.error( + f"Could not find valid model file {file_name}, skipping load" + ) + else: + # Load state dictionary + model.load_state_dict(torch.load(file_name, map_location=device)) - checkpoint_logging.success( - f"Loaded model state dictionary {file_name} to device {device}" - ) + checkpoint_logging.success( + f"Loaded model state dictionary {file_name} to device {device}" + ) # == Loading training checkpoint == checkpoint_filename = _get_checkpoint_filename(path, index=epoch, model_type="pt") From 1029d4c7bb16f105e34181810a959797a1f32c8e Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 25 Apr 2025 16:52:27 +0200 Subject: [PATCH 13/66] fix model imports and dependency on module metadata --- src/hirad/models/meta.py | 50 +++++++++++++++++++++++++++++ src/hirad/models/preconditioning.py | 14 ++++---- src/hirad/models/song_unet.py | 6 ++-- src/hirad/models/unet.py | 6 ++-- 4 files changed, 63 insertions(+), 13 deletions(-) create mode 100644 src/hirad/models/meta.py diff --git a/src/hirad/models/meta.py b/src/hirad/models/meta.py new file mode 100644 index 0000000..aab8e45 --- /dev/null +++ b/src/hirad/models/meta.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + + +@dataclass +class ModelMetaData: + """Data class for storing essential meta data needed for all Hirad Models""" + + # Model info + name: str = "HiradModule" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp: bool = False + amp_cpu: bool = None + amp_gpu: bool = None + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + onnx_gpu: bool = None + onnx_cpu: bool = None + onnx_runtime: bool = False + trt: bool = False + # Physics informed + var_dim: int = -1 + func_torch: bool = False + auto_grad: bool = False + + def __post_init__(self): + self.amp_cpu = self.amp if self.amp_cpu is None else self.amp_cpu + self.amp_gpu = self.amp if self.amp_gpu is None else self.amp_gpu + self.onnx_cpu = self.onnx if self.onnx_cpu is None else self.onnx_cpu + self.onnx_gpu = self.onnx if self.onnx_gpu is None else self.onnx_gpu diff --git a/src/hirad/models/preconditioning.py b/src/hirad/models/preconditioning.py index 7b621e2..b0e924e 100644 --- a/src/hirad/models/preconditioning.py +++ b/src/hirad/models/preconditioning.py @@ -29,11 +29,11 @@ import torch import torch.nn as nn -from src.models import ( +from hirad.models import ( DhariwalUNet, # noqa: F401 for globals SongUNet, # noqa: F401 for globals ) -from physicsnemo.models.meta import ModelMetaData +from hirad.models import ModelMetaData network_module = importlib.import_module("physicsnemo.models.diffusion") @@ -105,7 +105,7 @@ def __init__( model_type: str = "SongUNet", **model_kwargs: dict, ): - super().__init__(meta=VPPrecondMetaData) + super().__init__() #meta=VPPrecondMetaData self.img_resolution = img_resolution self.img_channels = img_channels self.label_dim = label_dim @@ -282,7 +282,7 @@ def __init__( model_type: str = "SongUNet", **model_kwargs: dict, ): - super().__init__(meta=VEPrecondMetaData) + super().__init__() #meta=VEPrecondMetaData self.img_resolution = img_resolution self.img_channels = img_channels self.label_dim = label_dim @@ -414,7 +414,7 @@ def __init__( model_type="DhariwalUNet", **model_kwargs, ): - super().__init__(meta=iDDPMPrecondMetaData) + super().__init__() #meta=iDDPMPrecondMetaData self.img_resolution = img_resolution self.img_channels = img_channels self.label_dim = label_dim @@ -601,7 +601,7 @@ def __init__( img_out_channels=None, **model_kwargs, ): - super().__init__(meta=EDMPrecondMetaData) + super().__init__() #meta=EDMPrecondMetaData self.img_resolution = img_resolution if img_in_channels is not None: img_in_channels = img_in_channels @@ -767,7 +767,7 @@ def __init__( scale_cond_input=True, **model_kwargs, ): - super().__init__(meta=EDMPrecondSRMetaData) + super().__init__() #meta=EDMPrecondSRMetaData self.img_resolution = img_resolution self.img_channels = img_channels # TODO: this is not used, remove it self.img_in_channels = img_in_channels diff --git a/src/hirad/models/song_unet.py b/src/hirad/models/song_unet.py index 68adbda..5bfca8a 100644 --- a/src/hirad/models/song_unet.py +++ b/src/hirad/models/song_unet.py @@ -29,7 +29,7 @@ from torch.utils.checkpoint import checkpoint import torch.nn as nn -from src.models import ( +from hirad.models import ( Conv2d, FourierEmbedding, GroupNorm, @@ -37,7 +37,7 @@ PositionalEmbedding, UNetBlock, ) -from physicsnemo.models.meta import ModelMetaData +from hirad.models import ModelMetaData @dataclass @@ -175,7 +175,7 @@ def __init__( f"Invalid decoder_type: {decoder_type}. Must be one of {valid_decoder_types}." ) - super().__init__(meta=MetaData()) + super().__init__() #meta=MetaData() self.label_dropout = label_dropout self.embedding_type = embedding_type emb_channels = model_channels * channel_mult_emb diff --git a/src/hirad/models/unet.py b/src/hirad/models/unet.py index db8e4f8..1333bda 100644 --- a/src/hirad/models/unet.py +++ b/src/hirad/models/unet.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn -from physicsnemo.models.meta import ModelMetaData +from hirad.models import ModelMetaData network_module = importlib.import_module("src.models") @@ -92,7 +92,7 @@ def __init__( model_type="SongUNetPosEmbd", **model_kwargs, ): - super().__init__(meta=MetaData) + super().__init__() #meta=MetaData self.img_channels = img_channels @@ -207,7 +207,7 @@ def __init__( model_type="SongUNet", **model_kwargs, ): - super().__init__(meta=MetaData("StormCastUNet")) + super().__init__() #meta=MetaData("StormCastUNet") if isinstance(img_resolution, int): self.img_shape_x = self.img_shape_y = img_resolution From a1daebec9efd89f08604af2323d92eb39b04047d Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 25 Apr 2025 16:53:10 +0200 Subject: [PATCH 14/66] add generate utils --- src/hirad/utils/deterministic_sampler.py | 2 +- src/hirad/utils/generate_utils.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) create mode 100644 src/hirad/utils/generate_utils.py diff --git a/src/hirad/utils/deterministic_sampler.py b/src/hirad/utils/deterministic_sampler.py index 4b2f32b..9fcea1d 100644 --- a/src/hirad/utils/deterministic_sampler.py +++ b/src/hirad/utils/deterministic_sampler.py @@ -19,7 +19,7 @@ import nvtx import torch -from src.models import EDMPrecond +from hirad.models import EDMPrecond # ruff: noqa: E731 diff --git a/src/hirad/utils/generate_utils.py b/src/hirad/utils/generate_utils.py new file mode 100644 index 0000000..29f7eb4 --- /dev/null +++ b/src/hirad/utils/generate_utils.py @@ -0,0 +1,24 @@ +import datetime +from hirad.datasets import init_dataset_from_config +from hirad.utils.function_utils import convert_datetime_to_cftime + + +def get_dataset_and_sampler(dataset_cfg, times, has_lead_time=False): + """ + Get a dataset and sampler for generation. + """ + (dataset, _) = init_dataset_from_config(dataset_cfg, batch_size=1) + if has_lead_time: + plot_times = times + else: + plot_times = [ + convert_datetime_to_cftime( + datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%S") + ) + for time in times + ] + all_times = dataset.time() + time_indices = [all_times.index(t) for t in plot_times] + sampler = time_indices + + return dataset, sampler \ No newline at end of file From 81647669d5cc090d59eae5611871b0b645b11ec0 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 25 Apr 2025 16:53:40 +0200 Subject: [PATCH 15/66] add sceleton for era5-cosmo dataset --- src/hirad/datasets/__init__.py | 3 +++ src/hirad/datasets/dataset.py | 6 ++---- src/hirad/datasets/era5_cosmo.py | 5 +++++ 3 files changed, 10 insertions(+), 4 deletions(-) create mode 100644 src/hirad/datasets/__init__.py create mode 100644 src/hirad/datasets/era5_cosmo.py diff --git a/src/hirad/datasets/__init__.py b/src/hirad/datasets/__init__.py new file mode 100644 index 0000000..706284e --- /dev/null +++ b/src/hirad/datasets/__init__.py @@ -0,0 +1,3 @@ +from .dataset import init_train_valid_datasets_from_config, init_dataset_from_config +from .era5_cosmo import ERA5_COSMO +from .base import DownscalingDataset \ No newline at end of file diff --git a/src/hirad/datasets/dataset.py b/src/hirad/datasets/dataset.py index 2a26630..1928e4d 100644 --- a/src/hirad/datasets/dataset.py +++ b/src/hirad/datasets/dataset.py @@ -21,14 +21,12 @@ from hirad.utils.function_utils import InfiniteSampler from hirad.distributed import DistributedManager -from . import base, cwb, hrrrmini, gefs_hrrr +from hirad.datasets import ERA5_COSMO # this maps all known dataset types to the corresponding init function known_datasets = { - "cwb": cwb.get_zarr_dataset, - "hrrr_mini": hrrrmini.HRRRMiniDataset, - "gefs_hrrr": gefs_hrrr.HrrrForecastGEFSDataset, + "era5_cosmo": ERA5_COSMO, } diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py new file mode 100644 index 0000000..2ec94be --- /dev/null +++ b/src/hirad/datasets/era5_cosmo.py @@ -0,0 +1,5 @@ +from hirad.datasets.base import DownscalingDataset, ChannelMetadata + +class ERA5_COSMO(DownscalingDataset): + def __init__(self): + super().__init__() \ No newline at end of file From a9f70349f5b07c4c436da53fa239e1dbb5295e59 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 25 Apr 2025 17:25:02 +0200 Subject: [PATCH 16/66] add in_channels to arg saving list --- src/hirad/training/train.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index d6fe563..0539aad 100644 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -180,6 +180,7 @@ def main(cfg: DictConfig) -> None: img_in_channels=img_in_channels + model_args["N_grid_channels"], **model_args, ) + model_args["image_in_channels"] = img_in_channels + model_args["N_grid_channels"] elif cfg.model.name == "lt_aware_ce_regression": model = UNet( img_in_channels=img_in_channels @@ -187,6 +188,7 @@ def main(cfg: DictConfig) -> None: + model_args["lead_time_channels"], **model_args, ) + model_args["image_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] elif cfg.model.name == "lt_aware_patched_diffusion": model = EDMPrecondSR( img_in_channels=img_in_channels @@ -194,15 +196,17 @@ def main(cfg: DictConfig) -> None: + model_args["lead_time_channels"], **model_args, ) + model_args["image_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] else: # diffusion or patched diffusion model = EDMPrecondSR( img_in_channels=img_in_channels + model_args["N_grid_channels"], **model_args, ) - + model_args["image_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model.train().requires_grad_(True).to(dist.device) - if not os.path.exists(os.path.join(checkpoint_dir, 'model_args.json')): + if dist.rank==0 and not os.path.exists(os.path.join(checkpoint_dir, 'model_args.json')): with open(os.path.join(checkpoint_dir, 'model_args.json'), 'w') as f: json.dump(model_args, f) @@ -235,18 +239,7 @@ def main(cfg: DictConfig) -> None: with open(regression_model_args_path, 'r') as f: regression_model_args = json.load(f) - if cfg.model.name == "lt_aware_patched_diffusion": - regression_net = UNet( - img_in_channels=img_in_channels - + model_args["N_grid_channels"] - + model_args["lead_time_channels"], - **regression_model_args, - ) - else: - regression_net = UNet( - img_in_channels=img_in_channels + model_args["N_grid_channels"], - **regression_model_args, - ) + regression_net = UNet(**regression_model_args) _ = load_checkpoint( path=regression_checkpoint_path, From e0059c83ab8ad7f30d8c10ecc46b744811fe10c8 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Mon, 28 Apr 2025 13:47:39 +0200 Subject: [PATCH 17/66] add dataset era5_cosmo --- src/hirad/datasets/dataset.py | 9 +-- src/hirad/datasets/era5_cosmo.py | 95 +++++++++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 7 deletions(-) diff --git a/src/hirad/datasets/dataset.py b/src/hirad/datasets/dataset.py index 1928e4d..6e402d9 100644 --- a/src/hirad/datasets/dataset.py +++ b/src/hirad/datasets/dataset.py @@ -21,7 +21,8 @@ from hirad.utils.function_utils import InfiniteSampler from hirad.distributed import DistributedManager -from hirad.datasets import ERA5_COSMO +from .era5_cosmo import ERA5_COSMO +from .base import DownscalingDataset # this maps all known dataset types to the corresponding init function @@ -38,9 +39,9 @@ def init_train_valid_datasets_from_config( validation_dataset_cfg: Union[dict, None] = None, train_test_split: bool = True, ) -> Tuple[ - base.DownscalingDataset, + DownscalingDataset, Iterable, - Union[base.DownscalingDataset, None], + Union[DownscalingDataset, None], Union[Iterable, None], ]: """ @@ -79,7 +80,7 @@ def init_dataset_from_config( dataloader_cfg: Union[dict, None] = None, batch_size: int = 1, seed: int = 0, -) -> Tuple[base.DownscalingDataset, Iterable]: +) -> Tuple[DownscalingDataset, Iterable]: dataset_cfg = copy.deepcopy(dataset_cfg) dataset_type = dataset_cfg.pop("type", "cwb") if "train_test_split" in dataset_cfg: diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index 2ec94be..597bdfc 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -1,5 +1,94 @@ -from hirad.datasets.base import DownscalingDataset, ChannelMetadata +from .base import DownscalingDataset, ChannelMetadata +import os +import numpy as np +import torch +from typing import List, Tuple +import yaml class ERA5_COSMO(DownscalingDataset): - def __init__(self): - super().__init__() \ No newline at end of file + def __init__(self, dataset_path: str): + super().__init__() + + #TODO switch hanbdling paths to Path rather than pure strings + self._dataset_path = dataset_path + self._era5_path = os.path.join(dataset_path, 'era-interpolated') + self._cosmo_path = os.path.join(dataset_path, 'cosmo') + self._info_path = os.path.join(dataset_path, 'info') + + # load file list (each file is one date-time state) + self._file_list = os.listdir(self._cosmo_path) + + # Load cosmo info and channel names + with open(os.path.join(self._info_path,'cosmo.yaml'), 'r') as file: + self._cosmo_info = yaml.safe_load(file) + self._cosmo_channels = [ChannelMetadata(name) for name in self._cosmo_info['select']] + + # Load era5 info and channel names + with open(os.path.join(self._info_path,'era.yaml'), 'r') as file: + self._era_info = yaml.safe_load(file) + self._era_channels = [ChannelMetadata(name) if len(name.split('_'))==1 + else ChannelMetadata(name.split('_')[0],name.split('_')[1]) + for name in self._era_info['select']] + + # Load stats for normalizing channels of input and output + + cosmo_stats = torch.load(os.path.join(self._info_path,'cosmo-stats'), weights_only=False) + print(cosmo_stats) + + + def __len__(self): + return len(self._file_list) + + + def longitude(self) -> np.ndarray: + """Get longitude values from the dataset.""" + lon_lat = torch.load(os.path.join(self._info_path,'cosmo-lat-lon'), weights_only=False) + return lon_lat[:,0] + + + def latitude(self) -> np.ndarray: + """Get latitude values from the dataset.""" + lon_lat = torch.load(os.path.join(self._info_path,'cosmo-lat-lon'), weights_only=False) + return lon_lat[:,1] + + + def input_channels(self) -> List[ChannelMetadata]: + """Metadata for the input channels. A list of ChannelMetadata, one for each channel""" + return self._era_channels + + + def output_channels(self) -> List[ChannelMetadata]: + """Metadata for the output channels. A list of ChannelMetadata, one for each channel""" + return self._cosmo_channels + + + def time(self) -> List: + """Get time values from the dataset.""" + #TODO Choose the time format and convert to that, currently it's a string from a filename + return [file.split('.')[0] for file in self._file_list] + + + def image_shape(self) -> Tuple[int, int]: + """Get the (height, width) of the data (same for input and output).""" + #TODO load from info, I hardcode it for now + return 390,582 + + + def normalize_input(self, x: np.ndarray) -> np.ndarray: + """Convert input from physical units to normalized data.""" + return (x - self.input_mean) / self.input_std + + + def denormalize_input(self, x: np.ndarray) -> np.ndarray: + """Convert input from normalized data to physical units.""" + return x * self.input_std + self.input_mean + + + def normalize_output(self, x: np.ndarray) -> np.ndarray: + """Convert output from physical units to normalized data.""" + return (x - self.output_mean) / self.output_std + + + def denormalize_output(self, x: np.ndarray) -> np.ndarray: + """Convert output from normalized data to physical units.""" + return x * self.output_std + self.output_mean \ No newline at end of file From da4cb6ccc9abc901128ea1db2d401387ff0374e1 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Mon, 28 Apr 2025 13:48:06 +0200 Subject: [PATCH 18/66] fix imports --- src/hirad/distributed/manager.py | 2 +- src/hirad/inference/generate.py | 159 ++++++++++++++++++++++++++++ src/hirad/models/layers.py | 2 +- src/hirad/models/preconditioning.py | 4 +- src/hirad/models/song_unet.py | 4 +- src/hirad/models/unet.py | 2 +- src/hirad/utils/checkpoint.py | 2 +- src/hirad/utils/generate_utils.py | 2 +- src/hirad/utils/inference_utils.py | 2 +- 9 files changed, 169 insertions(+), 10 deletions(-) create mode 100644 src/hirad/inference/generate.py diff --git a/src/hirad/distributed/manager.py b/src/hirad/distributed/manager.py index e80ce13..647d054 100644 --- a/src/hirad/distributed/manager.py +++ b/src/hirad/distributed/manager.py @@ -25,7 +25,7 @@ import torch import torch.distributed as dist -from hirad.distributed.config import ProcessGroupConfig, ProcessGroupNode +from .config import ProcessGroupConfig, ProcessGroupNode warnings.simplefilter("default", DeprecationWarning) diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py new file mode 100644 index 0000000..6e5273a --- /dev/null +++ b/src/hirad/inference/generate.py @@ -0,0 +1,159 @@ +import hydra +import os +import json +from omegaconf import OmegaConf, DictConfig +import torch +import torch._dynamo +import nvtx +import numpy as np +from hirad.distributed import DistributedManager +from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from einops import rearrange +from torch.distributed import gather + + +from hydra.utils import to_absolute_path +from hirad.models import EDMPrecond, UNet +from hirad.utils.stochastic_sampler import stochastic_sampler +from hirad.utils.deterministic_sampler import deterministic_sampler +from hirad.utils.inference_utils import ( + get_time_from_range, + regression_step, + diffusion_step, +) +from hirad.utils.checkpoint import load_checkpoint + + +from hirad.utils.generate_utils import ( + get_dataset_and_sampler +) + +from hirad.utils.train_helpers import set_patch_shape + + +@hydra.main(version_base="1.2", config_path="conf", config_name="config_generate") +def main(cfg: DictConfig) -> None: + """Generate random dowscaled atmospheric states using the techniques described in the paper + "Elucidating the Design Space of Diffusion-Based Generative Models". + """ + + # Initialize distributed manager + DistributedManager.initialize() + dist = DistributedManager() + device = dist.device + + # Initialize logger + logger = PythonLogger("generate") # General python logger + logger0 = RankZeroLoggingWrapper(logger, dist) + logger.file_logging("generate.log") + + # Handle the batch size + seeds = list(np.arange(cfg.generation.num_ensembles)) + num_batches = ( + (len(seeds) - 1) // (cfg.generation.seed_batch_size * dist.world_size) + 1 + ) * dist.world_size + all_batches = torch.as_tensor(seeds).tensor_split(num_batches) + rank_batches = all_batches[dist.rank :: dist.world_size] + + # Synchronize + if dist.world_size > 1: + torch.distributed.barrier() + + # Parse the inference input times + if cfg.generation.times_range and cfg.generation.times: + raise ValueError("Either times_range or times must be provided, but not both") + if cfg.generation.times_range: + times = get_time_from_range(cfg.generation.times_range) #TODO check what time formats we are using and adapt + else: + times = cfg.generation.times + + # Create dataset object + dataset_cfg = OmegaConf.to_container(cfg.dataset) + if "has_lead_time" in cfg.generation: + has_lead_time = cfg.generation["has_lead_time"] + else: + has_lead_time = False + dataset, sampler = get_dataset_and_sampler( + dataset_cfg=dataset_cfg, times=times, has_lead_time=has_lead_time + ) + img_shape = dataset.image_shape() + img_out_channels = len(dataset.output_channels()) + + # Parse the patch shape + if hasattr(cfg.generation, "patch_shape_x"): # TODO better config handling + patch_shape_x = cfg.generation.patch_shape_x + else: + patch_shape_x = None + if hasattr(cfg.generation, "patch_shape_y"): + patch_shape_y = cfg.generation.patch_shape_y + else: + patch_shape_y = None + patch_shape = (patch_shape_y, patch_shape_x) + img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) + if patch_shape != img_shape: + logger0.info("Patch-based training enabled") + else: + logger0.info("Patch-based training disabled") + + # Parse the inference mode + if cfg.generation.inference_mode == "regression": + load_net_reg, load_net_res = True, False + elif cfg.generation.inference_mode == "diffusion": + load_net_reg, load_net_res = False, True + elif cfg.generation.inference_mode == "all": + load_net_reg, load_net_res = True, True + else: + raise ValueError(f"Invalid inference mode {cfg.generation.inference_mode}") + + # Load diffusion network, move to device, change precision + if load_net_res: + res_ckpt_path = cfg.generation.io.res_ckpt_path + logger0.info(f'Loading residual network from "{res_ckpt_path}"...') + + diffusion_model_args_path = os.path.join(res_ckpt_path, 'model_args.json') + if not os.path.isfile(diffusion_model_args_path): + raise FileNotFoundError(f"Missing config file at '{diffusion_model_args_path}'.") + with open(diffusion_model_args_path, 'r') as f: + diffusion_model_args = json.load(f) + + net_res = EDMPrecond(**diffusion_model_args) + + _ = load_checkpoint( + path=res_ckpt_path, + model=net_res, + device=dist.device + ) + + net_res = net_res.eval().to(device).to(memory_format=torch.channels_last) + if cfg.generation.perf.force_fp16: + net_res.use_fp16 = True + else: + net_res = None + + # load regression network, move to device, change precision + if load_net_reg: + reg_ckpt_path = cfg.generation.io.reg_ckpt_path + logger0.info(f'Loading network from "{reg_ckpt_path}"...') + + + regression_model_args_path = os.path.join(reg_ckpt_path, 'model_args.json') + if not os.path.isfile(regression_model_args_path): + raise FileNotFoundError(f"Missing config file at '{regression_model_args_path}'.") + with open(regression_model_args_path, 'r') as f: + regression_model_args = json.load(f) + + net_reg = EDMPrecond(**regression_model_args) + + _ = load_checkpoint( + path=reg_ckpt_path, + model=net_reg, + device=dist.device + ) + + net_reg = net_reg.eval().to(device).to(memory_format=torch.channels_last) + if cfg.generation.perf.force_fp16: + net_reg.use_fp16 = True + else: + net_reg = None \ No newline at end of file diff --git a/src/hirad/models/layers.py b/src/hirad/models/layers.py index d5a1ab2..ddb23b6 100644 --- a/src/hirad/models/layers.py +++ b/src/hirad/models/layers.py @@ -26,7 +26,7 @@ from einops import rearrange from torch.nn.functional import silu -from src.utils.model_utils import weight_init +from hirad.utils.model_utils import weight_init class Linear(torch.nn.Module): diff --git a/src/hirad/models/preconditioning.py b/src/hirad/models/preconditioning.py index b0e924e..9c10004 100644 --- a/src/hirad/models/preconditioning.py +++ b/src/hirad/models/preconditioning.py @@ -29,11 +29,11 @@ import torch import torch.nn as nn -from hirad.models import ( +from .song_unet import ( DhariwalUNet, # noqa: F401 for globals SongUNet, # noqa: F401 for globals ) -from hirad.models import ModelMetaData +from .meta import ModelMetaData network_module = importlib.import_module("physicsnemo.models.diffusion") diff --git a/src/hirad/models/song_unet.py b/src/hirad/models/song_unet.py index 5bfca8a..6267dfc 100644 --- a/src/hirad/models/song_unet.py +++ b/src/hirad/models/song_unet.py @@ -29,7 +29,7 @@ from torch.utils.checkpoint import checkpoint import torch.nn as nn -from hirad.models import ( +from .layers import ( Conv2d, FourierEmbedding, GroupNorm, @@ -37,7 +37,7 @@ PositionalEmbedding, UNetBlock, ) -from hirad.models import ModelMetaData +from .meta import ModelMetaData @dataclass diff --git a/src/hirad/models/unet.py b/src/hirad/models/unet.py index 1333bda..d81a734 100644 --- a/src/hirad/models/unet.py +++ b/src/hirad/models/unet.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn -from hirad.models import ModelMetaData +from .meta import ModelMetaData network_module = importlib.import_module("src.models") diff --git a/src/hirad/utils/checkpoint.py b/src/hirad/utils/checkpoint.py index 03b423d..e0f8d58 100644 --- a/src/hirad/utils/checkpoint.py +++ b/src/hirad/utils/checkpoint.py @@ -24,7 +24,7 @@ from torch.optim.lr_scheduler import _LRScheduler from hirad.distributed import DistributedManager -from hirad.utils.console import PythonLogger +from .console import PythonLogger optimizer = NewType("optimizer", torch.optim.Optimizer) scheduler = NewType("scheduler", _LRScheduler) diff --git a/src/hirad/utils/generate_utils.py b/src/hirad/utils/generate_utils.py index 29f7eb4..b99852f 100644 --- a/src/hirad/utils/generate_utils.py +++ b/src/hirad/utils/generate_utils.py @@ -1,6 +1,6 @@ import datetime from hirad.datasets import init_dataset_from_config -from hirad.utils.function_utils import convert_datetime_to_cftime +from .function_utils import convert_datetime_to_cftime def get_dataset_and_sampler(dataset_cfg, times, has_lead_time=False): diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index 842bdd3..b158ec0 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -21,7 +21,7 @@ import torch import tqdm -from src.utils.function_utils import StackedRandomGenerator, time_range +from .function_utils import StackedRandomGenerator, time_range ############################################################################ # CorrDiff Generation Utilities # From c284d0aa948bca579290fc9abe1ebb75212e3ad4 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Tue, 29 Apr 2025 14:51:43 +0200 Subject: [PATCH 19/66] add getitem to dataset --- src/hirad/datasets/era5_cosmo.py | 40 +++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index 597bdfc..4d9187d 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -33,8 +33,28 @@ def __init__(self, dataset_path: str): # Load stats for normalizing channels of input and output cosmo_stats = torch.load(os.path.join(self._info_path,'cosmo-stats'), weights_only=False) - print(cosmo_stats) - + self.output_mean = cosmo_stats['mean'] + self.output_std = cosmo_stats['stdev'] + + era_stats = torch.load(os.path.join(self._info_path,'era-stats'), weights_only=False) + #TODO Switch from cosmo to era stats once era-interpolated has all channels + self.input_mean = cosmo_stats['mean'] + self.input_std = cosmo_stats['stdev'] + + + def __getitem__(self, idx): + # get era5 data point + era5_data = torch.load(os.path.join(self._era5_path,self._file_list[idx]), weights_only=False)\ + .squeeze()\ + .reshape(-1,*self.image_shape()) + era5_data = self.normalize_input(era5_data) + # get cosmo data point + cosmo_data = torch.load(os.path.join(self._cosmo_path,self._file_list[idx]), weights_only=False)\ + .squeeze()\ + .reshape(-1,*self.image_shape()) + cosmo_data = self.normalize_output(cosmo_data) + # return samples + return cosmo_data, era5_data, 0 def __len__(self): return len(self._file_list) @@ -70,25 +90,29 @@ def time(self) -> List: def image_shape(self) -> Tuple[int, int]: """Get the (height, width) of the data (same for input and output).""" - #TODO load from info, I hardcode it for now - return 390,582 + #TODO load from info, I hardcode it for now (cosmo from anemoi-datasets minus trim-edge=20) + return 350,542 def normalize_input(self, x: np.ndarray) -> np.ndarray: """Convert input from physical units to normalized data.""" - return (x - self.input_mean) / self.input_std + return (x - self.input_mean.reshape((self.input_mean.shape[0],1,1))) \ + / self.input_std.reshape((self.input_std.shape[0],1,1)) def denormalize_input(self, x: np.ndarray) -> np.ndarray: """Convert input from normalized data to physical units.""" - return x * self.input_std + self.input_mean + return x * self.input_std.reshape((self.input_std.shape[0],1,1)) \ + + self.input_mean.reshape((self.input_mean.shape[0],1,1)) def normalize_output(self, x: np.ndarray) -> np.ndarray: """Convert output from physical units to normalized data.""" - return (x - self.output_mean) / self.output_std + return (x - self.output_mean.reshape((self.output_mean.shape[0],1,1))) \ + / self.output_std.reshape((self.output_std.shape[0],1,1)) def denormalize_output(self, x: np.ndarray) -> np.ndarray: """Convert output from normalized data to physical units.""" - return x * self.output_std + self.output_mean \ No newline at end of file + return x * self.output_std.reshape((self.output_std.shape[0],1,1)) \ + + self.output_mean.reshape((self.output_mean.shape[0],1,1)) \ No newline at end of file From e7d5b1b6324a8affba89bc216074395ca7b8af25 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Tue, 29 Apr 2025 15:11:52 +0200 Subject: [PATCH 20/66] add grid flip to start at top left corner --- src/hirad/datasets/era5_cosmo.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index 4d9187d..89a0581 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -44,14 +44,19 @@ def __init__(self, dataset_path: str): def __getitem__(self, idx): # get era5 data point - era5_data = torch.load(os.path.join(self._era5_path,self._file_list[idx]), weights_only=False)\ - .squeeze()\ - .reshape(-1,*self.image_shape()) + # squeeze the ensemble dimesnsion + # reshape to image_shape + # flip so that it starts in top-left corner (by default it is bottom left) + era5_data = np.flip(torch.load(os.path.join(self._era5_path,self._file_list[idx]), weights_only=False)\ + .squeeze() \ + .reshape(-1,*self.image_shape()), + 1) era5_data = self.normalize_input(era5_data) # get cosmo data point - cosmo_data = torch.load(os.path.join(self._cosmo_path,self._file_list[idx]), weights_only=False)\ - .squeeze()\ - .reshape(-1,*self.image_shape()) + cosmo_data = np.flip(torch.load(os.path.join(self._cosmo_path,self._file_list[idx]), weights_only=False)\ + .squeeze() \ + .reshape(-1,*self.image_shape()), + 1) cosmo_data = self.normalize_output(cosmo_data) # return samples return cosmo_data, era5_data, 0 From d093e662e23f079f891accbd656b258fa7e5b8ac Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Tue, 29 Apr 2025 15:26:51 +0200 Subject: [PATCH 21/66] small fix --- src/hirad/datasets/dataset.py | 2 +- src/hirad/datasets/era5_cosmo.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/hirad/datasets/dataset.py b/src/hirad/datasets/dataset.py index 6e402d9..6cc6165 100644 --- a/src/hirad/datasets/dataset.py +++ b/src/hirad/datasets/dataset.py @@ -82,7 +82,7 @@ def init_dataset_from_config( seed: int = 0, ) -> Tuple[DownscalingDataset, Iterable]: dataset_cfg = copy.deepcopy(dataset_cfg) - dataset_type = dataset_cfg.pop("type", "cwb") + dataset_type = dataset_cfg.pop("type", "era5_cosmo") if "train_test_split" in dataset_cfg: # handled by init_train_valid_datasets_from_config del dataset_cfg["train_test_split"] diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index 89a0581..e7de456 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -13,7 +13,7 @@ def __init__(self, dataset_path: str): self._dataset_path = dataset_path self._era5_path = os.path.join(dataset_path, 'era-interpolated') self._cosmo_path = os.path.join(dataset_path, 'cosmo') - self._info_path = os.path.join(dataset_path, 'info') + self._info_path = os.path.join(dataset_path, 'old/info') # load file list (each file is one date-time state) self._file_list = os.listdir(self._cosmo_path) @@ -37,9 +37,8 @@ def __init__(self, dataset_path: str): self.output_std = cosmo_stats['stdev'] era_stats = torch.load(os.path.join(self._info_path,'era-stats'), weights_only=False) - #TODO Switch from cosmo to era stats once era-interpolated has all channels - self.input_mean = cosmo_stats['mean'] - self.input_std = cosmo_stats['stdev'] + self.input_mean = era_stats['mean'] + self.input_std = era_stats['stdev'] def __getitem__(self, idx): @@ -59,7 +58,7 @@ def __getitem__(self, idx): 1) cosmo_data = self.normalize_output(cosmo_data) # return samples - return cosmo_data, era5_data, 0 + return torch.tensor(cosmo_data), torch.tensor(era5_data), 0 def __len__(self): return len(self._file_list) From 7e695f0305f6190aa4484c020926cdafb62d64e1 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 7 May 2025 15:19:32 +0200 Subject: [PATCH 22/66] update everything for training --- src/hirad/conf/dataset/era_cosmo.yaml | 2 + src/hirad/conf/model/era_cosmo_diffusion.yaml | 5 + .../conf/model/era_cosmo_regression.yaml | 2 + src/hirad/conf/training.yaml | 0 .../conf/training/era_cosmo_diffusion.yaml | 41 +++ .../conf/training/era_cosmo_regression.yaml | 38 +++ .../conf/training_era_cosmo_diffusion.yaml | 19 ++ .../conf/training_era_cosmo_regression.yaml | 19 ++ src/hirad/datasets/era5_cosmo.py | 13 +- src/hirad/distributed/config.py | 2 +- src/hirad/distributed/manager.py | 34 ++- src/hirad/models/__init__.py | 7 +- src/hirad/models/dhariwal_unet.py | 259 ++++++++++++++++++ src/hirad/models/preconditioning.py | 6 +- src/hirad/models/unet.py | 2 +- src/hirad/testrun.sh | 41 +++ src/hirad/training/train.py | 37 +-- 17 files changed, 485 insertions(+), 42 deletions(-) create mode 100644 src/hirad/conf/dataset/era_cosmo.yaml create mode 100644 src/hirad/conf/model/era_cosmo_diffusion.yaml create mode 100644 src/hirad/conf/model/era_cosmo_regression.yaml delete mode 100644 src/hirad/conf/training.yaml create mode 100644 src/hirad/conf/training/era_cosmo_diffusion.yaml create mode 100644 src/hirad/conf/training/era_cosmo_regression.yaml create mode 100644 src/hirad/conf/training_era_cosmo_diffusion.yaml create mode 100644 src/hirad/conf/training_era_cosmo_regression.yaml create mode 100644 src/hirad/models/dhariwal_unet.py create mode 100644 src/hirad/testrun.sh diff --git a/src/hirad/conf/dataset/era_cosmo.yaml b/src/hirad/conf/dataset/era_cosmo.yaml new file mode 100644 index 0000000..854b775 --- /dev/null +++ b/src/hirad/conf/dataset/era_cosmo.yaml @@ -0,0 +1,2 @@ +type: era5_cosmo +dataset_path: /store_new/mch/msopr/hirad-gen/basic-torch \ No newline at end of file diff --git a/src/hirad/conf/model/era_cosmo_diffusion.yaml b/src/hirad/conf/model/era_cosmo_diffusion.yaml new file mode 100644 index 0000000..06aa2a4 --- /dev/null +++ b/src/hirad/conf/model/era_cosmo_diffusion.yaml @@ -0,0 +1,5 @@ +name: diffusion + # Name of the preconditioner +hr_mean_conditioning: True + # High-res mean (regression's output) as additional condition +scale_cond_input: False \ No newline at end of file diff --git a/src/hirad/conf/model/era_cosmo_regression.yaml b/src/hirad/conf/model/era_cosmo_regression.yaml new file mode 100644 index 0000000..487eb4b --- /dev/null +++ b/src/hirad/conf/model/era_cosmo_regression.yaml @@ -0,0 +1,2 @@ +name: regression +hr_mean_conditioning: False \ No newline at end of file diff --git a/src/hirad/conf/training.yaml b/src/hirad/conf/training.yaml deleted file mode 100644 index e69de29..0000000 diff --git a/src/hirad/conf/training/era_cosmo_diffusion.yaml b/src/hirad/conf/training/era_cosmo_diffusion.yaml new file mode 100644 index 0000000..b61603a --- /dev/null +++ b/src/hirad/conf/training/era_cosmo_diffusion.yaml @@ -0,0 +1,41 @@ +# Hyperparameters +hp: + training_duration: 128 + # Training duration based on the number of processed samples + total_batch_size: 16 + # Total batch size + batch_size_per_gpu: "auto" + # Batch size per GPU + lr: 0.0002 + # Learning rate + grad_clip_threshold: null + # no gradient clipping for defualt non-patch-based training + lr_decay: 1 + # LR decay rate + lr_rampup: 0 + # Rampup for learning rate, in number of samples + +# Performance +perf: + fp_optimizations: amp-bf16 + # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] + # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} + dataloader_workers: 4 + # DataLoader worker processes + songunet_checkpoint_level: 0 # 0 means no checkpointing + # Gradient checkpointing level, value is number of layers to checkpoint + +# I/O +io: + regression_checkpoint_path: /scratch/mch/pstamenk/output/regression/checkpoints_regression + # Where to load the regression checkpoint + print_progress_freq: 32 + # How often to print progress + save_checkpoint_freq: 5000 + # How often to save the checkpoints, measured in number of processed samples + validation_freq: 5000 + # how often to record the validation loss, measured in number of processed samples + validation_steps: 10 + # how many loss evaluations are used to compute the validation loss per checkpoint + # how many loss evaluations are used to compute the validation loss per checkpoint + checkpoint_dir: /scratch/mch/pstamenk/output/diffusion \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_regression.yaml b/src/hirad/conf/training/era_cosmo_regression.yaml new file mode 100644 index 0000000..7c443f0 --- /dev/null +++ b/src/hirad/conf/training/era_cosmo_regression.yaml @@ -0,0 +1,38 @@ +# Hyperparameters +hp: + training_duration: 16 + # Training duration based on the number of processed samples + total_batch_size: 16 + # Total batch size + batch_size_per_gpu: "auto" + # Batch size per GPU + lr: 0.0002 + # Learning rate + grad_clip_threshold: null + # no gradient clipping for defualt non-patch-based training + lr_decay: 1 + # LR decay rate + lr_rampup: 0 + # Rampup for learning rate, in number of samples + +# Performance +perf: + fp_optimizations: amp-bf16 + # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] + # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} + dataloader_workers: 4 + # DataLoader worker processes + songunet_checkpoint_level: 0 # 0 means no checkpointing + # Gradient checkpointing level, value is number of layers to checkpoint + +# I/O +io: + print_progress_freq: 32 + # How often to print progress + save_checkpoint_freq: 5000 + # How often to save the checkpoints, measured in number of processed samples + validation_freq: 5000 + # how often to record the validation loss, measured in number of processed samples + validation_steps: 10 + # how many loss evaluations are used to compute the validation loss per checkpoint + checkpoint_dir: /scratch/mch/pstamenk/output/regression \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_diffusion.yaml b/src/hirad/conf/training_era_cosmo_diffusion.yaml new file mode 100644 index 0000000..7ee7dba --- /dev/null +++ b/src/hirad/conf/training_era_cosmo_diffusion.yaml @@ -0,0 +1,19 @@ +hydra: + job: + chdir: true + name: diffusion + run: + dir: /scratch/mch/pstamenk/output/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + + # Dataset + - dataset/era_cosmo + + # Model + - model/era_cosmo_diffusion + + # Training + - training/era_cosmo_diffusion \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_regression.yaml b/src/hirad/conf/training_era_cosmo_regression.yaml new file mode 100644 index 0000000..d857d12 --- /dev/null +++ b/src/hirad/conf/training_era_cosmo_regression.yaml @@ -0,0 +1,19 @@ +hydra: + job: + chdir: true + name: regression + run: + dir: /scratch/mch/pstamenk/output/${hydra:job.name} + +# Get defaults +defaults: + - _self_ + + # Dataset + - dataset/era_cosmo + + # Model + - model/era_cosmo_regression + + # Training + - training/era_cosmo_regression \ No newline at end of file diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index e7de456..8b0d60f 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -4,6 +4,7 @@ import torch from typing import List, Tuple import yaml +import torch.nn.functional as F class ERA5_COSMO(DownscalingDataset): def __init__(self, dataset_path: str): @@ -42,23 +43,27 @@ def __init__(self, dataset_path: str): def __getitem__(self, idx): + """Get cosmo and era5 interpolated to cosmo grid""" # get era5 data point # squeeze the ensemble dimesnsion # reshape to image_shape # flip so that it starts in top-left corner (by default it is bottom left) + orig_shape = [350,542] #TODO currently padding to be divisible by 16 era5_data = np.flip(torch.load(os.path.join(self._era5_path,self._file_list[idx]), weights_only=False)\ .squeeze() \ - .reshape(-1,*self.image_shape()), + .reshape(-1,*orig_shape), 1) era5_data = self.normalize_input(era5_data) # get cosmo data point cosmo_data = np.flip(torch.load(os.path.join(self._cosmo_path,self._file_list[idx]), weights_only=False)\ .squeeze() \ - .reshape(-1,*self.image_shape()), + .reshape(-1,*orig_shape), 1) cosmo_data = self.normalize_output(cosmo_data) # return samples - return torch.tensor(cosmo_data), torch.tensor(era5_data), 0 + return F.pad(torch.tensor(cosmo_data), pad=(1,1,1,1), mode='constant', value=0), \ + F.pad(torch.tensor(era5_data), pad=(1,1,1,1), mode='constant', value=0), \ + 0 def __len__(self): return len(self._file_list) @@ -95,7 +100,7 @@ def time(self) -> List: def image_shape(self) -> Tuple[int, int]: """Get the (height, width) of the data (same for input and output).""" #TODO load from info, I hardcode it for now (cosmo from anemoi-datasets minus trim-edge=20) - return 350,542 + return 352,544 #TODO 350,542 is orig size, UNet requires dimenions divisible by 16, for now, I just add zeros to orig images def normalize_input(self, x: np.ndarray) -> np.ndarray: diff --git a/src/hirad/distributed/config.py b/src/hirad/distributed/config.py index c5414b4..2808d92 100644 --- a/src/hirad/distributed/config.py +++ b/src/hirad/distributed/config.py @@ -84,7 +84,7 @@ class ProcessGroupConfig: Examples -------- - >>> from physicsnemo.distributed import ProcessGroupNode, ProcessGroupConfig + >>> from hirad.distributed import ProcessGroupNode, ProcessGroupConfig >>> >>> # Create world group that contains all processes that are part of this job >>> world = ProcessGroupNode("world") diff --git a/src/hirad/distributed/manager.py b/src/hirad/distributed/manager.py index 647d054..eca46c6 100644 --- a/src/hirad/distributed/manager.py +++ b/src/hirad/distributed/manager.py @@ -348,7 +348,7 @@ def initialize_slurm(port): rank = int(os.environ.get("SLURM_PROCID")) world_size = int(os.environ.get("SLURM_NPROCS")) local_rank = int(os.environ.get("SLURM_LOCALID")) - addr = os.environ.get("SLURM_LAUNCH_NODE_IPADDR") + addr = os.environ.get("MASTER_ADDR") DistributedManager.setup( rank=rank, @@ -388,6 +388,7 @@ def initialize(): port = os.getenv("MASTER_PORT", "12355") # https://pytorch.org/docs/master/notes/cuda.html#id5 # was changed in version 2.2 + #TODO why is setting this important? if torch.__version__ < (2, 2): os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" else: @@ -542,22 +543,25 @@ def setup( f"cuda:{manager.local_rank}" if torch.cuda.is_available() else "cpu" ) + #TODO device_id makes the init hang, couldn't figure out why if manager._distributed: # Setup distributed process group - try: - dist.init_process_group( - backend, - rank=manager.rank, - world_size=manager.world_size, - device_id=manager.device, - ) - except TypeError: - # device_id only introduced in PyTorch 2.3 - dist.init_process_group( - backend, - rank=manager.rank, - world_size=manager.world_size, - ) + # try: + dist.init_process_group( + backend, + rank=manager.rank, + world_size=manager.world_size, + ) + # rank=manager.rank, + # world_size=manager.world_size, + # device_id=manager.device, + # except TypeError: + # # device_id only introduced in PyTorch 2.3 + # dist.init_process_group( + # backend, + # rank=manager.rank, + # world_size=manager.world_size, + # ) if torch.cuda.is_available(): # Set device for this process and empty cache to optimize memory usage diff --git a/src/hirad/models/__init__.py b/src/hirad/models/__init__.py index f17e5ce..3ab4a6f 100644 --- a/src/hirad/models/__init__.py +++ b/src/hirad/models/__init__.py @@ -1,5 +1,6 @@ -from .unet import UNet +from .layers import Linear, Conv2d, GroupNorm, AttentionOp, UNetBlock, PositionalEmbedding, FourierEmbedding +from .meta import ModelMetaData from .song_unet import SongUNet, SongUNetPosEmbd, SongUNetPosLtEmbd +from .dhariwal_unet import DhariwalUNet +from .unet import UNet from .preconditioning import EDMPrecondSR, EDMPrecond -from .layers import Linear, Conv2d, GroupNorm, AttentionOp, UNetBlock, PositionalEmbedding, FourierEmbedding -from .meta import ModelMetaData \ No newline at end of file diff --git a/src/hirad/models/dhariwal_unet.py b/src/hirad/models/dhariwal_unet.py new file mode 100644 index 0000000..3880cd0 --- /dev/null +++ b/src/hirad/models/dhariwal_unet.py @@ -0,0 +1,259 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model architectures used in the paper "Elucidating the Design Space of +Diffusion-Based Generative Models". +""" + +from dataclasses import dataclass +from typing import List + +import numpy as np +import torch +from torch.nn.functional import silu +import torch.nn as nn + +from .layers import ( + Conv2d, + GroupNorm, + Linear, + PositionalEmbedding, + UNetBlock, +) +from .meta import ModelMetaData + + +@dataclass +class MetaData(ModelMetaData): + name: str = "DhariwalUNet" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = True + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class DhariwalUNet(nn.Module): + """ + Reimplementation of the ADM architecture, a U-Net variant, with optional + self-attention. + + This model supports conditional and unconditional setups, as well as several + options for various internal architectural choices such as encoder and decoder + type, embedding type, etc., making it flexible and adaptable to different tasks + and configurations. + + Parameters + ----------- + img_resolution : int + The resolution of the input/output image. + in_channels : int + Number of channels in the input image. + out_channels : int + Number of channels in the output image. + label_dim : int, optional + Number of class labels; 0 indicates an unconditional model. By default 0. + augment_dim : int, optional + Dimensionality of augmentation labels; 0 means no augmentation. By default 0. + model_channels : int, optional + Base multiplier for the number of channels across the network, by default 192. + channel_mult : List[int], optional + Per-resolution multipliers for the number of channels. By default [1,2,3,4]. + channel_mult_emb : int, optional + Multiplier for the dimensionality of the embedding vector. By default 4. + num_blocks : int, optional + Number of residual blocks per resolution. By default 3. + attn_resolutions : List[int], optional + Resolutions at which self-attention layers are applied. By default [32, 16, 8]. + dropout : float, optional + Dropout probability applied to intermediate activations. By default 0.10. + label_dropout : float, optional + Dropout probability of class labels for classifier-free guidance. By default 0.0. + + Reference + ---------- + Reference: Dhariwal, P. and Nichol, A., 2021. Diffusion models beat gans on image + synthesis. Advances in neural information processing systems, 34, pp.8780-8794. + + Note + ----- + Equivalent to the original implementation by Dhariwal and Nichol, available at + https://github.com/openai/guided-diffusion + + Example + -------- + >>> model = DhariwalUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> noise_labels = torch.randn([1]) + >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> input_image = torch.ones([1, 2, 16, 16]) + >>> output_image = model(input_image, noise_labels, class_labels) + """ + + def __init__( + self, + img_resolution: int, + in_channels: int, + out_channels: int, + label_dim: int = 0, + augment_dim: int = 0, + model_channels: int = 192, + channel_mult: List[int] = [1, 2, 3, 4], + channel_mult_emb: int = 4, + num_blocks: int = 3, + attn_resolutions: List[int] = [32, 16, 8], + dropout: float = 0.10, + label_dropout: float = 0.0, + ): + super().__init__(meta=MetaData()) + self.label_dropout = label_dropout + emb_channels = model_channels * channel_mult_emb + init = dict( + init_mode="kaiming_uniform", + init_weight=np.sqrt(1 / 3), + init_bias=np.sqrt(1 / 3), + ) + init_zero = dict(init_mode="kaiming_uniform", init_weight=0, init_bias=0) + block_kwargs = dict( + emb_channels=emb_channels, + channels_per_head=64, + dropout=dropout, + init=init, + init_zero=init_zero, + ) + + # Mapping. + self.map_noise = PositionalEmbedding(num_channels=model_channels) + self.map_augment = ( + Linear( + in_features=augment_dim, + out_features=model_channels, + bias=False, + **init_zero, + ) + if augment_dim + else None + ) + self.map_layer0 = Linear( + in_features=model_channels, out_features=emb_channels, **init + ) + self.map_layer1 = Linear( + in_features=emb_channels, out_features=emb_channels, **init + ) + self.map_label = ( + Linear( + in_features=label_dim, + out_features=emb_channels, + bias=False, + init_mode="kaiming_normal", + init_weight=np.sqrt(label_dim), + ) + if label_dim + else None + ) + + # Encoder. + self.enc = torch.nn.ModuleDict() + cout = in_channels + for level, mult in enumerate(channel_mult): + res = img_resolution >> level + if level == 0: + cin = cout + cout = model_channels * mult + self.enc[f"{res}x{res}_conv"] = Conv2d( + in_channels=cin, out_channels=cout, kernel=3, **init + ) + else: + self.enc[f"{res}x{res}_down"] = UNetBlock( + in_channels=cout, out_channels=cout, down=True, **block_kwargs + ) + for idx in range(num_blocks): + cin = cout + cout = model_channels * mult + self.enc[f"{res}x{res}_block{idx}"] = UNetBlock( + in_channels=cin, + out_channels=cout, + attention=(res in attn_resolutions), + **block_kwargs, + ) + skips = [block.out_channels for block in self.enc.values()] + + # Decoder. + self.dec = torch.nn.ModuleDict() + for level, mult in reversed(list(enumerate(channel_mult))): + res = img_resolution >> level + if level == len(channel_mult) - 1: + self.dec[f"{res}x{res}_in0"] = UNetBlock( + in_channels=cout, out_channels=cout, attention=True, **block_kwargs + ) + self.dec[f"{res}x{res}_in1"] = UNetBlock( + in_channels=cout, out_channels=cout, **block_kwargs + ) + else: + self.dec[f"{res}x{res}_up"] = UNetBlock( + in_channels=cout, out_channels=cout, up=True, **block_kwargs + ) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = model_channels * mult + self.dec[f"{res}x{res}_block{idx}"] = UNetBlock( + in_channels=cin, + out_channels=cout, + attention=(res in attn_resolutions), + **block_kwargs, + ) + self.out_norm = GroupNorm(num_channels=cout) + self.out_conv = Conv2d( + in_channels=cout, out_channels=out_channels, kernel=3, **init_zero + ) + + def forward(self, x, noise_labels, class_labels, augment_labels=None): + # Mapping. + emb = self.map_noise(noise_labels) + if self.map_augment is not None and augment_labels is not None: + emb = emb + self.map_augment(augment_labels) + emb = silu(self.map_layer0(emb)) + emb = self.map_layer1(emb) + if self.map_label is not None: + tmp = class_labels + if self.training and self.label_dropout: + tmp = tmp * ( + torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout + ).to(tmp.dtype) + emb = emb + self.map_label(tmp) + emb = silu(emb) + + # Encoder. + skips = [] + for block in self.enc.values(): + x = block(x, emb) if isinstance(block, UNetBlock) else block(x) + skips.append(x) + + # Decoder. + for block in self.dec.values(): + if x.shape[1] != block.in_channels: + x = torch.cat([x, skips.pop()], dim=1) + x = block(x, emb) + x = self.out_conv(silu(self.out_norm(x))) + return x diff --git a/src/hirad/models/preconditioning.py b/src/hirad/models/preconditioning.py index 9c10004..c66b6b6 100644 --- a/src/hirad/models/preconditioning.py +++ b/src/hirad/models/preconditioning.py @@ -30,12 +30,14 @@ import torch.nn as nn from .song_unet import ( - DhariwalUNet, # noqa: F401 for globals SongUNet, # noqa: F401 for globals ) +from .dhariwal_unet import ( + DhariwalUNet, # noqa: F401 for globals +) from .meta import ModelMetaData -network_module = importlib.import_module("physicsnemo.models.diffusion") +network_module = importlib.import_module("hirad.models") @dataclass diff --git a/src/hirad/models/unet.py b/src/hirad/models/unet.py index d81a734..10079ec 100644 --- a/src/hirad/models/unet.py +++ b/src/hirad/models/unet.py @@ -22,7 +22,7 @@ from .meta import ModelMetaData -network_module = importlib.import_module("src.models") +network_module = importlib.import_module("hirad.models") @dataclass diff --git a/src/hirad/testrun.sh b/src/hirad/testrun.sh new file mode 100644 index 0000000..ee4a977 --- /dev/null +++ b/src/hirad/testrun.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +#SBATCH --job-name="testrun" + +### HARDWARE ### +#SBATCH --partition=debug +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=4 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=16G +#SBATCH --time=00:30:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=/scratch/mch/pstamenk/logs/regression_test.log +#SBATCH --error=/scratch/mch/pstamenk/logs/regression_test.err + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=ENV + +# Get number of physical cores using Python +PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# Use SLURM_NTASKS (number of processes to be launched by torchrun) +LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# Compute threads per process +OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +export OMP_NUM_THREADS=$OMP_THREADS +echo "Node: $(hostname)" +echo "Physical cores: $PHYSICAL_CORES" +echo "Local processes: $LOCAL_PROCS" +echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" + +# activate conda env +CONDA_ENV=train +source /users/pstamenk/.bashrc +mamba activate $CONDA_ENV + +# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml +torchrun --nproc-per-node=$LOCAL_PROCS src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml \ No newline at end of file diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 0539aad..88b2118 100644 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -9,6 +9,7 @@ from hydra.utils import to_absolute_path from torch.utils.tensorboard import SummaryWriter from torch.nn.parallel import DistributedDataParallel +from torchinfo import summary from hirad.distributed import DistributedManager from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper @@ -20,9 +21,10 @@ from hirad.losses import ResLoss, RegressionLoss, RegressionLossCE from hirad.datasets import init_train_valid_datasets_from_config -@hydra.main(version_base=None, config_path="conf", config_name="training") +from matplotlib import pyplot as plt + +@hydra.main(version_base=None, config_path="../conf", config_name="training") def main(cfg: DictConfig) -> None: - # Initialize distributed environment for training DistributedManager.initialize() dist = DistributedManager() @@ -45,10 +47,12 @@ def main(cfg: DictConfig) -> None: fp16 = fp_optimizations == "fp16" enable_amp = fp_optimizations.startswith("amp") amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16 - logger.info(f"Saving the outputs in {os.getcwd()}") + logger0.info(f"Saving the outputs in {os.getcwd()}") checkpoint_dir = os.path.join( cfg.training.io.get("checkpoint_dir", "."), f"checkpoints_{cfg.model.name}" ) + if dist.rank==0 and not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) # added creating checkpoint dir if cfg.training.hp.batch_size_per_gpu == "auto": cfg.training.hp.batch_size_per_gpu = ( cfg.training.hp.total_batch_size // dist.world_size @@ -85,12 +89,14 @@ def main(cfg: DictConfig) -> None: if cfg.model.hr_mean_conditioning: img_in_channels += img_out_channels + if cfg.model.name == "lt_aware_ce_regression": - prob_channels = dataset.get_prob_channel_index() + prob_channels = dataset.get_prob_channel_index() #TODO figure out what prob_channel are and update dataloader else: prob_channels = None # Parse the patch shape + #TODO figure out patched diffusion and how to use it if ( cfg.model.name == "patched_diffusion" or cfg.model.name == "lt_aware_patched_diffusion" @@ -109,9 +115,8 @@ def main(cfg: DictConfig) -> None: # interpolate global channel if patch-based model is used if img_shape[1] != patch_shape[1]: img_in_channels += dataset_channels - - # Instantiate the model and move to device. + # Instantiate the model and move to device. if cfg.model.name not in ( "regression", "lt_aware_ce_regression", @@ -180,7 +185,7 @@ def main(cfg: DictConfig) -> None: img_in_channels=img_in_channels + model_args["N_grid_channels"], **model_args, ) - model_args["image_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] elif cfg.model.name == "lt_aware_ce_regression": model = UNet( img_in_channels=img_in_channels @@ -188,7 +193,7 @@ def main(cfg: DictConfig) -> None: + model_args["lead_time_channels"], **model_args, ) - model_args["image_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] elif cfg.model.name == "lt_aware_patched_diffusion": model = EDMPrecondSR( img_in_channels=img_in_channels @@ -196,18 +201,21 @@ def main(cfg: DictConfig) -> None: + model_args["lead_time_channels"], **model_args, ) - model_args["image_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] else: # diffusion or patched diffusion model = EDMPrecondSR( img_in_channels=img_in_channels + model_args["N_grid_channels"], **model_args, ) - model_args["image_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] model.train().requires_grad_(True).to(dist.device) + # TODO write summry from rank=0 possibly + # summary(model, input_size=[(4,img_out_channels,*img_shape),(4,img_in_channels,*img_shape),(4,1),(4,1)]) + if dist.rank==0 and not os.path.exists(os.path.join(checkpoint_dir, 'model_args.json')): - with open(os.path.join(checkpoint_dir, 'model_args.json'), 'w') as f: + with open(os.path.join(checkpoint_dir, f'model_args.json'), 'w') as f: json.dump(model_args, f) # Enable distributed data parallel if applicable @@ -220,7 +228,7 @@ def main(cfg: DictConfig) -> None: find_unused_parameters=dist.find_unused_parameters, ) - # Load the regression checkpoint if applicable + # Load the regression checkpoint if applicable #TODO test when training correction if hasattr(cfg.training.io, "regression_checkpoint_path"): regression_checkpoint_path = to_absolute_path( cfg.training.io.regression_checkpoint_path @@ -286,8 +294,7 @@ def main(cfg: DictConfig) -> None: dist.world_size, ) batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu - logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds") - + logger0.info(f"Using {num_accumulation_rounds} gradient accumulation {"rounds" if num_accumulation_rounds>1 else "round"}.") ## Resume training from previous checkpoints if exists if dist.world_size > 1: @@ -313,7 +320,6 @@ def main(cfg: DictConfig) -> None: average_loss_running_mean = 0 n_average_loss_running_mean = 1 - while not done: tick_start_nimg = cur_nimg tick_start_time = time.time() @@ -494,7 +500,6 @@ def main(cfg: DictConfig) -> None: optimizer=optimizer, epoch=cur_nimg, ) - pass # Done. logger0.info("Training Completed.") From 7510cf1ad7f40c0794d881d06277de793bbdde94 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 7 May 2025 15:20:07 +0200 Subject: [PATCH 23/66] small fix --- src/hirad/utils/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hirad/utils/checkpoint.py b/src/hirad/utils/checkpoint.py index e0f8d58..a346b16 100644 --- a/src/hirad/utils/checkpoint.py +++ b/src/hirad/utils/checkpoint.py @@ -294,7 +294,7 @@ def load_checkpoint( checkpoint_filename = _get_checkpoint_filename(path, index=epoch, model_type="pt") if not Path(checkpoint_filename).is_file(): checkpoint_logging.warning( - "Could not find valid checkpoint file, skipping load" + f"Could not find valid checkpoint file {checkpoint_filename} skipping load" ) return 0 From 75db04fb5b550787fac0cea394d34e090127ae77 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 7 May 2025 15:27:18 +0200 Subject: [PATCH 24/66] remove tracked .pyc files --- .../models/__pycache__/__init__.cpython-312.pyc | Bin 185 -> 0 bytes .../models/__pycache__/dummy.cpython-312.pyc | Bin 451 -> 0 bytes .../models/__pycache__/unet.cpython-312.pyc | Bin 8966 -> 0 bytes 3 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/hirad/models/__pycache__/__init__.cpython-312.pyc delete mode 100644 src/hirad/models/__pycache__/dummy.cpython-312.pyc delete mode 100644 src/hirad/models/__pycache__/unet.cpython-312.pyc diff --git a/src/hirad/models/__pycache__/__init__.cpython-312.pyc b/src/hirad/models/__pycache__/__init__.cpython-312.pyc deleted file mode 100644 index 70d5748263686e7ee01a00c9c9ffd61e4261990b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 185 zcmX@j%ge<81eP(8=gQRm475F^(AnLQ) zQDjD9>o~pA)a3YQB1#<>S}gTyk;VCqxJWL~XOnrH3tdDayH#g=oRqaZdZNZE447`7==k|mj}EK82PsyZ3&%#yR( zo!Q)(l}KR)vxSQ;+q#mP6r@`haex4oDiu*4g2+$KQ(yWbg(|!mH~@j92#mZ?pxh>Y zNzS?RvrBqy2Tp*#_zF6E&;2>~@7#0GcZUBJk4G3t8zaA+{guElzd*)K0VlI@7c$F? z%*bq!$>OiC$d-IrA8VKSvwm8~mG~@A%lySaDVPnijE_;6Ledr3_?5Yvtj%S^iqFR! zWn}(MMh+;!n|{0erqehZp|wG%4Jmw~t0BY9xjnKm+A<6+BVW)mzQCrVAEAKhG;4)r z$&_+M$uNM4PwBE!oKj3_5=t*c43f)~u*k@{1LChUOJ?Jrfmw!CW?7lZ`p&=}W&H}L z@Urh&sFD3ySQBuZ!-XivJ6xD@0m=ol5z2*}x+vwsPF;+05vMLrxv0Y>C>L|MHp<13 zTS%natdP4Fi#s!|m=>o=CFO|aFQ}#!%~fPcoF!6u&ai?~sVwHomD6@^wp_77rcQEm zVt(G@X7VEkEKb+7c`K5yXgQHqSz)PS>L@{;PA9F9D57*xv?8JiyH+V;J|>Fasz^n< zhJ;{ZmftXm)h3FPrs<|+s=8)~A_-#^H>2xCf?7$UNMcChND@fekf279b|5#H4WP^? zgToa=A;xgoFkyV{{P2tFsqx97qlz{R2RLjW;BZA#%uIQKpkhQ*i@+!Autvv;O zAA=ri?I;CuU2EHW7mr`%R=M7K*UrUbHS<2Vlk~U)Z_*_mN-=G`2;^^?v?$h;iL^ED>V%E33oSzeCFp*PuVKnBgv25Zbbn+?6fq@yGV;-vrJ7$zb>Ix_^P zFMI>ex~vc(uM=THCuCC1=Q9!7;G{%g?@ZW6$__OW<<-)xNEAaaR`9rlF+tT#T7kk( z&nZGvrJff|C{eX?#T?cvCj1C$Nt#tMo~Ya$oITvCErNvUN@)h>2NH665hm%jYxGn# z(JkE4xzU%7{FxG9Mqez#nRcsLfUb%8^2mWP0d;jy(95_bQqkLv!}Ch6Vk&eKDw?d2 z;)1Hp3RDdR$rO&9962DAiIP(d==l7gFtZ@YN?xiIp-w6q&ORCHY)KMJsx~I%i@IdC z7->q?)KaA+7?pBaC#E6`nywm(P*g4`FtaAP`|KTbWJ>d&BU76H17*@#4fLy8zCYb8 z9u4Ye%Rf9{R&qF1xU8Bi5gFEgL8Vpz&mA?#BoZtPwhSe0S;o`$NOUtrbD^w^fgsXi zZzj;B3uty+v0Kk#ljQ!%IfZ+{f8d`rsS-7af^r5Z- zR>w5xwZi#t9SE)%RgP0iULlH>Q?}kQS~~@MA!+KMFqIh;4naHpqH10m6sCcl(B~9w zP3~7Iq)g|iCtz)r-EUYVpZlX zu&*XqK|9@IZ#td_%T#?;Ms#0t#>D4`rYdWC@^9j7Ad$8zJI_LU*DGQmk0aBhZL6}f z56%0g7Od=d%Dk3UKeWI%kftgJTw`1wIB(=?061Ki-7*LqzJ)dy0uCQVn#?{g_lTRT zhTWN8j{2BtROYKJjLltw232@X-tku>-)H{;SLFF;3rspNeS@_)cm^bT3?JcM-SAt9 z&DP(FHW^_n)@)BnH}puiTRd1Kg`h_O7ukfz9rw+Op&ZwGTM;mS8XAF|g69qX2v!?& zQdtq_o0$cx4QrYQO=3awz*fjL%U0MkoL*J}x5X|+i(&<5VTC66APH~@;0~X6ycLV5 z_dOlAI72Ds$zw2Sl17q2GK^#&k`W}MNcIDfc$2WdANfKCzgWuLf!Ao-3@i$5B0D04%E2&p`M2U zCf0RT`+n|zbZ>q1g}bpAYLoY4y({K_?;2j)HFRCNw`(6QsW&rgyE53jw=V3%-@dhO zp~n4vyxy}1eiH7_j>mv&`>?C;9qoEywd-Il`m>a<&hTB4`p$H{yZ`Rqv3uQP^`7*D zAe60#n6{k{!jP+RZ$+rjlXG1hSOtU1Pp|75T&-oegn!v%I#-fm3eI)IhHe7I_2qps zd&}pkG+Sj0%`%;)jf-FKuBj`9Vcd&<87{(??_Kb&YeTZ^ms$LG(=jub;gT1;i`-~O zo~&|48s_d@S!H+R#?KrNbrEeT3nto;`6f;c$mfXt|Dl}F=rIam$0loB}q4D8Fk zrupIR*^=Cu7`+VxDF>^8e*q}jS}xRt|CM)F|5q%u`tn*9s-hbnnI?N+UVtq9vReA+ zcaSU&Vb=|x#mnfvSOIi@zy!u@+fbBdz%Tx!$8(d=p`a}VQx|3w;gbmXE#Ohl>Zb5X z7`$x5R3teASdX_snLj6)me16!Aa&68k6J-+&e5#e#_JGF$l~bOc8ie%xk$kDqu&ZJ z5x^rl%KJyjL1<#j^B)1z?w=Ir4ZKIL^%JP|gZb(NJT3WBj9Ng4TY!c}KbVL`W-Z!& zH`@1MGPNAO7G4=y$*(+ledK;}s5ZV9i(i>sn*2^`x$j!vYHU}HtwXEOwa`1M@Akdj zx0>w#AUS+5Iego9d;Ipd?(AMoK3n7015A5rIsT9F?-g#p{NvYt_}Y)p{qWpR@;`n0 zrH6hd-f?B>Z>N6Q-h~p^gJ5{qV@z95O}iypLdLR0I@x%i@*cMrjFgqSuk75p_?0gW1T+P4U!NikSjx8NqN&PhT z$SwYZf$@6-((E`kfO$S-6{gdBgAPOsq2?79_TH zT*)uxzjJ2!+_iJ7iAQVxdRs@$F}cl#6nz~0zcB~|{G%^0x9$Z1^<|~w&^Y3r(}b>F zPzbng!YsIjnm|ulxPaO)V5Ct$CBalQ16)G%@)QHSV(QpC7!&1e!I@aJ1!`iEEqP2h zqk*$|mg=!kR1I)wsg7q((^1ZOGs8Zo;i7zcMmc-VO*1&9m=&VgE;#mchWBJOkKh%$ z1GHS83oyz=zo5zpZSV?s8)JL6C?V}5N8t=0ZeobGBq+fPL{0-S@X!S3GDL|EFhZn~ zE+y%OXewF3=AA&mDd=UAsVi?Ry|GeRZF~G~=yCey-e8Zw-9)2Luo*!EZUseA)^iY( z;pvW2cZ0l!)n~Cf?)WD;rC1cj8!Y)clp3flFWBx0y0J8_fbNiiDXTH==@p7_7(oygZ0j?dgo&wryl1c8`Eqb-}%4KCHN=o2nb5a!GABgMs@=&)0c>V zgzX3jjeP{s=vwht;1Ej_V6<@KBnz~ko}?WGx!4p1@fvH@KnHsL)L< zzY42-6;}BwtkM)qbMBiRG9VbuB?yJBkQ*+@H9PBQb|~<83M|@|I6>5j6S$TL;`le6 zR}x@{lVceZWzYmM0uRNzu*8cIcv1jedk#5`5xDVyMhjNvJ{32>xI|;5KS{*E-09BT z#C@Wv8ggfT8NRYqBXR(tR@E=_5FH3R(E$v&(*)6h;53qvq4=Xd-X_MwjE80P5Hg;)ch zT_-hz>@LsnZ$jJu2|bK|2ZCt3yMeafePi{JXDHCvksJPjM@SS0+emk zx0uysjzZZ)V=On!(1L zuktOWa1M}d%&yfJPG^w-WYJX61b3AKxEgFN7a{H@Aj|&|8W5-kz^4jJZ-Ek>s%V0Q zAzy@PLKy?>j#uJcC-7Ki!D*pEV(M-+dEVm?MIWP1kb5QPg@I;wJB{8Del7wifw-3n z*wg{Wc2K1-u=(=~^re;HqZuYfra?x&Ebbo_)Vy$-RFsS`X@6ufu*&(+;t4+`!H1JA zvuohoG2vL=nHokA1w_vR_Yy2b`XI5QOL04{;JfVGk?XsE(0`-<&hGaI-W#|dn|Q$U zfW`xe%Y&G6;+7PbC;kXr4st#lmy2(0hp&fR?E|BkSG_)Z&qVc&GZ3acgjU|3XVtSBlpjXfy_K|^Fi zRQuz9M$;@G{NG-lTQJm|p(!QZ`GbfI{Y-#X+KSuKoG7~ee52!I1*Yg<0Z`0n`wr+Q z*fYV9ofUc(0_5;_0?!c*SbV08IUwtPmSs2gFl_tJnccr&Vn1h6|HF)|wGUkF`JjE^ zUi-lH-qrR4iwQVBU-Y3b!S>cJZ!qw);fI2CEO@wQo@Jk?jb1spbnqd36=$D#c)>rz TvYl5CEg!pf>|+M=bSA$CZ1ZH& From 4b5ecb4f7cfa0048c3d253ed50b9ae6537848245 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 9 May 2025 14:21:47 +0200 Subject: [PATCH 25/66] add small loggign changes --- src/hirad/training/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) mode change 100644 => 100755 src/hirad/training/train.py diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py old mode 100644 new mode 100755 index 88b2118..9ce619d --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -80,6 +80,7 @@ def main(cfg: DictConfig) -> None: validation_dataset_cfg=validation_dataset_cfg, train_test_split=train_test_split, ) + logger0.info(f"Training on dataset with size {len(dataset)}") # Parse image configuration & update model args dataset_channels = len(dataset.input_channels()) @@ -295,7 +296,7 @@ def main(cfg: DictConfig) -> None: ) batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu logger0.info(f"Using {num_accumulation_rounds} gradient accumulation {"rounds" if num_accumulation_rounds>1 else "round"}.") - + logger0.info(f"Batch size per gpu: {batch_size_per_gpu}") ## Resume training from previous checkpoints if exists if dist.world_size > 1: torch.distributed.barrier() From 3eb4d0a7b15af18b42ae647a2f08a1c5b45b5b8e Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 9 May 2025 14:22:55 +0200 Subject: [PATCH 26/66] adapt sbatch script to slurm config --- src/hirad/testrun.sh | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/src/hirad/testrun.sh b/src/hirad/testrun.sh index ee4a977..ac631c8 100644 --- a/src/hirad/testrun.sh +++ b/src/hirad/testrun.sh @@ -5,20 +5,33 @@ ### HARDWARE ### #SBATCH --partition=debug #SBATCH --nodes=1 -#SBATCH --gres=gpu:4 -#SBATCH --ntasks-per-node=4 -#SBATCH --cpus-per-task=16 -#SBATCH --mem=16G +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=72 #SBATCH --time=00:30:00 #SBATCH --no-requeue #SBATCH --exclusive ### OUTPUT ### -#SBATCH --output=/scratch/mch/pstamenk/logs/regression_test.log -#SBATCH --error=/scratch/mch/pstamenk/logs/regression_test.err +#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/regression_test.log +#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/regression_test.err + +### ENVIRONMENT #### +#SBATCH --uenv=pytorch/v2.6.0:/user-environment +#SBATCH --view=default +#SBATCH -A a-a01 # Choose method to initialize dist in pythorch -export DISTRIBUTED_INITIALIZATION_METHOD=ENV +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "Master node : $MASTER_ADDR" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +echo "Master address : $MASTER_ADDR" +export MASTER_ADDR +export MASTER_PORT=29500 +echo "Master port: $MASTER_PORT" # Get number of physical cores using Python PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") @@ -27,15 +40,12 @@ LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} # Compute threads per process OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) export OMP_NUM_THREADS=$OMP_THREADS -echo "Node: $(hostname)" echo "Physical cores: $PHYSICAL_CORES" echo "Local processes: $LOCAL_PROCS" echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" -# activate conda env -CONDA_ENV=train -source /users/pstamenk/.bashrc -mamba activate $CONDA_ENV - # python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml -torchrun --nproc-per-node=$LOCAL_PROCS src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml \ No newline at end of file +srun bash -c " + . ./train_env/bin/activate + python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml +" \ No newline at end of file From 9aaea9d2b2f7ddb9825613beb6303c32e10ed784 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 9 May 2025 14:23:40 +0200 Subject: [PATCH 27/66] adapt era5cosmo loader to trim_edge 19 --- src/hirad/datasets/era5_cosmo.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index 8b0d60f..f8835d1 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -14,7 +14,7 @@ def __init__(self, dataset_path: str): self._dataset_path = dataset_path self._era5_path = os.path.join(dataset_path, 'era-interpolated') self._cosmo_path = os.path.join(dataset_path, 'cosmo') - self._info_path = os.path.join(dataset_path, 'old/info') + self._info_path = os.path.join(dataset_path, 'info') # load file list (each file is one date-time state) self._file_list = os.listdir(self._cosmo_path) @@ -48,7 +48,8 @@ def __getitem__(self, idx): # squeeze the ensemble dimesnsion # reshape to image_shape # flip so that it starts in top-left corner (by default it is bottom left) - orig_shape = [350,542] #TODO currently padding to be divisible by 16 + # orig_shape = [350,542] #TODO currently padding to be divisible by 16 + orig_shape = self.image_shape() era5_data = np.flip(torch.load(os.path.join(self._era5_path,self._file_list[idx]), weights_only=False)\ .squeeze() \ .reshape(-1,*orig_shape), @@ -61,9 +62,12 @@ def __getitem__(self, idx): 1) cosmo_data = self.normalize_output(cosmo_data) # return samples - return F.pad(torch.tensor(cosmo_data), pad=(1,1,1,1), mode='constant', value=0), \ - F.pad(torch.tensor(era5_data), pad=(1,1,1,1), mode='constant', value=0), \ + return torch.tensor(cosmo_data),\ + torch.tensor(era5_data),\ 0 + # return F.pad(torch.tensor(cosmo_data), pad=(1,1,1,1), mode='constant', value=0), \ + # F.pad(torch.tensor(era5_data), pad=(1,1,1,1), mode='constant', value=0), \ + # 0 def __len__(self): return len(self._file_list) From dca7ff446056ce2596248559e3db2b853f576cd4 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Mon, 12 May 2025 17:32:56 +0200 Subject: [PATCH 28/66] add inference --- src/hirad/conf/generate_era_cosmo.yaml | 20 ++ src/hirad/conf/generation/era_cosmo.yaml | 37 ++++ src/hirad/conf/sampler/deterministic.yaml | 4 + src/hirad/conf/sampler/stochastic.yaml | 3 + src/hirad/inference/generate.py | 253 +++++++++++++++++++++- src/hirad/utils/generate_utils.py | 18 +- 6 files changed, 319 insertions(+), 16 deletions(-) create mode 100644 src/hirad/conf/generate_era_cosmo.yaml create mode 100644 src/hirad/conf/generation/era_cosmo.yaml create mode 100644 src/hirad/conf/sampler/deterministic.yaml create mode 100644 src/hirad/conf/sampler/stochastic.yaml diff --git a/src/hirad/conf/generate_era_cosmo.yaml b/src/hirad/conf/generate_era_cosmo.yaml new file mode 100644 index 0000000..03650e2 --- /dev/null +++ b/src/hirad/conf/generate_era_cosmo.yaml @@ -0,0 +1,20 @@ +hydra: + job: + chdir: true + name: generation + run: + dir: ./outputs/${hydra:job.name} + +# Get defaults +defaults: + + # Dataset + - dataset/era_cosmo + + # Sampler + - sampler/stochastic + #- sampler/deterministic + + # Generation + - generation/era_cosmo + #- generation/patched_based \ No newline at end of file diff --git a/src/hirad/conf/generation/era_cosmo.yaml b/src/hirad/conf/generation/era_cosmo.yaml new file mode 100644 index 0000000..2e37a63 --- /dev/null +++ b/src/hirad/conf/generation/era_cosmo.yaml @@ -0,0 +1,37 @@ +num_ensembles: 64 + # Number of ensembles to generate per input +seed_batch_size: 1 + # Size of the batched inference +inference_mode: regression + # Choose between "all" (regression + diffusion), "regression" or "diffusion" + # Patch size. Patch-based sampling will be utilized if these dimensions differ from + # img_shape_x and img_shape_y +overlap_pixels: 0 + # Number of overlapping pixels between adjacent patches +boundary_pixels: 0 + # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary + # artifact. +hr_mean_conditioning: False +sample_res: full + # Sampling resolution +times_range: null +times: + - 20160101-0000 + +perf: + force_fp16: false + # Whether to force fp16 precision for the model. If false, it'll use the precision + # specified upon training. + use_torch_compile: false + # whether to use torch.compile on the diffusion model + # this will make the first time stamp generation very slow due to compilation overheads + # but will significantly speed up subsequent inference runs + num_writer_workers: 1 + # number of workers to use for writing file + # To support multiple workers a threadsafe version of the netCDF library must be used + +io: + res_ckpt_path: diffusion_checkpoint + # Checkpoint filename for the diffusion model + reg_ckpt_path: regression_checkpoint + # Checkpoint filename for the mean predictor model \ No newline at end of file diff --git a/src/hirad/conf/sampler/deterministic.yaml b/src/hirad/conf/sampler/deterministic.yaml new file mode 100644 index 0000000..35bc0f6 --- /dev/null +++ b/src/hirad/conf/sampler/deterministic.yaml @@ -0,0 +1,4 @@ +type: deterministic +num_steps: 9 + # Number of denoising steps +solver: euler \ No newline at end of file diff --git a/src/hirad/conf/sampler/stochastic.yaml b/src/hirad/conf/sampler/stochastic.yaml new file mode 100644 index 0000000..5e8fa88 --- /dev/null +++ b/src/hirad/conf/sampler/stochastic.yaml @@ -0,0 +1,3 @@ +type: stochastic +boundary_pix: 2 +overlap_pix: 4 \ No newline at end of file diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index 6e5273a..adb882e 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -10,6 +10,9 @@ from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper from concurrent.futures import ThreadPoolExecutor from functools import partial + +from matplotlib import pyplot as plt +import cartopy.crs as ccrs from einops import rearrange from torch.distributed import gather @@ -57,7 +60,7 @@ def main(cfg: DictConfig) -> None: all_batches = torch.as_tensor(seeds).tensor_split(num_batches) rank_batches = all_batches[dist.rank :: dist.world_size] - # Synchronize + # Synchronize if dist.world_size > 1: torch.distributed.barrier() @@ -65,7 +68,7 @@ def main(cfg: DictConfig) -> None: if cfg.generation.times_range and cfg.generation.times: raise ValueError("Either times_range or times must be provided, but not both") if cfg.generation.times_range: - times = get_time_from_range(cfg.generation.times_range) #TODO check what time formats we are using and adapt + times = get_time_from_range(cfg.generation.times_range, time_format="%Y%m%d-%H%M") #TODO check what time formats we are using and adapt else: times = cfg.generation.times @@ -110,7 +113,7 @@ def main(cfg: DictConfig) -> None: # Load diffusion network, move to device, change precision if load_net_res: res_ckpt_path = cfg.generation.io.res_ckpt_path - logger0.info(f'Loading residual network from "{res_ckpt_path}"...') + logger0.info(f'Loading correction network from "{res_ckpt_path}"...') diffusion_model_args_path = os.path.join(res_ckpt_path, 'model_args.json') if not os.path.isfile(diffusion_model_args_path): @@ -135,7 +138,7 @@ def main(cfg: DictConfig) -> None: # load regression network, move to device, change precision if load_net_reg: reg_ckpt_path = cfg.generation.io.reg_ckpt_path - logger0.info(f'Loading network from "{reg_ckpt_path}"...') + logger0.info(f'Loading regression network from "{reg_ckpt_path}"...') regression_model_args_path = os.path.join(reg_ckpt_path, 'model_args.json') @@ -144,7 +147,7 @@ def main(cfg: DictConfig) -> None: with open(regression_model_args_path, 'r') as f: regression_model_args = json.load(f) - net_reg = EDMPrecond(**regression_model_args) + net_reg = UNet(**regression_model_args) _ = load_checkpoint( path=reg_ckpt_path, @@ -156,4 +159,242 @@ def main(cfg: DictConfig) -> None: if cfg.generation.perf.force_fp16: net_reg.use_fp16 = True else: - net_reg = None \ No newline at end of file + net_reg = None + + # Reset since we are using a different mode. + if cfg.generation.perf.use_torch_compile: + torch._dynamo.reset() + # Only compile residual network + # Overhead of compiling regression network outweights any benefits + if net_res: + net_res = torch.compile(net_res, mode="reduce-overhead") + + # Partially instantiate the sampler based on the configs + if cfg.sampler.type == "deterministic": + if cfg.generation.hr_mean_conditioning: + raise NotImplementedError( + "High-res mean conditioning is not yet implemented for the deterministic sampler" + ) + sampler_fn = partial( + deterministic_sampler, + num_steps=cfg.sampler.num_steps, + # num_ensembles=cfg.generation.num_ensembles, + solver=cfg.sampler.solver, + ) + elif cfg.sampler.type == "stochastic": + sampler_fn = partial( + stochastic_sampler, + img_shape=img_shape[1], + patch_shape=patch_shape[1], + boundary_pix=cfg.sampler.boundary_pix, + overlap_pix=cfg.sampler.overlap_pix, + ) + else: + raise ValueError(f"Unknown sampling method {cfg.sampling.type}") + + + # Main generation definition + def generate_fn(image_lr, lead_time_label): + img_shape_y, img_shape_x = img_shape + with nvtx.annotate("generate_fn", color="green"): + if cfg.generation.sample_res == "full": + image_lr_patch = image_lr + else: + torch.cuda.nvtx.range_push("rearrange") + image_lr_patch = rearrange( + image_lr, + "b c (h1 h) (w1 w) -> (b h1 w1) c h w", + h1=img_shape_y // patch_shape[0], + w1=img_shape_x // patch_shape[1], + ) + torch.cuda.nvtx.range_pop() + image_lr_patch = image_lr_patch.to(memory_format=torch.channels_last) + + if net_reg: + with nvtx.annotate("regression_model", color="yellow"): + image_reg = regression_step( + net=net_reg, + img_lr=image_lr_patch, + latents_shape=( + cfg.generation.seed_batch_size, + img_out_channels, + img_shape[0], + img_shape[1], + ), + lead_time_label=lead_time_label, + ) + if net_res: + if cfg.generation.hr_mean_conditioning: + mean_hr = image_reg[0:1] + else: + mean_hr = None + with nvtx.annotate("diffusion model", color="purple"): + image_res = diffusion_step( + net=net_res, + sampler_fn=sampler_fn, + seed_batch_size=cfg.generation.seed_batch_size, + img_shape=img_shape, + img_out_channels=img_out_channels, + rank_batches=rank_batches, + img_lr=image_lr_patch.expand( + cfg.generation.seed_batch_size, -1, -1, -1 + ).to(memory_format=torch.channels_last), + rank=dist.rank, + device=device, + hr_mean=mean_hr, + lead_time_label=lead_time_label, + ) + if cfg.generation.inference_mode == "regression": + image_out = image_reg + elif cfg.generation.inference_mode == "diffusion": + image_out = image_res + else: + image_out = image_reg + image_res + + if cfg.generation.sample_res != "full": + image_out = rearrange( + image_out, + "(b h1 w1) c h w -> b c (h1 h) (w1 w)", + h1=img_shape_y // patch_shape[0], + w1=img_shape_x // patch_shape[1], + ) + # Gather tensors on rank 0 + if dist.world_size > 1: + if dist.rank == 0: + gathered_tensors = [ + torch.zeros_like( + image_out, dtype=image_out.dtype, device=image_out.device + ) + for _ in range(dist.world_size) + ] + else: + gathered_tensors = None + + torch.distributed.barrier() + gather( + image_out, + gather_list=gathered_tensors if dist.rank == 0 else None, + dst=0, + ) + + if dist.rank == 0: + return torch.cat(gathered_tensors) + else: + return None + else: + return image_out + + # generate images + output_path = getattr(cfg.generation.io, "output_path", "./outputs") + logger0.info(f"Generating images, saving results to {output_path}...") + batch_size = 1 + warmup_steps = min(len(times) - 1, 2) + # Generates model predictions from the input data using the specified + # `generate_fn`, and save the predictions to the provided NetCDF file. It iterates + # through the dataset using a data loader, computes predictions, and saves them along + # with associated metadata. + + with torch.cuda.profiler.profile(): + with torch.autograd.profiler.emit_nvtx(): + + data_loader = torch.utils.data.DataLoader( + dataset=dataset, sampler=sampler, batch_size=1, pin_memory=True + ) + time_index = -1 + if dist.rank == 0: + writer_executor = ThreadPoolExecutor( + max_workers=cfg.generation.perf.num_writer_workers + ) + writer_threads = [] + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + times = dataset.time() + for image_tar, image_lr, index, *lead_time_label in iter(data_loader): + time_index += 1 + if dist.rank == 0: + logger0.info(f"starting index: {time_index}") + + if time_index == warmup_steps: + start.record() + + # continue + if lead_time_label: + lead_time_label = lead_time_label[0].to(dist.device).contiguous() + else: + lead_time_label = None + image_lr = ( + image_lr.to(device=device) + .to(torch.float32) + .to(memory_format=torch.channels_last) + ) + image_tar = image_tar.to(device=device).to(torch.float32) + image_out = generate_fn(image_lr,lead_time_label) + if dist.rank == 0: + batch_size = image_out.shape[0] + # write out data in a seperate thread so we don't hold up inferencing + writer_threads.append( + writer_executor.submit( + save_images, + output_path, + dataset, + image_out.cpu(), + image_tar.cpu(), + image_lr.cpu(), + ) + ) + end.record() + end.synchronize() + elapsed_time = start.elapsed_time(end) / 1000.0 # Convert ms to s + timed_steps = time_index + 1 - warmup_steps + if dist.rank == 0: + average_time_per_batch_element = elapsed_time / timed_steps / batch_size + logger.info( + f"Total time to run {timed_steps} steps and {batch_size} members = {elapsed_time} s" + ) + logger.info( + f"Average time per batch element = {average_time_per_batch_element} s" + ) + + # make sure all the workers are done writing + if dist.rank == 0: + for thread in list(writer_threads): + thread.result() + writer_threads.remove(thread) + writer_executor.shutdown() + + if dist.rank == 0: + f.close() + logger0.info("Generation Completed.") + +def save_images(output_path, dataset, image_pred, image_hr, image_lr): + longitudes = dataset.longitude() + latitudes = dataset.latitude() + input_channels = dataset.input_channels() + output_channels = dataset.output_channels() + image_pred = np.flip(dataset.denormalize_output(image_pred.numpy()),1).reshape(len(output_channels),-1) + image_hr = np.flip(dataset.denormalize_output(image_hr.numpy()),1).reshape(len(output_channels),-1) + image_lr = np.flip(dataset.denormalize_input(image_lr.numpy()),1).reshape(len(input_channels),-1) + for idx, channel in enumerate(output_channels): + input_channel_idx = input_channels.index(channel) + _plot_projection(longitudes,latitudes,image_lr[input_channel_idx,:],os.path.join(output_path,f'{channel.name}-lr.jpg')) + _plot_projection(longitudes,latitudes,image_hr[idx,:],os.path.join(output_path,f'{channel.name}-hr.jpg')) + _plot_projection(longitudes,latitudes,image_pred[idx,:],os.path.join(output_path,f'{channel.name}-hr-pred.jpg')) + +def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): + + """Plot observed or interpolated data in a scatter plot.""" + # TODO: Refactor this somehow, it's not really generalizing well across variables. + fig = plt.figure() + fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) + p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) + ax.coastlines() + ax.gridlines(draw_labels=True) + plt.colorbar(p, label="K", orientation="horizontal") + plt.savefig(filename) + plt.close('all') + +if __name__ == "__main__": + main() + \ No newline at end of file diff --git a/src/hirad/utils/generate_utils.py b/src/hirad/utils/generate_utils.py index b99852f..43f83b6 100644 --- a/src/hirad/utils/generate_utils.py +++ b/src/hirad/utils/generate_utils.py @@ -8,17 +8,15 @@ def get_dataset_and_sampler(dataset_cfg, times, has_lead_time=False): Get a dataset and sampler for generation. """ (dataset, _) = init_dataset_from_config(dataset_cfg, batch_size=1) - if has_lead_time: - plot_times = times - else: - plot_times = [ - convert_datetime_to_cftime( - datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%S") - ) - for time in times - ] + # if has_lead_time: + # plot_times = times + # else: + # plot_times = [ + # datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%S") + # for time in times + # ] all_times = dataset.time() - time_indices = [all_times.index(t) for t in plot_times] + time_indices = [all_times.index(t) for t in times] sampler = time_indices return dataset, sampler \ No newline at end of file From 8ba0c5a21e21a09ff1b248221f01378e4bc30754 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 12 May 2025 18:38:19 +0200 Subject: [PATCH 29/66] Plot absolute error onto a projection for a given date --- src/hirad/eval/metrics.py | 21 ++++++++++++++++++++ src/hirad/eval/plotting.py | 16 ++++++++++++++++ src/hirad/eval/run_scoring.py | 36 +++++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+) create mode 100644 src/hirad/eval/metrics.py create mode 100644 src/hirad/eval/plotting.py create mode 100644 src/hirad/eval/run_scoring.py diff --git a/src/hirad/eval/metrics.py b/src/hirad/eval/metrics.py new file mode 100644 index 0000000..c8fda43 --- /dev/null +++ b/src/hirad/eval/metrics.py @@ -0,0 +1,21 @@ +import numpy as np +import torch + + +# set up MAE calculation to be run for each channel for a given date/time (for target COSMO, prediction, and ERA interpolated) + +# input will be a 2D tensor of values with the COSMO lat/lon. + +# Extracted from physicsnemo/examples/weather/regen/paper_figures/score_inference.py + +def absolute_error(pred, target): + return torch.abs(pred-target) + +def compute_mae(pred, target): + # Exclude any target NaNs (not expected, but precautionary) + mask = ~np.isnan(target) + pred = pred[:, mask] + target = target[mask] + + return torch.mean(absolute_error(pred, target)) + diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py new file mode 100644 index 0000000..141b5c9 --- /dev/null +++ b/src/hirad/eval/plotting.py @@ -0,0 +1,16 @@ +import logging + +import cartopy.crs as ccrs +import matplotlib.pyplot as plt +import numpy as np + +def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np.array, filename: str): + fig = plt.figure() + fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) + logging.info(f'plotting values to {filename}') + p = ax.scatter(x=longitudes, y=latitudes, c=values) + ax.coastlines() + ax.gridlines(draw_labels=True) + plt.colorbar(p, label="absolute error", orientation="horizontal") + plt.savefig(filename) + plt.close('all') diff --git a/src/hirad/eval/run_scoring.py b/src/hirad/eval/run_scoring.py new file mode 100644 index 0000000..57ca69f --- /dev/null +++ b/src/hirad/eval/run_scoring.py @@ -0,0 +1,36 @@ +import os +import sys + +import metrics +import plotting +import torch +import yaml + + +def main(): + if len(sys.argv) < 4: + raise ValueError('Expected call run_scoring.py [input data directory] [predictions directory] [date]') + + input_directory = sys.argv[1] + predictions_directory = sys.argv[2] + date = sys.argv[3] + + target = torch.load(os.path.join(input_directory, 'cosmo', date), weights_only=False) + baseline = torch.load(os.path.join(input_directory, 'era-interpolated', date), weights_only=False) + #prediction_file = torch.load(os.path.join(predictions_directory, date), weights_only=False) + prediction = torch.load(os.path.join(input_directory, 'cosmo', '20160101-0000'), weights_only=False) + lat_lon = torch.load(os.path.join(input_directory, 'info', 'cosmo-lat-lon'), weights_only=False) + + # Reshape grides to be the same as prediction + #target = target.squeeze().reshape(-1,*prediction.shape), + target = torch.from_numpy(target) + prediction = torch.from_numpy(prediction) + #prediction = prediction.squeeze().reshape(-1,*prediction.shape) + latitudes = lat_lon[:,0] #.squeeze().reshape(-1,*prediction.shape) + longitudes = lat_lon[:,1] #squeeze().reshape(-1,*prediction.shape) + + errors = metrics.absolute_error(prediction[0,:,:], target[0,:,:]) + plotting.plot_error_projection(errors, latitudes, longitudes, os.path.join('plots/errors/', date)) + +if __name__ == "__main__": + main() \ No newline at end of file From c6e632a1bdc52d9f932baf15beab93abc187d94f Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 12 May 2025 18:39:35 +0200 Subject: [PATCH 30/66] Adjust how reshaping is done before error calcs --- src/hirad/eval/run_scoring.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/hirad/eval/run_scoring.py b/src/hirad/eval/run_scoring.py index 57ca69f..ce2c343 100644 --- a/src/hirad/eval/run_scoring.py +++ b/src/hirad/eval/run_scoring.py @@ -17,18 +17,25 @@ def main(): target = torch.load(os.path.join(input_directory, 'cosmo', date), weights_only=False) baseline = torch.load(os.path.join(input_directory, 'era-interpolated', date), weights_only=False) - #prediction_file = torch.load(os.path.join(predictions_directory, date), weights_only=False) - prediction = torch.load(os.path.join(input_directory, 'cosmo', '20160101-0000'), weights_only=False) + prediction = torch.load(os.path.join(predictions_directory, date), weights_only=False) lat_lon = torch.load(os.path.join(input_directory, 'info', 'cosmo-lat-lon'), weights_only=False) - # Reshape grides to be the same as prediction - #target = target.squeeze().reshape(-1,*prediction.shape), + with open(os.path.join(input_directory), 'info', 'cosmo.yaml') as cosmo_file: + cosmo_config = yaml.safe_load(cosmo_file) + channels = cosmo_config['select'] + + # Reshape predictions, if necessary + # target is shape [channels, ensembles, points] + # prediction is shape [channels, ensembles, x, y] + prediction = prediction.reshape(*target.shape) + + latitudes = lat_lon[:,0] + longitudes = lat_lon[:,1] + + # convert to torch target = torch.from_numpy(target) prediction = torch.from_numpy(prediction) - #prediction = prediction.squeeze().reshape(-1,*prediction.shape) - latitudes = lat_lon[:,0] #.squeeze().reshape(-1,*prediction.shape) - longitudes = lat_lon[:,1] #squeeze().reshape(-1,*prediction.shape) - + errors = metrics.absolute_error(prediction[0,:,:], target[0,:,:]) plotting.plot_error_projection(errors, latitudes, longitudes, os.path.join('plots/errors/', date)) From d21ec35ab055d313d7128163a941b909a3b44745 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Mon, 12 May 2025 18:58:47 +0200 Subject: [PATCH 31/66] plot for all channels, and against baseline --- src/hirad/eval/run_scoring.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/hirad/eval/run_scoring.py b/src/hirad/eval/run_scoring.py index ce2c343..0cc3c1d 100644 --- a/src/hirad/eval/run_scoring.py +++ b/src/hirad/eval/run_scoring.py @@ -20,9 +20,13 @@ def main(): prediction = torch.load(os.path.join(predictions_directory, date), weights_only=False) lat_lon = torch.load(os.path.join(input_directory, 'info', 'cosmo-lat-lon'), weights_only=False) - with open(os.path.join(input_directory), 'info', 'cosmo.yaml') as cosmo_file: + with open(os.path.join(input_directory, 'info', 'cosmo.yaml')) as cosmo_file: cosmo_config = yaml.safe_load(cosmo_file) - channels = cosmo_config['select'] + target_channels = cosmo_config['select'] + + with open(os.path.join(input_directory, 'info', 'era.yaml')) as era_file: + era_config = yaml.safe_load(era_file) + input_channels = era_config['select'] # Reshape predictions, if necessary # target is shape [channels, ensembles, points] @@ -34,10 +38,18 @@ def main(): # convert to torch target = torch.from_numpy(target) + baseline = torch.from_numpy(baseline) prediction = torch.from_numpy(prediction) - errors = metrics.absolute_error(prediction[0,:,:], target[0,:,:]) - plotting.plot_error_projection(errors, latitudes, longitudes, os.path.join('plots/errors/', date)) + # plot baseline error + for t_c in range(len(target_channels)): + b_c = input_channels.index(target_channels[t_c]) + if b_c > -1: + baseline_errors = metrics.absolute_error(baseline[b_c,:,:], target[t_c,:,:]) + plotting.plot_error_projection(baseline_errors, latitudes, longitudes, os.path.join('plots/errors/', 'baseline', target_channels[t_c] + '-' + date)) + prediction_errors = metrics.absolute_error(prediction[t_c,:,:], target[t_c,:,:]) + plotting.plot_error_projection(prediction_errors, latitudes, longitudes, os.path.join('plots/errors/', 'prediction', target_channels[t_c] + '-' + date)) + if __name__ == "__main__": main() \ No newline at end of file From 5b32566bf13bc19ae4de8c0f957ac0f44c6ce892 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 13 May 2025 09:14:41 +0200 Subject: [PATCH 32/66] Add MAE output --- src/hirad/eval/metrics.py | 6 ++++-- src/hirad/eval/run_scoring.py | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/hirad/eval/metrics.py b/src/hirad/eval/metrics.py index c8fda43..133cc21 100644 --- a/src/hirad/eval/metrics.py +++ b/src/hirad/eval/metrics.py @@ -8,7 +8,7 @@ # Extracted from physicsnemo/examples/weather/regen/paper_figures/score_inference.py -def absolute_error(pred, target): +def absolute_error(pred, target) -> tuple[float, np.ndarray]: return torch.abs(pred-target) def compute_mae(pred, target): @@ -17,5 +17,7 @@ def compute_mae(pred, target): pred = pred[:, mask] target = target[mask] - return torch.mean(absolute_error(pred, target)) + ae = absolute_error(pred, target) + + return torch.mean(absolute_error(pred, target)), ae diff --git a/src/hirad/eval/run_scoring.py b/src/hirad/eval/run_scoring.py index 0cc3c1d..fee984e 100644 --- a/src/hirad/eval/run_scoring.py +++ b/src/hirad/eval/run_scoring.py @@ -41,15 +41,15 @@ def main(): baseline = torch.from_numpy(baseline) prediction = torch.from_numpy(prediction) - # plot baseline error + # plot errors for t_c in range(len(target_channels)): b_c = input_channels.index(target_channels[t_c]) if b_c > -1: - baseline_errors = metrics.absolute_error(baseline[b_c,:,:], target[t_c,:,:]) + baseline_mae, baseline_errors = metrics.compute_mae(baseline[b_c,:,:], target[t_c,:,:]) plotting.plot_error_projection(baseline_errors, latitudes, longitudes, os.path.join('plots/errors/', 'baseline', target_channels[t_c] + '-' + date)) - prediction_errors = metrics.absolute_error(prediction[t_c,:,:], target[t_c,:,:]) + prediction_mae, prediction_errors = metrics.compute_mae(prediction[t_c,:,:], target[t_c,:,:]) plotting.plot_error_projection(prediction_errors, latitudes, longitudes, os.path.join('plots/errors/', 'prediction', target_channels[t_c] + '-' + date)) - + print(f'baseline MAE={baseline_mae}, prediction MAE={prediction_mae}') if __name__ == "__main__": main() \ No newline at end of file From 60c8affbc46d240940d84255c9ab3b32e82cfc8b Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 13 May 2025 09:23:15 +0200 Subject: [PATCH 33/66] Fix indexing error --- src/hirad/eval/metrics.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/hirad/eval/metrics.py b/src/hirad/eval/metrics.py index 133cc21..e6e1afb 100644 --- a/src/hirad/eval/metrics.py +++ b/src/hirad/eval/metrics.py @@ -13,8 +13,9 @@ def absolute_error(pred, target) -> tuple[float, np.ndarray]: def compute_mae(pred, target): # Exclude any target NaNs (not expected, but precautionary) + # TODO: Fix the deprecated warning (index with dtype torch.bool instead of torch.uint8) mask = ~np.isnan(target) - pred = pred[:, mask] + pred = pred[mask] target = target[mask] ae = absolute_error(pred, target) From 5cc42e9c0de8b04f2f865c90a12bf928443e9ad5 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 14 May 2025 10:08:43 +0200 Subject: [PATCH 34/66] Try adding spectral graph --- .gitignore | 17 ++++++++++++++++- src/hirad/eval/plotting.py | 7 +++++++ src/hirad/eval/run_scoring.py | 19 +++++++++++++++---- 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index c514c5d..dee6b07 100644 --- a/.gitignore +++ b/.gitignore @@ -168,4 +168,19 @@ poetry.toml .ruff_cache/ # LSP config files -pyrightconfig.json \ No newline at end of file +pyrightconfig.json + +# output files +*.out +*.torch +plots/* +*.npz + +# conda +.conda/* + +# temp +temp.* + +# local script +interpolate.sh \ No newline at end of file diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 141b5c9..262109f 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -14,3 +14,10 @@ def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np. plt.colorbar(p, label="absolute error", orientation="horizontal") plt.savefig(filename) plt.close('all') + +def plot_power_spectrum(x, filename): + fig = plt.figure() + plt.psd(x) + logging.info(f'plotting values to {filename}') + plt.savefig(filename) + plt.close('all') \ No newline at end of file diff --git a/src/hirad/eval/run_scoring.py b/src/hirad/eval/run_scoring.py index fee984e..df37a9a 100644 --- a/src/hirad/eval/run_scoring.py +++ b/src/hirad/eval/run_scoring.py @@ -31,7 +31,8 @@ def main(): # Reshape predictions, if necessary # target is shape [channels, ensembles, points] # prediction is shape [channels, ensembles, x, y] - prediction = prediction.reshape(*target.shape) + prediction_1d = prediction.reshape(*target.shape) + prediction_2d = prediction.reshape(prediction.shape[0],352,544) latitudes = lat_lon[:,0] longitudes = lat_lon[:,1] @@ -39,7 +40,7 @@ def main(): # convert to torch target = torch.from_numpy(target) baseline = torch.from_numpy(baseline) - prediction = torch.from_numpy(prediction) + prediction_1d = torch.from_numpy(prediction_1d) # plot errors for t_c in range(len(target_channels)): @@ -47,9 +48,19 @@ def main(): if b_c > -1: baseline_mae, baseline_errors = metrics.compute_mae(baseline[b_c,:,:], target[t_c,:,:]) plotting.plot_error_projection(baseline_errors, latitudes, longitudes, os.path.join('plots/errors/', 'baseline', target_channels[t_c] + '-' + date)) - prediction_mae, prediction_errors = metrics.compute_mae(prediction[t_c,:,:], target[t_c,:,:]) - plotting.plot_error_projection(prediction_errors, latitudes, longitudes, os.path.join('plots/errors/', 'prediction', target_channels[t_c] + '-' + date)) + plotting.plot_power_spectrum(baseline[b_c,:,:], os.path.join('plots/spectra/', 'baseline', target_channels[t_c] + date)) + prediction_mae, prediction_errors = metrics.compute_mae(prediction_1d[t_c,:,:], target[t_c,:,:]) + plotting.plot_error_projection(prediction_errors, latitudes, longitudes, os.path.join('plots/errors/', 'prediction', target_channels[t_c] + '-' + date)) + plotting.plot_power_spectrum(prediction[t_c,0,:], os.path.join('plots/spectra/', 'prediction', target_channels[t_c] + date)) + plotting.plot_power_spectrum(prediction_2d[t_c,:,:], os.path.join('plots/spectra/', 'prediction2d', target_channels[t_c] + date)) print(f'baseline MAE={baseline_mae}, prediction MAE={prediction_mae}') + # Plot power spectra + freq, power = metrics.compute_power_spectrum(prediction, 1) + plotting.plot_power_spectrum(prediction, 'plots/errors/powerspec-prediction') + plotting.plot_power_spectrum(prediction, 'plots/errors/powerspec-prediction') + + + if __name__ == "__main__": main() \ No newline at end of file From f1645045a411c59d5a97c36e678894b1fde1decd Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 14 May 2025 12:19:39 +0200 Subject: [PATCH 35/66] Start some CE stuff --- .edf/ubuntu.toml | 3 + .gitignore | 186 +++++++++++++++++++++++++++++++++ config/containers/storage.conf | 17 +++ 3 files changed, 206 insertions(+) create mode 100644 .edf/ubuntu.toml create mode 100644 .gitignore create mode 100644 config/containers/storage.conf diff --git a/.edf/ubuntu.toml b/.edf/ubuntu.toml new file mode 100644 index 0000000..4f14ee9 --- /dev/null +++ b/.edf/ubuntu.toml @@ -0,0 +1,3 @@ +image = "library/ubuntu:24.04" +mounts = ["/capstor/scratch/cscs/mmcgloho:/capstor/scratch/cscs/mmcgloho"] +workdir = "/capstor/scratch/cscs/mmcgloho" diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..dee6b07 --- /dev/null +++ b/.gitignore @@ -0,0 +1,186 @@ +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +# output files +*.out +*.torch +plots/* +*.npz + +# conda +.conda/* + +# temp +temp.* + +# local script +interpolate.sh \ No newline at end of file diff --git a/config/containers/storage.conf b/config/containers/storage.conf new file mode 100644 index 0000000..740925d --- /dev/null +++ b/config/containers/storage.conf @@ -0,0 +1,17 @@ +# https://confluence.cscs.ch/spaces/KB/pages/868834153/Building+container+images+on+Alps +# TOML format + +[storage] +driver = "overlay" +runroot = "/dev/shm/$USER/runroot" +graphroot = "/dev/shm/$USER/root" + +[storage.options.overlay] +mount_program = "/usr/bin/fuse-overlayfs-1.13" + +# In the above configuration, /dev/shm is used to store the container images. +# /dev/shm is the mount point of a tmpfs filesystem and is compatible with the +# user namespaces used by Podman. The limitation of this approach is that +# container images created during a job allocation are deleted when the job +# ends. Therefore, the image # needs to either be pushed to a container registry +# or imported by the Container Engine before the job allocation finishes. \ No newline at end of file From 5b77dbebd984a35e3797ed40e97c6cdb1e03669c Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Thu, 15 May 2025 15:38:46 +0200 Subject: [PATCH 36/66] fix inference for diffusion --- pyproject.toml | 10 ++- src/hirad/conf/dataset/era_cosmo.yaml | 2 +- src/hirad/conf/generate_era_cosmo.yaml | 6 +- src/hirad/conf/generation/era_cosmo.yaml | 7 ++- .../conf/training/era_cosmo_diffusion.yaml | 10 +-- .../conf/training/era_cosmo_regression.yaml | 11 ++-- src/hirad/datasets/era5_cosmo.py | 8 +-- src/hirad/generate.sh | 51 +++++++++++++++ src/hirad/inference/generate.py | 62 ++++++++++++------- src/hirad/models/layers.py | 3 +- src/hirad/train.sh | 51 +++++++++++++++ src/hirad/training/train.py | 2 +- src/hirad/utils/inference_utils.py | 13 ++-- src/hirad/utils/stochastic_sampler.py | 60 ++++++++++-------- 14 files changed, 218 insertions(+), 78 deletions(-) create mode 100644 src/hirad/generate.sh create mode 100644 src/hirad/train.sh diff --git a/pyproject.toml b/pyproject.toml index b2fa56c..1477899 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,15 @@ requires-python = ">=3.12" license = {file = "LICENSE"} dependencies = [ - "torch>=2.6.0" + "cartopy>=0.24.1", + "cftime>=1.6.4", + "hydra-core>=1.3.2", + "matplotlib>=3.10.1", + "omegaconf>=2.3.0", + "tensorboard>=2.19.0", + "termcolor>=3.1.0", + "torchinfo>=1.8.0", + "treelib>=1.7.1" ] [tool.setuptools] diff --git a/src/hirad/conf/dataset/era_cosmo.yaml b/src/hirad/conf/dataset/era_cosmo.yaml index 854b775..63d7361 100644 --- a/src/hirad/conf/dataset/era_cosmo.yaml +++ b/src/hirad/conf/dataset/era_cosmo.yaml @@ -1,2 +1,2 @@ type: era5_cosmo -dataset_path: /store_new/mch/msopr/hirad-gen/basic-torch \ No newline at end of file +dataset_path: /iopsstor/scratch/cscs/pstamenk/datasets/basic-torch/trim_19_overfit \ No newline at end of file diff --git a/src/hirad/conf/generate_era_cosmo.yaml b/src/hirad/conf/generate_era_cosmo.yaml index 03650e2..5d7649d 100644 --- a/src/hirad/conf/generate_era_cosmo.yaml +++ b/src/hirad/conf/generate_era_cosmo.yaml @@ -1,13 +1,13 @@ hydra: job: chdir: true - name: generation + name: generation_full run: - dir: ./outputs/${hydra:job.name} + dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} # Get defaults defaults: - + - _self_ # Dataset - dataset/era_cosmo diff --git a/src/hirad/conf/generation/era_cosmo.yaml b/src/hirad/conf/generation/era_cosmo.yaml index 2e37a63..5179520 100644 --- a/src/hirad/conf/generation/era_cosmo.yaml +++ b/src/hirad/conf/generation/era_cosmo.yaml @@ -31,7 +31,8 @@ perf: # To support multiple workers a threadsafe version of the netCDF library must be used io: - res_ckpt_path: diffusion_checkpoint + res_ckpt_path: null # Checkpoint filename for the diffusion model - reg_ckpt_path: regression_checkpoint - # Checkpoint filename for the mean predictor model \ No newline at end of file + reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_overfit/checkpoints_regression + # Checkpoint filename for the mean predictor model + output_path: ./images \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_diffusion.yaml b/src/hirad/conf/training/era_cosmo_diffusion.yaml index b61603a..b06ec61 100644 --- a/src/hirad/conf/training/era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training/era_cosmo_diffusion.yaml @@ -1,8 +1,8 @@ # Hyperparameters hp: - training_duration: 128 + training_duration: 16 # Training duration based on the number of processed samples - total_batch_size: 16 + total_batch_size: 4 # Total batch size batch_size_per_gpu: "auto" # Batch size per GPU @@ -20,14 +20,14 @@ perf: fp_optimizations: amp-bf16 # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} - dataloader_workers: 4 + dataloader_workers: 8 # DataLoader worker processes songunet_checkpoint_level: 0 # 0 means no checkpointing # Gradient checkpointing level, value is number of layers to checkpoint # I/O io: - regression_checkpoint_path: /scratch/mch/pstamenk/output/regression/checkpoints_regression + regression_checkpoint_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_overfit/checkpoints_regression # Where to load the regression checkpoint print_progress_freq: 32 # How often to print progress @@ -38,4 +38,4 @@ io: validation_steps: 10 # how many loss evaluations are used to compute the validation loss per checkpoint # how many loss evaluations are used to compute the validation loss per checkpoint - checkpoint_dir: /scratch/mch/pstamenk/output/diffusion \ No newline at end of file + checkpoint_dir: . \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_regression.yaml b/src/hirad/conf/training/era_cosmo_regression.yaml index 7c443f0..76bdc4e 100644 --- a/src/hirad/conf/training/era_cosmo_regression.yaml +++ b/src/hirad/conf/training/era_cosmo_regression.yaml @@ -1,12 +1,13 @@ # Hyperparameters hp: - training_duration: 16 + training_duration: 8 # Training duration based on the number of processed samples - total_batch_size: 16 + total_batch_size: 4 # Total batch size batch_size_per_gpu: "auto" # Batch size per GPU - lr: 0.0002 + lr: 0.001 + #0.0002 # Learning rate grad_clip_threshold: null # no gradient clipping for defualt non-patch-based training @@ -27,7 +28,7 @@ perf: # I/O io: - print_progress_freq: 32 + print_progress_freq: 128 # How often to print progress save_checkpoint_freq: 5000 # How often to save the checkpoints, measured in number of processed samples @@ -35,4 +36,4 @@ io: # how often to record the validation loss, measured in number of processed samples validation_steps: 10 # how many loss evaluations are used to compute the validation loss per checkpoint - checkpoint_dir: /scratch/mch/pstamenk/output/regression \ No newline at end of file + checkpoint_dir: . \ No newline at end of file diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index f8835d1..674dbf0 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -75,14 +75,14 @@ def __len__(self): def longitude(self) -> np.ndarray: """Get longitude values from the dataset.""" - lon_lat = torch.load(os.path.join(self._info_path,'cosmo-lat-lon'), weights_only=False) - return lon_lat[:,0] + lat_lon = torch.load(os.path.join(self._info_path,'cosmo-lat-lon'), weights_only=False) + return lat_lon[:,1] def latitude(self) -> np.ndarray: """Get latitude values from the dataset.""" - lon_lat = torch.load(os.path.join(self._info_path,'cosmo-lat-lon'), weights_only=False) - return lon_lat[:,1] + lat_lon = torch.load(os.path.join(self._info_path,'cosmo-lat-lon'), weights_only=False) + return lat_lon[:,0] def input_channels(self) -> List[ChannelMetadata]: diff --git a/src/hirad/generate.sh b/src/hirad/generate.sh new file mode 100644 index 0000000..87c8979 --- /dev/null +++ b/src/hirad/generate.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +#SBATCH --job-name="testrun" + +### HARDWARE ### +#SBATCH --partition=debug +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=72 +#SBATCH --time=00:30:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/full_generation.log +#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/full_generation.err + +### ENVIRONMENT #### +#SBATCH --uenv=pytorch/v2.6.0:/user-environment +#SBATCH --view=default +#SBATCH -A a-a122 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "Master node : $MASTER_ADDR" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +echo "Master address : $MASTER_ADDR" +export MASTER_ADDR +export MASTER_PORT=29500 +echo "Master port: $MASTER_PORT" + +# Get number of physical cores using Python +PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# Use SLURM_NTASKS (number of processes to be launched by torchrun) +LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# Compute threads per process +OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +export OMP_NUM_THREADS=$OMP_THREADS +echo "Physical cores: $PHYSICAL_CORES" +echo "Local processes: $LOCAL_PROCS" +echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" + +# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml +srun bash -c " + . ./train_env/bin/activate + python src/hirad/inference/generate.py --config-name=generate_era_cosmo.yaml +" \ No newline at end of file diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index adb882e..5558a20 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -11,14 +11,14 @@ from concurrent.futures import ThreadPoolExecutor from functools import partial -from matplotlib import pyplot as plt import cartopy.crs as ccrs +from matplotlib import pyplot as plt from einops import rearrange from torch.distributed import gather from hydra.utils import to_absolute_path -from hirad.models import EDMPrecond, UNet +from hirad.models import EDMPrecondSR, UNet from hirad.utils.stochastic_sampler import stochastic_sampler from hirad.utils.deterministic_sampler import deterministic_sampler from hirad.utils.inference_utils import ( @@ -36,12 +36,12 @@ from hirad.utils.train_helpers import set_patch_shape -@hydra.main(version_base="1.2", config_path="conf", config_name="config_generate") +@hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") def main(cfg: DictConfig) -> None: """Generate random dowscaled atmospheric states using the techniques described in the paper "Elucidating the Design Space of Diffusion-Based Generative Models". """ - + torch.backends.cudnn.enabled = False # Initialize distributed manager DistributedManager.initialize() dist = DistributedManager() @@ -50,7 +50,7 @@ def main(cfg: DictConfig) -> None: # Initialize logger logger = PythonLogger("generate") # General python logger logger0 = RankZeroLoggingWrapper(logger, dist) - logger.file_logging("generate.log") + # logger.file_logging("generate.log") # Handle the batch size seeds = list(np.arange(cfg.generation.num_ensembles)) @@ -121,7 +121,7 @@ def main(cfg: DictConfig) -> None: with open(diffusion_model_args_path, 'r') as f: diffusion_model_args = json.load(f) - net_res = EDMPrecond(**diffusion_model_args) + net_res = EDMPrecondSR(**diffusion_model_args) _ = load_checkpoint( path=res_ckpt_path, @@ -129,7 +129,8 @@ def main(cfg: DictConfig) -> None: device=dist.device ) - net_res = net_res.eval().to(device).to(memory_format=torch.channels_last) + #TODO fix to use channels_last which is optimal for H100 + net_res = net_res.eval().to(device)#.to(memory_format=torch.channels_last) if cfg.generation.perf.force_fp16: net_res.use_fp16 = True else: @@ -155,7 +156,7 @@ def main(cfg: DictConfig) -> None: device=dist.device ) - net_reg = net_reg.eval().to(device).to(memory_format=torch.channels_last) + net_reg = net_reg.eval().to(device)#.to(memory_format=torch.channels_last) if cfg.generation.perf.force_fp16: net_reg.use_fp16 = True else: @@ -184,8 +185,9 @@ def main(cfg: DictConfig) -> None: elif cfg.sampler.type == "stochastic": sampler_fn = partial( stochastic_sampler, - img_shape=img_shape[1], - patch_shape=patch_shape[1], + img_shape=img_shape, + patch_shape_x=patch_shape[0], + patch_shape_y=patch_shape[1], boundary_pix=cfg.sampler.boundary_pix, overlap_pix=cfg.sampler.overlap_pix, ) @@ -194,7 +196,7 @@ def main(cfg: DictConfig) -> None: # Main generation definition - def generate_fn(image_lr, lead_time_label): + def generate_fn(image_lr, labels, lead_time_label): img_shape_y, img_shape_x = img_shape with nvtx.annotate("generate_fn", color="green"): if cfg.generation.sample_res == "full": @@ -208,13 +210,14 @@ def generate_fn(image_lr, lead_time_label): w1=img_shape_x // patch_shape[1], ) torch.cuda.nvtx.range_pop() - image_lr_patch = image_lr_patch.to(memory_format=torch.channels_last) + image_lr_patch = image_lr_patch #.to(memory_format=torch.channels_last) if net_reg: with nvtx.annotate("regression_model", color="yellow"): image_reg = regression_step( net=net_reg, img_lr=image_lr_patch, + labels=labels, latents_shape=( cfg.generation.seed_batch_size, img_out_channels, @@ -238,7 +241,7 @@ def generate_fn(image_lr, lead_time_label): rank_batches=rank_batches, img_lr=image_lr_patch.expand( cfg.generation.seed_batch_size, -1, -1, -1 - ).to(memory_format=torch.channels_last), + ), #.to(memory_format=torch.channels_last), rank=dist.rank, device=device, hr_mean=mean_hr, @@ -282,7 +285,10 @@ def generate_fn(image_lr, lead_time_label): else: return None else: - return image_out + #TODO do this for multi-gpu setting above too + if cfg.generation.inference_mode != "regression": + return image_out, image_reg + return image_out, None # generate images output_path = getattr(cfg.generation.io, "output_path", "./outputs") @@ -311,7 +317,7 @@ def generate_fn(image_lr, lead_time_label): end = torch.cuda.Event(enable_timing=True) times = dataset.time() - for image_tar, image_lr, index, *lead_time_label in iter(data_loader): + for image_tar, image_lr, labels, *lead_time_label in iter(data_loader): time_index += 1 if dist.rank == 0: logger0.info(f"starting index: {time_index}") @@ -327,10 +333,11 @@ def generate_fn(image_lr, lead_time_label): image_lr = ( image_lr.to(device=device) .to(torch.float32) - .to(memory_format=torch.channels_last) + #.to(memory_format=torch.channels_last) ) image_tar = image_tar.to(device=device).to(torch.float32) - image_out = generate_fn(image_lr,lead_time_label) + labels = labels.to(device).to(torch.float32).contiguous() + image_out, image_reg = generate_fn(image_lr,labels,lead_time_label) if dist.rank == 0: batch_size = image_out.shape[0] # write out data in a seperate thread so we don't hold up inferencing @@ -342,6 +349,7 @@ def generate_fn(image_lr, lead_time_label): image_out.cpu(), image_tar.cpu(), image_lr.cpu(), + image_reg.cpu(), ) ) end.record() @@ -368,19 +376,29 @@ def generate_fn(image_lr, lead_time_label): f.close() logger0.info("Generation Completed.") -def save_images(output_path, dataset, image_pred, image_hr, image_lr): +def save_images(output_path, dataset, image_pred, image_hr, image_lr, mean_pred): longitudes = dataset.longitude() latitudes = dataset.latitude() input_channels = dataset.input_channels() output_channels = dataset.output_channels() - image_pred = np.flip(dataset.denormalize_output(image_pred.numpy()),1).reshape(len(output_channels),-1) - image_hr = np.flip(dataset.denormalize_output(image_hr.numpy()),1).reshape(len(output_channels),-1) - image_lr = np.flip(dataset.denormalize_input(image_lr.numpy()),1).reshape(len(input_channels),-1) + image_pred = image_pred.numpy() + image_pred_final = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1).reshape(len(output_channels),-1) + image_pred_first_step = np.flip(dataset.denormalize_output(image_pred[0,::].squeeze()),1).reshape(len(output_channels),-1) + image_pred_mid_step = np.flip(dataset.denormalize_output(image_pred[32,::].squeeze()),1).reshape(len(output_channels),-1) + image_hr = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) + image_lr = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze().numpy()),1).reshape(len(input_channels),-1) + if mean_pred is not None: + mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) + os.makedirs(output_path, exist_ok=True) for idx, channel in enumerate(output_channels): input_channel_idx = input_channels.index(channel) _plot_projection(longitudes,latitudes,image_lr[input_channel_idx,:],os.path.join(output_path,f'{channel.name}-lr.jpg')) _plot_projection(longitudes,latitudes,image_hr[idx,:],os.path.join(output_path,f'{channel.name}-hr.jpg')) - _plot_projection(longitudes,latitudes,image_pred[idx,:],os.path.join(output_path,f'{channel.name}-hr-pred.jpg')) + _plot_projection(longitudes,latitudes,image_pred_final[idx,:],os.path.join(output_path,f'{channel.name}-hr-pred.jpg')) + _plot_projection(longitudes,latitudes,image_pred_first_step[idx,:],os.path.join(output_path,f'{channel.name}-hr-pred-0.jpg')) + _plot_projection(longitudes,latitudes,image_pred_mid_step[idx,:],os.path.join(output_path,f'{channel.name}-hr-pred-mid.jpg')) + if mean_pred is not None: + _plot_projection(longitudes,latitudes,mean_pred[idx,:],os.path.join(output_path,f'{channel.name}-mean-pred.jpg')) def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): diff --git a/src/hirad/models/layers.py b/src/hirad/models/layers.py index ddb23b6..8612da7 100644 --- a/src/hirad/models/layers.py +++ b/src/hirad/models/layers.py @@ -221,6 +221,8 @@ def forward(self, x): padding=f_pad, ) if w is not None: + #TODO during inference, model breaks here for some reason + # current fix is to disable torch.backends.cudnn.enabled = False x = torch.nn.functional.conv2d(x, w, padding=w_pad) if b is not None: x = x.add_(b.reshape(1, -1, 1, 1)) @@ -473,7 +475,6 @@ def forward(self, x, emb): torch.cuda.nvtx.range_push("UNetBlock") orig = x x = self.conv0(silu(self.norm0(x))) - params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype) if self.adaptive_scale: scale, shift = params.chunk(chunks=2, dim=1) diff --git a/src/hirad/train.sh b/src/hirad/train.sh new file mode 100644 index 0000000..a31cec0 --- /dev/null +++ b/src/hirad/train.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +#SBATCH --job-name="testrun" + +### HARDWARE ### +#SBATCH --partition=debug +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=72 +#SBATCH --time=00:30:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/regression_test.log +#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/regression_test.err + +### ENVIRONMENT #### +#SBATCH --uenv=pytorch/v2.6.0:/user-environment +#SBATCH --view=default +#SBATCH -A a-a122 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +echo "Master node : $MASTER_ADDR" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +echo "Master address : $MASTER_ADDR" +export MASTER_ADDR +export MASTER_PORT=29500 +echo "Master port: $MASTER_PORT" + +# Get number of physical cores using Python +PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +# Use SLURM_NTASKS (number of processes to be launched by torchrun) +LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# Compute threads per process +OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +export OMP_NUM_THREADS=$OMP_THREADS +echo "Physical cores: $PHYSICAL_CORES" +echo "Local processes: $LOCAL_PROCS" +echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" + +# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml +srun bash -c " + . ./train_env/bin/activate + python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion.yaml +" \ No newline at end of file diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 9ce619d..37e6110 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -213,7 +213,7 @@ def main(cfg: DictConfig) -> None: model.train().requires_grad_(True).to(dist.device) # TODO write summry from rank=0 possibly - # summary(model, input_size=[(4,img_out_channels,*img_shape),(4,img_in_channels,*img_shape),(4,1),(4,1)]) + # summary(model, input_size=[(1,img_out_channels,*img_shape),(1,img_in_channels,*img_shape),(1,1)]) if dist.rank==0 and not os.path.exists(os.path.join(checkpoint_dir, 'model_args.json')): with open(os.path.join(checkpoint_dir, f'model_args.json'), 'w') as f: diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index b158ec0..4831bdd 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -31,6 +31,7 @@ def regression_step( net: torch.nn.Module, img_lr: torch.Tensor, + labels: torch.Tensor, latents_shape: torch.Size, lead_time_label: torch.Tensor = None, ) -> torch.Tensor: @@ -50,15 +51,15 @@ def regression_step( torch.Tensor: Predicted output at the next time step. """ # Create a tensor of zeros with the given shape and move it to the appropriate device - x_hat = torch.zeros(latents_shape, dtype=torch.float64, device=net.device) - t_hat = torch.tensor(1.0, dtype=torch.float64, device=net.device) + x_hat = torch.zeros(latents_shape, dtype=img_lr.dtype, device=img_lr.device) + t_hat = torch.tensor(1.0, dtype=img_lr.dtype, device=img_lr.device).reshape((1,1,1,1)) # Perform regression on a single batch element with torch.inference_mode(): if lead_time_label is not None: - x = net(x_hat[0:1], img_lr, t_hat, lead_time_label=lead_time_label) + x = net(x_hat, img_lr, t_hat, labels, lead_time_label=lead_time_label) else: - x = net(x_hat[0:1], img_lr, t_hat) + x = net(x_hat, img_lr, t_hat, labels) # If the batch size is greater than 1, repeat the prediction if x_hat.shape[0] > 1: @@ -100,7 +101,7 @@ def diffusion_step( # TODO generalize the module and add defaults torch.Tensor: Generated images concatenated across batches. """ - img_lr = img_lr.to(memory_format=torch.channels_last) + img_lr = img_lr #.to(memory_format=torch.channels_last) # Handling of the high-res mean additional_args = {} @@ -128,7 +129,7 @@ def diffusion_step( # TODO generalize the module and add defaults img_shape[1], ], device=device, - ).to(memory_format=torch.channels_last) + )#.to(memory_format=torch.channels_last) with torch.inference_mode(): images = sampler_fn( diff --git a/src/hirad/utils/stochastic_sampler.py b/src/hirad/utils/stochastic_sampler.py index ddcf9cc..ac5c13b 100644 --- a/src/hirad/utils/stochastic_sampler.py +++ b/src/hirad/utils/stochastic_sampler.py @@ -292,8 +292,9 @@ def stochastic_sampler( img_lr: Tensor, class_labels: Optional[Tensor] = None, randn_like: Callable[[Tensor], Tensor] = torch.randn_like, - img_shape: int = 448, - patch_shape: int = 448, + img_shape: tuple[int,int] = (448,448), + patch_shape_x: int = 448, + patch_shape_y: int = 448, overlap_pix: int = 4, boundary_pix: int = 2, mean_hr: Optional[Tensor] = None, @@ -360,12 +361,13 @@ def stochastic_sampler( "Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution." sigma_min = max(sigma_min, net.sigma_min) sigma_max = min(sigma_max, net.sigma_max) - if isinstance(img_shape, tuple): - img_shape_y, img_shape_x = img_shape - else: - img_shape_x = img_shape_y = img_shape - if patch_shape > img_shape_x or patch_shape > img_shape_y: - patch_shape = min(img_shape_x, img_shape_y) + # if isinstance(img_shape, tuple): + # img_shape_y, img_shape_x = img_shape + # else: + # img_shape_x = img_shape_y = img_shape + img_shape_x, img_shape_y = img_shape + patch_shape_x = min(img_shape_x, patch_shape_x) + patch_shape_y = min(img_shape_y, patch_shape_y) # Time step discretization. step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) @@ -394,16 +396,16 @@ def stochastic_sampler( global_index = None # input and position padding + patching - if patch_shape != img_shape_x or patch_shape != img_shape_y: + if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: input_interp = torch.nn.functional.interpolate( - img_lr, (patch_shape, patch_shape), mode="bilinear" + img_lr, (patch_shape_x, patch_shape_y), mode="bilinear" ) x_lr = image_batching( x_lr, img_shape_y, img_shape_x, - patch_shape, - patch_shape, + patch_shape_x, + patch_shape_y, batch_size, overlap_pix, boundary_pix, @@ -413,8 +415,8 @@ def stochastic_sampler( grid.float(), img_shape_y, img_shape_x, - patch_shape, - patch_shape, + patch_shape_x, + patch_shape_y, batch_size, overlap_pix, boundary_pix, @@ -433,13 +435,13 @@ def stochastic_sampler( # Euler step. Perform patching operation on score tensor if patch-based generation is used # denoised = net(x_hat, t_hat, class_labels,lead_time_label=lead_time_label).to(torch.float64) #x_lr - if patch_shape != img_shape_x or patch_shape != img_shape_y: + if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: x_hat_batch = image_batching( x_hat, img_shape_y, img_shape_x, - patch_shape, - patch_shape, + patch_shape_x, + patch_shape_y, batch_size, overlap_pix, boundary_pix, @@ -461,6 +463,12 @@ def stochastic_sampler( global_index=global_index, ).to(torch.float64) else: + # print("Sizes") + # print(x_hat_batch.shape) + # print(x_lr.shape) + # print(t_hat) + # print(class_labels) + # print(global_index) denoised = net( x_hat_batch, x_lr, @@ -468,14 +476,14 @@ def stochastic_sampler( class_labels, global_index=global_index, ).to(torch.float64) - if patch_shape != img_shape_x or patch_shape != img_shape_y: + if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: denoised = image_fuse( denoised, img_shape_y, img_shape_x, - patch_shape, - patch_shape, + patch_shape_x, + patch_shape_y, batch_size, overlap_pix, boundary_pix, @@ -485,13 +493,13 @@ def stochastic_sampler( # Apply 2nd order correction. if i < num_steps - 1: - if patch_shape != img_shape_x or patch_shape != img_shape_y: + if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: x_next_batch = image_batching( x_next, img_shape_y, img_shape_x, - patch_shape, - patch_shape, + patch_shape_x, + patch_shape_y, batch_size, overlap_pix, boundary_pix, @@ -517,13 +525,13 @@ def stochastic_sampler( class_labels, global_index=global_index, ).to(torch.float64) - if patch_shape != img_shape_x or patch_shape != img_shape_y: + if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: denoised = image_fuse( denoised, img_shape_y, img_shape_x, - patch_shape, - patch_shape, + patch_shape_x, + patch_shape_y, batch_size, overlap_pix, boundary_pix, From 83716f45821e6f15f04422b140bc90f493a3be79 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Thu, 15 May 2025 15:55:32 +0200 Subject: [PATCH 37/66] clean up --- src/hirad/conf/generation/era_cosmo.yaml | 11 ++--- src/hirad/testrun.sh | 51 ------------------------ 2 files changed, 6 insertions(+), 56 deletions(-) delete mode 100644 src/hirad/testrun.sh diff --git a/src/hirad/conf/generation/era_cosmo.yaml b/src/hirad/conf/generation/era_cosmo.yaml index 5179520..a0c5a40 100644 --- a/src/hirad/conf/generation/era_cosmo.yaml +++ b/src/hirad/conf/generation/era_cosmo.yaml @@ -2,7 +2,7 @@ num_ensembles: 64 # Number of ensembles to generate per input seed_batch_size: 1 # Size of the batched inference -inference_mode: regression +inference_mode: all # Choose between "all" (regression + diffusion), "regression" or "diffusion" # Patch size. Patch-based sampling will be utilized if these dimensions differ from # img_shape_x and img_shape_y @@ -11,7 +11,7 @@ overlap_pixels: 0 boundary_pixels: 0 # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary # artifact. -hr_mean_conditioning: False +hr_mean_conditioning: True sample_res: full # Sampling resolution times_range: null @@ -19,10 +19,10 @@ times: - 20160101-0000 perf: - force_fp16: false + force_fp16: False # Whether to force fp16 precision for the model. If false, it'll use the precision # specified upon training. - use_torch_compile: false + use_torch_compile: False # whether to use torch.compile on the diffusion model # this will make the first time stamp generation very slow due to compilation overheads # but will significantly speed up subsequent inference runs @@ -31,8 +31,9 @@ perf: # To support multiple workers a threadsafe version of the netCDF library must be used io: - res_ckpt_path: null + res_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/diffusion_test/checkpoints_diffusion # Checkpoint filename for the diffusion model reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_overfit/checkpoints_regression + # reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_test/checkpoints_regression # Checkpoint filename for the mean predictor model output_path: ./images \ No newline at end of file diff --git a/src/hirad/testrun.sh b/src/hirad/testrun.sh deleted file mode 100644 index ac631c8..0000000 --- a/src/hirad/testrun.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name="testrun" - -### HARDWARE ### -#SBATCH --partition=debug -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=1 -#SBATCH --cpus-per-task=72 -#SBATCH --time=00:30:00 -#SBATCH --no-requeue -#SBATCH --exclusive - -### OUTPUT ### -#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/regression_test.log -#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/regression_test.err - -### ENVIRONMENT #### -#SBATCH --uenv=pytorch/v2.6.0:/user-environment -#SBATCH --view=default -#SBATCH -A a-a01 - -# Choose method to initialize dist in pythorch -export DISTRIBUTED_INITIALIZATION_METHOD=SLURM - -MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" -echo "Master node : $MASTER_ADDR" -# Get IP for hostname. -MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" -echo "Master address : $MASTER_ADDR" -export MASTER_ADDR -export MASTER_PORT=29500 -echo "Master port: $MASTER_PORT" - -# Get number of physical cores using Python -PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") -# Use SLURM_NTASKS (number of processes to be launched by torchrun) -LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} -# Compute threads per process -OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) -export OMP_NUM_THREADS=$OMP_THREADS -echo "Physical cores: $PHYSICAL_CORES" -echo "Local processes: $LOCAL_PROCS" -echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" - -# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml -srun bash -c " - . ./train_env/bin/activate - python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml -" \ No newline at end of file From 69f10ddc0f3964eebb116233489d0d55f6b34077 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Thu, 15 May 2025 16:45:29 +0200 Subject: [PATCH 38/66] Add power spectrum plots --- src/hirad/eval/metrics.py | 39 +++++++- src/hirad/eval/plotting.py | 11 ++- src/hirad/eval/run_scoring.py | 165 +++++++++++++++++++++++++++------- 3 files changed, 178 insertions(+), 37 deletions(-) diff --git a/src/hirad/eval/metrics.py b/src/hirad/eval/metrics.py index e6e1afb..1170ea1 100644 --- a/src/hirad/eval/metrics.py +++ b/src/hirad/eval/metrics.py @@ -1,6 +1,10 @@ +import logging + import numpy as np import torch +from scipy.signal import periodogram + # set up MAE calculation to be run for each channel for a given date/time (for target COSMO, prediction, and ERA interpolated) @@ -9,7 +13,7 @@ # Extracted from physicsnemo/examples/weather/regen/paper_figures/score_inference.py def absolute_error(pred, target) -> tuple[float, np.ndarray]: - return torch.abs(pred-target) + return np.abs(pred-target) def compute_mae(pred, target): # Exclude any target NaNs (not expected, but precautionary) @@ -20,5 +24,34 @@ def compute_mae(pred, target): ae = absolute_error(pred, target) - return torch.mean(absolute_error(pred, target)), ae - + # TODO, consider adding axis=-1 to choose what axis to average + return np.mean(absolute_error(pred, target)), ae + +def average_power_spectrum(data: np.ndarray, d=2.0): # d=2km by default + """ + Compute the average power spectrum of a data array. + + This function calculates the power spectrum for each row of the input data and + then averages them to obtain the overall power spectrum, repeating until + dimensionality is reduced to 1D. + The power spectrum represents the distribution of signal power as a function of frequency. + + Parameters: + data (numpy.ndarray): Input data array. + d (float): Sampling interval (time between data points). + + Returns: + tuple: A tuple containing the frequency values and the average power spectrum. + - freqs (numpy.ndarray): Frequency values corresponding to the power spectrum. + - power_spectra (numpy.ndarray): Average power spectrum of the input data. + """ + # Compute the power spectrum along the highest dimension for each row + freqs, power_spectra = periodogram(data, fs=1 / d, axis=-1) + logging.info(f'freqs.shape={freqs.shape}, power_spectra.shape={power_spectra.shape}') + + # Average along the first dimension + while power_spectra.ndim > 1: + power_spectra = power_spectra.mean(axis=0) + logging.info(f'power spectra shape={power_spectra.shape}') + + return freqs, power_spectra diff --git a/src/hirad/eval/plotting.py b/src/hirad/eval/plotting.py index 262109f..1ca11c2 100644 --- a/src/hirad/eval/plotting.py +++ b/src/hirad/eval/plotting.py @@ -15,9 +15,16 @@ def plot_error_projection(values: np.array, latitudes: np.array, longitudes: np. plt.savefig(filename) plt.close('all') -def plot_power_spectrum(x, filename): +def plot_power_spectra(freqs: dict, spec: dict, channel_name, filename): fig = plt.figure() - plt.psd(x) + for k in freqs.keys(): + plt.loglog(freqs[k], spec[k], label=k) + plt.title(channel_name) + plt.legend() + plt.xlabel("Frequency (1/km)") + plt.ylabel("Power Spectrum") + plt.ylim(bottom=1e-1) + #plt.psd(x) logging.info(f'plotting values to {filename}') plt.savefig(filename) plt.close('all') \ No newline at end of file diff --git a/src/hirad/eval/run_scoring.py b/src/hirad/eval/run_scoring.py index df37a9a..da98a04 100644 --- a/src/hirad/eval/run_scoring.py +++ b/src/hirad/eval/run_scoring.py @@ -2,23 +2,22 @@ import sys import metrics +import numpy as np import plotting import torch import yaml +X = 352 # length of grid from N-S +Y = 544 # length of grid from E-W def main(): - if len(sys.argv) < 4: - raise ValueError('Expected call run_scoring.py [input data directory] [predictions directory] [date]') + # TODO: Better arg parsing. + if len(sys.argv) < 3: + raise ValueError('Expected call run_scoring.py [input data directory] [predictions directory] [output plot directory]') input_directory = sys.argv[1] predictions_directory = sys.argv[2] - date = sys.argv[3] - - target = torch.load(os.path.join(input_directory, 'cosmo', date), weights_only=False) - baseline = torch.load(os.path.join(input_directory, 'era-interpolated', date), weights_only=False) - prediction = torch.load(os.path.join(predictions_directory, date), weights_only=False) - lat_lon = torch.load(os.path.join(input_directory, 'info', 'cosmo-lat-lon'), weights_only=False) + output_directory = sys.argv[3] with open(os.path.join(input_directory, 'info', 'cosmo.yaml')) as cosmo_file: cosmo_config = yaml.safe_load(cosmo_file) @@ -28,37 +27,139 @@ def main(): era_config = yaml.safe_load(era_file) input_channels = era_config['select'] - # Reshape predictions, if necessary - # target is shape [channels, ensembles, points] - # prediction is shape [channels, ensembles, x, y] - prediction_1d = prediction.reshape(*target.shape) - prediction_2d = prediction.reshape(prediction.shape[0],352,544) - + lat_lon = torch.load(os.path.join(input_directory, 'info', 'cosmo-lat-lon'), weights_only=False) latitudes = lat_lon[:,0] longitudes = lat_lon[:,1] - - # convert to torch - target = torch.from_numpy(target) - baseline = torch.from_numpy(baseline) - prediction_1d = torch.from_numpy(prediction_1d) - # plot errors + # Iterate over all files in the ground truth directory + files = os.listdir(os.path.join(input_directory, 'cosmo')) + files = sorted(files) + + + # Plot power spectra + # TODO: Handle ensembles + prediction_tensor = np.ndarray([len(files), len(target_channels), X, Y]) + baseline_tensor = np.ndarray([len(files), len(input_channels), X, Y]) + target_tensor = np.ndarray([len(files), len(target_channels), X, Y]) + + for i in range(len(files)): + datetime = files[i] + target = torch.load(os.path.join(input_directory, 'cosmo', datetime), weights_only=False) + baseline = torch.load(os.path.join(input_directory, 'era-interpolated', datetime), weights_only=False) + prediction = torch.load(os.path.join(predictions_directory, datetime), weights_only=False) + + # TODO: Handle ensembles + prediction_1d = prediction.reshape(prediction.shape[0], X*Y) + prediction_2d = prediction.reshape(prediction.shape[0], X, Y) + + baseline_1d = baseline.reshape(baseline.shape[0], X*Y) + baseline_2d = baseline.reshape(baseline.shape[0], X, Y) + + target_1d = target.reshape(target.shape[0], X*Y) + target_2d = target.reshape(target.shape[0], X, Y) + + baseline_tensor[i, :] = baseline_2d + prediction_tensor[i, :] = prediction_2d + target_tensor[i,:] = target_2d + + + # Calc spectra for t_c in range(len(target_channels)): b_c = input_channels.index(target_channels[t_c]) + freqs = {} + power = {} if b_c > -1: - baseline_mae, baseline_errors = metrics.compute_mae(baseline[b_c,:,:], target[t_c,:,:]) - plotting.plot_error_projection(baseline_errors, latitudes, longitudes, os.path.join('plots/errors/', 'baseline', target_channels[t_c] + '-' + date)) - plotting.plot_power_spectrum(baseline[b_c,:,:], os.path.join('plots/spectra/', 'baseline', target_channels[t_c] + date)) - prediction_mae, prediction_errors = metrics.compute_mae(prediction_1d[t_c,:,:], target[t_c,:,:]) - plotting.plot_error_projection(prediction_errors, latitudes, longitudes, os.path.join('plots/errors/', 'prediction', target_channels[t_c] + '-' + date)) - plotting.plot_power_spectrum(prediction[t_c,0,:], os.path.join('plots/spectra/', 'prediction', target_channels[t_c] + date)) - plotting.plot_power_spectrum(prediction_2d[t_c,:,:], os.path.join('plots/spectra/', 'prediction2d', target_channels[t_c] + date)) - print(f'baseline MAE={baseline_mae}, prediction MAE={prediction_mae}') + b_freq, b_power = metrics.average_power_spectrum(baseline_tensor[:,b_c,:,:].squeeze(), 2.0) + freqs['baseline'] = b_freq + power['baseline'] = b_power + #plotting.plot_power_spectrum(b_freq, b_power, target_channels[t_c], os.path.join('plots/spectra/baseline2dt', target_channels[t_c] + '-all_dates')) + t_freq, t_power = metrics.average_power_spectrum(target_tensor[:,t_c,:,:].squeeze(), 2.0) + freqs['target'] = t_freq + power['target'] = t_power + #p_freq, p_power = metrics.average_power_spectrum(prediction_tensor[:,t_c,:,:].squeeze(), 2.0) + #freqs['prediction'] = p_freq + #power['prediction'] = p_power + plotting.plot_power_spectra(freqs, power, target_channels[t_c], os.path.join(output_directory, 'spectra', target_channels[t_c] + '-alldates')) + + # store MAE as tensor of date:channel:ensembles:points + # TODO: Handle ensembles + baseline_absolute_error = np.ndarray([len(files),len(target_channels),1,X*Y]) + prediction_absolute_error = np.ndarray([len(files),len(target_channels),1,X*Y]) + + for i in range(len(files)): + datetime = files[i] + target = torch.load(os.path.join(input_directory, 'cosmo', datetime), weights_only=False) + baseline = torch.load(os.path.join(input_directory, 'era-interpolated', datetime), weights_only=False) + prediction = torch.load(os.path.join(predictions_directory, datetime), weights_only=False) + + + prediction_1d = prediction.reshape(prediction.shape[0], 1, X*Y) + prediction_2d = prediction.reshape(prediction.shape[0], 1, X, Y) + + # Get MAE + for t_c in range(len(target_channels)): + b_c = input_channels.index(target_channels[t_c]) + if b_c > -1: + _, baseline_errors = metrics.compute_mae(baseline[b_c,:,:], target[t_c,:,:]) + baseline_absolute_error[i, t_c, :, :] = baseline_errors + #plotting.plot_error_projection(baseline_errors, latitudes, longitudes, os.path.join('plots/errors/', 'baseline', target_channels[t_c] + '-' + date)) + #plotting.plot_power_spectrum(baseline[b_c,:,:], os.path.join('plots/spectra/', 'baseline', target_channels[t_c] + date)) + _, prediction_errors = metrics.compute_mae(prediction_1d[t_c,:,:], target[t_c,:,:]) + prediction_absolute_error[i, t_c, :, :] = prediction_errors + #plotting.plot_error_projection(prediction_errors, latitudes, longitudes, os.path.join('plots/errors/', 'prediction', target_channels[t_c] + '-' + date)) + #plotting.plot_power_spectrum(prediction[t_c,0,:], os.path.join('plots/spectra/', 'prediction', target_channels[t_c] + date)) + #plotting.plot_power_spectrum(prediction_2d[t_c,:,:], os.path.join('plots/spectra/', 'prediction2d', target_channels[t_c] + date)) + + + print(f'baseline_absolute_error.shape={baseline_absolute_error.shape}, prediction_absolute_error.shape={prediction_absolute_error.shape}') + # Average errors over ensembles + baseline_mae = np.mean(baseline_absolute_error, axis=2) + prediction_mae = np.mean(prediction_absolute_error, axis=2) + + # Average errors over time + baseline_mae = np.mean(baseline_mae, axis=0) + prediction_mae = np.mean(prediction_mae, axis = 0) + + print(f'baseline mean error = {np.mean(baseline_mae, axis=-1)}') + print(f'prediction mean error = {np.mean(prediction_mae, axis=-1)}') + + # Plot the mean error onto the grid. + for t_c in range(len(target_channels)): + plotting.plot_error_projection(baseline_mae[t_c,:], latitudes, longitudes, os.path.join(output_directory, 'baseline-error' + target_channels[t_c] + '-' + 'average_over_time')) + plotting.plot_error_projection(prediction_mae[t_c,:], latitudes, longitudes, os.path.join(output_directory, 'prediction-error' + target_channels[t_c] + '-' + 'average_over_time')) + + + + + #for i in range(4): + # dates = ['20160101-0000', '20160115-0000', '20160201-0000', '20160215-0000'] + # pred = torch.load(os.path.join(predictions_directory, dates[i]), weights_only=False) + # base = torch.load(os.path.join(input_directory, 'era-interpolated', dates[i]), weights_only=False) + # pred_2d = pred.reshape(pred.shape[0],352,544) + # base_2d = base.reshape(baseline.shape[0],352,544) + # base_2d = np.transpose(base_2d, (0,-1,-2)) + # preds_tensor[i,:] = pred_2d + # baseline_tensor[i,:] = base_2d + #for t_c in range(len(target_channels)): + # freq, power = metrics.average_power_spectrum(baseline_tensor[:,t_c,:,:].squeeze(), 2) + # b_c = input_channels.index(target_channels[t_c]) + ## if b_c > -1: + # plotting.plot_power_spectrum(freq, power, target_channels[t_c], os.path.join('plots/spectra/baseline2dt', target_channels[t_c] + date)) + + + # plot errors + #for t_c in range(len(target_channels)): + # b_c = input_channels.index(target_channels[t_c]) + # if b_c > -1: + # baseline_mae, baseline_errors = metrics.compute_mae(baseline[b_c,:,:], target[t_c,:,:]) + # plotting.plot_error_projection(baseline_errors, latitudes, longitudes, os.path.join('plots/errors/', 'baseline', target_channels[t_c] + '-' + date)) + # #plotting.plot_power_spectrum(baseline[b_c,:,:], os.path.join('plots/spectra/', 'baseline', target_channels[t_c] + date)) + # prediction_mae, prediction_errors = metrics.compute_mae(prediction_1d[t_c,:,:], target[t_c,:,:]) + # plotting.plot_error_projection(prediction_errors, latitudes, longitudes, os.path.join('plots/errors/', 'prediction', target_channels[t_c] + '-' + date)) + # #plotting.plot_power_spectrum(prediction[t_c,0,:], os.path.join('plots/spectra/', 'prediction', target_channels[t_c] + date)) + #plotting.plot_power_spectrum(prediction_2d[t_c,:,:], os.path.join('plots/spectra/', 'prediction2d', target_channels[t_c] + date)) + - # Plot power spectra - freq, power = metrics.compute_power_spectrum(prediction, 1) - plotting.plot_power_spectrum(prediction, 'plots/errors/powerspec-prediction') - plotting.plot_power_spectrum(prediction, 'plots/errors/powerspec-prediction') From b4c97c51baf4d543f6365d193afeab68b01dd317 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Thu, 15 May 2025 16:46:17 +0200 Subject: [PATCH 39/66] clean up a bit --- src/hirad/eval/run_scoring.py | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/src/hirad/eval/run_scoring.py b/src/hirad/eval/run_scoring.py index da98a04..fd7c2ab 100644 --- a/src/hirad/eval/run_scoring.py +++ b/src/hirad/eval/run_scoring.py @@ -129,39 +129,5 @@ def main(): plotting.plot_error_projection(prediction_mae[t_c,:], latitudes, longitudes, os.path.join(output_directory, 'prediction-error' + target_channels[t_c] + '-' + 'average_over_time')) - - - #for i in range(4): - # dates = ['20160101-0000', '20160115-0000', '20160201-0000', '20160215-0000'] - # pred = torch.load(os.path.join(predictions_directory, dates[i]), weights_only=False) - # base = torch.load(os.path.join(input_directory, 'era-interpolated', dates[i]), weights_only=False) - # pred_2d = pred.reshape(pred.shape[0],352,544) - # base_2d = base.reshape(baseline.shape[0],352,544) - # base_2d = np.transpose(base_2d, (0,-1,-2)) - # preds_tensor[i,:] = pred_2d - # baseline_tensor[i,:] = base_2d - #for t_c in range(len(target_channels)): - # freq, power = metrics.average_power_spectrum(baseline_tensor[:,t_c,:,:].squeeze(), 2) - # b_c = input_channels.index(target_channels[t_c]) - ## if b_c > -1: - # plotting.plot_power_spectrum(freq, power, target_channels[t_c], os.path.join('plots/spectra/baseline2dt', target_channels[t_c] + date)) - - - # plot errors - #for t_c in range(len(target_channels)): - # b_c = input_channels.index(target_channels[t_c]) - # if b_c > -1: - # baseline_mae, baseline_errors = metrics.compute_mae(baseline[b_c,:,:], target[t_c,:,:]) - # plotting.plot_error_projection(baseline_errors, latitudes, longitudes, os.path.join('plots/errors/', 'baseline', target_channels[t_c] + '-' + date)) - # #plotting.plot_power_spectrum(baseline[b_c,:,:], os.path.join('plots/spectra/', 'baseline', target_channels[t_c] + date)) - # prediction_mae, prediction_errors = metrics.compute_mae(prediction_1d[t_c,:,:], target[t_c,:,:]) - # plotting.plot_error_projection(prediction_errors, latitudes, longitudes, os.path.join('plots/errors/', 'prediction', target_channels[t_c] + '-' + date)) - # #plotting.plot_power_spectrum(prediction[t_c,0,:], os.path.join('plots/spectra/', 'prediction', target_channels[t_c] + date)) - #plotting.plot_power_spectrum(prediction_2d[t_c,:,:], os.path.join('plots/spectra/', 'prediction2d', target_channels[t_c] + date)) - - - - - if __name__ == "__main__": main() \ No newline at end of file From 90c7e28719a83d8534c102240faa756b0883f01e Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Thu, 15 May 2025 16:47:28 +0200 Subject: [PATCH 40/66] clean up a bit --- src/hirad/eval/run_scoring.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/hirad/eval/run_scoring.py b/src/hirad/eval/run_scoring.py index fd7c2ab..4f2fcd8 100644 --- a/src/hirad/eval/run_scoring.py +++ b/src/hirad/eval/run_scoring.py @@ -76,7 +76,8 @@ def main(): t_freq, t_power = metrics.average_power_spectrum(target_tensor[:,t_c,:,:].squeeze(), 2.0) freqs['target'] = t_freq power['target'] = t_power - #p_freq, p_power = metrics.average_power_spectrum(prediction_tensor[:,t_c,:,:].squeeze(), 2.0) + p_freq, p_power = metrics.average_power_spectrum(prediction_tensor[:,t_c,:,:].squeeze(), 2.0) + # TODO: Uncomment when we have predictions #freqs['prediction'] = p_freq #power['prediction'] = p_power plotting.plot_power_spectra(freqs, power, target_channels[t_c], os.path.join(output_directory, 'spectra', target_channels[t_c] + '-alldates')) @@ -102,14 +103,9 @@ def main(): if b_c > -1: _, baseline_errors = metrics.compute_mae(baseline[b_c,:,:], target[t_c,:,:]) baseline_absolute_error[i, t_c, :, :] = baseline_errors - #plotting.plot_error_projection(baseline_errors, latitudes, longitudes, os.path.join('plots/errors/', 'baseline', target_channels[t_c] + '-' + date)) - #plotting.plot_power_spectrum(baseline[b_c,:,:], os.path.join('plots/spectra/', 'baseline', target_channels[t_c] + date)) _, prediction_errors = metrics.compute_mae(prediction_1d[t_c,:,:], target[t_c,:,:]) prediction_absolute_error[i, t_c, :, :] = prediction_errors - #plotting.plot_error_projection(prediction_errors, latitudes, longitudes, os.path.join('plots/errors/', 'prediction', target_channels[t_c] + '-' + date)) - #plotting.plot_power_spectrum(prediction[t_c,0,:], os.path.join('plots/spectra/', 'prediction', target_channels[t_c] + date)) - #plotting.plot_power_spectrum(prediction_2d[t_c,:,:], os.path.join('plots/spectra/', 'prediction2d', target_channels[t_c] + date)) - + print(f'baseline_absolute_error.shape={baseline_absolute_error.shape}, prediction_absolute_error.shape={prediction_absolute_error.shape}') # Average errors over ensembles From a9056027b6ea337c6417a2997e6a9419740e33b8 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Thu, 15 May 2025 17:49:32 +0200 Subject: [PATCH 41/66] add readme for training --- README.md | 110 ++++++++++++++++++ .../conf/training/era_cosmo_diffusion.yaml | 2 +- src/hirad/train_diffusion.sh | 45 +++++++ src/hirad/{train.sh => train_regression.sh} | 16 +-- 4 files changed, 161 insertions(+), 12 deletions(-) create mode 100644 src/hirad/train_diffusion.sh rename src/hirad/{train.sh => train_regression.sh} (73%) diff --git a/README.md b/README.md index e69de29..3e66062 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,110 @@ +# HiRAD-Gen + +HiRAD-Gen is short for high-resolution atmospheric downscaling using generative models. This repository contains the code and configuration required to train and use the model. + +## Installation (Alps) + +To set up the environment for **HiRAD-Gen** on Alps supercomputer, follow these steps: + +1. **Start the PyTorch user environment**: + ```bash + uenv start pytorch/v2.6.0:v1 --view=default + ``` + +2. **Create a Python virtual environment** (replace `{env_name}` with your desired environment name): + ```bash + python -m venv ./{env_name} + ``` + +3. **Activate the virtual environment**: + ```bash + source ./{env_name}/bin/activate + ``` + +4. **Install project dependencies**: + ```bash + pip install -e . + ``` + +This will set up the necessary environment to run HiRAD-Gen within the Alps infrastructure. + +## Run regression model training (Alps) + +1. Script for running the training of regression model is in `src/hirad/train_regression.sh`. +Inside this script set the following: +```bash +### OUTPUT ### +#SBATCH --output=your_path_to_output_log +#SBATCH --error=your_path_to_output_error +``` +```bash +#SBATCH -A your_compute_group +``` +```bash +srun bash -c " + . ./{your_env_name}/bin/activate + python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml +" +``` + +2. Setup the following config files in `src/hirad/conf`: + +- In `training_era_cosmo_regression.yaml` set: +``` +hydra: + run: + dir: your_path_to_save_training_output +``` +- In `training/era_cosmo_regression.yaml` set: +``` +hp: + training_duration: number of samples to train for (set to 4 for debugging, 512 fits into 30 minutes on 1 gpu with total_batch_size: 4) +``` +- In `dataset/era_cosmo.yaml` set the `dataset_path` if different from default. + +3. Submit the job with: +```bash +sbatch src/hirad/train_regression.sh +``` + +## Run diffusion model training (Alps) +Before training diffusion model, checkpoint for regression model has to exist. + +1. Script for running the training of diffusion model is in `src/hirad/train_diffusion.sh`. +Inside this script set the following: +```bash +### OUTPUT ### +#SBATCH --output=your_path_to_output_log +#SBATCH --error=your_path_to_output_error +``` +```bash +#SBATCH -A your_compute_group +``` +```bash +srun bash -c " + . ./{your_env_name}/bin/activate + python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion.yaml +" +``` + +2. Setup the following config files in `src/hirad/conf`: + +- In `training_era_cosmo_diffusion.yaml` set: +``` +hydra: + run: + dir: your_path_to_save_training_output +``` +- In `training/era_cosmo_regression.yaml` set: +``` +hp: + training_duration: number of samples to train for (set to 4 for debugging, 512 fits into 30 minutes on 1 gpu with total_batch_size: 4) +io: + regression_checkpoint_path: path_to_directory_containing_regression_training_model_checkpoints +``` +- In `dataset/era_cosmo.yaml` set the `dataset_path` if different from default. + +3. Submit the job with: +```bash +sbatch src/hirad/train_diffusion.sh +``` \ No newline at end of file diff --git a/src/hirad/conf/training/era_cosmo_diffusion.yaml b/src/hirad/conf/training/era_cosmo_diffusion.yaml index b06ec61..f8d19e6 100644 --- a/src/hirad/conf/training/era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training/era_cosmo_diffusion.yaml @@ -29,7 +29,7 @@ perf: io: regression_checkpoint_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_overfit/checkpoints_regression # Where to load the regression checkpoint - print_progress_freq: 32 + print_progress_freq: 128 # How often to print progress save_checkpoint_freq: 5000 # How often to save the checkpoints, measured in number of processed samples diff --git a/src/hirad/train_diffusion.sh b/src/hirad/train_diffusion.sh new file mode 100644 index 0000000..cf2f88f --- /dev/null +++ b/src/hirad/train_diffusion.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +#SBATCH --job-name="testrun" + +### HARDWARE ### +#SBATCH --partition=debug +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=72 +#SBATCH --time=00:30:00 +#SBATCH --no-requeue +#SBATCH --exclusive + +### OUTPUT ### +#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/diffusion.log +#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/diffusion.err + +### ENVIRONMENT #### +#SBATCH --uenv=pytorch/v2.6.0:/user-environment +#SBATCH --view=default +#SBATCH -A a-a122 + +# Choose method to initialize dist in pythorch +export DISTRIBUTED_INITIALIZATION_METHOD=SLURM + +# Get master node. +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +# Get IP for hostname. +MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" +export MASTER_ADDR +export MASTER_PORT=29500 + +# Get number of physical cores using Python +PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") +LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} +# Compute cores per process +OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) +export OMP_NUM_THREADS=$OMP_THREADS + +# python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml +srun bash -c " + . ./train_env/bin/activate + python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion.yaml +" \ No newline at end of file diff --git a/src/hirad/train.sh b/src/hirad/train_regression.sh similarity index 73% rename from src/hirad/train.sh rename to src/hirad/train_regression.sh index a31cec0..c065477 100644 --- a/src/hirad/train.sh +++ b/src/hirad/train_regression.sh @@ -13,8 +13,8 @@ #SBATCH --exclusive ### OUTPUT ### -#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/regression_test.log -#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/regression_test.err +#SBATCH --output=/iopsstor/scratch/cscs/pstamenk/logs/regression.log +#SBATCH --error=/iopsstor/scratch/cscs/pstamenk/logs/regression.err ### ENVIRONMENT #### #SBATCH --uenv=pytorch/v2.6.0:/user-environment @@ -24,28 +24,22 @@ # Choose method to initialize dist in pythorch export DISTRIBUTED_INITIALIZATION_METHOD=SLURM +# Get master node. MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" -echo "Master node : $MASTER_ADDR" # Get IP for hostname. MASTER_ADDR="$(getent ahosts "$MASTER_ADDR" | awk '{ print $1; exit }')" -echo "Master address : $MASTER_ADDR" export MASTER_ADDR export MASTER_PORT=29500 -echo "Master port: $MASTER_PORT" # Get number of physical cores using Python PHYSICAL_CORES=$(python -c "import psutil; print(psutil.cpu_count(logical=False))") -# Use SLURM_NTASKS (number of processes to be launched by torchrun) LOCAL_PROCS=${SLURM_NTASKS_PER_NODE:-1} -# Compute threads per process +# Compute cores per process OMP_THREADS=$(( PHYSICAL_CORES / LOCAL_PROCS )) export OMP_NUM_THREADS=$OMP_THREADS -echo "Physical cores: $PHYSICAL_CORES" -echo "Local processes: $LOCAL_PROCS" -echo "Setting OMP_NUM_THREADS=$OMP_NUM_THREADS" # python src/hirad/training/train.py --config-name=training_era_cosmo_testrun.yaml srun bash -c " . ./train_env/bin/activate - python src/hirad/training/train.py --config-name=training_era_cosmo_diffusion.yaml + python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml " \ No newline at end of file From 573dc2387af0851ee16de51c6bbda4e16519d57d Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 16 May 2025 13:16:20 +0200 Subject: [PATCH 42/66] update readme for inference --- README.md | 65 +++++++++++++++++-- .../conf/training_era_cosmo_diffusion.yaml | 4 +- .../conf/training_era_cosmo_regression.yaml | 2 +- 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 3e66062..b0dbd2e 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,9 @@ To set up the environment for **HiRAD-Gen** on Alps supercomputer, follow these This will set up the necessary environment to run HiRAD-Gen within the Alps infrastructure. -## Run regression model training (Alps) +## Training + +### Run regression model training (Alps) 1. Script for running the training of regression model is in `src/hirad/train_regression.sh`. Inside this script set the following: @@ -47,7 +49,7 @@ srun bash -c " " ``` -2. Setup the following config files in `src/hirad/conf`: +2. Set up the following config files in `src/hirad/conf`: - In `training_era_cosmo_regression.yaml` set: ``` @@ -67,7 +69,7 @@ hp: sbatch src/hirad/train_regression.sh ``` -## Run diffusion model training (Alps) +### Run diffusion model training (Alps) Before training diffusion model, checkpoint for regression model has to exist. 1. Script for running the training of diffusion model is in `src/hirad/train_diffusion.sh`. @@ -87,7 +89,7 @@ srun bash -c " " ``` -2. Setup the following config files in `src/hirad/conf`: +2. Set up the following config files in `src/hirad/conf`: - In `training_era_cosmo_diffusion.yaml` set: ``` @@ -107,4 +109,59 @@ io: 3. Submit the job with: ```bash sbatch src/hirad/train_diffusion.sh +``` + +## Inference + +### Running inference on Alps + +1. Script for running the inference is in `src/hirad/generate.sh`. +Inside this script set the following: +```bash +### OUTPUT ### +#SBATCH --output=your_path_to_output_log +#SBATCH --error=your_path_to_output_error +``` +```bash +#SBATCH -A your_compute_group +``` +```bash +srun bash -c " + . ./{your_env_name}/bin/activate + python src/hirad/inference/generate.py --config-name=generate_era_cosmo.yaml +" +``` + +2. Set up the following config files in `src/hirad/conf`: + +- In `generate_era_cosmo.yaml` set: +``` +hydra: + run: + dir: your_path_to_save_inference_output +``` +- In `generation/era_cosmo.yaml`: +Choose the inference mode: +``` +inference_mode: all/regression/diffusion +``` +by default `all` does both regression and diffusion. Depending on mode, regression and/or diffusion model pretrained weights should be provided: +``` +io: + res_ckpt_path: path_to_directory_containing_diffusion_training_model_checkpoints + reg_ckpt_path: path_to_directory_containing_regression_training_model_checkpoints +``` +Finally, from the dataset, subset of time steps can be chosen to do inference for. + +One way is to list steps under `times:` in format `%Y%m%d-%H%M` for era5_cosmo dataset. + +The other way is to specify `times_range:` with three items: first time step (`%Y%m%d-%H%M`), last time step (`%Y%m%d-%H%M`), hour shift (int). Hour shift specifies distance in hours between closest time steps for specific dataset (6 for era_cosmo). + +By default, inference is done for one time step `20160101-0000` + +- In `dataset/era_cosmo.yaml` set the `dataset_path` if different from default. + +3. Submit the job with: +```bash +sbatch src/hirad/generate.sh ``` \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_diffusion.yaml b/src/hirad/conf/training_era_cosmo_diffusion.yaml index 7ee7dba..2c8d37f 100644 --- a/src/hirad/conf/training_era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training_era_cosmo_diffusion.yaml @@ -1,9 +1,9 @@ hydra: job: chdir: true - name: diffusion + name: diffusion_test run: - dir: /scratch/mch/pstamenk/output/${hydra:job.name} + dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} # Get defaults defaults: diff --git a/src/hirad/conf/training_era_cosmo_regression.yaml b/src/hirad/conf/training_era_cosmo_regression.yaml index d857d12..dc498ce 100644 --- a/src/hirad/conf/training_era_cosmo_regression.yaml +++ b/src/hirad/conf/training_era_cosmo_regression.yaml @@ -3,7 +3,7 @@ hydra: chdir: true name: regression run: - dir: /scratch/mch/pstamenk/output/${hydra:job.name} + dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} # Get defaults defaults: From f4d856b0e562e78fa002b5f9780d9c7d878a14d9 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 16 May 2025 18:04:37 +0200 Subject: [PATCH 43/66] small fix for inference on multiple time steps --- .../conf/training_era_cosmo_diffusion.yaml | 2 +- src/hirad/inference/generate.py | 23 +++++++++++-------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/hirad/conf/training_era_cosmo_diffusion.yaml b/src/hirad/conf/training_era_cosmo_diffusion.yaml index 2c8d37f..4271e44 100644 --- a/src/hirad/conf/training_era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training_era_cosmo_diffusion.yaml @@ -1,7 +1,7 @@ hydra: job: chdir: true - name: diffusion_test + name: diffusion run: dir: /iopsstor/scratch/cscs/pstamenk/outputs/${hydra:job.name} diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index 5558a20..7cb9685 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -345,11 +345,12 @@ def generate_fn(image_lr, labels, lead_time_label): writer_executor.submit( save_images, output_path, + times[sampler[time_index]], dataset, image_out.cpu(), image_tar.cpu(), image_lr.cpu(), - image_reg.cpu(), + image_reg.cpu() if image_reg is not None else None, ) ) end.record() @@ -376,15 +377,16 @@ def generate_fn(image_lr, labels, lead_time_label): f.close() logger0.info("Generation Completed.") -def save_images(output_path, dataset, image_pred, image_hr, image_lr, mean_pred): +def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): longitudes = dataset.longitude() latitudes = dataset.latitude() input_channels = dataset.input_channels() output_channels = dataset.output_channels() image_pred = image_pred.numpy() image_pred_final = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1).reshape(len(output_channels),-1) - image_pred_first_step = np.flip(dataset.denormalize_output(image_pred[0,::].squeeze()),1).reshape(len(output_channels),-1) - image_pred_mid_step = np.flip(dataset.denormalize_output(image_pred[32,::].squeeze()),1).reshape(len(output_channels),-1) + if image_pred.shape[0]>1: + image_pred_first_step = np.flip(dataset.denormalize_output(image_pred[0,::].squeeze()),1).reshape(len(output_channels),-1) + image_pred_mid_step = np.flip(dataset.denormalize_output(image_pred[32,::].squeeze()),1).reshape(len(output_channels),-1) image_hr = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) image_lr = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze().numpy()),1).reshape(len(input_channels),-1) if mean_pred is not None: @@ -392,13 +394,14 @@ def save_images(output_path, dataset, image_pred, image_hr, image_lr, mean_pred) os.makedirs(output_path, exist_ok=True) for idx, channel in enumerate(output_channels): input_channel_idx = input_channels.index(channel) - _plot_projection(longitudes,latitudes,image_lr[input_channel_idx,:],os.path.join(output_path,f'{channel.name}-lr.jpg')) - _plot_projection(longitudes,latitudes,image_hr[idx,:],os.path.join(output_path,f'{channel.name}-hr.jpg')) - _plot_projection(longitudes,latitudes,image_pred_final[idx,:],os.path.join(output_path,f'{channel.name}-hr-pred.jpg')) - _plot_projection(longitudes,latitudes,image_pred_first_step[idx,:],os.path.join(output_path,f'{channel.name}-hr-pred-0.jpg')) - _plot_projection(longitudes,latitudes,image_pred_mid_step[idx,:],os.path.join(output_path,f'{channel.name}-hr-pred-mid.jpg')) + _plot_projection(longitudes,latitudes,image_lr[input_channel_idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-lr.jpg')) + _plot_projection(longitudes,latitudes,image_hr[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr.jpg')) + _plot_projection(longitudes,latitudes,image_pred_final[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred.jpg')) + if image_pred.shape[0]>1: + _plot_projection(longitudes,latitudes,image_pred_first_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-0.jpg')) + _plot_projection(longitudes,latitudes,image_pred_mid_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-mid.jpg')) if mean_pred is not None: - _plot_projection(longitudes,latitudes,mean_pred[idx,:],os.path.join(output_path,f'{channel.name}-mean-pred.jpg')) + _plot_projection(longitudes,latitudes,mean_pred[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-mean-pred.jpg')) def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): From dcc2a067c1e7f0a7dc1a2da7b58ee5568d4dd6a6 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Wed, 21 May 2025 13:00:23 +0200 Subject: [PATCH 44/66] enable validation during training --- src/hirad/datasets/dataset.py | 8 ++++---- src/hirad/inference/generate.py | 7 ++++--- src/hirad/training/train.py | 5 +---- src/hirad/utils/inference_utils.py | 6 +++--- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/hirad/datasets/dataset.py b/src/hirad/datasets/dataset.py index 6cc6165..380f797 100644 --- a/src/hirad/datasets/dataset.py +++ b/src/hirad/datasets/dataset.py @@ -36,7 +36,6 @@ def init_train_valid_datasets_from_config( dataloader_cfg: Union[dict, None] = None, batch_size: int = 1, seed: int = 0, - validation_dataset_cfg: Union[dict, None] = None, train_test_split: bool = True, ) -> Tuple[ DownscalingDataset, @@ -59,13 +58,14 @@ def init_train_valid_datasets_from_config( """ config = copy.deepcopy(dataset_cfg) + del config['validation_path'] (dataset, dataset_iter) = init_dataset_from_config( config, dataloader_cfg, batch_size=batch_size, seed=seed ) if train_test_split: - valid_dataset_cfg = copy.deepcopy(config) - if validation_dataset_cfg: - valid_dataset_cfg.update(validation_dataset_cfg) + valid_dataset_cfg = copy.deepcopy(dataset_cfg) + valid_dataset_cfg["dataset_path"] = valid_dataset_cfg["validation_path"] + del valid_dataset_cfg['validation_path'] (valid_dataset, valid_dataset_iter) = init_dataset_from_config( valid_dataset_cfg, dataloader_cfg, batch_size=batch_size, seed=seed ) diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index 7cb9685..ce8ed7b 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -50,7 +50,6 @@ def main(cfg: DictConfig) -> None: # Initialize logger logger = PythonLogger("generate") # General python logger logger0 = RankZeroLoggingWrapper(logger, dist) - # logger.file_logging("generate.log") # Handle the batch size seeds = list(np.arange(cfg.generation.num_ensembles)) @@ -252,7 +251,7 @@ def generate_fn(image_lr, labels, lead_time_label): elif cfg.generation.inference_mode == "diffusion": image_out = image_res else: - image_out = image_reg + image_res + image_out = image_reg[0:1,::] + image_res if cfg.generation.sample_res != "full": image_out = rearrange( @@ -385,8 +384,9 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, image_pred = image_pred.numpy() image_pred_final = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1).reshape(len(output_channels),-1) if image_pred.shape[0]>1: + image_pred_mean = np.flip(dataset.denormalize_output(image_pred.mean(axis=0)),1).reshape(len(output_channels),-1) image_pred_first_step = np.flip(dataset.denormalize_output(image_pred[0,::].squeeze()),1).reshape(len(output_channels),-1) - image_pred_mid_step = np.flip(dataset.denormalize_output(image_pred[32,::].squeeze()),1).reshape(len(output_channels),-1) + image_pred_mid_step = np.flip(dataset.denormalize_output(image_pred[image_pred.shape[0]//2,::].squeeze()),1).reshape(len(output_channels),-1) image_hr = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) image_lr = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze().numpy()),1).reshape(len(input_channels),-1) if mean_pred is not None: @@ -398,6 +398,7 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, _plot_projection(longitudes,latitudes,image_hr[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr.jpg')) _plot_projection(longitudes,latitudes,image_pred_final[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred.jpg')) if image_pred.shape[0]>1: + _plot_projection(longitudes,latitudes,image_pred_mean[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-mean.jpg')) _plot_projection(longitudes,latitudes,image_pred_first_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-0.jpg')) _plot_projection(longitudes,latitudes,image_pred_mid_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-mid.jpg')) if mean_pred is not None: diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 37e6110..664a6a5 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -36,12 +36,10 @@ def main(cfg: DictConfig) -> None: OmegaConf.resolve(cfg) dataset_cfg = OmegaConf.to_container(cfg.dataset) - if hasattr(cfg, "validation"): + if hasattr(cfg.dataset, "validation_path"): train_test_split = True - validation_dataset_cfg = OmegaConf.to_container(cfg.validation) else: train_test_split = False - validation_dataset_cfg = None fp_optimizations = cfg.training.perf.fp_optimizations songunet_checkpoint_level = cfg.training.perf.songunet_checkpoint_level fp16 = fp_optimizations == "fp16" @@ -77,7 +75,6 @@ def main(cfg: DictConfig) -> None: data_loader_kwargs, batch_size=cfg.training.hp.batch_size_per_gpu, seed=0, - validation_dataset_cfg=validation_dataset_cfg, train_test_split=train_test_split, ) logger0.info(f"Training on dataset with size {len(dataset)}") diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index 4831bdd..ace05ba 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -52,14 +52,14 @@ def regression_step( """ # Create a tensor of zeros with the given shape and move it to the appropriate device x_hat = torch.zeros(latents_shape, dtype=img_lr.dtype, device=img_lr.device) - t_hat = torch.tensor(1.0, dtype=img_lr.dtype, device=img_lr.device).reshape((1,1,1,1)) + t_hat = torch.tensor(1.0, dtype=img_lr.dtype, device=img_lr.device)#.reshape((1,1,1,1)) # Perform regression on a single batch element with torch.inference_mode(): if lead_time_label is not None: - x = net(x_hat, img_lr, t_hat, labels, lead_time_label=lead_time_label) + x = net(x_hat[0:1], img_lr, t_hat, labels, lead_time_label=lead_time_label) else: - x = net(x_hat, img_lr, t_hat, labels) + x = net(x_hat[0:1], img_lr, t_hat, labels) # If the batch size is greater than 1, repeat the prediction if x_hat.shape[0] > 1: From cadccd502fb913e18db7d7ea4eba2787f8fca5c5 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Thu, 22 May 2025 15:57:25 +0200 Subject: [PATCH 45/66] change generate eval to new functions --- src/hirad/inference/generate.py | 107 +++++++++++++++++++++----------- 1 file changed, 70 insertions(+), 37 deletions(-) diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index ce8ed7b..fbfd8cf 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -35,6 +35,7 @@ from hirad.utils.train_helpers import set_patch_shape +from hirad.eval import compute_mae, average_power_spectrum, plot_error_projection, plot_power_spectra @hydra.main(version_base="1.2", config_path="../conf", config_name="config_generate") def main(cfg: DictConfig) -> None: @@ -346,10 +347,10 @@ def generate_fn(image_lr, labels, lead_time_label): output_path, times[sampler[time_index]], dataset, - image_out.cpu(), - image_tar.cpu(), - image_lr.cpu(), - image_reg.cpu() if image_reg is not None else None, + image_out.cpu().numpy(), + image_tar.cpu().numpy(), + image_lr.cpu().numpy(), + image_reg.cpu().numpy() if image_reg is not None else None, ) ) end.record() @@ -381,41 +382,73 @@ def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, latitudes = dataset.latitude() input_channels = dataset.input_channels() output_channels = dataset.output_channels() - image_pred = image_pred.numpy() - image_pred_final = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1).reshape(len(output_channels),-1) - if image_pred.shape[0]>1: - image_pred_mean = np.flip(dataset.denormalize_output(image_pred.mean(axis=0)),1).reshape(len(output_channels),-1) - image_pred_first_step = np.flip(dataset.denormalize_output(image_pred[0,::].squeeze()),1).reshape(len(output_channels),-1) - image_pred_mid_step = np.flip(dataset.denormalize_output(image_pred[image_pred.shape[0]//2,::].squeeze()),1).reshape(len(output_channels),-1) - image_hr = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) - image_lr = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze().numpy()),1).reshape(len(input_channels),-1) - if mean_pred is not None: - mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) - os.makedirs(output_path, exist_ok=True) + + target = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze()),1) #.reshape(len(output_channels),-1) + prediction = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1) #.reshape(len(output_channels),-1) + baseline = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze()),1)# .reshape(len(input_channels),-1) + + freqs = {} + power = {} for idx, channel in enumerate(output_channels): input_channel_idx = input_channels.index(channel) - _plot_projection(longitudes,latitudes,image_lr[input_channel_idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-lr.jpg')) - _plot_projection(longitudes,latitudes,image_hr[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr.jpg')) - _plot_projection(longitudes,latitudes,image_pred_final[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred.jpg')) - if image_pred.shape[0]>1: - _plot_projection(longitudes,latitudes,image_pred_mean[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-mean.jpg')) - _plot_projection(longitudes,latitudes,image_pred_first_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-0.jpg')) - _plot_projection(longitudes,latitudes,image_pred_mid_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-mid.jpg')) - if mean_pred is not None: - _plot_projection(longitudes,latitudes,mean_pred[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-mean-pred.jpg')) - -def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): - - """Plot observed or interpolated data in a scatter plot.""" - # TODO: Refactor this somehow, it's not really generalizing well across variables. - fig = plt.figure() - fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) - p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) - ax.coastlines() - ax.gridlines(draw_labels=True) - plt.colorbar(p, label="K", orientation="horizontal") - plt.savefig(filename) - plt.close('all') + _, baseline_errors = compute_mae(baseline[input_channel_idx,:,:], target[idx,:,:]) + _, prediction_errors = compute_mae(prediction[idx,:,:], target[idx,:,:]) + + plot_error_projection(baseline_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-baseline-error.jpg')) + plot_error_projection(prediction_errors.reshape(-1), latitudes, longitudes, os.path.join(output_path, f'{time_step}-{channel.name}-prediction-error.jpg')) + + b_freq, b_power = average_power_spectrum(baseline[input_channel_idx,:,:].squeeze(), 2.0) + freqs['baseline'] = b_freq + power['baseline'] = b_power + #plotting.plot_power_spectrum(b_freq, b_power, target_channels[t_c], os.path.join('plots/spectra/baseline2dt', target_channels[t_c] + '-all_dates')) + t_freq, t_power = average_power_spectrum(target[idx,:,:].squeeze(), 2.0) + freqs['target'] = t_freq + power['target'] = t_power + p_freq, p_power = average_power_spectrum(prediction[idx,:,:].squeeze(), 2.0) + freqs['prediction'] = p_freq + power['prediction'] = p_power + plot_power_spectra(freqs, power, channel.name, os.path.join(output_path, f'{time_step}-{channel.name}-spectra.jpg')) + +# def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): +# longitudes = dataset.longitude() +# latitudes = dataset.latitude() +# input_channels = dataset.input_channels() +# output_channels = dataset.output_channels() +# image_pred = image_pred.numpy() +# image_pred_final = np.flip(dataset.denormalize_output(image_pred[-1,::].squeeze()),1).reshape(len(output_channels),-1) +# if image_pred.shape[0]>1: +# image_pred_mean = np.flip(dataset.denormalize_output(image_pred.mean(axis=0)),1).reshape(len(output_channels),-1) +# image_pred_first_step = np.flip(dataset.denormalize_output(image_pred[0,::].squeeze()),1).reshape(len(output_channels),-1) +# image_pred_mid_step = np.flip(dataset.denormalize_output(image_pred[image_pred.shape[0]//2,::].squeeze()),1).reshape(len(output_channels),-1) +# image_hr = np.flip(dataset.denormalize_output(image_hr[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) +# image_lr = np.flip(dataset.denormalize_input(image_lr[0,::].squeeze().numpy()),1).reshape(len(input_channels),-1) +# if mean_pred is not None: +# mean_pred = np.flip(dataset.denormalize_output(mean_pred[0,::].squeeze().numpy()),1).reshape(len(output_channels),-1) +# os.makedirs(output_path, exist_ok=True) +# for idx, channel in enumerate(output_channels): +# input_channel_idx = input_channels.index(channel) +# _plot_projection(longitudes,latitudes,image_lr[input_channel_idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-lr.jpg')) +# _plot_projection(longitudes,latitudes,image_hr[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr.jpg')) +# _plot_projection(longitudes,latitudes,image_pred_final[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred.jpg')) +# if image_pred.shape[0]>1: +# _plot_projection(longitudes,latitudes,image_pred_mean[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-mean.jpg')) +# _plot_projection(longitudes,latitudes,image_pred_first_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-0.jpg')) +# _plot_projection(longitudes,latitudes,image_pred_mid_step[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-hr-pred-mid.jpg')) +# if mean_pred is not None: +# _plot_projection(longitudes,latitudes,mean_pred[idx,:],os.path.join(output_path,f'{time_step}-{channel.name}-mean-pred.jpg')) + +# def _plot_projection(longitudes: np.array, latitudes: np.array, values: np.array, filename: str, cmap=None, vmin = None, vmax = None): + +# """Plot observed or interpolated data in a scatter plot.""" +# # TODO: Refactor this somehow, it's not really generalizing well across variables. +# fig = plt.figure() +# fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) +# p = ax.scatter(x=longitudes, y=latitudes, c=values, cmap=cmap, vmin=vmin, vmax=vmax) +# ax.coastlines() +# ax.gridlines(draw_labels=True) +# plt.colorbar(p, label="K", orientation="horizontal") +# plt.savefig(filename) +# plt.close('all') if __name__ == "__main__": main() From 010bf26601769a376f66367395fda36fa2867434 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Thu, 22 May 2025 17:14:15 +0200 Subject: [PATCH 46/66] Add pytorch toml files --- .edf/README.md | 10 ++++++++++ .edf/gemma-pytorch.toml | 14 ++++++++++++++ .edf/ngc-pytorch.toml | 6 ++++++ 3 files changed, 30 insertions(+) create mode 100644 .edf/README.md create mode 100644 .edf/gemma-pytorch.toml create mode 100644 .edf/ngc-pytorch.toml diff --git a/.edf/README.md b/.edf/README.md new file mode 100644 index 0000000..4da3cd1 --- /dev/null +++ b/.edf/README.md @@ -0,0 +1,10 @@ +run: +``` +export EDF_PATH=`pwd`/.edf +``` +This adds the repository path to the EDF search path. + +run: +``` +srun -A a-a122 --environment=ubuntu2 cat /etc/os-release +``` \ No newline at end of file diff --git a/.edf/gemma-pytorch.toml b/.edf/gemma-pytorch.toml new file mode 100644 index 0000000..3c4723a --- /dev/null +++ b/.edf/gemma-pytorch.toml @@ -0,0 +1,14 @@ +image = "/iopsstor/scratch/cscs/${USER}/pytorch-24.01-py3-venv/pytorch-24.01-py3-venv.sqsh" + +mounts = ["/capstor", "/users","/iopsstor/scratch/cscs/mmcgloho"] + +writable = true + +[annotations] +com.hooks.aws_ofi_nccl.enabled = "true" +com.hooks.aws_ofi_nccl.variant = "cuda12" + +[env] +FI_CXI_DISABLE_HOST_REGISTER = "1" +FI_MR_CACHE_MONITOR = "userfaultfd" +NCCL_DEBUG = "INFO" diff --git a/.edf/ngc-pytorch.toml b/.edf/ngc-pytorch.toml new file mode 100644 index 0000000..4f790ca --- /dev/null +++ b/.edf/ngc-pytorch.toml @@ -0,0 +1,6 @@ +# https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch +image = "nvcr.io#nvidia/pytorch:22.01-py3" +mounts = ["/capstor/scratch/cscs/${USER}:/capstor/scratch/cscs/${USER}"] +workdir = "/capstor/scratch/cscs/${USER}" + +# Maybe above should be iopsstor \ No newline at end of file From 92b08b7fad3c2fdf09b07ea2a82a69ba9175f62e Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 23 May 2025 14:28:30 +0200 Subject: [PATCH 47/66] fix average training loss tracking --- src/hirad/training/train.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 664a6a5..559d800 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -366,18 +366,6 @@ def main(cfg: DictConfig) -> None: "training_loss_running_mean", average_loss_running_mean, cur_nimg ) - ptt = is_time_for_periodic_task( - cur_nimg, - cfg.training.io.print_progress_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - rank_0_only=True, - ) - if ptt: - # reset running mean of average loss - average_loss_running_mean = 0 - n_average_loss_running_mean = 1 # Update weights. lr_rampup = cfg.training.hp.lr_rampup # ramp up the learning rate @@ -481,6 +469,19 @@ def main(cfg: DictConfig) -> None: logger0.info(" ".join(fields)) torch.cuda.reset_peak_memory_stats() + ptt = is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ) + if ptt: + # reset running mean of average loss + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 + # Save checkpoints if dist.world_size > 1: torch.distributed.barrier() From 97970fcc634f7ca1934cda32ee58b8c27d79dc02 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Fri, 23 May 2025 14:33:58 +0200 Subject: [PATCH 48/66] fix validation bug --- src/hirad/datasets/dataset.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/hirad/datasets/dataset.py b/src/hirad/datasets/dataset.py index 380f797..7ba8833 100644 --- a/src/hirad/datasets/dataset.py +++ b/src/hirad/datasets/dataset.py @@ -58,7 +58,8 @@ def init_train_valid_datasets_from_config( """ config = copy.deepcopy(dataset_cfg) - del config['validation_path'] + if 'validation_path': + del config['validation_path'] (dataset, dataset_iter) = init_dataset_from_config( config, dataloader_cfg, batch_size=batch_size, seed=seed ) @@ -83,6 +84,8 @@ def init_dataset_from_config( ) -> Tuple[DownscalingDataset, Iterable]: dataset_cfg = copy.deepcopy(dataset_cfg) dataset_type = dataset_cfg.pop("type", "era5_cosmo") + if "validation_path" in dataset_cfg: + del dataset_cfg['validation_path'] if "train_test_split" in dataset_cfg: # handled by init_train_valid_datasets_from_config del dataset_cfg["train_test_split"] From 937e7c97e129ee6c0d696ef318300372b17f29a3 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Mon, 26 May 2025 18:20:09 +0200 Subject: [PATCH 49/66] update to latest corrdiff version --- src/hirad/conf/generation/era_cosmo.yaml | 19 +- src/hirad/conf/model/era_cosmo_diffusion.yaml | 12 +- .../conf/model/era_cosmo_regression.yaml | 10 +- src/hirad/conf/model_size/mini.yaml | 26 + src/hirad/conf/model_size/normal.yaml | 26 + src/hirad/conf/sampler/stochastic.yaml | 4 +- .../conf/training_era_cosmo_regression.yaml | 2 + src/hirad/datasets/era5_cosmo.py | 3 +- src/hirad/inference/generate.py | 132 ++- src/hirad/losses/__init__.py | 2 +- src/hirad/losses/loss.py | 630 ++++++----- src/hirad/models/__init__.py | 12 +- src/hirad/models/layers.py | 354 +++++-- src/hirad/models/preconditioning.py | 352 +++++-- src/hirad/models/song_unet.py | 996 ++++++++++++------ src/hirad/models/unet.py | 177 +++- src/hirad/training/train.py | 741 ++++++++----- src/hirad/utils/deterministic_sampler.py | 162 ++- src/hirad/utils/function_utils.py | 37 +- src/hirad/utils/inference_utils.py | 145 ++- src/hirad/utils/patching.py | 767 ++++++++++++++ src/hirad/utils/stochastic_sampler.py | 524 +++------ src/hirad/utils/train_helpers.py | 14 +- 23 files changed, 3584 insertions(+), 1563 deletions(-) create mode 100644 src/hirad/conf/model_size/mini.yaml create mode 100644 src/hirad/conf/model_size/normal.yaml create mode 100644 src/hirad/utils/patching.py diff --git a/src/hirad/conf/generation/era_cosmo.yaml b/src/hirad/conf/generation/era_cosmo.yaml index a0c5a40..be4219d 100644 --- a/src/hirad/conf/generation/era_cosmo.yaml +++ b/src/hirad/conf/generation/era_cosmo.yaml @@ -1,22 +1,26 @@ -num_ensembles: 64 +num_ensembles: 8 # Number of ensembles to generate per input -seed_batch_size: 1 +seed_batch_size: 4 # Size of the batched inference inference_mode: all # Choose between "all" (regression + diffusion), "regression" or "diffusion" # Patch size. Patch-based sampling will be utilized if these dimensions differ from # img_shape_x and img_shape_y -overlap_pixels: 0 +# overlap_pixels: 0 # Number of overlapping pixels between adjacent patches -boundary_pixels: 0 +# boundary_pixels: 0 # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary # artifact. +patching: False hr_mean_conditioning: True -sample_res: full +# sample_res: full # Sampling resolution times_range: null times: - 20160101-0000 + # - 20160101-0600 + # - 20160101-1200 +has_laed_time: False perf: force_fp16: False @@ -31,9 +35,10 @@ perf: # To support multiple workers a threadsafe version of the netCDF library must be used io: - res_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/diffusion_test/checkpoints_diffusion + res_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/diffusion_refactoring/checkpoints_diffusion + # res_ckpt_path: null # Checkpoint filename for the diffusion model - reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_overfit/checkpoints_regression + reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_refactoring/checkpoints_regression # reg_ckpt_path: /iopsstor/scratch/cscs/pstamenk/outputs/regression_test/checkpoints_regression # Checkpoint filename for the mean predictor model output_path: ./images \ No newline at end of file diff --git a/src/hirad/conf/model/era_cosmo_diffusion.yaml b/src/hirad/conf/model/era_cosmo_diffusion.yaml index 06aa2a4..441239e 100644 --- a/src/hirad/conf/model/era_cosmo_diffusion.yaml +++ b/src/hirad/conf/model/era_cosmo_diffusion.yaml @@ -2,4 +2,14 @@ name: diffusion # Name of the preconditioner hr_mean_conditioning: True # High-res mean (regression's output) as additional condition -scale_cond_input: False \ No newline at end of file + +# Standard model parameters. +model_args: + gridtype: "sinusoidal" + # Type of positional grid to use: 'sinusoidal', 'learnable', 'linear'. + # Controls how positional information is encoded. + N_grid_channels: 4 + # Number of channels for positional grid embeddings + embedding_type: "zero" + # Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++, + # 'zero' for none \ No newline at end of file diff --git a/src/hirad/conf/model/era_cosmo_regression.yaml b/src/hirad/conf/model/era_cosmo_regression.yaml index 487eb4b..29b43e8 100644 --- a/src/hirad/conf/model/era_cosmo_regression.yaml +++ b/src/hirad/conf/model/era_cosmo_regression.yaml @@ -1,2 +1,10 @@ name: regression -hr_mean_conditioning: False \ No newline at end of file +hr_mean_conditioning: False + +# Default regression model parameters. Do not modify. +model_args: + "N_grid_channels": 4 + # Number of channels for positional grid embeddings + "embedding_type": "zero" + # Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++, + # 'zero' for none \ No newline at end of file diff --git a/src/hirad/conf/model_size/mini.yaml b/src/hirad/conf/model_size/mini.yaml new file mode 100644 index 0000000..2eb8f8a --- /dev/null +++ b/src/hirad/conf/model_size/mini.yaml @@ -0,0 +1,26 @@ +# @package _global_.model + +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +model_args: + # Base multiplier for the number of channels across the network. + model_channels: 64 + # Per-resolution multipliers for the number of channels. + channel_mult: [1, 2, 2] + # Resolutions at which self-attention layers are applied. + attn_resolutions: [16] \ No newline at end of file diff --git a/src/hirad/conf/model_size/normal.yaml b/src/hirad/conf/model_size/normal.yaml new file mode 100644 index 0000000..b81fe15 --- /dev/null +++ b/src/hirad/conf/model_size/normal.yaml @@ -0,0 +1,26 @@ +# @package _global_.model + +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +model_args: + # Base multiplier for the number of channels across the network. + model_channels: 128 + # Per-resolution multipliers for the number of channels. + channel_mult: [1, 2, 2, 2, 2] + # Resolutions at which self-attention layers are applied. + attn_resolutions: [28] \ No newline at end of file diff --git a/src/hirad/conf/sampler/stochastic.yaml b/src/hirad/conf/sampler/stochastic.yaml index 5e8fa88..2481cd3 100644 --- a/src/hirad/conf/sampler/stochastic.yaml +++ b/src/hirad/conf/sampler/stochastic.yaml @@ -1,3 +1,3 @@ type: stochastic -boundary_pix: 2 -overlap_pix: 4 \ No newline at end of file +# boundary_pix: 2 +# overlap_pix: 4 \ No newline at end of file diff --git a/src/hirad/conf/training_era_cosmo_regression.yaml b/src/hirad/conf/training_era_cosmo_regression.yaml index dc498ce..1de83d9 100644 --- a/src/hirad/conf/training_era_cosmo_regression.yaml +++ b/src/hirad/conf/training_era_cosmo_regression.yaml @@ -15,5 +15,7 @@ defaults: # Model - model/era_cosmo_regression + - model_size/normal + # Training - training/era_cosmo_regression \ No newline at end of file diff --git a/src/hirad/datasets/era5_cosmo.py b/src/hirad/datasets/era5_cosmo.py index 674dbf0..f97dbc6 100644 --- a/src/hirad/datasets/era5_cosmo.py +++ b/src/hirad/datasets/era5_cosmo.py @@ -63,8 +63,7 @@ def __getitem__(self, idx): cosmo_data = self.normalize_output(cosmo_data) # return samples return torch.tensor(cosmo_data),\ - torch.tensor(era5_data),\ - 0 + torch.tensor(era5_data), # return F.pad(torch.tensor(cosmo_data), pad=(1,1,1,1), mode='constant', value=0), \ # F.pad(torch.tensor(era5_data), pad=(1,1,1,1), mode='constant', value=0), \ # 0 diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index fbfd8cf..8fed809 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -6,6 +6,8 @@ import torch._dynamo import nvtx import numpy as np +import contextlib + from hirad.distributed import DistributedManager from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper from concurrent.futures import ThreadPoolExecutor @@ -18,7 +20,8 @@ from hydra.utils import to_absolute_path -from hirad.models import EDMPrecondSR, UNet +from hirad.models import EDMPrecondSuperResolution, UNet +from hirad.utils.patching import GridPatching2D from hirad.utils.stochastic_sampler import stochastic_sampler from hirad.utils.deterministic_sampler import deterministic_sampler from hirad.utils.inference_utils import ( @@ -85,19 +88,23 @@ def main(cfg: DictConfig) -> None: img_out_channels = len(dataset.output_channels()) # Parse the patch shape - if hasattr(cfg.generation, "patch_shape_x"): # TODO better config handling + if cfg.generation.patching: patch_shape_x = cfg.generation.patch_shape_x - else: - patch_shape_x = None - if hasattr(cfg.generation, "patch_shape_y"): patch_shape_y = cfg.generation.patch_shape_y else: - patch_shape_y = None + patch_shape_x, patch_shape_y = None, None patch_shape = (patch_shape_y, patch_shape_x) - img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) - if patch_shape != img_shape: + use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) + if use_patching: + patching = GridPatching2D( + img_shape=img_shape, + patch_shape=patch_shape, + boundary_pix=cfg.generation.boundary_pix, + overlap_pix=cfg.generation.overlap_pix, + ) logger0.info("Patch-based training enabled") else: + patching = None logger0.info("Patch-based training disabled") # Parse the inference mode @@ -121,7 +128,7 @@ def main(cfg: DictConfig) -> None: with open(diffusion_model_args_path, 'r') as f: diffusion_model_args = json.load(f) - net_res = EDMPrecondSR(**diffusion_model_args) + net_res = EDMPrecondSuperResolution(**diffusion_model_args) _ = load_checkpoint( path=res_ckpt_path, @@ -130,9 +137,13 @@ def main(cfg: DictConfig) -> None: ) #TODO fix to use channels_last which is optimal for H100 - net_res = net_res.eval().to(device)#.to(memory_format=torch.channels_last) + net_res = net_res.eval().to(device).to(memory_format=torch.channels_last) if cfg.generation.perf.force_fp16: net_res.use_fp16 = True + + # Disable AMP for inference (even if model is trained with AMP) + if hasattr(net_res, "amp_mode"): + net_res.amp_mode = False else: net_res = None @@ -156,9 +167,13 @@ def main(cfg: DictConfig) -> None: device=dist.device ) - net_reg = net_reg.eval().to(device)#.to(memory_format=torch.channels_last) + net_reg = net_reg.eval().to(device).to(memory_format=torch.channels_last) if cfg.generation.perf.force_fp16: net_reg.use_fp16 = True + + # Disable AMP for inference (even if model is trained with AMP) + if hasattr(net_reg, "amp_mode"): + net_reg.amp_mode = False else: net_reg = None @@ -183,47 +198,28 @@ def main(cfg: DictConfig) -> None: solver=cfg.sampler.solver, ) elif cfg.sampler.type == "stochastic": - sampler_fn = partial( - stochastic_sampler, - img_shape=img_shape, - patch_shape_x=patch_shape[0], - patch_shape_y=patch_shape[1], - boundary_pix=cfg.sampler.boundary_pix, - overlap_pix=cfg.sampler.overlap_pix, - ) + sampler_fn = partial(stochastic_sampler, patching=patching) else: raise ValueError(f"Unknown sampling method {cfg.sampling.type}") # Main generation definition - def generate_fn(image_lr, labels, lead_time_label): - img_shape_y, img_shape_x = img_shape + def generate_fn(image_lr, lead_time_label): with nvtx.annotate("generate_fn", color="green"): - if cfg.generation.sample_res == "full": - image_lr_patch = image_lr - else: - torch.cuda.nvtx.range_push("rearrange") - image_lr_patch = rearrange( - image_lr, - "b c (h1 h) (w1 w) -> (b h1 w1) c h w", - h1=img_shape_y // patch_shape[0], - w1=img_shape_x // patch_shape[1], - ) - torch.cuda.nvtx.range_pop() - image_lr_patch = image_lr_patch #.to(memory_format=torch.channels_last) + # (1, C, H, W) + image_lr = image_lr.to(memory_format=torch.channels_last) if net_reg: with nvtx.annotate("regression_model", color="yellow"): image_reg = regression_step( net=net_reg, - img_lr=image_lr_patch, - labels=labels, + img_lr=image_lr, latents_shape=( cfg.generation.seed_batch_size, img_out_channels, img_shape[0], img_shape[1], - ), + ), # (batch_size, C, H, W) lead_time_label=lead_time_label, ) if net_res: @@ -235,16 +231,15 @@ def generate_fn(image_lr, labels, lead_time_label): image_res = diffusion_step( net=net_res, sampler_fn=sampler_fn, - seed_batch_size=cfg.generation.seed_batch_size, img_shape=img_shape, img_out_channels=img_out_channels, rank_batches=rank_batches, - img_lr=image_lr_patch.expand( + img_lr=image_lr.expand( cfg.generation.seed_batch_size, -1, -1, -1 ), #.to(memory_format=torch.channels_last), rank=dist.rank, device=device, - hr_mean=mean_hr, + mean_hr=mean_hr, lead_time_label=lead_time_label, ) if cfg.generation.inference_mode == "regression": @@ -254,13 +249,6 @@ def generate_fn(image_lr, labels, lead_time_label): else: image_out = image_reg[0:1,::] + image_res - if cfg.generation.sample_res != "full": - image_out = rearrange( - image_out, - "(b h1 w1) c h w -> b c (h1 h) (w1 w)", - h1=img_shape_y // patch_shape[0], - w1=img_shape_x // patch_shape[1], - ) # Gather tensors on rank 0 if dist.world_size > 1: if dist.rank == 0: @@ -300,8 +288,18 @@ def generate_fn(image_lr, labels, lead_time_label): # through the dataset using a data loader, computes predictions, and saves them along # with associated metadata. - with torch.cuda.profiler.profile(): - with torch.autograd.profiler.emit_nvtx(): + torch_cuda_profiler = ( + torch.cuda.profiler.profile() + if torch.cuda.is_available() + else contextlib.nullcontext() + ) + torch_nvtx_profiler = ( + torch.autograd.profiler.emit_nvtx() + if torch.cuda.is_available() + else contextlib.nullcontext() + ) + with torch_cuda_profiler: + with torch_nvtx_profiler: data_loader = torch.utils.data.DataLoader( dataset=dataset, sampler=sampler, batch_size=1, pin_memory=True @@ -313,11 +311,29 @@ def generate_fn(image_lr, labels, lead_time_label): ) writer_threads = [] - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) + # Create timer objects only if CUDA is available + use_cuda_timing = torch.cuda.is_available() + if use_cuda_timing: + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + else: + # Dummy no-op functions for CPU case + class DummyEvent: + def record(self): + pass + + def synchronize(self): + pass + + def elapsed_time(self, _): + return 0 + + start = end = DummyEvent() times = dataset.time() - for image_tar, image_lr, labels, *lead_time_label in iter(data_loader): + for index, (image_tar, image_lr, *lead_time_label) in enumerate( + iter(data_loader) + ): time_index += 1 if dist.rank == 0: logger0.info(f"starting index: {time_index}") @@ -333,11 +349,10 @@ def generate_fn(image_lr, labels, lead_time_label): image_lr = ( image_lr.to(device=device) .to(torch.float32) - #.to(memory_format=torch.channels_last) + .to(memory_format=torch.channels_last) ) image_tar = image_tar.to(device=device).to(torch.float32) - labels = labels.to(device).to(torch.float32).contiguous() - image_out, image_reg = generate_fn(image_lr,labels,lead_time_label) + image_out, image_reg = generate_fn(image_lr,lead_time_label) if dist.rank == 0: batch_size = image_out.shape[0] # write out data in a seperate thread so we don't hold up inferencing @@ -355,9 +370,11 @@ def generate_fn(image_lr, labels, lead_time_label): ) end.record() end.synchronize() - elapsed_time = start.elapsed_time(end) / 1000.0 # Convert ms to s + elapsed_time = ( + start.elapsed_time(end) / 1000.0 if use_cuda_timing else 0 + ) # Convert ms to s timed_steps = time_index + 1 - warmup_steps - if dist.rank == 0: + if dist.rank == 0 and use_cuda_timing: average_time_per_batch_element = elapsed_time / timed_steps / batch_size logger.info( f"Total time to run {timed_steps} steps and {batch_size} members = {elapsed_time} s" @@ -378,6 +395,9 @@ def generate_fn(image_lr, labels, lead_time_label): logger0.info("Generation Completed.") def save_images(output_path, time_step, dataset, image_pred, image_hr, image_lr, mean_pred): + + os.makedirs(output_path, exist_ok=True) + longitudes = dataset.longitude() latitudes = dataset.latitude() input_channels = dataset.input_channels() diff --git a/src/hirad/losses/__init__.py b/src/hirad/losses/__init__.py index 185527b..868ffdf 100644 --- a/src/hirad/losses/__init__.py +++ b/src/hirad/losses/__init__.py @@ -1 +1 @@ -from .loss import ResLoss, RegressionLoss, RegressionLossCE \ No newline at end of file +from .loss import ResidualLoss, RegressionLoss, RegressionLossCE \ No newline at end of file diff --git a/src/hirad/losses/loss.py b/src/hirad/losses/loss.py index 18dde13..fb65960 100644 --- a/src/hirad/losses/loss.py +++ b/src/hirad/losses/loss.py @@ -18,12 +18,12 @@ """Loss functions used in the paper "Elucidating the Design Space of Diffusion-Based Generative Models".""" -import random -from typing import Callable, Optional, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch +from hirad.utils.patching import RandomPatching2D class VPLoss: """ @@ -333,7 +333,7 @@ def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): sigma = (rnd_normal * self.P_std + self.P_mean).exp() weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 - # augment for conditional generaiton + # augment for conditional generation img_tot = torch.cat((img_clean, img_lr), dim=1) y_tot, augment_labels = ( augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) @@ -349,16 +349,13 @@ def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): class RegressionLoss: """ - Regression loss function for the U-Net for deterministic predictions. + Regression loss function for the deterministic predictions. + Note: this loss does not apply any reduction. - Parameters + Attributes ---------- - P_mean: float, optional - Mean value for `sigma` computation, by default -1.2. - P_std: float, optional: - Standard deviation for `sigma` computation, by default 1.2. - sigma_data: float, optional - Standard deviation for data, by default 0.5. + sigma_data: float + Standard deviation for data. Deprecated and ignored. Note ---- @@ -368,43 +365,68 @@ class RegressionLoss: arXiv preprint arXiv:2309.15214. """ - def __init__( - self, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5 - ): - self.P_mean = P_mean - self.P_std = P_std - self.sigma_data = sigma_data + def __init__(self): + """ + Arguments + ---------- + """ + return - def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): + def __call__( + self, + net: torch.nn.Module, + img_clean: torch.Tensor, + img_lr: torch.Tensor, + augment_pipe: Optional[ + Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]] + ] = None, + ) -> torch.Tensor: """ - Calculate and return the loss for the U-Net for deterministic predictions. + Calculate and return the regression loss for + deterministic predictions. - Parameters: + Parameters ---------- - net: torch.nn.Module + net : torch.nn.Module The neural network model that will make predictions. + Expected signature: `net(x, img_lr, + augment_labels=augment_labels, force_fp32=False)`, where: + x (torch.Tensor): Tensor of shape (B, C_hr, H, W). Is zero-filled. + img_lr (torch.Tensor): Low-resolution input of shape (B, C_lr, H, W) + augment_labels (torch.Tensor, optional): Optional augmentation + labels, returned by `augment_pipe`. + force_fp32 (bool, optional): Whether to force the model to use + fp32, by default False. + Returns: + torch.Tensor: Predictions of shape (B, C_hr, H, W) + + img_clean : torch.Tensor + High-resolution input images of shape (B, C_hr, H, W). + Used as ground truth and for data augmentation if 'augment_pipe' is provided. + + img_lr : torch.Tensor + Low-resolution input images of shape (B, C_lr, H, W). + Used as input to the neural network. + + augment_pipe : callable, optional + An optional data augmentation function. + Expected signature: + img_tot (torch.Tensor): Concatenated high and low resolution + images of shape (B, C_hr+C_lr, H, W) + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - Augmented images of shape (B, C_hr+C_lr, H, W) + - Optional augmentation labels - img_clean: torch.Tensor - Input images (high resolution) to the neural network. - - img_lr: torch.Tensor - Input images (low resolution) to the neural network. - - labels: torch.Tensor - Ground truth labels for the input images. - - augment_pipe: callable, optional - An optional data augmentation function that takes images as input and - returns augmented images. If not provided, no data augmentation is applied. - - Returns: + Returns ------- torch.Tensor - A tensor representing the loss calculated based on the network's - predictions. + A tensor representing the per-sample element-wise squared + difference between the network's predictions and the high + resolution images `img_clean` (possibly data-augmented by + `augment_pipe`). + Shape: (B, C_hr, H, W), same as `img_clean`. """ - rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) - sigma = (rnd_normal * self.P_std + self.P_mean).exp() weight = ( 1.0 # (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 ) @@ -416,100 +438,214 @@ def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): y = y_tot[:, : img_clean.shape[1], :, :] y_lr = y_tot[:, img_clean.shape[1] :, :, :] - input = torch.zeros_like(y, device=img_clean.device) - D_yn = net(input, y_lr, sigma, labels, augment_labels=augment_labels) + zero_input = torch.zeros_like(y, device=img_clean.device) + D_yn = net(zero_input, y_lr, force_fp32=False, augment_labels=augment_labels) loss = weight * ((D_yn - y) ** 2) return loss -class ResLoss: +class ResidualLoss: """ Mixture loss function for denoising score matching. - Parameters + This class implements a loss function that combines deterministic + regression with denoising score matching. It uses a pre-trained regression + network to compute residuals before applying the diffusion process. + + Attributes ---------- - P_mean: float, optional - Mean value for `sigma` computation, by default -1.2. - P_std: float, optional: - Standard deviation for `sigma` computation, by default 1.2. - sigma_data: float, optional - Standard deviation for data, by default 0.5. + regression_net : torch.nn.Module + The regression network used for computing residuals. + P_mean : float + Mean value for noise level computation. + P_std : float + Standard deviation for noise level computation. + sigma_data : float + Standard deviation for data weighting. + hr_mean_conditioning : bool + Flag indicating whether to use high-resolution mean for conditioning. Note ---- Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., - Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. - Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. - arXiv preprint arXiv:2309.15214. + Liu, C.C., Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric + Downscaling. arXiv preprint arXiv:2309.15214. """ def __init__( self, - regression_net, - img_shape_x, - img_shape_y, - patch_shape_x, - patch_shape_y, - patch_num, + regression_net: torch.nn.Module, P_mean: float = 0.0, P_std: float = 1.2, sigma_data: float = 0.5, hr_mean_conditioning: bool = False, ): - self.unet = regression_net + """ + Arguments + ---------- + regression_net : torch.nn.Module + Pre-trained regression network used to compute residuals. + Expected signature: `net(zero_input, y_lr, + lead_time_label=lead_time_label, augment_labels=augment_labels)` or + `net(zero_input, y_lr, augment_labels=augment_labels)`, where: + zero_input (torch.Tensor): Zero tensor of shape (B, C_hr, H, W) + y_lr (torch.Tensor): Low-resolution input of shape (B, C_lr, H, W) + lead_time_label (torch.Tensor, optional): Optional lead time labels + augment_labels (torch.Tensor, optional): Optional augmentation labels + Returns: + torch.Tensor: Predictions of shape (B, C_hr, H, W) + + P_mean : float, optional + Mean value for noise level computation, by default 0.0. + + P_std : float, optional + Standard deviation for noise level computation, by default 1.2. + + sigma_data : float, optional + Standard deviation for data weighting, by default 0.5. + + hr_mean_conditioning : bool, optional + Whether to use high-resolution mean for conditioning predicted, by default False. + When True, the mean prediction from `regression_net` is channel-wise + concatenated with `img_lr` for conditioning. + """ + self.regression_net = regression_net self.P_mean = P_mean self.P_std = P_std self.sigma_data = sigma_data - self.img_shape_x = img_shape_x - self.img_shape_y = img_shape_y - self.patch_shape_x = patch_shape_x - self.patch_shape_y = patch_shape_y - self.patch_num = patch_num self.hr_mean_conditioning = hr_mean_conditioning + self.y_mean = None def __call__( self, - net, - img_clean, - img_lr, - labels=None, - lead_time_label=None, - augment_pipe=None, - ): + net: torch.nn.Module, + img_clean: torch.Tensor, + img_lr: torch.Tensor, + patching: Optional[RandomPatching2D] = None, + lead_time_label: Optional[torch.Tensor] = None, + augment_pipe: Optional[ + Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]] + ] = None, + use_patch_grad_acc: bool = False, + ) -> torch.Tensor: """ Calculate and return the loss for denoising score matching. - Parameters: - ---------- - net: torch.nn.Module - The neural network model that will make predictions. - - img_clean: torch.Tensor - Input images (high resolution) to the neural network. - - img_lr: torch.Tensor - Input images (low resolution) to the neural network. - - labels: torch.Tensor - Ground truth labels for the input images. + This method computes a mixture loss that combines deterministic + regression with denoising score matching. It first computes residuals + using the regression network, then applies the diffusion process to + these residuals. + + In addition to the standard denoising score matching loss, this method + also supports optional patching for multi-diffusion. In this case, the spatial + dimensions of the input are decomposed into `P` smaller patches of shape + (H_patch, W_patch), that are grouped along the batch dimension, and the + model is applied to each patch individually. In the following, if `patching` + is not provided, then the input is not patched and `P=1` and `(H_patch, + W_patch) = (H, W)`. When patching is used, the original non-patched conditioning is + interpolated onto a spatial grid of shape `(H_patch, W_patch)` and channel-wise + concatenated to the patched conditioning. This ensures that each patch + maintains global information from the entire domain. + + The diffusion model `net` is expected to be conditioned on an input with + `C_cond` channels, which should be: + - `C_cond = C_lr` if `hr_mean_conditioning` is `False` and + `patching` is None. + - `C_cond = C_hr + C_lr` if `hr_mean_conditioning` is `True` and + `patching` is None. + - `C_cond = C_hr + 2*C_lr` if `hr_mean_conditioning` is `True` and + `patching` is not None. + - `C_cond = 2*C_lr` if `hr_mean_conditioning` is `False` and + `patching` is not None. + Additionally, `C_cond` should also include any embedding channels, + such as positional embeddings or time embeddings. + + Note: this loss function does not apply any reduction. - augment_pipe: callable, optional - An optional data augmentation function that takes images as input and - returns augmented images. If not provided, no data augmentation is applied. + Parameters + ---------- + net : torch.nn.Module + The neural network model for the diffusion process. + Expected signature: `net(latent, y_lr, sigma, + embedding_selector=embedding_selector, lead_time_label=lead_time_label, + augment_labels=augment_labels)`, where: + latent (torch.Tensor): Noisy input of shape (B[*P], C_hr, H_patch, W_patch) + y_lr (torch.Tensor): Conditioning of shape (B[*P], C_cond, H_patch, W_patch) + sigma (torch.Tensor): Noise level of shape (B[*P], 1, 1, 1) + embedding_selector (callable, optional): Function to select + positional embeddings. Only used if `patching` is provided. + lead_time_label (torch.Tensor, optional): Lead time labels. + augment_labels (torch.Tensor, optional): Augmentation labels + Returns: + torch.Tensor: Predictions of shape (B[*P], C_hr, H_patch, W_patch) + + img_clean : torch.Tensor + High-resolution input images of shape (B, C_hr, H, W). + Used as ground truth and for data augmentation if 'augment_pipe' is provided. + + img_lr : torch.Tensor + Low-resolution input images of shape (B, C_lr, H, W). + Used as input to the regression network and conditioning for the + diffusion process. + + patching : Optional[RandomPatching2D], optional + Patching strategy for processing large images, by default None. See + :class:`physicsnemo.utils.patching.RandomPatching2D` for details. + When provided, the patching strategy is used for both image patches + and positional embeddings selection in the diffusion model `net`. + Transforms tensors from shape (B, C, H, W) to (B*P, C, H_patch, + W_patch). + + lead_time_label : Optional[torch.Tensor], optional + Labels for lead-time aware predictions, by default None. + Shape can vary based on model requirements, typically (B,) or scalar. + + augment_pipe : Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]]] + Data augmentation function. + Expected signature: + img_tot (torch.Tensor): Concatenated high and low resolution images + of shape (B, C_hr+C_lr, H, W) + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - Augmented images of shape (B, C_hr+C_lr, H, W) + - Optional augmentation labels + use_patch_grad_acc: bool, optional + A boolean flag indicating whether to enable multi-iterations of patching accumulations + for amortizing regression cost. Default False. - Returns: + Returns ------- torch.Tensor - A tensor representing the loss calculated based on the network's - predictions. + If patching is not used: + A tensor of shape (B, C_hr, H, W) representing the per-sample loss. + If patching is used: + A tensor of shape (B*P, C_hr, H_patch, W_patch) representing + the per-patch loss. + + Raises + ------ + ValueError + If patching is provided but is not an instance of RandomPatching2D. + If shapes of img_clean and img_lr are incompatible. """ - rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) - sigma = (rnd_normal * self.P_std + self.P_mean).exp() - weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + # Safety check: enforce patching object + if patching and not isinstance(patching, RandomPatching2D): + raise ValueError("patching must be a 'RandomPatching2D' object.") + # Safety check: enforce shapes + if ( + img_clean.shape[0] != img_lr.shape[0] + or img_clean.shape[2:] != img_lr.shape[2:] + ): + raise ValueError( + f"Shape mismatch between img_clean {img_clean.shape} and " + f"img_lr {img_lr.shape}. " + f"Batch size, height and width must match." + ) - # augment for conditional generaiton + # augment for conditional generation img_tot = torch.cat((img_clean, img_lr), dim=1) y_tot, augment_labels = ( augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) @@ -517,114 +653,71 @@ def __call__( y = y_tot[:, : img_clean.shape[1], :, :] y_lr = y_tot[:, img_clean.shape[1] :, :, :] y_lr_res = y_lr - - # global index - b = y.shape[0] - Nx = torch.arange(self.img_shape_x).int() - Ny = torch.arange(self.img_shape_y).int() - grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0)[ - None, - ].expand(b, -1, -1, -1) - - # form residual - if lead_time_label is not None: - y_mean = self.unet( - torch.zeros_like(y, device=img_clean.device), - y_lr_res, - sigma, - labels, - lead_time_label=lead_time_label, - augment_labels=augment_labels, - ) + batch_size = y.shape[0] + + # if using multi-iterations of patching, switch to optimized version + if use_patch_grad_acc: + # form residual + if self.y_mean is None: + if lead_time_label is not None: + y_mean = self.regression_net( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + y_mean = self.regression_net( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + augment_labels=augment_labels, + ) + self.y_mean = y_mean + + # if on full domain, or if using patching without multi-iterations else: - y_mean = self.unet( - torch.zeros_like(y, device=img_clean.device), - y_lr_res, - sigma, - labels, - augment_labels=augment_labels, - ) + # form residual + if lead_time_label is not None: + y_mean = self.regression_net( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + y_mean = self.regression_net( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + augment_labels=augment_labels, + ) - y = y - y_mean + self.y_mean = y_mean + + y = y - self.y_mean if self.hr_mean_conditioning: - y_lr = torch.cat((y_mean, y_lr), dim=1).contiguous() - global_index = None + y_lr = torch.cat((self.y_mean, y_lr), dim=1) + # patchified training # conditioning: cat(y_mean, y_lr, input_interp, pos_embd), 4+12+100+4 - if ( - self.img_shape_x != self.patch_shape_x - or self.img_shape_y != self.patch_shape_y - ): - c_in = y_lr.shape[1] - c_out = y.shape[1] - rnd_normal = torch.randn( - [img_clean.shape[0] * self.patch_num, 1, 1, 1], device=img_clean.device - ) - sigma = (rnd_normal * self.P_std + self.P_mean).exp() - weight = (sigma**2 + self.sigma_data**2) / ( - sigma * self.sigma_data - ) ** 2 - - # global interpolation - input_interp = torch.nn.functional.interpolate( - img_lr, - (self.patch_shape_y, self.patch_shape_x), - mode="bilinear", - ) + # removed patch_embedding_selector due to compilation issue with dynamo. + if patching: + # Patched residual + # (batch_size * patch_num, c_out, patch_shape_y, patch_shape_x) + y_patched = patching.apply(input=y) + # Patched conditioning on y_lr and interp(img_lr) + # (batch_size * patch_num, 2*c_in, patch_shape_y, patch_shape_x) + y_lr_patched = patching.apply(input=y_lr, additional_input=img_lr) + + y = y_patched + y_lr = y_lr_patched + + # Noise + rnd_normal = torch.randn([y.shape[0], 1, 1, 1], device=img_clean.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 - # patch generation from a single sample (not from random samples due to memory consumption of regression) - y_new = torch.zeros( - b * self.patch_num, - c_out, - self.patch_shape_y, - self.patch_shape_x, - device=img_clean.device, - ) - y_lr_new = torch.zeros( - b * self.patch_num, - c_in + input_interp.shape[1], - self.patch_shape_y, - self.patch_shape_x, - device=img_clean.device, - ) - global_index = torch.zeros( - b * self.patch_num, - 2, - self.patch_shape_y, - self.patch_shape_x, - dtype=torch.int, - device=img_clean.device, - ) - for i in range(self.patch_num): - rnd_x = random.randint(0, self.img_shape_x - self.patch_shape_x) - rnd_y = random.randint(0, self.img_shape_y - self.patch_shape_y) - y_new[b * i : b * (i + 1),] = y[ - :, - :, - rnd_y : rnd_y + self.patch_shape_y, - rnd_x : rnd_x + self.patch_shape_x, - ] - global_index[b * i : b * (i + 1),] = grid[ - :, - :, - rnd_y : rnd_y + self.patch_shape_y, - rnd_x : rnd_x + self.patch_shape_x, - ] - y_lr_new[b * i : b * (i + 1),] = torch.cat( - ( - y_lr[ - :, - :, - rnd_y : rnd_y + self.patch_shape_y, - rnd_x : rnd_x + self.patch_shape_x, - ], - input_interp, - ), - 1, - ) - y = y_new - y_lr = y_lr_new + # Input + noise latent = y + torch.randn_like(y) * sigma if lead_time_label is not None: @@ -632,8 +725,10 @@ def __call__( latent, y_lr, sigma, - labels, - global_index=global_index, + embedding_selector=None, + global_index=patching.global_index(batch_size, img_clean.device) + if patching is not None + else None, lead_time_label=lead_time_label, augment_labels=augment_labels, ) @@ -642,8 +737,10 @@ def __call__( latent, y_lr, sigma, - labels, - global_index=global_index, + embedding_selector=None, + global_index=patching.global_index(batch_size, img_clean.device) + if patching is not None + else None, augment_labels=augment_labels, ) loss = weight * ((D_yn - y) ** 2) @@ -651,6 +748,7 @@ def __call__( return loss + class VELoss_dfsr: """ Loss function for dfsr model, modified from class VELoss. @@ -792,20 +890,19 @@ def __call__(self, net, images, labels, augment_pipe=None): class RegressionLossCE: """ - A regression loss function for the GEFS-HRRR model with probability channels, adapted - from RegressionLoss. In this version, probability channels are evaluated using - CrossEntropyLoss instead of MSELoss. - - Parameters + A regression loss function for deterministic predictions with probability + channels and lead time labels. Adapted from + :class:`physicsnemo.metrics.diffusion.loss.RegressionLoss`. In this version, + probability channels are evaluated using CrossEntropyLoss instead of + squared error. + Note: this loss does not apply any reduction. + + Attributes ---------- - P_mean: float, optional - Mean value for `sigma` computation, by default -1.2. - P_std: float, optional: - Standard deviation for `sigma` computation, by default 1.2. - sigma_data: float, optional - Standard deviation for data, by default 0.5. - prob_channels: list, optional - A index list of output probability channels. + entropy : torch.nn.CrossEntropyLoss + Cross entropy loss function used for probability channels. + prob_channels : list[int] + List of channel indices to be treated as probability channels. Note ---- @@ -817,62 +914,86 @@ class RegressionLossCE: def __init__( self, - P_mean: float = -1.2, - P_std: float = 1.2, - sigma_data: float = 0.5, - prob_channels: list = [4, 5, 6, 7, 8], + prob_channels: list[int] = [4, 5, 6, 7, 8], ): - self.P_mean = P_mean - self.P_std = P_std - self.sigma_data = sigma_data + """ + Arguments + ---------- + prob_channels: list[int], optional + List of channel indices from the target tensor to be treated as + probability channels. Cross entropy loss is computed over these + channels, while the remaining channels are treated as scalar + channels and the squared error loss is computed over them. By + default, [4, 5, 6, 7, 8]. + """ self.entropy = torch.nn.CrossEntropyLoss(reduction="none") self.prob_channels = prob_channels def __call__( self, - net, - img_clean, - img_lr, - lead_time_label=None, - labels=None, - augment_pipe=None, - ): + net: torch.nn.Module, + img_clean: torch.Tensor, + img_lr: torch.Tensor, + lead_time_label: Optional[torch.Tensor] = None, + augment_pipe: Optional[ + Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]] + ] = None, + ) -> torch.Tensor: """ - Calculate and return the loss for the U-Net for deterministic predictions. + Calculate and return the loss for deterministic + predictions, treating specific channels as probability distributions. - Parameters: + Parameters ---------- - net: torch.nn.Module + net : torch.nn.Module The neural network model that will make predictions. + Expected signature: `net(input, img_lr, lead_time_label=lead_time_label, augment_labels=augment_labels)`, + where: + input (torch.Tensor): Tensor of shape (B, C_hr, H, W). Zero-filled. + y_lr (torch.Tensor): Low-resolution input of shape (B, C_lr, H, W) + lead_time_label (torch.Tensor, optional): Optional lead time + labels. If provided, should be of shape (B,). + augment_labels (torch.Tensor, optional): Optional augmentation + labels, returned by `augment_pipe`. + Returns: + torch.Tensor: Predictions of shape (B, C_hr, H, W) + + img_clean : torch.Tensor + High-resolution input images of shape (B, C_hr, H, W). + Used as ground truth and for data augmentation if `augment_pipe` is provided. + + img_lr : torch.Tensor + Low-resolution input images of shape (B, C_lr, H, W). + Used as input to the neural network. + + lead_time_label : Optional[torch.Tensor], optional + Lead time labels for temporal predictions, by default None. + Shape can vary based on model requirements, typically (B,) or scalar. + + augment_pipe : Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]]] + Data augmentation function. + Expected signature: + img_tot (torch.Tensor): Concatenated high and low resolution + images of shape (B, C_hr+C_lr, H, W). + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - Augmented images of shape (B, C_hr+C_lr, H, W) + - Optional augmentation labels - img_clean: torch.Tensor - Input images (high resolution) to the neural network. - - img_lr: torch.Tensor - Input images (low resolution) to the neural network. - - lead_time_label: torch.Tensor - Lead time labels for input batches. - - labels: torch.Tensor - Ground truth labels for the input images. - - augment_pipe: callable, optional - An optional data augmentation function that takes images as input and - returns augmented images. If not provided, no data augmentation is applied. - - Returns: + Returns ------- torch.Tensor - A tensor representing the loss calculated based on the network's - predictions. + A tensor of shape (B, C_loss, H, W) representing the pixel-wise + loss., where `C_loss = C_hr - len(prob_channels) + 1`. More + specifically, the last channel of the output tensor corresponds to + the cross-entropy loss computed over the channels specified in + `prob_channels`, while the first `C_hr - len(prob_channels)` + channels of the output tensor correspond to the squared error loss. """ all_channels = list(range(img_clean.shape[1])) # [0, 1, 2, ..., 10] scalar_channels = [ item for item in all_channels if item not in self.prob_channels ] - rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) - sigma = (rnd_normal * self.P_std + self.P_mean).exp() weight = ( 1.0 # (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 ) @@ -890,8 +1011,6 @@ def __call__( D_yn = net( input, y_lr, - sigma, - labels, lead_time_label=lead_time_label, augment_labels=augment_labels, ) @@ -899,11 +1018,10 @@ def __call__( D_yn = net( input, y_lr, - sigma, - labels, + lead_time_label=lead_time_label, augment_labels=augment_labels, ) - loss1 = weight * ((D_yn[:, scalar_channels] - y[:, scalar_channels]) ** 2) + loss1 = weight * (D_yn[:, scalar_channels] - y[:, scalar_channels]) ** 2 loss2 = ( weight * self.entropy(D_yn[:, self.prob_channels], y[:, self.prob_channels])[ @@ -911,4 +1029,4 @@ def __call__( ] ) loss = torch.cat((loss1, loss2), dim=1) - return loss + return loss \ No newline at end of file diff --git a/src/hirad/models/__init__.py b/src/hirad/models/__init__.py index 3ab4a6f..b00a477 100644 --- a/src/hirad/models/__init__.py +++ b/src/hirad/models/__init__.py @@ -1,6 +1,14 @@ -from .layers import Linear, Conv2d, GroupNorm, AttentionOp, UNetBlock, PositionalEmbedding, FourierEmbedding +from .layers import ( + Linear, + Conv2d, + GroupNorm, + AttentionOp, + UNetBlock, + PositionalEmbedding, + FourierEmbedding +) from .meta import ModelMetaData from .song_unet import SongUNet, SongUNetPosEmbd, SongUNetPosLtEmbd from .dhariwal_unet import DhariwalUNet from .unet import UNet -from .preconditioning import EDMPrecondSR, EDMPrecond +from .preconditioning import EDMPrecondSuperResolution, EDMPrecondSR, EDMPrecond diff --git a/src/hirad/models/layers.py b/src/hirad/models/layers.py index 8612da7..d7e63d7 100644 --- a/src/hirad/models/layers.py +++ b/src/hirad/models/layers.py @@ -19,15 +19,27 @@ Diffusion-Based Generative Models". """ +import contextlib +import importlib from typing import Any, Dict, List import numpy as np +import nvtx import torch +import torch.cuda.amp as amp from einops import rearrange -from torch.nn.functional import silu +from torch.nn.functional import elu, gelu, leaky_relu, relu, sigmoid, silu, tanh from hirad.utils.model_utils import weight_init +_is_apex_available = False +if torch.cuda.is_available(): + try: + apex_gn_module = importlib.import_module("apex.contrib.group_norm") + ApexGroupNorm = getattr(apex_gn_module, "GroupNorm") + _is_apex_available = True + except ImportError: + pass class Linear(torch.nn.Module): """ @@ -56,6 +68,8 @@ class Linear(torch.nn.Module): A scaling factor to multiply with the initialized weights. By default 1. init_bias : float, optional A scaling factor to multiply with the initialized biases. By default 0. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ def __init__( @@ -66,10 +80,12 @@ def __init__( init_mode: str = "kaiming_normal", init_weight: int = 1, init_bias: int = 0, + amp_mode: bool = False, ): super().__init__() self.in_features = in_features self.out_features = out_features + self.amp_mode = amp_mode init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) self.weight = torch.nn.Parameter( weight_init([out_features, in_features], **init_kwargs) * init_weight @@ -81,9 +97,16 @@ def __init__( ) def forward(self, x): - x = x @ self.weight.to(x.dtype).t() + weight, bias = self.weight, self.bias + # pdb.set_trace() + if not self.amp_mode: + if self.weight is not None and self.weight.dtype != x.dtype: + weight = self.weight.to(x.dtype) + if self.bias is not None and self.bias.dtype != x.dtype: + bias = self.bias.to(x.dtype) + x = x @ weight.t() if self.bias is not None: - x = x.add_(self.bias.to(x.dtype)) + x = x.add_(bias) return x @@ -128,6 +151,10 @@ class Conv2d(torch.nn.Module): A scaling factor to multiply with the initialized weights. By default 1.0. init_bias : float, optional A scaling factor to multiply with the initialized biases. By default 0.0. + fused_conv_bias: bool, optional + A boolean flag indicating whether bias will be passed as a parameter of conv2d. By default False. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ def __init__( @@ -143,9 +170,16 @@ def __init__( init_mode: str = "kaiming_normal", init_weight: float = 1.0, init_bias: float = 0.0, + fused_conv_bias: bool = False, + amp_mode: bool = False, ): if up and down: raise ValueError("Both 'up' and 'down' cannot be true at the same time.") + if not kernel and fused_conv_bias: + print( + "Warning: Kernel is required when fused_conv_bias is enabled. Setting fused_conv_bias to False." + ) + fused_conv_bias = False super().__init__() self.in_channels = in_channels @@ -153,6 +187,8 @@ def __init__( self.up = up self.down = down self.fused_resample = fused_resample + self.fused_conv_bias = fused_conv_bias + self.amp_mode = amp_mode init_kwargs = dict( mode=init_mode, fan_in=in_channels * kernel * kernel, @@ -176,13 +212,21 @@ def __init__( self.register_buffer("resample_filter", f if up or down else None) def forward(self, x): - w = self.weight.to(x.dtype) if self.weight is not None else None - b = self.bias.to(x.dtype) if self.bias is not None else None - f = ( - self.resample_filter.to(x.dtype) - if self.resample_filter is not None - else None - ) + weight, bias, resample_filter = self.weight, self.bias, self.resample_filter + if not self.amp_mode: + if self.weight is not None and self.weight.dtype != x.dtype: + weight = self.weight.to(x.dtype) + if self.bias is not None and self.bias.dtype != x.dtype: + bias = self.bias.to(x.dtype) + if ( + self.resample_filter is not None + and self.resample_filter.dtype != x.dtype + ): + resample_filter = self.resample_filter.to(x.dtype) + + w = weight if weight is not None else None + b = bias if bias is not None else None + f = resample_filter if resample_filter is not None else None w_pad = w.shape[-1] // 2 if w is not None else 0 f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0 @@ -194,15 +238,29 @@ def forward(self, x): stride=2, padding=max(f_pad - w_pad, 0), ) - x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0)) + if self.fused_conv_bias: + x = torch.nn.functional.conv2d( + x, w, padding=max(w_pad - f_pad, 0), bias=b + ) + else: + x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0)) elif self.fused_resample and self.down and w is not None: x = torch.nn.functional.conv2d(x, w, padding=w_pad + f_pad) - x = torch.nn.functional.conv2d( - x, - f.tile([self.out_channels, 1, 1, 1]), - groups=self.out_channels, - stride=2, - ) + if self.fused_conv_bias: + x = torch.nn.functional.conv2d( + x, + f.tile([self.out_channels, 1, 1, 1]), + groups=self.out_channels, + stride=2, + bias=b, + ) + else: + x = torch.nn.functional.conv2d( + x, + f.tile([self.out_channels, 1, 1, 1]), + groups=self.out_channels, + stride=2, + ) else: if self.up: x = torch.nn.functional.conv_transpose2d( @@ -220,11 +278,15 @@ def forward(self, x): stride=2, padding=f_pad, ) - if w is not None: - #TODO during inference, model breaks here for some reason - # current fix is to disable torch.backends.cudnn.enabled = False - x = torch.nn.functional.conv2d(x, w, padding=w_pad) - if b is not None: + + #TODO during inference, model breaks here for some reason + # current fix is to disable torch.backends.cudnn.enabled = False + if w is not None: # ask in corrdiff channel whether w will ever be none + if self.fused_conv_bias: + x = torch.nn.functional.conv2d(x, w, padding=w_pad, bias=b) + else: + x = torch.nn.functional.conv2d(x, w, padding=w_pad) + if b is not None and not self.fused_conv_bias: x = x.add_(b.reshape(1, -1, 1, 1)) return x @@ -251,7 +313,15 @@ class GroupNorm(torch.nn.Module): eps : float, optional A small number added to the variance to prevent division by zero, by default 1e-5. - + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + fused_act : bool, optional + Whether to fuse the activation function with GroupNorm. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. Notes ----- If `num_channels` is not divisible by `num_groups`, the actual number of groups @@ -264,28 +334,71 @@ def __init__( num_groups: int = 32, min_channels_per_group: int = 4, eps: float = 1e-5, + use_apex_gn: bool = False, + fused_act: bool = False, + act: str = None, + amp_mode: bool = False, ): + if fused_act and act is None: + raise ValueError("'act' must be specified when 'fused_act' is set to True.") + super().__init__() self.num_groups = min(num_groups, num_channels // min_channels_per_group) self.eps = eps self.weight = torch.nn.Parameter(torch.ones(num_channels)) self.bias = torch.nn.Parameter(torch.zeros(num_channels)) + if use_apex_gn and not _is_apex_available: + raise ValueError("'apex' is not installed, set `use_apex_gn=False`") + self.use_apex_gn = use_apex_gn + self.fused_act = fused_act + self.act = act.lower() if act else act + self.act_fn = None + self.amp_mode = amp_mode + if self.use_apex_gn: + if self.act: + self.gn = ApexGroupNorm( + num_groups=self.num_groups, + num_channels=num_channels, + eps=self.eps, + affine=True, + act=self.act, + ) + + else: + self.gn = ApexGroupNorm( + num_groups=self.num_groups, + num_channels=num_channels, + eps=self.eps, + affine=True, + ) + if self.fused_act: + self.act_fn = self.get_activation_function() def forward(self, x): - if self.training: + weight, bias = self.weight, self.bias + if not self.amp_mode: + if not self.use_apex_gn: + if weight.dtype != x.dtype: + weight = self.weight.to(x.dtype) + if bias.dtype != x.dtype: + bias = self.bias.to(x.dtype) + if self.use_apex_gn: + x = self.gn(x) + elif self.training: # Use default torch implementation of GroupNorm for training # This does not support channels last memory format x = torch.nn.functional.group_norm( x, num_groups=self.num_groups, - weight=self.weight.to(x.dtype), - bias=self.bias.to(x.dtype), + weight=weight, + bias=bias, eps=self.eps, ) + if self.fused_act: + x = self.act_fn(x) else: # Use custom GroupNorm implementation that supports channels last # memory layout for inference - dtype = x.dtype x = x.float() x = rearrange(x, "b (g c) h w -> b g c h w", g=self.num_groups) @@ -295,12 +408,33 @@ def forward(self, x): x = (x - mean) * (var + self.eps).rsqrt() x = rearrange(x, "b g c h w -> b (g c) h w") - weight = rearrange(self.weight, "c -> 1 c 1 1") - bias = rearrange(self.bias, "c -> 1 c 1 1") + weight = rearrange(weight, "c -> 1 c 1 1") + bias = rearrange(bias, "c -> 1 c 1 1") x = x * weight + bias - x = x.type(dtype) + if self.fused_act: + x = self.act_fn(x) return x + + def get_activation_function(self): + """ + Get activation function given string input + """ + + activation_map = { + "silu": silu, + "relu": relu, + "leaky_relu": leaky_relu, + "sigmoid": sigmoid, + "tanh": tanh, + "gelu": gelu, + "elu": elu, + } + + act_fn = activation_map.get(self.act, None) + if act_fn is None: + raise ValueError(f"Unknown activation function: {self.act}") + return act_fn class AttentionOp(torch.autograd.Function): @@ -333,6 +467,7 @@ def backward(ctx, dw): dim=2, input_dtype=torch.float32, ) + dq = torch.einsum("nck,nqk->ncq", k.to(torch.float32), db).to( q.dtype ) / np.sqrt(k.shape[1]) @@ -385,6 +520,17 @@ class UNetBlock(torch.nn.Module): init_attn : dict, optional Initialization parameters specific to attention mechanism layers. Defaults to 'init' if not provided. + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + fused_conv_bias: bool, optional + A boolean flag indicating whether bias will be passed as a parameter of conv2d. By default False. + profile_mode: + A boolean flag indicating whether to enable all nvtx annotations during profiling. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ def __init__( @@ -406,6 +552,11 @@ def __init__( init: Dict[str, Any] = dict(), init_zero: Dict[str, Any] = dict(init_weight=0), init_attn: Any = None, + use_apex_gn: bool = False, + act: str = "silu", + fused_conv_bias: bool = False, + profile_mode: bool = False, + amp_mode: bool = False, ): super().__init__() @@ -423,7 +574,16 @@ def __init__( self.skip_scale = skip_scale self.adaptive_scale = adaptive_scale - self.norm0 = GroupNorm(num_channels=in_channels, eps=eps) + self.profile_mode = profile_mode + self.amp_mode = amp_mode + self.norm0 = GroupNorm( + num_channels=in_channels, + eps=eps, + use_apex_gn=use_apex_gn, + fused_act=True, + act=act, + amp_mode=amp_mode, + ) self.conv0 = Conv2d( in_channels=in_channels, out_channels=out_channels, @@ -431,21 +591,45 @@ def __init__( up=up, down=down, resample_filter=resample_filter, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, **init, ) self.affine = Linear( in_features=emb_channels, out_features=out_channels * (2 if adaptive_scale else 1), + amp_mode=amp_mode, **init, ) - self.norm1 = GroupNorm(num_channels=out_channels, eps=eps) + if self.adaptive_scale: + self.norm1 = GroupNorm( + num_channels=out_channels, + eps=eps, + use_apex_gn=use_apex_gn, + amp_mode=amp_mode, + ) + else: + self.norm1 = GroupNorm( + num_channels=out_channels, + eps=eps, + use_apex_gn=use_apex_gn, + act=act, + fused_act=True, + amp_mode=amp_mode, + ) self.conv1 = Conv2d( - in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero + in_channels=out_channels, + out_channels=out_channels, + kernel=3, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, + **init_zero, ) self.skip = None if out_channels != in_channels or up or down: kernel = 1 if resample_proj or out_channels != in_channels else 0 + fused_conv_bias = fused_conv_bias if kernel != 0 else False self.skip = Conv2d( in_channels=in_channels, out_channels=out_channels, @@ -453,55 +637,75 @@ def __init__( up=up, down=down, resample_filter=resample_filter, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, **init, ) if self.num_heads: - self.norm2 = GroupNorm(num_channels=out_channels, eps=eps) + self.norm2 = GroupNorm( + num_channels=out_channels, + eps=eps, + use_apex_gn=use_apex_gn, + amp_mode=amp_mode, + ) self.qkv = Conv2d( in_channels=out_channels, out_channels=out_channels * 3, kernel=1, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, **(init_attn if init_attn is not None else init), ) self.proj = Conv2d( in_channels=out_channels, out_channels=out_channels, kernel=1, + fused_conv_bias=fused_conv_bias, + amp_mode=amp_mode, **init_zero, ) def forward(self, x, emb): - torch.cuda.nvtx.range_push("UNetBlock") - orig = x - x = self.conv0(silu(self.norm0(x))) - params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype) - if self.adaptive_scale: - scale, shift = params.chunk(chunks=2, dim=1) - x = silu(torch.addcmul(shift, self.norm1(x), scale + 1)) - else: - x = silu(self.norm1(x.add_(params))) - - x = self.conv1( - torch.nn.functional.dropout(x, p=self.dropout, training=self.training) - ) - x = x.add_(self.skip(orig) if self.skip is not None else orig) - x = x * self.skip_scale - - if self.num_heads: - q, k, v = ( - self.qkv(self.norm2(x)) - .reshape( - x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1 - ) - .unbind(2) + with nvtx.annotate( + message="UNetBlock", color="purple" + ) if self.profile_mode else contextlib.nullcontext(): + orig = x + x = self.conv0(self.norm0(x)) + params = self.affine(emb).unsqueeze(2).unsqueeze(3) + if not self.amp_mode: + if params.dtype != x.dtype: + params = params.to(x.dtype) + + if self.adaptive_scale: + scale, shift = params.chunk(chunks=2, dim=1) + x = silu(torch.addcmul(shift, self.norm1(x), scale + 1)) + else: + x = self.norm1(x.add_(params)) + + x = self.conv1( + torch.nn.functional.dropout(x, p=self.dropout, training=self.training) ) - w = AttentionOp.apply(q, k) - a = torch.einsum("nqk,nck->ncq", w, v) - x = self.proj(a.reshape(*x.shape)).add_(x) + x = x.add_(self.skip(orig) if self.skip is not None else orig) x = x * self.skip_scale - torch.cuda.nvtx.range_pop() - return x + + if self.num_heads: + q, k, v = ( + self.qkv(self.norm2(x)) + .reshape( + x.shape[0], self.num_heads, x.shape[1] // self.num_heads, 3, -1 + ) + .unbind(3) + ) + # w = AttentionOp.apply(q, k) + # a = torch.einsum("nqk,nck->ncq", w, v) + # Compute attention in one step + with amp.autocast(enabled=self.amp_mode): + attn = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = self.proj(attn.reshape(*x.shape)).add_(x) + x = x * self.skip_scale + + return x class PositionalEmbedding(torch.nn.Module): @@ -517,16 +721,23 @@ class PositionalEmbedding(torch.nn.Module): Maximum number of positions for the embeddings, by default 10000. endpoint : bool, optional If True, the embedding considers the endpoint. By default False. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ def __init__( - self, num_channels: int, max_positions: int = 10000, endpoint: bool = False + self, + num_channels: int, + max_positions: int = 10000, + endpoint: bool = False, + amp_mode: bool = False, ): super().__init__() self.num_channels = num_channels self.max_positions = max_positions self.endpoint = endpoint + self.amp_mode = amp_mode def forward(self, x): freqs = torch.arange( @@ -534,7 +745,10 @@ def forward(self, x): ) freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) freqs = (1 / self.max_positions) ** freqs - x = x.ger(freqs.to(x.dtype)) + if not self.amp_mode: + if freqs.dtype != x.dtype: + freqs = freqs.to(x.dtype) + x = x.ger(freqs) x = torch.cat([x.cos(), x.sin()], dim=1) return x @@ -556,13 +770,21 @@ class FourierEmbedding(torch.nn.Module): scale : int, optional A scale factor applied to the random frequencies, controlling their range and thereby the frequency of oscillations in the embedding space. By default 16. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ - def __init__(self, num_channels: int, scale: int = 16): + def __init__(self, num_channels: int, scale: int = 16, amp_mode: bool = False): super().__init__() self.register_buffer("freqs", torch.randn(num_channels // 2) * scale) + self.amp_mode = amp_mode def forward(self, x): - x = x.ger((2 * np.pi * self.freqs).to(x.dtype)) + freqs = self.freqs + if not self.amp_mode: + if x.dtype != self.freqs.dtype: + freqs = self.freqs.to(x.dtype) + + x = x.ger((2 * np.pi * freqs)) x = torch.cat([x.cos(), x.sin()], dim=1) return x diff --git a/src/hirad/models/preconditioning.py b/src/hirad/models/preconditioning.py index c66b6b6..74496a5 100644 --- a/src/hirad/models/preconditioning.py +++ b/src/hirad/models/preconditioning.py @@ -22,19 +22,12 @@ import importlib import warnings from dataclasses import dataclass -from typing import List, Union +from typing import List, Literal, Tuple, Union import numpy as np -import nvtx import torch import torch.nn as nn -from .song_unet import ( - SongUNet, # noqa: F401 for globals -) -from .dhariwal_unet import ( - DhariwalUNet, # noqa: F401 for globals -) from .meta import ModelMetaData network_module = importlib.import_module("hirad.models") @@ -694,12 +687,11 @@ def round_sigma(sigma: Union[float, List, torch.Tensor]): """ return torch.as_tensor(sigma) - @dataclass -class EDMPrecondSRMetaData(ModelMetaData): +class EDMPrecondSuperResolutionMetaData(ModelMetaData): """EDMPrecondSR meta data""" - name: str = "EDMPrecondSR" + name: str = "EDMPrecondSuperResolution" # Optimization jit: bool = False cuda_graphs: bool = False @@ -715,33 +707,40 @@ class EDMPrecondSRMetaData(ModelMetaData): auto_grad: bool = False -class EDMPrecondSR(nn.Module): +class EDMPrecondSuperResolution(nn.Module): """ Improved preconditioning proposed in the paper "Elucidating the Design Space of - Diffusion-Based Generative Models" (EDM) for super-resolution tasks + Diffusion-Based Generative Models" (EDM). + + This is a variant of `EDMPrecond` that is specifically designed for super-resolution + tasks. It wraps a neural network that predicts the denoised high-resolution image + given a noisy high-resolution image, and additional conditioning that includes a + low-resolution image, and a noise level. Parameters ---------- - img_resolution : int - Image resolution. - img_channels : int - Number of color channels. + img_resolution : Union[int, Tuple[int, int]] + Spatial resolution `(H, W)` of the image. If a single int is provided, + the image is assumed to be square. img_in_channels : int - Number of input color channels. + Number of input channels in the low-resolution input image. img_out_channels : int - Number of output color channels. - use_fp16 : bool - Execute the underlying model at FP16 precision?, by default False. - sigma_min : float + Number of output channels in the high-resolution output image. + use_fp16 : bool, optional + Whether to use half-precision floating point (FP16) for model execution, + by default False. + model_type : str, optional + Class name of the underlying model. Must be one of the following: + 'SongUNet', 'SongUNetPosEmbd', 'SongUNetPosLtEmbd', 'DhariwalUNet'. + Defaults to 'SongUNetPosEmbd'. + sigma_data : float, optional + Expected standard deviation of the training data, by default 0.5. + sigma_min : float, optional Minimum supported noise level, by default 0.0. - sigma_max : float + sigma_max : float, optional Maximum supported noise level, by default inf. - sigma_data : float - Expected standard deviation of the training data, by default 0.5. - model_type :str - Class name of the underlying model, by default "SongUNetPosEmbd". **model_kwargs : dict - Keyword arguments for the underlying model. + Keyword arguments passed to the underlying model `__init__` method. Note ---- @@ -757,28 +756,26 @@ class EDMPrecondSR(nn.Module): def __init__( self, - img_resolution, - img_channels, - img_in_channels, - img_out_channels, - use_fp16=False, + img_resolution: Union[int, Tuple[int, int]], + img_in_channels: int, + img_out_channels: int, + use_fp16: bool = False, + model_type: Literal[ + "SongUNetPosEmbd", "SongUNetPosLtEmbd", "SongUNet", "DhariwalUNet" + ] = "SongUNetPosEmbd", + sigma_data: float = 0.5, sigma_min=0.0, sigma_max=float("inf"), - sigma_data=0.5, - model_type="SongUNetPosEmbd", - scale_cond_input=True, - **model_kwargs, + **model_kwargs: dict, ): super().__init__() #meta=EDMPrecondSRMetaData self.img_resolution = img_resolution - self.img_channels = img_channels # TODO: this is not used, remove it self.img_in_channels = img_in_channels self.img_out_channels = img_out_channels self.use_fp16 = use_fp16 + self.sigma_data = sigma_data self.sigma_min = sigma_min self.sigma_max = sigma_max - self.sigma_data = sigma_data - self.scale_cond_input = scale_cond_input model_class = getattr(network_module, model_type) self.model = model_class( @@ -787,39 +784,73 @@ def __init__( out_channels=img_out_channels, **model_kwargs, ) # TODO needs better handling - self.scaling_fn = self._get_scaling_fn() - - def _get_scaling_fn(self): - if self.scale_cond_input: - warnings.warn( - "scale_cond_input=True does not properly scale the conditional input. " - "(see https://github.com/NVIDIA/modulus/issues/229). " - "This setup will be deprecated. " - "Please set scale_cond_input=False.", - DeprecationWarning, - ) - return self._legacy_scaling_fn - else: - return self._scaling_fn + self.scaling_fn = self._scaling_fn @staticmethod - def _scaling_fn(x, img_lr, c_in): - return torch.cat([c_in * x, img_lr.to(x.dtype)], dim=1) + def _scaling_fn( + x: torch.Tensor, img_lr: torch.Tensor, c_in: torch.Tensor + ) -> torch.Tensor: + """ + Scale input tensors by first scaling the high-resolution tensor and then + concatenating with the low-resolution tensor. - @staticmethod - def _legacy_scaling_fn(x, img_lr, c_in): - return c_in * torch.cat([x, img_lr.to(x.dtype)], dim=1) + Parameters + ---------- + x : torch.Tensor + Noisy high-resolution image of shape (B, C_hr, H, W). + img_lr : torch.Tensor + Low-resolution image of shape (B, C_lr, H, W). + c_in : torch.Tensor + Scaling factor of shape (B, 1, 1, 1). + + Returns + ------- + torch.Tensor + Scaled and concatenated tensor of shape (B, C_in+C_out, H, W). + """ + return torch.cat([c_in * x, img_lr.to(x.dtype)], dim=1) - @nvtx.annotate(message="EDMPrecondSR", color="orange") def forward( self, - x, - img_lr, - sigma, - force_fp32=False, - **model_kwargs, - ): - # Concatenate input channels + x: torch.Tensor, + img_lr: torch.Tensor, + sigma: torch.Tensor, + force_fp32: bool = False, + **model_kwargs: dict, + ) -> torch.Tensor: + """ + Forward pass of the EDMPrecondSuperResolution model wrapper. + + This method applies the EDM preconditioning to compute the denoised image + from a noisy high-resolution image and low-resolution conditioning image. + + Parameters + ---------- + x : torch.Tensor + Noisy high-resolution image of shape (B, C_hr, H, W). The number of + channels `C_hr` should be equal to `img_out_channels`. + img_lr : torch.Tensor + Low-resolution conditioning image of shape (B, C_lr, H, W). The number + of channels `C_lr` should be equal to `img_in_channels`. + sigma : torch.Tensor + Noise level of shape (B) or (B, 1) or (B, 1, 1, 1). + force_fp32 : bool, optional + Whether to force FP32 precision regardless of the `use_fp16` attribute, + by default False. + **model_kwargs : dict + Additional keyword arguments to pass to the underlying model + `self.model` forward method. + + Returns + ------- + torch.Tensor + Denoised high-resolution image of shape (B, C_hr, H, W). + + Raises + ------ + ValueError + If the model output dtype doesn't match the expected dtype. + """ x = x.to(torch.float32) sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) dtype = ( @@ -855,13 +886,190 @@ def forward( return D_x @staticmethod - def round_sigma(sigma: Union[float, List, torch.Tensor]): + def round_sigma(sigma: Union[float, List, torch.Tensor]) -> torch.Tensor: """ Convert a given sigma value(s) to a tensor representation. - See EDMPrecond.round_sigma + + Parameters + ---------- + sigma : Union[float, List, torch.Tensor] + Sigma value(s) to convert. + + Returns + ------- + torch.Tensor + Tensor representation of sigma values. + + See Also + -------- + EDMPrecond.round_sigma """ return EDMPrecond.round_sigma(sigma) + @property + def amp_mode(self): + """ + Return the *amp_mode* flag of the wrapped model or *None*. + """ + return getattr(self.model, "amp_mode", None) + + @amp_mode.setter + def amp_mode(self, value: bool): + """ + Propagate *amp_mode* to the model and all its sub-modules. + """ + + if not isinstance(value, bool): + raise TypeError("amp_mode must be a boolean value.") + + if hasattr(self.model, "amp_mode"): + self.model.amp_mode = value + + for sub_module in self.model.modules(): + if hasattr(sub_module, "amp_mode"): + sub_module.amp_mode = value + +# NOTE: This is a deprecated version of the EDMPrecondSuperResolution model. +# This was used to maintain backwards compatibility and allow loading old models. +@dataclass +class EDMPrecondSRMetaData(ModelMetaData): + """EDMPrecondSR meta data""" + + name: str = "EDMPrecondSR" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = False + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class EDMPrecondSR(EDMPrecondSuperResolution): + """ + Improved preconditioning proposed in the paper "Elucidating the Design Space of + Diffusion-Based Generative Models" (EDM) for super-resolution tasks + + Parameters + ---------- + img_resolution : int + Image resolution. + img_channels : int + Number of color channels. + img_in_channels : int + Number of input color channels. + img_out_channels : int + Number of output color channels. + use_fp16 : bool + Execute the underlying model at FP16 precision?, by default False. + sigma_min : float + Minimum supported noise level, by default 0.0. + sigma_max : float + Maximum supported noise level, by default inf. + sigma_data : float + Expected standard deviation of the training data, by default 0.5. + model_type :str + Class name of the underlying model, by default "SongUNetPosEmbd". + **model_kwargs : dict + Keyword arguments for the underlying model. + + Note + ---- + References: + - Karras, T., Aittala, M., Aila, T. and Laine, S., 2022. Elucidating the + design space of diffusion-based generative models. Advances in Neural Information + Processing Systems, 35, pp.26565-26577. + - Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., + Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. + arXiv preprint arXiv:2309.15214. + """ + + def __init__( + self, + img_resolution, + img_channels, #deprecated + img_in_channels, + img_out_channels, + use_fp16=False, + sigma_min=0.0, + sigma_max=float("inf"), + sigma_data=0.5, + model_type="SongUNetPosEmbd", + scale_cond_input=True, #deprecated + **model_kwargs, + ): + warnings.warn( + "EDMPrecondSR is deprecated and will be removed in a future version. " + "Please use EDMPrecondSuperResolution instead.", + DeprecationWarning, + stacklevel=2, + ) + + if scale_cond_input: + warnings.warn( + "scale_cond_input=True does not properly scale the conditional input. " + "(see https://github.com/NVIDIA/modulus/issues/229). " + "This setup will be deprecated. " + "Please set scale_cond_input=False.", + DeprecationWarning, + ) + + super().__init__( + img_resolution=img_resolution, + img_in_channels=img_in_channels, + img_out_channels=img_out_channels, + use_fp16=use_fp16, + sigma_min=sigma_min, + sigma_max=sigma_max, + sigma_data=sigma_data, + model_type=model_type, + **model_kwargs, + ) + + # Store deprecated parameters for backward compatibility + self.img_channels = img_channels + self.scale_cond_input = scale_cond_input + + def forward( + self, + x, + img_lr, + sigma, + force_fp32=False, + **model_kwargs, + ): + """ + Forward pass of the EDMPrecondSR model wrapper. + + Parameters + ---------- + x : torch.Tensor + Noisy high-resolution image of shape (B, C_hr, H, W). + img_lr : torch.Tensor + Low-resolution conditioning image of shape (B, C_lr, H, W). + sigma : torch.Tensor + Noise level of shape (B) or (B, 1) or (B, 1, 1, 1). + force_fp32 : bool, optional + Whether to force FP32 precision regardless of the `use_fp16` attribute, + by default False. + **model_kwargs : dict + Additional keyword arguments to pass to the underlying model. + + Returns + ------- + torch.Tensor + Denoised high-resolution image of shape (B, C_hr, H, W). + """ + return super().forward( + x=x, img_lr=img_lr, sigma=sigma, force_fp32=force_fp32, **model_kwargs + ) class VEPrecond_dfsr(nn.Module): """ @@ -912,7 +1120,8 @@ def __init__( self.img_channels = img_channels self.label_dim = label_dim self.use_fp16 = use_fp16 - self.model = globals()[model_type]( + model_class = getattr(network_module, model_type) + self.model = model_class( img_resolution=img_resolution, in_channels=self.img_channels, out_channels=img_channels, @@ -1011,7 +1220,8 @@ def __init__( self.img_channels = img_channels self.label_dim = label_dim self.use_fp16 = use_fp16 - self.model = globals()[model_type]( + model_class = getattr(network_module, model_type) + self.model = model_class( img_resolution=img_resolution, in_channels=model_kwargs["model_channels"] * 2, out_channels=img_channels, diff --git a/src/hirad/models/song_unet.py b/src/hirad/models/song_unet.py index 6267dfc..a56f861 100644 --- a/src/hirad/models/song_unet.py +++ b/src/hirad/models/song_unet.py @@ -19,8 +19,9 @@ Diffusion-Based Generative Models". """ +import contextlib from dataclasses import dataclass -from typing import List, Union +from typing import Callable, List, Optional, Union import numpy as np import nvtx @@ -71,7 +72,8 @@ class SongUNet(nn.Module): Parameters ----------- img_resolution : Union[List[int], int] - The resolution of the input/output image, 1 value represents a square image. + The resolution of the input/output image. Can be a single int for square images + or a list [height, width] for rectangular images. in_channels : int Number of channels in the input image. out_channels : int @@ -81,7 +83,7 @@ class SongUNet(nn.Module): augment_dim : int, optional Dimensionality of augmentation labels; 0 means no augmentation. By default 0. model_channels : int, optional - Base multiplier for the number of channels across the network, by default 128. + Base multiplier for the number of channels across the network. By default 128. channel_mult : List[int], optional Per-resolution multipliers for the number of channels. By default [1,2,2,2]. channel_mult_emb : int, optional @@ -93,29 +95,39 @@ class SongUNet(nn.Module): dropout : float, optional Dropout probability applied to intermediate activations. By default 0.10. label_dropout : float, optional - Dropout probability of class labels for classifier-free guidance. By default 0.0. + Dropout probability of class labels for classifier-free guidance. By default 0.0. embedding_type : str, optional - Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++, 'zero' for none + Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++, 'zero' for none. By default 'positional'. channel_mult_noise : int, optional Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. encoder_type : str, optional - Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. By default - 'standard'. + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. , 'skip' for skip connections. + By default 'standard'. decoder_type : str, optional - Decoder architecture: 'standard' for both DDPM++ and NCSN++. By default - 'standard'. - resample_filter : List[int], optional (default=[1,1]) - Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. - checkpoint_level : int, optional (default=0) - How many layers should use gradient checkpointing, 0 is None - additive_pos_embed: bool = False, - Set to True to add a learned position embedding after the first conv (used in StormCast) + Decoder architecture: 'standard' or 'skip' for skip connections. By default 'standard'. + resample_filter : List[int], optional + Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++. By default [1,1]. + checkpoint_level : int, optional + Number of layers that should use gradient checkpointing (0 disables checkpointing). + Higher values trade memory for computation. By default 0. + additive_pos_embed : bool, optional + If True, adds a learned positional embedding after the first convolution layer. + Used in StormCast model. By default False. + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + profile_mode: + A boolean flag indicating whether to enable all nvtx annotations during profiling. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. Reference ---------- - Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and Poole, B., 2020. Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456. @@ -156,6 +168,10 @@ def __init__( resample_filter: List[int] = [1, 1], checkpoint_level: int = 0, additive_pos_embed: bool = False, + use_apex_gn: bool = False, + act: str = "silu", + profile_mode: bool = False, + amp_mode: bool = False, ): valid_embedding_types = ["fourier", "positional", "zero"] if embedding_type not in valid_embedding_types: @@ -196,7 +212,14 @@ def __init__( init=init, init_zero=init_zero, init_attn=init_attn, + use_apex_gn=use_apex_gn, + act=act, + fused_conv_bias=True, + profile_mode=profile_mode, + amp_mode=amp_mode, ) + self.profile_mode = profile_mode + self.amp_mode = amp_mode # for compatibility with older versions that took only 1 dimension self.img_resolution = img_resolution @@ -220,12 +243,19 @@ def __init__( # Mapping. if self.embedding_type != "zero": self.map_noise = ( - PositionalEmbedding(num_channels=noise_channels, endpoint=True) + PositionalEmbedding( + num_channels=noise_channels, endpoint=True, amp_mode=amp_mode + ) if embedding_type == "positional" - else FourierEmbedding(num_channels=noise_channels) + else FourierEmbedding(num_channels=noise_channels, amp_mode=amp_mode) ) self.map_label = ( - Linear(in_features=label_dim, out_features=noise_channels, **init) + Linear( + in_features=label_dim, + out_features=noise_channels, + amp_mode=amp_mode, + **init, + ) if label_dim else None ) @@ -234,16 +264,23 @@ def __init__( in_features=augment_dim, out_features=noise_channels, bias=False, + amp_mode=amp_mode, **init, ) if augment_dim else None ) self.map_layer0 = Linear( - in_features=noise_channels, out_features=emb_channels, **init + in_features=noise_channels, + out_features=emb_channels, + amp_mode=amp_mode, + **init, ) self.map_layer1 = Linear( - in_features=emb_channels, out_features=emb_channels, **init + in_features=emb_channels, + out_features=emb_channels, + amp_mode=amp_mode, + **init, ) # Encoder. @@ -256,7 +293,12 @@ def __init__( cin = cout cout = model_channels self.enc[f"{res}x{res}_conv"] = Conv2d( - in_channels=cin, out_channels=cout, kernel=3, **init + in_channels=cin, + out_channels=cout, + kernel=3, + fused_conv_bias=True, + amp_mode=amp_mode, + **init, ) else: self.enc[f"{res}x{res}_down"] = UNetBlock( @@ -269,9 +311,15 @@ def __init__( kernel=0, down=True, resample_filter=resample_filter, + amp_mode=amp_mode, ) self.enc[f"{res}x{res}_aux_skip"] = Conv2d( - in_channels=caux, out_channels=cout, kernel=1, **init + in_channels=caux, + out_channels=cout, + kernel=1, + fused_conv_bias=True, + amp_mode=amp_mode, + **init, ) if encoder_type == "residual": self.enc[f"{res}x{res}_aux_residual"] = Conv2d( @@ -281,6 +329,8 @@ def __init__( down=True, resample_filter=resample_filter, fused_resample=True, + fused_conv_bias=True, + amp_mode=amp_mode, **init, ) caux = cout @@ -325,107 +375,138 @@ def __init__( kernel=0, up=True, resample_filter=resample_filter, + amp_mode=amp_mode, ) self.dec[f"{res}x{res}_aux_norm"] = GroupNorm( - num_channels=cout, eps=1e-6 + num_channels=cout, + eps=1e-6, + use_apex_gn=use_apex_gn, + amp_mode=amp_mode, ) self.dec[f"{res}x{res}_aux_conv"] = Conv2d( - in_channels=cout, out_channels=out_channels, kernel=3, **init_zero + in_channels=cout, + out_channels=out_channels, + kernel=3, + fused_conv_bias=True, + amp_mode=amp_mode, + **init_zero, ) - @nvtx.annotate(message="SongUNet", color="blue") def forward(self, x, noise_labels, class_labels, augment_labels=None): - if self.embedding_type != "zero": - # Mapping. - emb = self.map_noise(noise_labels) - emb = ( - emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) - ) # swap sin/cos - if self.map_label is not None: - tmp = class_labels - if self.training and self.label_dropout: - tmp = tmp * ( - torch.rand([x.shape[0], 1], device=x.device) - >= self.label_dropout - ).to(tmp.dtype) - emb = emb + self.map_label(tmp * np.sqrt(self.map_label.in_features)) - if self.map_augment is not None and augment_labels is not None: - emb = emb + self.map_augment(augment_labels) - emb = silu(self.map_layer0(emb)) - emb = silu(self.map_layer1(emb)) - else: - emb = torch.zeros( - (noise_labels.shape[0], self.emb_channels), device=x.device - ) + with nvtx.annotate( + message="SongUNet", color="blue" + ) if self.profile_mode else contextlib.nullcontext(): + if self.embedding_type != "zero": + # Mapping. + emb = self.map_noise(noise_labels) + emb = ( + emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) + ) # swap sin/cos + if self.map_label is not None: + tmp = class_labels + if self.training and self.label_dropout: + tmp = tmp * ( + torch.rand([x.shape[0], 1], device=x.device) + >= self.label_dropout + ).to(tmp.dtype) + emb = emb + self.map_label( + tmp * np.sqrt(self.map_label.in_features) + ) + if self.map_augment is not None and augment_labels is not None: + emb = emb + self.map_augment(augment_labels) + emb = silu(self.map_layer0(emb)) + emb = silu(self.map_layer1(emb)) + else: + emb = torch.zeros( + (noise_labels.shape[0], self.emb_channels), device=x.device + ) - # Encoder. - skips = [] - aux = x - for name, block in self.enc.items(): - with nvtx.annotate(f"SongUNet encoder: {name}", color="blue"): - if "aux_down" in name: - aux = block(aux) - elif "aux_skip" in name: - x = skips[-1] = x + block(aux) - elif "aux_residual" in name: - x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) - elif "_conv" in name: - x = block(x) - if self.additive_pos_embed: - x = x + self.spatial_emb.to(dtype=x.dtype) - skips.append(x) - else: - # For UNetBlocks check if we should use gradient checkpointing - if isinstance(block, UNetBlock): - if x.shape[-1] > self.checkpoint_threshold: - x = checkpoint(block, x, emb, use_reentrant=False) - else: - x = block(x, emb) - else: + # Encoder. + skips = [] + aux = x + for name, block in self.enc.items(): + with nvtx.annotate( + f"SongUNet encoder: {name}", color="blue" + ) if self.profile_mode else contextlib.nullcontext(): + if "aux_down" in name: + aux = block(aux) + elif "aux_skip" in name: + x = skips[-1] = x + block(aux) + elif "aux_residual" in name: + x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) + elif "_conv" in name: x = block(x) - skips.append(x) + if self.additive_pos_embed: + x = x + self.spatial_emb.to(dtype=x.dtype) + skips.append(x) + else: + # For UNetBlocks check if we should use gradient checkpointing + if isinstance(block, UNetBlock): + if x.shape[-1] > self.checkpoint_threshold: + # self.checkpoint = checkpoint? + # else: self.checkpoint = lambda(block,x,emb:block(x,emb)) + x = checkpoint(block, x, emb, use_reentrant=False) + else: + # AssertionError: Only support NHWC layout. + x = block(x, emb) + else: + x = block(x) + skips.append(x) - # Decoder. - aux = None - tmp = None - for name, block in self.dec.items(): - with nvtx.annotate(f"SongUNet decoder: {name}", color="blue"): - if "aux_up" in name: - aux = block(aux) - elif "aux_norm" in name: - tmp = block(x) - elif "aux_conv" in name: - tmp = block(silu(tmp)) - aux = tmp if aux is None else tmp + aux - else: - if x.shape[1] != block.in_channels: - x = torch.cat([x, skips.pop()], dim=1) - # check for checkpointing on decoder blocks and up sampling blocks - if ( - x.shape[-1] > self.checkpoint_threshold and "_block" in name - ) or ( - x.shape[-1] > (self.checkpoint_threshold / 2) and "_up" in name - ): - x = checkpoint(block, x, emb, use_reentrant=False) + # Decoder. + aux = None + tmp = None + for name, block in self.dec.items(): + with nvtx.annotate( + f"SongUNet decoder: {name}", color="blue" + ) if self.profile_mode else contextlib.nullcontext(): + if "aux_up" in name: + aux = block(aux) + elif "aux_norm" in name: + tmp = block(x) + elif "aux_conv" in name: + tmp = block(silu(tmp)) + aux = tmp if aux is None else tmp + aux else: - x = block(x, emb) - return aux + if x.shape[1] != block.in_channels: + x = torch.cat([x, skips.pop()], dim=1) + # check for checkpointing on decoder blocks and up sampling blocks + if ( + x.shape[-1] > self.checkpoint_threshold and "_block" in name + ) or ( + x.shape[-1] > (self.checkpoint_threshold / 2) + and "_up" in name + ): + x = checkpoint(block, x, emb, use_reentrant=False) + else: + x = block(x, emb) + return aux class SongUNetPosEmbd(SongUNet): - """ - Reimplementation of the DDPM++ and NCSN++ architectures, U-Net variants with - optional self-attention,embeddings, and encoder-decoder components. + """Extends SongUNet with positional embeddings. This model supports conditional and unconditional setups, as well as several options for various internal architectural choices such as encoder and decoder type, embedding type, etc., making it flexible and adaptable to different tasks and configurations. + This model adds positional embeddings to the base SongUNet architecture. The embeddings + can be selected using either a selector function or global indices, with the selector + approach being more computationally efficient. + + The model provides two methods for selecting positional embeddings: + + 1. Using a selector function (preferred method). See + :meth:`positional_embedding_selector` for details. + 2. Using global indices. See :meth:`positional_embedding_indexing` for + details. + Parameters ----------- img_resolution : Union[List[int], int] - The resolution of the input/output image, 1 value represents a square image. + The resolution of the input/output image. Can be a single int for square images + or a list [height, width] for rectangular images. in_channels : int Number of channels in the input image. out_channels : int @@ -435,39 +516,63 @@ class SongUNetPosEmbd(SongUNet): augment_dim : int, optional Dimensionality of augmentation labels; 0 means no augmentation. By default 0. model_channels : int, optional - Base multiplier for the number of channels across the network, by default 128. + Base multiplier for the number of channels across the network. By default 128. channel_mult : List[int], optional - Per-resolution multipliers for the number of channels. By default [1,2,2,2]. + Per-resolution multipliers for the number of channels. By default [1,2,2,2,2]. channel_mult_emb : int, optional Multiplier for the dimensionality of the embedding vector. By default 4. num_blocks : int, optional Number of residual blocks per resolution. By default 4. attn_resolutions : List[int], optional - Resolutions at which self-attention layers are applied. By default [16]. + Resolutions at which self-attention layers are applied. By default [28]. dropout : float, optional Dropout probability applied to intermediate activations. By default 0.13. label_dropout : float, optional - Dropout probability of class labels for classifier-free guidance. By default 0.0. + Dropout probability of class labels for classifier-free guidance. By default 0.0. embedding_type : str, optional Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. By default 'positional'. channel_mult_noise : int, optional Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. encoder_type : str, optional - Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. By default - 'standard'. + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. , 'skip' for skip connections. + By default'standard'. decoder_type : str, optional - Decoder architecture: 'standard' for both DDPM++ and NCSN++. By default - 'standard'. - resample_filter : List[int], optional (default=[1,1]) - Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. - - - Reference - ---------- - Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and - Poole, B., 2020. Score-based generative modeling through stochastic differential - equations. arXiv preprint arXiv:2011.13456. + Decoder architecture: 'standard' or 'skip' for skip connections. By default 'standard'. + resample_filter : List[int], optional + Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++. By default [1,1]. + gridtype : str, optional + Type of positional grid to use: 'sinusoidal', 'learnable', 'linear', or 'test'. + Controls how positional information is encoded. By default 'sinusoidal'. + N_grid_channels : int, optional + Number of channels in the positional embedding grid. For 'sinusoidal' must be 4 or + multiple of 4. For 'linear' must be 2. By default 4. + checkpoint_level : int, optional + Number of layers that should use gradient checkpointing (0 disables checkpointing). + Higher values trade memory for computation. By default 0. + additive_pos_embed : bool, optional + If True, adds a learned positional embedding after the first convolution layer. + Used in StormCast model. By default False. + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + profile_mode: + A boolean flag indicating whether to enable all nvtx annotations during profiling. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. + lead_time_mode : bool, optional + A boolean flag indicating whether we are running SongUNet with lead time embedding. Defaults to False. + lead_time_channels : int, optional + Number of channels in the lead time embedding. These are learned embeddings that + encode temporal forecast information. By default None. + lead_time_steps : int, optional + Number of discrete lead time steps to support. Each step gets its own learned + embedding vector. By default 9. + prob_channels : List[int], optional + Indices of probability output channels that should use softmax activation. + Used for classification outputs. By default empty list. Note ----- @@ -476,13 +581,41 @@ class SongUNetPosEmbd(SongUNet): Example -------- - >>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> import torch + >>> from physicsnemo.models.diffusion.song_unet import SongUNetPosEmbd + >>> from physicsnemo.utils.patching import GridPatching2D + >>> + >>> # Model initialization - in_channels must include both original input channels (2) + >>> # and the positional embedding channels (N_grid_channels=4 by default) + >>> model = SongUNetPosEmbd(img_resolution=16, in_channels=2+4, out_channels=2) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> # The input has only the original 2 channels - positional embeddings are + >>> # added automatically inside the forward method >>> input_image = torch.ones([1, 2, 16, 16]) >>> output_image = model(input_image, noise_labels, class_labels) >>> output_image.shape torch.Size([1, 2, 16, 16]) + >>> + >>> # Using a global index to select all positional embeddings + >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16)) + >>> global_index = patching.global_index(batch_size=1) + >>> output_image = model( + ... input_image, noise_labels, class_labels, + ... global_index=global_index + ... ) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + >>> + >>> # Using a custom embedding selector to select all positional embeddings + >>> def patch_embedding_selector(emb): + ... return patching.apply(emb[None].expand(1, -1, -1, -1)) + >>> output_image = model( + ... input_image, noise_labels, class_labels, + ... embedding_selector=patch_embedding_selector + ... ) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) """ def __init__( @@ -507,6 +640,15 @@ def __init__( gridtype: str = "sinusoidal", N_grid_channels: int = 4, checkpoint_level: int = 0, + additive_pos_embed: bool = False, + use_apex_gn: bool = False, + act: str = "silu", + profile_mode: bool = False, + amp_mode: bool = False, + lead_time_mode: bool = False, + lead_time_channels: int = None, + lead_time_steps: int = 9, + prob_channels: List[int] = [], ): super().__init__( img_resolution, @@ -527,49 +669,286 @@ def __init__( decoder_type, resample_filter, checkpoint_level, + additive_pos_embed, + use_apex_gn, + act, + profile_mode, + amp_mode, ) self.gridtype = gridtype self.N_grid_channels = N_grid_channels - self.pos_embd = self._get_positional_embedding() + if self.gridtype == "learnable": + self.pos_embd = self._get_positional_embedding() + else: + self.register_buffer("pos_embd", self._get_positional_embedding().float()) + self.lead_time_mode = lead_time_mode + if self.lead_time_mode: + self.lead_time_channels = lead_time_channels + self.lead_time_steps = lead_time_steps + self.lt_embd = self._get_lead_time_embedding() + self.prob_channels = prob_channels + if self.prob_channels: + self.scalar = torch.nn.Parameter( + torch.ones((1, len(self.prob_channels), 1, 1)) + ) - @nvtx.annotate(message="SongUNet", color="blue") def forward( - self, x, noise_labels, class_labels, global_index=None, augment_labels=None + self, + x, + noise_labels, + class_labels, + global_index: Optional[torch.Tensor] = None, + embedding_selector: Optional[Callable] = None, + augment_labels=None, + lead_time_label=None, ): - # append positional embedding to input conditioning - if self.pos_embd is not None: - selected_pos_embd = self.positional_embedding_indexing(x, global_index) - x = torch.cat((x, selected_pos_embd), dim=1) + with nvtx.annotate( + message="SongUNetPosEmbd", color="blue" + ) if self.profile_mode else contextlib.nullcontext(): + if embedding_selector is not None and global_index is not None: + raise ValueError( + "Cannot provide both embedding_selector and global_index. " + "embedding_selector is the preferred approach for better efficiency." + ) + + if x.dtype != self.pos_embd.dtype: + self.pos_embd = self.pos_embd.to(x.dtype) + + # Append positional embedding to input conditioning + if self.pos_embd is not None: + # Select positional embeddings with a selector function + if embedding_selector is not None: + selected_pos_embd = self.positional_embedding_selector( + x, embedding_selector + ) + # Select positional embeddings using global indices (selects all + # embeddings if global_index is None) + else: + selected_pos_embd = self.positional_embedding_indexing( + x, global_index=global_index, lead_time_label=lead_time_label + ) + x = torch.cat((x, selected_pos_embd), dim=1) + + out = super().forward(x, noise_labels, class_labels, augment_labels) + + if self.lead_time_mode: + # if training mode, let crossEntropyLoss do softmax. The model outputs logits. + # if eval mode, the model outputs probability + all_channels = list(range(out.shape[1])) # [0, 1, 2, ..., 10] + scalar_channels = [ + item for item in all_channels if item not in self.prob_channels + ] + if self.prob_channels and (not self.training): + out_final = torch.cat( + ( + out[:, scalar_channels], + (out[:, self.prob_channels] * self.scalar).softmax(dim=1), + ), + dim=1, + ) + elif self.prob_channels and self.training: + out_final = torch.cat( + ( + out[:, scalar_channels], + (out[:, self.prob_channels] * self.scalar), + ), + dim=1, + ) + else: + out_final = out + return out_final + + return out + + def positional_embedding_indexing( + self, + x: torch.Tensor, + global_index: Optional[torch.Tensor] = None, + lead_time_label=None, + ) -> torch.Tensor: + """Select positional embeddings using global indices. - return super().forward(x, noise_labels, class_labels, augment_labels) + This method either uses global indices to select specific embeddings or expands + the embeddings for the full input when no indices are provided. + + Typically used in patch-based training, where the batch dimension + contains multiple patches extracted from a larger image. + + Arguments + --------- + x : torch.Tensor + Input tensor of shape (B, C, H, W), used to determine batch size + and device. + global_index : Optional[torch.Tensor] + Optional tensor of indices for selecting embeddings. These should + correspond to the spatial indices of the batch elements in the + input tensor x. When provided, should have shape (P, 2, H, W) where + the second dimension contains y,x coordinates (indices of the + positional embedding grid). + + Returns + ------- + torch.Tensor + Selected positional embeddings with shape: + - If global_index provided: (B, N_pe, H, W) + - If global_index is None: (B, N_pe, H_pe, W_pe) + where N_pe is the number of positional embedding channels, and H_pe + and W_pe are the height and width of the positional embedding grid. + + Example + ------- + >>> # Create global indices using patching utility: + >>> from physicsnemo.utils.patching import GridPatching2D + >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8)) + >>> global_index = patching.global_index(batch_size=3) + >>> print(global_index.shape) + torch.Size([4, 2, 8, 8]) + + See Also + -------- + :meth:`physicsnemo.utils.patching.RandomPatching2D.global_index` + For generating random patch indices. + :meth:`physicsnemo.utils.patching.GridPatching2D.global_index` + For generating deterministic grid-based patch indices. + See these methods for possible ways to generate the global_index parameter. + """ + # If no global indices are provided, select all embeddings and expand + # to match the batch size of the input + if x.dtype != self.pos_embd.dtype: + self.pos_embd = self.pos_embd.to(x.dtype) - def positional_embedding_indexing(self, x, global_index): if global_index is None: - selected_pos_embd = ( - self.pos_embd.to(x.dtype) - .to(x.device)[None] - .expand((x.shape[0], -1, -1, -1)) - ) + if self.lead_time_mode: + selected_pos_embd = [] + if self.pos_embd is not None: + selected_pos_embd.append( + self.pos_embd[None].expand((x.shape[0], -1, -1, -1)) + ) + if self.lt_embd is not None: + selected_pos_embd.append( + torch.reshape( + self.lt_embd[lead_time_label.int()], + ( + x.shape[0], + self.lead_time_channels, + self.img_shape_y, + self.img_shape_x, + ), + ) + ) + if len(selected_pos_embd) > 0: + selected_pos_embd = torch.cat(selected_pos_embd, dim=1) + else: + selected_pos_embd = self.pos_embd[None].expand( + (x.shape[0], -1, -1, -1) + ) # (B, N_pe, H, W) + else: - B = global_index.shape[0] - X = global_index.shape[2] - Y = global_index.shape[3] + P = global_index.shape[0] + B = x.shape[0] // P + H = global_index.shape[2] + W = global_index.shape[3] + global_index = torch.reshape( torch.permute(global_index, (1, 0, 2, 3)), (2, -1) - ) # (B, 2, X, Y) to (2, B*X*Y) - selected_pos_embd = self.pos_embd.to(x.device)[ + ) # (P, 2, X, Y) to (2, P*X*Y) + selected_pos_embd = self.pos_embd[ :, global_index[0], global_index[1] - ] # (N_pe, B*X*Y) - selected_pos_embd = ( - torch.permute( - torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], B, X, Y)), - (1, 0, 2, 3), - ) - .to(x.device) - .to(x.dtype) - ) # (B, N_pe, X, Y) + ] # (N_pe, P*X*Y) + selected_pos_embd = torch.permute( + torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], P, H, W)), + (1, 0, 2, 3), + ) # (P, N_pe, X, Y) + + selected_pos_embd = selected_pos_embd.repeat( + B, 1, 1, 1 + ) # (B*P, N_pe, X, Y) + + # Append positional and lead time embeddings to input conditioning + if self.lead_time_mode: + embeds = [] + if self.pos_embd is not None: + embeds.append(selected_pos_embd) # reuse code below + if self.lt_embd is not None: + lt_embds = self.lt_embd[ + lead_time_label.int() + ] # (B, self.lead_time_channels, self.img_shape_y, self.img_shape_x), + + selected_lt_pos_embd = lt_embds[ + :, :, global_index[0], global_index[1] + ] # (B, N_lt, P*X*Y) + selected_lt_pos_embd = torch.reshape( + torch.permute( + torch.reshape( + selected_lt_pos_embd, + (B, self.lead_time_channels, P, H, W), + ), + (0, 2, 1, 3, 4), + ).contiguous(), + (B * P, self.lead_time_channels, H, W), + ) # (B*P, N_pe, X, Y) + embeds.append(selected_lt_pos_embd) + + if len(embeds) > 0: + selected_pos_embd = torch.cat(embeds, dim=1) + return selected_pos_embd + + def positional_embedding_selector( + self, + x: torch.Tensor, + embedding_selector: Callable[[torch.Tensor], torch.Tensor], + ) -> torch.Tensor: + """Select positional embeddings using a selector function. + + Similar to positional_embedding_indexing, but uses a selector function + to select the embeddings. This method provides a more efficient way to + select embeddings for batches of data. + Typically used with patch-based processing, where the batch dimension + contains multiple patches extracted from a larger image. + + Arguments + --------- + x : torch.Tensor + Input tensor of shape (B, C, H, W) only used to determine dtype and + device. + embedding_selector : Callable + Function that takes as input an embedding tensor of shape (N_pe, + H_pe, W_pe) and returns selected embeddings with shape (batch_size, N_pe, H, W). + Each selected embedding should correspond to the positional + information of each batch element in x. + For patch-based processing, typically this should be based on + :meth:`physicsnemo.utils.patching.BasePatching2D.apply` method to + maintain consistency with patch extraction. + embeds : Optional[torch.Tensor] + Optional tensor for combined positional and lead time embeddings tensor + + Returns + ------- + torch.Tensor + Selected positional embeddings with shape (B, N_pe, H, W) + where N_pe is the number of positional embedding channels. + + Example + ------- + >>> # Define a selector function with a patching utility: + >>> from physicsnemo.utils.patching import GridPatching2D + >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8)) + >>> batch_size = 4 + >>> def embedding_selector(emb): + ... return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) + >>> + + See Also + -------- + :meth:`physicsnemo.utils.patching.BasePatching2D.apply` + For the base patching method typically used in embedding_selector. + """ + if x.dtype != self.pos_embd.dtype: + self.pos_embd = self.pos_embd.to(x.dtype) + + return embedding_selector(self.pos_embd) # (B, N_pe, H, W) def _get_positional_embedding(self): if self.N_grid_channels == 0: @@ -577,14 +956,16 @@ def _get_positional_embedding(self): elif self.gridtype == "learnable": grid = torch.nn.Parameter( torch.randn(self.N_grid_channels, self.img_shape_y, self.img_shape_x) - ) + ) # (N_grid_channels, img_shape_y, img_shape_x) elif self.gridtype == "linear": if self.N_grid_channels != 2: raise ValueError("N_grid_channels must be set to 2 for gridtype linear") x = np.meshgrid(np.linspace(-1, 1, self.img_shape_y)) y = np.meshgrid(np.linspace(-1, 1, self.img_shape_x)) grid_x, grid_y = np.meshgrid(y, x) - grid = torch.from_numpy(np.stack((grid_x, grid_y), axis=0)) + grid = torch.from_numpy( + np.stack((grid_x, grid_y), axis=0) + ) # (2, img_shape_y, img_shape_x) grid.requires_grad = False elif self.gridtype == "sinusoidal" and self.N_grid_channels == 4: # print('sinusuidal grid added ......') @@ -600,7 +981,7 @@ def _get_positional_embedding(self): np.stack((grid_x1, grid_y1, grid_x2, grid_y2), axis=0), axis=0 ) ) - ) + ) # (4, img_shape_y, img_shape_x) grid.requires_grad = False elif self.gridtype == "sinusoidal" and self.N_grid_channels != 4: if self.N_grid_channels % 4 != 0: @@ -616,28 +997,50 @@ def _get_positional_embedding(self): for p_fn in [np.sin, np.cos]: grid_list.append(p_fn(grid_x * freq)) grid_list.append(p_fn(grid_y * freq)) - grid = torch.from_numpy(np.stack(grid_list, axis=0)) + grid = torch.from_numpy( + np.stack(grid_list, axis=0) + ) # (N_grid_channels, img_shape_y, img_shape_x) grid.requires_grad = False elif self.gridtype == "test" and self.N_grid_channels == 2: idx_x = torch.arange(self.img_shape_y) idx_y = torch.arange(self.img_shape_x) mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y) - grid = torch.stack((mesh_x, mesh_y), dim=0) + grid = torch.stack((mesh_x, mesh_y), dim=0) # (2, img_shape_y, img_shape_x) else: raise ValueError("Gridtype not supported.") return grid + + def _get_lead_time_embedding(self): + if (self.lead_time_steps is None) or (self.lead_time_channels is None): + return None + grid = torch.nn.Parameter( + torch.randn( + self.lead_time_steps, + self.lead_time_channels, + self.img_shape_y, + self.img_shape_x, + ) + ) # (lead_time_steps, lead_time_channels, img_shape_y, img_shape_x) + return grid -class SongUNetPosLtEmbd(SongUNet): +class SongUNetPosLtEmbd(SongUNetPosEmbd): """ - This model is adapated from SongUNetPosEmbd, with the incoporatation of lead-time aware - embedding for the GEFS-HRRR model. The lead-time embedding is activated by setting the - lead_time_channels and lead_time_steps parameters. + This model is adapted from SongUNetPosEmbd, with the incorporation of lead-time aware + embeddings. The lead-time embedding is activated by setting the + `lead_time_channels` and `lead_time_steps` parameters. + + Like SongUNetPosEmbd, this model provides two methods for selecting positional embeddings: + 1. Using a selector function (preferred method). See + :meth:`positional_embedding_selector` for details. + 2. Using global indices. See :meth:`positional_embedding_indexing` for + details. Parameters ----------- img_resolution : Union[List[int], int] - The resolution of the input/output image, 1 value represents a square image. + The resolution of the input/output image. Can be a single int for square images + or a list [height, width] for rectangular images. in_channels : int Number of channels in the input image. out_channels : int @@ -647,44 +1050,63 @@ class SongUNetPosLtEmbd(SongUNet): augment_dim : int, optional Dimensionality of augmentation labels; 0 means no augmentation. By default 0. model_channels : int, optional - Base multiplier for the number of channels across the network, by default 128. + Base multiplier for the number of channels across the network. By default 128. channel_mult : List[int], optional - Per-resolution multipliers for the number of channels. By default [1,2,2,2]. + Per-resolution multipliers for the number of channels. By default [1,2,2,2,2]. channel_mult_emb : int, optional Multiplier for the dimensionality of the embedding vector. By default 4. num_blocks : int, optional Number of residual blocks per resolution. By default 4. attn_resolutions : List[int], optional - Resolutions at which self-attention layers are applied. By default [16]. + Resolutions at which self-attention layers are applied. By default [28]. dropout : float, optional Dropout probability applied to intermediate activations. By default 0.13. label_dropout : float, optional - Dropout probability of class labels for classifier-free guidance. By default 0.0. + Dropout probability of class labels for classifier-free guidance. By default 0.0. embedding_type : str, optional Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. By default 'positional'. channel_mult_noise : int, optional Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. encoder_type : str, optional - Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. By default - 'standard'. + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++, 'skip' for skip connections. + By default 'standard'. decoder_type : str, optional - Decoder architecture: 'standard' for both DDPM++ and NCSN++. By default - 'standard'. - resample_filter : List[int], optional (default=[1,1]) - Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. - lead_time_channels: int, optional - Length of lead time embedding vector - lead_time_steps: int, optional - Total number of lead times + Decoder architecture: 'standard' or 'skip' for skip connections. By default 'standard'. + resample_filter : List[int], optional + Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++. By default [1,1]. + gridtype : str, optional + Type of positional grid to use: 'sinusoidal', 'learnable', 'linear', or 'test'. + Controls how positional information is encoded. By default 'sinusoidal'. + N_grid_channels : int, optional + Number of channels in the positional embedding grid. For 'sinusoidal' must be 4 or + multiple of 4. For 'linear' must be 2. By default 4. + lead_time_channels : int, optional + Number of channels in the lead time embedding. These are learned embeddings that + encode temporal forecast information. By default None. + lead_time_steps : int, optional + Number of discrete lead time steps to support. Each step gets its own learned + embedding vector. By default 9. + prob_channels : List[int], optional + Indices of probability output channels that should use softmax activation. + Used for classification outputs. By default empty list. + checkpoint_level : int, optional + Number of layers that should use gradient checkpointing (0 disables checkpointing). + Higher values trade memory for computation. By default 0. + additive_pos_embed : bool, optional + If True, adds a learned positional embedding after the first convolution layer. + Used in StormCast model. By default False. + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + profile_mode: + A boolean flag indicating whether to enable all nvtx annotations during profiling. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. - Reference - ---------- - Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and - Poole, B., 2020. Score-based generative modeling through stochastic differential - equations. arXiv preprint arXiv:2011.13456. - Note ----- Equivalent to the original implementation by Song et al., available at @@ -692,13 +1114,54 @@ class SongUNetPosLtEmbd(SongUNet): Example -------- - >>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> import torch + >>> from physicsnemo.models.diffusion.song_unet import SongUNetPosLtEmbd + >>> from physicsnemo.utils.patching import GridPatching2D + >>> + >>> # Model initialization - in_channels must include original input channels (2), + >>> # positional embedding channels (N_grid_channels=4 by default) and + >>> # lead time embedding channels (4) + >>> model = SongUNetPosLtEmbd( + ... img_resolution=16, in_channels=2+4+4, out_channels=2, + ... lead_time_channels=4, lead_time_steps=9 + ... ) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> # The input has only the original 2 channels - positional embeddings and + >>> # lead time embeddings are added automatically inside the forward method >>> input_image = torch.ones([1, 2, 16, 16]) - >>> output_image = model(input_image, noise_labels, class_labels) + >>> lead_time_label = torch.tensor([3]) + >>> output_image = model( + ... input_image, noise_labels, class_labels, + ... lead_time_label=lead_time_label + ... ) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + >>> + >>> # Using global_index to select all the positional and lead time embeddings + >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16)) + >>> global_index = patching.global_index(batch_size=1) + >>> output_image = model( + ... input_image, noise_labels, class_labels, + ... lead_time_label=lead_time_label, + ... global_index=global_index + ... ) >>> output_image.shape torch.Size([1, 2, 16, 16]) + + # NOTE: commented out doctest for embedding_selector due to compatibility issue + # >>> + # >>> # Using custom embedding selector to select all the positional and lead time embeddings + # >>> def patch_embedding_selector(emb): + # ... return patching.apply(emb[None].expand(1, -1, -1, -1)) + # >>> output_image = model( + # ... input_image, noise_labels, class_labels, + # ... lead_time_label=lead_time_label, + # ... embedding_selector=patch_embedding_selector + # ... ) + # >>> output_image.shape + # torch.Size([1, 2, 16, 16]) + """ def __init__( @@ -726,6 +1189,11 @@ def __init__( lead_time_steps: int = 9, prob_channels: List[int] = [], checkpoint_level: int = 0, + additive_pos_embed: bool = False, + use_apex_gn: bool = False, + act: str = "silu", + profile_mode: bool = False, + amp_mode: bool = False, ): super().__init__( img_resolution, @@ -745,162 +1213,38 @@ def __init__( encoder_type, decoder_type, resample_filter, + gridtype, + N_grid_channels, checkpoint_level, + additive_pos_embed, + use_apex_gn, + act, + profile_mode, + amp_mode, + True, # Note: lead_time_mode=True is enforced here + lead_time_channels, + lead_time_steps, + prob_channels, ) - self.gridtype = gridtype - self.N_grid_channels = N_grid_channels - self.pos_embd = self._get_positional_embedding() - self.lead_time_channels = lead_time_channels - self.lead_time_steps = lead_time_steps - self.lt_embd = self._get_lead_time_embedding() - self.prob_channels = prob_channels - if self.prob_channels: - self.scalar = torch.nn.Parameter( - torch.ones((1, len(self.prob_channels), 1, 1)) - ) - - @nvtx.annotate(message="SongUNet", color="blue") def forward( self, x, noise_labels, class_labels, lead_time_label=None, - global_index=None, + global_index: Optional[torch.Tensor] = None, + embedding_selector: Optional[Callable] = None, augment_labels=None, ): - # append positional embedding to input conditioning - embeds = [] - if self.pos_embd is not None: - embeds.append(self.pos_embd.to(x.device)) - if self.lt_embd is not None: - embeds.append( - torch.reshape( - self.lt_embd[lead_time_label.int()], - (self.lead_time_channels, self.img_shape_y, self.img_shape_x), - ).to(x.device) - ) - if len(embeds) > 0: - embeds = torch.cat(embeds, dim=0) - selected_pos_embd = self.positional_embedding_indexing( - x, embeds, global_index - ) - x = torch.cat((x, selected_pos_embd), dim=1) - out = super().forward(x, noise_labels, class_labels, augment_labels) - # if training mode, let crossEntropyLoss do softmax. The model outputs logits. - # if eval mode, the model outputs probability - all_channels = list(range(out.shape[1])) # [0, 1, 2, ..., 10] - scalar_channels = [ - item for item in all_channels if item not in self.prob_channels - ] - if self.prob_channels and (not self.training): - out_final = torch.cat( - ( - out[:, scalar_channels], - (out[:, self.prob_channels] * self.scalar).softmax(dim=1), - ), - dim=1, - ) - elif self.prob_channels and self.training: - out_final = torch.cat( - (out[:, scalar_channels], (out[:, self.prob_channels] * self.scalar)), - dim=1, - ) - else: - out_final = out - return out_final - - def positional_embedding_indexing(self, x, pos_embd, global_index): - if global_index is None: - selected_pos_embd = ( - pos_embd.to(x.dtype).to(x.device)[None].expand((x.shape[0], -1, -1, -1)) - ) - else: - B = global_index.shape[0] - X = global_index.shape[2] - Y = global_index.shape[3] - global_index = torch.reshape( - torch.permute(global_index, (1, 0, 2, 3)), (2, -1) - ) # (B, 2, X, Y) to (2, B*X*Y) - selected_pos_embd = pos_embd.to(x.device)[ - :, global_index[0], global_index[1] - ] # (N_pe, B*X*Y) - selected_pos_embd = ( - torch.permute( - torch.reshape(selected_pos_embd, (pos_embd.shape[0], B, X, Y)), - (1, 0, 2, 3), - ) - .to(x.device) - .to(x.dtype) - ) # (B, N_pe, X, Y) - return selected_pos_embd - - def _get_positional_embedding(self): - if self.N_grid_channels == 0: - return None - elif self.gridtype == "learnable": - grid = torch.nn.Parameter( - torch.randn(self.N_grid_channels, self.img_shape_y, self.img_shape_x) - ) - elif self.gridtype == "linear": - if self.N_grid_channels != 2: - raise ValueError("N_grid_channels must be set to 2 for gridtype linear") - x = np.meshgrid(np.linspace(-1, 1, self.img_shape_y)) - y = np.meshgrid(np.linspace(-1, 1, self.img_shape_x)) - grid_x, grid_y = np.meshgrid(y, x) - grid = torch.from_numpy(np.stack((grid_x, grid_y), axis=0)) - grid.requires_grad = False - elif self.gridtype == "sinusoidal" and self.N_grid_channels == 4: - # print('sinusuidal grid added ......') - x1 = np.meshgrid(np.sin(np.linspace(0, 2 * np.pi, self.img_shape_y))) - x2 = np.meshgrid(np.cos(np.linspace(0, 2 * np.pi, self.img_shape_y))) - y1 = np.meshgrid(np.sin(np.linspace(0, 2 * np.pi, self.img_shape_x))) - y2 = np.meshgrid(np.cos(np.linspace(0, 2 * np.pi, self.img_shape_x))) - grid_x1, grid_y1 = np.meshgrid(y1, x1) - grid_x2, grid_y2 = np.meshgrid(y2, x2) - grid = torch.squeeze( - torch.from_numpy( - np.expand_dims( - np.stack((grid_x1, grid_y1, grid_x2, grid_y2), axis=0), axis=0 - ) - ) - ) - grid.requires_grad = False - elif self.gridtype == "sinusoidal" and self.N_grid_channels != 4: - if self.N_grid_channels % 4 != 0: - raise ValueError("N_grid_channels must be a factor of 4") - num_freq = self.N_grid_channels // 4 - freq_bands = 2.0 ** np.linspace(0.0, num_freq, num=num_freq) - grid_list = [] - grid_x, grid_y = np.meshgrid( - np.linspace(0, 2 * np.pi, self.img_shape_x), - np.linspace(0, 2 * np.pi, self.img_shape_y), - ) - for freq in freq_bands: - for p_fn in [np.sin, np.cos]: - grid_list.append(p_fn(grid_x * freq)) - grid_list.append(p_fn(grid_y * freq)) - grid = torch.from_numpy(np.stack(grid_list, axis=0)) - grid.requires_grad = False - elif self.gridtype == "test" and self.N_grid_channels == 2: - idx_x = torch.arange(self.img_shape_y) - idx_y = torch.arange(self.img_shape_x) - mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y) - grid = torch.stack((mesh_x, mesh_y), dim=0) - else: - raise ValueError("Gridtype not supported.") - return grid - - def _get_lead_time_embedding(self): - if (self.lead_time_steps is None) or (self.lead_time_channels is None): - return None - grid = torch.nn.Parameter( - torch.randn( - self.lead_time_steps, - self.lead_time_channels, - self.img_shape_y, - self.img_shape_x, - ) + return super().forward( + x=x, + noise_labels=noise_labels, + class_labels=class_labels, + global_index=global_index, + embedding_selector=embedding_selector, + augment_labels=augment_labels, + lead_time_label=lead_time_label, ) - return grid + + # Nothing else is re-implemented, because everything is already in the parent SongUNetPosEmb \ No newline at end of file diff --git a/src/hirad/models/unet.py b/src/hirad/models/unet.py index 10079ec..e0a447a 100644 --- a/src/hirad/models/unet.py +++ b/src/hirad/models/unet.py @@ -16,6 +16,7 @@ import importlib from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Tuple, Union import torch import torch.nn as nn @@ -45,31 +46,35 @@ class MetaData(ModelMetaData): class UNet(nn.Module): # TODO a lot of redundancy, need to clean up """ - U-Net Wrapper for CorrDiff. + U-Net Wrapper for CorrDiff deterministic regression model. Parameters ----------- - img_resolution : int - The resolution of the input/output image. - img_channels : int - Number of color channels. + img_resolution : Union[int, Tuple[int, int]] + The resolution of the input/output image. If a single int is provided, + then the image is assumed to be square. img_in_channels : int - Number of input color channels. + Number of channels in the input image. img_out_channels : int - Number of output color channels. + Number of channels in the output image. use_fp16: bool, optional - Execute the underlying model at FP16 precision?, by default False. - sigma_min: float, optional - Minimum supported noise level, by default 0. - sigma_max: float, optional - Maximum supported noise level, by default float('inf'). - sigma_data: float, optional - Expected standard deviation of the training data, by default 0.5. + Execute the underlying model at FP16 precision, by default False. model_type: str, optional - Class name of the underlying model, by default 'DhariwalUNet'. + Class name of the underlying model. Must be one of the following: + 'SongUNet', 'SongUNetPosEmbd', 'SongUNetPosLtEmbd', 'DhariwalUNet'. + Defaults to 'SongUNetPosEmbd'. **model_kwargs : dict - Keyword arguments for the underlying model. + Keyword arguments passed to the underlying model `__init__` method. + + See Also + -------- + For information on model types and their usage: + :class:`~physicsnemo.models.diffusion.SongUNet`: Basic U-Net for diffusion models + :class:`~physicsnemo.models.diffusion.SongUNetPosEmbd`: U-Net with positional embeddings + :class:`~physicsnemo.models.diffusion.SongUNetPosLtEmbd`: U-Net with positional and lead-time embeddings + Please refer to the documentation of these classes for details on how to call + and use these models directly. References ---------- @@ -79,37 +84,66 @@ class UNet(nn.Module): # TODO a lot of redundancy, need to clean up arXiv preprint arXiv:2309.15214. """ + @classmethod + def _backward_compat_arg_mapper( + cls, version: str, args: Dict[str, Any] + ) -> Dict[str, Any]: + """Map arguments from older versions to current version format. + + Parameters + ---------- + version : str + Version of the checkpoint being loaded + args : Dict[str, Any] + Arguments dictionary from the checkpoint + + Returns + ------- + Dict[str, Any] + Updated arguments dictionary compatible with current version + """ + # Call parent class method first + args = super()._backward_compat_arg_mapper(version, args) + + if version == "0.1.0": + # In version 0.1.0, img_channels was unused + if "img_channels" in args: + _ = args.pop("img_channels") + + # Sigma parameters are also unused + if "sigma_min" in args: + _ = args.pop("sigma_min") + if "sigma_max" in args: + _ = args.pop("sigma_max") + if "sigma_data" in args: + _ = args.pop("sigma_data") + + return args + def __init__( self, - img_resolution, - img_channels, - img_in_channels, - img_out_channels, - use_fp16=False, - sigma_min=0, - sigma_max=float("inf"), - sigma_data=0.5, - model_type="SongUNetPosEmbd", - **model_kwargs, + img_resolution: Union[int, Tuple[int, int]], + img_in_channels: int, + img_out_channels: int, + use_fp16: bool = False, + model_type: Literal[ + "SongUNetPosEmbd", "SongUNetPosLtEmbd", "SongUNet", "DhariwalUNet" + ] = "SongUNetPosEmbd", + **model_kwargs: dict, ): super().__init__() #meta=MetaData - self.img_channels = img_channels - # for compatibility with older versions that took only 1 dimension if isinstance(img_resolution, int): self.img_shape_x = self.img_shape_y = img_resolution else: - self.img_shape_x = img_resolution[0] - self.img_shape_y = img_resolution[1] + self.img_shape_y = img_resolution[0] + self.img_shape_x = img_resolution[1] self.img_in_channels = img_in_channels self.img_out_channels = img_out_channels self.use_fp16 = use_fp16 - self.sigma_min = sigma_min - self.sigma_max = sigma_max - self.sigma_data = sigma_data model_class = getattr(network_module, model_type) self.model = model_class( img_resolution=img_resolution, @@ -118,13 +152,47 @@ def __init__( **model_kwargs, ) - def forward(self, x, img_lr, sigma, force_fp32=False, **model_kwargs): + def forward( + self, + x: torch.Tensor, + img_lr: torch.Tensor, + force_fp32: bool = False, + **model_kwargs: dict, + ) -> torch.Tensor: + """ + Forward pass of the UNet wrapper model. + + This method concatenates the input tensor with the low-resolution conditioning tensor + and passes the result through the underlying model. + + Parameters + ---------- + x : torch.Tensor + The input tensor, typically zero-filled, of shape (B, C_hr, H, W). + img_lr : torch.Tensor + Low-resolution conditioning image of shape (B, C_lr, H, W). + force_fp32 : bool, optional + Whether to force FP32 precision regardless of the `use_fp16` attribute, + by default False. + **model_kwargs : dict + Additional keyword arguments to pass to the underlying model + `self.model` forward method. + + Returns + ------- + torch.Tensor + Output tensor (prediction) of shape (B, C_hr, H, W). + + Raises + ------ + ValueError + If the model output dtype doesn't match the expected dtype. + """ + # SR: concatenate input channels if img_lr is not None: x = torch.cat((x, img_lr), dim=1) - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) dtype = ( torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") @@ -133,29 +201,27 @@ def forward(self, x, img_lr, sigma, force_fp32=False, **model_kwargs): F_x = self.model( x.to(dtype), # (c_in * x).to(dtype), - torch.zeros( - sigma.numel(), dtype=sigma.dtype, device=sigma.device - ), # c_noise.flatten() + torch.zeros(x.shape[0], dtype=dtype, device=x.device), # c_noise.flatten() class_labels=None, **model_kwargs, ) if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): raise ValueError( - f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + f"Expected the dtype to be {dtype}, " f"but got {F_x.dtype} instead." ) - # skip connection - for SR there's size mismatch bwtween input and output + # skip connection D_x = F_x.to(torch.float32) return D_x - def round_sigma(self, sigma): + def round_sigma(self, sigma: Union[float, List, torch.Tensor]) -> torch.Tensor: """ Convert a given sigma value(s) to a tensor representation. Parameters ---------- - sigma : Union[float list, torch.Tensor] + sigma : Union[float, List, torch.Tensor] The sigma value(s) to convert. Returns @@ -164,8 +230,31 @@ def round_sigma(self, sigma): The tensor representation of the provided sigma value(s). """ return torch.as_tensor(sigma) + + @property + def amp_mode(self): + """ + Return the *amp_mode* flag of the underlying model if present. + """ + return getattr(self.model, "amp_mode", None) + + @amp_mode.setter + def amp_mode(self, value: bool): + """ + Update *amp_mode* on the wrapped model and its sub-modules. + """ + if not isinstance(value, bool): + raise TypeError("amp_mode must be a boolean value.") + + if hasattr(self.model, "amp_mode"): + self.model.amp_mode = value + # Recursively update sub-modules that define *amp_mode*. + for sub_module in self.model.modules(): + if hasattr(sub_module, "amp_mode"): + sub_module.amp_mode = value +# TODO: implement amp_mode property for StormCastUNet (same as UNet) class StormCastUNet(nn.Module): """ U-Net wrapper for StormCast; used so the same Song U-Net network can be re-used for this model. @@ -189,7 +278,7 @@ class StormCastUNet(nn.Module): sigma_data: float, optional Expected standard deviation of the training data, by default 0.5. model_type: str, optional - Class name of the underlying model, by default 'DhariwalUNet'. + Class name of the underlying model, by default 'SongUNet'. **model_kwargs : dict Keyword arguments for the underlying model. diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 559d800..794dd55 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -5,6 +5,8 @@ import hydra from omegaconf import DictConfig, OmegaConf import json +from contextlib import nullcontext +import nvtx import torch from hydra.utils import to_absolute_path from torch.utils.tensorboard import SummaryWriter @@ -17,12 +19,44 @@ set_patch_shape, compute_num_accumulation_rounds, \ is_time_for_periodic_task, handle_and_clip_gradients from hirad.utils.checkpoint import load_checkpoint, save_checkpoint -from hirad.models import UNet, EDMPrecondSR -from hirad.losses import ResLoss, RegressionLoss, RegressionLossCE +from hirad.utils.patching import RandomPatching2D +from hirad.models import UNet, EDMPrecondSuperResolution, EDMPrecondSR +from hirad.losses import ResidualLoss, RegressionLoss, RegressionLossCE from hirad.datasets import init_train_valid_datasets_from_config from matplotlib import pyplot as plt +torch._dynamo.reset() +# Increase the cache size limit +torch._dynamo.config.cache_size_limit = 264 # Set to a higher value +torch._dynamo.config.verbose = True # Enable verbose logging +torch._dynamo.config.suppress_errors = False # Forces the error to show all details +torch._logging.set_logs(recompiles=True, graph_breaks=True) + +# Define safe CUDA profiler tools that fallback to no-ops when CUDA is not available +def cuda_profiler(): + if torch.cuda.is_available(): + return torch.cuda.profiler.profile() + else: + return nullcontext() + + +def cuda_profiler_start(): + if torch.cuda.is_available(): + torch.cuda.profiler.start() + + +def cuda_profiler_stop(): + if torch.cuda.is_available(): + torch.cuda.profiler.stop() + + +def profiler_emit_nvtx(): + if torch.cuda.is_available(): + return torch.autograd.profiler.emit_nvtx() + else: + return nullcontext() + @hydra.main(version_base=None, config_path="../conf", config_name="training") def main(cfg: DictConfig) -> None: # Initialize distributed environment for training @@ -63,7 +97,7 @@ def main(cfg: DictConfig) -> None: data_loader_kwargs = { "pin_memory": True, "num_workers": cfg.training.perf.dataloader_workers, - "prefetch_factor": 2, + "prefetch_factor": 2 if cfg.training.perf.dataloader_workers > 0 else None, } ( dataset, @@ -104,80 +138,64 @@ def main(cfg: DictConfig) -> None: else: patch_shape_x = None patch_shape_y = None + if ( + patch_shape_x + and patch_shape_y + and patch_shape_y >= img_shape[0] + and patch_shape_x >= img_shape[1] + ): + logger0.warning( + f"Patch shape {patch_shape_y}x{patch_shape_x} is larger than \ + the image shape {img_shape[0]}x{img_shape[1]}. Patching will not be used." + ) patch_shape = (patch_shape_y, patch_shape_x) - img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) - if patch_shape != img_shape: + use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) + if use_patching: + # Utility to perform patches extraction and batching + patching = RandomPatching2D( + img_shape=img_shape, + patch_shape=patch_shape, + patch_num=getattr(cfg.training.hp, "patch_num", 1), + ) logger0.info("Patch-based training enabled") else: + patching = None logger0.info("Patch-based training disabled") # interpolate global channel if patch-based model is used - if img_shape[1] != patch_shape[1]: + if use_patching: img_in_channels += dataset_channels # Instantiate the model and move to device. - if cfg.model.name not in ( - "regression", - "lt_aware_ce_regression", - "diffusion", - "patched_diffusion", - "lt_aware_patched_diffusion", - ): - raise ValueError("Invalid model") model_args = { # default parameters for all networks "img_out_channels": img_out_channels, "img_resolution": list(img_shape), "use_fp16": fp16, + "checkpoint_level": songunet_checkpoint_level, } - standard_model_cfgs = { # default parameters for different network types - "regression": { - "img_channels": 4, - "N_grid_channels": 4, - "embedding_type": "zero", - "checkpoint_level": songunet_checkpoint_level, - }, - "lt_aware_ce_regression": { - "img_channels": 4, - "N_grid_channels": 4, - "embedding_type": "zero", - "lead_time_channels": 4, - "lead_time_steps": 9, - "prob_channels": prob_channels, - "checkpoint_level": songunet_checkpoint_level, - "model_type": "SongUNetPosLtEmbd", - }, - "diffusion": { - "img_channels": img_out_channels, - "gridtype": "sinusoidal", - "N_grid_channels": 4, - "checkpoint_level": songunet_checkpoint_level, - }, - "patched_diffusion": { - "img_channels": img_out_channels, - "gridtype": "learnable", - "N_grid_channels": 100, - "checkpoint_level": songunet_checkpoint_level, - }, - "lt_aware_patched_diffusion": { - "img_channels": img_out_channels, - "gridtype": "learnable", - "N_grid_channels": 100, - "lead_time_channels": 20, - "lead_time_steps": 9, - "checkpoint_level": songunet_checkpoint_level, - "model_type": "SongUNetPosLtEmbd", - }, - } - - - model_args.update(standard_model_cfgs[cfg.model.name]) - if cfg.model.name in ( - "diffusion", - "patched_diffusion", - "lt_aware_patched_diffusion", - ): - model_args["scale_cond_input"] = cfg.model.scale_cond_input + if cfg.model.name == "lt_aware_ce_regression": + model_args["prob_channels"] = prob_channels + if hasattr(cfg.model, "model_args"): # override defaults from config file model_args.update(OmegaConf.to_container(cfg.model.model_args)) + + use_torch_compile = False + use_apex_gn = False + profile_mode = False + + if hasattr(cfg.training.perf, "torch_compile"): + use_torch_compile = cfg.training.perf.torch_compile + if hasattr(cfg.training.perf, "use_apex_gn"): + use_apex_gn = cfg.training.perf.use_apex_gn + model_args["use_apex_gn"] = use_apex_gn + + if hasattr(cfg.training.perf, "profile_mode"): + profile_mode = cfg.training.perf.profile_mode + model_args["profile_mode"] = profile_mode + + if enable_amp: + model_args["amp_mode"] = enable_amp + + if cfg.model.name == "regression": model = UNet( img_in_channels=img_in_channels + model_args["N_grid_channels"], @@ -193,7 +211,7 @@ def main(cfg: DictConfig) -> None: ) model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] elif cfg.model.name == "lt_aware_patched_diffusion": - model = EDMPrecondSR( + model = EDMPrecondSuperResolution( img_in_channels=img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"], @@ -201,7 +219,7 @@ def main(cfg: DictConfig) -> None: ) model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] else: # diffusion or patched diffusion - model = EDMPrecondSR( + model = EDMPrecondSuperResolution( img_in_channels=img_in_channels + model_args["N_grid_channels"], **model_args, ) @@ -209,6 +227,18 @@ def main(cfg: DictConfig) -> None: model.train().requires_grad_(True).to(dist.device) + # param_to_name = {} + # ppp = False + # for name, param in model.named_parameters(): + # pid = id(param) + # if pid in param_to_name: + # print(f"[SHARED PARAM] {name} == {param_to_name[pid]}") + # ppp = True + # break + # else: + # param_to_name[pid] = name + # print(f'There are shared parameters: {ppp}') + # TODO write summry from rank=0 possibly # summary(model, input_size=[(1,img_out_channels,*img_shape),(1,img_in_channels,*img_shape),(1,1)]) @@ -216,6 +246,18 @@ def main(cfg: DictConfig) -> None: with open(os.path.join(checkpoint_dir, f'model_args.json'), 'w') as f: json.dump(model_args, f) + if use_apex_gn: + model.to(memory_format=torch.channels_last) + + # Check if regression model is used with patching + if ( + cfg.model.name in ["regression", "lt_aware_ce_regression"] + and patching is not None + ): + raise ValueError( + f"Regression model ({cfg.model.name}) cannot be used with patch-based training. " + ) + # Enable distributed data parallel if applicable if dist.world_size > 1: model = DistributedDataParallel( @@ -223,7 +265,9 @@ def main(cfg: DictConfig) -> None: device_ids=[dist.local_rank], broadcast_buffers=True, output_device=dist.device, - find_unused_parameters=dist.find_unused_parameters, + find_unused_parameters=True, # dist.find_unused_parameters, + bucket_cap_mb=35, + gradient_as_bucket_view=True, ) # Load the regression checkpoint if applicable #TODO test when training correction @@ -245,6 +289,12 @@ def main(cfg: DictConfig) -> None: with open(regression_model_args_path, 'r') as f: regression_model_args = json.load(f) + regression_model_args.update({ + "use_apex_gn": use_apex_gn, + "profile_mode": profile_mode, + "amp_mode": enable_amp, + }) + regression_net = UNet(**regression_model_args) _ = load_checkpoint( @@ -253,22 +303,81 @@ def main(cfg: DictConfig) -> None: device=dist.device ) regression_net.eval().requires_grad_(False).to(dist.device) + if use_apex_gn: + regression_net.to(memory_format=torch.channels_last) logger0.success("Loaded the pre-trained regression model") + else: + regression_net = None + + # Compile the model and regression net if applicable + if use_torch_compile: + model = torch.compile(model) + if regression_net: + regression_net = torch.compile(regression_net) + + + # Compute the number of required gradient accumulation rounds + # It is automatically used if batch_size_per_gpu * dist.world_size < total_batch_size + batch_gpu_total, num_accumulation_rounds = compute_num_accumulation_rounds( + cfg.training.hp.total_batch_size, + cfg.training.hp.batch_size_per_gpu, + dist.world_size, + ) + batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu + logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds") - # Instantiate the loss function patch_num = getattr(cfg.training.hp, "patch_num", 1) + max_patch_per_gpu = getattr(cfg.training.hp, "max_patch_per_gpu", 1) + + # calculate patch per iter + if hasattr(cfg.training.hp, "max_patch_per_gpu") and max_patch_per_gpu > 1: + max_patch_num_per_iter = min( + patch_num, (max_patch_per_gpu // batch_size_per_gpu) + ) # Ensure at least 1 patch per iter + patch_iterations = ( + patch_num + max_patch_num_per_iter - 1 + ) // max_patch_num_per_iter + patch_nums_iter = [ + min(max_patch_num_per_iter, patch_num - i * max_patch_num_per_iter) + for i in range(patch_iterations) + ] + print( + f"max_patch_num_per_iter is {max_patch_num_per_iter}, patch_iterations is {patch_iterations}, patch_nums_iter is {patch_nums_iter}" + ) + else: + patch_nums_iter = [patch_num] + + # Set patch gradient accumulation only for patched diffusion models + if cfg.model.name in { + "patched_diffusion", + "lt_aware_patched_diffusion", + }: + if len(patch_nums_iter) > 1: + if not patching: + logger0.info( + "Patching is not enabled: patch gradient accumulation automatically disabled." + ) + use_patch_grad_acc = False + else: + use_patch_grad_acc = True + else: + use_patch_grad_acc = False + # Automatically disable patch gradient accumulation for non-patched models + else: + logger0.info( + "Training a non-patched model: patch gradient accumulation automatically disabled." + ) + use_patch_grad_acc = None + + + # Instantiate the loss function if cfg.model.name in ( "diffusion", "patched_diffusion", "lt_aware_patched_diffusion", ): - loss_fn = ResLoss( + loss_fn = ResidualLoss( regression_net=regression_net, - img_shape_x=img_shape[1], - img_shape_y=img_shape[0], - patch_shape_x=patch_shape[1], - patch_shape_y=patch_shape[0], - patch_num=patch_num, hr_mean_conditioning=cfg.model.hr_mean_conditioning, ) elif cfg.model.name == "regression": @@ -278,23 +387,17 @@ def main(cfg: DictConfig) -> None: # Instantiate the optimizer optimizer = torch.optim.Adam( - params=model.parameters(), lr=cfg.training.hp.lr, betas=[0.9, 0.999], eps=1e-8 + params=model.parameters(), + lr=cfg.training.hp.lr, + betas=[0.9, 0.999], + eps=1e-8, + fused=True, ) # Record the current time to measure the duration of subsequent operations. start_time = time.time() - # Compute the number of required gradient accumulation rounds - # It is automatically used if batch_size_per_gpu * dist.world_size < total_batch_size - batch_gpu_total, num_accumulation_rounds = compute_num_accumulation_rounds( - cfg.training.hp.total_batch_size, - cfg.training.hp.batch_size_per_gpu, - dist.world_size, - ) - batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu - logger0.info(f"Using {num_accumulation_rounds} gradient accumulation {"rounds" if num_accumulation_rounds>1 else "round"}.") - logger0.info(f"Batch size per gpu: {batch_size_per_gpu}") - ## Resume training from previous checkpoints if exists + # Load optimizer checkpoint if it exists if dist.world_size > 1: torch.distributed.barrier() try: @@ -317,188 +420,308 @@ def main(cfg: DictConfig) -> None: # init variables to monitor running mean of average loss since last periodic average_loss_running_mean = 0 n_average_loss_running_mean = 1 - - while not done: - tick_start_nimg = cur_nimg - tick_start_time = time.time() - # Compute & accumulate gradients - optimizer.zero_grad(set_to_none=True) - loss_accum = 0 - for _ in range(num_accumulation_rounds): - img_clean, img_lr, labels, *lead_time_label = next(dataset_iterator) # what are labels and lead_time_label - img_clean = img_clean.to(dist.device).to(torch.float32).contiguous() - img_lr = img_lr.to(dist.device).to(torch.float32).contiguous() - labels = labels.to(dist.device).contiguous() - loss_fn_kwargs = { - "net": model, - "img_clean": img_clean, - "img_lr": img_lr, - "labels": labels, - "augment_pipe": None, - } - if lead_time_label: - lead_time_label = lead_time_label[0].to(dist.device).contiguous() - loss_fn_kwargs.update({"lead_time_label": lead_time_label}) - else: - lead_time_label = None - with torch.autocast("cuda", dtype=amp_dtype, enabled=enable_amp): - loss = loss_fn(**loss_fn_kwargs) - loss = loss.sum() / batch_size_per_gpu - loss_accum += loss / num_accumulation_rounds - loss.backward() + start_nimg = cur_nimg + input_dtype = torch.float32 + if enable_amp: + input_dtype = torch.float32 + elif fp16: + input_dtype = torch.float16 + + # enable profiler: + with cuda_profiler(): + with profiler_emit_nvtx(): + while not done: + tick_start_nimg = cur_nimg + tick_start_time = time.time() + + if cur_nimg - start_nimg == 24 * cfg.training.hp.total_batch_size: + logger0.info(f"Starting Profiler at {cur_nimg}") + cuda_profiler_start() + + if cur_nimg - start_nimg == 25 * cfg.training.hp.total_batch_size: + logger0.info(f"Stopping Profiler at {cur_nimg}") + cuda_profiler_stop() + + with nvtx.annotate("Training iteration", color="green"): + # Compute & accumulate gradients + optimizer.zero_grad(set_to_none=True) + loss_accum = 0 + for n_i in range(num_accumulation_rounds): + with nvtx.annotate( + f"accumulation round {n_i}", color="Magenta" + ): + with nvtx.annotate("loading data", color="green"): + img_clean, img_lr, *lead_time_label = next( + dataset_iterator + ) + if use_apex_gn: + img_clean = img_clean.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + img_lr = img_lr.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + else: + img_clean = ( + img_clean.to(dist.device) + .to(input_dtype) + .contiguous() + ) + img_lr = ( + img_lr.to(dist.device) + .to(input_dtype) + .contiguous() + ) + loss_fn_kwargs = { + "net": model, + "img_clean": img_clean, + "img_lr": img_lr, + "augment_pipe": None, + } + if use_patch_grad_acc is not None: + loss_fn_kwargs[ + "use_patch_grad_acc" + ] = use_patch_grad_acc + + if lead_time_label: + lead_time_label = ( + lead_time_label[0].to(dist.device).contiguous() + ) + loss_fn_kwargs.update( + {"lead_time_label": lead_time_label} + ) + else: + lead_time_label = None + if use_patch_grad_acc: + loss_fn.y_mean = None + + for patch_num_per_iter in patch_nums_iter: + if patching is not None: + patching.set_patch_num(patch_num_per_iter) + loss_fn_kwargs.update({"patching": patching}) + with nvtx.annotate(f"loss forward", color="green"): + with torch.autocast( + "cuda", dtype=amp_dtype, enabled=enable_amp + ): + loss = loss_fn(**loss_fn_kwargs) + + loss = loss.sum() / batch_size_per_gpu + loss_accum += loss / num_accumulation_rounds + with nvtx.annotate(f"loss backward", color="yellow"): + loss.backward() - loss_sum = torch.tensor([loss_accum], device=dist.device) - if dist.world_size > 1: - torch.distributed.barrier() - torch.distributed.all_reduce(loss_sum, op=torch.distributed.ReduceOp.SUM) - average_loss = (loss_sum / dist.world_size).cpu().item() - - # update running mean of average loss since last periodic task - average_loss_running_mean += ( - average_loss - average_loss_running_mean - ) / n_average_loss_running_mean - n_average_loss_running_mean += 1 - - if dist.rank == 0: - writer.add_scalar("training_loss", average_loss, cur_nimg) - writer.add_scalar( - "training_loss_running_mean", average_loss_running_mean, cur_nimg - ) + with nvtx.annotate(f"loss aggregate", color="green"): + loss_sum = torch.tensor([loss_accum], device=dist.device) + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce( + loss_sum, op=torch.distributed.ReduceOp.SUM + ) + average_loss = (loss_sum / dist.world_size).cpu().item() + + # update running mean of average loss since last periodic task + average_loss_running_mean += ( + average_loss - average_loss_running_mean + ) / n_average_loss_running_mean + n_average_loss_running_mean += 1 - - # Update weights. - lr_rampup = cfg.training.hp.lr_rampup # ramp up the learning rate - for g in optimizer.param_groups: - if lr_rampup > 0: - g["lr"] = cfg.training.hp.lr * min(cur_nimg / lr_rampup, 1) - if cur_nimg >= lr_rampup: - g["lr"] *= cfg.training.hp.lr_decay ** ((cur_nimg - lr_rampup) // 5e6) - current_lr = g["lr"] - if dist.rank == 0: - writer.add_scalar("learning_rate", current_lr, cur_nimg) - handle_and_clip_gradients( - model, grad_clip_threshold=cfg.training.hp.grad_clip_threshold - ) - optimizer.step() - - cur_nimg += cfg.training.hp.total_batch_size - done = cur_nimg >= cfg.training.hp.training_duration - - # Validation - if validation_dataset_iterator is not None: - valid_loss_accum = 0 - if is_time_for_periodic_task( - cur_nimg, - cfg.training.io.validation_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - ): - with torch.no_grad(): - for _ in range(cfg.training.io.validation_steps): - img_clean_valid, img_lr_valid, labels_valid = next( - validation_dataset_iterator - ) - - img_clean_valid = ( - img_clean_valid.to(dist.device) - .to(torch.float32) - .contiguous() - ) - img_lr_valid = ( - img_lr_valid.to(dist.device).to(torch.float32).contiguous() - ) - labels_valid = labels_valid.to(dist.device).contiguous() - loss_valid = loss_fn( - net=model, - img_clean=img_clean_valid, - img_lr=img_lr_valid, - labels=labels_valid, - augment_pipe=None, - ) - loss_valid = ( - (loss_valid.sum() / batch_size_per_gpu).cpu().item() - ) - valid_loss_accum += ( - loss_valid / cfg.training.io.validation_steps - ) - valid_loss_sum = torch.tensor( - [valid_loss_accum], device=dist.device - ) - if dist.world_size > 1: - torch.distributed.barrier() - torch.distributed.all_reduce( - valid_loss_sum, op=torch.distributed.ReduceOp.SUM - ) - average_valid_loss = valid_loss_sum / dist.world_size if dist.rank == 0: + writer.add_scalar("training_loss", average_loss, cur_nimg) writer.add_scalar( - "validation_loss", average_valid_loss, cur_nimg + "training_loss_running_mean", + average_loss_running_mean, + cur_nimg, ) - if is_time_for_periodic_task( - cur_nimg, - cfg.training.io.print_progress_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - rank_0_only=True, - ): - # Print stats if we crossed the printing threshold with this batch - tick_end_time = time.time() - fields = [] - fields += [f"samples {cur_nimg:<9.1f}"] - fields += [f"training_loss {average_loss:<7.2f}"] - fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] - fields += [f"learning_rate {current_lr:<7.8f}"] - fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] - fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] - fields += [ - f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" - ] - fields += [ - f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" - ] - fields += [ - f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" - ] - fields += [ - f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" - ] - logger0.info(" ".join(fields)) - torch.cuda.reset_peak_memory_stats() - - ptt = is_time_for_periodic_task( - cur_nimg, - cfg.training.io.print_progress_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - rank_0_only=True, - ) - if ptt: - # reset running mean of average loss - average_loss_running_mean = 0 - n_average_loss_running_mean = 1 - - # Save checkpoints - if dist.world_size > 1: - torch.distributed.barrier() - if is_time_for_periodic_task( - cur_nimg, - cfg.training.io.save_checkpoint_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - rank_0_only=True, - ): - save_checkpoint( - path=checkpoint_dir, - model=model, - optimizer=optimizer, - epoch=cur_nimg, - ) + ptt = is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ) + if ptt: + # reset running mean of average loss + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 + + # Update weights. + with nvtx.annotate("update weights", color="blue"): + + lr_rampup = cfg.training.hp.lr_rampup # ramp up the learning rate + for g in optimizer.param_groups: + if lr_rampup > 0: + g["lr"] = cfg.training.hp.lr * min(cur_nimg / lr_rampup, 1) + if cur_nimg >= lr_rampup: + g["lr"] *= cfg.training.hp.lr_decay ** ((cur_nimg - lr_rampup) // 5e6) + current_lr = g["lr"] + if dist.rank == 0: + writer.add_scalar("learning_rate", current_lr, cur_nimg) + handle_and_clip_gradients( + model, grad_clip_threshold=cfg.training.hp.grad_clip_threshold + ) + with nvtx.annotate("optimizer step", color="blue"): + optimizer.step() + + cur_nimg += cfg.training.hp.total_batch_size + done = cur_nimg >= cfg.training.hp.training_duration + + with nvtx.annotate("validation", color="red"): + # Validation + if validation_dataset_iterator is not None: + valid_loss_accum = 0 + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.validation_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + ): + with torch.no_grad(): + for _ in range(cfg.training.io.validation_steps): + ( + img_clean_valid, + img_lr_valid, + *lead_time_label_valid, + ) = next(validation_dataset_iterator) + + if use_apex_gn: + img_clean_valid = img_clean_valid.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + img_lr_valid = img_lr_valid.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + + else: + img_clean_valid = ( + img_clean_valid.to(dist.device) + .to(input_dtype) + .contiguous() + ) + img_lr_valid = ( + img_lr_valid.to(dist.device) + .to(input_dtype) + .contiguous() + ) + + loss_valid_kwargs = { + "net": model, + "img_clean": img_clean_valid, + "img_lr": img_lr_valid, + "augment_pipe": None, + } + if use_patch_grad_acc is not None: + loss_valid_kwargs[ + "use_patch_grad_acc" + ] = use_patch_grad_acc + if lead_time_label_valid: + lead_time_label_valid = ( + lead_time_label_valid[0] + .to(dist.device) + .contiguous() + ) + loss_valid_kwargs.update( + {"lead_time_label": lead_time_label_valid} + ) + if use_patch_grad_acc: + loss_fn.y_mean = None + + for patch_num_per_iter in patch_nums_iter: + if patching is not None: + patching.set_patch_num(patch_num_per_iter) + loss_fn_kwargs.update( + {"patching": patching} + ) + with torch.autocast( + "cuda", dtype=amp_dtype, enabled=enable_amp + ): + loss_valid = loss_fn(**loss_valid_kwargs) + + loss_valid = ( + (loss_valid.sum() / batch_size_per_gpu) + .cpu() + .item() + ) + valid_loss_accum += ( + loss_valid + / cfg.training.io.validation_steps + ) + valid_loss_sum = torch.tensor( + [valid_loss_accum], device=dist.device + ) + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce( + valid_loss_sum, op=torch.distributed.ReduceOp.SUM + ) + average_valid_loss = valid_loss_sum / dist.world_size + if dist.rank == 0: + writer.add_scalar( + "validation_loss", average_valid_loss, cur_nimg + ) + + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + # Print stats if we crossed the printing threshold with this batch + tick_end_time = time.time() + fields = [] + fields += [f"samples {cur_nimg:<9.1f}"] + fields += [f"training_loss {average_loss:<7.2f}"] + fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] + fields += [f"learning_rate {current_lr:<7.8f}"] + fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] + fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] + fields += [ + f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" + ] + fields += [ + f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" + ] + if torch.cuda.is_available(): + fields += [ + f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" + ] + torch.cuda.reset_peak_memory_stats() + logger0.info(" ".join(fields)) + + + # Save checkpoints + if dist.world_size > 1: + torch.distributed.barrier() + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.save_checkpoint_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + save_checkpoint( + path=checkpoint_dir, + model=model, + optimizer=optimizer, + epoch=cur_nimg, + ) # Done. logger0.info("Training Completed.") diff --git a/src/hirad/utils/deterministic_sampler.py b/src/hirad/utils/deterministic_sampler.py index 9fcea1d..e502875 100644 --- a/src/hirad/utils/deterministic_sampler.py +++ b/src/hirad/utils/deterministic_sampler.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable, Literal, Optional import numpy as np import nvtx @@ -26,33 +27,142 @@ @nvtx.annotate(message="deterministic_sampler", color="red") def deterministic_sampler( - net, - latents, - img_lr, - img_shape=None, - class_labels=None, - randn_like=torch.randn_like, - num_steps=18, - sigma_min=None, - sigma_max=None, - rho=7, - solver="heun", - discretization="edm", - schedule="linear", - scaling="none", - epsilon_s=1e-3, - C_1=0.001, - C_2=0.008, - M=1000, - alpha=1, - S_churn=0, - S_min=0, - S_max=float("inf"), - S_noise=1, -): + net: torch.nn.Module, + latents: torch.Tensor, + img_lr: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + randn_like: Callable = torch.randn_like, + num_steps: int = 18, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + rho: float = 7.0, + solver: Literal["heun", "euler"] = "heun", + discretization: Literal["vp", "ve", "iddpm", "edm"] = "edm", + schedule: Literal["vp", "ve", "linear"] = "linear", + scaling: Literal["vp", "none"] = "none", + epsilon_s: float = 1e-3, + C_1: float = 0.001, + C_2: float = 0.008, + M: int = 1000, + alpha: float = 1.0, + S_churn: int = 0, + S_min: float = 0.0, + S_max: float = float("inf"), + S_noise: float = 1.0, +) -> torch.Tensor: """ - Generalized sampler, representing the superset of all sampling methods discussed - in the paper "Elucidating the Design Space of Diffusion-Based Generative Models" + Generalized sampler, representing the superset of all sampling methods + discussed in the paper "Elucidating the Design Space of Diffusion-Based + Generative Models" (EDM). + - https://arxiv.org/abs/2206.00364 + + This function integrates an ODE (probability flow) or SDE over multiple + time-steps to generate samples from the diffusion model provided by the + argument 'net'. It can be used to combine multiple choices to + design a custom sampler, including multiple integration solver, + discretization method, noise schedule, and so on. + + Parameters: + ----------- + net : torch.nn.Module + The diffusion model to use in the sampling process. + latents : torch.Tensor + The latent random noise used as the initial condition for the + stochastic ODE. + img_lr : torch.Tensor + Low-resolution input image for conditioning the diffusion process. + Passed as a keywork argument to the model 'net'. + class_labels : Optional[torch.Tensor] + Labels of the classes used as input to a class-conditionned + diffusion model. Passed as a keyword argument to the model 'net'. + If provided, it must be a tensor containing integer values. + Defaults to None, in which case it is ignored. + randn_like: Callable + Random Number Generator to generate random noise that is added + during the stochastic sampling. Must have the same signature as + torch.randn_like and return torch.Tensor. Defaults to + torch.randn_like. + num_steps : Optional[int] + Number of time-steps for the stochastic ODE integration. Defaults + to 18. + sigma_min : Optional[float] + Minimum noise level for the diffusion process. 'sigma_min', + 'sigma_max', and 'rho' are used to compute the time-step + discretization, based on the choice of discretization. For the + default choice ("discretization='heun'"), the noise level schedule + is computed as: + :math:`\sigma_i = (\sigma_{max}^{1/\rho} + i / (num_steps - 1) * (\sigma_{min}^{1/\rho} - \sigma_{max}^{1/\rho}))^{rho}`. + For other choices of 'discretization', see details in the EDM + paper. Defaults to None, in which case defaults values depending + of the specified discretization are used. + sigma_max : Optional[float] + Maximum noise level for the diffusion process. See sigma_min for + details. Defaults to None, in which case defaults values depending + of the specified discretization are used. + rho : float, optional + Exponent used in the noise schedule. See sigma_min for details. + Only used when 'discretization' is 'heun'. Values in the range [5, + 10] produce better images. Lower values lead to truncation errors + equalized over all time steps. Defaults to 7. + solver : Literal["heun", "euler"] + The numerical method used to integrate the stochastic ODE. "euler" + is 1st order solver, which is faster but produces lower-quality + images. "heun" is 2nd order, more expensive, but produces + higher-quality images. Defaults to "heun". + discretization : Literal["vp", "ve", "iddpm", "edm"] + The method to discretize time-steps :math:`t_i` in the + diffusion process. See the EDM papper for details. Defaults to + "edm". + schedule : Literal["vp", "ve", "linear"] + The type of noise level schedule. Defaults to "linear". If + schedule='ve', then :math:`\sigma(t) = \sqrt{t}`. If + schedule='linear', then :math:`\sigma(t) = t`. If schedule='vp', + see EDM paper for details. Defaults to "linear". + scaling : Literal["vp", "none"] + The type of time-dependent signal scaling :math:`s(t)`, such that + :math:`x = s(t) \hat{x}`. See EDM paper for details on the 'vp' + scaling. Defaults to 'none', in which case :math:`s(t)=1`. + epsilon_s : float, optional + Parameter to compute both the noise level schedule and the + time-step discetization. Only used when discretization='vp' or + schedule='vp'. Ignored in other cases. Defaults to 1e-3. + C_1 : float, optional + Parameters to compute the time-step discetization. Only used when + discretization='iddpm'. Defaults to 0.001. + C_2 : float, optional + Same as for C_1. Only used when discretization='iddpm'. Defaults to + 0.008. + M : int, optional + Same as for C_1 and C_2. Only used when discretization='iddpm'. + Defaults to 1000. + alpha : float, optional + Controls (i.e. multiplies) the step size :math:`t_{i+1} - + \hat{t}_i` in the stochastic sampler, where :math:`\hat{t}_i` is + the temporarily increased noise level. Defaults to 1.0, which is + the recommended value. + S_churn : int, optional + Controls the amount of stochasticty injected in the SDE in the + stochatsic sampler. Larger values of S_churn lead to larger values + of :math:`\hat{t}_i`, which in turn lead to injecting more + stochasticity in the SDE by Defaults to 0, which means no + stochasticity is injected. + S_min : float, optional + S_min and S_max control the time-step range obver which + stochasticty is injected in the SDE. Stochasticity is injected + through `\hat{t}_i` for time-steps :math:`t_i` such that + :math:`S_{min} \leq t_i \leq S_{max}`. Defaults to 0.0. + S_max : float, optional + See S_min. Defaults to float("inf"). + S_noise : float, optional + Controls the amount of stochasticty injected in the SDE in the + stochatsic sampler. Added signal noise is proportinal to + :math:`\epsilon_i` where `\epsilon_i ~ N(0, S_{noise}^2)`. Defaults + to 1.0. + + Returns + ------- + torch.Tensor: + Generated batch of samples. Same shape as the input 'latents'. """ # conditioning diff --git a/src/hirad/utils/function_utils.py b/src/hirad/utils/function_utils.py index dcbb127..347457c 100644 --- a/src/hirad/utils/function_utils.py +++ b/src/hirad/utils/function_utils.py @@ -29,7 +29,7 @@ import sys import types import warnings -from typing import Any, List, Tuple, Union +from typing import Any, Iterator, List, Tuple, Union import cftime import numpy as np @@ -553,14 +553,37 @@ def decorator(*args, **kwargs): # indefinitely, shuffling items as it goes. -class InfiniteSampler(torch.utils.data.Sampler): # pragma: no cover - """ - Sampler for torch.utils.data.DataLoader that loops over the dataset - indefinitely, shuffling items as it goes. +class InfiniteSampler(torch.utils.data.Sampler[int]): # pragma: no cover + """Sampler for torch.utils.data.DataLoader that loops over the dataset indefinitely. + + This sampler yields indices indefinitely, optionally shuffling items as it goes. + It can also perform distributed sampling when rank and num_replicas are specified. + + Parameters + ---------- + dataset : torch.utils.data.Dataset + The dataset to sample from + rank : int, default=0 + The rank of the current process within num_replicas processes + num_replicas : int, default=1 + The number of processes participating in distributed sampling + shuffle : bool, default=True + Whether to shuffle the indices + seed : int, default=0 + Random seed for reproducibility when shuffling + window_size : float, default=0.5 + Fraction of dataset to use as window for shuffling. Must be between 0 and 1. + A larger window means more thorough shuffling but slower iteration. """ def __init__( - self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5 + self, + dataset: torch.utils.data.Dataset, + rank: int = 0, + num_replicas: int = 1, + shuffle: bool = True, + seed: int = 0, + window_size: float = 0.5, ): if not len(dataset) > 0: raise ValueError("Dataset must contain at least one item") @@ -578,7 +601,7 @@ def __init__( self.seed = seed self.window_size = window_size - def __iter__(self): + def __iter__(self) -> Iterator[int]: order = np.arange(len(self.dataset)) rnd = None window = 0 diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index ace05ba..8665536 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -15,6 +15,7 @@ # limitations under the License. import datetime +from typing import Optional import cftime import nvtx @@ -23,6 +24,9 @@ from .function_utils import StackedRandomGenerator, time_range +from .stochastic_sampler import stochastic_sampler +from .deterministic_sampler import deterministic_sampler + ############################################################################ # CorrDiff Generation Utilities # ############################################################################ @@ -31,35 +35,56 @@ def regression_step( net: torch.nn.Module, img_lr: torch.Tensor, - labels: torch.Tensor, latents_shape: torch.Size, - lead_time_label: torch.Tensor = None, + lead_time_label: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ - Given a low-res input, performs a regression step to produce ensemble mean. - This function performs the regression on a single instance and then replicates - the results across the batch dimension. - - Args: - net (torch.nn.Module): U-Net model for regression. - img_lr (torch.Tensor): Low-resolution input. - latents_shape (torch.Size): Shape of the latent representation. Typically - (batch_size, out_channels, image_shape_x, image_shape_y). - - - Returns: - torch.Tensor: Predicted output at the next time step. + Perform a regression step to produce ensemble mean prediction. + + This function takes a low-resolution input and performs a regression step to produce + an ensemble mean prediction. It processes a single instance and then replicates + the results across the batch dimension if needed. + + Parameters + ---------- + net : torch.nn.Module + U-Net model for regression. + img_lr : torch.Tensor + Low-resolution input to the network with shape (1, channels, height, width). + Must have a batch dimension of 1. + latents_shape : torch.Size + Shape of the latent representation with format + (batch_size, out_channels, image_shape_y, image_shape_x). + lead_time_label : Optional[torch.Tensor], optional + Lead time label tensor for lead time conditioning, + with shape (1, lead_time_dims). Default is None. + + Returns + ------- + torch.Tensor + Predicted ensemble mean at the next time step with shape matching latents_shape. + + Raises + ------ + ValueError + If img_lr has a batch size greater than 1. """ # Create a tensor of zeros with the given shape and move it to the appropriate device x_hat = torch.zeros(latents_shape, dtype=img_lr.dtype, device=img_lr.device) - t_hat = torch.tensor(1.0, dtype=img_lr.dtype, device=img_lr.device)#.reshape((1,1,1,1)) + + # Safety check: avoid silently ignoring batch elements in img_lr + if img_lr.shape[0] > 1: + raise ValueError( + f"Expected img_lr to have a batch size of 1, " + f"but found {img_lr.shape[0]}." + ) # Perform regression on a single batch element with torch.inference_mode(): if lead_time_label is not None: - x = net(x_hat[0:1], img_lr, t_hat, labels, lead_time_label=lead_time_label) + x = net(x=x_hat[0:1], img_lr=img_lr, lead_time_label=lead_time_label) else: - x = net(x_hat[0:1], img_lr, t_hat, labels) + x = net(x=x_hat[0:1], img_lr=img_lr) # If the batch size is greater than 1, repeat the prediction if x_hat.shape[0] > 1: @@ -68,48 +93,85 @@ def regression_step( return x -def diffusion_step( # TODO generalize the module and add defaults +def diffusion_step( net: torch.nn.Module, sampler_fn: callable, - seed_batch_size: int, img_shape: tuple, img_out_channels: int, rank_batches: list, img_lr: torch.Tensor, rank: int, device: torch.device, - hr_mean: torch.Tensor = None, + mean_hr: torch.Tensor = None, lead_time_label: torch.Tensor = None, ) -> torch.Tensor: """ Generate images using diffusion techniques as described in the relevant paper. - Args: - net (torch.nn.Module): The diffusion model network. - sampler_fn (callable): Function used to sample images from the diffusion model. - seed_batch_size (int): Number of seeds per batch. - img_shape (tuple): Shape of the images, (height, width). - img_out_channels (int): Number of output channels for the image. - rank_batches (list): List of batches of seeds to process. - img_lr (torch.Tensor): Low-resolution input image. - rank (int): Rank of the current process for distributed processing. - device (torch.device): Device to perform computations. - mean_hr (torch.Tensor, optional): High-resolution mean tensor, to be used as an additional input. By default None. - - Returns: - torch.Tensor: Generated images concatenated across batches. + This function applies a diffusion model to generate high-resolution images based on + low-resolution inputs. It supports optional conditioning on high-resolution mean + predictions and lead time labels. + + For each low-resolution sample in `img_lr`, the function generates multiple + high-resolution samples, with different random seeds, specified in `rank_batches`. + The function then concatenates these high-resolution samples across the batch dimension. + + Parameters + ---------- + net : torch.nn.Module + The diffusion model network. + sampler_fn : callable + Function used to sample images from the diffusion model. + img_shape : tuple + Shape of the images, (height, width). + img_out_channels : int + Number of output channels for the image. + rank_batches : list + List of batches of seeds to process. + img_lr : torch.Tensor + Low-resolution input image with shape (seed_batch_size, channels_lr, height, width). + rank : int, optional + Rank of the current process for distributed processing. + device : torch.device, optional + Device to perform computations. + mean_hr : torch.Tensor, optional + High-resolution mean tensor to be used as an additional input, + with shape (1, channels_hr, height, width). Default is None. + lead_time_label : torch.Tensor, optional + Lead time label tensor for temporal conditioning, + with shape (batch_size, lead_time_dims). Default is None. + + Returns + ------- + torch.Tensor + Generated images concatenated across batches with shape + (seed_batch_size * len(rank_batches), out_channels, height, width). """ - img_lr = img_lr #.to(memory_format=torch.channels_last) + # Check img_lr dimensions match expected shape + if img_lr.shape[2:] != img_shape: + raise ValueError( + f"img_lr shape {img_lr.shape[2:]} does not match expected shape img_shape {img_shape}" + ) + + # Check mean_hr dimensions if provided + if mean_hr is not None: + if mean_hr.shape[2:] != img_shape: + raise ValueError( + f"mean_hr shape {mean_hr.shape[2:]} does not match expected shape img_shape {img_shape}" + ) + if mean_hr.shape[0] != 1: + raise ValueError(f"mean_hr must have batch size 1, got {mean_hr.shape[0]}") + + img_lr = img_lr.to(memory_format=torch.channels_last) # Handling of the high-res mean additional_args = {} - if hr_mean is not None: - additional_args["mean_hr"] = hr_mean + if mean_hr is not None: + additional_args["mean_hr"] = mean_hr if lead_time_label is not None: additional_args["lead_time_label"] = lead_time_label - additional_args["img_shape"] = img_shape # Loop over batches all_images = [] @@ -123,7 +185,7 @@ def diffusion_step( # TODO generalize the module and add defaults rnd = StackedRandomGenerator(device, batch_seeds) latents = rnd.randn( [ - seed_batch_size, + img_lr.shape[0], img_out_channels, img_shape[0], img_shape[1], @@ -139,6 +201,9 @@ def diffusion_step( # TODO generalize the module and add defaults return torch.cat(all_images) +def generate(): + pass + ############################################################################ # CorrDiff writer utilities # ############################################################################ diff --git a/src/hirad/utils/patching.py b/src/hirad/utils/patching.py new file mode 100644 index 0000000..6f4bc4d --- /dev/null +++ b/src/hirad/utils/patching.py @@ -0,0 +1,767 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +import random +import warnings +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple, Union + +import torch +from einops import rearrange +from torch import Tensor + +""" +This module defines utilities, including classes and functions, for domain +decomposition. +""" + + +class BasePatching2D(ABC): + """ + Abstract base class for 2D image patching operations. + + This class provides a foundation for implementing various image patching + strategies. + It handles basic validation and provides abstract methods that must be + implemented by subclasses. + + Parameters + ---------- + img_shape : Tuple[int, int] + The height and width of the input images (img_shape_y, img_shape_x). + patch_shape : Tuple[int, int] + The height and width of the patches (patch_shape_y, patch_shape_x) to + extract. + """ + + def __init__( + self, img_shape: Tuple[int, int], patch_shape: Tuple[int, int] + ) -> None: + # Check that img_shape and patch_shape are 2D + if len(img_shape) != 2: + raise ValueError(f"img_shape must be 2D, got {len(img_shape)}D") + if len(patch_shape) != 2: + raise ValueError(f"patch_shape must be 2D, got {len(patch_shape)}D") + + # Make sure patches fit within the image + if any(p > i for p, i in zip(patch_shape, img_shape)): + warnings.warn( + f"Patch shape {patch_shape} is larger than " + f"image shape {img_shape}. " + f"Patches will be cropped to fit within the image." + ) + self.img_shape = img_shape + self.patch_shape = tuple(min(p, i) for p, i in zip(patch_shape, img_shape)) + + @abstractmethod + def apply(self, input: Tensor, **kwargs) -> Tensor: + """ + Apply the patching operation to the input tensor. + + Parameters + ---------- + input : Tensor + Input tensor of shape (batch_size, channels, img_shape_y, + img_shape_x). + **kwargs : dict + Additional keyword arguments specific to the patching + implementation. + + Returns + ------- + Tensor + Patched tensor, shape depends on specific implementation. + """ + pass + + def fuse(self, input: Tensor, **kwargs) -> Tensor: + """ + Fuse patches back into a complete image. + + Parameters + ---------- + input : Tensor + Input tensor containing patches. + **kwargs : dict + Additional keyword arguments specific to the fusion implementation. + + Returns + ------- + Tensor + Fused tensor, shape depends on specific implementation. + + Raises + ------ + NotImplementedError + If the subclass does not implement this method. + """ + raise NotImplementedError("'fuse' method must be implemented in subclasses.") + + def global_index( + self, batch_size: int, device: Union[torch.device, str] = "cpu" + ) -> Tensor: + """ + Returns a tensor containing the global indices for each patch. + + Global indices correspond to (y, x) global grid coordinates of each + element within the original image (before patching). It is typically + used to keep track of the original position of each patch in the + original image. + + Parameters + ---------- + batch_size : int + The size of the batch of images to patch. + device : Union[torch.device, str] + Proper device to initialize global_index on. Default to `cpu` + + Returns + ------- + Tensor + A tensor of shape (self.patch_num, 2, patch_shape_y, + patch_shape_x). `global_index[:, 0, :, :]` contains the + y-coordinate (height), and `global_index[:, 1, :, :]` contains the + x-coordinate (width). + """ + Ny = torch.arange(self.img_shape[0], device=device).int() + Nx = torch.arange(self.img_shape[1], device=device).int() + grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0).unsqueeze(0) + global_index = self.apply(grid).long() + return global_index + + +class RandomPatching2D(BasePatching2D): + """ + Class for randomly extracting patches from 2D images. + + This class provides utilities to randomly extract patches from images + represented as 4D tensors. It maintains a list of random patch indices + that can be reset as needed. + + Parameters + ---------- + img_shape : Tuple[int, int] + The height and width of the input images (img_shape_y, img_shape_x). + patch_shape : Tuple[int, int] + The height and width of the patches (patch_shape_y, patch_shape_x) to + extract. + patch_num : int + The number of patches to extract. + + Attributes + ---------- + patch_indices : List[Tuple[int, int]] + The indices of the patches to extract from the images. These indices + correspond to the (y, x) coordinates of the lower left corner of each + patch. + + See Also + -------- + :class:`physicsnemo.utils.patching.BasePatching2D` + The base class providing the patching interface. + :class:`physicsnemo.utils.patching.GridPatching2D` + Alternative patching strategy using deterministic patch locations. + """ + + def __init__( + self, img_shape: Tuple[int, int], patch_shape: Tuple[int, int], patch_num: int + ) -> None: + """ + Initialize the RandomPatching2D object with the provided image shape, + patch shape, and number of patches to extract. + + Parameters + ---------- + img_shape : Tuple[int, int] + The height and width of the input images (img_shape_y, + img_shape_x). + patch_shape : Tuple[int, int] + The height and width of the patches (patch_shape_y, patch_shape_x) + to extract. + patch_num : int + The number of patches to extract. + + Returns + ------- + None + """ + super().__init__(img_shape, patch_shape) + self._patch_num = patch_num + # Generate the indices of the patches to extract + self.reset_patch_indices() + + @property + def patch_num(self) -> int: + """ + Get the number of patches to extract. + + Returns + ------- + int + The number of patches to extract. + """ + return self._patch_num + + def set_patch_num(self, value: int) -> None: + """ + Set the number of patches to extract and reset patch indices. + This is the only way to modify the patch_num value. + + Parameters + ---------- + value : int + The new number of patches to extract. + """ + self._patch_num = value + self.reset_patch_indices() + + def reset_patch_indices(self) -> None: + """ + Generate new random indices for the patches to extract. These are the + starting indices of the patches to extract (upper left corner). + + Returns + ------- + None + """ + self.patch_indices = [ + ( + random.randint(0, self.img_shape[0] - self.patch_shape[0]), + random.randint(0, self.img_shape[1] - self.patch_shape[1]), + ) + for _ in range(self.patch_num) + ] + return + + def get_patch_indices(self) -> List[Tuple[int, int]]: + """ + Get the current list of patch starting indices. + + These are the upper-left coordinates of each extracted patch + from the full image. + + Returns + ------- + List[Tuple[int, int]] + A list of (row, column) tuples representing patch starting positions. + """ + return self.patch_indices + + def apply( + self, + input: Tensor, + additional_input: Optional[Tensor] = None, + ) -> Tensor: + """ + Applies the patching operation by extracting patches specified by + `self.patch_indices` from the `input` Tensor. Extracted patches are + batched along the first dimension of the output. The layout of the + output assumes that for any i, `out[B * i: B * (i + 1)]` + corresponds to the same patch exacted from each batch element of + `input`. + + Arguments + --------- + input : Tensor + The input tensor representing the full image with shape + (batch_size, channels_in, img_shape_y, img_shape_x). + additional_input : Optional[Tensor], optional + If provided, it is concatenated to each patch along `dim=1`. + Must have same batch size as `input`. Bilinear interpolation + is used to interpolate `additional_input` onto a 2D grid of shape + (patch_shape_y, patch_shape_x). + + Returns + ------- + Tensor + A tensor of shape (batch_size * self.patch_num, channels [+ + additional_channels], patch_shape_y, patch_shape_x). If + `additional_input` is provided, its channels are concatenated + along the channel dimension. + """ + B = input.shape[0] + out = torch.zeros( + B * self.patch_num, + ( + input.shape[1] + + (additional_input.shape[1] if additional_input is not None else 0) + ), + self.patch_shape[0], + self.patch_shape[1], + device=input.device, + ) + out = out.to( + memory_format=torch.channels_last + if input.is_contiguous(memory_format=torch.channels_last) + else torch.contiguous_format + ) + if additional_input is not None: + add_input_interp = torch.nn.functional.interpolate( + input=additional_input, size=self.patch_shape, mode="bilinear" + ) + + for i, (py, px) in enumerate(self.patch_indices): + if additional_input is not None: + out[B * i : B * (i + 1),] = torch.cat( + ( + input[ + :, + :, + py : py + self.patch_shape[0], + px : px + self.patch_shape[1], + ], + add_input_interp, + ), + dim=1, + ) + else: + out[B * i : B * (i + 1),] = input[ + :, + :, + py : py + self.patch_shape[0], + px : px + self.patch_shape[1], + ] + return out + + +class GridPatching2D(BasePatching2D): + """ + Class for deterministically extracting patches from 2D images in a grid pattern. + + This class provides utilities to extract patches from images in a + deterministic manner, with configurable overlap and boundary pixels. + The patches are extracted in a grid-like pattern covering the entire image. + + Parameters + ---------- + img_shape : Tuple[int, int] + The height and width of the input images (img_shape_y, img_shape_x). + patch_shape : Tuple[int, int] + The height and width of the patches (patch_shape_y, patch_shape_x) to + extract. + overlap_pix : int, optional + Number of pixels to overlap between adjacent patches, by default 0. + boundary_pix : int, optional + Number of pixels to crop as boundary from each patch, by default 0. + + Attributes + ---------- + patch_num : int + Total number of patches that will be extracted from the image, + calculated as patch_num_x * patch_num_y. + + See Also + -------- + :class:`physicsnemo.utils.patching.BasePatching2D` + The base class providing the patching interface. + :class:`physicsnemo.utils.patching.RandomPatching2D` + Alternative patching strategy using random patch locations. + """ + + def __init__( + self, + img_shape: Tuple[int, int], + patch_shape: Tuple[int, int], + overlap_pix: int = 0, + boundary_pix: int = 0, + ): + super().__init__(img_shape, patch_shape) + self.overlap_pix = overlap_pix + self.boundary_pix = boundary_pix + patch_num_x = math.ceil( + img_shape[1] / (patch_shape[1] - overlap_pix - boundary_pix) + ) + patch_num_y = math.ceil( + img_shape[0] / (patch_shape[0] - overlap_pix - boundary_pix) + ) + self.patch_num = patch_num_x * patch_num_y + + def apply( + self, + input: Tensor, + additional_input: Optional[Tensor] = None, + ) -> Tensor: + """ + Apply deterministic patching to the input tensor. + + Splits the input tensor into patches in a grid-like pattern. Can + optionally concatenate additional interpolated data to each patch. + Extracted patches are batched along the first dimension of the output. + The layout of the output assumes that for any i, `out[B * i: B * (i + 1)]` + corresponds to the same patch exacted from each batch element of + `input`. The patches can be reconstructed back into the original image + using the fuse method. + + Parameters + ---------- + input : Tensor + Input tensor of shape (batch_size, channels, img_shape_y, + img_shape_x). + additional_input : Optional[Tensor], optional + Additional data to concatenate to each patch. Will be interpolated + to match patch dimensions. Shape must be (batch_size, + additional_channels, H, W), by default None. + + Returns + ------- + Tensor + Tensor containing patches with shape (batch_size * patch_num, + channels [+ additional_channels], patch_shape_y, patch_shape_x). + If additional_input is provided, its channels are concatenated + along the channel dimension. + + See Also + -------- + :func:`physicsnemo.utils.patching.image_batching` + The underlying function used to perform the patching operation. + """ + if additional_input is not None: + add_input_interp = torch.nn.functional.interpolate( + input=additional_input, size=self.patch_shape, mode="bilinear" + ) + else: + add_input_interp = None + out = image_batching( + input=input, + patch_shape_y=self.patch_shape[0], + patch_shape_x=self.patch_shape[1], + overlap_pix=self.overlap_pix, + boundary_pix=self.boundary_pix, + input_interp=add_input_interp, + ) + return out + + def fuse(self, input: Tensor, batch_size: int) -> Tensor: + """ + Fuse patches back into a complete image. + + Reconstructs the original image by stitching together patches, + accounting for overlapping regions and boundary pixels. In overlapping + regions, values are averaged. + + Parameters + ---------- + input : Tensor + Input tensor containing patches with shape (batch_size * patch_num, + channels, patch_shape_y, patch_shape_x). + batch_size : int + The original batch size before patching. + + Returns + ------- + Tensor + Reconstructed image tensor with shape (batch_size, channels, + img_shape_y, img_shape_x). + + See Also + -------- + :func:`physicsnemo.utils.patching.image_fuse` + The underlying function used to perform the fusion operation. + """ + out = image_fuse( + input=input, + img_shape_y=self.img_shape[0], + img_shape_x=self.img_shape[1], + batch_size=batch_size, + overlap_pix=self.overlap_pix, + boundary_pix=self.boundary_pix, + ) + return out + + +def image_batching( + input: Tensor, + patch_shape_y: int, + patch_shape_x: int, + overlap_pix: int, + boundary_pix: int, + input_interp: Optional[Tensor] = None, +) -> Tensor: + """ + Splits a full image into a batch of patched images. + + This function takes a full image and splits it into patches, adding padding + where necessary. It can also concatenate additional interpolated data to + each patch if provided. + + Parameters + ---------- + input : Tensor + The input tensor representing the full image with shape (batch_size, + channels, img_shape_y, img_shape_x). + patch_shape_y : int + The height (y-dimension) of each image patch. + patch_shape_x : int + The width (x-dimension) of each image patch. + overlap_pix : int + The number of overlapping pixels between adjacent patches. + boundary_pix : int + The number of pixels to crop as a boundary from each patch. + input_interp : Optional[Tensor], optional + Optional additional data to concatenate to each patch with shape + (batch_size, interp_channels, patch_shape_y, patch_shape_x). + By default None. + + Returns + ------- + Tensor + A tensor containing the image patches, with shape (total_patches * + batch_size, channels [+ interp_channels], patch_shape_x, + patch_shape_y). + """ + # Infer sizes from input image + batch_size, _, img_shape_y, img_shape_x = input.shape + + # Safety check: make sure patch_shapes are large enough to accommodate + # overlaps and boundaries pixels + if (patch_shape_x - overlap_pix - boundary_pix) < 1: + raise ValueError( + f"patch_shape_x must verify patch_shape_x ({patch_shape_x}) >= " + f"1 + overlap_pix ({overlap_pix}) + boundary_pix ({boundary_pix})" + ) + if (patch_shape_y - overlap_pix - boundary_pix) < 1: + raise ValueError( + f"patch_shape_y must verify patch_shape_y ({patch_shape_y}) >= " + f"1 + overlap_pix ({overlap_pix}) + boundary_pix ({boundary_pix})" + ) + # Safety check: validate input_interp dimensions if provided + if input_interp is not None: + if input_interp.shape[0] != batch_size: + raise ValueError( + f"input_interp batch size ({input_interp.shape[0]}) must match " + f"input batch size ({batch_size})" + ) + if (input_interp.shape[2] != patch_shape_y) or ( + input_interp.shape[3] != patch_shape_x + ): + raise ValueError( + f"input_interp patch shape ({input_interp.shape[2]}, {input_interp.shape[3]}) " + f"must match specified patch shape ({patch_shape_y}, {patch_shape_x})" + ) + + # Safety check: make sure patch_shape is large enough in comparison to + # overlap_pix and boundary_pix. Otherwise, number of patches extracted by + # unfold differs from the expected number of patches. + if patch_shape_x <= overlap_pix + 2 * boundary_pix: + raise ValueError( + f"patch_shape_x ({patch_shape_x}) must verify " + f"patch_shape_x ({patch_shape_x}) > " + f"overlap_pix ({overlap_pix}) + 2 * boundary_pix ({boundary_pix})" + ) + if patch_shape_y <= overlap_pix + 2 * boundary_pix: + raise ValueError( + f"patch_shape_y ({patch_shape_y}) must verify " + f"patch_shape_y ({patch_shape_y}) > " + f"overlap_pix ({overlap_pix}) + 2 * boundary_pix ({boundary_pix})" + ) + + patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)) + patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)) + padded_shape_x = ( + (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1) + + patch_shape_x + + boundary_pix + ) + padded_shape_y = ( + (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1) + + patch_shape_y + + boundary_pix + ) + pad_x_right = padded_shape_x - img_shape_x - boundary_pix + pad_y_right = padded_shape_y - img_shape_y - boundary_pix + image_padding = torch.nn.ReflectionPad2d( + (boundary_pix, pad_x_right, boundary_pix, pad_y_right) + ).to( + input.device + ) # (padding_left,padding_right,padding_top,padding_bottom) + input_padded = image_padding(input) + patch_num = patch_num_x * patch_num_y + x_unfold = torch.nn.functional.unfold( + input=input_padded.view(_cast_type(input_padded)), # Cast to float + kernel_size=(patch_shape_y, patch_shape_x), + stride=( + patch_shape_y - overlap_pix - boundary_pix, + patch_shape_x - overlap_pix - boundary_pix, + ), + ).to(input_padded.dtype) + x_unfold = rearrange( + x_unfold, + "b (c p_h p_w) (nb_p_h nb_p_w) -> (nb_p_w nb_p_h b) c p_h p_w", + p_h=patch_shape_y, + p_w=patch_shape_x, + nb_p_h=patch_num_y, + nb_p_w=patch_num_x, + ) + if input_interp is not None: + input_interp_repeated = rearrange( + torch.repeat_interleave( + input=input_interp, + repeats=patch_num, + dim=0, + output_size=x_unfold.shape[0], + ), + "(b p) c h w -> (p b) c h w", + p=patch_num, + ) + return torch.cat((x_unfold, input_interp_repeated), dim=1) + else: + return x_unfold + + +def image_fuse( + input: Tensor, + img_shape_y: int, + img_shape_x: int, + batch_size: int, + overlap_pix: int, + boundary_pix: int, +) -> Tensor: + """ + Reconstructs a full image from a batch of patched images. Reverts the patching + operation performed by image_batching(). + + This function takes a batch of image patches and reconstructs the full + image by stitching the patches together. The function accounts for + overlapping and boundary pixels, ensuring that overlapping areas are + averaged. + + Parameters + ---------- + input : Tensor + The input tensor containing the image patches with shape (patch_num * batch_size, channels, patch_shape_y, patch_shape_x). + img_shape_y : int + The height (y-dimension) of the original full image. + img_shape_x : int + The width (x-dimension) of the original full image. + batch_size : int + The original batch size before patching. + overlap_pix : int + The number of overlapping pixels between adjacent patches. + boundary_pix : int + The number of pixels to crop as a boundary from each patch. + + Returns + ------- + Tensor + The reconstructed full image tensor with shape (batch_size, channels, + img_shape_y, img_shape_x). + + See Also + -------- + :func:`physicsnemo.utils.patching.image_batching` + The function this reverses, which splits images into patches. + """ + + # Infer sizes from input image shape + patch_shape_y, patch_shape_x = input.shape[2], input.shape[3] + + # Calculate the number of patches in each dimension + patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)) + patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)) + + # Calculate the shape of the input after padding + padded_shape_x = ( + (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1) + + patch_shape_x + + boundary_pix + ) + padded_shape_y = ( + (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1) + + patch_shape_y + + boundary_pix + ) + # Calculate the shape of the padding to add to input + pad_x_right = padded_shape_x - img_shape_x - boundary_pix + pad_y_right = padded_shape_y - img_shape_y - boundary_pix + pad = (boundary_pix, pad_x_right, boundary_pix, pad_y_right) + + # Count local overlaps between patches + input_ones = torch.ones( + (batch_size, input.shape[1], padded_shape_y, padded_shape_x), + device=input.device, + ) + overlap_count = torch.nn.functional.unfold( + input=input_ones, + kernel_size=(patch_shape_y, patch_shape_x), + stride=( + patch_shape_y - overlap_pix - boundary_pix, + patch_shape_x - overlap_pix - boundary_pix, + ), + ) + overlap_count = torch.nn.functional.fold( + input=overlap_count, + output_size=(padded_shape_y, padded_shape_x), + kernel_size=(patch_shape_y, patch_shape_x), + stride=( + patch_shape_y - overlap_pix - boundary_pix, + patch_shape_x - overlap_pix - boundary_pix, + ), + ) + + # Reshape input to make it 3D to apply fold + x = rearrange( + input, + "(nb_p_w nb_p_h b) c p_h p_w -> b (c p_h p_w) (nb_p_h nb_p_w)", + p_h=patch_shape_y, + p_w=patch_shape_x, + nb_p_h=patch_num_y, + nb_p_w=patch_num_x, + ) + # Stitch patches together (by summing over overlapping patches) + x_folded = torch.nn.functional.fold( + input=x, + output_size=(padded_shape_y, padded_shape_x), + kernel_size=(patch_shape_y, patch_shape_x), + stride=( + patch_shape_y - overlap_pix - boundary_pix, + patch_shape_x - overlap_pix - boundary_pix, + ), + ) + + # Remove padding + x_no_padding = x_folded[ + ..., pad[2] : pad[2] + img_shape_y, pad[0] : pad[0] + img_shape_x + ] + overlap_count_no_padding = overlap_count[ + ..., pad[2] : pad[2] + img_shape_y, pad[0] : pad[0] + img_shape_x + ] + + # Normalize by overlap count + return x_no_padding / overlap_count_no_padding + + +def _cast_type(input: Tensor) -> torch.dtype: + """Return float type based on input tensor type. + + Parameters + ---------- + input : Tensor + Input tensor to determine float type from + + Returns + ------- + torch.dtype + Float type corresponding to input tensor type for int32/64, + otherwise returns original dtype + """ + if input.dtype == torch.int32: + return torch.float32 + elif input.dtype == torch.int64: + return torch.float64 + else: + return input.dtype diff --git a/src/hirad/utils/stochastic_sampler.py b/src/hirad/utils/stochastic_sampler.py index ac5c13b..198fde4 100644 --- a/src/hirad/utils/stochastic_sampler.py +++ b/src/hirad/utils/stochastic_sampler.py @@ -15,290 +15,23 @@ # limitations under the License. -import math -from typing import Any, Callable, Optional +from typing import Callable, Optional import torch from torch import Tensor - -def image_batching( - input: Tensor, - img_shape_y: int, - img_shape_x: int, - patch_shape_y: int, - patch_shape_x: int, - batch_size: int, - overlap_pix: int, - boundary_pix: int, - input_interp: Optional[Tensor] = None, -) -> Tensor: - """ - Splits a full image into a batch of patched images. - - This function takes a full image and splits it into patches, adding padding where necessary. - It can also concatenate additional interpolated data to each patch if provided. - - Parameters - ---------- - input : Tensor - The input tensor representing the full image with shape (batch_size, channels, img_shape_x, img_shape_y). - img_shape_x : int - The width (x-dimension) of the original full image. - img_shape_y : int - The height (y-dimension) of the original full image. - patch_shape_x : int - The width (x-dimension) of each image patch. - patch_shape_y : int - The height (y-dimension) of each image patch. - batch_size : int - The original batch size before patching. - overlap_pix : int - The number of overlapping pixels between adjacent patches. - boundary_pix : int - The number of pixels to crop as a boundary from each patch. - input_interp : Optional[Tensor], optional - Optional additional data to concatenate to each patch with shape (batch_size, interp_channels, patch_shape_x, patch_shape_y). - By default None. - - Returns - ------- - Tensor - A tensor containing the image patches, with shape (total_patches * batch_size, channels [+ interp_channels], patch_shape_x, patch_shape_y). - """ - patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)) - patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)) - padded_shape_x = ( - (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1) - + patch_shape_x - + boundary_pix - ) - padded_shape_y = ( - (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1) - + patch_shape_y - + boundary_pix - ) - pad_x_right = padded_shape_x - img_shape_x - boundary_pix - pad_y_right = padded_shape_y - img_shape_y - boundary_pix - input_padded = torch.zeros( - input.shape[0], input.shape[1], padded_shape_y, padded_shape_x - ).to(input.device) - image_padding = torch.nn.ReflectionPad2d( - (boundary_pix, pad_x_right, boundary_pix, pad_y_right) - ).to( - input.device - ) # (padding_left,padding_right,padding_top,padding_bottom) - input_padded = image_padding(input) - patch_num = patch_num_x * patch_num_y - if input_interp is not None: - output = torch.zeros( - patch_num * batch_size, - input.shape[1] + input_interp.shape[1], - patch_shape_y, - patch_shape_x, - ).to(input.device) - else: - output = torch.zeros( - patch_num * batch_size, input.shape[1], patch_shape_y, patch_shape_x - ).to(input.device) - for x_index in range(patch_num_x): - for y_index in range(patch_num_y): - x_start = x_index * (patch_shape_x - overlap_pix - boundary_pix) - y_start = y_index * (patch_shape_y - overlap_pix - boundary_pix) - if input_interp is not None: - output[ - (x_index * patch_num_y + y_index) - * batch_size : (x_index * patch_num_y + y_index + 1) - * batch_size, - ] = torch.cat( - ( - input_padded[ - :, - :, - y_start : y_start + patch_shape_y, - x_start : x_start + patch_shape_x, - ], - input_interp, - ), - dim=1, - ) - else: - output[ - (x_index * patch_num_y + y_index) - * batch_size : (x_index * patch_num_y + y_index + 1) - * batch_size, - ] = input_padded[ - :, - :, - y_start : y_start + patch_shape_y, - x_start : x_start + patch_shape_x, - ] - return output - - -def image_fuse( - input: Tensor, - img_shape_y: int, - img_shape_x: int, - patch_shape_y: int, - patch_shape_x: int, - batch_size: int, - overlap_pix: int, - boundary_pix: int, -) -> Tensor: - """ - Reconstructs a full image from a batch of patched images. - - This function takes a batch of image patches and reconstructs the full image - by stitching the patches together. The function accounts for overlapping and - boundary pixels, ensuring that overlapping areas are averaged. - - Parameters - ---------- - input : Tensor - The input tensor containing the image patches with shape (total_patches * batch_size, channels, patch_shape_x, patch_shape_y). - img_shape_x : int - The width (x-dimension) of the original full image. - img_shape_y : int - The height (y-dimension) of the original full image. - patch_shape_x : int - The width (x-dimension) of each image patch. - patch_shape_y : int - The height (y-dimension) of each image patch. - batch_size : int - The original batch size before patching. - overlap_pix : int - The number of overlapping pixels between adjacent patches. - boundary_pix : int - The number of pixels to crop as a boundary from each patch. - - Returns - ------- - Tensor - The reconstructed full image tensor with shape (batch_size, channels, img_shape_x, img_shape_y). - - """ - patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)) - patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)) - padded_shape_x = ( - (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1) - + patch_shape_x - + boundary_pix - ) - padded_shape_y = ( - (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1) - + patch_shape_y - + boundary_pix - ) - pad_x_right = padded_shape_x - img_shape_x - boundary_pix - pad_y_right = padded_shape_y - img_shape_y - boundary_pix - residual_x = patch_shape_x - pad_x_right # residual pixels in the last patch - residual_y = patch_shape_y - pad_y_right # residual pixels in the last patch - output = torch.zeros( - batch_size, input.shape[1], img_shape_y, img_shape_x, device=input.device - ) - one_map = torch.ones(1, 1, input.shape[2], input.shape[3], device=input.device) - count_map = torch.zeros( - 1, 1, img_shape_y, img_shape_x, device=input.device - ) # to count the overlapping times - for x_index in range(patch_num_x): - for y_index in range(patch_num_y): - x_start = x_index * (patch_shape_x - overlap_pix - boundary_pix) - y_start = y_index * (patch_shape_y - overlap_pix - boundary_pix) - if (x_index == patch_num_x - 1) and (y_index != patch_num_y - 1): - output[ - :, :, y_start : y_start + patch_shape_y - 2 * boundary_pix, x_start: - ] += input[ - (x_index * patch_num_y + y_index) - * batch_size : (x_index * patch_num_y + y_index + 1) - * batch_size, - :, - boundary_pix : patch_shape_y - boundary_pix, - boundary_pix : residual_x + boundary_pix, - ] - count_map[ - :, :, y_start : y_start + patch_shape_y - 2 * boundary_pix, x_start: - ] += one_map[ - :, - :, - boundary_pix : patch_shape_y - boundary_pix, - boundary_pix : residual_x + boundary_pix, - ] - elif (y_index == patch_num_y - 1) and ((x_index != patch_num_x - 1)): - output[ - :, :, y_start:, x_start : x_start + patch_shape_x - 2 * boundary_pix - ] += input[ - (x_index * patch_num_y + y_index) - * batch_size : (x_index * patch_num_y + y_index + 1) - * batch_size, - :, - boundary_pix : residual_y + boundary_pix, - boundary_pix : patch_shape_x - boundary_pix, - ] - count_map[ - :, :, y_start:, x_start : x_start + patch_shape_x - 2 * boundary_pix - ] += one_map[ - :, - :, - boundary_pix : residual_y + boundary_pix, - boundary_pix : patch_shape_x - boundary_pix, - ] - elif x_index == patch_num_x - 1 and y_index == patch_num_y - 1: - output[:, :, y_start:, x_start:] += input[ - (x_index * patch_num_y + y_index) - * batch_size : (x_index * patch_num_y + y_index + 1) - * batch_size, - :, - boundary_pix : residual_y + boundary_pix, - boundary_pix : residual_x + boundary_pix, - ] - count_map[:, :, y_start:, x_start:] += one_map[ - :, - :, - boundary_pix : residual_y + boundary_pix, - boundary_pix : residual_x + boundary_pix, - ] - else: - output[ - :, - :, - y_start : y_start + patch_shape_y - 2 * boundary_pix, - x_start : x_start + patch_shape_x - 2 * boundary_pix, - ] += input[ - (x_index * patch_num_y + y_index) - * batch_size : (x_index * patch_num_y + y_index + 1) - * batch_size, - :, - boundary_pix : patch_shape_y - boundary_pix, - boundary_pix : patch_shape_x - boundary_pix, - ] - count_map[ - :, - :, - y_start : y_start + patch_shape_y - 2 * boundary_pix, - x_start : x_start + patch_shape_x - 2 * boundary_pix, - ] += one_map[ - :, - :, - boundary_pix : patch_shape_y - boundary_pix, - boundary_pix : patch_shape_x - boundary_pix, - ] - return output / count_map +from hirad.utils.patching import GridPatching2D def stochastic_sampler( - net: Any, - latents: Tensor, - img_lr: Tensor, + net: torch.nn.Module, + latents: torch.Tensor, + img_lr: torch.Tensor, class_labels: Optional[Tensor] = None, randn_like: Callable[[Tensor], Tensor] = torch.randn_like, - img_shape: tuple[int,int] = (448,448), - patch_shape_x: int = 448, - patch_shape_y: int = 448, - overlap_pix: int = 4, - boundary_pix: int = 2, - mean_hr: Optional[Tensor] = None, - lead_time_label: Optional[Tensor] = None, + patching: Optional[GridPatching2D] = None, + mean_hr: Optional[torch.Tensor] = None, + lead_time_label: Optional[torch.Tensor] = None, num_steps: int = 18, sigma_min: float = 0.002, sigma_max: float = 800, @@ -307,33 +40,63 @@ def stochastic_sampler( S_min: float = 0, S_max: float = float("inf"), S_noise: float = 1, -) -> Tensor: +) -> torch.Tensor: """ - Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution and patch-based diffusion. + Proposed EDM sampler (Algorithm 2) with minor changes to enable + super-resolution and patch-based diffusion. Parameters ---------- - net : Any - The neural network model that generates denoised images from noisy inputs. + net : torch.nn.Module + The neural network model that generates denoised images from noisy + inputs. + Expected signature: `net(x, x_lr, t_hat, class_labels, + lead_time_label=lead_time_label, embedding_selector=embedding_selector)`, + where: + x (torch.Tensor): Noisy input of shape (batch_size, C_out, H, W) + x_lr (torch.Tensor): Conditioning input of shape (batch_size, C_cond, H, W) + t_hat (torch.Tensor): Noise level of shape (batch_size, 1, 1, 1) or scalar + class_labels (torch.Tensor, optional): Optional class labels + lead_time_label (torch.Tensor, optional): Optional lead time labels + embedding_selector (callable, optional): Function to select + positional embeddings. Used for patch-based diffusion. + Returns: + torch.Tensor: Denoised prediction of shape (batch_size, C_out, H, W) + + Required attributes: + sigma_min (float): Minimum supported noise level for the model + sigma_max (float): Maximum supported noise level for the model + round_sigma (callable): Method to convert sigma values to tensor representation latents : Tensor - The latent variables (e.g., noise) used as the initial input for the sampler. + The latent variables (e.g., noise) used as the initial input for the + sampler. Has shape (batch_size, C_out, img_shape_y, img_shape_x). img_lr : Tensor - Low-resolution input image for conditioning the super-resolution process. + Low-resolution input image for conditioning the super-resolution + process. Must have shape (batch_size, C_lr, img_lr_ shape_y, + img_lr_shape_x). class_labels : Optional[Tensor], optional - Class labels for conditional generation, if required by the model. By default None. + Class labels for conditional generation, if required by the model. By + default None. randn_like : Callable[[Tensor], Tensor] - Function to generate random noise with the same shape as the input tensor. + Function to generate random noise with the same shape as the input + tensor. By default torch.randn_like. - img_shape : int - The height and width of the full image (assumed to be square). By default 448. - patch_shape : int - The height and width of each patch (assumed to be square). By default 448. - overlap_pix : int - Number of overlapping pixels between adjacent patches. By default 4. - boundary_pix : int - Number of pixels to be cropped as a boundary from each patch. By default 2. + patching : Optional[GridPatching2D], optional + A patching utility for patch-based diffusion. Implements methods to + extract patches from an image and batch the patches along `dim=0`. + Should also implement a `fuse` method to reconstruct the original image + from a batch of patches. See + :class:`physicsnemo.utils.patching.GridPatching2D` for details. By + default None, in which case non-patched diffusion is used. mean_hr : Optional[Tensor], optional - Optional tensor containing mean high-resolution images for conditioning. By default None. + Optional tensor containing mean high-resolution images for + conditioning. Must have same height and width as `img_lr`, with shape + (B_hr, C_hr, img_lr_shape_y, img_lr_shape_x) where the batch dimension + B_hr can be either 1, either equal to batch_size, or can be omitted. If + B_hr = 1 or is omitted, `mean_hr` will be expanded to match the shape + of `img_lr`. By default None. + lead_time_label : Optional[Tensor], optional + Optional lead time labels. By default None. num_steps : int Number of time steps for the sampler. By default 18. sigma_min : float @@ -343,7 +106,8 @@ def stochastic_sampler( rho : float Exponent used in the time step discretization. By default 7. S_churn : float - Churn parameter controlling the level of noise added in each step. By default 0. + Churn parameter controlling the level of noise added in each step. By + default 0. S_min : float Minimum time step for applying churn. By default 0. S_max : float @@ -354,20 +118,40 @@ def stochastic_sampler( Returns ------- Tensor - The final denoised image produced by the sampler. + The final denoised image produced by the sampler. Same shape as + `latents`: (batch_size, C_out, img_shape_y, img_shape_x). + + See Also + -------- + :class:`physicsnemo.models.diffusion.EDMPrecondSuperResolution`: A model + wrapper that provides preconditioning for super-resolution diffusion + models and implements the required interface for this sampler. """ # Adjust noise levels based on what's supported by the network. - "Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution." + # Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution. sigma_min = max(sigma_min, net.sigma_min) sigma_max = min(sigma_max, net.sigma_max) - # if isinstance(img_shape, tuple): - # img_shape_y, img_shape_x = img_shape - # else: - # img_shape_x = img_shape_y = img_shape - img_shape_x, img_shape_y = img_shape - patch_shape_x = min(img_shape_x, patch_shape_x) - patch_shape_y = min(img_shape_y, patch_shape_y) + + if patching is not None and not isinstance(patching, GridPatching2D): + raise ValueError("patching must be an instance of GridPatching2D.") + + # Safety check: if patching is used then img_lr and latents must have same + # height and width, otherwise there is mismatch in the number + # of patches extracted to form the final batch_size. + if patching: + if img_lr.shape[-2:] != latents.shape[-2:]: + raise ValueError( + f"img_lr and latents must have the same height and width, " + f"but found {img_lr.shape[-2:]} vs {latents.shape[-2:]}. " + ) + # img_lr and latents must also have the same batch_size, otherwise mismatch + # when processed by the network + if img_lr.shape[0] != latents.shape[0]: + raise ValueError( + f"img_lr and latents must have the same batch size, but found " + f"{img_lr.shape[0]} vs {latents.shape[0]}." + ) # Time step discretization. step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) @@ -381,46 +165,32 @@ def stochastic_sampler( [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] ) # t_N = 0 - b = latents.shape[0] - Nx = torch.arange(img_shape_x) - Ny = torch.arange(img_shape_y) - grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0)[ - None, - ].expand(b, -1, -1, -1) + batch_size = img_lr.shape[0] # conditioning = [mean_hr, img_lr, global_lr, pos_embd] - batch_size = img_lr.shape[0] x_lr = img_lr if mean_hr is not None: + if mean_hr.shape[-2:] != img_lr.shape[-2:]: + raise ValueError( + f"mean_hr and img_lr must have the same height and width, " + f"but found {mean_hr.shape[-2:]} vs {img_lr.shape[-2:]}." + ) x_lr = torch.cat((mean_hr.expand(x_lr.shape[0], -1, -1, -1), x_lr), dim=1) - global_index = None # input and position padding + patching - if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: - input_interp = torch.nn.functional.interpolate( - img_lr, (patch_shape_x, patch_shape_y), mode="bilinear" - ) - x_lr = image_batching( - x_lr, - img_shape_y, - img_shape_x, - patch_shape_x, - patch_shape_y, - batch_size, - overlap_pix, - boundary_pix, - input_interp, - ) - global_index = image_batching( - grid.float(), - img_shape_y, - img_shape_x, - patch_shape_x, - patch_shape_y, - batch_size, - overlap_pix, - boundary_pix, - ).int() + if patching: + # Patched conditioning [x_lr, mean_hr] + # (batch_size * patch_num, C_in + C_out, patch_shape_y, patch_shape_x) + x_lr = patching.apply(input=x_lr, additional_input=img_lr) + + # Function to select the correct positional embedding for each patch + def patch_embedding_selector(emb): + # emb: (N_pe, image_shape_y, image_shape_x) + # return: (batch_size * patch_num, N_pe, patch_shape_y, patch_shape_x) + return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) + + else: + patch_embedding_selector = None # Main sampling loop. x_next = latents.to(torch.float64) * t_steps[0] @@ -432,26 +202,14 @@ def stochastic_sampler( x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) - # Euler step. Perform patching operation on score tensor if patch-based generation is used - # denoised = net(x_hat, t_hat, class_labels,lead_time_label=lead_time_label).to(torch.float64) #x_lr + # Euler step. Perform patching operation on score tensor if patch-based + # generation is used denoised = net(x_hat, t_hat, + # class_labels,lead_time_label=lead_time_label).to(torch.float64) - if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: - x_hat_batch = image_batching( - x_hat, - img_shape_y, - img_shape_x, - patch_shape_x, - patch_shape_y, - batch_size, - overlap_pix, - boundary_pix, - ) - else: - x_hat_batch = x_hat - x_hat_batch = x_hat_batch.to(latents.device) + x_hat_batch = (patching.apply(input=x_hat) if patching else x_hat).to( + latents.device + ) x_lr = x_lr.to(latents.device) - if global_index is not None: - global_index = global_index.to(latents.device) if lead_time_label is not None: denoised = net( @@ -460,7 +218,7 @@ def stochastic_sampler( t_hat, class_labels, lead_time_label=lead_time_label, - global_index=global_index, + embedding_selector=patch_embedding_selector, ).to(torch.float64) else: # print("Sizes") @@ -474,40 +232,24 @@ def stochastic_sampler( x_lr, t_hat, class_labels, - global_index=global_index, + embedding_selector=patch_embedding_selector, ).to(torch.float64) - if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: + if patching: + # Un-patch the denoised image + # (batch_size, C_out, img_shape_y, img_shape_x) + denoised = patching.fuse(input=denoised, batch_size=batch_size) - denoised = image_fuse( - denoised, - img_shape_y, - img_shape_x, - patch_shape_x, - patch_shape_y, - batch_size, - overlap_pix, - boundary_pix, - ) d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur # Apply 2nd order correction. if i < num_steps - 1: - if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: - x_next_batch = image_batching( - x_next, - img_shape_y, - img_shape_x, - patch_shape_x, - patch_shape_y, - batch_size, - overlap_pix, - boundary_pix, - ) - else: - x_next_batch = x_next - # ask about this fix - x_next_batch = x_next_batch.to(latents.device) + # Patched input + # (batch_size * patch_num, C_out, patch_shape_y, patch_shape_x) + x_next_batch = (patching.apply(input=x_next) if patching else x_next).to( + latents.device + ) + if lead_time_label is not None: denoised = net( x_next_batch, @@ -515,7 +257,7 @@ def stochastic_sampler( t_next, class_labels, lead_time_label=lead_time_label, - global_index=global_index, + embedding_selector=patch_embedding_selector, ).to(torch.float64) else: denoised = net( @@ -523,19 +265,13 @@ def stochastic_sampler( x_lr, t_next, class_labels, - global_index=global_index, + embedding_selector=patch_embedding_selector, ).to(torch.float64) - if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: - denoised = image_fuse( - denoised, - img_shape_y, - img_shape_x, - patch_shape_x, - patch_shape_y, - batch_size, - overlap_pix, - boundary_pix, - ) + if patching: + # Un-patch the denoised image + # (batch_size, C_out, img_shape_y, img_shape_x) + denoised = patching.fuse(input=denoised, batch_size=batch_size) + d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) return x_next diff --git a/src/hirad/utils/train_helpers.py b/src/hirad/utils/train_helpers.py index d4529ac..218d6f1 100644 --- a/src/hirad/utils/train_helpers.py +++ b/src/hirad/utils/train_helpers.py @@ -17,6 +17,7 @@ import torch import numpy as np from omegaconf import ListConfig +import warnings def set_patch_shape(img_shape, patch_shape): @@ -26,12 +27,21 @@ def set_patch_shape(img_shape, patch_shape): patch_shape_x = img_shape_x if (patch_shape_y is None) or (patch_shape_y > img_shape_y): patch_shape_y = img_shape_y - if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: + if patch_shape_x == img_shape_x and patch_shape_y == img_shape_y: + use_patching = False + else: + use_patching = True + if use_patching: if patch_shape_x != patch_shape_y: + warnings.warn( + f"You are using rectangular patches " + f"of shape {(patch_shape_y, patch_shape_x)}, " + f"which are an experimental feature." + ) raise NotImplementedError("Rectangular patch not supported yet") if patch_shape_x % 32 != 0 or patch_shape_y % 32 != 0: raise ValueError("Patch shape needs to be a multiple of 32") - return (img_shape_y, img_shape_x), (patch_shape_y, patch_shape_x) + return use_patching, (img_shape_y, img_shape_x), (patch_shape_y, patch_shape_x) def set_seed(rank): From 5db5e47088b534b5ddc0362439fa6699f8361263 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Mon, 26 May 2025 18:22:20 +0200 Subject: [PATCH 50/66] small config fix --- src/hirad/conf/training_era_cosmo_diffusion.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/hirad/conf/training_era_cosmo_diffusion.yaml b/src/hirad/conf/training_era_cosmo_diffusion.yaml index 4271e44..0a069e9 100644 --- a/src/hirad/conf/training_era_cosmo_diffusion.yaml +++ b/src/hirad/conf/training_era_cosmo_diffusion.yaml @@ -15,5 +15,7 @@ defaults: # Model - model/era_cosmo_diffusion + - model_size/normal + # Training - training/era_cosmo_diffusion \ No newline at end of file From e8bd5cd7b3b9ae61f000f5fd676cda0780526f7d Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Tue, 27 May 2025 12:09:35 +0200 Subject: [PATCH 51/66] fix generation on distributed --- src/hirad/inference/generate.py | 4 +- src/hirad/training/train.py | 83 ++++++++++++++------------------- 2 files changed, 38 insertions(+), 49 deletions(-) diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index 8fed809..ec385dc 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -269,9 +269,11 @@ def generate_fn(image_lr, lead_time_label): ) if dist.rank == 0: + if cfg.generation.inference_mode != "regression": + return torch.cat(gathered_tensors), image_reg[0:1,::] return torch.cat(gathered_tensors) else: - return None + return None, None else: #TODO do this for multi-gpu setting above too if cfg.generation.inference_mode != "regression": diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 794dd55..39b3653 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -227,21 +227,6 @@ def main(cfg: DictConfig) -> None: model.train().requires_grad_(True).to(dist.device) - # param_to_name = {} - # ppp = False - # for name, param in model.named_parameters(): - # pid = id(param) - # if pid in param_to_name: - # print(f"[SHARED PARAM] {name} == {param_to_name[pid]}") - # ppp = True - # break - # else: - # param_to_name[pid] = name - # print(f'There are shared parameters: {ppp}') - - # TODO write summry from rank=0 possibly - # summary(model, input_size=[(1,img_out_channels,*img_shape),(1,img_in_channels,*img_shape),(1,1)]) - if dist.rank==0 and not os.path.exists(os.path.join(checkpoint_dir, 'model_args.json')): with open(os.path.join(checkpoint_dir, f'model_args.json'), 'w') as f: json.dump(model_args, f) @@ -572,6 +557,41 @@ def main(cfg: DictConfig) -> None: cur_nimg += cfg.training.hp.total_batch_size done = cur_nimg >= cfg.training.hp.training_duration + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + # Print stats if we crossed the printing threshold with this batch + tick_end_time = time.time() + fields = [] + fields += [f"samples {cur_nimg:<9.1f}"] + fields += [f"training_loss {average_loss:<7.2f}"] + fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] + fields += [f"learning_rate {current_lr:<7.8f}"] + fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] + fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] + fields += [ + f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" + ] + fields += [ + f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" + ] + if torch.cuda.is_available(): + fields += [ + f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" + ] + torch.cuda.reset_peak_memory_stats() + logger0.info(" ".join(fields)) + logger0.info(img_clean.shape) + logger0.info(img_lr.shape) + with nvtx.annotate("validation", color="red"): # Validation if validation_dataset_iterator is not None: @@ -671,39 +691,6 @@ def main(cfg: DictConfig) -> None: "validation_loss", average_valid_loss, cur_nimg ) - if is_time_for_periodic_task( - cur_nimg, - cfg.training.io.print_progress_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - rank_0_only=True, - ): - # Print stats if we crossed the printing threshold with this batch - tick_end_time = time.time() - fields = [] - fields += [f"samples {cur_nimg:<9.1f}"] - fields += [f"training_loss {average_loss:<7.2f}"] - fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] - fields += [f"learning_rate {current_lr:<7.8f}"] - fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] - fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] - fields += [ - f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" - ] - fields += [ - f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" - ] - if torch.cuda.is_available(): - fields += [ - f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" - ] - fields += [ - f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" - ] - torch.cuda.reset_peak_memory_stats() - logger0.info(" ".join(fields)) - # Save checkpoints if dist.world_size > 1: From 692dfe25d91af282d4f0b72f3a05237720d1fb76 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Tue, 27 May 2025 12:18:08 +0200 Subject: [PATCH 52/66] delete unnecessary logging --- src/hirad/training/train.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 39b3653..3a2fe2e 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -589,8 +589,6 @@ def main(cfg: DictConfig) -> None: ] torch.cuda.reset_peak_memory_stats() logger0.info(" ".join(fields)) - logger0.info(img_clean.shape) - logger0.info(img_lr.shape) with nvtx.annotate("validation", color="red"): # Validation From e3bab90489a6cc8e96fea2610e4b749fae62b5a7 Mon Sep 17 00:00:00 2001 From: Petar Stamenkovic Date: Tue, 27 May 2025 12:18:08 +0200 Subject: [PATCH 53/66] delete unnecessary logging --- src/hirad/training/train.py | 716 ++++++++++++++++++++++++++++++++++++ 1 file changed, 716 insertions(+) create mode 100755 src/hirad/training/train.py diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py new file mode 100755 index 0000000..3a2fe2e --- /dev/null +++ b/src/hirad/training/train.py @@ -0,0 +1,716 @@ +import os +import time + +import psutil +import hydra +from omegaconf import DictConfig, OmegaConf +import json +from contextlib import nullcontext +import nvtx +import torch +from hydra.utils import to_absolute_path +from torch.utils.tensorboard import SummaryWriter +from torch.nn.parallel import DistributedDataParallel +from torchinfo import summary + +from hirad.distributed import DistributedManager +from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper +from hirad.utils.train_helpers import set_seed, configure_cuda_for_consistent_precision, \ + set_patch_shape, compute_num_accumulation_rounds, \ + is_time_for_periodic_task, handle_and_clip_gradients +from hirad.utils.checkpoint import load_checkpoint, save_checkpoint +from hirad.utils.patching import RandomPatching2D +from hirad.models import UNet, EDMPrecondSuperResolution, EDMPrecondSR +from hirad.losses import ResidualLoss, RegressionLoss, RegressionLossCE +from hirad.datasets import init_train_valid_datasets_from_config + +from matplotlib import pyplot as plt + +torch._dynamo.reset() +# Increase the cache size limit +torch._dynamo.config.cache_size_limit = 264 # Set to a higher value +torch._dynamo.config.verbose = True # Enable verbose logging +torch._dynamo.config.suppress_errors = False # Forces the error to show all details +torch._logging.set_logs(recompiles=True, graph_breaks=True) + +# Define safe CUDA profiler tools that fallback to no-ops when CUDA is not available +def cuda_profiler(): + if torch.cuda.is_available(): + return torch.cuda.profiler.profile() + else: + return nullcontext() + + +def cuda_profiler_start(): + if torch.cuda.is_available(): + torch.cuda.profiler.start() + + +def cuda_profiler_stop(): + if torch.cuda.is_available(): + torch.cuda.profiler.stop() + + +def profiler_emit_nvtx(): + if torch.cuda.is_available(): + return torch.autograd.profiler.emit_nvtx() + else: + return nullcontext() + +@hydra.main(version_base=None, config_path="../conf", config_name="training") +def main(cfg: DictConfig) -> None: + # Initialize distributed environment for training + DistributedManager.initialize() + dist = DistributedManager() + + if dist.rank==0: + writer = SummaryWriter(log_dir='tensorboard') + logger = PythonLogger("main") # general logger + logger0 = RankZeroLoggingWrapper(logger, dist) # rank 0 logger + + OmegaConf.resolve(cfg) + dataset_cfg = OmegaConf.to_container(cfg.dataset) + if hasattr(cfg.dataset, "validation_path"): + train_test_split = True + else: + train_test_split = False + fp_optimizations = cfg.training.perf.fp_optimizations + songunet_checkpoint_level = cfg.training.perf.songunet_checkpoint_level + fp16 = fp_optimizations == "fp16" + enable_amp = fp_optimizations.startswith("amp") + amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16 + logger0.info(f"Saving the outputs in {os.getcwd()}") + checkpoint_dir = os.path.join( + cfg.training.io.get("checkpoint_dir", "."), f"checkpoints_{cfg.model.name}" + ) + if dist.rank==0 and not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) # added creating checkpoint dir + if cfg.training.hp.batch_size_per_gpu == "auto": + cfg.training.hp.batch_size_per_gpu = ( + cfg.training.hp.total_batch_size // dist.world_size + ) + + set_seed(dist.rank) + configure_cuda_for_consistent_precision() + + # Instantiate the dataset + data_loader_kwargs = { + "pin_memory": True, + "num_workers": cfg.training.perf.dataloader_workers, + "prefetch_factor": 2 if cfg.training.perf.dataloader_workers > 0 else None, + } + ( + dataset, + dataset_iterator, + validation_dataset, + validation_dataset_iterator, + ) = init_train_valid_datasets_from_config( + dataset_cfg, + data_loader_kwargs, + batch_size=cfg.training.hp.batch_size_per_gpu, + seed=0, + train_test_split=train_test_split, + ) + logger0.info(f"Training on dataset with size {len(dataset)}") + + # Parse image configuration & update model args + dataset_channels = len(dataset.input_channels()) + img_in_channels = dataset_channels + img_shape = dataset.image_shape() + img_out_channels = len(dataset.output_channels()) + if cfg.model.hr_mean_conditioning: + img_in_channels += img_out_channels + + + if cfg.model.name == "lt_aware_ce_regression": + prob_channels = dataset.get_prob_channel_index() #TODO figure out what prob_channel are and update dataloader + else: + prob_channels = None + + # Parse the patch shape + #TODO figure out patched diffusion and how to use it + if ( + cfg.model.name == "patched_diffusion" + or cfg.model.name == "lt_aware_patched_diffusion" + ): + patch_shape_x = cfg.training.hp.patch_shape_x + patch_shape_y = cfg.training.hp.patch_shape_y + else: + patch_shape_x = None + patch_shape_y = None + if ( + patch_shape_x + and patch_shape_y + and patch_shape_y >= img_shape[0] + and patch_shape_x >= img_shape[1] + ): + logger0.warning( + f"Patch shape {patch_shape_y}x{patch_shape_x} is larger than \ + the image shape {img_shape[0]}x{img_shape[1]}. Patching will not be used." + ) + patch_shape = (patch_shape_y, patch_shape_x) + use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) + if use_patching: + # Utility to perform patches extraction and batching + patching = RandomPatching2D( + img_shape=img_shape, + patch_shape=patch_shape, + patch_num=getattr(cfg.training.hp, "patch_num", 1), + ) + logger0.info("Patch-based training enabled") + else: + patching = None + logger0.info("Patch-based training disabled") + # interpolate global channel if patch-based model is used + if use_patching: + img_in_channels += dataset_channels + + # Instantiate the model and move to device. + model_args = { # default parameters for all networks + "img_out_channels": img_out_channels, + "img_resolution": list(img_shape), + "use_fp16": fp16, + "checkpoint_level": songunet_checkpoint_level, + } + if cfg.model.name == "lt_aware_ce_regression": + model_args["prob_channels"] = prob_channels + + if hasattr(cfg.model, "model_args"): # override defaults from config file + model_args.update(OmegaConf.to_container(cfg.model.model_args)) + + use_torch_compile = False + use_apex_gn = False + profile_mode = False + + if hasattr(cfg.training.perf, "torch_compile"): + use_torch_compile = cfg.training.perf.torch_compile + if hasattr(cfg.training.perf, "use_apex_gn"): + use_apex_gn = cfg.training.perf.use_apex_gn + model_args["use_apex_gn"] = use_apex_gn + + if hasattr(cfg.training.perf, "profile_mode"): + profile_mode = cfg.training.perf.profile_mode + model_args["profile_mode"] = profile_mode + + if enable_amp: + model_args["amp_mode"] = enable_amp + + + if cfg.model.name == "regression": + model = UNet( + img_in_channels=img_in_channels + model_args["N_grid_channels"], + **model_args, + ) + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + elif cfg.model.name == "lt_aware_ce_regression": + model = UNet( + img_in_channels=img_in_channels + + model_args["N_grid_channels"] + + model_args["lead_time_channels"], + **model_args, + ) + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] + elif cfg.model.name == "lt_aware_patched_diffusion": + model = EDMPrecondSuperResolution( + img_in_channels=img_in_channels + + model_args["N_grid_channels"] + + model_args["lead_time_channels"], + **model_args, + ) + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] + else: # diffusion or patched diffusion + model = EDMPrecondSuperResolution( + img_in_channels=img_in_channels + model_args["N_grid_channels"], + **model_args, + ) + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + + model.train().requires_grad_(True).to(dist.device) + + if dist.rank==0 and not os.path.exists(os.path.join(checkpoint_dir, 'model_args.json')): + with open(os.path.join(checkpoint_dir, f'model_args.json'), 'w') as f: + json.dump(model_args, f) + + if use_apex_gn: + model.to(memory_format=torch.channels_last) + + # Check if regression model is used with patching + if ( + cfg.model.name in ["regression", "lt_aware_ce_regression"] + and patching is not None + ): + raise ValueError( + f"Regression model ({cfg.model.name}) cannot be used with patch-based training. " + ) + + # Enable distributed data parallel if applicable + if dist.world_size > 1: + model = DistributedDataParallel( + model, + device_ids=[dist.local_rank], + broadcast_buffers=True, + output_device=dist.device, + find_unused_parameters=True, # dist.find_unused_parameters, + bucket_cap_mb=35, + gradient_as_bucket_view=True, + ) + + # Load the regression checkpoint if applicable #TODO test when training correction + if hasattr(cfg.training.io, "regression_checkpoint_path"): + regression_checkpoint_path = to_absolute_path( + cfg.training.io.regression_checkpoint_path + ) + if not os.path.isdir(regression_checkpoint_path): + raise FileNotFoundError( + f"Expected this regression checkpoint but not found: {regression_checkpoint_path}" + ) + #regression_net = torch.nn.Module() #TODO Module.from_checkpoint(regression_checkpoint_path) figure out how to save and load models, also, some basic functions like num_params, device + #TODO make regression model loading more robust (model type is both in rergession_checkpoint_path and regression_name) + #TODO add the option to choose epoch to load from / regression_checkpoint_path is now a folder + regression_model_args_path = os.path.join(regression_checkpoint_path, 'model_args.json') + if not os.path.isfile(regression_model_args_path): + raise FileNotFoundError(f"Missing config file at '{regression_model_args_path}'.") + + with open(regression_model_args_path, 'r') as f: + regression_model_args = json.load(f) + + regression_model_args.update({ + "use_apex_gn": use_apex_gn, + "profile_mode": profile_mode, + "amp_mode": enable_amp, + }) + + regression_net = UNet(**regression_model_args) + + _ = load_checkpoint( + path=regression_checkpoint_path, + model=regression_net, + device=dist.device + ) + regression_net.eval().requires_grad_(False).to(dist.device) + if use_apex_gn: + regression_net.to(memory_format=torch.channels_last) + logger0.success("Loaded the pre-trained regression model") + else: + regression_net = None + + # Compile the model and regression net if applicable + if use_torch_compile: + model = torch.compile(model) + if regression_net: + regression_net = torch.compile(regression_net) + + + # Compute the number of required gradient accumulation rounds + # It is automatically used if batch_size_per_gpu * dist.world_size < total_batch_size + batch_gpu_total, num_accumulation_rounds = compute_num_accumulation_rounds( + cfg.training.hp.total_batch_size, + cfg.training.hp.batch_size_per_gpu, + dist.world_size, + ) + batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu + logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds") + + patch_num = getattr(cfg.training.hp, "patch_num", 1) + max_patch_per_gpu = getattr(cfg.training.hp, "max_patch_per_gpu", 1) + + # calculate patch per iter + if hasattr(cfg.training.hp, "max_patch_per_gpu") and max_patch_per_gpu > 1: + max_patch_num_per_iter = min( + patch_num, (max_patch_per_gpu // batch_size_per_gpu) + ) # Ensure at least 1 patch per iter + patch_iterations = ( + patch_num + max_patch_num_per_iter - 1 + ) // max_patch_num_per_iter + patch_nums_iter = [ + min(max_patch_num_per_iter, patch_num - i * max_patch_num_per_iter) + for i in range(patch_iterations) + ] + print( + f"max_patch_num_per_iter is {max_patch_num_per_iter}, patch_iterations is {patch_iterations}, patch_nums_iter is {patch_nums_iter}" + ) + else: + patch_nums_iter = [patch_num] + + # Set patch gradient accumulation only for patched diffusion models + if cfg.model.name in { + "patched_diffusion", + "lt_aware_patched_diffusion", + }: + if len(patch_nums_iter) > 1: + if not patching: + logger0.info( + "Patching is not enabled: patch gradient accumulation automatically disabled." + ) + use_patch_grad_acc = False + else: + use_patch_grad_acc = True + else: + use_patch_grad_acc = False + # Automatically disable patch gradient accumulation for non-patched models + else: + logger0.info( + "Training a non-patched model: patch gradient accumulation automatically disabled." + ) + use_patch_grad_acc = None + + + # Instantiate the loss function + if cfg.model.name in ( + "diffusion", + "patched_diffusion", + "lt_aware_patched_diffusion", + ): + loss_fn = ResidualLoss( + regression_net=regression_net, + hr_mean_conditioning=cfg.model.hr_mean_conditioning, + ) + elif cfg.model.name == "regression": + loss_fn = RegressionLoss() + elif cfg.model.name == "lt_aware_ce_regression": + loss_fn = RegressionLossCE(prob_channels=prob_channels) + + # Instantiate the optimizer + optimizer = torch.optim.Adam( + params=model.parameters(), + lr=cfg.training.hp.lr, + betas=[0.9, 0.999], + eps=1e-8, + fused=True, + ) + + # Record the current time to measure the duration of subsequent operations. + start_time = time.time() + + # Load optimizer checkpoint if it exists + if dist.world_size > 1: + torch.distributed.barrier() + try: + cur_nimg = load_checkpoint( + path=checkpoint_dir, + model=model, + optimizer=optimizer, + device=dist.device, + ) + except: + cur_nimg = 0 + + ############################################################################ + # MAIN TRAINING LOOP # + ############################################################################ + + logger0.info(f"Training for {cfg.training.hp.training_duration} images...") + done = False + + # init variables to monitor running mean of average loss since last periodic + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 + start_nimg = cur_nimg + input_dtype = torch.float32 + if enable_amp: + input_dtype = torch.float32 + elif fp16: + input_dtype = torch.float16 + + # enable profiler: + with cuda_profiler(): + with profiler_emit_nvtx(): + while not done: + tick_start_nimg = cur_nimg + tick_start_time = time.time() + + if cur_nimg - start_nimg == 24 * cfg.training.hp.total_batch_size: + logger0.info(f"Starting Profiler at {cur_nimg}") + cuda_profiler_start() + + if cur_nimg - start_nimg == 25 * cfg.training.hp.total_batch_size: + logger0.info(f"Stopping Profiler at {cur_nimg}") + cuda_profiler_stop() + + with nvtx.annotate("Training iteration", color="green"): + # Compute & accumulate gradients + optimizer.zero_grad(set_to_none=True) + loss_accum = 0 + for n_i in range(num_accumulation_rounds): + with nvtx.annotate( + f"accumulation round {n_i}", color="Magenta" + ): + with nvtx.annotate("loading data", color="green"): + img_clean, img_lr, *lead_time_label = next( + dataset_iterator + ) + if use_apex_gn: + img_clean = img_clean.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + img_lr = img_lr.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + else: + img_clean = ( + img_clean.to(dist.device) + .to(input_dtype) + .contiguous() + ) + img_lr = ( + img_lr.to(dist.device) + .to(input_dtype) + .contiguous() + ) + loss_fn_kwargs = { + "net": model, + "img_clean": img_clean, + "img_lr": img_lr, + "augment_pipe": None, + } + if use_patch_grad_acc is not None: + loss_fn_kwargs[ + "use_patch_grad_acc" + ] = use_patch_grad_acc + + if lead_time_label: + lead_time_label = ( + lead_time_label[0].to(dist.device).contiguous() + ) + loss_fn_kwargs.update( + {"lead_time_label": lead_time_label} + ) + else: + lead_time_label = None + if use_patch_grad_acc: + loss_fn.y_mean = None + + for patch_num_per_iter in patch_nums_iter: + if patching is not None: + patching.set_patch_num(patch_num_per_iter) + loss_fn_kwargs.update({"patching": patching}) + with nvtx.annotate(f"loss forward", color="green"): + with torch.autocast( + "cuda", dtype=amp_dtype, enabled=enable_amp + ): + loss = loss_fn(**loss_fn_kwargs) + + loss = loss.sum() / batch_size_per_gpu + loss_accum += loss / num_accumulation_rounds + with nvtx.annotate(f"loss backward", color="yellow"): + loss.backward() + + + with nvtx.annotate(f"loss aggregate", color="green"): + loss_sum = torch.tensor([loss_accum], device=dist.device) + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce( + loss_sum, op=torch.distributed.ReduceOp.SUM + ) + average_loss = (loss_sum / dist.world_size).cpu().item() + + # update running mean of average loss since last periodic task + average_loss_running_mean += ( + average_loss - average_loss_running_mean + ) / n_average_loss_running_mean + n_average_loss_running_mean += 1 + + if dist.rank == 0: + writer.add_scalar("training_loss", average_loss, cur_nimg) + writer.add_scalar( + "training_loss_running_mean", + average_loss_running_mean, + cur_nimg, + ) + + ptt = is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ) + if ptt: + # reset running mean of average loss + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 + + # Update weights. + with nvtx.annotate("update weights", color="blue"): + + lr_rampup = cfg.training.hp.lr_rampup # ramp up the learning rate + for g in optimizer.param_groups: + if lr_rampup > 0: + g["lr"] = cfg.training.hp.lr * min(cur_nimg / lr_rampup, 1) + if cur_nimg >= lr_rampup: + g["lr"] *= cfg.training.hp.lr_decay ** ((cur_nimg - lr_rampup) // 5e6) + current_lr = g["lr"] + if dist.rank == 0: + writer.add_scalar("learning_rate", current_lr, cur_nimg) + handle_and_clip_gradients( + model, grad_clip_threshold=cfg.training.hp.grad_clip_threshold + ) + with nvtx.annotate("optimizer step", color="blue"): + optimizer.step() + + cur_nimg += cfg.training.hp.total_batch_size + done = cur_nimg >= cfg.training.hp.training_duration + + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.print_progress_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + # Print stats if we crossed the printing threshold with this batch + tick_end_time = time.time() + fields = [] + fields += [f"samples {cur_nimg:<9.1f}"] + fields += [f"training_loss {average_loss:<7.2f}"] + fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] + fields += [f"learning_rate {current_lr:<7.8f}"] + fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] + fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] + fields += [ + f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" + ] + fields += [ + f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" + ] + if torch.cuda.is_available(): + fields += [ + f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" + ] + torch.cuda.reset_peak_memory_stats() + logger0.info(" ".join(fields)) + + with nvtx.annotate("validation", color="red"): + # Validation + if validation_dataset_iterator is not None: + valid_loss_accum = 0 + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.validation_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + ): + with torch.no_grad(): + for _ in range(cfg.training.io.validation_steps): + ( + img_clean_valid, + img_lr_valid, + *lead_time_label_valid, + ) = next(validation_dataset_iterator) + + if use_apex_gn: + img_clean_valid = img_clean_valid.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + img_lr_valid = img_lr_valid.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + + else: + img_clean_valid = ( + img_clean_valid.to(dist.device) + .to(input_dtype) + .contiguous() + ) + img_lr_valid = ( + img_lr_valid.to(dist.device) + .to(input_dtype) + .contiguous() + ) + + loss_valid_kwargs = { + "net": model, + "img_clean": img_clean_valid, + "img_lr": img_lr_valid, + "augment_pipe": None, + } + if use_patch_grad_acc is not None: + loss_valid_kwargs[ + "use_patch_grad_acc" + ] = use_patch_grad_acc + if lead_time_label_valid: + lead_time_label_valid = ( + lead_time_label_valid[0] + .to(dist.device) + .contiguous() + ) + loss_valid_kwargs.update( + {"lead_time_label": lead_time_label_valid} + ) + if use_patch_grad_acc: + loss_fn.y_mean = None + + for patch_num_per_iter in patch_nums_iter: + if patching is not None: + patching.set_patch_num(patch_num_per_iter) + loss_fn_kwargs.update( + {"patching": patching} + ) + with torch.autocast( + "cuda", dtype=amp_dtype, enabled=enable_amp + ): + loss_valid = loss_fn(**loss_valid_kwargs) + + loss_valid = ( + (loss_valid.sum() / batch_size_per_gpu) + .cpu() + .item() + ) + valid_loss_accum += ( + loss_valid + / cfg.training.io.validation_steps + ) + valid_loss_sum = torch.tensor( + [valid_loss_accum], device=dist.device + ) + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce( + valid_loss_sum, op=torch.distributed.ReduceOp.SUM + ) + average_valid_loss = valid_loss_sum / dist.world_size + if dist.rank == 0: + writer.add_scalar( + "validation_loss", average_valid_loss, cur_nimg + ) + + + # Save checkpoints + if dist.world_size > 1: + torch.distributed.barrier() + if is_time_for_periodic_task( + cur_nimg, + cfg.training.io.save_checkpoint_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + save_checkpoint( + path=checkpoint_dir, + model=model, + optimizer=optimizer, + epoch=cur_nimg, + ) + + # Done. + logger0.info("Training Completed.") + + +if __name__ == "__main__": + main() \ No newline at end of file From 996f1368af5795c64e6146a0701c1425d1089925 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 27 May 2025 17:12:51 +0200 Subject: [PATCH 54/66] new image --- .edf/hirad-ci.toml | 14 ++++++++++++++ ci/cscs.yml | 18 +++++++++--------- ci/docker/Dockerfile | 25 +++++++++++++------------ 3 files changed, 36 insertions(+), 21 deletions(-) create mode 100644 .edf/hirad-ci.toml diff --git a/.edf/hirad-ci.toml b/.edf/hirad-ci.toml new file mode 100644 index 0000000..2ff1c6b --- /dev/null +++ b/.edf/hirad-ci.toml @@ -0,0 +1,14 @@ +image = "/capstor/scratch/cscs/${USER}/images/hirad-pytorch-25.01-py3.sqsh" + +mounts = ["/capstor","/iopsstor"] + +writable = true + +[annotations] +com.hooks.aws_ofi_nccl.enabled = "true" +com.hooks.aws_ofi_nccl.variant = "cuda12" + +[env] +FI_CXI_DISABLE_HOST_REGISTER = "1" +FI_MR_CACHE_MONITOR = "userfaultfd" +NCCL_DEBUG = "INFO" diff --git a/ci/cscs.yml b/ci/cscs.yml index fc92645..911274b 100644 --- a/ci/cscs.yml +++ b/ci/cscs.yml @@ -14,12 +14,12 @@ build_job: variables: DOCKERFILE: ci/docker/Dockerfile -#test_job: -# stage: test -# extends: .container-runner-clariden-gh200 -# image: $PERSIST_IMAGE_NAME -# script: -# - /opt/helloworld/bin/hello -# variables: -# SLURM_JOB_NUM_NODES: 2 -# SLURM_NTASKS: 2 +test_job: + stage: test + extends: .container-runner-clariden-gh200 + image: $PERSIST_IMAGE_NAME + script: + - python src/hirad/eval/run_scoring.py + variables: + SLURM_JOB_NUM_NODES: 1 + SLURM_NTASKS: 1 diff --git a/ci/docker/Dockerfile b/ci/docker/Dockerfile index 4772d76..72a4e3f 100644 --- a/ci/docker/Dockerfile +++ b/ci/docker/Dockerfile @@ -3,10 +3,10 @@ #FROM ubuntu:22.04 as builder FROM nvcr.io/nvidia/pytorch:25.01-py3 -COPY . /src + # setup -RUN apt-get update && apt-get install python3-pip python3-venv -y +#RUN apt-get update && apt-get install python3-pip python3-venv -y RUN pip install --upgrade \ pip #ninja @@ -15,27 +15,28 @@ RUN pip install --upgrade \ #setuptools # update flash-attn -RUN MAX_JOBS=16 pip install --upgrade --no-build-isolation \ - flash-attn==2.7.4.post1 -v +#RUN MAX_JOBS=16 pip install --upgrade --no-build-isolation \ +# flash-attn==2.7.4.post1 -v # install the rest of dependencies # TODO: Factor pydeps into a separate file(s) # TODO: Add versions for things RUN pip install \ anemoi-datasets \ - cartopy \ - matplotlib \ - numpy \ - pandas \ - scipy \ - torch + cartopy + #matplotlib \ + #numpy \ + #pandas \ + #scipy \ + #torch # replace pynvml with nvidia-ml-py -RUN pip uninstall -y pynvml && pip install nvidia-ml-py +#RUN pip uninstall -y pynvml && pip install nvidia-ml-py #CMD ["python3.11" "src/input_data/interpolate_basic_test.py"] - +COPY . /src +WORKDIR /src From 5d4fdff6ff5f57fbdf40497905a50ab549f7a143 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 27 May 2025 17:46:34 +0200 Subject: [PATCH 55/66] use absolute path --- ci/cscs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/cscs.yml b/ci/cscs.yml index 911274b..bc69beb 100644 --- a/ci/cscs.yml +++ b/ci/cscs.yml @@ -19,7 +19,7 @@ test_job: extends: .container-runner-clariden-gh200 image: $PERSIST_IMAGE_NAME script: - - python src/hirad/eval/run_scoring.py + - python /src/hirad/eval/run_scoring.py variables: SLURM_JOB_NUM_NODES: 1 SLURM_NTASKS: 1 From 9a46267486b6655888fe4a9fca3ea19f21ec9408 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 27 May 2025 17:50:20 +0200 Subject: [PATCH 56/66] cd to /src --- ci/cscs.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/cscs.yml b/ci/cscs.yml index bc69beb..9f4ec56 100644 --- a/ci/cscs.yml +++ b/ci/cscs.yml @@ -19,7 +19,8 @@ test_job: extends: .container-runner-clariden-gh200 image: $PERSIST_IMAGE_NAME script: - - python /src/hirad/eval/run_scoring.py + - cd /src + - python src/hirad/eval/run_scoring.py variables: SLURM_JOB_NUM_NODES: 1 SLURM_NTASKS: 1 From a834730abd322b7db46ed3ed986604f4d21211eb Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 27 May 2025 17:58:58 +0200 Subject: [PATCH 57/66] add USE_NCCL variable to pipeline --- ci/cscs.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/cscs.yml b/ci/cscs.yml index 9f4ec56..c318a2e 100644 --- a/ci/cscs.yml +++ b/ci/cscs.yml @@ -24,3 +24,4 @@ test_job: variables: SLURM_JOB_NUM_NODES: 1 SLURM_NTASKS: 1 + USE_NCCL: cuda12 From cdb35f47b9858dcd9d8f9758367b805e7d1c8936 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 27 May 2025 18:12:57 +0200 Subject: [PATCH 58/66] Attempting adding more variables --- ci/cscs.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ci/cscs.yml b/ci/cscs.yml index c318a2e..d446ab7 100644 --- a/ci/cscs.yml +++ b/ci/cscs.yml @@ -25,3 +25,7 @@ test_job: SLURM_JOB_NUM_NODES: 1 SLURM_NTASKS: 1 USE_NCCL: cuda12 + FI_CXI_DISABLE_HOST_REGISTER: 1 + FI_MR_CACHE_MONITOR: userfaultfd + NCCL_DEBUG: INFO + From 0601b70b8eacc8758c18215bd582720061d5cac3 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 27 May 2025 18:15:23 +0200 Subject: [PATCH 59/66] add env logging --- ci/cscs.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/cscs.yml b/ci/cscs.yml index d446ab7..bc120c7 100644 --- a/ci/cscs.yml +++ b/ci/cscs.yml @@ -19,6 +19,7 @@ test_job: extends: .container-runner-clariden-gh200 image: $PERSIST_IMAGE_NAME script: + - env - cd /src - python src/hirad/eval/run_scoring.py variables: From aa9b9aa8ee1183bdf19fbc6f69893fd9c6822522 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 27 May 2025 19:12:54 +0200 Subject: [PATCH 60/66] Try distributed torch --- ci/cscs.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/cscs.yml b/ci/cscs.yml index bc120c7..d477f94 100644 --- a/ci/cscs.yml +++ b/ci/cscs.yml @@ -21,7 +21,8 @@ test_job: script: - env - cd /src - - python src/hirad/eval/run_scoring.py + #- python src/hirad/eval/run_scoring.py + - MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) MASTER_PORT=29500 RANK=${SLURM_PROCID} LOCAL_RANK=${SLURM_LOCALID} WORLD_SIZE=${SLURM_NPROCS} python -c "import os, torch; import torch.distributed as dist; local_rank = int(os.environ['LOCAL_RANK']); torch.cuda.set_device(local_rank); dist.init_process_group('nccl', init_method='env://'); rank = dist.get_rank(); print(f'Hello from rank {rank}'); t = torch.tensor([rank]).to('cuda'); dist.all_reduce(t); print(f'The sum of ranks is {t}.'); dist.destroy_process_group()" variables: SLURM_JOB_NUM_NODES: 1 SLURM_NTASKS: 1 From 18ab0dc4cc475ce51ed961cd487f3531826924f2 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Tue, 27 May 2025 19:13:58 +0200 Subject: [PATCH 61/66] distributed test --- ci/cscs.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ci/cscs.yml b/ci/cscs.yml index d477f94..0fb1640 100644 --- a/ci/cscs.yml +++ b/ci/cscs.yml @@ -24,8 +24,8 @@ test_job: #- python src/hirad/eval/run_scoring.py - MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) MASTER_PORT=29500 RANK=${SLURM_PROCID} LOCAL_RANK=${SLURM_LOCALID} WORLD_SIZE=${SLURM_NPROCS} python -c "import os, torch; import torch.distributed as dist; local_rank = int(os.environ['LOCAL_RANK']); torch.cuda.set_device(local_rank); dist.init_process_group('nccl', init_method='env://'); rank = dist.get_rank(); print(f'Hello from rank {rank}'); t = torch.tensor([rank]).to('cuda'); dist.all_reduce(t); print(f'The sum of ranks is {t}.'); dist.destroy_process_group()" variables: - SLURM_JOB_NUM_NODES: 1 - SLURM_NTASKS: 1 + SLURM_JOB_NUM_NODES: 2 + SLURM_NTASKS: 4 USE_NCCL: cuda12 FI_CXI_DISABLE_HOST_REGISTER: 1 FI_MR_CACHE_MONITOR: userfaultfd From d3411677452618b76e656d0abdb01d6f3947a197 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 28 May 2025 13:57:08 +0200 Subject: [PATCH 62/66] run training from ci/cd. --- ci/cscs.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ci/cscs.yml b/ci/cscs.yml index 0fb1640..a7df846 100644 --- a/ci/cscs.yml +++ b/ci/cscs.yml @@ -21,8 +21,9 @@ test_job: script: - env - cd /src + - python src/hirad/training/train.py --config-name=training_era_cosmo_regression.yaml #- python src/hirad/eval/run_scoring.py - - MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) MASTER_PORT=29500 RANK=${SLURM_PROCID} LOCAL_RANK=${SLURM_LOCALID} WORLD_SIZE=${SLURM_NPROCS} python -c "import os, torch; import torch.distributed as dist; local_rank = int(os.environ['LOCAL_RANK']); torch.cuda.set_device(local_rank); dist.init_process_group('nccl', init_method='env://'); rank = dist.get_rank(); print(f'Hello from rank {rank}'); t = torch.tensor([rank]).to('cuda'); dist.all_reduce(t); print(f'The sum of ranks is {t}.'); dist.destroy_process_group()" + #- MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) MASTER_PORT=29500 RANK=${SLURM_PROCID} LOCAL_RANK=${SLURM_LOCALID} WORLD_SIZE=${SLURM_NPROCS} python -c "import os, torch; import torch.distributed as dist; local_rank = int(os.environ['LOCAL_RANK']); torch.cuda.set_device(local_rank); dist.init_process_group('nccl', init_method='env://'); rank = dist.get_rank(); print(f'Hello from rank {rank}'); t = torch.tensor([rank]).to('cuda'); dist.all_reduce(t); print(f'The sum of ranks is {t}.'); dist.destroy_process_group()" variables: SLURM_JOB_NUM_NODES: 2 SLURM_NTASKS: 4 @@ -31,3 +32,4 @@ test_job: FI_MR_CACHE_MONITOR: userfaultfd NCCL_DEBUG: INFO + From 2f2e5f560122ff0e4f5be2a5d591594c52944040 Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Wed, 28 May 2025 17:57:11 +0200 Subject: [PATCH 63/66] Split Dockerfile so we have differnet environments, to keep the CI one light. (Note: May need to change webhook command to take the new docker file) --- .edf/README.md | 23 ++++++++++++++++++++++- .edf/ubuntu2.toml | 5 +++++ .gitignore | 4 +++- ci/docker/{Dockerfile => Dockerfile.ci} | 15 +++++++++++---- ci/docker/Dockerfile.dev | 22 ++++++++++++++++++++++ 5 files changed, 63 insertions(+), 6 deletions(-) create mode 100644 .edf/ubuntu2.toml rename ci/docker/{Dockerfile => Dockerfile.ci} (65%) create mode 100644 ci/docker/Dockerfile.dev diff --git a/.edf/README.md b/.edf/README.md index 4da3cd1..2316e95 100644 --- a/.edf/README.md +++ b/.edf/README.md @@ -7,4 +7,25 @@ This adds the repository path to the EDF search path. run: ``` srun -A a-a122 --environment=ubuntu2 cat /etc/os-release -``` \ No newline at end of file +``` + +# local development +srun --environment $PWD/.edf/hirad-ci.toml -A a-a122 -p debug --pty bash + + + +# list current images +podman images + +# build according to the dockerfile into an image with tag tmpv1, from current directory. +podman build -f ci/docker/Dockerfile -t tmpv1 . + +# +podman run -it localhost/tmpv1 + +mkdir /capstor/scratch/cscs/mmcgloho/images + +# export the image into a sqsh file so it is availabe outside the interactive shell +enroot import -x mount -o /capstor/scratch/cscs/mmcgloho/images/hirad-pytorch-25.01-py3.sqsh podman://localhost/tmpv1 + +ls /capstor/scratch/cscs/mmcgloho/images \ No newline at end of file diff --git a/.edf/ubuntu2.toml b/.edf/ubuntu2.toml new file mode 100644 index 0000000..22dba65 --- /dev/null +++ b/.edf/ubuntu2.toml @@ -0,0 +1,5 @@ + + +image = "library/ubuntu:24.04" +mounts = ["/capstor/scratch/cscs/${USER}:/capstor/scratch/cscs/${USER}"] +workdir = "/capstor/scratch/cscs/${USER}" diff --git a/.gitignore b/.gitignore index dee6b07..c7d62e7 100644 --- a/.gitignore +++ b/.gitignore @@ -183,4 +183,6 @@ plots/* temp.* # local script -interpolate.sh \ No newline at end of file +interpolate.sh +core_clariden-ln002_241188 + diff --git a/ci/docker/Dockerfile b/ci/docker/Dockerfile.ci similarity index 65% rename from ci/docker/Dockerfile rename to ci/docker/Dockerfile.ci index 72a4e3f..3378585 100644 --- a/ci/docker/Dockerfile +++ b/ci/docker/Dockerfile.ci @@ -1,9 +1,8 @@ # Following some suggestions in https://meteoswiss.atlassian.net/wiki/spaces/APN/pages/719684202/Clariden+Alps+environment+setup #FROM ubuntu:22.04 as builder -FROM nvcr.io/nvidia/pytorch:25.01-py3 - - +#FROM nvcr.io/nvidia/pytorch:25.01-py3 +FROM nvcr.io/nvidia/physicsnemo/physicsnemo:25.03 # setup #RUN apt-get update && apt-get install python3-pip python3-venv -y @@ -36,7 +35,15 @@ RUN pip install \ #CMD ["python3.11" "src/input_data/interpolate_basic_test.py"] - COPY . /src WORKDIR /src +ENV NCCL_TESTS_VERSION=2.15.0 + +RUN wget -O nccl-tests-${NCCL_TESTS_VERSION}.tar.gz https://github.com/NVIDIA/nccl-tests/archive/refs/tags/v${NCCL_TESTS_VERSION}.tar.gz \ + && tar xf nccl-tests-${NCCL_TESTS_VERSION}.tar.gz \ + && cd nccl-tests-${NCCL_TESTS_VERSION} \ + && MPI=1 MPI_HOME=/opt/hpcx/ompi make -j$(nproc) \ + && cd .. \ + && rm -rf nccl-tests-${NCCL_TESTS_VERSION}.tar.gz + diff --git a/ci/docker/Dockerfile.dev b/ci/docker/Dockerfile.dev new file mode 100644 index 0000000..a37e84d --- /dev/null +++ b/ci/docker/Dockerfile.dev @@ -0,0 +1,22 @@ +FROM nvcr.io/nvidia/pytorch:25.01-py3 + +# setup +#RUN apt-get update && apt-get install python3-pip python3-venv -y +RUN pip install --upgrade \ + pip + +# install the rest of dependencies +# TODO: Factor pydeps into a separate file(s) +# TODO: Add versions for things +RUN pip install \ + anemoi-datasets \ + cartopy + + +COPY . /src +WORKDIR /src + +# Useful utilities for performance monitoring/analysis/debugging +RUN apt-get update \ + && apt-get install -yqq --no-install-recommends strace valgrind htop nvtop atop ioping fio \ + && rm -rf /var/lib/apt/lists/* \ No newline at end of file From 417291cae0f43d08d670ca00854caf0f8ec1c9fd Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Thu, 5 Jun 2025 12:06:59 +0200 Subject: [PATCH 64/66] update dockerfile --- ci/cscs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/cscs.yml b/ci/cscs.yml index a7df846..85579a0 100644 --- a/ci/cscs.yml +++ b/ci/cscs.yml @@ -12,7 +12,7 @@ build_job: stage: build extends: .container-builder-cscs-gh200 variables: - DOCKERFILE: ci/docker/Dockerfile + DOCKERFILE: ci/docker/Dockerfile.ci test_job: stage: test From 700d91da53a64d3c4ead6321b3ac92a595e2d68e Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Thu, 5 Jun 2025 12:16:50 +0200 Subject: [PATCH 65/66] Adding back Dockerfile to see if that causes webhook to work --- ci/docker/Dockerfile | 49 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 ci/docker/Dockerfile diff --git a/ci/docker/Dockerfile b/ci/docker/Dockerfile new file mode 100644 index 0000000..3378585 --- /dev/null +++ b/ci/docker/Dockerfile @@ -0,0 +1,49 @@ +# Following some suggestions in https://meteoswiss.atlassian.net/wiki/spaces/APN/pages/719684202/Clariden+Alps+environment+setup + +#FROM ubuntu:22.04 as builder +#FROM nvcr.io/nvidia/pytorch:25.01-py3 +FROM nvcr.io/nvidia/physicsnemo/physicsnemo:25.03 + +# setup +#RUN apt-get update && apt-get install python3-pip python3-venv -y +RUN pip install --upgrade \ + pip + #ninja + #wheel + #packaging + #setuptools + +# update flash-attn +#RUN MAX_JOBS=16 pip install --upgrade --no-build-isolation \ +# flash-attn==2.7.4.post1 -v + +# install the rest of dependencies +# TODO: Factor pydeps into a separate file(s) +# TODO: Add versions for things +RUN pip install \ + anemoi-datasets \ + cartopy + #matplotlib \ + #numpy \ + #pandas \ + #scipy \ + #torch + + +# replace pynvml with nvidia-ml-py +#RUN pip uninstall -y pynvml && pip install nvidia-ml-py + +#CMD ["python3.11" "src/input_data/interpolate_basic_test.py"] + +COPY . /src +WORKDIR /src + +ENV NCCL_TESTS_VERSION=2.15.0 + +RUN wget -O nccl-tests-${NCCL_TESTS_VERSION}.tar.gz https://github.com/NVIDIA/nccl-tests/archive/refs/tags/v${NCCL_TESTS_VERSION}.tar.gz \ + && tar xf nccl-tests-${NCCL_TESTS_VERSION}.tar.gz \ + && cd nccl-tests-${NCCL_TESTS_VERSION} \ + && MPI=1 MPI_HOME=/opt/hpcx/ompi make -j$(nproc) \ + && cd .. \ + && rm -rf nccl-tests-${NCCL_TESTS_VERSION}.tar.gz + From 72f40a283a5eb466c327ffc2ebb75e9f39638b2b Mon Sep 17 00:00:00 2001 From: Mary McGlohon Date: Thu, 5 Jun 2025 12:18:17 +0200 Subject: [PATCH 66/66] Re-delete Dockerfile, don't need it. --- ci/docker/Dockerfile | 49 -------------------------------------------- 1 file changed, 49 deletions(-) delete mode 100644 ci/docker/Dockerfile diff --git a/ci/docker/Dockerfile b/ci/docker/Dockerfile deleted file mode 100644 index 3378585..0000000 --- a/ci/docker/Dockerfile +++ /dev/null @@ -1,49 +0,0 @@ -# Following some suggestions in https://meteoswiss.atlassian.net/wiki/spaces/APN/pages/719684202/Clariden+Alps+environment+setup - -#FROM ubuntu:22.04 as builder -#FROM nvcr.io/nvidia/pytorch:25.01-py3 -FROM nvcr.io/nvidia/physicsnemo/physicsnemo:25.03 - -# setup -#RUN apt-get update && apt-get install python3-pip python3-venv -y -RUN pip install --upgrade \ - pip - #ninja - #wheel - #packaging - #setuptools - -# update flash-attn -#RUN MAX_JOBS=16 pip install --upgrade --no-build-isolation \ -# flash-attn==2.7.4.post1 -v - -# install the rest of dependencies -# TODO: Factor pydeps into a separate file(s) -# TODO: Add versions for things -RUN pip install \ - anemoi-datasets \ - cartopy - #matplotlib \ - #numpy \ - #pandas \ - #scipy \ - #torch - - -# replace pynvml with nvidia-ml-py -#RUN pip uninstall -y pynvml && pip install nvidia-ml-py - -#CMD ["python3.11" "src/input_data/interpolate_basic_test.py"] - -COPY . /src -WORKDIR /src - -ENV NCCL_TESTS_VERSION=2.15.0 - -RUN wget -O nccl-tests-${NCCL_TESTS_VERSION}.tar.gz https://github.com/NVIDIA/nccl-tests/archive/refs/tags/v${NCCL_TESTS_VERSION}.tar.gz \ - && tar xf nccl-tests-${NCCL_TESTS_VERSION}.tar.gz \ - && cd nccl-tests-${NCCL_TESTS_VERSION} \ - && MPI=1 MPI_HOME=/opt/hpcx/ompi make -j$(nproc) \ - && cd .. \ - && rm -rf nccl-tests-${NCCL_TESTS_VERSION}.tar.gz -