Skip to content

Commit 4f522e6

Browse files
committed
Add weighted quantile and percentile support with tests
1 parent 7b967cb commit 4f522e6

File tree

2 files changed

+45
-60
lines changed

2 files changed

+45
-60
lines changed

jax/_src/numpy/reductions.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2379,7 +2379,6 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
23792379
>>> jnp.quantile(x, q, method='nearest')
23802380
Array([2., 4., 7.], dtype=float32)
23812381
"""
2382-
check_arraylike("quantile", a, q)
23832382
if weights is None:
23842383
a, q = ensure_arraylike("quantile", a, q)
23852384
else:
@@ -2390,7 +2389,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
23902389
if not isinstance(interpolation, DeprecatedArg):
23912390
raise TypeError("quantile() argument interpolation was removed in JAX"
23922391
" v0.8.0. Use method instead.")
2393-
return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, False, weights)
2392+
return _quantile(a, q, axis, method, keepdims, False, weights)
23942393

23952394
# TODO(jakevdp): interpolation argument deprecated 2024-05-16
23962395
@export
@@ -2439,7 +2438,6 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
24392438
>>> jnp.nanquantile(x, q)
24402439
Array([1.5, 3. , 4.5], dtype=float32)
24412440
"""
2442-
check_arraylike("nanquantile", a, q)
24432441
if weights is None:
24442442
a, q = ensure_arraylike("nanquantile", a, q)
24452443
else:
@@ -2451,7 +2449,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
24512449
if not isinstance(interpolation, DeprecatedArg):
24522450
raise TypeError("nanquantile() argument interpolation was removed in JAX"
24532451
" v0.8.0. Use method instead.")
2454-
return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, True, weights)
2452+
return _quantile(a, q, axis, method, keepdims, True, weights)
24552453

24562454
def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24572455
method: str, keepdims: bool, squash_nans: bool, weights: Array | None = None) -> Array:
@@ -2498,7 +2496,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24982496
if weights is None:
24992497
a, = promote_dtypes_inexact(a)
25002498
else:
2501-
a, q = promote_dtypes_inexact(a, q)
2499+
a, weights = promote_dtypes_inexact(a, weights)
25022500
a_shape = a.shape
25032501
w_shape = np.shape(weights)
25042502
if w_shape != a_shape:
@@ -2513,11 +2511,6 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25132511
weights = lax.expand_dims(weights, axis)
25142512
weights = _broadcast_to(weights, a.shape)
25152513

2516-
weights_have_nan = jnp.any(jnp.isnan(weights))
2517-
if weights_have_nan:
2518-
out_shape = q.shape if hasattr(q, "shape") and getattr(q, "ndim", 0) > 0 else ()
2519-
return lax.full(out_shape, np.nan, dtype=a.dtype)
2520-
25212514
if squash_nans:
25222515
nan_mask = ~lax_internal._isnan(a)
25232516
weights = _where(nan_mask, weights, 0)
@@ -2530,7 +2523,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25302523
cum_weights = lax.cumsum(weights_sorted, axis=axis)
25312524
cum_weights_norm = lax.div(cum_weights, total_weight)
25322525

2533-
def _weighted_quantile(qi, weights_have_nan=weights_have_nan):
2526+
def _weighted_quantile(qi):
25342527
index_dtype = dtypes.default_int_dtype()
25352528
idx = sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype)
25362529
idx = lax.clamp(0, idx, a_sorted.shape[axis] - 1)
@@ -2558,9 +2551,6 @@ def _weighted_quantile(qi, weights_have_nan=weights_have_nan):
25582551
out = val
25592552
else:
25602553
raise ValueError(f"{method=!r} not recognized")
2561-
if weights_have_nan:
2562-
out = lax.full_like(out, np.nan)
2563-
out = lax.squeeze(out, axis=axis)
25642554
return out
25652555

25662556
if q.ndim == 0:
@@ -2700,7 +2690,6 @@ def percentile(a: ArrayLike, q: ArrayLike,
27002690
>>> jnp.percentile(x, q, method='nearest')
27012691
Array([1., 3., 4.], dtype=float32)
27022692
"""
2703-
check_arraylike("percentile", a, q)
27042693
if weights is None:
27052694
a, q = ensure_arraylike("percentile", a, q)
27062695
else:
@@ -2764,7 +2753,6 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
27642753
>>> jnp.nanpercentile(x, q)
27652754
Array([1.5, 3. , 4.5], dtype=float32)
27662755
"""
2767-
check_arraylike("nanpercentile", a, q)
27682756
if weights is None:
27692757
a, q = ensure_arraylike("nanpercentile", a, q)
27702758
else:

tests/lax_numpy_reducers_test.py

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from functools import partial
1818
import itertools
1919
import unittest
20-
import pytest
2120

2221
from absl.testing import absltest
2322
from absl.testing import parameterized
@@ -764,51 +763,58 @@ def testPercentilePrecision(self):
764763
x = jnp.float64([1, 2, 3, 4, 7, 10])
765764
self.assertEqual(jnp.percentile(x, 50), 3.5)
766765

767-
def test_weighted_quantile_all_weights_one(self):
768-
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
769-
weights = jnp.ones_like(a)
770-
q = jnp.array([0.25, 0.5, 0.75])
771-
result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)
772-
expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf")
773-
np.testing.assert_allclose(np.array(result), expected, rtol=1e-6)
774-
775-
def test_weighted_quantile_multiple_q(self):
776-
a = jnp.arange(10, dtype=float)
777-
weights = jnp.ones_like(a)
778-
q = jnp.array([0.25, 0.5, 0.75])
779-
result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)
780-
expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf")
781-
np.testing.assert_allclose(np.array(result), expected, rtol=1e-6)
782-
783-
def test_weighted_quantile_keepdims(self):
784-
a = jnp.array([1, 2, 3, 4], dtype=float)
785-
weights = jnp.array([1, 1, 1, 1], dtype=float)
786-
q = 0.5
787-
result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=True, squash_nans=False, weights=weights)
788-
expected = np.quantile(np.array(a), np.array(q), axis=0, keepdims=True, weights=np.array(weights), method="inverted_cdf")
789-
np.testing.assert_allclose(np.array(result), expected, rtol=1e-6)
766+
@jtu.sample_product(
767+
[dict(a_shape=a_shape, axis=axis)
768+
for a_shape, axis in (
769+
((7,), None),
770+
((6, 7,), None),
771+
((47, 7), 0),
772+
((47, 7), ()),
773+
((4, 101), 1),
774+
((4, 47, 7), (1, 2)),
775+
((4, 47, 7), (0, 2)),
776+
((4, 47, 7), (1, 0, 2)),
777+
)
778+
],
779+
a_dtype=default_dtypes,
780+
q_dtype=[np.float32],
781+
q_shape=scalar_shapes + [(1,), (4,)],
782+
keepdims=[False, True],
783+
method=['linear', 'lower', 'higher', 'nearest', 'midpoint', 'inverted_cdf'],
784+
)
785+
def testWeightedQuantile(self, a_shape, a_dtype, q_shape, q_dtype, axis, keepdims, method):
786+
rng = jtu.rand_default(self.rng())
787+
a = rng(a_shape, a_dtype)
788+
q = rng(q_shape, q_dtype)
789+
if axis is None:
790+
weights_shape = a_shape
791+
elif isinstance(axis, tuple):
792+
weights_shape = tuple(a_shape[i] for i in axis)
793+
else:
794+
weights_shape = (a_shape[axis],)
795+
weights = np.abs(rng(weights_shape, a_dtype)) + 1e-3
790796

791-
def test_weighted_quantile_linear(self):
792-
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
793-
weights = jnp.array([1, 2, 1, 1, 1], dtype=float)
794-
q = jnp.array([0.5])
795-
result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)
796-
expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf")
797-
np.testing.assert_allclose(np.array(result), expected, rtol=1e-6)
797+
def np_fun(a, q, weights):
798+
return np.quantile(np.array(a), np.array(q), axis=axis, weights=np.array(weights), method=method, keepdims=keepdims)
799+
def jnp_fun(a, q, weights):
800+
return jnp.quantile(a, q, axis=axis, weights=weights, method=method, keepdims=keepdims)
801+
args_maker = lambda: [a, q, weights]
802+
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=1e-6)
803+
self._CompileAndCheck(jnp_fun, args_maker, rtol=1e-6)
798804

799805
def test_weighted_quantile_negative_weights(self):
800806
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
801807
weights = jnp.array([1, -1, 1, 1, 1], dtype=float)
802808
q = jnp.array([0.5])
803-
with pytest.raises(ValueError):
804-
jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights)
809+
with self.assertRaisesRegex(ValueError, "Weights must be non-negative"):
810+
jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights)
805811

806812
def test_weighted_quantile_all_weights_zero(self):
807813
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
808814
weights = jnp.zeros_like(a)
809815
q = jnp.array([0.5])
810-
with pytest.raises(ValueError):
811-
jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights)
816+
with self.assertRaisesRegex(ValueError, "Sum of weights must not be zero"):
817+
jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights)
812818

813819
def test_weighted_quantile_weights_with_nan(self):
814820
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
@@ -825,15 +831,6 @@ def test_weighted_quantile_scalar_q(self):
825831
assert jnp.issubdtype(result.dtype, jnp.floating)
826832
assert result.shape == ()
827833

828-
def test_weighted_quantile_jit(self):
829-
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
830-
weights = jnp.array([1, 2, 1, 1, 1], dtype=float)
831-
q = jnp.array([0.25, 0.5, 0.75])
832-
quantile_jit = jax.jit(lambda a, q, weights: jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights))
833-
result = quantile_jit(a, q, weights)
834-
expected = np.quantile(np.array(a), np.array(q), axis=0, weights=np.array(weights), method="inverted_cdf")
835-
np.testing.assert_allclose(np.array(result), expected, rtol=1e-6)
836-
837834
@jtu.sample_product(
838835
[dict(a_shape=a_shape, axis=axis)
839836
for a_shape, axis in (

0 commit comments

Comments
 (0)