diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_casestudy_main.py b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_casestudy_main.py new file mode 100644 index 000000000..a0a3f2632 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_casestudy_main.py @@ -0,0 +1,146 @@ +#!/usr/bin/env -S uv run --script +# /// script +# dependencies = [ +# "scikit-image>=0.26.0", +# "scikit-learn>=1.8.0", +# "xarray", +# "tqdm", +# "cartopy", +# "omegaconf", +# "netcdf4" +# ] +# /// + +""" +This script tracks tropical cyclones in forecast and target +for a single sample, then looks for the tracks corresponding +to a user-selected storm of interest and produces diagnostic +plots for that storm, including the track error, simulated pressure +and wind speed. + +The tracking functionality is also intended for future use in +systematical evaluation of all tropical cyclones in the prediction. + +Before running, export 10u, 10v and msl to netcdf, regridded +to 1°x1° as follows: +uv run export --run-id --stream ERA5 \ +--output-dir --format netcdf --regrid-degree 1 \ +--regrid-type regular_ll \ +--channel 10u 10v msl +and again with --type target for the target. In TC_config.yml, set inpath to + where the regridded data is. Make sure that +the timesteps specificed in TC_config.yml are within the simulation. + +Then run this script via +uv run TC_casestuy_main.py + +All parameters including the strom of interest are set in the config. +""" + +from functools import cached_property +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import xarray as xr +from cyclone_finder import ( + Cyclone, + CycloneFinder, + cyclones_in_ds, + track2pandas, + track_cyclones, + wrap_lon, +) +from cyclone_plots import track_eval_plot, track_snapshots +from omegaconf import OmegaConf + + +class TcCaseStudy: + """ + Read the cyclone tracker settings, data paths and the target cyclone + from config, then find the matched tracks corresponding to that cyclones + in the prediction and target. + """ + + def __init__(self, cfg: dict): + self.cfg = cfg + self.selected_storm = Cyclone( + wind=0, + pressure=0, + lon=cfg.selected_storm.lon, + lat=cfg.selected_storm.lat, + time=np.datetime64(cfg.selected_storm.time), + ) + self.finder = CycloneFinder( + sigma=cfg.tracking_params.laplace_size, + th_laplace=cfg.tracking_params.laplace_threshold, + th_pressure=cfg.tracking_params.pressure_threshold, + th_wind=cfg.tracking_params.wind_threshold, + min_distance=cfg.tracking_params.peak_separation, + ) + self.outpath = Path(cfg.outpath) + + @cached_property + def datasets(self): + infiles = { + k: f"{self.cfg.inpath}{k}_{self.cfg.init_time}_{self.cfg.runid}_ERA5.nc" + for k in ("target", "prediction") + } + datasets = { + k: wrap_lon(xr.open_dataset(f)).sel(latitude=slice(self.cfg.latmin, self.cfg.latmax)) + for k, f in infiles.items() + } + return datasets + + @cached_property + def cyclones(self): + times = self.datasets["target"].valid_time.values + cyclones = { + k: [cyclones_in_ds(ds, self.finder, time=t) for t in times] + for k, ds in self.datasets.items() + } + return cyclones + + @cached_property + def tracks(self): + tracks = { + k: track_cyclones(d, self.cfg.tracking_params.merge_distance) + for k, d in self.cyclones.items() + } + return tracks + + @cached_property + def matched_tracks(self): + times = self.datasets["target"].valid_time.values + storm_index = np.argmin(np.abs(times - self.selected_storm.time)) + matched_stroms = { + k: self.selected_storm.match(x[storm_index]) for k, x in self.cyclones.items() + } + matched_tracks = { + k: track2pandas(d.subset(matched_stroms[k])) for k, d in self.tracks.items() + } + return matched_tracks + + def plot(self): + self.outpath.mkdir(exist_ok=True) + # evaluation plot + evalfile = f"{self.outpath}/{self.cfg.runid}_cyclone_{self.cfg.init_time}.png" + fig, axs = track_eval_plot(self.matched_tracks) + init_time = self.datasets["target"].forecast_reference_time.values + fig.suptitle(f"forecast initialized {init_time}") + plt.savefig(evalfile) + + # example maps + snapshotfile = f"{self.outpath}/{self.cfg.runid}_cyclone_{self.cfg.init_time}_snapshots.png" + track_snapshots(self.matched_tracks, self.datasets) + plt.savefig(snapshotfile) + + +def main(): + cfg = OmegaConf.load("TC_config.yml") + casestudy = TcCaseStudy(cfg) + casestudy.plot() + + +if __name__ == "__main__": + main() diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_config.yml b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_config.yml new file mode 100644 index 000000000..17a02da61 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/TC_config.yml @@ -0,0 +1,17 @@ +runid : "i0xr7z48" +init_time : "2023-10-07T00" +inpath : "/p/project1/weatherai/buschow1/wegen_export/cyclones/" +outpath : "./plots/" +latmin: -30 # TCs are only detected for latmin<=lat<=latmax +latmax: 30 +selected_storm : # the storm you want to analyze + lon : 154.7 + lat : 9.6 + time : "2023-10-07T00:00" +tracking_params: + laplace_size : 2 # in units of gridboxes + laplace_threshold : 0 # should be >= 0 to fond low pressure systems + pressure_threshold : 103000 # in Pa + wind_threshold : 0 # in m/s + peak_separation : 5 # in units of gridboxes + merge_distance: 300 # in km \ No newline at end of file diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_finder.py b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_finder.py new file mode 100644 index 000000000..8e7641d4a --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_finder.py @@ -0,0 +1,173 @@ +from dataclasses import dataclass + +import numpy as np +import pandas as pd +import xarray as xr +from scipy.cluster.hierarchy import DisjointSet +from scipy.ndimage import gaussian_laplace, maximum_filter +from skimage.feature import peak_local_max +from sklearn.metrics.pairwise import haversine_distances +from tqdm import tqdm + + +@dataclass(order=True, frozen=True) +class Cyclone: + wind: float + pressure: float + lon: float + lat: float + ID: str | None = None + time: np.datetime64 | None = None + + def dist_to(self, other: "Cyclone") -> float: + r_earth = 6371.0 + p1 = [np.deg2rad(deg) for deg in (self.lat, self.lon)] + p2 = [np.deg2rad(deg) for deg in (other.lat, other.lon)] + angle = haversine_distances(X=np.array(p1).reshape(1, -1), Y=np.array(p2).reshape(1, -1)) + return r_earth * angle + + def match(self, cyclones: list["Cyclone"], maxdist_km: float = 3000) -> "Cyclone": + """ + Select the closest from a set of other cyclones + """ + dists = [self.dist_to(other) for other in cyclones] + if min(dists) < maxdist_km: + return cyclones[np.argmin(dists)] + else: + return None + + +class CycloneFinder: + def __init__( + self, + sigma: float = 2, + th_laplace: float = 30, + th_pressure: float = 101000, + th_wind: float = 10, + min_distance: float = 5, + ): + """ + Try finding cyclones with simple blob detection + plus some heuristic filter criteria + Attributes + ---------- + sigma: Gauss standard deviation. The zeros of the laplace filter + are at sqrt(2)*sigma distance from the center + th_laplace: minimum value of the filtered field + th_pressure: maxmimum pressure value + th_wind: minimum wind speed + min_distance: minimum distance between peaks in number of gridpoints + """ + self.sigma = sigma + self.th_laplace = th_laplace + self.th_pressure = th_pressure + self.th_wind = th_wind + self.min_distance = min_distance + + def filter(self, image): + return gaussian_laplace(image, sigma=self.sigma) + + def mask(self, pressure, windmax): + pressuremask = (pressure < self.th_pressure).values + windmask = windmax > self.th_wind + return pressuremask & windmask + + def find_cyclones(self, pressure, wind, windmaxsize=5, timestamp=None) -> list["Cyclone"]: + # apply the LoG filter to pressure + filtered = self.filter(pressure) + # find candidate maxima + candidates = peak_local_max( + filtered, threshold_abs=self.th_laplace, min_distance=self.min_distance + ) + # apply mask + windmax = maximum_filter(wind.values, size=windmaxsize) + mask = self.mask(pressure, windmax)[candidates[:, 0], candidates[:, 1]] + cyclones = candidates[mask, :] + res = [ + Cyclone( + lon=pressure.longitude.values[y], + lat=pressure.latitude.values[x], + wind=windmax[x, y], + pressure=pressure.values[x, y], + time=timestamp, + ) + for x, y in zip(cyclones[:, 0], cyclones[:, 1], strict=False) + ] + return res + + +def track_cyclones(timesteps: list[list["Cyclone"]], merge_distance_km: float = 300) -> DisjointSet: + """ + Takes a list of lists of cyclones, each top level entry representing one timestep, + returns a DisjointSet where each entry represents a track. + """ + tracks = DisjointSet() + prev_step = [] + + for step in tqdm(timesteps): + # Add all storms from this timestep + for storm in step: + tracks.add(storm) + + # Build all candidate matches (prev → curr) + candidates = [] + for s_prev in prev_step: + for s_curr in step: + d = s_prev.dist_to(s_curr) + if d <= merge_distance_km: + candidates.append((d, s_prev, s_curr)) + + # Sort by distance (closest first) + candidates.sort(key=lambda x: x[0]) + + # Keep track of which storms have already been matched + used_prev = set() + used_curr = set() + + # Greedy matching: closest pairs first + for _dist, s_prev, s_curr in candidates: + if s_prev not in used_prev and s_curr not in used_curr: + tracks.merge(s_prev, s_curr) + used_prev.add(s_prev) + used_curr.add(s_curr) + + prev_step = step + + return tracks + + +def track2pandas(track: list["Cyclone"]) -> pd.DataFrame: + return pd.DataFrame([storm.__dict__ for storm in track]).set_index("time").sort_index() + + +def cyclones_in_ds(ds: xr.Dataset, finder: "CycloneFinder", time: np.datetime64) -> list["Cyclone"]: + """ + Find cyclones in a dataset containing at least msl, u10, v10, + at a given timestep, using a given CycloneFinder. + """ + ds_t = ds.sel(valid_time=time) + msl = ds_t.msl + v = np.sqrt(ds_t.u10**2 + ds_t.v10**2) + return finder.find_cyclones(pressure=msl, wind=v, timestamp=time) + + +def track_error(track1: pd.DataFrame, track2: pd.DataFrame) -> pd.DataFrame: + """ + Given two tracks as pd.DataFrames, compute their distance in km. + At timesteps where one track is missing, the result is NaN. + """ + r_earth = 6371.0 + coords = [np.deg2rad(x.loc[:, ["lat", "lon"]]) for x in track1.align(track2, join="inner")] + angle = haversine_distances(X=coords[0].values, Y=coords[1].values) + distance = pd.DataFrame({"distance": r_earth * np.diag(angle)}, index=coords[0].index) + all_idx = track1.index.union(track2.index) + distance = distance.reindex(all_idx) + + return distance + + +def wrap_lon(ds: xr.Dataset) -> xr.Dataset: + "Convert longitude from 0...360 to -180...180" + ds["longitude"] = (ds["longitude"] + 180) % 360 - 180 + ds = ds.sortby("longitude") + return ds diff --git a/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_plots.py b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_plots.py new file mode 100644 index 000000000..caaaa8c5f --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/example_extras/tropical_cyclones/cyclone_plots.py @@ -0,0 +1,75 @@ +import cartopy.crs as ccrs +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import xarray as xr +from cyclone_finder import track_error + + +def track_eval_plot(matched_tracks): + """ + A four panel plot showing + * the target and predicted track on a map + * the track error in km + * the pressure at the cyclone core for target and prediction + * the maximum wind speed near the cyclone center + """ + fig, axs = plt.subplots(2, 2, sharex=True, figsize=(10, 6)) + fig.delaxes(axs[0, 0]) + axs[0, 0] = fig.add_subplot(2, 2, 1, projection=ccrs.PlateCarree()) + axs[0, 0].coastlines() + axs[0, 0].set_title("storm tracks") + axs[0, 1].set_title("track error in km") + axs[1, 0].set_title("core pressure in Pa") + axs[1, 1].set_title("max wind speed in m/s") + track_error(*matched_tracks.values()).plot(ax=axs[0, 1]) + for lab, track in matched_tracks.items(): + track.plot(x="lon", y="lat", ax=axs[0, 0], label=lab) + track.plot(y="pressure", ax=axs[1, 0], label=lab) + track.plot(y="wind", ax=axs[1, 1], label=lab) + return fig, axs + + +def bounding_box(matched_tracks, pad=2): + """ + Compute a lon/lat box containing the matched cyclone tracks. + """ + all_lons = pd.concat([matched_tracks["target"]["lon"], matched_tracks["prediction"]["lon"]]) + all_lats = pd.concat([matched_tracks["target"]["lat"], matched_tracks["prediction"]["lat"]]) + lon_min = all_lons.min() - pad + lon_max = all_lons.max() + pad + lat_min = all_lats.min() - pad + lat_max = all_lats.max() + pad + bbox = (lon_min, lon_max, lat_min, lat_max) + return bbox + + +def track_snapshots(matched_tracks, datasets, skip=5): + """ + A plot with two rows showing the spatial distribution of windspeeds + in prediction and target, with crosses marking the cyclone centers found + by the tracker. The time difference between snapshots is controlled by + skip. + """ + bbox = bounding_box(matched_tracks) + all_steps = matched_tracks["target"].index.union(matched_tracks["prediction"].index) + selsteps = np.arange(0, len(all_steps), skip) + plotdat = xr.concat(datasets.values(), dim=datasets.keys()).isel(valid_time=selsteps) + plotdat = plotdat.sel(longitude=slice(bbox[0], bbox[1]), latitude=slice(bbox[2], bbox[3])) + speed = np.sqrt(plotdat.u10**2 + plotdat.v10**2) + p = speed.plot( + row="concat_dim", col="valid_time", subplot_kws=dict(projection=ccrs.PlateCarree()) + ) + for ax in p.axs.flatten(): + ax.coastlines() + ax.set_extent(bbox) + for i, s in enumerate(all_steps[selsteps]): + leadtime = plotdat.forecast_period[i].values / np.timedelta64(1, "h") + p.axs[0, i].set_title(s) + p.axs[1, i].set_title(f"{leadtime}h forecast") + for j, tr in enumerate(matched_tracks.values()): + if s in tr.index: + tr.loc[[s]].plot.scatter( + x="lon", y="lat", ax=p.axs[j, i], color="tab:red", marker="x", s=100 + ) + return p