@@ -363,37 +363,55 @@ def invert(x) -> tuple[np.ndarray, ...]:
363363 logger .info ("find_group_cohorts: cohorts is preferred, chunking is perfect." )
364364 return "cohorts" , chunks_cohorts
365365
366- # Containment = |Q & S| / |Q|
366+ # We'll use containment to measure degree of overlap between labels.
367+ # Containment C = |Q & S| / |Q|
367368 # - |X| is the cardinality of set X
368369 # - Q is the query set being tested
369370 # - S is the existing set
370- # We'll use containment to measure degree of overlap between labels. The bitmask
371- # matrix allows us to calculate this pretty efficiently.
372- asfloat = bitmask .astype (float )
373- # Note: While A.T @ A is a symmetric matrix, the division by chunks_per_label
374- # makes it non-symmetric.
375- containment = csr_array ((asfloat .T @ asfloat ) / chunks_per_label )
376-
377- # The containment matrix is a measure of how much the labels overlap
378- # with each other. We treat the sparsity = (nnz/size) as a summary measure of the net overlap.
371+ # The bitmask matrix S allows us to calculate this pretty efficiently using a dot product.
372+ # S.T @ S / chunks_per_label
373+ #
374+ # We treat the sparsity(C) = (nnz/size) as a summary measure of the net overlap.
379375 # 1. For high enough sparsity, there is a lot of overlap and we should use "map-reduce".
380376 # 2. When labels are uniformly distributed amongst all chunks
381377 # (and number of labels < chunk size), sparsity is 1.
382378 # 3. Time grouping cohorts (e.g. dayofyear) appear as lines in this matrix.
383379 # 4. When there are no overlaps at all between labels, containment is a block diagonal matrix
384380 # (approximately).
385- MAX_SPARSITY_FOR_COHORTS = 0.6 # arbitrary
386- sparsity = containment .nnz / math .prod (containment .shape )
381+ #
382+ # However computing S.T @ S can still be the slowest step, especially if S
383+ # is not particularly sparse. Empirically the sparsity( S.T @ S ) > min(1, 2 x sparsity(S)).
384+ # So we use sparsity(S) as a shortcut.
385+ MAX_SPARSITY_FOR_COHORTS = 0.4 # arbitrary
386+ sparsity = bitmask .nnz / math .prod (bitmask .shape )
387387 preferred_method : Literal ["map-reduce" ] | Literal ["cohorts" ]
388+ logger .debug (
389+ "sparsity of bitmask is {}, threshold is {}" .format ( # noqa
390+ sparsity , MAX_SPARSITY_FOR_COHORTS
391+ )
392+ )
388393 if sparsity > MAX_SPARSITY_FOR_COHORTS :
389- logger .info ("sparsity is {}" .format (sparsity )) # noqa
390394 if not merge :
391- logger .info ("find_group_cohorts: merge=False, choosing 'map-reduce'" )
395+ logger .info (
396+ "find_group_cohorts: bitmask sparsity={}, merge=False, choosing 'map-reduce'" .format ( # noqa
397+ sparsity
398+ )
399+ )
392400 return "map-reduce" , {}
393401 preferred_method = "map-reduce"
394402 else :
395403 preferred_method = "cohorts"
396404
405+ # Note: While A.T @ A is a symmetric matrix, the division by chunks_per_label
406+ # makes it non-symmetric.
407+ asfloat = bitmask .astype (float )
408+ containment = csr_array (asfloat .T @ asfloat / chunks_per_label )
409+
410+ logger .debug (
411+ "sparsity of containment matrix is {}" .format ( # noqa
412+ containment .nnz / math .prod (containment .shape )
413+ )
414+ )
397415 # Use a threshold to force some merging. We do not use the filtered
398416 # containment matrix for estimating "sparsity" because it is a bit
399417 # hard to reason about.
0 commit comments