1919 _atleast_1d ,
2020 generic_aggregate ,
2121)
22+ from .cohorts import find_group_cohorts
2223from .factorize import _factorize_multiple
24+ from .lib import _should_auto_rechunk_blockwise
25+ from .rechunk import rechunk_for_blockwise
2326from .xrutils import is_duck_array , is_duck_dask_array , module_available
2427
2528if module_available ("numpy" , minversion = "2.0.0" ):
2831 from numpy .core .numeric import normalize_axis_tuple # type: ignore[no-redef]
2932
3033if TYPE_CHECKING :
34+ from typing import Literal
35+
3136 from .core import (
3237 T_By ,
3338 T_EngineOpt ,
3742 )
3843 from .types import DaskArray
3944
45+ T_ScanMethod = Literal ["blockwise" , "blelloch" ]
46+
47+
48+ def _choose_scan_method (
49+ method : T_MethodOpt , preferred_method : T_ScanMethod , nax : int , by_ndim : int
50+ ) -> T_ScanMethod :
51+ """Choose the scan method based on user input and preferred method.
52+
53+ Parameters
54+ ----------
55+ method : T_MethodOpt
56+ User-specified method, or None for automatic selection.
57+ preferred_method : T_ScanMethod
58+ The preferred method based on data layout analysis.
59+ nax : int
60+ Number of axes being reduced.
61+ by_ndim : int
62+ Number of dimensions in the `by` array.
63+
64+ Returns
65+ -------
66+ T_ScanMethod
67+ The chosen scan method: "blockwise" or "blelloch".
68+ """
69+ if method is None :
70+ # Scans must reduce along all dimensions of by for blockwise
71+ if nax != by_ndim :
72+ return "blelloch"
73+ return preferred_method
74+ elif method == "blockwise" :
75+ return "blockwise"
76+ else :
77+ # For any other method (including map-reduce, cohorts), use blelloch
78+ return "blelloch"
79+
4080
4181def _validate_expected_groups (nby , expected_groups ):
4282 """Validate expected_groups for scan operations."""
@@ -91,8 +131,8 @@ def groupby_scan(
91131 Value to assign when a label in ``expected_groups`` is not present.
92132 dtype : data-type , optional
93133 DType for the output. Can be anything that is accepted by ``np.dtype``.
94- method : {"blockwise", "cohorts "}, optional
95- Strategy for reduction of dask arrays only:
134+ method : {"blockwise", "blelloch "}, optional
135+ Strategy for scan of dask arrays only:
96136 * ``"blockwise"``:
97137 Only scan using blockwise and avoid aggregating blocks
98138 together. Useful for resampling-style groupby problems where group
@@ -101,14 +141,10 @@ def groupby_scan(
101141 i.e. each block contains all members of any group present
102142 in that block. For nD `by`, you must make sure that all members of a group
103143 are present in a single block.
104- * ``"cohorts"``:
105- Finds group labels that tend to occur together ("cohorts"),
106- indexes out cohorts and reduces that subset using "map-reduce",
107- repeat for all cohorts. This works well for many time groupings
108- where the group labels repeat at regular intervals like 'hour',
109- 'month', dayofyear' etc. Optimize chunking ``array`` for this
110- method by first rechunking using ``rechunk_for_cohorts``
111- (for 1D ``by`` only).
144+ * ``"blelloch"``:
145+ Use Blelloch's parallel prefix scan algorithm, which allows
146+ scanning across chunk boundaries. This is the default when groups
147+ span multiple chunks.
112148 engine : {"flox", "numpy", "numba", "numbagg"}, optional
113149 Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk:
114150 * ``"numpy"``:
@@ -149,8 +185,6 @@ def groupby_scan(
149185
150186 if engine is not None :
151187 raise NotImplementedError ("Setting `engine` is not supported for scans yet." )
152- if method is not None :
153- raise NotImplementedError ("Setting `method` is not supported for scans yet." )
154188 if engine is None :
155189 engine = "flox"
156190 assert engine == "flox"
@@ -191,6 +225,38 @@ def groupby_scan(
191225 by_ : np .ndarray
192226 (by_ ,) = bys
193227 has_dask = is_duck_dask_array (array ) or is_duck_dask_array (by_ )
228+ nax = len (axis_ )
229+
230+ # Method selection for dask arrays
231+ scan_method : T_ScanMethod = "blelloch"
232+ if has_dask :
233+ (single_axis ,) = axis_ # type: ignore[misc]
234+
235+ # Try rechunking for sorted 1D by when method is not specified
236+ if _should_auto_rechunk_blockwise (method , array , any_by_dask , by_ ):
237+ rechunk_method , array = rechunk_for_blockwise (array , single_axis , by_ , force = False )
238+ if rechunk_method == "blockwise" :
239+ method = "blockwise"
240+
241+ # Determine preferred method based on data layout
242+ if not any_by_dask and method is None :
243+ cohorts_method , _ = find_group_cohorts (
244+ by_ ,
245+ [array .chunks [ax ] for ax in range (- by_ .ndim , 0 )], # type: ignore[union-attr]
246+ expected_groups = None ,
247+ merge = False ,
248+ )
249+ # Map groupby_reduce methods to scan methods
250+ preferred_method : T_ScanMethod = "blockwise" if cohorts_method == "blockwise" else "blelloch"
251+ else :
252+ preferred_method = "blelloch"
253+
254+ # Choose the final method
255+ scan_method = _choose_scan_method (method , preferred_method , nax , by_ .ndim )
256+
257+ # Rechunk if blockwise was explicitly requested but data isn't aligned
258+ if preferred_method != "blockwise" and scan_method == "blockwise" and by_ .ndim == 1 :
259+ _ , array = rechunk_for_blockwise (array , axis = - 1 , labels = by_ )
194260
195261 if array .dtype .kind in "Mm" :
196262 cast_to = array .dtype
@@ -237,7 +303,7 @@ def groupby_scan(
237303 else :
238304 from .dask import dask_groupby_scan
239305
240- result = dask_groupby_scan (inp .array , inp .group_idx , axes = axis_ , agg = agg )
306+ result = dask_groupby_scan (inp .array , inp .group_idx , axes = axis_ , agg = agg , method = scan_method )
241307
242308 # Made a design choice here to have `postprocess` handle both array and group_idx
243309 out = AlignedArrays (array = result , group_idx = by_ )
0 commit comments