Skip to content

Commit 1b30f08

Browse files
dcherianclaude
andcommitted
Support method="blockwise" for scans
🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 34602d5 commit 1b30f08

File tree

5 files changed

+203
-37
lines changed

5 files changed

+203
-37
lines changed

flox/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@
113113
_is_reindex_sparse_supported_reduction,
114114
_issorted,
115115
_postprocess_numbagg,
116+
_should_auto_rechunk_blockwise,
116117
)
117118

118119

@@ -962,7 +963,7 @@ def groupby_reduce(
962963
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
963964
has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_)
964965

965-
if method is None and is_duck_dask_array(array) and not any_by_dask and by_.ndim == 1 and _issorted(by_):
966+
if _should_auto_rechunk_blockwise(method, array, any_by_dask, by_):
966967
# Let's try rechunking for sorted 1D by.
967968
(single_axis,) = axis_
968969
method, array = rechunk_for_blockwise(array, single_axis, by_, force=False)

flox/dask.py

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,17 @@
1414
import toolz as tlz
1515

1616
if TYPE_CHECKING:
17-
from .aggregations import Aggregation, Scan
17+
from typing import Literal
18+
19+
from .aggregations import Aggregation
1820
from .core import T_Axes, T_Engine, T_Method
1921
from .lib import ArrayLayer
2022
from .reindex import ReindexArrayType, ReindexStrategy
2123
from .types import DaskArray, Graph, IntermediateDict, T_By
2224

25+
T_ScanMethod = Literal["blelloch", "blockwise"]
26+
27+
from .aggregations import Scan, scan_binary_op
2328
from .core import (
2429
DUMMY_AXIS,
2530
_get_chunk_reduction,
@@ -34,6 +39,7 @@
3439
ReindexStrategy,
3540
reindex_,
3641
)
42+
from .scan import _finalize_scan, _zip, chunk_scan, grouped_reduce
3743
from .types import FinalResultsDict, IntermediateDict
3844
from .xrutils import is_duck_dask_array, notnull
3945

@@ -567,49 +573,91 @@ def dask_groupby_agg(
567573
return (result, groups)
568574

569575

570-
def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray:
576+
def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan, method: T_ScanMethod = "blelloch") -> DaskArray:
577+
"""Grouped scan for dask arrays.
578+
579+
Parameters
580+
----------
581+
array : DaskArray
582+
Input array to scan.
583+
by : DaskArray
584+
Group labels array, must have same chunks as array along scan axis.
585+
axes : T_Axes
586+
Tuple of axes to scan along (must be single axis).
587+
agg : Scan
588+
Scan aggregation specification.
589+
method : {"blelloch", "blockwise"}, optional
590+
Scan method to use:
591+
- "blelloch": Blelloch parallel prefix scan algorithm, allows scanning
592+
across chunk boundaries using tree reduction. Default.
593+
- "blockwise": Each chunk is processed independently. Only valid when
594+
all members of each group are contained within a single chunk.
595+
596+
Returns
597+
-------
598+
DaskArray
599+
Result of the grouped scan with same shape and chunks as input.
600+
"""
571601
from dask.array import map_blocks
572602
from dask.array.reductions import cumreduction as scan
573-
574-
from .aggregations import scan_binary_op
575-
from .scan import _finalize_scan, _zip, chunk_scan, grouped_reduce
603+
from dask.base import tokenize
576604

577605
if len(axes) > 1:
578606
raise NotImplementedError("Scans are only supported along a single axis.")
579607
(axis,) = axes
580608

581609
array, by = _unify_chunks(array, by)
582610

611+
# Include method in token to differentiate task graphs
612+
token = tokenize(array, by, agg, axes, method)
613+
583614
# 1. zip together group indices & array
584615
zipped = map_blocks(
585616
_zip,
586617
by,
587618
array,
588619
dtype=array.dtype,
589620
meta=array._meta,
590-
name="groupby-scan-preprocess",
621+
name=f"groupby-scan-preprocess-{token}",
591622
)
592623

593-
scan_ = partial(chunk_scan, agg=agg)
594-
# dask tokenizing error workaround
595-
scan_.__name__ = scan_.func.__name__ # type: ignore[attr-defined]
596-
597624
# 2. Run the scan
598-
accumulated = scan(
599-
func=scan_,
600-
binop=partial(scan_binary_op, agg=agg),
601-
ident=agg.identity,
602-
x=zipped,
603-
axis=axis,
604-
# TODO: support method="sequential" here.
605-
method="blelloch",
606-
preop=partial(grouped_reduce, agg=agg),
625+
if method == "blockwise":
626+
# Apply chunk_scan blockwise - each block independently
627+
scan_func = partial(chunk_scan, agg=agg, axis=axis, dtype=agg.dtype)
628+
scanned = map_blocks(
629+
scan_func,
630+
zipped,
631+
dtype=agg.dtype,
632+
meta=array._meta,
633+
name=f"groupby-scan-{token}",
634+
)
635+
else:
636+
# Use Blelloch parallel prefix scan algorithm
637+
scan_ = partial(chunk_scan, agg=agg)
638+
# dask tokenizing error workaround
639+
scan_.__name__ = scan_.func.__name__ # type: ignore[attr-defined]
640+
641+
scanned = scan(
642+
func=scan_,
643+
binop=partial(scan_binary_op, agg=agg),
644+
ident=agg.identity,
645+
x=zipped,
646+
axis=axis,
647+
# TODO: support method="sequential" here.
648+
method="blelloch",
649+
preop=partial(grouped_reduce, agg=agg),
650+
dtype=agg.dtype,
651+
)
652+
653+
# 3. Extract final result
654+
result = map_blocks(
655+
partial(_finalize_scan, dtype=agg.dtype),
656+
scanned,
607657
dtype=agg.dtype,
658+
name=f"groupby-scan-finalize-{token}",
608659
)
609660

610-
# 3. Unzip and extract the final result array, discard groups
611-
result = map_blocks(partial(_finalize_scan, dtype=agg.dtype), accumulated, dtype=agg.dtype)
612-
613661
assert result.chunks == array.chunks
614662

615663
return result

flox/lib.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import TYPE_CHECKING, TypeAlias, TypeVar
55

66
from .types import DaskArray, Graph
7-
from .xrutils import module_available
7+
from .xrutils import is_duck_dask_array, module_available
88

99
if TYPE_CHECKING:
1010
from .aggregations import Aggregation
@@ -78,6 +78,11 @@ def _issorted(arr, ascending=True) -> bool:
7878
return bool((arr[:-1] >= arr[1:]).all())
7979

8080

81+
def _should_auto_rechunk_blockwise(method, array, any_by_dask: bool, by) -> bool:
82+
"""Check if we should attempt automatic rechunking for blockwise operations."""
83+
return method is None and is_duck_dask_array(array) and not any_by_dask and by.ndim == 1 and _issorted(by)
84+
85+
8186
def _is_nanlen(reduction) -> bool:
8287
return isinstance(reduction, str) and reduction == "nanlen"
8388

flox/scan.py

Lines changed: 79 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
_atleast_1d,
2020
generic_aggregate,
2121
)
22+
from .cohorts import find_group_cohorts
2223
from .factorize import _factorize_multiple
24+
from .lib import _should_auto_rechunk_blockwise
25+
from .rechunk import rechunk_for_blockwise
2326
from .xrutils import is_duck_array, is_duck_dask_array, module_available
2427

2528
if module_available("numpy", minversion="2.0.0"):
@@ -28,6 +31,8 @@
2831
from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]
2932

3033
if TYPE_CHECKING:
34+
from typing import Literal
35+
3136
from .core import (
3237
T_By,
3338
T_EngineOpt,
@@ -37,6 +42,41 @@
3742
)
3843
from .types import DaskArray
3944

45+
T_ScanMethod = Literal["blockwise", "blelloch"]
46+
47+
48+
def _choose_scan_method(
49+
method: T_MethodOpt, preferred_method: T_ScanMethod, nax: int, by_ndim: int
50+
) -> T_ScanMethod:
51+
"""Choose the scan method based on user input and preferred method.
52+
53+
Parameters
54+
----------
55+
method : T_MethodOpt
56+
User-specified method, or None for automatic selection.
57+
preferred_method : T_ScanMethod
58+
The preferred method based on data layout analysis.
59+
nax : int
60+
Number of axes being reduced.
61+
by_ndim : int
62+
Number of dimensions in the `by` array.
63+
64+
Returns
65+
-------
66+
T_ScanMethod
67+
The chosen scan method: "blockwise" or "blelloch".
68+
"""
69+
if method is None:
70+
# Scans must reduce along all dimensions of by for blockwise
71+
if nax != by_ndim:
72+
return "blelloch"
73+
return preferred_method
74+
elif method == "blockwise":
75+
return "blockwise"
76+
else:
77+
# For any other method (including map-reduce, cohorts), use blelloch
78+
return "blelloch"
79+
4080

4181
def _validate_expected_groups(nby, expected_groups):
4282
"""Validate expected_groups for scan operations."""
@@ -91,8 +131,8 @@ def groupby_scan(
91131
Value to assign when a label in ``expected_groups`` is not present.
92132
dtype : data-type , optional
93133
DType for the output. Can be anything that is accepted by ``np.dtype``.
94-
method : {"blockwise", "cohorts"}, optional
95-
Strategy for reduction of dask arrays only:
134+
method : {"blockwise", "blelloch"}, optional
135+
Strategy for scan of dask arrays only:
96136
* ``"blockwise"``:
97137
Only scan using blockwise and avoid aggregating blocks
98138
together. Useful for resampling-style groupby problems where group
@@ -101,14 +141,10 @@ def groupby_scan(
101141
i.e. each block contains all members of any group present
102142
in that block. For nD `by`, you must make sure that all members of a group
103143
are present in a single block.
104-
* ``"cohorts"``:
105-
Finds group labels that tend to occur together ("cohorts"),
106-
indexes out cohorts and reduces that subset using "map-reduce",
107-
repeat for all cohorts. This works well for many time groupings
108-
where the group labels repeat at regular intervals like 'hour',
109-
'month', dayofyear' etc. Optimize chunking ``array`` for this
110-
method by first rechunking using ``rechunk_for_cohorts``
111-
(for 1D ``by`` only).
144+
* ``"blelloch"``:
145+
Use Blelloch's parallel prefix scan algorithm, which allows
146+
scanning across chunk boundaries. This is the default when groups
147+
span multiple chunks.
112148
engine : {"flox", "numpy", "numba", "numbagg"}, optional
113149
Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk:
114150
* ``"numpy"``:
@@ -149,8 +185,6 @@ def groupby_scan(
149185

150186
if engine is not None:
151187
raise NotImplementedError("Setting `engine` is not supported for scans yet.")
152-
if method is not None:
153-
raise NotImplementedError("Setting `method` is not supported for scans yet.")
154188
if engine is None:
155189
engine = "flox"
156190
assert engine == "flox"
@@ -191,6 +225,38 @@ def groupby_scan(
191225
by_: np.ndarray
192226
(by_,) = bys
193227
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
228+
nax = len(axis_)
229+
230+
# Method selection for dask arrays
231+
scan_method: T_ScanMethod = "blelloch"
232+
if has_dask:
233+
(single_axis,) = axis_ # type: ignore[misc]
234+
235+
# Try rechunking for sorted 1D by when method is not specified
236+
if _should_auto_rechunk_blockwise(method, array, any_by_dask, by_):
237+
rechunk_method, array = rechunk_for_blockwise(array, single_axis, by_, force=False)
238+
if rechunk_method == "blockwise":
239+
method = "blockwise"
240+
241+
# Determine preferred method based on data layout
242+
if not any_by_dask and method is None:
243+
cohorts_method, _ = find_group_cohorts(
244+
by_,
245+
[array.chunks[ax] for ax in range(-by_.ndim, 0)], # type: ignore[union-attr]
246+
expected_groups=None,
247+
merge=False,
248+
)
249+
# Map groupby_reduce methods to scan methods
250+
preferred_method: T_ScanMethod = "blockwise" if cohorts_method == "blockwise" else "blelloch"
251+
else:
252+
preferred_method = "blelloch"
253+
254+
# Choose the final method
255+
scan_method = _choose_scan_method(method, preferred_method, nax, by_.ndim)
256+
257+
# Rechunk if blockwise was explicitly requested but data isn't aligned
258+
if preferred_method != "blockwise" and scan_method == "blockwise" and by_.ndim == 1:
259+
_, array = rechunk_for_blockwise(array, axis=-1, labels=by_)
194260

195261
if array.dtype.kind in "Mm":
196262
cast_to = array.dtype
@@ -237,7 +303,7 @@ def groupby_scan(
237303
else:
238304
from .dask import dask_groupby_scan
239305

240-
result = dask_groupby_scan(inp.array, inp.group_idx, axes=axis_, agg=agg)
306+
result = dask_groupby_scan(inp.array, inp.group_idx, axes=axis_, agg=agg, method=scan_method)
241307

242308
# Made a design choice here to have `postprocess` handle both array and group_idx
243309
out = AlignedArrays(array=result, group_idx=by_)

tests/test_core.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2072,6 +2072,52 @@ def test_blockwise_nans() -> None:
20722072
assert_equal(expected, actual)
20732073

20742074

2075+
@requires_dask
2076+
@pytest.mark.parametrize("func", ["nancumsum", "ffill", "bfill"])
2077+
@pytest.mark.parametrize("method", ["blockwise", "blelloch"])
2078+
def test_groupby_scan_method(func, method) -> None:
2079+
"""Test that groupby_scan works correctly with explicit method parameter."""
2080+
# Create array where groups fit within chunks (suitable for blockwise)
2081+
# Include NaN values for ffill/bfill to actually test gap filling
2082+
if "fill" in func:
2083+
data = [1.0, np.nan, 3.0, 4.0, np.nan, 6.0]
2084+
else:
2085+
data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
2086+
array = dask.array.from_array(data, chunks=3)
2087+
by = np.array([0, 0, 0, 1, 1, 1])
2088+
2089+
expected = groupby_scan(array.compute(), by, func=func, axis=-1)
2090+
actual = groupby_scan(array, by, func=func, axis=-1, method=method)
2091+
2092+
assert_equal(expected, actual)
2093+
2094+
2095+
@requires_dask
2096+
def test_groupby_scan_blockwise_auto_rechunk() -> None:
2097+
"""Test that blockwise scan auto-rechunks when groups are sorted but span chunks."""
2098+
from flox import scan
2099+
from flox.rechunk import rechunk_for_blockwise as real_rechunk
2100+
2101+
# Create array with sorted groups that span chunk boundaries
2102+
array = dask.array.from_array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], chunks=2)
2103+
by = np.array([0, 0, 0, 1, 1, 1]) # sorted, but group 0 spans chunks 0 and 1
2104+
2105+
expected = groupby_scan(array.compute(), by, func="nancumsum", axis=-1)
2106+
2107+
# This should auto-rechunk to enable blockwise
2108+
with patch.object(scan, "rechunk_for_blockwise", wraps=real_rechunk) as rechunk_spy:
2109+
actual = groupby_scan(array, by, func="nancumsum", axis=-1)
2110+
assert_equal(expected, actual)
2111+
# Verify rechunk_for_blockwise was called
2112+
assert rechunk_spy.call_count >= 1
2113+
2114+
# Explicit method="blockwise" should also rechunk and produce correct results
2115+
with patch.object(scan, "rechunk_for_blockwise", wraps=real_rechunk) as rechunk_spy:
2116+
actual_explicit = groupby_scan(array, by, func="nancumsum", axis=-1, method="blockwise")
2117+
assert_equal(expected, actual_explicit)
2118+
assert rechunk_spy.call_count >= 1
2119+
2120+
20752121
@pytest.mark.parametrize("func", ["sum", "prod", "count", "nansum"])
20762122
@pytest.mark.parametrize("engine", ["flox", "numpy"])
20772123
def test_agg_dtypes(func, engine) -> None:

0 commit comments

Comments
 (0)