Skip to content

Commit 0d78db9

Browse files
authored
Support cohorts for nD by arrays (#55)
1 parent 2cef60e commit 0d78db9

File tree

3 files changed

+161
-78
lines changed

3 files changed

+161
-78
lines changed

flox/core.py

Lines changed: 105 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
Any,
99
Callable,
1010
Dict,
11-
Iterable,
1211
Mapping,
1312
Optional,
1413
Sequence,
@@ -131,9 +130,9 @@ def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
131130
Parameters
132131
----------
133132
labels: np.ndarray
134-
1D Array of group labels
135-
chunks: tuple
136-
chunks along grouping dimension for array that is being reduced
133+
mD Array of group labels
134+
array: tuple
135+
nD array that is being reduced
137136
merge: bool, optional
138137
Attempt to merge cohorts when one cohort's chunks are a subset
139138
of another cohort's chunks.
@@ -147,19 +146,35 @@ def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
147146
"""
148147
import copy
149148

149+
import dask
150150
import toolz as tlz
151151

152152
if method == "split-reduce":
153153
return np.unique(labels).reshape(-1, 1).tolist()
154154

155-
which_chunk = np.repeat(np.arange(len(chunks)), chunks)
155+
# To do this, we must have values in memory so casting to numpy should be safe
156+
labels = np.asarray(labels)
157+
158+
# Build an array with the shape of labels, but where every element is the "chunk number"
159+
# 1. First subset the array appropriately
160+
axis = range(-labels.ndim, 0)
161+
# Easier to create a dask array and use the .blocks property
162+
array = dask.array.ones(tuple(sum(c) for c in chunks), chunks=chunks)
163+
164+
# Iterate over each block and create a new block of same shape with "chunk number"
165+
shape = tuple(array.blocks.shape[ax] for ax in axis)
166+
blocks = np.empty(np.prod(shape), dtype=object)
167+
for idx, block in enumerate(array.blocks.ravel()):
168+
blocks[idx] = np.full(tuple(block.shape[ax] for ax in axis), idx)
169+
which_chunk = np.block(blocks.reshape(shape).tolist()).ravel()
170+
156171
# these are chunks where a label is present
157-
label_chunks = {lab: tuple(np.unique(which_chunk[labels == lab])) for lab in np.unique(labels)}
172+
label_chunks = {
173+
lab: tuple(np.unique(which_chunk[labels.ravel() == lab])) for lab in np.unique(labels)
174+
}
158175
# These invert the label_chunks mapping so we know which labels occur together.
159176
chunks_cohorts = tlz.groupby(label_chunks.get, label_chunks.keys())
160177

161-
# TODO: sort by length of values (i.e. cohort);
162-
# then loop in reverse and merge when keys are subsets of initial keys?
163178
if merge:
164179
# First sort by number of chunks occupied by cohort
165180
sorted_chunks_cohorts = dict(
@@ -299,7 +314,6 @@ def reindex_(array: np.ndarray, from_, to, fill_value=None, axis: int = -1) -> n
299314
reindexed = np.full(array.shape[:-1] + (len(to),), fill_value, dtype=array.dtype)
300315
return reindexed
301316

302-
from_ = np.atleast_1d(from_)
303317
if from_.dtype.kind == "O" and isinstance(from_[0], tuple):
304318
raise NotImplementedError(
305319
"Currently does not support reindexing with object arrays of tuples. "
@@ -706,14 +720,14 @@ def _npg_aggregate(
706720
expected_groups: Union[Sequence, np.ndarray, None],
707721
axis: Sequence,
708722
keepdims,
709-
group_ndim: int,
723+
neg_axis: Sequence,
710724
fill_value: Any = None,
711725
min_count: Optional[int] = None,
712726
engine: str = "numpy",
713727
finalize_kwargs: Optional[Mapping] = None,
714728
) -> FinalResultsDict:
715729
"""Final aggregation step of tree reduction"""
716-
results = _npg_combine(x_chunk, agg, axis, keepdims, group_ndim, engine)
730+
results = _npg_combine(x_chunk, agg, axis, keepdims, neg_axis, engine)
717731
return _finalize_results(
718732
results, agg, axis, expected_groups, fill_value, min_count, finalize_kwargs
719733
)
@@ -742,7 +756,7 @@ def _npg_combine(
742756
agg: Aggregation,
743757
axis: Sequence,
744758
keepdims: bool,
745-
group_ndim: int,
759+
neg_axis: Sequence,
746760
engine: str,
747761
) -> IntermediateDict:
748762
"""Combine intermediates step of tree reduction."""
@@ -771,12 +785,7 @@ def reindex_intermediates(x):
771785

772786
x_chunk = deepmap(reindex_intermediates, x_chunk)
773787

774-
group_conc_axis: Iterable[int]
775-
if group_ndim == 1:
776-
group_conc_axis = (0,)
777-
else:
778-
group_conc_axis = sorted(group_ndim - ax - 1 for ax in axis)
779-
groups = _conc2(x_chunk, "groups", axis=group_conc_axis)
788+
groups = _conc2(x_chunk, "groups", axis=neg_axis)
780789

781790
if agg.reduction_type == "argreduce":
782791
# If "nanlen" was added for masking later, we need to account for that
@@ -830,7 +839,7 @@ def reindex_intermediates(x):
830839
np.empty(shape=(1,) * (len(axis) - 1) + (0,), dtype=agg.dtype)
831840
)
832841
results["groups"] = np.empty(
833-
shape=(1,) * (len(group_conc_axis) - 1) + (0,), dtype=groups.dtype
842+
shape=(1,) * (len(neg_axis) - 1) + (0,), dtype=groups.dtype
834843
)
835844
else:
836845
_results = chunk_reduce(
@@ -891,6 +900,7 @@ def groupby_agg(
891900
method: str = "map-reduce",
892901
min_count: Optional[int] = None,
893902
isbin: bool = False,
903+
reindex: bool = False,
894904
engine: str = "numpy",
895905
finalize_kwargs: Optional[Mapping] = None,
896906
) -> Tuple["DaskArray", Union[np.ndarray, "DaskArray"]]:
@@ -902,6 +912,9 @@ def groupby_agg(
902912
assert isinstance(axis, Sequence)
903913
assert all(ax >= 0 for ax in axis)
904914

915+
# these are negative axis indices useful for concatenating the intermediates
916+
neg_axis = range(-len(axis), 0)
917+
905918
inds = tuple(range(array.ndim))
906919
name = f"groupby_{agg.name}"
907920
token = dask.base.tokenize(array, by, agg, expected_groups, axis, split_out)
@@ -926,11 +939,11 @@ def groupby_agg(
926939
axis=axis,
927940
# with the current implementation we want reindexing at the blockwise step
928941
# only reindex to groups present at combine stage
929-
expected_groups=expected_groups if split_out > 1 or isbin else None,
942+
expected_groups=expected_groups if reindex or split_out > 1 or isbin else None,
930943
fill_value=agg.fill_value["intermediate"],
931944
dtype=agg.dtype["intermediate"],
932945
isbin=isbin,
933-
reindex=split_out > 1,
946+
reindex=reindex or (split_out > 1),
934947
engine=engine,
935948
),
936949
inds,
@@ -964,7 +977,7 @@ def groupby_agg(
964977
expected_agg = expected_groups
965978

966979
agg_kwargs = dict(
967-
group_ndim=by.ndim,
980+
neg_axis=neg_axis,
968981
fill_value=fill_value,
969982
min_count=min_count,
970983
engine=engine,
@@ -984,7 +997,7 @@ def groupby_agg(
984997
expected_groups=expected_agg,
985998
**agg_kwargs,
986999
),
987-
combine=partial(_npg_combine, agg=agg, group_ndim=by.ndim, engine=engine),
1000+
combine=partial(_npg_combine, agg=agg, neg_axis=neg_axis, engine=engine),
9881001
name=f"{name}-reduce",
9891002
dtype=array.dtype,
9901003
axis=axis,
@@ -996,12 +1009,7 @@ def groupby_agg(
9961009
# Blockwise apply the aggregation step so that one input chunk → one output chunk
9971010
# TODO: We could combine this with the chunk reduction and do everything in one task.
9981011
# This would also optimize the single block along reduced-axis case.
999-
if (
1000-
expected_groups is None
1001-
or split_out > 1
1002-
or len(axis) > 1
1003-
or not isinstance(by_maybe_numpy, np.ndarray)
1004-
):
1012+
if expected_groups is None or split_out > 1 or not isinstance(by_maybe_numpy, np.ndarray):
10051013
raise NotImplementedError
10061014

10071015
reduced = dask.array.blockwise(
@@ -1020,17 +1028,25 @@ def groupby_agg(
10201028
dtype=array.dtype,
10211029
meta=array._meta,
10221030
align_arrays=False,
1023-
name=f"{name}-blockwise-agg-{token}",
1031+
name=f"{name}-blockwise-{token}",
10241032
)
1025-
chunks = array.chunks[axis[0]]
10261033

10271034
# find number of groups in each chunk, this is needed for output chunks
10281035
# along the reduced axis
1029-
bnds = np.insert(np.cumsum(chunks), 0, 0)
1030-
groups_per_chunk = tuple(
1031-
len(np.unique(by_maybe_numpy[i0:i1])) for i0, i1 in zip(bnds[:-1], bnds[1:])
1032-
)
1033-
output_chunks = reduced.chunks[: -(len(axis))] + (groups_per_chunk,)
1036+
from dask.array.core import slices_from_chunks
1037+
1038+
slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis))
1039+
if expected_groups is None:
1040+
groups_in_block = tuple(np.unique(by_maybe_numpy[slc]) for slc in slices)
1041+
else:
1042+
# For cohorts, we could be indexing a block with groups that
1043+
# are not in the cohort (usually for nD `by`)
1044+
# Only keep the expected groups.
1045+
groups_in_block = tuple(
1046+
np.intersect1d(by_maybe_numpy[slc], expected_groups) for slc in slices
1047+
)
1048+
ngroups_per_block = tuple(len(groups) for groups in groups_in_block)
1049+
output_chunks = reduced.chunks[: -(len(axis))] + (ngroups_per_block,)
10341050
else:
10351051
raise ValueError(f"Unknown method={method}.")
10361052

@@ -1059,15 +1075,22 @@ def _getitem(d, key1, key2):
10591075
),
10601076
)
10611077
else:
1062-
groups = (expected_groups,)
1078+
if method == "map-reduce":
1079+
groups = (expected_groups,)
1080+
else:
1081+
groups = (np.concatenate(groups_in_block),)
10631082

10641083
layer: Dict[Tuple, Tuple] = {} # type: ignore
10651084
agg_name = f"{name}-{token}"
10661085
for ochunk in itertools.product(*ochunks):
10671086
if method == "blockwise":
1068-
inchunk = ochunk
1087+
if len(axis) == 1:
1088+
inchunk = ochunk
1089+
else:
1090+
nblocks = tuple(len(array.chunks[ax]) for ax in axis)
1091+
inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks)
10691092
else:
1070-
inchunk = ochunk[:-1] + (0,) * (len(axis)) + (ochunk[-1],) * int(split_out > 1)
1093+
inchunk = ochunk[:-1] + (0,) * len(axis) + (ochunk[-1],) * int(split_out > 1)
10711094
layer[(agg_name, *ochunk)] = (
10721095
operator.getitem,
10731096
(reduced.name, *inchunk),
@@ -1089,6 +1112,7 @@ def groupby_reduce(
10891112
func: Union[str, Aggregation],
10901113
*,
10911114
expected_groups: Union[Sequence, np.ndarray] = None,
1115+
sort: bool = True,
10921116
isbin: bool = False,
10931117
axis=None,
10941118
fill_value=None,
@@ -1114,6 +1138,10 @@ def groupby_reduce(
11141138
Expected unique labels.
11151139
isbin : bool, optional
11161140
Are ``expected_groups`` bin edges?
1141+
sort : (optional), bool
1142+
Whether groups should be returned in sorted order. Only applies for dask
1143+
reductions when ``method`` is not `"map-reduce"`. For ``"map-reduce", the groups
1144+
are always sorted.
11171145
axis : (optional) None or int or Sequence[int]
11181146
If None, reduce across all dimensions of by
11191147
Else, reduce across corresponding axes of array
@@ -1138,17 +1166,19 @@ def groupby_reduce(
11381166
* ``"blockwise"``:
11391167
Only reduce using blockwise and avoid aggregating blocks
11401168
together. Useful for resampling-style reductions where group
1141-
members are always together. The array is rechunked so that
1142-
chunk boundaries line up with group boundaries
1169+
members are always together. If `by` is 1D, `array` is automatically
1170+
rechunked so that chunk boundaries line up with group boundaries
11431171
i.e. each block contains all members of any group present
1144-
in that block.
1172+
in that block. For nD `by`, you must make sure that all members of a group
1173+
are present in a single block.
11451174
* ``"cohorts"``:
11461175
Finds group labels that tend to occur together ("cohorts"),
11471176
indexes out cohorts and reduces that subset using "map-reduce",
11481177
repeat for all cohorts. This works well for many time groupings
11491178
where the group labels repeat at regular intervals like 'hour',
11501179
'month', dayofyear' etc. Optimize chunking ``array`` for this
1151-
method by first rechunking using ``rechunk_for_cohorts``.
1180+
method by first rechunking using ``rechunk_for_cohorts``
1181+
(for 1D ``by`` only).
11521182
* ``"split-reduce"``:
11531183
Break out each group into its own array and then ``"map-reduce"``.
11541184
This is implemented by having each group be its own cohort,
@@ -1208,6 +1238,11 @@ def groupby_reduce(
12081238
else:
12091239
axis = np.core.numeric.normalize_axis_tuple(axis, array.ndim) # type: ignore
12101240

1241+
if method in ["blockwise", "cohorts", "split-reduce"] and len(axis) != by.ndim:
1242+
raise NotImplementedError(
1243+
"Must reduce along all dimensions of `by` when method != 'map-reduce'."
1244+
)
1245+
12111246
if expected_groups is None and isinstance(by, np.ndarray):
12121247
flatby = by.ravel()
12131248
expected_groups = np.unique(flatby[~isnull(flatby)])
@@ -1366,59 +1401,58 @@ def groupby_reduce(
13661401
)
13671402

13681403
if method in ["split-reduce", "cohorts"]:
1369-
if by.ndim > 1:
1370-
raise ValueError(
1371-
"`by` must be 1D when method='split-reduce' and method='cohorts'. "
1372-
f"Received {by.ndim}D array. Please use method='map-reduce' instead."
1373-
)
1374-
assert axis == (array.ndim - 1,)
1375-
1376-
cohorts = find_group_cohorts(by, array.chunks[axis[0]], merge=True, method=method)
1377-
idx = np.arange(len(by))
1404+
cohorts = find_group_cohorts(
1405+
by, [array.chunks[ax] for ax in axis], merge=True, method=method
1406+
)
13781407

13791408
results = []
13801409
groups_ = []
13811410
for cohort in cohorts:
13821411
cohort = sorted(cohort)
1383-
# indexes for a subset of groups
1384-
subset_idx = idx[np.isin(by, cohort)]
1385-
array_subset = array[..., subset_idx]
1386-
numblocks = len(array_subset.chunks[-1])
1412+
# equivalent of xarray.DataArray.where(mask, drop=True)
1413+
mask = np.isin(by, cohort)
1414+
indexer = [np.unique(v) for v in np.nonzero(mask)]
1415+
array_subset = array
1416+
for ax, idxr in zip(range(-by.ndim, 0), indexer):
1417+
array_subset = np.take(array_subset, idxr, axis=ax)
1418+
numblocks = np.prod([len(array_subset.chunks[ax]) for ax in axis])
13871419

13881420
# get final result for these groups
13891421
r, *g = partial_agg(
13901422
array_subset,
1391-
by[subset_idx],
1423+
by[np.ix_(*indexer)],
13921424
expected_groups=cohort,
1425+
# reindex to expected_groups at the blockwise step.
1426+
# this approach avoids replacing non-cohort members with
1427+
# np.nan or some other sentinel value, and preserves dtypes
1428+
reindex=True,
13931429
# if only a single block along axis, we can just work blockwise
13941430
# inspired by https://github.com/dask/dask/issues/8361
1395-
method="blockwise" if numblocks == 1 else "map-reduce",
1431+
method="blockwise" if numblocks == 1 and len(axis) == by.ndim else "map-reduce",
13961432
)
13971433
results.append(r)
13981434
groups_.append(cohort)
13991435

14001436
# concatenate results together,
14011437
# sort to make sure we match expected output
1402-
allgroups = np.hstack(groups_)
1403-
sorted_idx = np.argsort(allgroups)
1404-
result = np.concatenate(results, axis=-1)[..., sorted_idx]
1405-
groups = (allgroups[sorted_idx],)
1406-
1438+
groups = (np.hstack(groups_),)
1439+
result = np.concatenate(results, axis=-1)
14071440
else:
14081441
if method == "blockwise":
1409-
if by.ndim > 1:
1410-
raise ValueError(
1411-
"For method='blockwise', ``by`` must be 1D. "
1412-
f"Received {by.ndim} dimensions instead."
1413-
)
1414-
array = rechunk_for_blockwise(array, axis=-1, labels=by)
1415-
1416-
# TODO: test with mixed array kinds (numpy + dask; dask + numpy)
1442+
if by.ndim == 1:
1443+
array = rechunk_for_blockwise(array, axis=-1, labels=by)
1444+
1445+
# TODO: test with mixed array kinds (numpy array + dask by)
14171446
result, *groups = partial_agg(
14181447
array,
14191448
by,
14201449
expected_groups=expected_groups,
14211450
method=method,
14221451
)
1452+
if sort and method != "map-reduce":
1453+
assert len(groups) == 1
1454+
sorted_idx = np.argsort(groups[0])
1455+
result = result[..., sorted_idx]
1456+
groups = (groups[0][sorted_idx],)
14231457

14241458
return (result, *groups)

0 commit comments

Comments
 (0)