Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions pysteps/verification/probscores.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
ROC_curve_init
ROC_curve_accum
ROC_curve_compute
PR_curve
PR_curve_init
PR_curve_accum
PR_curve_compute
"""

import numpy as np
Expand Down Expand Up @@ -421,3 +425,142 @@ def ROC_curve_compute(ROC, compute_area=False):
return POFD_vals, POD_vals, area
else:
return POFD_vals, POD_vals


def PR_curve(P_f, X_o, X_min, n_prob_thrs=10, compute_area=False):
"""
Compute the Precision–Recall (PR) curve and optionally its area.

Parameters
----------
P_f : array_like
Forecasted probabilities for exceeding the threshold X_min.
Non-finite values are ignored.
X_o : array_like
Observed values. Non-finite values are ignored.
X_min : float
Precipitation intensity threshold for yes/no prediction.
n_prob_thrs : int, optional
Number of probability thresholds to evaluate.
The interval [0, 1] is divided into n_prob_thrs evenly spaced values.
compute_area : bool, optional
If True, compute the area under the PR curve using trapezoidal integration.

Returns
-------
out : tuple
(precision_vals, recall_vals) for each probability threshold.
If compute_area is True, return (precision_vals, recall_vals, area),
where area is the trapezoidal estimate of the PR curve area.
"""
P_f = P_f.copy()
X_o = X_o.copy()
pr = PR_curve_init(X_min, n_prob_thrs)
PR_curve_accum(pr, P_f, X_o)
return PR_curve_compute(pr, X_o, X_min, compute_area)


def PR_curve_init(X_min, n_prob_thrs=10):
"""
Initialize a Precision–Recall curve object.

Parameters
----------
X_min : float
Precipitation intensity threshold for yes/no prediction.
n_prob_thrs : int, optional
Number of probability thresholds to evaluate.

Returns
-------
PR : dict
Dictionary containing counters for hits, misses, false alarms,
correct negatives, and the probability thresholds.
Keys:
- "X_min" : threshold value
- "hits", "misses", "false_alarms", "corr_neg" : arrays of counts
- "prob_thrs" : array of evenly spaced thresholds in [0, 1]
"""
PR = {}
PR["X_min"] = X_min
PR["hits"] = np.zeros(n_prob_thrs, dtype=int)
PR["misses"] = np.zeros(n_prob_thrs, dtype=int)
PR["false_alarms"] = np.zeros(n_prob_thrs, dtype=int)
PR["corr_neg"] = np.zeros(n_prob_thrs, dtype=int)
PR["prob_thrs"] = np.linspace(0.0, 1.0, int(n_prob_thrs))
return PR


def PR_curve_accum(PR, P_f, X_o):
"""
Accumulate forecast–observation pairs into the PR object.

Parameters
----------
PR : dict
A PR curve object created with PR_curve_init.
P_f : array_like
Forecasted probabilities for exceeding X_min.
X_o : array_like
Observed values.

Notes
-----
Updates the counters (hits, misses, false alarms, correct negatives)
for each probability threshold in PR["prob_thrs"].
"""
mask = np.logical_and(np.isfinite(P_f), np.isfinite(X_o))
P_f = P_f[mask]
X_o = X_o[mask]
for i, p in enumerate(PR["prob_thrs"]):
forecast_yes = P_f >= p
obs_yes = X_o >= PR["X_min"]
PR["hits"][i] += np.sum(np.logical_and(forecast_yes, obs_yes))
PR["misses"][i] += np.sum(np.logical_and(~forecast_yes, obs_yes))
PR["false_alarms"][i] += np.sum(np.logical_and(forecast_yes, ~obs_yes))
PR["corr_neg"][i] += np.sum(np.logical_and(~forecast_yes, ~obs_yes))


def PR_curve_compute(PR, X_o, X_min, compute_area=False):
"""
Compute precision and recall values from the PR object.

Parameters
----------
PR : dict
A PR curve object created with PR_curve_init.
X_o : array_like
Observed values (used only if prevalence or area is computed).
X_min : float
Precipitation intensity threshold for yes/no prediction.
compute_area : bool, optional
If True, compute the area under the PR curve.

Returns
-------
out : tuple
(precision_vals, recall_vals) for each probability threshold.
If compute_area is True, return (precision_vals, recall_vals, area),
where area is the trapezoidal estimate of the PR curve area.
"""
precision_vals = []
recall_vals = []

for i in range(len(PR["prob_thrs"])):
hits = PR["hits"][i]
misses = PR["misses"][i]
false_alarms = PR["false_alarms"][i]

recall = hits / (hits + misses) if (hits + misses) > 0 else 0.0
precision = hits / (hits + false_alarms) if (hits + false_alarms) > 0 else 1.0

recall_vals.append(recall)
precision_vals.append(precision)

if compute_area:
# Sort by recall before integration
recall_sorted, precision_sorted = zip(*sorted(zip(recall_vals, precision_vals)))
area = np.trapz(precision_sorted, recall_sorted)
return precision_vals, recall_vals, area
else:
return precision_vals, recall_vals