From bfe05710cf1a0eb938218381c59cbe03bb0c9374 Mon Sep 17 00:00:00 2001 From: Kobe Vandelanotte <61313519+kobebryant432@users.noreply.github.com> Date: Fri, 13 Feb 2026 09:51:37 +0100 Subject: [PATCH] Revert "Model selection swiss (#245)" This reverts commit 4ec34b9bb6e247f0d8d4071747599f46ae25b484. --- src/valenspy/diagnostic/__init__.py | 2 +- src/valenspy/diagnostic/_ensemble2ref.py | 26 +------- src/valenspy/diagnostic/diagnostic.py | 72 ++++++++--------------- src/valenspy/diagnostic/functions.py | 44 -------------- src/valenspy/diagnostic/plot_utils.py | 5 +- src/valenspy/diagnostic/visualizations.py | 58 ------------------ 6 files changed, 28 insertions(+), 179 deletions(-) diff --git a/src/valenspy/diagnostic/__init__.py b/src/valenspy/diagnostic/__init__.py index c67c0a93..0d9fac73 100644 --- a/src/valenspy/diagnostic/__init__.py +++ b/src/valenspy/diagnostic/__init__.py @@ -7,4 +7,4 @@ from ._model2self import * from ._model2ref import * from ._ensemble2self import * -from ._ensemble2ref import * +from ._ensemble2ref import * \ No newline at end of file diff --git a/src/valenspy/diagnostic/_ensemble2ref.py b/src/valenspy/diagnostic/_ensemble2ref.py index e1c74431..5cf51100 100644 --- a/src/valenspy/diagnostic/_ensemble2ref.py +++ b/src/valenspy/diagnostic/_ensemble2ref.py @@ -2,33 +2,11 @@ from valenspy.diagnostic.functions import * from valenspy.diagnostic.visualizations import * -__all__ = [ - "MetricsRankings", - "EnsembleSubSelection", - ] +__all__ = ["MetricsRankings"] MetricsRankings = Ensemble2Ref( calc_metrics_dt, plot_metric_ranking, "Metrics Rankings", "The rankings of ensemble members with respect to several metrics when compared to the reference." -) - -EnsembleSubSelection = Ensemble2Ref( - case_sub_selection, - {"default": - default_plot_kwargs({ - "x": "var", - "y": "abs_change", - "selected": ["highest", "middle", "lowest"], - "sel_colors": {"highest": "red", "middle": "blue", "lowest": "green"} - })(ensemble_selection_boxplot), - "heatmap": - default_plot_kwargs({ - "index": "label", - "columns": "var", - "values": "rel_change" - })(ensemble_change_signal_heatmap)}, - "Ensemble Sub Selection", - "The sub selection of ensemble members." - ) +) \ No newline at end of file diff --git a/src/valenspy/diagnostic/diagnostic.py b/src/valenspy/diagnostic/diagnostic.py index 473a493e..926915dd 100644 --- a/src/valenspy/diagnostic/diagnostic.py +++ b/src/valenspy/diagnostic/diagnostic.py @@ -2,7 +2,7 @@ import xarray as xr import matplotlib.pyplot as plt from valenspy.processing.mask import add_prudence_regions -from valenspy.diagnostic.plot_utils import default_plot_kwargs, _augment_kwargs +from valenspy.diagnostic.plot_utils import _augment_kwargs from valenspy._utilities import generate_parameters_doc import numpy as np import inspect @@ -18,16 +18,16 @@ class Diagnostic(): """An abstract class representing a diagnostic.""" def __init__( - self, diagnostic_function, plotting_functions, name=None, description=None + self, diagnostic_function, plotting_function, name=None, description=None ): """Initialize the Diagnostic. Parameters ---------- - diagnostic_function : function + diagnostic_function The function that applies a diagnostic to the data. - plotting_functions : function or dict of functions - The functions or dictionary of functions that visualize the diagnostic. + plotting_function + The function that visualizes the results of the diagnostic. name : str The name of the diagnostic. description : str @@ -36,11 +36,7 @@ def __init__( self.name = name self._description = description self.diagnostic_function = diagnostic_function - - if callable(plotting_functions): - plotting_functions = {"default": plotting_functions} - - self.plotting_functions = plotting_functions + self.plotting_function = plotting_function self.__signature__ = inspect.signature(self.diagnostic_function) self.__doc__ = self.description @@ -65,7 +61,7 @@ def apply(self, data): """ pass - def plot(self, result, kind="default", title=None, **kwargs): + def plot(self, result, title=None, **kwargs): """Plot the diagnostic. Single ax plots. Parameters @@ -82,29 +78,12 @@ def plot(self, result, kind="default", title=None, **kwargs): ax : matplotlib.axis.Axis The axis (singular) of the plot. """ - ax = self.plotting_functions[kind](result, **kwargs) + ax = self.plotting_function(result, **kwargs) if not title: title = self.name ax.set_title(title) return ax - #Support easy access to the plotting functions - # class PlotAccessor: - # """An accessor to the plotting functions of the diagnostic.""" - - # def __init__(self, diagnostic): - # self.diagnostic = diagnostic - - # def __getattr__(self, kind): - # def plot_kind(*args, **kwargs): - # return self._diagnostic.plot(*args, kind=kind, **kwargs) - # return plot_kind - - # @property - # def plot(self): - # return self.PlotAccessor(self) - - @property def description(self): """Generate the docstring for the diagnostic.""" @@ -121,7 +100,7 @@ class DataSetDiagnostic(Diagnostic): """A class representing a diagnostic that operates on the level of single datasets.""" def __init__( - self, diagnostic_function, plotting_functions, name=None, description=None, plot_type="single" + self, diagnostic_function, plotting_function, name=None, description=None, plot_type="single" ): """ Initialize the DataSetDiagnostic. @@ -133,7 +112,7 @@ def __init__( If "single", plot_dt will plot all the leaves of the DataTree on the same axis. If "facetted", plot_dt will plot all the leaves of the DataTree on different axes. """ - super().__init__(diagnostic_function, plotting_functions, name, description) + super().__init__(diagnostic_function, plotting_function, name, description) self.plot_type = plot_type def __call__(self, data, *args, **kwargs): @@ -185,7 +164,6 @@ def apply(self, ds: xr.Dataset, *args, **kwargs): """ return self.diagnostic_function(ds, *args, **kwargs) - #Currently no support for different plotting kinds def plot_dt(self, dt, *args, **kwargs): if self.plot_type == "single": return self.plot_dt_single(dt, *args, **kwargs) @@ -279,20 +257,20 @@ class Model2Self(DataSetDiagnostic): """A class representing a diagnostic that compares a model to itself.""" def __init__( - self, diagnostic_function, plotting_functions, name=None, description=None, plot_type="single" + self, diagnostic_function, plotting_function, name=None, description=None, plot_type="single" ): """Initialize the Model2Self diagnostic.""" - super().__init__(diagnostic_function, plotting_functions, name, description, plot_type) + super().__init__(diagnostic_function, plotting_function, name, description, plot_type) class Model2Ref(DataSetDiagnostic): """A class representing a diagnostic that compares a model to a reference.""" def __init__( - self, diagnostic_function, plotting_functions, name=None, description=None, plot_type="facetted" + self, diagnostic_function, plotting_function, name=None, description=None, plot_type="facetted" ): """Initialize the Model2Ref diagnostic.""" - super().__init__(diagnostic_function, plotting_functions, name, description, plot_type) + super().__init__(diagnostic_function, plotting_function, name, description, plot_type) def apply(self, ds: xr.Dataset, ref: xr.Dataset, **kwargs): """Apply the diagnostic to the data. Only the common variables between the data and the reference are used. @@ -318,11 +296,11 @@ class Ensemble2Self(Diagnostic): """A class representing a diagnostic that compares an ensemble to itself.""" def __init__( - self, diagnostic_function, plotting_functions, name=None, description=None, iterative_plotting=False + self, diagnostic_function, plotting_function, name=None, description=None, iterative_plotting=False ): """Initialize the Ensemble2Self diagnostic.""" self.iterative_plotting = iterative_plotting - super().__init__(diagnostic_function, plotting_functions, name, description) + super().__init__(diagnostic_function, plotting_function, name, description) def apply(self, dt: DataTree, mask=None, **kwargs): @@ -343,7 +321,7 @@ def apply(self, dt: DataTree, mask=None, **kwargs): return self.diagnostic_function(dt, **kwargs) - def plot(self, result, kind="default", variables=None, title=None, facetted=None, **kwargs): + def plot(self, result, variables=None, title=None, facetted=None, **kwargs): """Plot the diagnostic. If facetted multiple plots on different axes are created. If not facetted, the plots are created on the same axis. @@ -371,10 +349,10 @@ class Ensemble2Ref(Diagnostic): """A class representing a diagnostic that compares an ensemble to a reference.""" def __init__( - self, diagnostic_function, plotting_functions, name=None, description=None + self, diagnostic_function, plotting_function, name=None, description=None ): """Initialize the Ensemble2Ref diagnostic.""" - super().__init__(diagnostic_function, plotting_functions, name, description) + super().__init__(diagnostic_function, plotting_function, name, description) def apply(self, dt: DataTree, ref, **kwargs): """Apply the diagnostic to the data. @@ -391,13 +369,10 @@ def apply(self, dt: DataTree, ref, **kwargs): DataTree or dict The data after applying the diagnostic as a DataTree or a dictionary of results with the tree nodes as keys. """ - #Make sure that the dt and ref are isomorphic - if isinstance(ref, DataTree): - dt = filter_like(dt, ref) - ref = filter_like(ref, dt) + # TODO: Add some checks to make sure the reference is a DataTree or a Dataset and contain common variables with the data. return self.diagnostic_function(dt, ref, **kwargs) - def plot(self, result, kind="default", facetted=True, **kwargs): + def plot(self, result, facetted=True, **kwargs): """Plot the diagnostic. If axes are provided, the diagnostic is plotted facetted. If ax is provided, the diagnostic is plotted non-facetted. @@ -417,9 +392,9 @@ def plot(self, result, kind="default", facetted=True, **kwargs): raise ValueError("Either ax or axes can be provided, not both.") elif "ax" not in kwargs and "axes" not in kwargs: ax = plt.gca() - return self.plotting_functions[kind](result, ax=ax, **kwargs) + return self.plotting_function(result, ax=ax, **kwargs) else: - return self.plotting_functions[kind](result, **kwargs) + return self.plotting_function(result, **kwargs) def _common_vars(ds1, ds2): """Return the common variables in two datasets.""" @@ -436,4 +411,3 @@ def _initialize_multiaxis_plot(n, subplot_kws={}): nrows=n//2+1, ncols=2, figsize=(10, 5 * n), subplot_kw=subplot_kws ) return fig, axes - diff --git a/src/valenspy/diagnostic/functions.py b/src/valenspy/diagnostic/functions.py index 62c1e43e..37855f90 100644 --- a/src/valenspy/diagnostic/functions.py +++ b/src/valenspy/diagnostic/functions.py @@ -314,50 +314,6 @@ def calc_metrics_dt(dt_mod: DataTree, da_obs: xr.Dataset, metrics=None, pss_binw df = _add_ranks_metrics(df) return df -########################################## -# Ensemble2Ensemble diagnostic functions # -########################################## - -def case_sub_selection(dt_future: DataTree, dt_ref: DataTree, vars): - """ - Select 3 ensemble members based on avg normalized climate change in the variable of interest. - The two extreme and mediaan members are selected. - """ - #TODO - add direction of indicators - #TODO - Check consistnecy with paper (slightly different approach)! - #TEST how to leave an extra dimension left over - e.g. regions - #Test with different periods of variables (tas_JJA, etc) - #Improve the heatmap plot (square default, absolute values of the change?, larger squares) - dt_change = dt_future.mean() - dt_ref.mean() - dt_rel_change = dt_change / dt_ref.mean() - - data = [[x.path, var, x[var].values] for x in dt_rel_change.leaves for var in x.data_vars] - data_abs = [[x.path, var, x[var].values] for x in dt_change.leaves for var in x.data_vars] - - #Create one dataframe containing the relative change and the absolute change - df = pd.DataFrame(data, columns=["label", "var", "rel_change"]) - df_abs = pd.DataFrame(data_abs, columns=["label", "var", "abs_change"]) - - df = pd.merge(df, df_abs, on=["label", "var"]) - - df["rel_change"] = df["rel_change"].astype(float) - df["abs_change"] = df["abs_change"].astype(float) - - #Normalize the absolute change per variable - df["norm_rel_change"] = df.groupby("var")["rel_change"].transform(lambda x: (x - x.mean()) / x.std()) - df["norm_abs_change"] = df.groupby("var")["abs_change"].transform(lambda x: (x - x.mean()) / x.std()) - - #Get the rank per variable - df["rank"] = df["norm_abs_change"].groupby(df["var"]).rank(ascending=True, method='min') - df["member_rank"] = df["rank"].where(df["var"].isin(vars)).groupby(df["label"]).transform("mean").rank(ascending=True, method='min') # - - #Rank the ensemble members based on the mean of the normalized absolute change - df["highest"] = df["member_rank"] == 1 - df["lowest"] = df["member_rank"] == df["member_rank"].max() - df["middle"] = df["member_rank"] == np.floor(df["member_rank"].median()) - - return df - ################################## ####### Helper functions ######### diff --git a/src/valenspy/diagnostic/plot_utils.py b/src/valenspy/diagnostic/plot_utils.py index f24fe1f9..931200e3 100644 --- a/src/valenspy/diagnostic/plot_utils.py +++ b/src/valenspy/diagnostic/plot_utils.py @@ -29,14 +29,13 @@ def _augment_kwargs(def_kwargs, **kwargs): cbar_kwargs = _merge_kwargs(def_kwargs.pop('cbar_kwargs'), kwargs.pop('cbar_kwargs', {})) def_kwargs['cbar_kwargs'] = cbar_kwargs - #Is this correct? Are the cbar_kwargs not overwritten if defined by the user? return _merge_kwargs(def_kwargs, kwargs) ###################################### ############## Wrappers ############## ###################################### -def default_plot_kwargs(def_kwargs): +def default_plot_kwargs(kwargs): """ Decorator to set the default keyword arguments for the plotting function. User will override and/or be augmented with the default keyword arguments. subplot_kws and cbar_kwargs can also be set as default keyword arguments for the plotting function. @@ -64,7 +63,7 @@ def decorator(plotting_function): @wraps(plotting_function) def wrapper(*args, **kwargs): - return plotting_function(*args, **_augment_kwargs(def_kwargs=def_kwargs, **kwargs)) + return plotting_function(*args, **_augment_kwargs(def_kwargs=kwargs, **kwargs)) return wrapper diff --git a/src/valenspy/diagnostic/visualizations.py b/src/valenspy/diagnostic/visualizations.py index 448722e2..cd702907 100644 --- a/src/valenspy/diagnostic/visualizations.py +++ b/src/valenspy/diagnostic/visualizations.py @@ -541,64 +541,6 @@ def plot_metric_ranking(df_metric, ax=None, plot_colorbar=True, hex_color1 = Non return ax -########################################## -# Ensemble2Ensemble diagnostic functions # -########################################## - -#Add default plot kwargs at the diagnostic level -def ensemble_selection_boxplot(df, selected=None, sel_colors=None, ax=None, **kwargs): - """ - Create a boxplot of the ensemble members with the selected ensemble members highlighted. - - Parameters - ---------- - df : pd.DataFrame - A DataFrame containing the ensemble members and their values. - selected : str or list of str, optional - The column name(s) with boolean values indicating the selected ensemble members. - If None all members are plotted. - ax : matplotlib.axes.Axes, optional - The axes on which to plot the boxplot. If None, a new figure and axes are created. - **kwargs : dict - Additional keyword arguments passed to `seaborn.boxplot`. - """ - sns.boxplot(data=df, ax=ax, **kwargs) - - if selected is None: - sns.swarmplot(data=df, ax=ax, **kwargs) - else: - for sel in selected: - if sel in df.columns: - if isinstance(sel_colors, dict) and sel in sel_colors: - kwargs.update({"color": sel_colors[sel]}) - sns.swarmplot(data=df[df[sel]], ax=ax, **kwargs) - - ax = _get_gca(**kwargs) - return ax - -def ensemble_change_signal_heatmap(df, index=None, columns=None, values=None, ax=None, **kwargs): - """ - Create a heatmap of the ensemble change signal for different variables per ensemble member. - - Parameters - ---------- - df : pd.DataFrame - A DataFrame containing the ensemble change signals for different variables per ensemble member. - index : str, optional - The column name to use as the index for the heatmap. - columns : str, optional - The column name to use as the columns for the heatmap. - values : str, optional - The column name to use as the climate signal values for the heatmap. - ax : matplotlib.axes.Axes, optional - The axes on which to plot the heatmap. If None, a new figure and axes are created. - **kwargs : dict - Additional keyword arguments passed to `seaborn.heatmap`. - """ - sns.heatmap(data=df.pivot(index=index, columns=columns, values=values), ax=ax, **kwargs) - ax = _get_gca(**kwargs) - return ax - ################################## # Helper functions # ##################################