Skip to content

Commit 6e1e93a

Browse files
committed
Avoid reindexing all groups when reducing along a single axis.
A single concatenate will suffice.
1 parent 39b0ca4 commit 6e1e93a

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

flox/core.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -678,19 +678,6 @@ def _npg_combine(
678678
if not isinstance(x_chunk, list):
679679
x_chunk = [x_chunk]
680680

681-
unique_groups = np.unique(
682-
tuple(flatten(deepmap(lambda x: list(np.atleast_1d(x["groups"].squeeze())), x_chunk)))
683-
)
684-
685-
def reindex_intermediates(x):
686-
new_shape = x["groups"].shape[:-1] + (len(unique_groups),)
687-
newx = {"groups": np.broadcast_to(unique_groups, new_shape)}
688-
newx["intermediates"] = tuple(
689-
reindex_(v, from_=x["groups"].squeeze(), to=unique_groups, fill_value=f)
690-
for v, f in zip(x["intermediates"], agg.fill_value["intermediate"])
691-
)
692-
return newx
693-
694681
def _conc2(key1, key2=None, axis=None) -> np.ndarray:
695682
"""copied from dask.array.reductions.mean_combine"""
696683
if key2 is not None:
@@ -699,7 +686,24 @@ def _conc2(key1, key2=None, axis=None) -> np.ndarray:
699686
mapped = deepmap(lambda x: x[key1], x_chunk)
700687
return _concatenate2(mapped, axes=axis)
701688

702-
x_chunk = deepmap(reindex_intermediates, x_chunk)
689+
if len(axis) != 1:
690+
# when there's only a single axis of reduction, we can just concatenate later,
691+
# reindexing is unnecessary
692+
# I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated
693+
unique_groups = np.unique(
694+
tuple(flatten(deepmap(lambda x: list(np.atleast_1d(x["groups"].squeeze())), x_chunk)))
695+
)
696+
697+
def reindex_intermediates(x):
698+
new_shape = x["groups"].shape[:-1] + (len(unique_groups),)
699+
newx = {"groups": np.broadcast_to(unique_groups, new_shape)}
700+
newx["intermediates"] = tuple(
701+
reindex_(v, from_=x["groups"].squeeze(), to=unique_groups, fill_value=f)
702+
for v, f in zip(x["intermediates"], agg.fill_value["intermediate"])
703+
)
704+
return newx
705+
706+
x_chunk = deepmap(reindex_intermediates, x_chunk)
703707

704708
group_conc_axis: Iterable[int]
705709
if group_ndim == 1:

0 commit comments

Comments
 (0)