@@ -509,6 +509,7 @@ def chunk_argreduce(
509509 dask.array.reductions.argtopk
510510 """
511511 array , idx = array_plus_idx
512+ by = np .broadcast_to (by , array .shape )
512513
513514 results = chunk_reduce (
514515 array ,
@@ -522,17 +523,22 @@ def chunk_argreduce(
522523 sort = sort ,
523524 )
524525 if not isnull (results ["groups" ]).all ():
525- # will not work for empty groups...
526- # glorious
527526 idx = np .broadcast_to (idx , array .shape )
527+
528+ # array, by get flattened to 1D before passing to npg
529+ # so the indexes need to be unraveled
528530 newidx = np .unravel_index (results ["intermediates" ][1 ], array .shape )
531+
532+ # Now index into the actual "global" indexes `idx`
529533 results ["intermediates" ][1 ] = idx [newidx ]
530534
531535 if reindex and expected_groups is not None :
532536 results ["intermediates" ][1 ] = reindex_ (
533537 results ["intermediates" ][1 ], results ["groups" ].squeeze (), expected_groups , fill_value = 0
534538 )
535539
540+ assert results ["intermediates" ][0 ].shape == results ["intermediates" ][1 ].shape
541+
536542 return results
537543
538544
@@ -879,34 +885,45 @@ def _grouped_combine(
879885 array_idx = tuple (
880886 _conc2 (x_chunk , key1 = "intermediates" , key2 = idx , axis = axis ) for idx in (0 , 1 )
881887 )
882- results = chunk_argreduce (
883- array_idx ,
884- groups ,
885- func = agg .combine [slicer ], # count gets treated specially next
886- axis = axis ,
887- expected_groups = None ,
888- fill_value = agg .fill_value ["intermediate" ][slicer ],
889- dtype = agg .dtype ["intermediate" ][slicer ],
890- engine = engine ,
891- sort = sort ,
892- )
888+
889+ # for a single element along axis, we don't want to run the argreduction twice
890+ # This happens when we are reducing along an axis with a single chunk.
891+ avoid_reduction = array_idx [0 ].shape [axis [0 ]] == 1
892+ if avoid_reduction :
893+ results = {"groups" : groups , "intermediates" : list (array_idx )}
894+ else :
895+ results = chunk_argreduce (
896+ array_idx ,
897+ groups ,
898+ func = agg .combine [slicer ], # count gets treated specially next
899+ axis = axis ,
900+ expected_groups = None ,
901+ fill_value = agg .fill_value ["intermediate" ][slicer ],
902+ dtype = agg .dtype ["intermediate" ][slicer ],
903+ engine = engine ,
904+ sort = sort ,
905+ )
893906
894907 if agg .chunk [- 1 ] == "nanlen" :
895908 counts = _conc2 (x_chunk , key1 = "intermediates" , key2 = 2 , axis = axis )
896- # sum the counts
897- results ["intermediates" ].append (
898- chunk_reduce (
899- counts ,
900- groups ,
901- func = "sum" ,
902- axis = axis ,
903- expected_groups = None ,
904- fill_value = (0 ,),
905- dtype = (np .intp ,),
906- engine = engine ,
907- sort = sort ,
908- )["intermediates" ][0 ]
909- )
909+
910+ if avoid_reduction :
911+ results ["intermediates" ].append (counts )
912+ else :
913+ # sum the counts
914+ results ["intermediates" ].append (
915+ chunk_reduce (
916+ counts ,
917+ groups ,
918+ func = "sum" ,
919+ axis = axis ,
920+ expected_groups = None ,
921+ fill_value = (0 ,),
922+ dtype = (np .intp ,),
923+ engine = engine ,
924+ sort = sort ,
925+ )["intermediates" ][0 ]
926+ )
910927
911928 elif agg .reduction_type == "reduce" :
912929 # Here we reduce the intermediates individually
@@ -1006,24 +1023,7 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, engi
10061023 ) # type: ignore
10071024
10081025 if _is_arg_reduction (agg ):
1009- if array .ndim > 1 :
1010- # default fill_value is -1; we can't unravel that;
1011- # so replace -1 with 0; unravel; then replace 0 with -1
1012- # UGH!
1013- idx = results ["intermediates" ][0 ]
1014- mask = idx == agg .fill_value ["numpy" ][0 ]
1015- idx [mask ] = 0
1016- # Fix npg bug where argmax with nD array, 1D group_idx, axis=-1
1017- # will return wrong indices
1018- idx = np .unravel_index (idx , array .shape )[- 1 ]
1019- idx [mask ] = agg .fill_value ["numpy" ][0 ]
1020- results ["intermediates" ][0 ] = idx
1021- elif agg .name in ["nanvar" , "nanstd" ]:
1022- # TODO: Fix npg bug where all-NaN rows are 0 instead of NaN
1023- value , counts = results ["intermediates" ]
1024- mask = counts <= 0
1025- value [mask ] = np .nan
1026- results ["intermediates" ][0 ] = value
1026+ results ["intermediates" ][0 ] = np .unravel_index (results ["intermediates" ][0 ], array .shape )[- 1 ]
10271027
10281028 result = _finalize_results (
10291029 results , agg , axis , expected_groups , fill_value = fill_value , reindex = reindex
@@ -1530,12 +1530,7 @@ def groupby_reduce(
15301530 # The only way to do this consistently is mask out using min_count
15311531 # Consider np.sum([np.nan]) = np.nan, np.nansum([np.nan]) = 0
15321532 if min_count is None :
1533- if (
1534- len (axis ) < by .ndim
1535- or fill_value is not None
1536- # TODO: Fix npg bug where all-NaN rows are 0 instead of NaN
1537- or (not has_dask and isinstance (func , str ) and func in ["nanvar" , "nanstd" ])
1538- ):
1533+ if len (axis ) < by .ndim or fill_value is not None :
15391534 min_count = 1
15401535
15411536 # TODO: set in xarray?
0 commit comments