Skip to content

Commit f5d2177

Browse files
committed
Add weighted quantile and percentile support with tests
1 parent 6a60ed4 commit f5d2177

File tree

2 files changed

+175
-64
lines changed

2 files changed

+175
-64
lines changed

jax/_src/numpy/reductions.py

Lines changed: 108 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2337,7 +2337,7 @@ def cumulative_prod(
23372337
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
23382338
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
23392339
out: None = None, overwrite_input: bool = False, method: str = "linear",
2340-
keepdims: bool = False, weights: ArrayLike | None = None, *,
2340+
keepdims: bool = False, *, weights: ArrayLike | None = None,
23412341
interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array:
23422342
"""Compute the quantile of the data along the specified axis.
23432343
@@ -2395,7 +2395,8 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
23952395
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
23962396
def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
23972397
out: None = None, overwrite_input: bool = False, method: str = "linear",
2398-
keepdims: bool = False, *, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array:
2398+
keepdims: bool = False, *, weights: ArrayLike | None = None,
2399+
interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array:
23992400
"""Compute the quantile of the data along the specified axis, ignoring NaNs.
24002401
24012402
JAX implementation of :func:`numpy.nanquantile`.
@@ -2447,12 +2448,12 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
24472448
("The interpolation= argument to 'nanquantile' is deprecated. "
24482449
"Use 'method=' instead."), stacklevel=2)
24492450
method = interpolation
2450-
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, True)
2451+
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, True, weights)
24512452

24522453
def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24532454
method: str, keepdims: bool, squash_nans: bool, weights: ArrayLike | None = None) -> Array:
2454-
if method not in ["linear", "lower", "higher", "midpoint", "nearest"]:
2455-
raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', or 'nearest'")
2455+
if method not in ["linear", "lower", "higher", "midpoint", "nearest", "inverted_cdf"]:
2456+
raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest', or 'inverted_cdf'")
24562457
a, = promote_dtypes_inexact(a)
24572458
keepdim = []
24582459
if dtypes.issubdtype(a.dtype, np.complexfloating):
@@ -2482,6 +2483,10 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24822483
axis = _canonicalize_axis(-1, a.ndim)
24832484
else:
24842485
axis = _canonicalize_axis(axis, a.ndim)
2486+
2487+
# Ensure q is an array and inexact
2488+
q = lax_internal.asarray(q)
2489+
q, = promote_dtypes_inexact(q)
24852490

24862491
q_shape = q.shape
24872492
q_ndim = q.ndim
@@ -2492,63 +2497,103 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24922497
# Handle weights
24932498
if weights is not None:
24942499
a, weights = promote_dtypes_inexact(a, weights)
2495-
if axis is None:
2496-
a = a.ravel()
2497-
weights = weights.ravel()
2498-
axis = 0
2500+
a_shape = a.shape
2501+
w_shape = np.shape(weights)
2502+
if w_shape != a_shape:
2503+
if len(w_shape) != 1:
2504+
raise ValueError("1D weights expected when shapes of a and weights differ.")
2505+
if axis is None:
2506+
raise TypeError("Axis must be specified when shapes of a and weights differ.")
2507+
if w_shape[0] != a_shape[axis]:
2508+
raise ValueError("Length of weights not compatible with specified axis.")
2509+
resh = [1] * a.ndim
2510+
resh[axis] = w_shape[0]
2511+
weights = lax.reshape(lax_internal.asarray(weights), tuple(resh))
2512+
weights = _broadcast_to(weights, a.shape)
2513+
2514+
if isinstance(weights, core.Tracer):
2515+
weights_arr = None
24992516
else:
2500-
weights = _broadcast_to(weights, a.shape)
2517+
try:
2518+
weights_arr = np.asarray(weights)
2519+
except Exception:
2520+
weights_arr = None
2521+
2522+
if weights_arr is not None:
2523+
if np.any(weights_arr < 0):
2524+
raise ValueError("Weights must be non-negative.")
2525+
if np.all(weights_arr == 0):
2526+
raise ValueError("Sum of weights must not be zero.")
2527+
if np.any(np.isnan(weights_arr)):
2528+
out_shape = q.shape if hasattr(q, "shape") and getattr(q, "ndim", 0) > 0 else ()
2529+
return lax.full(out_shape, np.nan, dtype=a.dtype)
2530+
weights_have_nan = np.any(np.isnan(weights_arr))
2531+
else:
2532+
weights_have_nan = False
2533+
25012534
if squash_nans:
25022535
nan_mask = ~lax_internal._isnan(a)
2503-
if axis is None:
2504-
a = a[nan_mask]
2505-
weights = weights[nan_mask]
2506-
else:
2507-
weights = _where(nan_mask, weights, 0)
2508-
a_sorted, weights_sorted = lax.sort_key_val(a, weights, dimension=axis)
2536+
weights = _where(nan_mask, weights, 0)
2537+
else:
2538+
with jax.debug_nans(False):
2539+
a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a)
25092540

2541+
total_weight = sum(weights, axis=axis, keepdims=True)
2542+
a_sorted, weights_sorted = lax.sort_key_val(a, weights, dimension=axis)
25102543
cum_weights = lax.cumsum(weights_sorted, axis=axis)
2511-
total_weight = lax.sum(weights_sorted, axis=axis, keepdims=True)
2512-
if lax_internal._all(total_weight == 0):
2513-
raise ValueError("Sum of weights must not be zero.")
2514-
cum_weights_norm = cum_weights / total_weight
2515-
quantile_pos = q
2516-
mask = cum_weights_norm >= quantile_pos[..., None]
2517-
idx = lax.argmin(mask.astype(int), axis=axis)
2518-
idx_prev = lax.max(idx - 1, _lax_const(idx, 0))
2519-
idx_next = idx
2520-
gather_shape = list(a_sorted.shape)
2521-
gather_shape[axis] = 1
2544+
cum_weights_norm = lax.div(cum_weights, total_weight)
2545+
2546+
slice_sizes = list(a_sorted.shape)
2547+
slice_sizes[axis] = 1
25222548
dnums = lax.GatherDimensionNumbers(
2523-
offset_dims=tuple(range(len(a_sorted.shape))),
2524-
collapsed_slice_dims=(axis,),
2549+
offset_dims=tuple(range(
2550+
0,
2551+
len(a_sorted.shape) if keepdims else len(a_sorted.shape) - 1)),
2552+
collapsed_slice_dims=() if keepdims else (axis,),
25252553
start_index_map=(axis,))
2526-
prev_value = lax.gather(a_sorted, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=gather_shape)
2527-
next_value = lax.gather(a_sorted, idx_next[..., None], dimension_numbers=dnums, slice_sizes=gather_shape)
2528-
prev_cumw = lax.gather(cum_weights_norm, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=gather_shape)
2529-
next_cumw = lax.gather(cum_weights_norm, idx_next[..., None], dimension_numbers=dnums, slice_sizes=gather_shape)
2530-
2531-
if method == "linear":
2532-
denom = next_cumw - prev_cumw
2533-
denom = lax.select(denom == 0, _lax_const(denom, 1), denom)
2534-
weight = (quantile_pos - prev_cumw) / denom
2535-
result = prev_value * (1 - weight) + next_value * weight
2536-
elif method == "lower":
2537-
result = prev_value
2538-
elif method == "higher":
2539-
result = next_value
2540-
elif method == "nearest":
2541-
use_prev = (quantile_pos - prev_cumw) < (next_cumw - quantile_pos)
2542-
result = lax.select(use_prev, prev_value, next_value)
2543-
elif method == "midpoint":
2544-
result = (prev_value + next_value) / 2
2545-
else:
2546-
raise ValueError(f"{method=!r} not recognized")
2547-
2548-
if not keepdims:
2549-
result = lax.squeeze(result, axis)
2550-
return lax.convert_element_type(result, a.dtype)
25512554

2555+
def _weighted_quantile(qi, weights_have_nan=weights_have_nan):
2556+
index_dtype = dtypes.canonicalize_dtype(int)
2557+
idx = sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype)
2558+
idx = lax.clamp(0, idx, a_sorted.shape[axis] - 1)
2559+
slicer = [slice(None)] * a_sorted.ndim
2560+
slicer[axis] = idx
2561+
val = a_sorted[tuple(slicer)]
2562+
2563+
idx_prev = lax.clamp(idx - 1, 0, a_sorted.shape[axis] - 1)
2564+
slicer_prev = slicer.copy()
2565+
slicer_prev[axis] = idx_prev
2566+
val_prev = a_sorted[tuple(slicer_prev)]
2567+
cw_prev = cum_weights_norm[tuple(slicer_prev)]
2568+
cw_next = cum_weights_norm[tuple(slicer)]
2569+
2570+
if method == "linear":
2571+
denom = cw_next - cw_prev
2572+
denom = _where(denom == 0, 1, denom)
2573+
weight = (qi - cw_prev) / denom
2574+
out = val_prev * (1 - weight) + val * weight
2575+
elif method == "lower":
2576+
out = val_prev
2577+
elif method == "higher":
2578+
out = val
2579+
elif method == "nearest":
2580+
out = _where(lax.abs(qi - cw_prev) < lax.abs(qi - cw_next), val_prev, val)
2581+
elif method == "midpoint":
2582+
out = (val_prev + val) / 2
2583+
elif method == "inverted_cdf":
2584+
out = val
2585+
else:
2586+
raise ValueError(f"{method=!r} not recognized")
2587+
if weights_have_nan:
2588+
out = lax.full_like(out, np.nan)
2589+
out = lax.squeeze(out, axis=axis)
2590+
return out
2591+
2592+
if q.ndim == 0:
2593+
result = _weighted_quantile(q)
2594+
else:
2595+
result = jax.vmap(_weighted_quantile)(q)
2596+
return result
25522597

25532598
if squash_nans:
25542599
a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
@@ -2566,10 +2611,10 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25662611

25672612
low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1))
25682613
high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1))
2569-
low = lax.convert_element_type(low, int)
2570-
high = lax.convert_element_type(high, int)
2614+
low = lax.convert_element_type(low, dtypes.canonicalize_dtype(int))
2615+
high = lax.convert_element_type(high, dtypes.canonicalize_dtype(int))
25712616
out_shape = q_shape + shape_after_reduction
2572-
index = [lax.broadcasted_iota(int, out_shape, dim + q_ndim)
2617+
index = [lax.broadcasted_iota(dtypes.canonicalize_dtype(int), out_shape, dim + q_ndim)
25732618
for dim in range(len(shape_after_reduction))]
25742619
if keepdims:
25752620
index[axis] = low
@@ -2591,8 +2636,8 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25912636

25922637
low = lax.clamp(_lax_const(low, 0), low, n - 1)
25932638
high = lax.clamp(_lax_const(high, 0), high, n - 1)
2594-
low = lax.convert_element_type(low, int)
2595-
high = lax.convert_element_type(high, int)
2639+
low = lax.convert_element_type(low, dtypes.int_)
2640+
high = lax.convert_element_type(high, dtypes.int_)
25962641

25972642
slice_sizes = list(a_shape)
25982643
slice_sizes[axis] = 1
@@ -2624,6 +2669,8 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
26242669
result = lax.select(pred, low_value, high_value)
26252670
elif method == "midpoint":
26262671
result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5))
2672+
elif method == "inverted_cdf":
2673+
result = high_value
26272674
else:
26282675
raise ValueError(f"{method=!r} not recognized")
26292676
if keepdims and keepdim:
@@ -2639,7 +2686,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
26392686
def percentile(a: ArrayLike, q: ArrayLike,
26402687
axis: int | tuple[int, ...] | None = None,
26412688
out: None = None, overwrite_input: bool = False, method: str = "linear",
2642-
keepdims: bool = False, weights: ArrayLike | None = None, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array:
2689+
keepdims: bool = False, *, weights: ArrayLike | None = None, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array:
26432690
"""Compute the percentile of the data along the specified axis.
26442691
26452692
JAX implementation of :func:`numpy.percentile`.
@@ -2697,7 +2744,7 @@ def percentile(a: ArrayLike, q: ArrayLike,
26972744
def nanpercentile(a: ArrayLike, q: ArrayLike,
26982745
axis: int | tuple[int, ...] | None = None,
26992746
out: None = None, overwrite_input: bool = False, method: str = "linear",
2700-
keepdims: bool = False, weights: ArrayLike | None = None, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array:
2747+
keepdims: bool = False, *, weights: ArrayLike | None = None, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array:
27012748
"""Compute the percentile of the data along the specified axis, ignoring NaN values.
27022749
27032750
JAX implementation of :func:`numpy.nanpercentile`.

tests/lax_numpy_reducers_test.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
from functools import partial
1818
import itertools
1919
import unittest
20+
import pytest
2021

2122
from absl.testing import absltest
2223
from absl.testing import parameterized
2324

2425
import numpy as np
26+
from jax._src.numpy.reductions import _quantile
2527

2628
import jax
2729
from jax import numpy as jnp
@@ -764,13 +766,75 @@ def testPercentilePrecision(self):
764766
x = jnp.float64([1, 2, 3, 4, 7, 10])
765767
self.assertEqual(jnp.percentile(x, 50), 3.5)
766768

769+
def test_weighted_quantile_all_weights_one(self):
770+
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
771+
weights = jnp.ones_like(a)
772+
q = jnp.array([0.25, 0.5, 0.75])
773+
result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)
774+
expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf")
775+
np.testing.assert_allclose(np.array(result), expected, rtol=1e-6)
776+
777+
def test_weighted_quantile_multiple_q(self):
778+
a = jnp.arange(10, dtype=float)
779+
weights = jnp.ones_like(a)
780+
q = jnp.array([0.25, 0.5, 0.75])
781+
result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)
782+
expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf")
783+
np.testing.assert_allclose(np.array(result), expected, rtol=1e-6)
784+
785+
def test_weighted_quantile_keepdims(self):
786+
a = jnp.array([1, 2, 3, 4], dtype=float)
787+
weights = jnp.array([1, 1, 1, 1], dtype=float)
788+
q = 0.5
789+
result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=True, squash_nans=False, weights=weights)
790+
expected = np.quantile(np.array(a), np.array(q), axis=0, keepdims=True, weights=np.array(weights), method="inverted_cdf")
791+
np.testing.assert_allclose(np.array(result), expected, rtol=1e-6)
792+
767793
def test_weighted_quantile_linear(self):
768794
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
769795
weights = jnp.array([1, 2, 1, 1, 1], dtype=float)
770796
q = jnp.array([0.5])
771-
expected = np.quantile(a, q, weights=weights)
772-
result = quantile(a, q, weights=weights, method="linear")
773-
np.testing.assert_allclose(result, expected, rtol=1e-6)
797+
result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)
798+
expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf")
799+
np.testing.assert_allclose(np.array(result), expected, rtol=1e-6)
800+
801+
def test_weighted_quantile_negative_weights(self):
802+
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
803+
weights = jnp.array([1, -1, 1, 1, 1], dtype=float)
804+
q = jnp.array([0.5])
805+
with pytest.raises(ValueError):
806+
_quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights)
807+
808+
def test_weighted_quantile_all_weights_zero(self):
809+
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
810+
weights = jnp.zeros_like(a)
811+
q = jnp.array([0.5])
812+
with pytest.raises(ValueError):
813+
_quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights)
814+
815+
def test_weighted_quantile_weights_with_nan(self):
816+
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
817+
weights = jnp.array([1, np.nan, 1, 1, 1], dtype=float)
818+
q = jnp.array([0.5])
819+
result = _quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights)
820+
assert np.isnan(np.array(result)).all()
821+
822+
def test_weighted_quantile_scalar_q(self):
823+
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
824+
weights = jnp.array([1, 2, 1, 1, 1], dtype=float)
825+
q = 0.5
826+
result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)
827+
assert jnp.issubdtype(result.dtype, jnp.floating)
828+
assert result.shape == ()
829+
830+
def test_weighted_quantile_jit(self):
831+
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
832+
weights = jnp.array([1, 2, 1, 1, 1], dtype=float)
833+
q = jnp.array([0.25, 0.5, 0.75])
834+
quantile_jit = jax.jit(lambda a, q, weights: _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights))
835+
result = quantile_jit(a, q, weights)
836+
expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf")
837+
np.testing.assert_allclose(np.array(result), expected, rtol=1e-6)
774838

775839
@jtu.sample_product(
776840
[dict(a_shape=a_shape, axis=axis)

0 commit comments

Comments
 (0)