diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index ec56e1c0506b..d7f3b23a2066 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -27,7 +27,7 @@ from jax._src import core from jax._src import dtypes from jax._src.numpy.util import ( - _broadcast_to, ensure_arraylike, + _broadcast_to, check_arraylike, ensure_arraylike, _complex_elem_type, promote_dtypes_inexact, promote_dtypes_numeric, _where) from jax._src.lax import control_flow from jax._src.lax import lax as lax @@ -2376,7 +2376,8 @@ def cumulative_prod( @api.jit(static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: + keepdims: bool = False, *, weights: ArrayLike | None = None, + interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: """Compute the quantile of the data along the specified axis. JAX implementation of :func:`numpy.quantile`. @@ -2414,7 +2415,10 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No >>> jnp.quantile(x, q, method='nearest') Array([2., 4., 7.], dtype=float32) """ - a, q = ensure_arraylike("quantile", a, q) + if weights is None: + a, q = ensure_arraylike("quantile", a, q) + else: + a, q, weights = ensure_arraylike("quantile", a, q, weights) if overwrite_input or out is not None: raise ValueError("jax.numpy.quantile does not support overwrite_input=True " "or out != None") @@ -2422,14 +2426,15 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No if not isinstance(interpolation, DeprecatedArg): raise TypeError("quantile() argument interpolation was removed in JAX" " v0.8.0. Use method instead.") - return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, False) + return _quantile(a, q, axis, method, keepdims, False, weights) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 @export @api.jit(static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: + keepdims: bool = False, *, weights: ArrayLike | None = None, + interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: """Compute the quantile of the data along the specified axis, ignoring NaNs. JAX implementation of :func:`numpy.nanquantile`. @@ -2468,7 +2473,10 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = >>> jnp.nanquantile(x, q) Array([1.5, 3. , 4.5], dtype=float32) """ - a, q = ensure_arraylike("nanquantile", a, q) + if weights is None: + a, q = ensure_arraylike("nanquantile", a, q) + else: + a, q, weights = ensure_arraylike("nanquantile", a, q, weights) if overwrite_input or out is not None: msg = ("jax.numpy.nanquantile does not support overwrite_input=True or " "out != None") @@ -2477,13 +2485,12 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = if not isinstance(interpolation, DeprecatedArg): raise TypeError("nanquantile() argument interpolation was removed in JAX" " v0.8.0. Use method instead.") - return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, True) + return _quantile(a, q, axis, method, keepdims, True, weights) def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, - method: str, keepdims: bool, squash_nans: bool) -> Array: - if method not in ["linear", "lower", "higher", "midpoint", "nearest"]: - raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', or 'nearest'") - a, = promote_dtypes_inexact(a) + method: str, keepdims: bool, squash_nans: bool, weights: Array | None = None) -> Array: + if method not in ["linear", "lower", "higher", "midpoint", "nearest", "inverted_cdf"]: + raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest', or 'inverted_cdf'") keepdim = [] if dtypes.issubdtype(a.dtype, np.complexfloating): raise ValueError("quantile does not support complex input, as the operation is poorly defined.") @@ -2513,12 +2520,77 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, else: axis = canonicalize_axis(axis, a.ndim) + q, = promote_dtypes_inexact(q) + q_was_scalar = q.ndim == 0 + if q_was_scalar: + q = lax.expand_dims(q, (0,)) q_shape = q.shape q_ndim = q.ndim if q_ndim > 1: raise ValueError(f"q must be have rank <= 1, got shape {q.shape}") a_shape = a.shape + # Handle weights + if weights is None: + a, = promote_dtypes_inexact(a) + else: + if method != "inverted_cdf": + raise ValueError("Weighted quantiles are only supported for method='inverted_cdf'") + if axis is None: + raise TypeError("Axis must be specified when shapes of a and weights differ.") + axis_tuple = canonicalize_axis_tuple(axis, a.ndim) + + a, q, weights = promote_dtypes_inexact(a, q, weights) + a_shape = a.shape + w_shape = np.shape(weights) + if np.ndim(weights) == 0: + weights = lax.broadcast_in_dim(weights, a_shape, ()) + w_shape = a_shape + if w_shape != a_shape: + expected_shape = tuple(a_shape[i] for i in axis_tuple) + if w_shape != expected_shape: + raise ValueError(f"Shape of weights must match the shape of the axes being reduced. " + f"Expected {expected_shape}, got {w_shape}") + weights = lax.broadcast_in_dim( + weights, + shape=a_shape, + broadcast_dimensions=axis_tuple + ) + + if squash_nans: + nan_mask = ~lax_internal._isnan(a) + weights = _where(nan_mask, weights, 0) + else: + with config.debug_nans(False): + has_nan_data = any(lax_internal._isnan(a), axis=axis, keepdims=True) + has_nan_weights = any(lax_internal._isnan(weights), axis=axis, keepdims=True) + a = _where(has_nan_data | has_nan_weights, np.nan, a) + + total_weight = sum(weights, axis=axis, keepdims=True) + a_sorted, weights_sorted = lax_internal.sort_key_val(a, weights, dimension=axis) + cum_weights = cumsum(weights_sorted, axis=axis) + cum_weights_norm = lax.div(cum_weights, total_weight) + + def _weighted_quantile(qi): + qi = lax.convert_element_type(qi, cum_weights_norm.dtype) + index_dtype = dtypes.default_int_dtype() + idx = _reduce_sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype, keepdims=keepdims) + idx = lax.clamp(_lax_const(idx, 0), idx, _lax_const(idx, a_sorted.shape[axis] - 1)) + + idx_expanded = lax.expand_dims(idx, (axis,)) if not keepdims else idx + return jnp.take_along_axis(a_sorted, idx_expanded, axis=axis).squeeze(axis=axis) + result = api.vmap(_weighted_quantile)(q) + shape_after = list(a_shape) + if keepdims: + shape_after[axis] = 1 + else: + del shape_after[axis] + if not q_was_scalar: + result = result.reshape((q_shape[0], *shape_after)) + else: + if result.ndim > 0 and result.shape[0] == 1: + result = result.reshape(tuple(shape_after)) + return result if squash_nans: a = _where(lax._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. @@ -2593,14 +2665,18 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, pred = lax.le(high_weight, lax._const(high_weight, 0.5)) result = lax.select(pred, low_value, high_value) elif method == "midpoint": - result = lax.mul(lax.add(low_value, high_value), lax._const(low_value, 0.5)) + result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5)) + elif method == "inverted_cdf": + result = high_value else: raise ValueError(f"{method=!r} not recognized") - if keepdims and keepdim: - if q_ndim > 0: - keepdim = [np.shape(q)[0], *keepdim] - result = result.reshape(keepdim) - return lax.convert_element_type(result, a.dtype) + keepdim_out = list(keepdim) + if not q_was_scalar: + keepdim_out = [q_shape[0], *keepdim_out] + result = result.reshape(tuple(keepdim_out)) + elif q_was_scalar and result.ndim > 0 and result.shape[0] == 1: + result = result.squeeze(axis=0) + return result # TODO(jakevdp): interpolation argument deprecated 2024-05-16 @@ -2609,7 +2685,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, def percentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: + keepdims: bool = False, *, weights: ArrayLike | None = None, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: """Compute the percentile of the data along the specified axis. JAX implementation of :func:`numpy.percentile`. @@ -2647,14 +2723,17 @@ def percentile(a: ArrayLike, q: ArrayLike, >>> jnp.percentile(x, q, method='nearest') Array([1., 3., 4.], dtype=float32) """ - a, q = ensure_arraylike("percentile", a, q) + if weights is None: + a, q = ensure_arraylike("percentile", a, q) + else: + a, q, weights = ensure_arraylike("percentile", a, q, weights) q, = promote_dtypes_inexact(q) # TODO(jakevdp): remove the interpolation argument in JAX v0.9.0 if not isinstance(interpolation, DeprecatedArg): raise TypeError("percentile() argument interpolation was removed in JAX" " v0.8.0. Use method instead.") return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, - method=method, keepdims=keepdims) + method=method, keepdims=keepdims, weights=weights) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 @@ -2663,7 +2742,7 @@ def percentile(a: ArrayLike, q: ArrayLike, def nanpercentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: + keepdims: bool = False, *, weights: ArrayLike | None = None, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: """Compute the percentile of the data along the specified axis, ignoring NaN values. JAX implementation of :func:`numpy.nanpercentile`. @@ -2703,7 +2782,10 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, >>> jnp.nanpercentile(x, q) Array([1.5, 3. , 4.5], dtype=float32) """ - a, q = ensure_arraylike("nanpercentile", a, q) + if weights is None: + a, q = ensure_arraylike("nanpercentile", a, q) + else: + a, q, weights = ensure_arraylike("nanpercentile", a, q, weights) q, = promote_dtypes_inexact(q) q = q / 100 # TODO(jakevdp): remove the interpolation argument in JAX v0.9.0 @@ -2711,7 +2793,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, raise TypeError("nanpercentile() argument interpolation was removed in JAX" " v0.8.0. Use method instead.") return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input, - method=method, keepdims=keepdims) + method=method, keepdims=keepdims, weights=weights) @export diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 8a863f68d5e7..902a2086cd7a 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -876,6 +876,82 @@ def testPercentilePrecision(self): x = jnp.float64([1, 2, 3, 4, 7, 10]) self.assertEqual(jnp.percentile(x, 50), 3.5) + @jtu.sample_product( + [dict(a_shape=a_shape, axis=axis) + for a_shape, axis in ( + ((7,), None), + ((6, 7,), None), + ((47, 7), 0), + ((47, 7), ()), + ((4, 101), 1), + ((4, 47, 7), (1, 2)), + ((4, 47, 7), (0, 2)), + ((4, 47, 7), (1, 0, 2)), + ) + ], + a_dtype=default_dtypes, + q_dtype=[np.float32], + q_shape=scalar_shapes + [(1,), (4,)], + keepdims=[False, True], + method=['linear', 'lower', 'higher', 'nearest', 'midpoint', 'inverted_cdf'], +) + def testWeightedQuantile(self, a_shape, a_dtype, q_shape, q_dtype, axis, keepdims, method): + rng = jtu.rand_default(self.rng()) + a = rng(a_shape, a_dtype) + q = rng(q_shape, q_dtype) + if axis is None: + weights_shape = a_shape + elif isinstance(axis, tuple): + weights_shape = a_shape + else: + weights_shape = (a_shape[axis],) + weights = np.abs(rng(weights_shape, a_dtype)) + 1e-3 + + def np_fun(a, q, weights): + return np.quantile(np.array(a), np.array(q), axis=axis, weights=np.array(weights), method=method, keepdims=keepdims) + def jnp_fun(a, q, weights): + return jnp.quantile(a, q, axis=axis, weights=weights, method=method, keepdims=keepdims) + args_maker = lambda: [ + rng(a_shape, a_dtype), + rng(q_shape, q_dtype), + np.abs(rng(weights_shape, a_dtype)) + 1e-3 + ] + if method == "inverted_cdf": + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=1e-6) + self._CompileAndCheck(jnp_fun, args_maker, rtol=1e-6) + else: + with self.assertRaisesRegex(ValueError, "Weighted quantiles are only supported for method='inverted_cdf'"): + jnp_fun(*args_maker()) + + def test_weighted_quantile_negative_weights(self): + a = jnp.array([1, 2, 3, 4, 5], dtype=float) + weights = jnp.array([1, -1, 1, 1, 1], dtype=float) + q = jnp.array([0.5]) + with self.assertRaisesRegex(ValueError, "Weights must be non-negative"): + jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, weights=weights) + + def test_weighted_quantile_all_weights_zero(self): + a = jnp.array([1, 2, 3, 4, 5], dtype=float) + weights = jnp.zeros_like(a) + q = jnp.array([0.5]) + with self.assertRaisesRegex(ValueError, "Sum of weights must not be zero"): + jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, weights=weights) + + def test_weighted_quantile_weights_with_nan(self): + a = jnp.array([1, 2, 3, 4, 5], dtype=float) + weights = jnp.array([1, np.nan, 1, 1, 1], dtype=float) + q = jnp.array([0.5]) + result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, weights=weights) + assert np.isnan(np.array(result)).all() + + def test_weighted_quantile_scalar_q(self): + a = jnp.array([1, 2, 3, 4, 5], dtype=float) + weights = jnp.array([1, 2, 1, 1, 1], dtype=float) + q = 0.5 + result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, weights=weights) + assert jnp.issubdtype(result.dtype, jnp.floating) + assert result.shape == () + @jtu.sample_product( [dict(a_shape=a_shape, axis=axis) for a_shape, axis in (