@@ -2485,7 +2485,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24852485 axis = _canonicalize_axis (axis , a .ndim )
24862486
24872487 q , = promote_dtypes_inexact (q )
2488- q = jnp .atleast_1d (q )
2488+ q = lax_internal .asarray (q )
2489+ if getattr (q , "ndim" , 0 ) == 0 :
2490+ q = lax .expand_dims (q , (0 ,))
24892491 q_shape = q .shape
24902492 q_ndim = q .ndim
24912493 if q_ndim > 1 :
@@ -2497,8 +2499,14 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24972499 a , = promote_dtypes_inexact (a )
24982500 else :
24992501 a , weights = promote_dtypes_inexact (a , weights )
2502+ weights = lax .convert_element_type (weights , a .dtype )
25002503 a_shape = a .shape
25012504 w_shape = np .shape (weights )
2505+ if np .ndim (weights ) == 0 :
2506+ weights = lax .broadcast_in_dim (weights , a_shape , ())
2507+ w_shape = a_shape
2508+ else :
2509+ w_shape = np .shape (weights )
25022510 if w_shape != a_shape :
25032511 if axis is None :
25042512 raise TypeError ("Axis must be specified when shapes of a and weights differ." )
@@ -2510,12 +2518,13 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25102518 shape = a_shape ,
25112519 broadcast_dimensions = axis
25122520 )
2513- else :
2521+ w_shape = a_shape
2522+ else :
25142523 if len (w_shape ) != 1 or w_shape [0 ] != a_shape [axis ]:
25152524 raise ValueError ("Length of weights not compatible with specified axis." )
2516- weights = lax .expand_dims (weights , axis )
2525+ weights = lax .expand_dims (weights , ( axis ,) )
25172526 weights = _broadcast_to (weights , a .shape )
2518-
2527+ w_shape = a_shape
25192528
25202529 if squash_nans :
25212530 nan_mask = ~ lax_internal ._isnan (a )
@@ -2526,20 +2535,29 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25262535
25272536 total_weight = sum (weights , axis = axis , keepdims = True )
25282537 a_sorted , weights_sorted = lax .sort_key_val (a , weights , dimension = axis )
2529- cum_weights = lax . cumsum (weights_sorted , axis = axis )
2538+ cum_weights = cumsum (weights_sorted , axis = axis )
25302539 cum_weights_norm = lax .div (cum_weights , total_weight )
25312540
25322541 def _weighted_quantile (qi ):
2533- index_dtype = dtypes .default_int_dtype ()
2542+ qi = lax .convert_element_type (qi , cum_weights_norm .dtype )
2543+ index_dtype = dtypes .canonicalize_dtype (dtypes .int_ )
25342544 idx = sum (lax .lt (cum_weights_norm , qi ), axis = axis , dtype = index_dtype , keepdims = keepdims )
2535- idx = lax .clamp (0 , idx , a_sorted .shape [axis ] - 1 )
2536- val = jnp .take_along_axis (a_sorted , idx , axis )
2537-
2538- idx_prev = lax .clamp (idx - 1 , 0 , a_sorted .shape [axis ] - 1 )
2539- val_prev = jnp .take_along_axis (a_sorted , idx_prev , axis )
2540- cw_prev = jnp .take_along_axis (cum_weights_norm , idx_prev , axis )
2541- cw_next = jnp .take_along_axis (cum_weights_norm , idx , axis )
2542-
2545+ idx = lax .clamp (_lax_const (idx , 0 ), idx , _lax_const (idx , a_sorted .shape [axis ] - 1 ))
2546+ idx_prev = lax .clamp (idx - 1 , _lax_const (idx , 0 ), _lax_const (idx , a_sorted .shape [axis ] - 1 ))
2547+
2548+ slice_sizes = list (a_shape )
2549+ slice_sizes [axis ] = 1
2550+ offset_start = q_ndim
2551+ total_offset_dims = len (a_shape ) + q_ndim if keepdims else len (a_shape ) + q_ndim - 1
2552+ dnums = lax .GatherDimensionNumbers (
2553+ offset_dims = tuple (range (offset_start , total_offset_dims )),
2554+ collapsed_slice_dims = (axis ,),
2555+ start_index_map = (axis ,)
2556+ )
2557+ val = lax .gather (a_sorted , idx [..., None ], dimension_numbers = dnums , slice_sizes = slice_sizes )
2558+ val_prev = lax .gather (a_sorted , idx_prev [..., None ], dimension_numbers = dnums , slice_sizes = slice_sizes )
2559+ cw_prev = lax .gather (cum_weights_norm , idx_prev [..., None ], dimension_numbers = dnums , slice_sizes = slice_sizes )
2560+ cw_next = lax .gather (cum_weights_norm , idx [..., None ], dimension_numbers = dnums , slice_sizes = slice_sizes )
25432561 if method == "linear" :
25442562 denom = cw_next - cw_prev
25452563 denom = _where (denom == 0 , 1 , denom )
@@ -2560,7 +2578,15 @@ def _weighted_quantile(qi):
25602578 return out
25612579
25622580 result = jax .vmap (_weighted_quantile )(q )
2563- return result
2581+ if keepdims and keepdim :
2582+ if q_ndim > 0 :
2583+ keepdim = [q_shape [0 ], * keepdim ]
2584+ result = result .reshape (tuple (keepdim ))
2585+ else :
2586+ if q_ndim == 0 or (q_ndim == 1 and q_shape [0 ] == 1 ):
2587+ if result .ndim > 0 and result .shape [0 ] == 1 :
2588+ result = lax .squeeze (result , (0 ,))
2589+ return lax .convert_element_type (result , a .dtype )
25642590
25652591 if squash_nans :
25662592 a = _where (lax_internal ._isnan (a ), np .nan , a ) # Ensure nans are positive so they sort to the end.
@@ -2576,8 +2602,8 @@ def _weighted_quantile(qi):
25762602 high_weight = lax .sub (q , low )
25772603 low_weight = lax .sub (_lax_const (high_weight , 1 ), high_weight )
25782604
2579- low = lax .max (lax . _const (low , 0 ), lax .min (low , counts - 1 ))
2580- high = lax .max (lax . _const (high , 0 ), lax .min (high , counts - 1 ))
2605+ low = lax .max (_lax_const (low , 0 ), lax .min (low , counts - 1 ))
2606+ high = lax .max (_lax_const (high , 0 ), lax .min (high , counts - 1 ))
25812607 low = lax .convert_element_type (low , int )
25822608 high = lax .convert_element_type (high , int )
25832609 out_shape = q_shape + shape_after_reduction
@@ -2601,8 +2627,8 @@ def _weighted_quantile(qi):
26012627 high_weight = lax .sub (q , low )
26022628 low_weight = lax .sub (_lax_const (high_weight , 1 ), high_weight )
26032629
2604- low = lax .clamp (lax . _const (low , 0 ), low , n - 1 )
2605- high = lax .clamp (lax . _const (high , 0 ), high , n - 1 )
2630+ low = lax .clamp (_lax_const (low , 0 ), low , n - 1 )
2631+ high = lax .clamp (_lax_const (high , 0 ), high , n - 1 )
26062632 low = lax .convert_element_type (low , int )
26072633 high = lax .convert_element_type (high , int )
26082634
@@ -2635,7 +2661,7 @@ def _weighted_quantile(qi):
26352661 pred = lax .le (high_weight , _lax_const (high_weight , 0.5 ))
26362662 result = lax .select (pred , low_value , high_value )
26372663 elif method == "midpoint" :
2638- result = lax .mul (lax .add (low_value , high_value ), lax . _const (low_value , 0.5 ))
2664+ result = lax .mul (lax .add (low_value , high_value ), _lax_const (low_value , 0.5 ))
26392665 elif method == "inverted_cdf" :
26402666 result = high_value
26412667 else :
@@ -2644,6 +2670,10 @@ def _weighted_quantile(qi):
26442670 if q_ndim > 0 :
26452671 keepdim = [np .shape (q )[0 ], * keepdim ]
26462672 result = result .reshape (keepdim )
2673+ else :
2674+ if q_ndim == 0 or (q_ndim == 1 and q_shape [0 ] == 1 ):
2675+ if result .ndim > 0 and result .shape [0 ] == 1 :
2676+ result = lax .squeeze (result , (0 ,))
26472677 return lax .convert_element_type (result , a .dtype )
26482678
26492679
0 commit comments