Skip to content

Commit 34602d5

Browse files
committed
Scan updates
1 parent 9ea30c5 commit 34602d5

File tree

2 files changed

+11
-17
lines changed

2 files changed

+11
-17
lines changed

flox/dask.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -572,16 +572,14 @@ def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray:
572572
from dask.array.reductions import cumreduction as scan
573573

574574
from .aggregations import scan_binary_op
575+
from .scan import _finalize_scan, _zip, chunk_scan, grouped_reduce
575576

576577
if len(axes) > 1:
577578
raise NotImplementedError("Scans are only supported along a single axis.")
578579
(axis,) = axes
579580

580581
array, by = _unify_chunks(array, by)
581582

582-
# Import scan-specific functions from scan module
583-
from .scan import _finalize_scan, _zip, chunk_scan, grouped_reduce
584-
585583
# 1. zip together group indices & array
586584
zipped = map_blocks(
587585
_zip,

flox/scan.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
from typing import TYPE_CHECKING
1010

1111
import numpy as np
12+
import pandas as pd
1213

1314
from .aggregations import (
14-
AGGREGATIONS,
15+
SCANS,
1516
AlignedArrays,
1617
Scan,
1718
ScanState,
@@ -37,19 +38,17 @@
3738
from .types import DaskArray
3839

3940

40-
def _validate_expected_groups_for_scan(nby, expected_groups):
41+
def _validate_expected_groups(nby, expected_groups):
4142
"""Validate expected_groups for scan operations."""
4243
if expected_groups is None:
4344
return (None,) * nby
4445
return expected_groups
4546

4647

47-
def _convert_expected_groups_to_index_for_scan(expected_groups, isbin, sort):
48+
def _convert_expected_groups_to_index(expected_groups):
4849
"""Convert expected_groups to index for scan operations."""
49-
import pandas as pd
50-
5150
result = []
52-
for expect, isbin_ in zip(expected_groups, isbin):
51+
for expect in expected_groups:
5352
if expect is None:
5453
result.append(None)
5554
elif isinstance(expect, pd.Index):
@@ -159,21 +158,18 @@ def groupby_scan(
159158
if not is_duck_array(array):
160159
array = np.asarray(array)
161160

162-
if isinstance(func, str):
163-
agg = AGGREGATIONS[func]
161+
agg = SCANS[func] if isinstance(func, str) else func
164162
assert isinstance(agg, Scan)
165163
agg = copy.deepcopy(agg)
166164

167-
if (agg == AGGREGATIONS["ffill"] or agg == AGGREGATIONS["bfill"]) and array.dtype.kind != "f":
165+
if (agg == SCANS["ffill"] or agg == SCANS["bfill"]) and array.dtype.kind != "f":
168166
# nothing to do, no NaNs!
169167
return array
170168

171169
if expected_groups is not None:
172-
raise NotImplementedError("Setting `expected_groups` and binning is not supported yet.")
173-
expected_groups = _validate_expected_groups_for_scan(nby, expected_groups)
174-
expected_groups = _convert_expected_groups_to_index_for_scan(
175-
expected_groups, isbin=(False,) * nby, sort=False
176-
)
170+
raise NotImplementedError("Setting `expected_groups` with scans is not supported yet.")
171+
expected_groups = _validate_expected_groups(nby, expected_groups)
172+
expected_groups = _convert_expected_groups_to_index(expected_groups)
177173

178174
# Don't factorize early only when
179175
# grouping by dask arrays, and not having expected_groups

0 commit comments

Comments
 (0)