diff --git a/src/fev/metrics.py b/src/fev/metrics.py index 0091e73..45a2439 100644 --- a/src/fev/metrics.py +++ b/src/fev/metrics.py @@ -55,6 +55,31 @@ def compute( """ raise NotImplementedError + def compute_scores( + self, + *, + y_true: np.ndarray, + y_pred: np.ndarray, + y_past: np.ndarray, + y_past_lengths: np.ndarray, + q_pred: np.ndarray, + seasonality: int, + quantile_levels: list[float], + per_quantile_scores: bool = False, + ) -> dict[str, float]: + """Named scores reported for this metric. Returns `{self.name: self.compute(...)}`.""" + return { + self.name: self.compute( + y_true=y_true, + y_pred=y_pred, + y_past=y_past, + y_past_lengths=y_past_lengths, + q_pred=q_pred, + seasonality=seasonality, + quantile_levels=quantile_levels, + ) + } + def get_metric(metric: MetricConfig) -> Metric: """Get a metric class by name or configuration.""" @@ -265,11 +290,29 @@ def compute( return float(np.mean(self._safemean(val, axis=(0, 1)))) -class MQL(Metric): - """Mean quantile loss.""" +class QuantileMetric(Metric): + """Base class for quantile loss metrics (MQL, WQL, SQL). + + Subclasses implement `_per_quantile_level`. The overall score is the mean over quantile levels, + so `SQL` always equals the mean of `SQL[0.1], SQL[0.5], ...` (single code path, cannot drift). + """ needs_quantiles: bool = True + def _per_quantile_level( + self, + *, + y_true: np.ndarray, + y_pred: np.ndarray, + y_past: np.ndarray, + y_past_lengths: np.ndarray, + q_pred: np.ndarray, + seasonality: int, + quantile_levels: list[float], + ) -> np.ndarray: + """Compute the metric at each quantile level. Returns [Q].""" + raise NotImplementedError + def compute( self, *, @@ -282,13 +325,68 @@ def compute( quantile_levels: list[float], ) -> float: if len(quantile_levels) == 0: - raise ValueError(f"{self.__class__.__name__} cannot be computed without quantile_levels") + raise ValueError(f"{self.name} cannot be computed without quantile_levels") + per_level = self._per_quantile_level( + y_true=y_true, + y_pred=y_pred, + y_past=y_past, + y_past_lengths=y_past_lengths, + q_pred=q_pred, + seasonality=seasonality, + quantile_levels=quantile_levels, + ) # [Q] + return float(np.mean(per_level)) + + def compute_scores( + self, + *, + y_true: np.ndarray, + y_pred: np.ndarray, + y_past: np.ndarray, + y_past_lengths: np.ndarray, + q_pred: np.ndarray, + seasonality: int, + quantile_levels: list[float], + per_quantile_scores: bool = False, + ) -> dict[str, float]: + if len(quantile_levels) == 0: + raise ValueError(f"{self.name} cannot be computed without quantile_levels") + per_level = self._per_quantile_level( + y_true=y_true, + y_pred=y_pred, + y_past=y_past, + y_past_lengths=y_past_lengths, + q_pred=q_pred, + seasonality=seasonality, + quantile_levels=quantile_levels, + ) # [Q] + assert len(per_level) == len(quantile_levels) + scores = {self.name: float(np.mean(per_level))} + if per_quantile_scores: + scores.update({f"{self.name}[{q}]": float(v) for q, v in zip(quantile_levels, per_level)}) + return scores + + +class MQL(QuantileMetric): + """Mean quantile loss.""" + + def _per_quantile_level( + self, + *, + y_true: np.ndarray, + y_pred: np.ndarray, + y_past: np.ndarray, + y_past_lengths: np.ndarray, + q_pred: np.ndarray, + seasonality: int, + quantile_levels: list[float], + ) -> np.ndarray: ql = _quantile_loss(y_true=y_true, q_pred=q_pred, quantile_levels=quantile_levels) # [N, H, D, Q] - per_dim = np.nanmean(ql, axis=(0, 1, 3)) # [D] - return float(np.mean(per_dim)) + per_dim = np.nanmean(ql, axis=(0, 1)) # [D, Q] + return np.mean(per_dim, axis=0) # [Q] -class SQL(Metric): +class SQL(QuantileMetric): """Scaled quantile loss. Warning: @@ -296,12 +394,10 @@ class SQL(Metric): all-NaN history, or zero seasonal error) are excluded from aggregation. """ - needs_quantiles: bool = True - def __init__(self, epsilon: float = 0.0) -> None: self.epsilon = epsilon - def compute( + def _per_quantile_level( self, *, y_true: np.ndarray, @@ -311,26 +407,24 @@ def compute( q_pred: np.ndarray, seasonality: int, quantile_levels: list[float], - ) -> float: + ) -> np.ndarray: ql = _quantile_loss(y_true=y_true, q_pred=q_pred, quantile_levels=quantile_levels) # [N, H, D, Q] - ql_avg_q = np.nanmean(ql, axis=3) # [N, H, D] seasonal_error = _abs_seasonal_error_per_item( y_past=y_past, y_past_lengths=y_past_lengths, seasonality=seasonality ) # [N, D] seasonal_error = np.clip(seasonal_error, self.epsilon, None) - scaled = ql_avg_q / seasonal_error[:, None, :] # [N, H, D] - return float(np.mean(self._safemean(scaled, axis=(0, 1)))) + scaled = ql / seasonal_error[:, None, :, None] # [N, H, D, Q] + per_dim = self._safemean(scaled, axis=(0, 1)) # [D, Q] + return np.mean(per_dim, axis=0) # [Q] -class WQL(Metric): +class WQL(QuantileMetric): """Weighted quantile loss.""" - needs_quantiles: bool = True - def __init__(self, epsilon: float = 0.0) -> None: self.epsilon = epsilon - def compute( + def _per_quantile_level( self, *, y_true: np.ndarray, @@ -340,12 +434,12 @@ def compute( q_pred: np.ndarray, seasonality: int, quantile_levels: list[float], - ) -> float: + ) -> np.ndarray: ql = _quantile_loss(y_true=y_true, q_pred=q_pred, quantile_levels=quantile_levels) # [N, H, D, Q] - ql_per_dim = np.nanmean(ql, axis=(0, 1, 3)) # [D] + ql_per_dim = np.nanmean(ql, axis=(0, 1)) # [D, Q] abs_true_per_dim = np.nanmean(np.abs(y_true), axis=(0, 1)) # [D] - per_dim = ql_per_dim / np.maximum(abs_true_per_dim, self.epsilon) - return float(np.mean(per_dim)) + per_dim = ql_per_dim / np.maximum(abs_true_per_dim, self.epsilon)[:, None] # [D, Q] + return np.mean(per_dim, axis=0) # [Q] def _quantile_loss( diff --git a/src/fev/task.py b/src/fev/task.py index 8397441..607c0a9 100644 --- a/src/fev/task.py +++ b/src/fev/task.py @@ -1,3 +1,4 @@ +import collections import copy import dataclasses import logging @@ -129,12 +130,16 @@ def compute_metrics( metrics: list[Metric], seasonality: int, quantile_levels: list[float], + per_quantile_scores: bool = False, ) -> dict[str, float]: """Compute accuracy metrics on the predictions made for this window. To compute metrics on your predictions, use [`Task.evaluation_summary`][fev.Task.evaluation_summary] instead. This is a convenience method that exists for debugging and additional evaluation. + + If `per_quantile_scores=True`, quantile metrics additionally report a breakdown per quantile level + (e.g. `SQL[0.1]`, `SQL[0.5]`, `SQL[0.9]`) alongside the overall score. """ past_data, _, test_data = self._get_past_future_test_data() @@ -189,14 +194,17 @@ def compute_metrics( with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning) for metric in metrics: - test_scores[metric.name] = metric.compute( - y_true=y_true, - y_pred=y_pred, - y_past=y_past_flat, - y_past_lengths=y_past_lengths, - q_pred=q_pred, - seasonality=seasonality, - quantile_levels=quantile_levels, + test_scores.update( + metric.compute_scores( + y_true=y_true, + y_pred=y_pred, + y_past=y_past_flat, + y_past_lengths=y_past_lengths, + q_pred=q_pred, + seasonality=seasonality, + quantile_levels=quantile_levels, + per_quantile_scores=per_quantile_scores, + ) ) return test_scores @@ -842,6 +850,7 @@ def evaluation_summary( inference_time_s: float | None = None, trained_on_this_dataset: bool = False, extra_info: dict | None = None, + per_quantile_scores: bool = False, ) -> dict[str, Any]: """Get a summary of the model performance for the given forecasting task. @@ -864,6 +873,10 @@ def evaluation_summary( zero-shot mode. extra_info : dict | None Optional dictionary with additional information that will be appended to the evaluation summary. + per_quantile_scores : bool, default False + If True, quantile metrics (MQL, WQL, SQL) additionally report a breakdown per quantile level + (e.g. `SQL[0.1]`, `SQL[0.5]`, `SQL[0.9]`) alongside the overall score. Non-quantile metrics + are unaffected. Returns ------- @@ -884,7 +897,8 @@ def evaluation_summary( metrics = [get_metric(m) for m in [self.eval_metric] + self.extra_metrics] eval_metric = metrics[0] - metrics_per_window = {metric.name: [] for metric in metrics} + # Use defaultdict since per-quantile breakdown adds score keys (e.g. SQL[0.1]) not known up front + metrics_per_window: dict[str, list[float]] = collections.defaultdict(list) if isinstance(predictions_per_window, (datasets.Dataset, datasets.DatasetDict, dict)): raise ValueError( f"predictions_per_window must be iterable (e.g., a list) but got {type(predictions_per_window)}" @@ -900,6 +914,7 @@ def evaluation_summary( metrics=metrics, seasonality=self.seasonality, quantile_levels=self.quantile_levels, + per_quantile_scores=per_quantile_scores, ) for metric, value in metric_scores.items(): metrics_per_window[metric].append(value) diff --git a/test/test_metrics.py b/test/test_metrics.py index 0343be4..f50980f 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -38,6 +38,12 @@ def _to_pandas(ds: datasets.Dataset) -> pd.DataFrame: return task, train_df, test_df, predictor +def _fev_predictions(predictor, train_df): + """Build fev-formatted predictions (one dict per item) from an AutoGluon predictor.""" + ag_predictions = predictor.predict(train_df).rename(columns={"mean": "predictions"}) + return [pred.to_dict("list") for _, pred in ag_predictions.groupby("item_id", as_index=False)] + + @pytest.mark.parametrize("eval_metric", list(AVAILABLE_METRICS)) def test_when_metrics_computed_then_score_matches_autogluon(model_setup, eval_metric): task, train_df, test_df, predictor = model_setup @@ -51,12 +57,7 @@ def test_when_metrics_computed_then_score_matches_autogluon(model_setup, eval_me else: ag_score = predictor.evaluate(full_df, metrics=[task.eval_metric])[task.eval_metric] * -1 - ag_predictions = predictor.predict(train_df).rename(columns={"mean": "predictions"}) - fev_predictions = [] - for _, pred in ag_predictions.groupby("item_id", as_index=False): - fev_predictions.append(pred.to_dict("list")) - - fev_score = task.evaluation_summary([fev_predictions], model_name="")[eval_metric] + fev_score = task.evaluation_summary([_fev_predictions(predictor, train_df)], model_name="")[eval_metric] assert np.isclose(ag_score, fev_score) @@ -110,3 +111,38 @@ def test_seasonal_error_per_item_empty(): result = _seasonal_error_per_item(y_past=flat, y_past_lengths=lengths, seasonality=2, aggregate_fn=np.abs) assert result.size == 0 assert result.dtype == np.float64 + + +@pytest.mark.parametrize("metric_name", ["MQL", "WQL", "SQL"]) +def test_when_per_quantile_scores_then_overall_equals_mean_of_per_level(model_setup, metric_name): + task, train_df, _, predictor = model_setup + task.eval_metric = metric_name + + summary = task.evaluation_summary([_fev_predictions(predictor, train_df)], model_name="", per_quantile_scores=True) + + per_level = [summary[f"{metric_name}[{q}]"] for q in task.quantile_levels] + assert np.isclose(summary[metric_name], np.mean(per_level)) + + +@pytest.mark.parametrize("metric_name", ["MQL", "WQL", "SQL"]) +def test_when_per_quantile_scores_disabled_then_no_per_level_keys(model_setup, metric_name): + task, train_df, _, predictor = model_setup + task.eval_metric = metric_name + + summary = task.evaluation_summary([_fev_predictions(predictor, train_df)], model_name="") + + assert metric_name in summary + assert not any(key.startswith(f"{metric_name}[") for key in summary) + + +def test_when_per_quantile_scores_then_non_quantile_metrics_have_no_breakdown(model_setup): + task, train_df, _, predictor = model_setup + task.eval_metric = "MASE" + task.extra_metrics = ["MAE", "SQL"] + + summary = task.evaluation_summary([_fev_predictions(predictor, train_df)], model_name="", per_quantile_scores=True) + + # Quantile metric is broken down per level + assert all(f"SQL[{q}]" in summary for q in task.quantile_levels) + # Non-quantile metrics emit only their overall score + assert not any(key.startswith("MAE[") or key.startswith("MASE[") for key in summary)