|
25 | 25 |
|
26 | 26 | import jax |
27 | 27 | from jax import lax |
| 28 | +import jax._src.numpy as jnp |
28 | 29 | from jax._src import api |
29 | 30 | from jax._src import core |
30 | 31 | from jax._src import deprecations |
@@ -2452,7 +2453,6 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = |
2452 | 2453 |
|
2453 | 2454 | def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, |
2454 | 2455 | method: str, keepdims: bool, squash_nans: bool, weights: Array | None = None) -> Array: |
2455 | | - from jax import numpy as jnp |
2456 | 2456 | if method not in ["linear", "lower", "higher", "midpoint", "nearest", "inverted_cdf"]: |
2457 | 2457 | raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest', or 'inverted_cdf'") |
2458 | 2458 | keepdim = [] |
@@ -2515,7 +2515,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, |
2515 | 2515 | raise ValueError("Length of weights not compatible with specified axis.") |
2516 | 2516 | weights = lax.expand_dims(weights, axis) |
2517 | 2517 | weights = _broadcast_to(weights, a.shape) |
2518 | | - |
| 2518 | + |
2519 | 2519 |
|
2520 | 2520 | if squash_nans: |
2521 | 2521 | nan_mask = ~lax_internal._isnan(a) |
@@ -2558,10 +2558,8 @@ def _weighted_quantile(qi): |
2558 | 2558 | else: |
2559 | 2559 | raise ValueError(f"{method=!r} not recognized") |
2560 | 2560 | return out |
2561 | | - |
| 2561 | + |
2562 | 2562 | result = jax.vmap(_weighted_quantile)(q) |
2563 | | - if q.shape == (1,): |
2564 | | - result = result[0] |
2565 | 2563 | return result |
2566 | 2564 |
|
2567 | 2565 | if squash_nans: |
|
0 commit comments