Skip to content

Commit a230e01

Browse files
committed
Add weighted quantile and percentile support with tests
1 parent cef1731 commit a230e01

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

jax/_src/numpy/reductions.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import jax
2727
from jax import lax
28+
import jax._src.numpy as jnp
2829
from jax._src import api
2930
from jax._src import core
3031
from jax._src import deprecations
@@ -2452,7 +2453,6 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
24522453

24532454
def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24542455
method: str, keepdims: bool, squash_nans: bool, weights: Array | None = None) -> Array:
2455-
from jax import numpy as jnp
24562456
if method not in ["linear", "lower", "higher", "midpoint", "nearest", "inverted_cdf"]:
24572457
raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest', or 'inverted_cdf'")
24582458
keepdim = []
@@ -2515,7 +2515,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25152515
raise ValueError("Length of weights not compatible with specified axis.")
25162516
weights = lax.expand_dims(weights, axis)
25172517
weights = _broadcast_to(weights, a.shape)
2518-
2518+
25192519

25202520
if squash_nans:
25212521
nan_mask = ~lax_internal._isnan(a)
@@ -2558,10 +2558,8 @@ def _weighted_quantile(qi):
25582558
else:
25592559
raise ValueError(f"{method=!r} not recognized")
25602560
return out
2561-
2561+
25622562
result = jax.vmap(_weighted_quantile)(q)
2563-
if q.shape == (1,):
2564-
result = result[0]
25652563
return result
25662564

25672565
if squash_nans:

0 commit comments

Comments
 (0)