@@ -803,29 +803,11 @@ def chunk_reduce(
803803 dict
804804 """
805805
806- if not (isinstance (func , str ) or callable (func )):
807- funcs = func
808- else :
809- funcs = (func ,)
806+ funcs = _atleast_1d (func )
810807 nfuncs = len (funcs )
811-
812- if isinstance (dtype , Sequence ):
813- dtypes = dtype
814- else :
815- dtypes = (dtype ,) * nfuncs
816- assert len (dtypes ) >= nfuncs
817-
818- if isinstance (fill_value , Sequence ):
819- fill_values = fill_value
820- else :
821- fill_values = (fill_value ,) * nfuncs
822- assert len (fill_values ) >= nfuncs
823-
824- if isinstance (kwargs , Sequence ):
825- kwargss = kwargs
826- else :
827- kwargss = ({},) * nfuncs
828- assert len (kwargss ) >= nfuncs
808+ dtypes = _atleast_1d (dtype , nfuncs )
809+ fill_values = _atleast_1d (fill_value , nfuncs )
810+ kwargss = _atleast_1d ({}, nfuncs ) if kwargs is None else kwargs
829811
830812 if isinstance (axis , Sequence ):
831813 axes : T_Axes = axis
@@ -862,7 +844,8 @@ def chunk_reduce(
862844
863845 # do this *before* possible broadcasting below.
864846 # factorize_ has already taken care of offsetting
865- seen_groups = _unique (group_idx )
847+ if engine == "numbagg" :
848+ seen_groups = _unique (group_idx )
866849
867850 order = "C"
868851 if nax > 1 :
@@ -1551,12 +1534,9 @@ def dask_groupby_agg(
15511534 groups = _extract_unknown_groups (reduced , dtype = by .dtype )
15521535 group_chunks = ((np .nan ,),)
15531536 else :
1554- if expected_groups is None :
1555- expected_groups_ = _get_expected_groups (by_input , sort = sort )
1556- else :
1557- expected_groups_ = expected_groups
1558- groups = (expected_groups_ .to_numpy (),)
1559- group_chunks = ((len (expected_groups_ ),),)
1537+ assert expected_groups is not None
1538+ groups = (expected_groups .to_numpy (),)
1539+ group_chunks = ((len (expected_groups ),),)
15601540
15611541 elif method == "cohorts" :
15621542 chunks_cohorts = find_group_cohorts (
@@ -2063,10 +2043,7 @@ def groupby_reduce(
20632043 is_bool_array = np .issubdtype (array .dtype , bool )
20642044 array = array .astype (int ) if is_bool_array else array
20652045
2066- if isinstance (isbin , Sequence ):
2067- isbins = isbin
2068- else :
2069- isbins = (isbin ,) * nby
2046+ isbins = _atleast_1d (isbin , nby )
20702047
20712048 _assert_by_is_aligned (array .shape , bys )
20722049
0 commit comments