@@ -247,7 +247,7 @@ def _compute_label_chunk_bitmask(labels, chunks, nlabels):
247247 return bitmask
248248
249249
250- @memoize
250+ # @memoize
251251def find_group_cohorts (
252252 labels , chunks , merge : bool = True , expected_groups : None | pd .RangeIndex = None
253253) -> dict :
@@ -286,7 +286,6 @@ def find_group_cohorts(
286286 nlabels = expected_groups [- 1 ] + 1
287287
288288 labels = np .broadcast_to (labels , shape [- labels .ndim :])
289- ilabels = np .arange (nlabels )
290289 bitmask = _compute_label_chunk_bitmask (labels , chunks , nlabels )
291290
292291 CHUNK_AXIS , LABEL_AXIS = 0 , 1
@@ -303,54 +302,64 @@ def find_group_cohorts(
303302 for lab in range (bitmask .shape [- 1 ])
304303 }
305304
306- # These invert the label_chunks mapping so we know which labels occur together.
305+ # Invert the label_chunks mapping so we know which labels occur together.
307306 def invert (x ) -> tuple [np .ndarray , ...]:
308- arr = label_chunks . get ( x )
309- return tuple (arr ) # type: ignore [arg-type] # pandas issue?
307+ arr = label_chunks [ x ]
308+ return tuple (arr )
310309
311310 chunks_cohorts = tlz .groupby (invert , label_chunks .keys ())
312311
313- # If our dataset has chunksize one along the axis,
314- # then no merging is possible.
312+ # No merging is possible when
313+ # 1. Our dataset has chunksize one along the axis,
315314 single_chunks = all (all (a == 1 for a in ac ) for ac in chunks )
315+ # 2. Every chunk only has a single group, but that group might extend across multiple chunks
316316 one_group_per_chunk = (bitmask .sum (axis = LABEL_AXIS ) == 1 ).all ()
317- # every group is contained to one block, we should be using blockwise here.
317+ # 3. Every group is contained to one block, we should be using blockwise here.
318318 every_group_one_block = (chunks_per_label == 1 ).all ()
319- if every_group_one_block or one_group_per_chunk or single_chunks or not merge :
320- return chunks_cohorts
321-
322- # First sort by number of chunks occupied by cohort
323- sorted_chunks_cohorts = dict (
324- sorted (chunks_cohorts .items (), key = lambda kv : len (kv [0 ]), reverse = True )
325- )
319+ # 4. Existing cohorts don't overlap, great for time grouping with perfect chunking
320+ no_overlapping_cohorts = (np .bincount (np .concatenate (tuple (chunks_cohorts .keys ()))) == 1 ).all ()
326321
327- # precompute needed metrics for the quadratic loop below.
328- items = tuple ((k , len (k ), set (k ), v ) for k , v in sorted_chunks_cohorts .items () if k )
322+ if (
323+ every_group_one_block
324+ or one_group_per_chunk
325+ or single_chunks
326+ or no_overlapping_cohorts
327+ or not merge
328+ ):
329+ return chunks_cohorts
329330
331+ # Containment = |Q & S| / |Q|
332+ # - |X| is the cardinality of set X
333+ # - Q is the query set being tested
334+ # - S is the existing set
335+ MIN_CONTAINMENT = 0.75 # arbitrary
336+ asfloat = bitmask .astype (float )
337+ containment = ((asfloat .T @ asfloat ) / chunks_per_label [present_labels ]).tocsr ()
338+ mask = containment .data < MIN_CONTAINMENT
339+ containment .data [mask ] = 0
340+ containment .eliminate_zeros ()
341+
342+ # Iterate over labels, beginning with those with most chunks
343+ order = np .argsort (containment .sum (axis = LABEL_AXIS ))[::- 1 ]
330344 merged_cohorts = {}
331- merged_keys : set [tuple ] = set ()
332-
333- # Now we iterate starting with the longest number of chunks,
334- # and then merge in cohorts that are present in a subset of those chunks
335- # I think this is suboptimal and must fail at some point.
336- # But it might work for most cases. There must be a better way...
337- for idx , (k1 , len_k1 , set_k1 , v1 ) in enumerate (items ):
338- if k1 in merged_keys :
345+ merged_keys = set ()
346+ # TODO: we can optimize this to loop over chunk_cohorts instead
347+ # by zeroing out rows that are already in a cohort
348+ for rowidx in order :
349+ cohort_ = containment .indices [
350+ slice (containment .indptr [rowidx ], containment .indptr [rowidx + 1 ])
351+ ]
352+ cohort = [elem for elem in cohort_ if elem not in merged_keys ]
353+ if not cohort :
339354 continue
340- new_key = set_k1
341- new_value = v1
342- # iterate in reverse since we expect small cohorts
343- # to be most likely merged in to larger ones
344- for k2 , len_k2 , set_k2 , v2 in reversed (items [idx + 1 :]):
345- if k2 not in merged_keys :
346- if (len (set_k2 & new_key ) / len_k2 ) > 0.75 :
347- new_key |= set_k2
348- new_value += v2
349- merged_keys .update ((k2 ,))
350- sorted_ = sorted (new_value )
351- merged_cohorts [tuple (sorted (new_key ))] = sorted_
352- if idx == 0 and (len (sorted_ ) == nlabels ) and (np .array (sorted_ ) == ilabels ).all ():
353- break
355+ merged_keys .update (cohort )
356+ allchunks = (label_chunks [member ] for member in cohort )
357+ chunk = tuple (set (itertools .chain (* allchunks )))
358+ merged_cohorts [chunk ] = cohort
359+
360+ actual_ngroups = np .concatenate (tuple (merged_cohorts .values ())).size
361+ expected_ngroups = bitmask .shape [LABEL_AXIS ]
362+ assert expected_ngroups == actual_ngroups , (expected_ngroups , actual_ngroups )
354363
355364 # sort by first label in cohort
356365 # This will help when sort=True (default)
0 commit comments