@@ -1459,6 +1459,10 @@ def groupby_reduce(
14591459 by : tuple = tuple (np .asarray (b ) if not is_duck_array (b ) else b for b in by )
14601460 nby = len (by )
14611461 by_is_dask = any (is_duck_dask_array (b ) for b in by )
1462+
1463+ if method in ["split-reduce" , "cohorts" ] and by_is_dask :
1464+ raise ValueError (f"method={ method !r} can only be used when grouping by numpy arrays." )
1465+
14621466 if not is_duck_array (array ):
14631467 array = np .asarray (array )
14641468 if isinstance (isbin , bool ):
@@ -1477,9 +1481,11 @@ def groupby_reduce(
14771481 # (pd.IntervalIndex or not)
14781482 expected_groups = _convert_expected_groups_to_index (expected_groups , isbin , sort )
14791483
1480- # when grouping by multiple variables, we factorize early.
14811484 # TODO: could restrict this to dask-only
1482- if nby > 1 :
1485+ factorize_early = (nby > 1 ) or (
1486+ any (isbin ) and method in ["split-reduce" , "cohorts" ] and is_duck_dask_array (array )
1487+ )
1488+ if factorize_early :
14831489 by , final_groups , grp_shape = _factorize_multiple (
14841490 by , expected_groups , by_is_dask = by_is_dask
14851491 )
@@ -1497,6 +1503,7 @@ def groupby_reduce(
14971503 if method in ["blockwise" , "cohorts" , "split-reduce" ] and len (axis ) != by .ndim :
14981504 raise NotImplementedError (
14991505 "Must reduce along all dimensions of `by` when method != 'map-reduce'."
1506+ f"Received method={ method !r} "
15001507 )
15011508
15021509 # TODO: make sure expected_groups is unique
@@ -1617,10 +1624,12 @@ def groupby_reduce(
16171624 result = result [..., sorted_idx ]
16181625 groups = (groups [0 ][sorted_idx ],)
16191626
1620- if nby > 1 :
1627+ if factorize_early :
16211628 # nan group labels are factorized to -1, and preserved
1622- # now we get rid of them
1623- nanmask = groups [0 ] == - 1
1629+ # now we get rid of them by reindexing
1630+ # This also handles bins with no data
1631+ result = reindex_ (
1632+ result , from_ = groups [0 ], to = expected_groups , fill_value = fill_value
1633+ ).reshape (result .shape [:- 1 ] + grp_shape )
16241634 groups = final_groups
1625- result = result [..., ~ nanmask ].reshape (result .shape [:- 1 ] + grp_shape )
16261635 return (result , * groups )
0 commit comments