1515 Any ,
1616 Callable ,
1717 Literal ,
18+ TypedDict ,
1819 Union ,
1920 overload ,
2021)
8788DUMMY_AXIS = - 2
8889
8990
91+ class FactorizeKwargs (TypedDict , total = False ):
92+ """Used in _factorize_multiple"""
93+
94+ by : T_Bys
95+ axes : T_Axes
96+ fastpath : bool
97+ expected_groups : T_ExpectIndexOptTuple | None
98+ reindex : bool
99+ sort : bool
100+
101+
90102def _postprocess_numbagg (result , * , func , fill_value , size , seen_groups ):
91103 """Account for numbagg not providing a fill_value kwarg."""
92104 from .aggregate_numbagg import DEFAULT_FILL_VALUE
@@ -1434,7 +1446,7 @@ def dask_groupby_agg(
14341446 _ , (array , by ) = dask .array .unify_chunks (array , inds , by , inds [- by .ndim :])
14351447
14361448 # tokenize here since by has already been hashed if its numpy
1437- token = dask .base .tokenize (array , by , agg , expected_groups , axis )
1449+ token = dask .base .tokenize (array , by , agg , expected_groups , axis , method )
14381450
14391451 # preprocess the array:
14401452 # - for argreductions, this zips the index together with the array block
@@ -1454,7 +1466,8 @@ def dask_groupby_agg(
14541466 # b. "_grouped_combine": A more general solution where we tree-reduce the groupby reduction.
14551467 # This allows us to discover groups at compute time, support argreductions, lower intermediate
14561468 # memory usage (but method="cohorts" would also work to reduce memory in some cases)
1457- do_simple_combine = not _is_arg_reduction (agg )
1469+ labels_are_unknown = is_duck_dask_array (by_input ) and expected_groups is None
1470+ do_simple_combine = not _is_arg_reduction (agg ) and not labels_are_unknown
14581471
14591472 if method == "blockwise" :
14601473 # use the "non dask" code path, but applied blockwise
@@ -1510,7 +1523,7 @@ def dask_groupby_agg(
15101523
15111524 tree_reduce = partial (
15121525 dask .array .reductions ._tree_reduce ,
1513- name = f"{ name } -reduce- { method } " ,
1526+ name = f"{ name } -reduce" ,
15141527 dtype = array .dtype ,
15151528 axis = axis ,
15161529 keepdims = True ,
@@ -1529,7 +1542,7 @@ def dask_groupby_agg(
15291542 combine = partial (combine , agg = agg ),
15301543 aggregate = partial (aggregate , expected_groups = expected_groups , reindex = reindex ),
15311544 )
1532- if is_duck_dask_array ( by_input ) and expected_groups is None :
1545+ if labels_are_unknown :
15331546 groups = _extract_unknown_groups (reduced , dtype = by .dtype )
15341547 group_chunks = ((np .nan ,),)
15351548 else :
@@ -1747,17 +1760,26 @@ def _convert_expected_groups_to_index(
17471760
17481761
17491762def _lazy_factorize_wrapper (* by : T_By , ** kwargs ) -> np .ndarray :
1750- group_idx , * rest = factorize_ (by , ** kwargs )
1763+ group_idx , * _ = factorize_ (by , ** kwargs )
17511764 return group_idx
17521765
17531766
17541767def _factorize_multiple (
17551768 by : T_Bys ,
17561769 expected_groups : T_ExpectIndexOptTuple ,
17571770 any_by_dask : bool ,
1758- reindex : bool ,
17591771 sort : bool = True ,
17601772) -> tuple [tuple [np .ndarray ], tuple [np .ndarray , ...], tuple [int , ...]]:
1773+ kwargs : FactorizeKwargs = dict (
1774+ axes = (), # always (), we offset later if necessary.
1775+ expected_groups = expected_groups ,
1776+ fastpath = True ,
1777+ # This is the only way it makes sense I think.
1778+ # reindex controls what's actually allocated in chunk_reduce
1779+ # At this point, we care about an accurate conversion to codes.
1780+ reindex = True ,
1781+ sort = sort ,
1782+ )
17611783 if any_by_dask :
17621784 import dask .array
17631785
@@ -1771,11 +1793,7 @@ def _factorize_multiple(
17711793 * by_ ,
17721794 chunks = tuple (chunks .values ()),
17731795 meta = np .array ((), dtype = np .int64 ),
1774- axes = (), # always (), we offset later if necessary.
1775- expected_groups = expected_groups ,
1776- fastpath = True ,
1777- reindex = reindex ,
1778- sort = sort ,
1796+ ** kwargs ,
17791797 )
17801798
17811799 fg , gs = [], []
@@ -1796,14 +1814,8 @@ def _factorize_multiple(
17961814 found_groups = tuple (fg )
17971815 grp_shape = tuple (gs )
17981816 else :
1799- group_idx , found_groups , grp_shape , ngroups , size , props = factorize_ (
1800- by ,
1801- axes = (), # always (), we offset later if necessary.
1802- expected_groups = expected_groups ,
1803- fastpath = True ,
1804- reindex = reindex ,
1805- sort = sort ,
1806- )
1817+ kwargs ["by" ] = by
1818+ group_idx , found_groups , grp_shape , * _ = factorize_ (** kwargs )
18071819
18081820 return (group_idx ,), found_groups , grp_shape
18091821
@@ -2058,7 +2070,7 @@ def groupby_reduce(
20582070 # (pd.IntervalIndex or not)
20592071 expected_groups = _convert_expected_groups_to_index (expected_groups , isbins , sort )
20602072
2061- # Don't factorize " early only when
2073+ # Don't factorize early only when
20622074 # grouping by dask arrays, and not having expected_groups
20632075 factorize_early = not (
20642076 # can't do it if we are grouping by dask array but don't have expected_groups
@@ -2069,10 +2081,6 @@ def groupby_reduce(
20692081 bys ,
20702082 expected_groups ,
20712083 any_by_dask = any_by_dask ,
2072- # This is the only way it makes sense I think.
2073- # reindex controls what's actually allocated in chunk_reduce
2074- # At this point, we care about an accurate conversion to codes.
2075- reindex = True ,
20762084 sort = sort ,
20772085 )
20782086 expected_groups = (pd .RangeIndex (math .prod (grp_shape )),)
@@ -2103,21 +2111,17 @@ def groupby_reduce(
21032111 "along a single axis or when reducing across all dimensions of `by`."
21042112 )
21052113
2106- # TODO: make sure expected_groups is unique
21072114 if nax == 1 and by_ .ndim > 1 and expected_groups is None :
2108- if not any_by_dask :
2109- expected_groups = _get_expected_groups (by_ , sort )
2110- else :
2111- # When we reduce along all axes, we are guaranteed to see all
2112- # groups in the final combine stage, so everything works.
2113- # This is not necessarily true when reducing along a subset of axes
2114- # (of by)
2115- # TODO: Does this depend on chunking of by?
2116- # For e.g., we could relax this if there is only one chunk along all
2117- # by dim != axis?
2118- raise NotImplementedError (
2119- "Please provide ``expected_groups`` when not reducing along all axes."
2120- )
2115+ # When we reduce along all axes, we are guaranteed to see all
2116+ # groups in the final combine stage, so everything works.
2117+ # This is not necessarily true when reducing along a subset of axes
2118+ # (of by)
2119+ # TODO: Does this depend on chunking of by?
2120+ # For e.g., we could relax this if there is only one chunk along all
2121+ # by dim != axis?
2122+ raise NotImplementedError (
2123+ "Please provide ``expected_groups`` when not reducing along all axes."
2124+ )
21212125
21222126 assert nax <= by_ .ndim
21232127 if nax < by_ .ndim :
0 commit comments