@@ -2337,7 +2337,7 @@ def cumulative_prod(
23372337@partial (api .jit , static_argnames = ('axis' , 'overwrite_input' , 'interpolation' , 'keepdims' , 'method' ))
23382338def quantile (a : ArrayLike , q : ArrayLike , axis : int | tuple [int , ...] | None = None ,
23392339 out : None = None , overwrite_input : bool = False , method : str = "linear" ,
2340- keepdims : bool = False , weights : ArrayLike | None = None , * ,
2340+ keepdims : bool = False , * , weights : ArrayLike | None = None ,
23412341 interpolation : DeprecatedArg | str = DeprecatedArg ()) -> Array :
23422342 """Compute the quantile of the data along the specified axis.
23432343
@@ -2395,7 +2395,8 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
23952395@partial (api .jit , static_argnames = ('axis' , 'overwrite_input' , 'interpolation' , 'keepdims' , 'method' ))
23962396def nanquantile (a : ArrayLike , q : ArrayLike , axis : int | tuple [int , ...] | None = None ,
23972397 out : None = None , overwrite_input : bool = False , method : str = "linear" ,
2398- keepdims : bool = False , * , interpolation : DeprecatedArg | str = DeprecatedArg ()) -> Array :
2398+ keepdims : bool = False , * , weights : ArrayLike | None = None ,
2399+ interpolation : DeprecatedArg | str = DeprecatedArg ()) -> Array :
23992400 """Compute the quantile of the data along the specified axis, ignoring NaNs.
24002401
24012402 JAX implementation of :func:`numpy.nanquantile`.
@@ -2447,12 +2448,12 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
24472448 ("The interpolation= argument to 'nanquantile' is deprecated. "
24482449 "Use 'method=' instead." ), stacklevel = 2 )
24492450 method = interpolation
2450- return _quantile (lax_internal .asarray (a ), lax_internal .asarray (q ), axis , method , keepdims , True )
2451+ return _quantile (lax_internal .asarray (a ), lax_internal .asarray (q ), axis , method , keepdims , True , weights )
24512452
24522453def _quantile (a : Array , q : Array , axis : int | tuple [int , ...] | None ,
24532454 method : str , keepdims : bool , squash_nans : bool , weights : ArrayLike | None = None ) -> Array :
2454- if method not in ["linear" , "lower" , "higher" , "midpoint" , "nearest" ]:
2455- raise ValueError ("method can only be 'linear', 'lower', 'higher', 'midpoint', or 'nearest '" )
2455+ if method not in ["linear" , "lower" , "higher" , "midpoint" , "nearest" , "inverted_cdf" ]:
2456+ raise ValueError ("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest', or 'inverted_cdf '" )
24562457 a , = promote_dtypes_inexact (a )
24572458 keepdim = []
24582459 if dtypes .issubdtype (a .dtype , np .complexfloating ):
@@ -2482,6 +2483,10 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24822483 axis = _canonicalize_axis (- 1 , a .ndim )
24832484 else :
24842485 axis = _canonicalize_axis (axis , a .ndim )
2486+
2487+ # Ensure q is an array and inexact
2488+ q = lax_internal .asarray (q )
2489+ q , = promote_dtypes_inexact (q )
24852490
24862491 q_shape = q .shape
24872492 q_ndim = q .ndim
@@ -2492,63 +2497,103 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24922497 # Handle weights
24932498 if weights is not None :
24942499 a , weights = promote_dtypes_inexact (a , weights )
2495- if axis is None :
2496- a = a .ravel ()
2497- weights = weights .ravel ()
2498- axis = 0
2500+ a_shape = a .shape
2501+ w_shape = np .shape (weights )
2502+ if w_shape != a_shape :
2503+ if len (w_shape ) != 1 :
2504+ raise ValueError ("1D weights expected when shapes of a and weights differ." )
2505+ if axis is None :
2506+ raise TypeError ("Axis must be specified when shapes of a and weights differ." )
2507+ if w_shape [0 ] != a_shape [axis ]:
2508+ raise ValueError ("Length of weights not compatible with specified axis." )
2509+ resh = [1 ] * a .ndim
2510+ resh [axis ] = w_shape [0 ]
2511+ weights = lax .reshape (lax_internal .asarray (weights ), tuple (resh ))
2512+ weights = _broadcast_to (weights , a .shape )
2513+
2514+ if isinstance (weights , core .Tracer ):
2515+ weights_arr = None
24992516 else :
2500- weights = _broadcast_to (weights , a .shape )
2517+ try :
2518+ weights_arr = np .asarray (weights )
2519+ except Exception :
2520+ weights_arr = None
2521+
2522+ if weights_arr is not None :
2523+ if np .any (weights_arr < 0 ):
2524+ raise ValueError ("Weights must be non-negative." )
2525+ if np .all (weights_arr == 0 ):
2526+ raise ValueError ("Sum of weights must not be zero." )
2527+ if np .any (np .isnan (weights_arr )):
2528+ out_shape = q .shape if hasattr (q , "shape" ) and getattr (q , "ndim" , 0 ) > 0 else ()
2529+ return lax .full (out_shape , np .nan , dtype = a .dtype )
2530+ weights_have_nan = np .any (np .isnan (weights_arr ))
2531+ else :
2532+ weights_have_nan = False
2533+
25012534 if squash_nans :
25022535 nan_mask = ~ lax_internal ._isnan (a )
2503- if axis is None :
2504- a = a [nan_mask ]
2505- weights = weights [nan_mask ]
2506- else :
2507- weights = _where (nan_mask , weights , 0 )
2508- a_sorted , weights_sorted = lax .sort_key_val (a , weights , dimension = axis )
2536+ weights = _where (nan_mask , weights , 0 )
2537+ else :
2538+ with jax .debug_nans (False ):
2539+ a = _where (any (lax_internal ._isnan (a ), axis = axis , keepdims = True ), np .nan , a )
25092540
2541+ total_weight = sum (weights , axis = axis , keepdims = True )
2542+ a_sorted , weights_sorted = lax .sort_key_val (a , weights , dimension = axis )
25102543 cum_weights = lax .cumsum (weights_sorted , axis = axis )
2511- total_weight = lax .sum (weights_sorted , axis = axis , keepdims = True )
2512- if lax_internal ._all (total_weight == 0 ):
2513- raise ValueError ("Sum of weights must not be zero." )
2514- cum_weights_norm = cum_weights / total_weight
2515- quantile_pos = q
2516- mask = cum_weights_norm >= quantile_pos [..., None ]
2517- idx = lax .argmin (mask .astype (int ), axis = axis )
2518- idx_prev = lax .max (idx - 1 , _lax_const (idx , 0 ))
2519- idx_next = idx
2520- gather_shape = list (a_sorted .shape )
2521- gather_shape [axis ] = 1
2544+ cum_weights_norm = lax .div (cum_weights , total_weight )
2545+
2546+ slice_sizes = list (a_sorted .shape )
2547+ slice_sizes [axis ] = 1
25222548 dnums = lax .GatherDimensionNumbers (
2523- offset_dims = tuple (range (len (a_sorted .shape ))),
2524- collapsed_slice_dims = (axis ,),
2549+ offset_dims = tuple (range (
2550+ 0 ,
2551+ len (a_sorted .shape ) if keepdims else len (a_sorted .shape ) - 1 )),
2552+ collapsed_slice_dims = () if keepdims else (axis ,),
25252553 start_index_map = (axis ,))
2526- prev_value = lax .gather (a_sorted , idx_prev [..., None ], dimension_numbers = dnums , slice_sizes = gather_shape )
2527- next_value = lax .gather (a_sorted , idx_next [..., None ], dimension_numbers = dnums , slice_sizes = gather_shape )
2528- prev_cumw = lax .gather (cum_weights_norm , idx_prev [..., None ], dimension_numbers = dnums , slice_sizes = gather_shape )
2529- next_cumw = lax .gather (cum_weights_norm , idx_next [..., None ], dimension_numbers = dnums , slice_sizes = gather_shape )
2530-
2531- if method == "linear" :
2532- denom = next_cumw - prev_cumw
2533- denom = lax .select (denom == 0 , _lax_const (denom , 1 ), denom )
2534- weight = (quantile_pos - prev_cumw ) / denom
2535- result = prev_value * (1 - weight ) + next_value * weight
2536- elif method == "lower" :
2537- result = prev_value
2538- elif method == "higher" :
2539- result = next_value
2540- elif method == "nearest" :
2541- use_prev = (quantile_pos - prev_cumw ) < (next_cumw - quantile_pos )
2542- result = lax .select (use_prev , prev_value , next_value )
2543- elif method == "midpoint" :
2544- result = (prev_value + next_value ) / 2
2545- else :
2546- raise ValueError (f"{ method = !r} not recognized" )
2547-
2548- if not keepdims :
2549- result = lax .squeeze (result , axis )
2550- return lax .convert_element_type (result , a .dtype )
25512554
2555+ def _weighted_quantile (qi , weights_have_nan = weights_have_nan ):
2556+ index_dtype = dtypes .canonicalize_dtype (int )
2557+ idx = sum (lax .lt (cum_weights_norm , qi ), axis = axis , dtype = index_dtype )
2558+ idx = lax .clamp (0 , idx , a_sorted .shape [axis ] - 1 )
2559+ slicer = [slice (None )] * a_sorted .ndim
2560+ slicer [axis ] = idx
2561+ val = a_sorted [tuple (slicer )]
2562+
2563+ idx_prev = lax .clamp (idx - 1 , 0 , a_sorted .shape [axis ] - 1 )
2564+ slicer_prev = slicer .copy ()
2565+ slicer_prev [axis ] = idx_prev
2566+ val_prev = a_sorted [tuple (slicer_prev )]
2567+ cw_prev = cum_weights_norm [tuple (slicer_prev )]
2568+ cw_next = cum_weights_norm [tuple (slicer )]
2569+
2570+ if method == "linear" :
2571+ denom = cw_next - cw_prev
2572+ denom = _where (denom == 0 , 1 , denom )
2573+ weight = (qi - cw_prev ) / denom
2574+ out = val_prev * (1 - weight ) + val * weight
2575+ elif method == "lower" :
2576+ out = val_prev
2577+ elif method == "higher" :
2578+ out = val
2579+ elif method == "nearest" :
2580+ out = _where (lax .abs (qi - cw_prev ) < lax .abs (qi - cw_next ), val_prev , val )
2581+ elif method == "midpoint" :
2582+ out = (val_prev + val ) / 2
2583+ elif method == "inverted_cdf" :
2584+ out = val
2585+ else :
2586+ raise ValueError (f"{ method = !r} not recognized" )
2587+ if weights_have_nan :
2588+ out = lax .full_like (out , np .nan )
2589+ out = lax .squeeze (out , axis = axis )
2590+ return out
2591+
2592+ if q .ndim == 0 :
2593+ result = _weighted_quantile (q )
2594+ else :
2595+ result = jax .vmap (_weighted_quantile )(q )
2596+ return result
25522597
25532598 if squash_nans :
25542599 a = _where (lax_internal ._isnan (a ), np .nan , a ) # Ensure nans are positive so they sort to the end.
@@ -2566,10 +2611,10 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25662611
25672612 low = lax .max (_lax_const (low , 0 ), lax .min (low , counts - 1 ))
25682613 high = lax .max (_lax_const (high , 0 ), lax .min (high , counts - 1 ))
2569- low = lax .convert_element_type (low , int )
2570- high = lax .convert_element_type (high , int )
2614+ low = lax .convert_element_type (low , dtypes . canonicalize_dtype ( int ) )
2615+ high = lax .convert_element_type (high , dtypes . canonicalize_dtype ( int ) )
25712616 out_shape = q_shape + shape_after_reduction
2572- index = [lax .broadcasted_iota (int , out_shape , dim + q_ndim )
2617+ index = [lax .broadcasted_iota (dtypes . canonicalize_dtype ( int ) , out_shape , dim + q_ndim )
25732618 for dim in range (len (shape_after_reduction ))]
25742619 if keepdims :
25752620 index [axis ] = low
@@ -2591,8 +2636,8 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25912636
25922637 low = lax .clamp (_lax_const (low , 0 ), low , n - 1 )
25932638 high = lax .clamp (_lax_const (high , 0 ), high , n - 1 )
2594- low = lax .convert_element_type (low , int )
2595- high = lax .convert_element_type (high , int )
2639+ low = lax .convert_element_type (low , dtypes . int_ )
2640+ high = lax .convert_element_type (high , dtypes . int_ )
25962641
25972642 slice_sizes = list (a_shape )
25982643 slice_sizes [axis ] = 1
@@ -2624,6 +2669,8 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
26242669 result = lax .select (pred , low_value , high_value )
26252670 elif method == "midpoint" :
26262671 result = lax .mul (lax .add (low_value , high_value ), _lax_const (low_value , 0.5 ))
2672+ elif method == "inverted_cdf" :
2673+ result = high_value
26272674 else :
26282675 raise ValueError (f"{ method = !r} not recognized" )
26292676 if keepdims and keepdim :
@@ -2639,7 +2686,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
26392686def percentile (a : ArrayLike , q : ArrayLike ,
26402687 axis : int | tuple [int , ...] | None = None ,
26412688 out : None = None , overwrite_input : bool = False , method : str = "linear" ,
2642- keepdims : bool = False , weights : ArrayLike | None = None , * , interpolation : str | DeprecatedArg = DeprecatedArg ()) -> Array :
2689+ keepdims : bool = False , * , weights : ArrayLike | None = None , interpolation : str | DeprecatedArg = DeprecatedArg ()) -> Array :
26432690 """Compute the percentile of the data along the specified axis.
26442691
26452692 JAX implementation of :func:`numpy.percentile`.
@@ -2697,7 +2744,7 @@ def percentile(a: ArrayLike, q: ArrayLike,
26972744def nanpercentile (a : ArrayLike , q : ArrayLike ,
26982745 axis : int | tuple [int , ...] | None = None ,
26992746 out : None = None , overwrite_input : bool = False , method : str = "linear" ,
2700- keepdims : bool = False , weights : ArrayLike | None = None , * , interpolation : str | DeprecatedArg = DeprecatedArg ()) -> Array :
2747+ keepdims : bool = False , * , weights : ArrayLike | None = None , interpolation : str | DeprecatedArg = DeprecatedArg ()) -> Array :
27012748 """Compute the percentile of the data along the specified axis, ignoring NaN values.
27022749
27032750 JAX implementation of :func:`numpy.nanpercentile`.
0 commit comments