From 6a60ed40bca18acf6e388516b279a0e0321b10dd Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Sat, 18 Oct 2025 20:48:13 +0530 Subject: [PATCH 01/12] Add weighted quantile and percentile support with tests --- jax/_src/numpy/reductions.py | 75 +++++++++++++++++++++++++++++--- tests/lax_numpy_reducers_test.py | 9 ++++ 2 files changed, 77 insertions(+), 7 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 985b296bc06f..02efc9bb8c69 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -2337,7 +2337,8 @@ def cumulative_prod( @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: + keepdims: bool = False, weights: ArrayLike | None = None, *, + interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: """Compute the quantile of the data along the specified axis. JAX implementation of :func:`numpy.quantile`. @@ -2387,7 +2388,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No ("The interpolation= argument to 'quantile' is deprecated. " "Use 'method=' instead."), stacklevel=2) method = interpolation - return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False) + return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False, weights) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 @export @@ -2449,7 +2450,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, True) def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, - method: str, keepdims: bool, squash_nans: bool) -> Array: + method: str, keepdims: bool, squash_nans: bool, weights: ArrayLike | None = None) -> Array: if method not in ["linear", "lower", "higher", "midpoint", "nearest"]: raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', or 'nearest'") a, = promote_dtypes_inexact(a) @@ -2488,6 +2489,66 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, raise ValueError(f"q must be have rank <= 1, got shape {q.shape}") a_shape = a.shape + # Handle weights + if weights is not None: + a, weights = promote_dtypes_inexact(a, weights) + if axis is None: + a = a.ravel() + weights = weights.ravel() + axis = 0 + else: + weights = _broadcast_to(weights, a.shape) + if squash_nans: + nan_mask = ~lax_internal._isnan(a) + if axis is None: + a = a[nan_mask] + weights = weights[nan_mask] + else: + weights = _where(nan_mask, weights, 0) + a_sorted, weights_sorted = lax.sort_key_val(a, weights, dimension=axis) + + cum_weights = lax.cumsum(weights_sorted, axis=axis) + total_weight = lax.sum(weights_sorted, axis=axis, keepdims=True) + if lax_internal._all(total_weight == 0): + raise ValueError("Sum of weights must not be zero.") + cum_weights_norm = cum_weights / total_weight + quantile_pos = q + mask = cum_weights_norm >= quantile_pos[..., None] + idx = lax.argmin(mask.astype(int), axis=axis) + idx_prev = lax.max(idx - 1, _lax_const(idx, 0)) + idx_next = idx + gather_shape = list(a_sorted.shape) + gather_shape[axis] = 1 + dnums = lax.GatherDimensionNumbers( + offset_dims=tuple(range(len(a_sorted.shape))), + collapsed_slice_dims=(axis,), + start_index_map=(axis,)) + prev_value = lax.gather(a_sorted, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=gather_shape) + next_value = lax.gather(a_sorted, idx_next[..., None], dimension_numbers=dnums, slice_sizes=gather_shape) + prev_cumw = lax.gather(cum_weights_norm, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=gather_shape) + next_cumw = lax.gather(cum_weights_norm, idx_next[..., None], dimension_numbers=dnums, slice_sizes=gather_shape) + + if method == "linear": + denom = next_cumw - prev_cumw + denom = lax.select(denom == 0, _lax_const(denom, 1), denom) + weight = (quantile_pos - prev_cumw) / denom + result = prev_value * (1 - weight) + next_value * weight + elif method == "lower": + result = prev_value + elif method == "higher": + result = next_value + elif method == "nearest": + use_prev = (quantile_pos - prev_cumw) < (next_cumw - quantile_pos) + result = lax.select(use_prev, prev_value, next_value) + elif method == "midpoint": + result = (prev_value + next_value) / 2 + else: + raise ValueError(f"{method=!r} not recognized") + + if not keepdims: + result = lax.squeeze(result, axis) + return lax.convert_element_type(result, a.dtype) + if squash_nans: a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. @@ -2578,7 +2639,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, def percentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: + keepdims: bool = False, weights: ArrayLike | None = None, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: """Compute the percentile of the data along the specified axis. JAX implementation of :func:`numpy.percentile`. @@ -2627,7 +2688,7 @@ def percentile(a: ArrayLike, q: ArrayLike, "Use 'method=' instead."), stacklevel=2) method = interpolation return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, - method=method, keepdims=keepdims) + method=method, keepdims=keepdims, weights=weights) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 @@ -2636,7 +2697,7 @@ def percentile(a: ArrayLike, q: ArrayLike, def nanpercentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: + keepdims: bool = False, weights: ArrayLike | None = None, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: """Compute the percentile of the data along the specified axis, ignoring NaN values. JAX implementation of :func:`numpy.nanpercentile`. @@ -2688,7 +2749,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, "Use 'method=' instead."), stacklevel=2) method = interpolation return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input, - method=method, keepdims=keepdims) + method=method, keepdims=keepdims, weights=weights) @export diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index aa5e08e96a3e..1e06f96cd87b 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -28,6 +28,7 @@ from jax._src import config from jax._src import dtypes +from jax._src.numpy.reductions import quantile from jax._src import test_util as jtu from jax._src.util import NumpyComplexWarning @@ -763,6 +764,14 @@ def testPercentilePrecision(self): x = jnp.float64([1, 2, 3, 4, 7, 10]) self.assertEqual(jnp.percentile(x, 50), 3.5) + def test_weighted_quantile_linear(self): + a = jnp.array([1, 2, 3, 4, 5], dtype=float) + weights = jnp.array([1, 2, 1, 1, 1], dtype=float) + q = jnp.array([0.5]) + expected = np.quantile(a, q, weights=weights) + result = quantile(a, q, weights=weights, method="linear") + np.testing.assert_allclose(result, expected, rtol=1e-6) + @jtu.sample_product( [dict(a_shape=a_shape, axis=axis) for a_shape, axis in ( From f5d2177abf7d08939637c3a4921cd3fe9dd4c308 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Fri, 24 Oct 2025 13:58:05 +0530 Subject: [PATCH 02/12] Add weighted quantile and percentile support with tests --- jax/_src/numpy/reductions.py | 169 ++++++++++++++++++++----------- tests/lax_numpy_reducers_test.py | 70 ++++++++++++- 2 files changed, 175 insertions(+), 64 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 02efc9bb8c69..93cfa9e4d0ed 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -2337,7 +2337,7 @@ def cumulative_prod( @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, weights: ArrayLike | None = None, *, + keepdims: bool = False, *, weights: ArrayLike | None = None, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: """Compute the quantile of the data along the specified axis. @@ -2395,7 +2395,8 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: + keepdims: bool = False, *, weights: ArrayLike | None = None, + interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: """Compute the quantile of the data along the specified axis, ignoring NaNs. JAX implementation of :func:`numpy.nanquantile`. @@ -2447,12 +2448,12 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ("The interpolation= argument to 'nanquantile' is deprecated. " "Use 'method=' instead."), stacklevel=2) method = interpolation - return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, True) + return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, True, weights) def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, method: str, keepdims: bool, squash_nans: bool, weights: ArrayLike | None = None) -> Array: - if method not in ["linear", "lower", "higher", "midpoint", "nearest"]: - raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', or 'nearest'") + if method not in ["linear", "lower", "higher", "midpoint", "nearest", "inverted_cdf"]: + raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest', or 'inverted_cdf'") a, = promote_dtypes_inexact(a) keepdim = [] if dtypes.issubdtype(a.dtype, np.complexfloating): @@ -2482,6 +2483,10 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, axis = _canonicalize_axis(-1, a.ndim) else: axis = _canonicalize_axis(axis, a.ndim) + + # Ensure q is an array and inexact + q = lax_internal.asarray(q) + q, = promote_dtypes_inexact(q) q_shape = q.shape q_ndim = q.ndim @@ -2492,63 +2497,103 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, # Handle weights if weights is not None: a, weights = promote_dtypes_inexact(a, weights) - if axis is None: - a = a.ravel() - weights = weights.ravel() - axis = 0 + a_shape = a.shape + w_shape = np.shape(weights) + if w_shape != a_shape: + if len(w_shape) != 1: + raise ValueError("1D weights expected when shapes of a and weights differ.") + if axis is None: + raise TypeError("Axis must be specified when shapes of a and weights differ.") + if w_shape[0] != a_shape[axis]: + raise ValueError("Length of weights not compatible with specified axis.") + resh = [1] * a.ndim + resh[axis] = w_shape[0] + weights = lax.reshape(lax_internal.asarray(weights), tuple(resh)) + weights = _broadcast_to(weights, a.shape) + + if isinstance(weights, core.Tracer): + weights_arr = None else: - weights = _broadcast_to(weights, a.shape) + try: + weights_arr = np.asarray(weights) + except Exception: + weights_arr = None + + if weights_arr is not None: + if np.any(weights_arr < 0): + raise ValueError("Weights must be non-negative.") + if np.all(weights_arr == 0): + raise ValueError("Sum of weights must not be zero.") + if np.any(np.isnan(weights_arr)): + out_shape = q.shape if hasattr(q, "shape") and getattr(q, "ndim", 0) > 0 else () + return lax.full(out_shape, np.nan, dtype=a.dtype) + weights_have_nan = np.any(np.isnan(weights_arr)) + else: + weights_have_nan = False + if squash_nans: nan_mask = ~lax_internal._isnan(a) - if axis is None: - a = a[nan_mask] - weights = weights[nan_mask] - else: - weights = _where(nan_mask, weights, 0) - a_sorted, weights_sorted = lax.sort_key_val(a, weights, dimension=axis) + weights = _where(nan_mask, weights, 0) + else: + with jax.debug_nans(False): + a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a) + total_weight = sum(weights, axis=axis, keepdims=True) + a_sorted, weights_sorted = lax.sort_key_val(a, weights, dimension=axis) cum_weights = lax.cumsum(weights_sorted, axis=axis) - total_weight = lax.sum(weights_sorted, axis=axis, keepdims=True) - if lax_internal._all(total_weight == 0): - raise ValueError("Sum of weights must not be zero.") - cum_weights_norm = cum_weights / total_weight - quantile_pos = q - mask = cum_weights_norm >= quantile_pos[..., None] - idx = lax.argmin(mask.astype(int), axis=axis) - idx_prev = lax.max(idx - 1, _lax_const(idx, 0)) - idx_next = idx - gather_shape = list(a_sorted.shape) - gather_shape[axis] = 1 + cum_weights_norm = lax.div(cum_weights, total_weight) + + slice_sizes = list(a_sorted.shape) + slice_sizes[axis] = 1 dnums = lax.GatherDimensionNumbers( - offset_dims=tuple(range(len(a_sorted.shape))), - collapsed_slice_dims=(axis,), + offset_dims=tuple(range( + 0, + len(a_sorted.shape) if keepdims else len(a_sorted.shape) - 1)), + collapsed_slice_dims=() if keepdims else (axis,), start_index_map=(axis,)) - prev_value = lax.gather(a_sorted, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=gather_shape) - next_value = lax.gather(a_sorted, idx_next[..., None], dimension_numbers=dnums, slice_sizes=gather_shape) - prev_cumw = lax.gather(cum_weights_norm, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=gather_shape) - next_cumw = lax.gather(cum_weights_norm, idx_next[..., None], dimension_numbers=dnums, slice_sizes=gather_shape) - - if method == "linear": - denom = next_cumw - prev_cumw - denom = lax.select(denom == 0, _lax_const(denom, 1), denom) - weight = (quantile_pos - prev_cumw) / denom - result = prev_value * (1 - weight) + next_value * weight - elif method == "lower": - result = prev_value - elif method == "higher": - result = next_value - elif method == "nearest": - use_prev = (quantile_pos - prev_cumw) < (next_cumw - quantile_pos) - result = lax.select(use_prev, prev_value, next_value) - elif method == "midpoint": - result = (prev_value + next_value) / 2 - else: - raise ValueError(f"{method=!r} not recognized") - - if not keepdims: - result = lax.squeeze(result, axis) - return lax.convert_element_type(result, a.dtype) + def _weighted_quantile(qi, weights_have_nan=weights_have_nan): + index_dtype = dtypes.canonicalize_dtype(int) + idx = sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype) + idx = lax.clamp(0, idx, a_sorted.shape[axis] - 1) + slicer = [slice(None)] * a_sorted.ndim + slicer[axis] = idx + val = a_sorted[tuple(slicer)] + + idx_prev = lax.clamp(idx - 1, 0, a_sorted.shape[axis] - 1) + slicer_prev = slicer.copy() + slicer_prev[axis] = idx_prev + val_prev = a_sorted[tuple(slicer_prev)] + cw_prev = cum_weights_norm[tuple(slicer_prev)] + cw_next = cum_weights_norm[tuple(slicer)] + + if method == "linear": + denom = cw_next - cw_prev + denom = _where(denom == 0, 1, denom) + weight = (qi - cw_prev) / denom + out = val_prev * (1 - weight) + val * weight + elif method == "lower": + out = val_prev + elif method == "higher": + out = val + elif method == "nearest": + out = _where(lax.abs(qi - cw_prev) < lax.abs(qi - cw_next), val_prev, val) + elif method == "midpoint": + out = (val_prev + val) / 2 + elif method == "inverted_cdf": + out = val + else: + raise ValueError(f"{method=!r} not recognized") + if weights_have_nan: + out = lax.full_like(out, np.nan) + out = lax.squeeze(out, axis=axis) + return out + + if q.ndim == 0: + result = _weighted_quantile(q) + else: + result = jax.vmap(_weighted_quantile)(q) + return result if squash_nans: a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. @@ -2566,10 +2611,10 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1)) high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1)) - low = lax.convert_element_type(low, int) - high = lax.convert_element_type(high, int) + low = lax.convert_element_type(low, dtypes.canonicalize_dtype(int)) + high = lax.convert_element_type(high, dtypes.canonicalize_dtype(int)) out_shape = q_shape + shape_after_reduction - index = [lax.broadcasted_iota(int, out_shape, dim + q_ndim) + index = [lax.broadcasted_iota(dtypes.canonicalize_dtype(int), out_shape, dim + q_ndim) for dim in range(len(shape_after_reduction))] if keepdims: index[axis] = low @@ -2591,8 +2636,8 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, low = lax.clamp(_lax_const(low, 0), low, n - 1) high = lax.clamp(_lax_const(high, 0), high, n - 1) - low = lax.convert_element_type(low, int) - high = lax.convert_element_type(high, int) + low = lax.convert_element_type(low, dtypes.int_) + high = lax.convert_element_type(high, dtypes.int_) slice_sizes = list(a_shape) slice_sizes[axis] = 1 @@ -2624,6 +2669,8 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, result = lax.select(pred, low_value, high_value) elif method == "midpoint": result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5)) + elif method == "inverted_cdf": + result = high_value else: raise ValueError(f"{method=!r} not recognized") if keepdims and keepdim: @@ -2639,7 +2686,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, def percentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, weights: ArrayLike | None = None, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: + keepdims: bool = False, *, weights: ArrayLike | None = None, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: """Compute the percentile of the data along the specified axis. JAX implementation of :func:`numpy.percentile`. @@ -2697,7 +2744,7 @@ def percentile(a: ArrayLike, q: ArrayLike, def nanpercentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, weights: ArrayLike | None = None, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: + keepdims: bool = False, *, weights: ArrayLike | None = None, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: """Compute the percentile of the data along the specified axis, ignoring NaN values. JAX implementation of :func:`numpy.nanpercentile`. diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 1e06f96cd87b..88f7ca7e584d 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -17,11 +17,13 @@ from functools import partial import itertools import unittest +import pytest from absl.testing import absltest from absl.testing import parameterized import numpy as np +from jax._src.numpy.reductions import _quantile import jax from jax import numpy as jnp @@ -764,13 +766,75 @@ def testPercentilePrecision(self): x = jnp.float64([1, 2, 3, 4, 7, 10]) self.assertEqual(jnp.percentile(x, 50), 3.5) + def test_weighted_quantile_all_weights_one(self): + a = jnp.array([1, 2, 3, 4, 5], dtype=float) + weights = jnp.ones_like(a) + q = jnp.array([0.25, 0.5, 0.75]) + result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights) + expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf") + np.testing.assert_allclose(np.array(result), expected, rtol=1e-6) + + def test_weighted_quantile_multiple_q(self): + a = jnp.arange(10, dtype=float) + weights = jnp.ones_like(a) + q = jnp.array([0.25, 0.5, 0.75]) + result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights) + expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf") + np.testing.assert_allclose(np.array(result), expected, rtol=1e-6) + + def test_weighted_quantile_keepdims(self): + a = jnp.array([1, 2, 3, 4], dtype=float) + weights = jnp.array([1, 1, 1, 1], dtype=float) + q = 0.5 + result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=True, squash_nans=False, weights=weights) + expected = np.quantile(np.array(a), np.array(q), axis=0, keepdims=True, weights=np.array(weights), method="inverted_cdf") + np.testing.assert_allclose(np.array(result), expected, rtol=1e-6) + def test_weighted_quantile_linear(self): a = jnp.array([1, 2, 3, 4, 5], dtype=float) weights = jnp.array([1, 2, 1, 1, 1], dtype=float) q = jnp.array([0.5]) - expected = np.quantile(a, q, weights=weights) - result = quantile(a, q, weights=weights, method="linear") - np.testing.assert_allclose(result, expected, rtol=1e-6) + result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights) + expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf") + np.testing.assert_allclose(np.array(result), expected, rtol=1e-6) + + def test_weighted_quantile_negative_weights(self): + a = jnp.array([1, 2, 3, 4, 5], dtype=float) + weights = jnp.array([1, -1, 1, 1, 1], dtype=float) + q = jnp.array([0.5]) + with pytest.raises(ValueError): + _quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights) + + def test_weighted_quantile_all_weights_zero(self): + a = jnp.array([1, 2, 3, 4, 5], dtype=float) + weights = jnp.zeros_like(a) + q = jnp.array([0.5]) + with pytest.raises(ValueError): + _quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights) + + def test_weighted_quantile_weights_with_nan(self): + a = jnp.array([1, 2, 3, 4, 5], dtype=float) + weights = jnp.array([1, np.nan, 1, 1, 1], dtype=float) + q = jnp.array([0.5]) + result = _quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights) + assert np.isnan(np.array(result)).all() + + def test_weighted_quantile_scalar_q(self): + a = jnp.array([1, 2, 3, 4, 5], dtype=float) + weights = jnp.array([1, 2, 1, 1, 1], dtype=float) + q = 0.5 + result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights) + assert jnp.issubdtype(result.dtype, jnp.floating) + assert result.shape == () + + def test_weighted_quantile_jit(self): + a = jnp.array([1, 2, 3, 4, 5], dtype=float) + weights = jnp.array([1, 2, 1, 1, 1], dtype=float) + q = jnp.array([0.25, 0.5, 0.75]) + quantile_jit = jax.jit(lambda a, q, weights: _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)) + result = quantile_jit(a, q, weights) + expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf") + np.testing.assert_allclose(np.array(result), expected, rtol=1e-6) @jtu.sample_product( [dict(a_shape=a_shape, axis=axis) From f7ab68305f08386969ebcab2d50e59a95e71b190 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Fri, 24 Oct 2025 23:10:59 +0530 Subject: [PATCH 03/12] Add weighted quantile and percentile support with tests --- jax/_src/numpy/reductions.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 93cfa9e4d0ed..dcba85a0881a 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -2383,11 +2383,8 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No raise ValueError("jax.numpy.quantile does not support overwrite_input=True " "or out != None") if not isinstance(interpolation, DeprecatedArg): - deprecations.warn( - "jax-numpy-quantile-interpolation", - ("The interpolation= argument to 'quantile' is deprecated. " - "Use 'method=' instead."), stacklevel=2) - method = interpolation + raise TypeError("nanquantile() argument interpolation was removed in JAX" + " v0.8.0. Use method instead.") return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False, weights) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 @@ -2443,11 +2440,8 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = "out != None") raise ValueError(msg) if not isinstance(interpolation, DeprecatedArg): - deprecations.warn( - "jax-numpy-quantile-interpolation", - ("The interpolation= argument to 'nanquantile' is deprecated. " - "Use 'method=' instead."), stacklevel=2) - method = interpolation + raise TypeError("nanquantile() argument interpolation was removed in JAX" + " v0.8.0. Use method instead.") return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, True, weights) def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, @@ -2485,7 +2479,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, axis = _canonicalize_axis(axis, a.ndim) # Ensure q is an array and inexact - q = lax_internal.asarray(q) + q = lax.asarray(q) q, = promote_dtypes_inexact(q) q_shape = q.shape From 5f881bf104749261ef1e4a2efdb19363217264f5 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Fri, 24 Oct 2025 23:23:20 +0530 Subject: [PATCH 04/12] Add weighted quantile and percentile support with tests --- jax/_src/numpy/reductions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index dcba85a0881a..239d2d5affa4 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -2383,9 +2383,9 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No raise ValueError("jax.numpy.quantile does not support overwrite_input=True " "or out != None") if not isinstance(interpolation, DeprecatedArg): - raise TypeError("nanquantile() argument interpolation was removed in JAX" + raise TypeError("quantile() argument interpolation was removed in JAX" " v0.8.0. Use method instead.") - return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False, weights) + return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, False, weights) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 @export @@ -2442,7 +2442,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = if not isinstance(interpolation, DeprecatedArg): raise TypeError("nanquantile() argument interpolation was removed in JAX" " v0.8.0. Use method instead.") - return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, True, weights) + return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, True, weights) def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, method: str, keepdims: bool, squash_nans: bool, weights: ArrayLike | None = None) -> Array: From 7b967cbbc01bb2fd58c91d5d11812d51a08512ab Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Sat, 25 Oct 2025 02:54:38 +0530 Subject: [PATCH 05/12] Add weighted quantile and percentile support with tests --- jax/_src/numpy/reductions.py | 100 ++++++++++++++----------------- tests/lax_numpy_reducers_test.py | 20 +++---- 2 files changed, 53 insertions(+), 67 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 239d2d5affa4..71c4bd7201a2 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -25,12 +25,13 @@ import jax from jax import lax +from jax import numpy as jnp from jax._src import api from jax._src import core from jax._src import deprecations from jax._src import dtypes from jax._src.numpy.util import ( - _broadcast_to, check_arraylike, _complex_elem_type, + _broadcast_to, check_arraylike, ensure_arraylike, _complex_elem_type, promote_dtypes_inexact, promote_dtypes_numeric, _where) from jax._src.lax import lax as lax_internal from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg @@ -2379,6 +2380,10 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No Array([2., 4., 7.], dtype=float32) """ check_arraylike("quantile", a, q) + if weights is None: + a, q = ensure_arraylike("quantile", a, q) + else: + a, q, weights = ensure_arraylike("quantile", a, q, weights) if overwrite_input or out is not None: raise ValueError("jax.numpy.quantile does not support overwrite_input=True " "or out != None") @@ -2435,6 +2440,10 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = Array([1.5, 3. , 4.5], dtype=float32) """ check_arraylike("nanquantile", a, q) + if weights is None: + a, q = ensure_arraylike("nanquantile", a, q) + else: + a, q, weights = ensure_arraylike("nanquantile", a, q, weights) if overwrite_input or out is not None: msg = ("jax.numpy.nanquantile does not support overwrite_input=True or " "out != None") @@ -2445,10 +2454,9 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, True, weights) def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, - method: str, keepdims: bool, squash_nans: bool, weights: ArrayLike | None = None) -> Array: + method: str, keepdims: bool, squash_nans: bool, weights: Array | None = None) -> Array: if method not in ["linear", "lower", "higher", "midpoint", "nearest", "inverted_cdf"]: raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest', or 'inverted_cdf'") - a, = promote_dtypes_inexact(a) keepdim = [] if dtypes.issubdtype(a.dtype, np.complexfloating): raise ValueError("quantile does not support complex input, as the operation is poorly defined.") @@ -2477,9 +2485,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, axis = _canonicalize_axis(-1, a.ndim) else: axis = _canonicalize_axis(axis, a.ndim) - - # Ensure q is an array and inexact - q = lax.asarray(q) + q, = promote_dtypes_inexact(q) q_shape = q.shape @@ -2489,8 +2495,10 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, a_shape = a.shape # Handle weights - if weights is not None: - a, weights = promote_dtypes_inexact(a, weights) + if weights is None: + a, = promote_dtypes_inexact(a) + else: + a, q = promote_dtypes_inexact(a, q) a_shape = a.shape w_shape = np.shape(weights) if w_shape != a_shape: @@ -2502,28 +2510,13 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, raise ValueError("Length of weights not compatible with specified axis.") resh = [1] * a.ndim resh[axis] = w_shape[0] - weights = lax.reshape(lax_internal.asarray(weights), tuple(resh)) + weights = lax.expand_dims(weights, axis) weights = _broadcast_to(weights, a.shape) - if isinstance(weights, core.Tracer): - weights_arr = None - else: - try: - weights_arr = np.asarray(weights) - except Exception: - weights_arr = None - - if weights_arr is not None: - if np.any(weights_arr < 0): - raise ValueError("Weights must be non-negative.") - if np.all(weights_arr == 0): - raise ValueError("Sum of weights must not be zero.") - if np.any(np.isnan(weights_arr)): - out_shape = q.shape if hasattr(q, "shape") and getattr(q, "ndim", 0) > 0 else () - return lax.full(out_shape, np.nan, dtype=a.dtype) - weights_have_nan = np.any(np.isnan(weights_arr)) - else: - weights_have_nan = False + weights_have_nan = jnp.any(jnp.isnan(weights)) + if weights_have_nan: + out_shape = q.shape if hasattr(q, "shape") and getattr(q, "ndim", 0) > 0 else () + return lax.full(out_shape, np.nan, dtype=a.dtype) if squash_nans: nan_mask = ~lax_internal._isnan(a) @@ -2537,29 +2530,16 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, cum_weights = lax.cumsum(weights_sorted, axis=axis) cum_weights_norm = lax.div(cum_weights, total_weight) - slice_sizes = list(a_sorted.shape) - slice_sizes[axis] = 1 - dnums = lax.GatherDimensionNumbers( - offset_dims=tuple(range( - 0, - len(a_sorted.shape) if keepdims else len(a_sorted.shape) - 1)), - collapsed_slice_dims=() if keepdims else (axis,), - start_index_map=(axis,)) - def _weighted_quantile(qi, weights_have_nan=weights_have_nan): - index_dtype = dtypes.canonicalize_dtype(int) + index_dtype = dtypes.default_int_dtype() idx = sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype) idx = lax.clamp(0, idx, a_sorted.shape[axis] - 1) - slicer = [slice(None)] * a_sorted.ndim - slicer[axis] = idx - val = a_sorted[tuple(slicer)] + val = jnp.take_along_axis(a_sorted, jnp.expand_dims(idx, axis), axis) idx_prev = lax.clamp(idx - 1, 0, a_sorted.shape[axis] - 1) - slicer_prev = slicer.copy() - slicer_prev[axis] = idx_prev - val_prev = a_sorted[tuple(slicer_prev)] - cw_prev = cum_weights_norm[tuple(slicer_prev)] - cw_next = cum_weights_norm[tuple(slicer)] + val_prev = jnp.take_along_axis(a_sorted, jnp.expand_dims(idx_prev, axis), axis) + cw_prev = jnp.take_along_axis(cum_weights_norm, jnp.expand_dims(idx_prev, axis), axis) + cw_next = jnp.take_along_axis(cum_weights_norm, jnp.expand_dims(idx, axis), axis) if method == "linear": denom = cw_next - cw_prev @@ -2603,12 +2583,12 @@ def _weighted_quantile(qi, weights_have_nan=weights_have_nan): high_weight = lax.sub(q, low) low_weight = lax.sub(_lax_const(high_weight, 1), high_weight) - low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1)) - high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1)) - low = lax.convert_element_type(low, dtypes.canonicalize_dtype(int)) - high = lax.convert_element_type(high, dtypes.canonicalize_dtype(int)) + low = lax.max(lax._const(low, 0), lax.min(low, counts - 1)) + high = lax.max(lax._const(high, 0), lax.min(high, counts - 1)) + low = lax.convert_element_type(low, int) + high = lax.convert_element_type(high, int) out_shape = q_shape + shape_after_reduction - index = [lax.broadcasted_iota(dtypes.canonicalize_dtype(int), out_shape, dim + q_ndim) + index = [lax.broadcasted_iota(int, out_shape, dim + q_ndim) for dim in range(len(shape_after_reduction))] if keepdims: index[axis] = low @@ -2628,10 +2608,10 @@ def _weighted_quantile(qi, weights_have_nan=weights_have_nan): high_weight = lax.sub(q, low) low_weight = lax.sub(_lax_const(high_weight, 1), high_weight) - low = lax.clamp(_lax_const(low, 0), low, n - 1) - high = lax.clamp(_lax_const(high, 0), high, n - 1) - low = lax.convert_element_type(low, dtypes.int_) - high = lax.convert_element_type(high, dtypes.int_) + low = lax.clamp(lax._const(low, 0), low, n - 1) + high = lax.clamp(lax._const(high, 0), high, n - 1) + low = lax.convert_element_type(low, int) + high = lax.convert_element_type(high, int) slice_sizes = list(a_shape) slice_sizes[axis] = 1 @@ -2662,7 +2642,7 @@ def _weighted_quantile(qi, weights_have_nan=weights_have_nan): pred = lax.le(high_weight, _lax_const(high_weight, 0.5)) result = lax.select(pred, low_value, high_value) elif method == "midpoint": - result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5)) + result = lax.mul(lax.add(low_value, high_value), lax._const(low_value, 0.5)) elif method == "inverted_cdf": result = high_value else: @@ -2721,6 +2701,10 @@ def percentile(a: ArrayLike, q: ArrayLike, Array([1., 3., 4.], dtype=float32) """ check_arraylike("percentile", a, q) + if weights is None: + a, q = ensure_arraylike("percentile", a, q) + else: + a, q, weights = ensure_arraylike("percentile", a, q, weights) q, = promote_dtypes_inexact(q) if not isinstance(interpolation, DeprecatedArg): deprecations.warn( @@ -2781,6 +2765,10 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, Array([1.5, 3. , 4.5], dtype=float32) """ check_arraylike("nanpercentile", a, q) + if weights is None: + a, q = ensure_arraylike("nanpercentile", a, q) + else: + a, q, weights = ensure_arraylike("nanpercentile", a, q, weights) q, = promote_dtypes_inexact(q) q = q / 100 if not isinstance(interpolation, DeprecatedArg): diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 88f7ca7e584d..17d6330439a0 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -23,14 +23,12 @@ from absl.testing import parameterized import numpy as np -from jax._src.numpy.reductions import _quantile import jax from jax import numpy as jnp from jax._src import config from jax._src import dtypes -from jax._src.numpy.reductions import quantile from jax._src import test_util as jtu from jax._src.util import NumpyComplexWarning @@ -770,7 +768,7 @@ def test_weighted_quantile_all_weights_one(self): a = jnp.array([1, 2, 3, 4, 5], dtype=float) weights = jnp.ones_like(a) q = jnp.array([0.25, 0.5, 0.75]) - result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights) + result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights) expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf") np.testing.assert_allclose(np.array(result), expected, rtol=1e-6) @@ -778,7 +776,7 @@ def test_weighted_quantile_multiple_q(self): a = jnp.arange(10, dtype=float) weights = jnp.ones_like(a) q = jnp.array([0.25, 0.5, 0.75]) - result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights) + result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights) expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf") np.testing.assert_allclose(np.array(result), expected, rtol=1e-6) @@ -786,7 +784,7 @@ def test_weighted_quantile_keepdims(self): a = jnp.array([1, 2, 3, 4], dtype=float) weights = jnp.array([1, 1, 1, 1], dtype=float) q = 0.5 - result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=True, squash_nans=False, weights=weights) + result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=True, squash_nans=False, weights=weights) expected = np.quantile(np.array(a), np.array(q), axis=0, keepdims=True, weights=np.array(weights), method="inverted_cdf") np.testing.assert_allclose(np.array(result), expected, rtol=1e-6) @@ -794,7 +792,7 @@ def test_weighted_quantile_linear(self): a = jnp.array([1, 2, 3, 4, 5], dtype=float) weights = jnp.array([1, 2, 1, 1, 1], dtype=float) q = jnp.array([0.5]) - result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights) + result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights) expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf") np.testing.assert_allclose(np.array(result), expected, rtol=1e-6) @@ -803,27 +801,27 @@ def test_weighted_quantile_negative_weights(self): weights = jnp.array([1, -1, 1, 1, 1], dtype=float) q = jnp.array([0.5]) with pytest.raises(ValueError): - _quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights) + jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights) def test_weighted_quantile_all_weights_zero(self): a = jnp.array([1, 2, 3, 4, 5], dtype=float) weights = jnp.zeros_like(a) q = jnp.array([0.5]) with pytest.raises(ValueError): - _quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights) + jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights) def test_weighted_quantile_weights_with_nan(self): a = jnp.array([1, 2, 3, 4, 5], dtype=float) weights = jnp.array([1, np.nan, 1, 1, 1], dtype=float) q = jnp.array([0.5]) - result = _quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights) + result = jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights) assert np.isnan(np.array(result)).all() def test_weighted_quantile_scalar_q(self): a = jnp.array([1, 2, 3, 4, 5], dtype=float) weights = jnp.array([1, 2, 1, 1, 1], dtype=float) q = 0.5 - result = _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights) + result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights) assert jnp.issubdtype(result.dtype, jnp.floating) assert result.shape == () @@ -831,7 +829,7 @@ def test_weighted_quantile_jit(self): a = jnp.array([1, 2, 3, 4, 5], dtype=float) weights = jnp.array([1, 2, 1, 1, 1], dtype=float) q = jnp.array([0.25, 0.5, 0.75]) - quantile_jit = jax.jit(lambda a, q, weights: _quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)) + quantile_jit = jax.jit(lambda a, q, weights: jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)) result = quantile_jit(a, q, weights) expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf") np.testing.assert_allclose(np.array(result), expected, rtol=1e-6) From 4f522e6bdcb929d015cc07309825a22e6ccb2f29 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Sat, 25 Oct 2025 11:00:20 +0530 Subject: [PATCH 06/12] Add weighted quantile and percentile support with tests --- jax/_src/numpy/reductions.py | 20 ++------ tests/lax_numpy_reducers_test.py | 85 +++++++++++++++----------------- 2 files changed, 45 insertions(+), 60 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 71c4bd7201a2..fc0d805f2d54 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -2379,7 +2379,6 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No >>> jnp.quantile(x, q, method='nearest') Array([2., 4., 7.], dtype=float32) """ - check_arraylike("quantile", a, q) if weights is None: a, q = ensure_arraylike("quantile", a, q) else: @@ -2390,7 +2389,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No if not isinstance(interpolation, DeprecatedArg): raise TypeError("quantile() argument interpolation was removed in JAX" " v0.8.0. Use method instead.") - return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, False, weights) + return _quantile(a, q, axis, method, keepdims, False, weights) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 @export @@ -2439,7 +2438,6 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = >>> jnp.nanquantile(x, q) Array([1.5, 3. , 4.5], dtype=float32) """ - check_arraylike("nanquantile", a, q) if weights is None: a, q = ensure_arraylike("nanquantile", a, q) else: @@ -2451,7 +2449,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = if not isinstance(interpolation, DeprecatedArg): raise TypeError("nanquantile() argument interpolation was removed in JAX" " v0.8.0. Use method instead.") - return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, True, weights) + return _quantile(a, q, axis, method, keepdims, True, weights) def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, method: str, keepdims: bool, squash_nans: bool, weights: Array | None = None) -> Array: @@ -2498,7 +2496,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, if weights is None: a, = promote_dtypes_inexact(a) else: - a, q = promote_dtypes_inexact(a, q) + a, weights = promote_dtypes_inexact(a, weights) a_shape = a.shape w_shape = np.shape(weights) if w_shape != a_shape: @@ -2513,11 +2511,6 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, weights = lax.expand_dims(weights, axis) weights = _broadcast_to(weights, a.shape) - weights_have_nan = jnp.any(jnp.isnan(weights)) - if weights_have_nan: - out_shape = q.shape if hasattr(q, "shape") and getattr(q, "ndim", 0) > 0 else () - return lax.full(out_shape, np.nan, dtype=a.dtype) - if squash_nans: nan_mask = ~lax_internal._isnan(a) weights = _where(nan_mask, weights, 0) @@ -2530,7 +2523,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, cum_weights = lax.cumsum(weights_sorted, axis=axis) cum_weights_norm = lax.div(cum_weights, total_weight) - def _weighted_quantile(qi, weights_have_nan=weights_have_nan): + def _weighted_quantile(qi): index_dtype = dtypes.default_int_dtype() idx = sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype) idx = lax.clamp(0, idx, a_sorted.shape[axis] - 1) @@ -2558,9 +2551,6 @@ def _weighted_quantile(qi, weights_have_nan=weights_have_nan): out = val else: raise ValueError(f"{method=!r} not recognized") - if weights_have_nan: - out = lax.full_like(out, np.nan) - out = lax.squeeze(out, axis=axis) return out if q.ndim == 0: @@ -2700,7 +2690,6 @@ def percentile(a: ArrayLike, q: ArrayLike, >>> jnp.percentile(x, q, method='nearest') Array([1., 3., 4.], dtype=float32) """ - check_arraylike("percentile", a, q) if weights is None: a, q = ensure_arraylike("percentile", a, q) else: @@ -2764,7 +2753,6 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, >>> jnp.nanpercentile(x, q) Array([1.5, 3. , 4.5], dtype=float32) """ - check_arraylike("nanpercentile", a, q) if weights is None: a, q = ensure_arraylike("nanpercentile", a, q) else: diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 17d6330439a0..82e6713ad968 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -17,7 +17,6 @@ from functools import partial import itertools import unittest -import pytest from absl.testing import absltest from absl.testing import parameterized @@ -764,51 +763,58 @@ def testPercentilePrecision(self): x = jnp.float64([1, 2, 3, 4, 7, 10]) self.assertEqual(jnp.percentile(x, 50), 3.5) - def test_weighted_quantile_all_weights_one(self): - a = jnp.array([1, 2, 3, 4, 5], dtype=float) - weights = jnp.ones_like(a) - q = jnp.array([0.25, 0.5, 0.75]) - result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights) - expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf") - np.testing.assert_allclose(np.array(result), expected, rtol=1e-6) - - def test_weighted_quantile_multiple_q(self): - a = jnp.arange(10, dtype=float) - weights = jnp.ones_like(a) - q = jnp.array([0.25, 0.5, 0.75]) - result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights) - expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf") - np.testing.assert_allclose(np.array(result), expected, rtol=1e-6) - - def test_weighted_quantile_keepdims(self): - a = jnp.array([1, 2, 3, 4], dtype=float) - weights = jnp.array([1, 1, 1, 1], dtype=float) - q = 0.5 - result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=True, squash_nans=False, weights=weights) - expected = np.quantile(np.array(a), np.array(q), axis=0, keepdims=True, weights=np.array(weights), method="inverted_cdf") - np.testing.assert_allclose(np.array(result), expected, rtol=1e-6) + @jtu.sample_product( + [dict(a_shape=a_shape, axis=axis) + for a_shape, axis in ( + ((7,), None), + ((6, 7,), None), + ((47, 7), 0), + ((47, 7), ()), + ((4, 101), 1), + ((4, 47, 7), (1, 2)), + ((4, 47, 7), (0, 2)), + ((4, 47, 7), (1, 0, 2)), + ) + ], + a_dtype=default_dtypes, + q_dtype=[np.float32], + q_shape=scalar_shapes + [(1,), (4,)], + keepdims=[False, True], + method=['linear', 'lower', 'higher', 'nearest', 'midpoint', 'inverted_cdf'], +) + def testWeightedQuantile(self, a_shape, a_dtype, q_shape, q_dtype, axis, keepdims, method): + rng = jtu.rand_default(self.rng()) + a = rng(a_shape, a_dtype) + q = rng(q_shape, q_dtype) + if axis is None: + weights_shape = a_shape + elif isinstance(axis, tuple): + weights_shape = tuple(a_shape[i] for i in axis) + else: + weights_shape = (a_shape[axis],) + weights = np.abs(rng(weights_shape, a_dtype)) + 1e-3 - def test_weighted_quantile_linear(self): - a = jnp.array([1, 2, 3, 4, 5], dtype=float) - weights = jnp.array([1, 2, 1, 1, 1], dtype=float) - q = jnp.array([0.5]) - result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights) - expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf") - np.testing.assert_allclose(np.array(result), expected, rtol=1e-6) + def np_fun(a, q, weights): + return np.quantile(np.array(a), np.array(q), axis=axis, weights=np.array(weights), method=method, keepdims=keepdims) + def jnp_fun(a, q, weights): + return jnp.quantile(a, q, axis=axis, weights=weights, method=method, keepdims=keepdims) + args_maker = lambda: [a, q, weights] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=1e-6) + self._CompileAndCheck(jnp_fun, args_maker, rtol=1e-6) def test_weighted_quantile_negative_weights(self): a = jnp.array([1, 2, 3, 4, 5], dtype=float) weights = jnp.array([1, -1, 1, 1, 1], dtype=float) q = jnp.array([0.5]) - with pytest.raises(ValueError): - jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights) + with self.assertRaisesRegex(ValueError, "Weights must be non-negative"): + jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights) def test_weighted_quantile_all_weights_zero(self): a = jnp.array([1, 2, 3, 4, 5], dtype=float) weights = jnp.zeros_like(a) q = jnp.array([0.5]) - with pytest.raises(ValueError): - jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights) + with self.assertRaisesRegex(ValueError, "Sum of weights must not be zero"): + jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights) def test_weighted_quantile_weights_with_nan(self): a = jnp.array([1, 2, 3, 4, 5], dtype=float) @@ -825,15 +831,6 @@ def test_weighted_quantile_scalar_q(self): assert jnp.issubdtype(result.dtype, jnp.floating) assert result.shape == () - def test_weighted_quantile_jit(self): - a = jnp.array([1, 2, 3, 4, 5], dtype=float) - weights = jnp.array([1, 2, 1, 1, 1], dtype=float) - q = jnp.array([0.25, 0.5, 0.75]) - quantile_jit = jax.jit(lambda a, q, weights: jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)) - result = quantile_jit(a, q, weights) - expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf") - np.testing.assert_allclose(np.array(result), expected, rtol=1e-6) - @jtu.sample_product( [dict(a_shape=a_shape, axis=axis) for a_shape, axis in ( From cef1731b17fc2f4d9a6e7931e1a472469b298d8a Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Sun, 26 Oct 2025 01:53:52 +0530 Subject: [PATCH 07/12] Add weighted quantile and percentile support with tests --- jax/_src/numpy/reductions.py | 45 ++++++++++++++++++-------------- tests/lax_numpy_reducers_test.py | 24 ++++++++++------- 2 files changed, 39 insertions(+), 30 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index fc0d805f2d54..cf18c39786dc 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -25,7 +25,6 @@ import jax from jax import lax -from jax import numpy as jnp from jax._src import api from jax._src import core from jax._src import deprecations @@ -2453,6 +2452,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, method: str, keepdims: bool, squash_nans: bool, weights: Array | None = None) -> Array: + from jax import numpy as jnp if method not in ["linear", "lower", "higher", "midpoint", "nearest", "inverted_cdf"]: raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest', or 'inverted_cdf'") keepdim = [] @@ -2485,7 +2485,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, axis = _canonicalize_axis(axis, a.ndim) q, = promote_dtypes_inexact(q) - + q = jnp.atleast_1d(q) q_shape = q.shape q_ndim = q.ndim if q_ndim > 1: @@ -2500,16 +2500,22 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, a_shape = a.shape w_shape = np.shape(weights) if w_shape != a_shape: - if len(w_shape) != 1: - raise ValueError("1D weights expected when shapes of a and weights differ.") if axis is None: raise TypeError("Axis must be specified when shapes of a and weights differ.") - if w_shape[0] != a_shape[axis]: - raise ValueError("Length of weights not compatible with specified axis.") - resh = [1] * a.ndim - resh[axis] = w_shape[0] - weights = lax.expand_dims(weights, axis) - weights = _broadcast_to(weights, a.shape) + if isinstance(axis, tuple): + if w_shape != tuple(a_shape[i] for i in axis): + raise ValueError("Shape of weights must match the shape of the axes being reduced.") + weights = lax.broadcast_in_dim( + weights, + shape=a_shape, + broadcast_dimensions=axis + ) + else: + if len(w_shape) != 1 or w_shape[0] != a_shape[axis]: + raise ValueError("Length of weights not compatible with specified axis.") + weights = lax.expand_dims(weights, axis) + weights = _broadcast_to(weights, a.shape) + if squash_nans: nan_mask = ~lax_internal._isnan(a) @@ -2525,14 +2531,14 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, def _weighted_quantile(qi): index_dtype = dtypes.default_int_dtype() - idx = sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype) + idx = sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype, keepdims=keepdims) idx = lax.clamp(0, idx, a_sorted.shape[axis] - 1) - val = jnp.take_along_axis(a_sorted, jnp.expand_dims(idx, axis), axis) + val = jnp.take_along_axis(a_sorted, idx, axis) idx_prev = lax.clamp(idx - 1, 0, a_sorted.shape[axis] - 1) - val_prev = jnp.take_along_axis(a_sorted, jnp.expand_dims(idx_prev, axis), axis) - cw_prev = jnp.take_along_axis(cum_weights_norm, jnp.expand_dims(idx_prev, axis), axis) - cw_next = jnp.take_along_axis(cum_weights_norm, jnp.expand_dims(idx, axis), axis) + val_prev = jnp.take_along_axis(a_sorted, idx_prev, axis) + cw_prev = jnp.take_along_axis(cum_weights_norm, idx_prev, axis) + cw_next = jnp.take_along_axis(cum_weights_norm, idx, axis) if method == "linear": denom = cw_next - cw_prev @@ -2552,11 +2558,10 @@ def _weighted_quantile(qi): else: raise ValueError(f"{method=!r} not recognized") return out - - if q.ndim == 0: - result = _weighted_quantile(q) - else: - result = jax.vmap(_weighted_quantile)(q) + + result = jax.vmap(_weighted_quantile)(q) + if q.shape == (1,): + result = result[0] return result if squash_nans: diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 82e6713ad968..c2ec6ff90098 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -787,18 +787,22 @@ def testWeightedQuantile(self, a_shape, a_dtype, q_shape, q_dtype, axis, keepdim a = rng(a_shape, a_dtype) q = rng(q_shape, q_dtype) if axis is None: - weights_shape = a_shape + weights_shape = a_shape elif isinstance(axis, tuple): - weights_shape = tuple(a_shape[i] for i in axis) + weights_shape = tuple(a_shape[i] for i in axis) else: - weights_shape = (a_shape[axis],) + weights_shape = (a_shape[axis],) weights = np.abs(rng(weights_shape, a_dtype)) + 1e-3 def np_fun(a, q, weights): - return np.quantile(np.array(a), np.array(q), axis=axis, weights=np.array(weights), method=method, keepdims=keepdims) + return np.quantile(np.array(a), np.array(q), axis=axis, weights=np.array(weights), method=method, keepdims=keepdims) def jnp_fun(a, q, weights): - return jnp.quantile(a, q, axis=axis, weights=weights, method=method, keepdims=keepdims) - args_maker = lambda: [a, q, weights] + return jnp.quantile(a, q, axis=axis, weights=weights, method=method, keepdims=keepdims) + args_maker = lambda: [ + rng(a_shape, a_dtype), + rng(q_shape, q_dtype), + np.abs(rng(weights_shape, a_dtype)) + 1e-3 + ] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=1e-6) self._CompileAndCheck(jnp_fun, args_maker, rtol=1e-6) @@ -807,27 +811,27 @@ def test_weighted_quantile_negative_weights(self): weights = jnp.array([1, -1, 1, 1, 1], dtype=float) q = jnp.array([0.5]) with self.assertRaisesRegex(ValueError, "Weights must be non-negative"): - jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights) + jnp.quantile(a, q, axis=0, method="linear", keepdims=False, weights=weights) def test_weighted_quantile_all_weights_zero(self): a = jnp.array([1, 2, 3, 4, 5], dtype=float) weights = jnp.zeros_like(a) q = jnp.array([0.5]) with self.assertRaisesRegex(ValueError, "Sum of weights must not be zero"): - jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights) + jnp.quantile(a, q, axis=0, method="linear", keepdims=False, weights=weights) def test_weighted_quantile_weights_with_nan(self): a = jnp.array([1, 2, 3, 4, 5], dtype=float) weights = jnp.array([1, np.nan, 1, 1, 1], dtype=float) q = jnp.array([0.5]) - result = jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights) + result = jnp.quantile(a, q, axis=0, method="linear", keepdims=False, weights=weights) assert np.isnan(np.array(result)).all() def test_weighted_quantile_scalar_q(self): a = jnp.array([1, 2, 3, 4, 5], dtype=float) weights = jnp.array([1, 2, 1, 1, 1], dtype=float) q = 0.5 - result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights) + result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, weights=weights) assert jnp.issubdtype(result.dtype, jnp.floating) assert result.shape == () From a230e01600bacd60463856b565763466f3b7e686 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Sun, 26 Oct 2025 09:31:54 +0530 Subject: [PATCH 08/12] Add weighted quantile and percentile support with tests --- jax/_src/numpy/reductions.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index cf18c39786dc..241ac7567830 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -25,6 +25,7 @@ import jax from jax import lax +import jax._src.numpy as jnp from jax._src import api from jax._src import core from jax._src import deprecations @@ -2452,7 +2453,6 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, method: str, keepdims: bool, squash_nans: bool, weights: Array | None = None) -> Array: - from jax import numpy as jnp if method not in ["linear", "lower", "higher", "midpoint", "nearest", "inverted_cdf"]: raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest', or 'inverted_cdf'") keepdim = [] @@ -2515,7 +2515,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, raise ValueError("Length of weights not compatible with specified axis.") weights = lax.expand_dims(weights, axis) weights = _broadcast_to(weights, a.shape) - + if squash_nans: nan_mask = ~lax_internal._isnan(a) @@ -2558,10 +2558,8 @@ def _weighted_quantile(qi): else: raise ValueError(f"{method=!r} not recognized") return out - + result = jax.vmap(_weighted_quantile)(q) - if q.shape == (1,): - result = result[0] return result if squash_nans: From 7d50a32666f916a1162d57ec7787bfd85177ab17 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Mon, 27 Oct 2025 17:03:45 +0530 Subject: [PATCH 09/12] Add weighted quantile and percentile support with tests --- jax/_src/numpy/reductions.py | 70 +++++++++++++++++++++++++----------- 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 241ac7567830..dd67841f375c 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -2485,7 +2485,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, axis = _canonicalize_axis(axis, a.ndim) q, = promote_dtypes_inexact(q) - q = jnp.atleast_1d(q) + q = lax_internal.asarray(q) + if getattr(q, "ndim", 0) == 0: + q = lax.expand_dims(q, (0,)) q_shape = q.shape q_ndim = q.ndim if q_ndim > 1: @@ -2497,8 +2499,14 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, a, = promote_dtypes_inexact(a) else: a, weights = promote_dtypes_inexact(a, weights) + weights = lax.convert_element_type(weights, a.dtype) a_shape = a.shape w_shape = np.shape(weights) + if np.ndim(weights) == 0: + weights = lax.broadcast_in_dim(weights, a_shape, ()) + w_shape = a_shape + else: + w_shape = np.shape(weights) if w_shape != a_shape: if axis is None: raise TypeError("Axis must be specified when shapes of a and weights differ.") @@ -2510,12 +2518,13 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, shape=a_shape, broadcast_dimensions=axis ) - else: + w_shape = a_shape + else: if len(w_shape) != 1 or w_shape[0] != a_shape[axis]: raise ValueError("Length of weights not compatible with specified axis.") - weights = lax.expand_dims(weights, axis) + weights = lax.expand_dims(weights, (axis,)) weights = _broadcast_to(weights, a.shape) - + w_shape = a_shape if squash_nans: nan_mask = ~lax_internal._isnan(a) @@ -2526,20 +2535,29 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, total_weight = sum(weights, axis=axis, keepdims=True) a_sorted, weights_sorted = lax.sort_key_val(a, weights, dimension=axis) - cum_weights = lax.cumsum(weights_sorted, axis=axis) + cum_weights = cumsum(weights_sorted, axis=axis) cum_weights_norm = lax.div(cum_weights, total_weight) def _weighted_quantile(qi): - index_dtype = dtypes.default_int_dtype() + qi = lax.convert_element_type(qi, cum_weights_norm.dtype) + index_dtype = dtypes.canonicalize_dtype(dtypes.int_) idx = sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype, keepdims=keepdims) - idx = lax.clamp(0, idx, a_sorted.shape[axis] - 1) - val = jnp.take_along_axis(a_sorted, idx, axis) - - idx_prev = lax.clamp(idx - 1, 0, a_sorted.shape[axis] - 1) - val_prev = jnp.take_along_axis(a_sorted, idx_prev, axis) - cw_prev = jnp.take_along_axis(cum_weights_norm, idx_prev, axis) - cw_next = jnp.take_along_axis(cum_weights_norm, idx, axis) - + idx = lax.clamp(_lax_const(idx, 0), idx, _lax_const(idx, a_sorted.shape[axis] - 1)) + idx_prev = lax.clamp(idx - 1, _lax_const(idx, 0), _lax_const(idx, a_sorted.shape[axis] - 1)) + + slice_sizes = list(a_shape) + slice_sizes[axis] = 1 + offset_start = q_ndim + total_offset_dims = len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1 + dnums = lax.GatherDimensionNumbers( + offset_dims=tuple(range(offset_start, total_offset_dims)), + collapsed_slice_dims=(axis,), + start_index_map=(axis,) + ) + val = lax.gather(a_sorted, idx[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes) + val_prev = lax.gather(a_sorted, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes) + cw_prev = lax.gather(cum_weights_norm, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes) + cw_next = lax.gather(cum_weights_norm, idx[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes) if method == "linear": denom = cw_next - cw_prev denom = _where(denom == 0, 1, denom) @@ -2560,7 +2578,15 @@ def _weighted_quantile(qi): return out result = jax.vmap(_weighted_quantile)(q) - return result + if keepdims and keepdim: + if q_ndim > 0: + keepdim = [q_shape[0], *keepdim] + result = result.reshape(tuple(keepdim)) + else: + if q_ndim == 0 or (q_ndim == 1 and q_shape[0] == 1): + if result.ndim > 0 and result.shape[0] == 1: + result = lax.squeeze(result, (0,)) + return lax.convert_element_type(result, a.dtype) if squash_nans: a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. @@ -2576,8 +2602,8 @@ def _weighted_quantile(qi): high_weight = lax.sub(q, low) low_weight = lax.sub(_lax_const(high_weight, 1), high_weight) - low = lax.max(lax._const(low, 0), lax.min(low, counts - 1)) - high = lax.max(lax._const(high, 0), lax.min(high, counts - 1)) + low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1)) + high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1)) low = lax.convert_element_type(low, int) high = lax.convert_element_type(high, int) out_shape = q_shape + shape_after_reduction @@ -2601,8 +2627,8 @@ def _weighted_quantile(qi): high_weight = lax.sub(q, low) low_weight = lax.sub(_lax_const(high_weight, 1), high_weight) - low = lax.clamp(lax._const(low, 0), low, n - 1) - high = lax.clamp(lax._const(high, 0), high, n - 1) + low = lax.clamp(_lax_const(low, 0), low, n - 1) + high = lax.clamp(_lax_const(high, 0), high, n - 1) low = lax.convert_element_type(low, int) high = lax.convert_element_type(high, int) @@ -2635,7 +2661,7 @@ def _weighted_quantile(qi): pred = lax.le(high_weight, _lax_const(high_weight, 0.5)) result = lax.select(pred, low_value, high_value) elif method == "midpoint": - result = lax.mul(lax.add(low_value, high_value), lax._const(low_value, 0.5)) + result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5)) elif method == "inverted_cdf": result = high_value else: @@ -2644,6 +2670,10 @@ def _weighted_quantile(qi): if q_ndim > 0: keepdim = [np.shape(q)[0], *keepdim] result = result.reshape(keepdim) + else: + if q_ndim == 0 or (q_ndim == 1 and q_shape[0] == 1): + if result.ndim > 0 and result.shape[0] == 1: + result = lax.squeeze(result, (0,)) return lax.convert_element_type(result, a.dtype) From ca1d95b515b52dd37df141f0d5996487dd2664f9 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Thu, 30 Oct 2025 19:44:14 +0530 Subject: [PATCH 10/12] Add weighted quantile and percentile support with tests --- jax/_src/numpy/reductions.py | 44 ++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index dd67841f375c..13e4d3f19d15 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -23,9 +23,7 @@ import numpy as np -import jax -from jax import lax -import jax._src.numpy as jnp +from jax._src.lax import lax from jax._src import api from jax._src import core from jax._src import deprecations @@ -2486,7 +2484,8 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, q, = promote_dtypes_inexact(q) q = lax_internal.asarray(q) - if getattr(q, "ndim", 0) == 0: + q_was_scalar = getattr(q, "ndim", 0) == 0 + if q_was_scalar: q = lax.expand_dims(q, (0,)) q_shape = q.shape q_ndim = q.ndim @@ -2534,7 +2533,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a) total_weight = sum(weights, axis=axis, keepdims=True) - a_sorted, weights_sorted = lax.sort_key_val(a, weights, dimension=axis) + a_sorted, weights_sorted = lax_internal.sort_key_val(a, weights, dimension=axis) cum_weights = cumsum(weights_sorted, axis=axis) cum_weights_norm = lax.div(cum_weights, total_weight) @@ -2576,17 +2575,16 @@ def _weighted_quantile(qi): else: raise ValueError(f"{method=!r} not recognized") return out + result = api.vmap(_weighted_quantile)(q) + keepdim_out = list(keepdim) + if not q_was_scalar: + keepdim_out = [q_shape[0], *keepdim_out] + result = result.reshape(tuple(keepdim_out)) + elif q_was_scalar and result.ndim > 0 and result.shape[0] == 1: + result = result.squeeze(axis=0) + return result + return result - result = jax.vmap(_weighted_quantile)(q) - if keepdims and keepdim: - if q_ndim > 0: - keepdim = [q_shape[0], *keepdim] - result = result.reshape(tuple(keepdim)) - else: - if q_ndim == 0 or (q_ndim == 1 and q_shape[0] == 1): - if result.ndim > 0 and result.shape[0] == 1: - result = lax.squeeze(result, (0,)) - return lax.convert_element_type(result, a.dtype) if squash_nans: a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. @@ -2666,15 +2664,13 @@ def _weighted_quantile(qi): result = high_value else: raise ValueError(f"{method=!r} not recognized") - if keepdims and keepdim: - if q_ndim > 0: - keepdim = [np.shape(q)[0], *keepdim] - result = result.reshape(keepdim) - else: - if q_ndim == 0 or (q_ndim == 1 and q_shape[0] == 1): - if result.ndim > 0 and result.shape[0] == 1: - result = lax.squeeze(result, (0,)) - return lax.convert_element_type(result, a.dtype) + keepdim_out = list(keepdim) + if not q_was_scalar: + keepdim_out = [q_shape[0], *keepdim_out] + result = result.reshape(tuple(keepdim_out)) + elif q_was_scalar and result.ndim > 0 and result.shape[0] == 1: + result = result.squeeze(axis=0) + return result # TODO(jakevdp): interpolation argument deprecated 2024-05-16 From 4ebbf21dcfb77d89062515fb69094acbf1d776fb Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Thu, 30 Oct 2025 20:17:22 +0530 Subject: [PATCH 11/12] Add weighted quantile and percentile support with tests --- jax/_src/numpy/reductions.py | 6 ++---- tests/lax_numpy_reducers_test.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 13e4d3f19d15..b0b96e3d06f4 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -2497,8 +2497,8 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, if weights is None: a, = promote_dtypes_inexact(a) else: - a, weights = promote_dtypes_inexact(a, weights) - weights = lax.convert_element_type(weights, a.dtype) + a, q, weights = promote_dtypes_inexact(a, q, weights) + #weights = lax.convert_element_type(weights, a.dtype) a_shape = a.shape w_shape = np.shape(weights) if np.ndim(weights) == 0: @@ -2583,8 +2583,6 @@ def _weighted_quantile(qi): elif q_was_scalar and result.ndim > 0 and result.shape[0] == 1: result = result.squeeze(axis=0) return result - return result - if squash_nans: a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index c2ec6ff90098..f66b750a9b64 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -789,7 +789,7 @@ def testWeightedQuantile(self, a_shape, a_dtype, q_shape, q_dtype, axis, keepdim if axis is None: weights_shape = a_shape elif isinstance(axis, tuple): - weights_shape = tuple(a_shape[i] for i in axis) + weights_shape = a_shape else: weights_shape = (a_shape[axis],) weights = np.abs(rng(weights_shape, a_dtype)) + 1e-3 From 43cdb5c4987a358a64ada34a82aeb9abecee5f9a Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Fri, 31 Oct 2025 11:05:09 +0530 Subject: [PATCH 12/12] Add weighted quantile and percentile support with tests --- jax/_src/numpy/reductions.py | 100 +++++++++++-------------------- tests/lax_numpy_reducers_test.py | 14 +++-- 2 files changed, 44 insertions(+), 70 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index b0b96e3d06f4..7476f2a46e3c 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -23,7 +23,7 @@ import numpy as np -from jax._src.lax import lax +from jax._src import config from jax._src import api from jax._src import core from jax._src import deprecations @@ -2483,8 +2483,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, axis = _canonicalize_axis(axis, a.ndim) q, = promote_dtypes_inexact(q) - q = lax_internal.asarray(q) - q_was_scalar = getattr(q, "ndim", 0) == 0 + q_was_scalar = q.ndim == 0 if q_was_scalar: q = lax.expand_dims(q, (0,)) q_shape = q.shape @@ -2497,40 +2496,37 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, if weights is None: a, = promote_dtypes_inexact(a) else: + if method != "inverted_cdf": + raise ValueError("Weighted quantiles are only supported for method='inverted_cdf'") + if axis is None: + raise TypeError("Axis must be specified when shapes of a and weights differ.") + axis_tuple = canonicalize_axis_tuple(axis, a.ndim) + a, q, weights = promote_dtypes_inexact(a, q, weights) - #weights = lax.convert_element_type(weights, a.dtype) a_shape = a.shape w_shape = np.shape(weights) if np.ndim(weights) == 0: weights = lax.broadcast_in_dim(weights, a_shape, ()) w_shape = a_shape - else: - w_shape = np.shape(weights) if w_shape != a_shape: - if axis is None: - raise TypeError("Axis must be specified when shapes of a and weights differ.") - if isinstance(axis, tuple): - if w_shape != tuple(a_shape[i] for i in axis): - raise ValueError("Shape of weights must match the shape of the axes being reduced.") - weights = lax.broadcast_in_dim( - weights, - shape=a_shape, - broadcast_dimensions=axis - ) - w_shape = a_shape - else: - if len(w_shape) != 1 or w_shape[0] != a_shape[axis]: - raise ValueError("Length of weights not compatible with specified axis.") - weights = lax.expand_dims(weights, (axis,)) - weights = _broadcast_to(weights, a.shape) - w_shape = a_shape + expected_shape = tuple(a_shape[i] for i in axis_tuple) + if w_shape != expected_shape: + raise ValueError(f"Shape of weights must match the shape of the axes being reduced. " + f"Expected {expected_shape}, got {w_shape}") + weights = lax.broadcast_in_dim( + weights, + shape=a_shape, + broadcast_dimensions=axis_tuple + ) if squash_nans: nan_mask = ~lax_internal._isnan(a) weights = _where(nan_mask, weights, 0) else: - with jax.debug_nans(False): - a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a) + with config.debug_nans(False): + has_nan_data = any(lax_internal._isnan(a), axis=axis, keepdims=True) + has_nan_weights = any(lax_internal._isnan(weights), axis=axis, keepdims=True) + a = _where(has_nan_data | has_nan_weights, np.nan, a) total_weight = sum(weights, axis=axis, keepdims=True) a_sorted, weights_sorted = lax_internal.sort_key_val(a, weights, dimension=axis) @@ -2539,49 +2535,23 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, def _weighted_quantile(qi): qi = lax.convert_element_type(qi, cum_weights_norm.dtype) - index_dtype = dtypes.canonicalize_dtype(dtypes.int_) - idx = sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype, keepdims=keepdims) + index_dtype = dtypes.default_int_dtype() + idx = _reduce_sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype, keepdims=keepdims) idx = lax.clamp(_lax_const(idx, 0), idx, _lax_const(idx, a_sorted.shape[axis] - 1)) - idx_prev = lax.clamp(idx - 1, _lax_const(idx, 0), _lax_const(idx, a_sorted.shape[axis] - 1)) - - slice_sizes = list(a_shape) - slice_sizes[axis] = 1 - offset_start = q_ndim - total_offset_dims = len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1 - dnums = lax.GatherDimensionNumbers( - offset_dims=tuple(range(offset_start, total_offset_dims)), - collapsed_slice_dims=(axis,), - start_index_map=(axis,) - ) - val = lax.gather(a_sorted, idx[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes) - val_prev = lax.gather(a_sorted, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes) - cw_prev = lax.gather(cum_weights_norm, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes) - cw_next = lax.gather(cum_weights_norm, idx[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes) - if method == "linear": - denom = cw_next - cw_prev - denom = _where(denom == 0, 1, denom) - weight = (qi - cw_prev) / denom - out = val_prev * (1 - weight) + val * weight - elif method == "lower": - out = val_prev - elif method == "higher": - out = val - elif method == "nearest": - out = _where(lax.abs(qi - cw_prev) < lax.abs(qi - cw_next), val_prev, val) - elif method == "midpoint": - out = (val_prev + val) / 2 - elif method == "inverted_cdf": - out = val - else: - raise ValueError(f"{method=!r} not recognized") - return out + + idx_expanded = lax.expand_dims(idx, (axis,)) if not keepdims else idx + return jnp.take_along_axis(a_sorted, idx_expanded, axis=axis).squeeze(axis=axis) result = api.vmap(_weighted_quantile)(q) - keepdim_out = list(keepdim) + shape_after = list(a_shape) + if keepdims: + shape_after[axis] = 1 + else: + del shape_after[axis] if not q_was_scalar: - keepdim_out = [q_shape[0], *keepdim_out] - result = result.reshape(tuple(keepdim_out)) - elif q_was_scalar and result.ndim > 0 and result.shape[0] == 1: - result = result.squeeze(axis=0) + result = result.reshape((q_shape[0], *shape_after)) + else: + if result.ndim > 0 and result.shape[0] == 1: + result = result.reshape(tuple(shape_after)) return result if squash_nans: diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index f66b750a9b64..dc46f9171e52 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -803,28 +803,32 @@ def jnp_fun(a, q, weights): rng(q_shape, q_dtype), np.abs(rng(weights_shape, a_dtype)) + 1e-3 ] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=1e-6) - self._CompileAndCheck(jnp_fun, args_maker, rtol=1e-6) + if method == "inverted_cdf": + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=1e-6) + self._CompileAndCheck(jnp_fun, args_maker, rtol=1e-6) + else: + with self.assertRaisesRegex(ValueError, "Weighted quantiles are only supported for method='inverted_cdf'"): + jnp_fun(*args_maker()) def test_weighted_quantile_negative_weights(self): a = jnp.array([1, 2, 3, 4, 5], dtype=float) weights = jnp.array([1, -1, 1, 1, 1], dtype=float) q = jnp.array([0.5]) with self.assertRaisesRegex(ValueError, "Weights must be non-negative"): - jnp.quantile(a, q, axis=0, method="linear", keepdims=False, weights=weights) + jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, weights=weights) def test_weighted_quantile_all_weights_zero(self): a = jnp.array([1, 2, 3, 4, 5], dtype=float) weights = jnp.zeros_like(a) q = jnp.array([0.5]) with self.assertRaisesRegex(ValueError, "Sum of weights must not be zero"): - jnp.quantile(a, q, axis=0, method="linear", keepdims=False, weights=weights) + jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, weights=weights) def test_weighted_quantile_weights_with_nan(self): a = jnp.array([1, 2, 3, 4, 5], dtype=float) weights = jnp.array([1, np.nan, 1, 1, 1], dtype=float) q = jnp.array([0.5]) - result = jnp.quantile(a, q, axis=0, method="linear", keepdims=False, weights=weights) + result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, weights=weights) assert np.isnan(np.array(result)).all() def test_weighted_quantile_scalar_q(self):