@@ -58,13 +58,14 @@ def _prepare_for_flox(group_idx, array):
5858 return group_idx , ordered_array
5959
6060
61- def _get_expected_groups (by , raise_if_dask = True ) -> np . ndarray | None :
61+ def _get_expected_groups (by , raise_if_dask = True ) -> pd . Index | None :
6262 if is_duck_dask_array (by ):
6363 if raise_if_dask :
6464 raise ValueError ("Please provide `expected_groups`." )
6565 return None
6666 flatby = by .ravel ()
67- return np .unique (flatby [~ isnull (flatby )])
67+ expected = np .unique (flatby [~ isnull (flatby )])
68+ return _convert_expected_groups_to_index (expected , isbin = False )
6869
6970
7071def _get_chunk_reduction (reduction_type : str ) -> Callable :
@@ -324,31 +325,28 @@ def rechunk_for_blockwise(array, axis, labels):
324325
325326def reindex_ (array : np .ndarray , from_ , to , fill_value = None , axis : int = - 1 ) -> np .ndarray :
326327
328+ assert isinstance (to , pd .Index )
327329 assert axis in (0 , - 1 )
328330
329- from_ = np .atleast_1d (from_ )
330- to = np .atleast_1d (to )
331-
332331 if to .ndim > 1 :
333332 raise ValueError (f"Cannot reindex to a multidimensional array: { to } " )
334333
335- # short-circuit for trivial case
336- if len (from_ ) == len (to ) and np .all (from_ == to ):
337- return array
338-
339334 if array .shape [axis ] == 0 :
340335 # all groups were NaN
341336 reindexed = np .full (array .shape [:- 1 ] + (len (to ),), fill_value , dtype = array .dtype )
342337 return reindexed
343338
339+ from_ = pd .Index (from_ )
340+ # short-circuit for trivial case
341+ if from_ .equals (to ):
342+ return array
343+
344344 if from_ .dtype .kind == "O" and isinstance (from_ [0 ], tuple ):
345345 raise NotImplementedError (
346346 "Currently does not support reindexing with object arrays of tuples. "
347347 "These occur when grouping by multi-indexed variables in xarray."
348348 )
349- idx = np .array (
350- [np .argwhere (np .array (from_ ) == label )[0 , 0 ] if label in from_ else - 1 for label in to ]
351- )
349+ idx = from_ .get_indexer (to )
352350 indexer = [slice (None , None )] * array .ndim
353351 indexer [axis ] = idx # type: ignore
354352 reindexed = array [tuple (indexer )]
@@ -384,29 +382,25 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
384382 return offset , size
385383
386384
387- def factorize_ (by : tuple , axis , expected_groups : tuple = None , isbin : tuple = None ):
385+ def factorize_ (by : tuple , axis , expected_groups : tuple [ pd . Index , ...] = None ):
388386 if not isinstance (by , tuple ):
389387 raise ValueError (f"Expected `by` to be a tuple. Received { type (by )} instead" )
390388
391- if isbin is None :
392- isbin = (False ,) * len (by )
393389 if expected_groups is None :
394390 expected_groups = (None ,) * len (by )
395391
396392 factorized = []
397393 found_groups = []
398- for groupvar , expect , tobin in zip (by , expected_groups , isbin ):
399- if tobin :
394+ for groupvar , expect in zip (by , expected_groups ):
395+ if isinstance ( expect , pd . IntervalIndex ) :
400396 # when binning we change expected groups to integers marking the interval
401397 # this makes the reindexing logic simpler.
402398 if expect is None :
403399 raise ValueError ("Please pass bin edges in expected_groups." )
404400 # idx = np.digitize(groupvar.ravel(), expect) - 1
405- idx = pd .cut (groupvar .ravel (), bins = expect , labels = False )
401+ idx = pd .cut (groupvar .ravel (), bins = expect , labels = False ). codes . copy ()
406402 # same sentinel value as factorize
407- idx [np .isnan (idx )] = - 1
408- idx = idx .astype (int , copy = False )
409- found_groups .append (np .arange (len (expect ) - 1 ))
403+ found_groups .append (expect )
410404 else :
411405 idx , groups = pd .factorize (groupvar .ravel ())
412406 found_groups .append (np .array (groups ))
@@ -447,12 +441,11 @@ def chunk_argreduce(
447441 array_plus_idx : tuple [np .ndarray , ...],
448442 by : np .ndarray ,
449443 func : Sequence [str ],
450- expected_groups : Sequence | np . ndarray | None ,
444+ expected_groups : pd . Index | None ,
451445 axis : int | Sequence [int ],
452446 fill_value : Mapping [str | Callable , Any ],
453447 dtype = None ,
454448 reindex : bool = False ,
455- isbin : bool = False ,
456449 engine : str = "numpy" ,
457450) -> IntermediateDict :
458451 """
@@ -470,7 +463,6 @@ def chunk_argreduce(
470463 expected_groups = None ,
471464 axis = axis ,
472465 fill_value = fill_value ,
473- isbin = isbin ,
474466 dtype = dtype ,
475467 engine = engine ,
476468 )
@@ -493,12 +485,11 @@ def chunk_reduce(
493485 array : np .ndarray ,
494486 by : np .ndarray ,
495487 func : str | Callable | Sequence [str ] | Sequence [Callable ],
496- expected_groups : Sequence | np . ndarray = None ,
488+ expected_groups : pd . Index | None ,
497489 axis : int | Sequence [int ] = None ,
498490 fill_value : Mapping [str | Callable , Any ] = None ,
499491 dtype = None ,
500492 reindex : bool = False ,
501- isbin : bool = False ,
502493 engine : str = "numpy" ,
503494 kwargs = None ,
504495) -> IntermediateDict :
@@ -574,12 +565,9 @@ def chunk_reduce(
574565 # indices=[0,0,0]. This is necessary when combining block results
575566 # factorize can handle strings etc unlike digitize
576567 group_idx , groups , _ , ngroups , size , props = factorize_ (
577- (by ,), axis , expected_groups = (expected_groups ,), isbin = ( isbin ,)
568+ (by ,), axis , expected_groups = (expected_groups ,)
578569 )
579570 groups = groups [0 ]
580- # TODO: why?
581- if isbin :
582- expected_groups = groups
583571
584572 # always reshape to 1D along group dimensions
585573 newshape = array .shape [: array .ndim - by .ndim ] + (np .prod (array .shape [- by .ndim :]),)
@@ -590,7 +578,8 @@ def chunk_reduce(
590578
591579 results : IntermediateDict = {"groups" : [], "intermediates" : []}
592580 if reindex and expected_groups is not None :
593- results ["groups" ] = np .array (expected_groups )
581+ # TODO: what happens with binning here?
582+ results ["groups" ] = expected_groups .values
594583 else :
595584 if empty :
596585 results ["groups" ] = np .array ([np .nan ])
@@ -654,16 +643,13 @@ def chunk_reduce(
654643 if props .offset_group :
655644 result = result .reshape (* final_array_shape [:- 1 ], ngroups )
656645 if reindex :
657- if not isbin :
658- result = reindex_ (result , groups , expected_groups , fill_value = fv )
646+ result = reindex_ (result , groups , expected_groups , fill_value = fv )
659647 else :
660648 result = result [..., sortidx ]
661649 result = result .reshape (final_array_shape )
662650 results ["intermediates" ].append (result )
663- if final_groups_shape :
664- # This happens when to_group is broadcasted, and we reduce along the broadcast
665- # dimension
666- results ["groups" ] = np .broadcast_to (results ["groups" ], final_groups_shape )
651+
652+ results ["groups" ] = np .broadcast_to (results ["groups" ], final_groups_shape )
667653 return results
668654
669655
@@ -691,7 +677,7 @@ def _finalize_results(
691677 results : IntermediateDict ,
692678 agg : Aggregation ,
693679 axis : Sequence [int ],
694- expected_groups : Sequence | np . ndarray | None ,
680+ expected_groups : pd . Index | None ,
695681 fill_value : Any ,
696682):
697683 """Finalize results by
@@ -731,7 +717,7 @@ def _finalize_results(
731717 finalized [agg .name ] = reindex_ (
732718 finalized [agg .name ], squeezed ["groups" ], expected_groups , fill_value = fill_value
733719 )
734- finalized ["groups" ] = expected_groups
720+ finalized ["groups" ] = expected_groups . to_numpy ()
735721 else :
736722 finalized ["groups" ] = squeezed ["groups" ]
737723
@@ -742,7 +728,7 @@ def _aggregate(
742728 x_chunk ,
743729 combine : Callable ,
744730 agg : Aggregation ,
745- expected_groups : Sequence | np . ndarray | None ,
731+ expected_groups : pd . Index | None ,
746732 axis : Sequence ,
747733 keepdims ,
748734 fill_value : Any ,
@@ -796,7 +782,9 @@ def reindex_intermediates(x, agg, unique_groups):
796782 new_shape = x ["groups" ].shape [:- 1 ] + (len (unique_groups ),)
797783 newx = {"groups" : np .broadcast_to (unique_groups , new_shape )}
798784 newx ["intermediates" ] = tuple (
799- reindex_ (v , from_ = x ["groups" ].squeeze (), to = unique_groups , fill_value = f )
785+ reindex_ (
786+ v , from_ = np .atleast_1d (x ["groups" ].squeeze ()), to = pd .Index (unique_groups ), fill_value = f
787+ )
800788 for v , f in zip (x ["intermediates" ], agg .fill_value ["intermediate" ])
801789 )
802790 return newx
@@ -940,7 +928,7 @@ def split_blocks(applied, split_out, expected_groups, split_name):
940928 return intermediate , group_chunks
941929
942930
943- def _reduce_blockwise (array , by , agg , * , axis , expected_groups , fill_value , isbin , engine ):
931+ def _reduce_blockwise (array , by , agg , * , axis , expected_groups , fill_value , engine ):
944932 """
945933 Blockwise groupby reduction that produces the final result. This code path is
946934 also used for non-dask array aggregations.
@@ -964,13 +952,12 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, isbi
964952 by ,
965953 func = func ,
966954 axis = axis ,
967- expected_groups = expected_groups if isbin else None ,
955+ expected_groups = expected_groups ,
968956 # This fill_value should only apply to groups that only contain NaN observations
969957 # BUT there is funkiness when axis is a subset of all possible values
970958 # (see below)
971959 fill_value = (agg .fill_value [agg .name ], 0 ),
972960 dtype = (agg .dtype [agg .name ], np .intp ),
973- isbin = isbin ,
974961 kwargs = finalize_kwargs ,
975962 engine = engine ,
976963 ) # type: ignore
@@ -995,9 +982,6 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, isbi
995982 value [mask ] = np .nan
996983 results ["intermediates" ][0 ] = value
997984
998- if isbin :
999- expected_groups = np .arange (len (expected_groups ) - 1 )
1000-
1001985 # When axis is a subset of possible values; then npg will
1002986 # apply it to groups that don't exist along a particular axis (for e.g.)
1003987 # since these count as a group that is absent. thoo!
@@ -1022,12 +1006,11 @@ def dask_groupby_agg(
10221006 array : DaskArray ,
10231007 by : DaskArray | np .ndarray ,
10241008 agg : Aggregation ,
1025- expected_groups : Sequence | np . ndarray | None ,
1009+ expected_groups : pd . Index | None ,
10261010 axis : Sequence = None ,
10271011 split_out : int = 1 ,
10281012 fill_value : Any = None ,
10291013 method : str = "map-reduce" ,
1030- isbin : bool = False ,
10311014 reindex : bool = False ,
10321015 engine : str = "numpy" ,
10331016) -> tuple [DaskArray , np .ndarray | DaskArray ]:
@@ -1107,8 +1090,7 @@ def dask_groupby_agg(
11071090 partial (
11081091 blockwise_method ,
11091092 axis = axis ,
1110- expected_groups = expected_groups if reindex or split_out > 1 or isbin else None ,
1111- isbin = isbin ,
1093+ expected_groups = expected_groups ,
11121094 engine = engine ,
11131095 ),
11141096 inds ,
@@ -1129,9 +1111,6 @@ def dask_groupby_agg(
11291111 )
11301112 else :
11311113 intermediate = applied
1132- # from this point on, we just work with bin indexes when binning
1133- if isbin :
1134- expected_groups = np .arange (len (expected_groups ) - 1 )
11351114 if expected_groups is None :
11361115 expected_groups = _get_expected_groups (by_input , raise_if_dask = False )
11371116 group_chunks = (len (expected_groups ),) if expected_groups is not None else (np .nan ,)
@@ -1212,7 +1191,7 @@ def dask_groupby_agg(
12121191 if method == "map-reduce" :
12131192 if expected_groups is None :
12141193 expected_groups = _get_expected_groups (by_input )
1215- groups = (expected_groups ,)
1194+ groups = (expected_groups . values ,)
12161195 else :
12171196 groups = (np .concatenate (groups_in_block ),)
12181197
@@ -1269,12 +1248,22 @@ def _assert_by_is_aligned(shape, by):
12691248 )
12701249
12711250
1251+ def _convert_expected_groups_to_index (expected_groups , isbin : bool ) -> pd .Index | None :
1252+ if isinstance (expected_groups , pd .Index ):
1253+ return expected_groups
1254+ if isbin :
1255+ return pd .IntervalIndex .from_arrays (expected_groups [:- 1 ], expected_groups [1 :])
1256+ elif expected_groups is not None :
1257+ return pd .Index (expected_groups )
1258+ return None
1259+
1260+
12721261def groupby_reduce (
12731262 array : np .ndarray | DaskArray ,
12741263 by : np .ndarray | DaskArray ,
12751264 func : str | Aggregation ,
12761265 * ,
1277- expected_groups : Sequence | np .ndarray = None ,
1266+ expected_groups : Sequence | np .ndarray | None = None ,
12781267 sort : bool = True ,
12791268 isbin : bool = False ,
12801269 axis = None ,
@@ -1393,6 +1382,10 @@ def groupby_reduce(
13931382
13941383 _assert_by_is_aligned (array .shape , by )
13951384
1385+ # We convert to pd.Index since that lets us know if we are binning or not
1386+ # (pd.IntervalIndex or not)
1387+ expected_groups = _convert_expected_groups_to_index (expected_groups , isbin )
1388+
13961389 if axis is None :
13971390 axis = tuple (array .ndim + np .arange (- by .ndim , 0 ))
13981391 else :
@@ -1435,7 +1428,7 @@ def groupby_reduce(
14351428 assert isinstance (finalize_kwargs , dict )
14361429 agg .finalize_kwargs = finalize_kwargs
14371430
1438- kwargs = dict (axis = axis , fill_value = fill_value , isbin = isbin , engine = engine )
1431+ kwargs = dict (axis = axis , fill_value = fill_value , engine = engine )
14391432
14401433 if not is_duck_dask_array (array ) and not is_duck_dask_array (by ):
14411434 results = _reduce_blockwise (array , by , agg , expected_groups = expected_groups , ** kwargs )
@@ -1484,7 +1477,7 @@ def groupby_reduce(
14841477 r , * g = partial_agg (
14851478 array_subset ,
14861479 by [np .ix_ (* indexer )],
1487- expected_groups = cohort ,
1480+ expected_groups = pd . Index ( cohort ) ,
14881481 # reindex to expected_groups at the blockwise step.
14891482 # this approach avoids replacing non-cohort members with
14901483 # np.nan or some other sentinel value, and preserves dtypes
@@ -1504,7 +1497,6 @@ def groupby_reduce(
15041497 if method == "blockwise" and by .ndim == 1 :
15051498 array = rechunk_for_blockwise (array , axis = - 1 , labels = by )
15061499
1507- # TODO: test with mixed array kinds (numpy array + dask by)
15081500 result , * groups = partial_agg (
15091501 array , by , expected_groups = expected_groups , reindex = reindex , method = method
15101502 )
0 commit comments