Skip to content

Commit 0c4a7f9

Browse files
authored
cohorts: Delete the merge kwarg (#313)
So fast now, we do it always!
1 parent 58bc9be commit 0c4a7f9

File tree

2 files changed

+12
-29
lines changed

2 files changed

+12
-29
lines changed

flox/core.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,7 @@ def _compute_label_chunk_bitmask(labels, chunks, nlabels):
248248

249249

250250
# @memoize
251-
def find_group_cohorts(
252-
labels, chunks, merge: bool = True, expected_groups: None | pd.RangeIndex = None
253-
) -> dict:
251+
def find_group_cohorts(labels, chunks, expected_groups: None | pd.RangeIndex = None) -> dict:
254252
"""
255253
Finds groups labels that occur together aka "cohorts"
256254
@@ -265,9 +263,8 @@ def find_group_cohorts(
265263
represents NaNs.
266264
chunks : tuple
267265
chunks of the array being reduced
268-
merge : bool, optional
269-
Attempt to merge cohorts when one cohort's chunks are a subset
270-
of another cohort's chunks.
266+
expected_groups: pd.RangeIndex (optional)
267+
Used to extract the largest label expected
271268
272269
Returns
273270
-------
@@ -322,13 +319,7 @@ def invert(x) -> tuple[np.ndarray, ...]:
322319
# 4. Existing cohorts don't overlap, great for time grouping with perfect chunking
323320
no_overlapping_cohorts = (np.bincount(np.concatenate(tuple(chunks_cohorts.keys()))) == 1).all()
324321

325-
if (
326-
every_group_one_block
327-
or one_group_per_chunk
328-
or single_chunks
329-
or no_overlapping_cohorts
330-
or not merge
331-
):
322+
if every_group_one_block or one_group_per_chunk or single_chunks or no_overlapping_cohorts:
332323
return chunks_cohorts
333324

334325
# Containment = |Q & S| / |Q|
@@ -1569,10 +1560,7 @@ def dask_groupby_agg(
15691560

15701561
elif method == "cohorts":
15711562
chunks_cohorts = find_group_cohorts(
1572-
by_input,
1573-
[array.chunks[ax] for ax in axis],
1574-
merge=True,
1575-
expected_groups=expected_groups,
1563+
by_input, [array.chunks[ax] for ax in axis], expected_groups=expected_groups
15761564
)
15771565
reduced_ = []
15781566
groups_ = []

tests/test_core.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -844,21 +844,16 @@ def test_rechunk_for_blockwise(inchunks, expected):
844844

845845
@requires_dask
846846
@pytest.mark.parametrize(
847-
"expected, labels, chunks, merge",
847+
"expected, labels, chunks",
848848
[
849-
[[[0, 1, 2, 3]], [0, 1, 2, 0, 1, 2, 3], (3, 4), True],
850-
[[[0], [1], [2], [3]], [0, 1, 2, 0, 1, 2, 3], (2, 2, 2, 1), True],
851-
[[[0, 1, 2], [3]], [0, 1, 2, 0, 1, 2, 3], (3, 3, 1), True],
852-
[
853-
[[0], [1, 2, 3, 4], [5]],
854-
np.repeat(np.arange(6), [4, 4, 12, 2, 3, 4]),
855-
(4, 8, 4, 9, 4),
856-
True,
857-
],
849+
[[[0, 1, 2, 3]], [0, 1, 2, 0, 1, 2, 3], (3, 4)],
850+
[[[0], [1], [2], [3]], [0, 1, 2, 0, 1, 2, 3], (2, 2, 2, 1)],
851+
[[[0, 1, 2], [3]], [0, 1, 2, 0, 1, 2, 3], (3, 3, 1)],
852+
[[[0], [1, 2, 3, 4], [5]], np.repeat(np.arange(6), [4, 4, 12, 2, 3, 4]), (4, 8, 4, 9, 4)],
858853
],
859854
)
860-
def test_find_group_cohorts(expected, labels, chunks: tuple[int], merge: bool) -> None:
861-
actual = list(find_group_cohorts(labels, (chunks,), merge).values())
855+
def test_find_group_cohorts(expected, labels, chunks: tuple[int]) -> None:
856+
actual = list(find_group_cohorts(labels, (chunks,)).values())
862857
assert actual == expected, (actual, expected)
863858

864859

0 commit comments

Comments
 (0)