@@ -48,6 +48,7 @@ def _prepare_for_flox(group_idx, array):
4848 """
4949 Sort the input array once to save time.
5050 """
51+ assert array .shape [- 1 ] == group_idx .shape [0 ]
5152 issorted = (group_idx [:- 1 ] <= group_idx [1 :]).all ()
5253 if issorted :
5354 ordered_array = array
@@ -323,14 +324,6 @@ def rechunk_for_blockwise(array, axis, labels):
323324 return array .rechunk ({axis : newchunks })
324325
325326
326- def reindex_multiple (array , groups , expected_groups , fill_value , promote = False ):
327- for ax , (group , expect ) in enumerate (zip (groups , expected_groups )):
328- array = reindex_ (
329- array , group , expect , fill_value = fill_value , axis = ax - len (groups ), promote = promote
330- )
331- return array
332-
333-
334327def reindex_ (
335328 array : np .ndarray , from_ , to , fill_value = None , axis : int = - 1 , promote : bool = False
336329) -> np .ndarray :
@@ -562,19 +555,12 @@ def chunk_reduce(
562555 array = _collapse_axis (array , len (axis ))
563556 axis = - 1
564557
565- if by .ndim == 1 :
566- # TODO: This assertion doesn't work with dask reducing across all dimensions
567- # when by.ndim == array.ndim
568- # the intermediates are 1D but axis=range(array.ndim)
569- # assert axis in (0, -1, array.ndim - 1, None)
570- axis = - 1
571-
572558 # if indices=[2,2,2], npg assumes groups are (0, 1, 2);
573559 # and will return a result that is bigger than necessary
574560 # avoid by factorizing again so indices=[2,2,2] is changed to
575561 # indices=[0,0,0]. This is necessary when combining block results
576562 # factorize can handle strings etc unlike digitize
577- group_idx , groups , _ , ngroups , size , props = factorize_ (
563+ group_idx , groups , found_groups_shape , ngroups , size , props = factorize_ (
578564 (by ,), axis , expected_groups = (expected_groups ,)
579565 )
580566 groups = groups [0 ]
@@ -584,21 +570,17 @@ def chunk_reduce(
584570 array = array .reshape (newshape )
585571
586572 assert group_idx .ndim == 1
587- empty = np .all (props .nanmask ) or np . prod ( by . shape ) == 0
573+ empty = np .all (props .nanmask )
588574
589575 results : IntermediateDict = {"groups" : [], "intermediates" : []}
590576 if reindex and expected_groups is not None :
591577 # TODO: what happens with binning here?
592- results ["groups" ] = expected_groups .values
578+ results ["groups" ] = expected_groups .to_numpy ()
593579 else :
594580 if empty :
595581 results ["groups" ] = np .array ([np .nan ])
596582 else :
597- if (groups [:- 1 ] <= groups [1 :]).all ():
598- sortidx = slice (None )
599- else :
600- sortidx = groups .argsort ()
601- results ["groups" ] = groups [sortidx ]
583+ results ["groups" ] = np .sort (groups )
602584
603585 # npg's argmax ensures that index of first "max" is returned assuming there
604586 # are many elements equal to the "max". Sorting messes this up totally.
@@ -650,13 +632,8 @@ def chunk_reduce(
650632 if np .any (props .nanmask ):
651633 # remove NaN group label which should be last
652634 result = result [..., :- 1 ]
653- if props .offset_group :
654- result = result .reshape (* final_array_shape [:- 1 ], ngroups )
655- if reindex :
656- result = reindex_ (result , groups , expected_groups , fill_value = fv )
657- else :
658- result = result [..., sortidx ]
659- result = result .reshape (final_array_shape )
635+ result = result .reshape (final_array_shape [:- 1 ] + found_groups_shape )
636+ result = reindex_ (result , groups , results ["groups" ], fill_value = fv , promote = True )
660637 results ["intermediates" ].append (result )
661638
662639 results ["groups" ] = np .broadcast_to (results ["groups" ], final_groups_shape )
@@ -913,8 +890,8 @@ def split_blocks(applied, split_out, expected_groups, split_name):
913890
914891 chunk_tuples = tuple (itertools .product (* tuple (range (n ) for n in applied .numblocks )))
915892 ngroups = len (expected_groups )
916- group_chunks = normalize_chunks (np .ceil (ngroups / split_out ), (ngroups ,))[ 0 ]
917- idx = tuple (np .cumsum ((0 ,) + group_chunks ))
893+ group_chunks = normalize_chunks (np .ceil (ngroups / split_out ), (ngroups ,))
894+ idx = tuple (np .cumsum ((0 ,) + group_chunks [ 0 ] ))
918895
919896 # split each block into `split_out` chunks
920897 dsk = {}
@@ -1123,7 +1100,7 @@ def dask_groupby_agg(
11231100 intermediate = applied
11241101 if expected_groups is None :
11251102 expected_groups = _get_expected_groups (by_input , raise_if_dask = False )
1126- group_chunks = (len (expected_groups ),) if expected_groups is not None else (np .nan ,)
1103+ group_chunks = (( len (expected_groups ),) if expected_groups is not None else (np .nan ,) ,)
11271104
11281105 if method == "map-reduce" :
11291106 # these are negative axis indices useful for concatenating the intermediates
@@ -1155,7 +1132,7 @@ def dask_groupby_agg(
11551132 keepdims = True ,
11561133 concatenate = False ,
11571134 )
1158- output_chunks = reduced .chunks [: - (len (axis ) + int (split_out > 1 ))] + ( group_chunks ,)
1135+ output_chunks = reduced .chunks [: - (len (axis ) + int (split_out > 1 ))] + group_chunks
11591136 elif method == "blockwise" :
11601137 reduced = intermediate
11611138 # Here one input chunk → one output chunka
@@ -1193,7 +1170,7 @@ def dask_groupby_agg(
11931170 dask .array .Array (
11941171 HighLevelGraph .from_collections (groups_name , layer , dependencies = [reduced ]),
11951172 groups_name ,
1196- chunks = ( group_chunks ,) ,
1173+ chunks = group_chunks ,
11971174 dtype = by .dtype ,
11981175 ),
11991176 )
0 commit comments