Skip to content

Commit 15abf49

Browse files
authored
Use dot product for containment (#306)
* Use dot product * avoid advanced indexing * small edits * Cache label_chunks as earlier * WIP * Readd chunks_cohorts groupby * Fix * comments * Remove cache test since find_group_cohorts is a lot faster now
1 parent 80ae6a4 commit 15abf49

File tree

3 files changed

+68
-62
lines changed

3 files changed

+68
-62
lines changed

flox/core.py

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def _compute_label_chunk_bitmask(labels, chunks, nlabels):
247247
return bitmask
248248

249249

250-
@memoize
250+
# @memoize
251251
def find_group_cohorts(
252252
labels, chunks, merge: bool = True, expected_groups: None | pd.RangeIndex = None
253253
) -> dict:
@@ -286,7 +286,6 @@ def find_group_cohorts(
286286
nlabels = expected_groups[-1] + 1
287287

288288
labels = np.broadcast_to(labels, shape[-labels.ndim :])
289-
ilabels = np.arange(nlabels)
290289
bitmask = _compute_label_chunk_bitmask(labels, chunks, nlabels)
291290

292291
CHUNK_AXIS, LABEL_AXIS = 0, 1
@@ -303,54 +302,64 @@ def find_group_cohorts(
303302
for lab in range(bitmask.shape[-1])
304303
}
305304

306-
# These invert the label_chunks mapping so we know which labels occur together.
305+
# Invert the label_chunks mapping so we know which labels occur together.
307306
def invert(x) -> tuple[np.ndarray, ...]:
308-
arr = label_chunks.get(x)
309-
return tuple(arr) # type: ignore [arg-type] # pandas issue?
307+
arr = label_chunks[x]
308+
return tuple(arr)
310309

311310
chunks_cohorts = tlz.groupby(invert, label_chunks.keys())
312311

313-
# If our dataset has chunksize one along the axis,
314-
# then no merging is possible.
312+
# No merging is possible when
313+
# 1. Our dataset has chunksize one along the axis,
315314
single_chunks = all(all(a == 1 for a in ac) for ac in chunks)
315+
# 2. Every chunk only has a single group, but that group might extend across multiple chunks
316316
one_group_per_chunk = (bitmask.sum(axis=LABEL_AXIS) == 1).all()
317-
# every group is contained to one block, we should be using blockwise here.
317+
# 3. Every group is contained to one block, we should be using blockwise here.
318318
every_group_one_block = (chunks_per_label == 1).all()
319-
if every_group_one_block or one_group_per_chunk or single_chunks or not merge:
320-
return chunks_cohorts
321-
322-
# First sort by number of chunks occupied by cohort
323-
sorted_chunks_cohorts = dict(
324-
sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0]), reverse=True)
325-
)
319+
# 4. Existing cohorts don't overlap, great for time grouping with perfect chunking
320+
no_overlapping_cohorts = (np.bincount(np.concatenate(tuple(chunks_cohorts.keys()))) == 1).all()
326321

327-
# precompute needed metrics for the quadratic loop below.
328-
items = tuple((k, len(k), set(k), v) for k, v in sorted_chunks_cohorts.items() if k)
322+
if (
323+
every_group_one_block
324+
or one_group_per_chunk
325+
or single_chunks
326+
or no_overlapping_cohorts
327+
or not merge
328+
):
329+
return chunks_cohorts
329330

331+
# Containment = |Q & S| / |Q|
332+
# - |X| is the cardinality of set X
333+
# - Q is the query set being tested
334+
# - S is the existing set
335+
MIN_CONTAINMENT = 0.75 # arbitrary
336+
asfloat = bitmask.astype(float)
337+
containment = ((asfloat.T @ asfloat) / chunks_per_label[present_labels]).tocsr()
338+
mask = containment.data < MIN_CONTAINMENT
339+
containment.data[mask] = 0
340+
containment.eliminate_zeros()
341+
342+
# Iterate over labels, beginning with those with most chunks
343+
order = np.argsort(containment.sum(axis=LABEL_AXIS))[::-1]
330344
merged_cohorts = {}
331-
merged_keys: set[tuple] = set()
332-
333-
# Now we iterate starting with the longest number of chunks,
334-
# and then merge in cohorts that are present in a subset of those chunks
335-
# I think this is suboptimal and must fail at some point.
336-
# But it might work for most cases. There must be a better way...
337-
for idx, (k1, len_k1, set_k1, v1) in enumerate(items):
338-
if k1 in merged_keys:
345+
merged_keys = set()
346+
# TODO: we can optimize this to loop over chunk_cohorts instead
347+
# by zeroing out rows that are already in a cohort
348+
for rowidx in order:
349+
cohort_ = containment.indices[
350+
slice(containment.indptr[rowidx], containment.indptr[rowidx + 1])
351+
]
352+
cohort = [elem for elem in cohort_ if elem not in merged_keys]
353+
if not cohort:
339354
continue
340-
new_key = set_k1
341-
new_value = v1
342-
# iterate in reverse since we expect small cohorts
343-
# to be most likely merged in to larger ones
344-
for k2, len_k2, set_k2, v2 in reversed(items[idx + 1 :]):
345-
if k2 not in merged_keys:
346-
if (len(set_k2 & new_key) / len_k2) > 0.75:
347-
new_key |= set_k2
348-
new_value += v2
349-
merged_keys.update((k2,))
350-
sorted_ = sorted(new_value)
351-
merged_cohorts[tuple(sorted(new_key))] = sorted_
352-
if idx == 0 and (len(sorted_) == nlabels) and (np.array(sorted_) == ilabels).all():
353-
break
355+
merged_keys.update(cohort)
356+
allchunks = (label_chunks[member] for member in cohort)
357+
chunk = tuple(set(itertools.chain(*allchunks)))
358+
merged_cohorts[chunk] = cohort
359+
360+
actual_ngroups = np.concatenate(tuple(merged_cohorts.values())).size
361+
expected_ngroups = bitmask.shape[LABEL_AXIS]
362+
assert expected_ngroups == actual_ngroups, (expected_ngroups, actual_ngroups)
354363

355364
# sort by first label in cohort
356365
# This will help when sort=True (default)

tests/test_core.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -847,11 +847,8 @@ def test_rechunk_for_blockwise(inchunks, expected):
847847
"expected, labels, chunks, merge",
848848
[
849849
[[[0, 1, 2, 3]], [0, 1, 2, 0, 1, 2, 3], (3, 4), True],
850-
[[[0, 1, 2], [3]], [0, 1, 2, 0, 1, 2, 3], (3, 4), False],
851-
[[[0], [1], [2], [3]], [0, 1, 2, 0, 1, 2, 3], (2, 2, 2, 1), False],
852850
[[[0], [1], [2], [3]], [0, 1, 2, 0, 1, 2, 3], (2, 2, 2, 1), True],
853851
[[[0, 1, 2], [3]], [0, 1, 2, 0, 1, 2, 3], (3, 3, 1), True],
854-
[[[0, 1, 2], [3]], [0, 1, 2, 0, 1, 2, 3], (3, 3, 1), False],
855852
[
856853
[[0], [1, 2, 3, 4], [5]],
857854
np.repeat(np.arange(6), [4, 4, 12, 2, 3, 4]),

tests/test_xarray.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -367,26 +367,26 @@ def test_func_is_aggregation():
367367
xarray_reduce(ds.Tair, ds.time.dt.month, func=mean, skipna=False)
368368

369369

370-
@requires_dask
371-
def test_cache():
372-
pytest.importorskip("cachey")
373-
374-
from flox.cache import cache
375-
376-
ds = xr.Dataset(
377-
{
378-
"foo": (("x", "y"), dask.array.ones((10, 20), chunks=2)),
379-
"bar": (("x", "y"), dask.array.ones((10, 20), chunks=2)),
380-
},
381-
coords={"labels": ("y", np.repeat([1, 2], 10))},
382-
)
383-
384-
cache.clear()
385-
xarray_reduce(ds, "labels", func="mean", method="cohorts")
386-
assert len(cache.data) == 1
387-
388-
xarray_reduce(ds, "labels", func="mean", method="blockwise")
389-
assert len(cache.data) == 2
370+
# @requires_dask
371+
# def test_cache():
372+
# pytest.importorskip("cachey")
373+
374+
# from flox.cache import cache
375+
376+
# ds = xr.Dataset(
377+
# {
378+
# "foo": (("x", "y"), dask.array.ones((10, 20), chunks=2)),
379+
# "bar": (("x", "y"), dask.array.ones((10, 20), chunks=2)),
380+
# },
381+
# coords={"labels": ("y", np.repeat([1, 2], 10))},
382+
# )
383+
384+
# cache.clear()
385+
# xarray_reduce(ds, "labels", func="mean", method="cohorts")
386+
# assert len(cache.data) == 1
387+
388+
# xarray_reduce(ds, "labels", func="mean", method="blockwise")
389+
# assert len(cache.data) == 2
390390

391391

392392
@requires_dask

0 commit comments

Comments
 (0)