@@ -254,7 +254,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
254254 fill_value = np .nan
255255 tolerance = {"rtol" : 1e-14 , "atol" : 1e-16 }
256256 elif "quantile" in func :
257- finalize_kwargs = [{"q" : DEFAULT_QUANTILE }]
257+ finalize_kwargs = [{"q" : DEFAULT_QUANTILE }, { "q" : [ DEFAULT_QUANTILE / 2 , DEFAULT_QUANTILE ]} ]
258258 fill_value = None
259259 tolerance = None
260260 else :
@@ -265,6 +265,8 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
265265 array_func = _get_array_func (func )
266266
267267 for kwargs in finalize_kwargs :
268+ if "quantile" in func and isinstance (kwargs ["q" ], list ) and engine != "flox" :
269+ continue
268270 flox_kwargs = dict (func = func , engine = engine , finalize_kwargs = kwargs , fill_value = fill_value )
269271 with np .errstate (invalid = "ignore" , divide = "ignore" ):
270272 with warnings .catch_warnings ():
@@ -289,10 +291,13 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
289291
290292 if func in BLOCKWISE_FUNCS :
291293 assert chunks == - 1
292- flox_kwargs ["method" ] = "blockwise"
293294
294295 actual , * groups = groupby_reduce (array , * by , ** flox_kwargs )
295- assert actual .ndim == expected .ndim == (array .ndim + nby - 1 )
296+ if "quantile" in func and isinstance (kwargs ["q" ], list ):
297+ assert actual .ndim == expected .ndim == (array .ndim + nby )
298+ else :
299+ assert actual .ndim == expected .ndim == (array .ndim + nby - 1 )
300+
296301 expected_groups = tuple (np .array ([idx + 1.0 ]) for idx in range (nby ))
297302 for actual_group , expect in zip (groups , expected_groups ):
298303 assert_equal (actual_group , expect )
@@ -598,6 +603,15 @@ def test_nanfirst_nanlast_disallowed_dask(axis, func):
598603
599604
600605@requires_dask
606+ @pytest .mark .xfail
607+ @pytest .mark .parametrize ("func" , ["first" , "last" ])
608+ def test_first_last_allowed_dask (func ):
609+ # blockwise should be fine... but doesn't work now.
610+ groupby_reduce (dask .array .empty ((2 , 3 , 2 )), np .ones ((2 , 3 , 2 )), func = func , axis = - 1 )
611+
612+
613+ @requires_dask
614+ @pytest .mark .xfail
601615@pytest .mark .parametrize ("func" , ["first" , "last" ])
602616def test_first_last_disallowed_dask (func ):
603617 # blockwise is fine
@@ -1678,19 +1692,25 @@ def test_xarray_fill_value_behaviour():
16781692 assert_equal (expected , actual )
16791693
16801694
1681- @pytest .mark .parametrize ("q" , (0.5 , (0.5 ,), (0.5 , 0.85 )))
1695+ @pytest .mark .parametrize ("q" , (0.5 , (0.5 ,), (0.5 , 0.67 , 0. 85 )))
16821696@pytest .mark .parametrize ("func" , ["nanquantile" , "quantile" ])
16831697@pytest .mark .parametrize ("chunk" , [pytest .param (True , marks = requires_dask ), False ])
1684- def test_multiple_quantiles (q , chunk , func ):
1698+ @pytest .mark .parametrize ("by_ndim" , [1 , 2 ])
1699+ def test_multiple_quantiles (q , chunk , func , by_ndim ):
16851700 array = np .array ([[1 , - 1 , np .nan , 3 , 4 , 10 , 5 ], [1 , np .nan , np .nan , 3 , 4 , np .nan , np .nan ]])
16861701 labels = np .array ([0 , 0 , 0 , 1 , 0 , 1 , 1 ])
1687- axis = - 1
1702+ if by_ndim == 2 :
1703+ labels = np .broadcast_to (labels , (5 , * labels .shape ))
1704+ array = np .broadcast_to (np .expand_dims (array , - 2 ), (2 , 5 , array .shape [- 1 ]))
1705+ axis = tuple (range (- by_ndim , 0 ))
16881706
16891707 if chunk :
1690- array = dask .array .from_array (array , chunks = (1 , - 1 ) )
1708+ array = dask .array .from_array (array , chunks = (1 ,) + ( - 1 ,) * by_ndim )
16911709
16921710 actual , _ = groupby_reduce (array , labels , func = func , finalize_kwargs = dict (q = q ), axis = axis )
16931711 sorted_array = array [..., [0 , 1 , 2 , 4 , 3 , 5 , 6 ]]
16941712 f = partial (getattr (np , func ), q = q , axis = axis , keepdims = True )
1695- expected = np .concatenate ((f (sorted_array [..., :4 ]), f (sorted_array [..., 4 :])), axis = axis )
1696- assert_equal (expected , actual )
1713+ expected = np .concatenate ((f (sorted_array [..., :4 ]), f (sorted_array [..., 4 :])), axis = - 1 )
1714+ if by_ndim == 2 :
1715+ expected = expected .squeeze (axis = - 2 )
1716+ assert_equal (expected , actual , tolerance = 1e-14 )
0 commit comments