1717from functools import partial
1818import itertools
1919import unittest
20- import pytest
2120
2221from absl .testing import absltest
2322from absl .testing import parameterized
@@ -764,51 +763,58 @@ def testPercentilePrecision(self):
764763 x = jnp .float64 ([1 , 2 , 3 , 4 , 7 , 10 ])
765764 self .assertEqual (jnp .percentile (x , 50 ), 3.5 )
766765
767- def test_weighted_quantile_all_weights_one (self ):
768- a = jnp .array ([1 , 2 , 3 , 4 , 5 ], dtype = float )
769- weights = jnp .ones_like (a )
770- q = jnp .array ([0.25 , 0.5 , 0.75 ])
771- result = jnp .quantile (a , q , axis = 0 , method = "inverted_cdf" , keepdims = False , squash_nans = False , weights = weights )
772- expected = np .quantile (np .array (a ), np .array (q ), axis = 0 , weights = np .array (weights ), method = "inverted_cdf" )
773- np .testing .assert_allclose (np .array (result ), expected , rtol = 1e-6 )
774-
775- def test_weighted_quantile_multiple_q (self ):
776- a = jnp .arange (10 , dtype = float )
777- weights = jnp .ones_like (a )
778- q = jnp .array ([0.25 , 0.5 , 0.75 ])
779- result = jnp .quantile (a , q , axis = 0 , method = "inverted_cdf" , keepdims = False , squash_nans = False , weights = weights )
780- expected = np .quantile (np .array (a ), np .array (q ), axis = 0 , weights = np .array (weights ), method = "inverted_cdf" )
781- np .testing .assert_allclose (np .array (result ), expected , rtol = 1e-6 )
782-
783- def test_weighted_quantile_keepdims (self ):
784- a = jnp .array ([1 , 2 , 3 , 4 ], dtype = float )
785- weights = jnp .array ([1 , 1 , 1 , 1 ], dtype = float )
786- q = 0.5
787- result = jnp .quantile (a , q , axis = 0 , method = "inverted_cdf" , keepdims = True , squash_nans = False , weights = weights )
788- expected = np .quantile (np .array (a ), np .array (q ), axis = 0 , keepdims = True , weights = np .array (weights ), method = "inverted_cdf" )
789- np .testing .assert_allclose (np .array (result ), expected , rtol = 1e-6 )
766+ @jtu .sample_product (
767+ [dict (a_shape = a_shape , axis = axis )
768+ for a_shape , axis in (
769+ ((7 ,), None ),
770+ ((6 , 7 ,), None ),
771+ ((47 , 7 ), 0 ),
772+ ((47 , 7 ), ()),
773+ ((4 , 101 ), 1 ),
774+ ((4 , 47 , 7 ), (1 , 2 )),
775+ ((4 , 47 , 7 ), (0 , 2 )),
776+ ((4 , 47 , 7 ), (1 , 0 , 2 )),
777+ )
778+ ],
779+ a_dtype = default_dtypes ,
780+ q_dtype = [np .float32 ],
781+ q_shape = scalar_shapes + [(1 ,), (4 ,)],
782+ keepdims = [False , True ],
783+ method = ['linear' , 'lower' , 'higher' , 'nearest' , 'midpoint' , 'inverted_cdf' ],
784+ )
785+ def testWeightedQuantile (self , a_shape , a_dtype , q_shape , q_dtype , axis , keepdims , method ):
786+ rng = jtu .rand_default (self .rng ())
787+ a = rng (a_shape , a_dtype )
788+ q = rng (q_shape , q_dtype )
789+ if axis is None :
790+ weights_shape = a_shape
791+ elif isinstance (axis , tuple ):
792+ weights_shape = tuple (a_shape [i ] for i in axis )
793+ else :
794+ weights_shape = (a_shape [axis ],)
795+ weights = np .abs (rng (weights_shape , a_dtype )) + 1e-3
790796
791- def test_weighted_quantile_linear ( self ):
792- a = jnp . array ([ 1 , 2 , 3 , 4 , 5 ], dtype = float )
793- weights = jnp . array ([ 1 , 2 , 1 , 1 , 1 ], dtype = float )
794- q = jnp .array ([ 0.5 ] )
795- result = jnp . quantile ( a , q , axis = 0 , method = "inverted_cdf" , keepdims = False , squash_nans = False , weights = weights )
796- expected = np . quantile ( np . array ( a ), np . array ( q ), axis = 0 , weights = np . array ( weights ), method = "inverted_cdf" )
797- np . testing . assert_allclose ( np . array ( result ), expected , rtol = 1e-6 )
797+ def np_fun ( a , q , weights ):
798+ return np . quantile ( np . array (a ), np . array ( q ), axis = axis , weights = np . array ( weights ), method = method , keepdims = keepdims )
799+ def jnp_fun ( a , q , weights ):
800+ return jnp .quantile ( a , q , axis = axis , weights = weights , method = method , keepdims = keepdims )
801+ args_maker = lambda : [ a , q , weights ]
802+ self . _CheckAgainstNumpy ( np_fun , jnp_fun , args_maker , tol = 1e-6 )
803+ self . _CompileAndCheck ( jnp_fun , args_maker , rtol = 1e-6 )
798804
799805 def test_weighted_quantile_negative_weights (self ):
800806 a = jnp .array ([1 , 2 , 3 , 4 , 5 ], dtype = float )
801807 weights = jnp .array ([1 , - 1 , 1 , 1 , 1 ], dtype = float )
802808 q = jnp .array ([0.5 ])
803- with pytest . raises (ValueError ):
804- jnp .quantile (a , q , axis = 0 , method = "linear" , keepdims = False , squash_nans = False , weights = weights )
809+ with self . assertRaisesRegex (ValueError , "Weights must be non-negative" ):
810+ jnp .quantile (a , q , axis = 0 , method = "linear" , keepdims = False , squash_nans = False , weights = weights )
805811
806812 def test_weighted_quantile_all_weights_zero (self ):
807813 a = jnp .array ([1 , 2 , 3 , 4 , 5 ], dtype = float )
808814 weights = jnp .zeros_like (a )
809815 q = jnp .array ([0.5 ])
810- with pytest . raises (ValueError ):
811- jnp .quantile (a , q , axis = 0 , method = "linear" , keepdims = False , squash_nans = False , weights = weights )
816+ with self . assertRaisesRegex (ValueError , "Sum of weights must not be zero" ):
817+ jnp .quantile (a , q , axis = 0 , method = "linear" , keepdims = False , squash_nans = False , weights = weights )
812818
813819 def test_weighted_quantile_weights_with_nan (self ):
814820 a = jnp .array ([1 , 2 , 3 , 4 , 5 ], dtype = float )
@@ -825,15 +831,6 @@ def test_weighted_quantile_scalar_q(self):
825831 assert jnp .issubdtype (result .dtype , jnp .floating )
826832 assert result .shape == ()
827833
828- def test_weighted_quantile_jit (self ):
829- a = jnp .array ([1 , 2 , 3 , 4 , 5 ], dtype = float )
830- weights = jnp .array ([1 , 2 , 1 , 1 , 1 ], dtype = float )
831- q = jnp .array ([0.25 , 0.5 , 0.75 ])
832- quantile_jit = jax .jit (lambda a , q , weights : jnp .quantile (a , q , axis = 0 , method = "inverted_cdf" , keepdims = False , squash_nans = False , weights = weights ))
833- result = quantile_jit (a , q , weights )
834- expected = np .quantile (np .array (a ), np .array (q ), axis = 0 , weights = np .array (weights ), method = "inverted_cdf" )
835- np .testing .assert_allclose (np .array (result ), expected , rtol = 1e-6 )
836-
837834 @jtu .sample_product (
838835 [dict (a_shape = a_shape , axis = axis )
839836 for a_shape , axis in (
0 commit comments