Skip to content

Commit 7d50a32

Browse files
committed
Add weighted quantile and percentile support with tests
1 parent a230e01 commit 7d50a32

File tree

1 file changed

+50
-20
lines changed

1 file changed

+50
-20
lines changed

jax/_src/numpy/reductions.py

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2485,7 +2485,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24852485
axis = _canonicalize_axis(axis, a.ndim)
24862486

24872487
q, = promote_dtypes_inexact(q)
2488-
q = jnp.atleast_1d(q)
2488+
q = lax_internal.asarray(q)
2489+
if getattr(q, "ndim", 0) == 0:
2490+
q = lax.expand_dims(q, (0,))
24892491
q_shape = q.shape
24902492
q_ndim = q.ndim
24912493
if q_ndim > 1:
@@ -2497,8 +2499,14 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24972499
a, = promote_dtypes_inexact(a)
24982500
else:
24992501
a, weights = promote_dtypes_inexact(a, weights)
2502+
weights = lax.convert_element_type(weights, a.dtype)
25002503
a_shape = a.shape
25012504
w_shape = np.shape(weights)
2505+
if np.ndim(weights) == 0:
2506+
weights = lax.broadcast_in_dim(weights, a_shape, ())
2507+
w_shape = a_shape
2508+
else:
2509+
w_shape = np.shape(weights)
25022510
if w_shape != a_shape:
25032511
if axis is None:
25042512
raise TypeError("Axis must be specified when shapes of a and weights differ.")
@@ -2510,12 +2518,13 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25102518
shape=a_shape,
25112519
broadcast_dimensions=axis
25122520
)
2513-
else:
2521+
w_shape = a_shape
2522+
else:
25142523
if len(w_shape) != 1 or w_shape[0] != a_shape[axis]:
25152524
raise ValueError("Length of weights not compatible with specified axis.")
2516-
weights = lax.expand_dims(weights, axis)
2525+
weights = lax.expand_dims(weights, (axis,))
25172526
weights = _broadcast_to(weights, a.shape)
2518-
2527+
w_shape = a_shape
25192528

25202529
if squash_nans:
25212530
nan_mask = ~lax_internal._isnan(a)
@@ -2526,20 +2535,29 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25262535

25272536
total_weight = sum(weights, axis=axis, keepdims=True)
25282537
a_sorted, weights_sorted = lax.sort_key_val(a, weights, dimension=axis)
2529-
cum_weights = lax.cumsum(weights_sorted, axis=axis)
2538+
cum_weights = cumsum(weights_sorted, axis=axis)
25302539
cum_weights_norm = lax.div(cum_weights, total_weight)
25312540

25322541
def _weighted_quantile(qi):
2533-
index_dtype = dtypes.default_int_dtype()
2542+
qi = lax.convert_element_type(qi, cum_weights_norm.dtype)
2543+
index_dtype = dtypes.canonicalize_dtype(dtypes.int_)
25342544
idx = sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype, keepdims=keepdims)
2535-
idx = lax.clamp(0, idx, a_sorted.shape[axis] - 1)
2536-
val = jnp.take_along_axis(a_sorted, idx, axis)
2537-
2538-
idx_prev = lax.clamp(idx - 1, 0, a_sorted.shape[axis] - 1)
2539-
val_prev = jnp.take_along_axis(a_sorted, idx_prev, axis)
2540-
cw_prev = jnp.take_along_axis(cum_weights_norm, idx_prev, axis)
2541-
cw_next = jnp.take_along_axis(cum_weights_norm, idx, axis)
2542-
2545+
idx = lax.clamp(_lax_const(idx, 0), idx, _lax_const(idx, a_sorted.shape[axis] - 1))
2546+
idx_prev = lax.clamp(idx - 1, _lax_const(idx, 0), _lax_const(idx, a_sorted.shape[axis] - 1))
2547+
2548+
slice_sizes = list(a_shape)
2549+
slice_sizes[axis] = 1
2550+
offset_start = q_ndim
2551+
total_offset_dims = len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1
2552+
dnums = lax.GatherDimensionNumbers(
2553+
offset_dims=tuple(range(offset_start, total_offset_dims)),
2554+
collapsed_slice_dims=(axis,),
2555+
start_index_map=(axis,)
2556+
)
2557+
val = lax.gather(a_sorted, idx[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes)
2558+
val_prev = lax.gather(a_sorted, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes)
2559+
cw_prev = lax.gather(cum_weights_norm, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes)
2560+
cw_next = lax.gather(cum_weights_norm, idx[..., None], dimension_numbers=dnums, slice_sizes=slice_sizes)
25432561
if method == "linear":
25442562
denom = cw_next - cw_prev
25452563
denom = _where(denom == 0, 1, denom)
@@ -2560,7 +2578,15 @@ def _weighted_quantile(qi):
25602578
return out
25612579

25622580
result = jax.vmap(_weighted_quantile)(q)
2563-
return result
2581+
if keepdims and keepdim:
2582+
if q_ndim > 0:
2583+
keepdim = [q_shape[0], *keepdim]
2584+
result = result.reshape(tuple(keepdim))
2585+
else:
2586+
if q_ndim == 0 or (q_ndim == 1 and q_shape[0] == 1):
2587+
if result.ndim > 0 and result.shape[0] == 1:
2588+
result = lax.squeeze(result, (0,))
2589+
return lax.convert_element_type(result, a.dtype)
25642590

25652591
if squash_nans:
25662592
a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
@@ -2576,8 +2602,8 @@ def _weighted_quantile(qi):
25762602
high_weight = lax.sub(q, low)
25772603
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)
25782604

2579-
low = lax.max(lax._const(low, 0), lax.min(low, counts - 1))
2580-
high = lax.max(lax._const(high, 0), lax.min(high, counts - 1))
2605+
low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1))
2606+
high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1))
25812607
low = lax.convert_element_type(low, int)
25822608
high = lax.convert_element_type(high, int)
25832609
out_shape = q_shape + shape_after_reduction
@@ -2601,8 +2627,8 @@ def _weighted_quantile(qi):
26012627
high_weight = lax.sub(q, low)
26022628
low_weight = lax.sub(_lax_const(high_weight, 1), high_weight)
26032629

2604-
low = lax.clamp(lax._const(low, 0), low, n - 1)
2605-
high = lax.clamp(lax._const(high, 0), high, n - 1)
2630+
low = lax.clamp(_lax_const(low, 0), low, n - 1)
2631+
high = lax.clamp(_lax_const(high, 0), high, n - 1)
26062632
low = lax.convert_element_type(low, int)
26072633
high = lax.convert_element_type(high, int)
26082634

@@ -2635,7 +2661,7 @@ def _weighted_quantile(qi):
26352661
pred = lax.le(high_weight, _lax_const(high_weight, 0.5))
26362662
result = lax.select(pred, low_value, high_value)
26372663
elif method == "midpoint":
2638-
result = lax.mul(lax.add(low_value, high_value), lax._const(low_value, 0.5))
2664+
result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5))
26392665
elif method == "inverted_cdf":
26402666
result = high_value
26412667
else:
@@ -2644,6 +2670,10 @@ def _weighted_quantile(qi):
26442670
if q_ndim > 0:
26452671
keepdim = [np.shape(q)[0], *keepdim]
26462672
result = result.reshape(keepdim)
2673+
else:
2674+
if q_ndim == 0 or (q_ndim == 1 and q_shape[0] == 1):
2675+
if result.ndim > 0 and result.shape[0] == 1:
2676+
result = lax.squeeze(result, (0,))
26472677
return lax.convert_element_type(result, a.dtype)
26482678

26492679

0 commit comments

Comments
 (0)