Skip to content

Commit f7ab683

Browse files
committed
Add weighted quantile and percentile support with tests
1 parent f5d2177 commit f7ab683

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

jax/_src/numpy/reductions.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2383,11 +2383,8 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
23832383
raise ValueError("jax.numpy.quantile does not support overwrite_input=True "
23842384
"or out != None")
23852385
if not isinstance(interpolation, DeprecatedArg):
2386-
deprecations.warn(
2387-
"jax-numpy-quantile-interpolation",
2388-
("The interpolation= argument to 'quantile' is deprecated. "
2389-
"Use 'method=' instead."), stacklevel=2)
2390-
method = interpolation
2386+
raise TypeError("nanquantile() argument interpolation was removed in JAX"
2387+
" v0.8.0. Use method instead.")
23912388
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False, weights)
23922389

23932390
# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@@ -2443,11 +2440,8 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
24432440
"out != None")
24442441
raise ValueError(msg)
24452442
if not isinstance(interpolation, DeprecatedArg):
2446-
deprecations.warn(
2447-
"jax-numpy-quantile-interpolation",
2448-
("The interpolation= argument to 'nanquantile' is deprecated. "
2449-
"Use 'method=' instead."), stacklevel=2)
2450-
method = interpolation
2443+
raise TypeError("nanquantile() argument interpolation was removed in JAX"
2444+
" v0.8.0. Use method instead.")
24512445
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, True, weights)
24522446

24532447
def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
@@ -2485,7 +2479,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24852479
axis = _canonicalize_axis(axis, a.ndim)
24862480

24872481
# Ensure q is an array and inexact
2488-
q = lax_internal.asarray(q)
2482+
q = lax.asarray(q)
24892483
q, = promote_dtypes_inexact(q)
24902484

24912485
q_shape = q.shape

0 commit comments

Comments
 (0)