Skip to content

Commit ef21d21

Browse files
committed
Cleanups from multiple groupers PR
1. Clean up chunk_reduce Always reindex and use pandas for fastpaths instead of writing our own 2. Remove unused reindex_multiple 3. group_chunks is always a tuple.
1 parent 5547c46 commit ef21d21

File tree

1 file changed

+12
-35
lines changed

1 file changed

+12
-35
lines changed

flox/core.py

Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def _prepare_for_flox(group_idx, array):
4848
"""
4949
Sort the input array once to save time.
5050
"""
51+
assert array.shape[-1] == group_idx.shape[0]
5152
issorted = (group_idx[:-1] <= group_idx[1:]).all()
5253
if issorted:
5354
ordered_array = array
@@ -323,14 +324,6 @@ def rechunk_for_blockwise(array, axis, labels):
323324
return array.rechunk({axis: newchunks})
324325

325326

326-
def reindex_multiple(array, groups, expected_groups, fill_value, promote=False):
327-
for ax, (group, expect) in enumerate(zip(groups, expected_groups)):
328-
array = reindex_(
329-
array, group, expect, fill_value=fill_value, axis=ax - len(groups), promote=promote
330-
)
331-
return array
332-
333-
334327
def reindex_(
335328
array: np.ndarray, from_, to, fill_value=None, axis: int = -1, promote: bool = False
336329
) -> np.ndarray:
@@ -562,19 +555,12 @@ def chunk_reduce(
562555
array = _collapse_axis(array, len(axis))
563556
axis = -1
564557

565-
if by.ndim == 1:
566-
# TODO: This assertion doesn't work with dask reducing across all dimensions
567-
# when by.ndim == array.ndim
568-
# the intermediates are 1D but axis=range(array.ndim)
569-
# assert axis in (0, -1, array.ndim - 1, None)
570-
axis = -1
571-
572558
# if indices=[2,2,2], npg assumes groups are (0, 1, 2);
573559
# and will return a result that is bigger than necessary
574560
# avoid by factorizing again so indices=[2,2,2] is changed to
575561
# indices=[0,0,0]. This is necessary when combining block results
576562
# factorize can handle strings etc unlike digitize
577-
group_idx, groups, _, ngroups, size, props = factorize_(
563+
group_idx, groups, found_groups_shape, ngroups, size, props = factorize_(
578564
(by,), axis, expected_groups=(expected_groups,)
579565
)
580566
groups = groups[0]
@@ -584,21 +570,17 @@ def chunk_reduce(
584570
array = array.reshape(newshape)
585571

586572
assert group_idx.ndim == 1
587-
empty = np.all(props.nanmask) or np.prod(by.shape) == 0
573+
empty = np.all(props.nanmask)
588574

589575
results: IntermediateDict = {"groups": [], "intermediates": []}
590576
if reindex and expected_groups is not None:
591577
# TODO: what happens with binning here?
592-
results["groups"] = expected_groups.values
578+
results["groups"] = expected_groups.to_numpy()
593579
else:
594580
if empty:
595581
results["groups"] = np.array([np.nan])
596582
else:
597-
if (groups[:-1] <= groups[1:]).all():
598-
sortidx = slice(None)
599-
else:
600-
sortidx = groups.argsort()
601-
results["groups"] = groups[sortidx]
583+
results["groups"] = np.sort(groups)
602584

603585
# npg's argmax ensures that index of first "max" is returned assuming there
604586
# are many elements equal to the "max". Sorting messes this up totally.
@@ -650,13 +632,8 @@ def chunk_reduce(
650632
if np.any(props.nanmask):
651633
# remove NaN group label which should be last
652634
result = result[..., :-1]
653-
if props.offset_group:
654-
result = result.reshape(*final_array_shape[:-1], ngroups)
655-
if reindex:
656-
result = reindex_(result, groups, expected_groups, fill_value=fv)
657-
else:
658-
result = result[..., sortidx]
659-
result = result.reshape(final_array_shape)
635+
result = result.reshape(final_array_shape[:-1] + found_groups_shape)
636+
result = reindex_(result, groups, results["groups"], fill_value=fv, promote=True)
660637
results["intermediates"].append(result)
661638

662639
results["groups"] = np.broadcast_to(results["groups"], final_groups_shape)
@@ -913,8 +890,8 @@ def split_blocks(applied, split_out, expected_groups, split_name):
913890

914891
chunk_tuples = tuple(itertools.product(*tuple(range(n) for n in applied.numblocks)))
915892
ngroups = len(expected_groups)
916-
group_chunks = normalize_chunks(np.ceil(ngroups / split_out), (ngroups,))[0]
917-
idx = tuple(np.cumsum((0,) + group_chunks))
893+
group_chunks = normalize_chunks(np.ceil(ngroups / split_out), (ngroups,))
894+
idx = tuple(np.cumsum((0,) + group_chunks[0]))
918895

919896
# split each block into `split_out` chunks
920897
dsk = {}
@@ -1123,7 +1100,7 @@ def dask_groupby_agg(
11231100
intermediate = applied
11241101
if expected_groups is None:
11251102
expected_groups = _get_expected_groups(by_input, raise_if_dask=False)
1126-
group_chunks = (len(expected_groups),) if expected_groups is not None else (np.nan,)
1103+
group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),)
11271104

11281105
if method == "map-reduce":
11291106
# these are negative axis indices useful for concatenating the intermediates
@@ -1155,7 +1132,7 @@ def dask_groupby_agg(
11551132
keepdims=True,
11561133
concatenate=False,
11571134
)
1158-
output_chunks = reduced.chunks[: -(len(axis) + int(split_out > 1))] + (group_chunks,)
1135+
output_chunks = reduced.chunks[: -(len(axis) + int(split_out > 1))] + group_chunks
11591136
elif method == "blockwise":
11601137
reduced = intermediate
11611138
# Here one input chunk → one output chunka
@@ -1193,7 +1170,7 @@ def dask_groupby_agg(
11931170
dask.array.Array(
11941171
HighLevelGraph.from_collections(groups_name, layer, dependencies=[reduced]),
11951172
groups_name,
1196-
chunks=(group_chunks,),
1173+
chunks=group_chunks,
11971174
dtype=by.dtype,
11981175
),
11991176
)

0 commit comments

Comments
 (0)