@@ -248,12 +248,22 @@ def slices_from_chunks(chunks):
248248
249249
250250def _compute_label_chunk_bitmask (labels , chunks , nlabels ):
251+ def make_bitmask (rows , cols ):
252+ data = np .broadcast_to (np .array (1 , dtype = np .uint8 ), rows .shape )
253+ return csc_array ((data , (rows , cols )), dtype = bool , shape = (nchunks , nlabels ))
254+
251255 assert isinstance (labels , np .ndarray )
252256 shape = tuple (sum (c ) for c in chunks )
253257 nchunks = math .prod (len (c ) for c in chunks )
254258
255- labels = np .broadcast_to (labels , shape [- labels .ndim :])
259+ # Shortcut for 1D with size-1 chunks
260+ if shape == (nchunks ,):
261+ rows_array = np .arange (nchunks )
262+ cols_array = labels
263+ mask = labels >= 0
264+ return make_bitmask (rows_array [mask ], cols_array [mask ])
256265
266+ labels = np .broadcast_to (labels , shape [- labels .ndim :])
257267 cols = []
258268 # Add one to handle the -1 sentinel value
259269 label_is_present = np .zeros ((nlabels + 1 ,), dtype = bool )
@@ -272,10 +282,8 @@ def _compute_label_chunk_bitmask(labels, chunks, nlabels):
272282 label_is_present [:] = False
273283 rows_array = np .repeat (np .arange (nchunks ), tuple (len (col ) for col in cols ))
274284 cols_array = np .concatenate (cols )
275- data = np .broadcast_to (np .array (1 , dtype = np .uint8 ), rows_array .shape )
276- bitmask = csc_array ((data , (rows_array , cols_array )), dtype = bool , shape = (nchunks , nlabels ))
277285
278- return bitmask
286+ return make_bitmask ( rows_array , cols_array )
279287
280288
281289# @memoize
@@ -312,13 +320,18 @@ def find_group_cohorts(
312320 labels = np .asarray (labels )
313321
314322 shape = tuple (sum (c ) for c in chunks )
323+ nchunks = math .prod (len (c ) for c in chunks )
315324
316325 # assumes that `labels` are factorized
317326 if expected_groups is None :
318327 nlabels = labels .max () + 1
319328 else :
320329 nlabels = expected_groups [- 1 ] + 1
321330
331+ # 1. Single chunk, blockwise always
332+ if nchunks == 1 :
333+ return "blockwise" , {(0 ,): list (range (nlabels ))}
334+
322335 labels = np .broadcast_to (labels , shape [- labels .ndim :])
323336 bitmask = _compute_label_chunk_bitmask (labels , chunks , nlabels )
324337
@@ -346,21 +359,21 @@ def invert(x) -> tuple[np.ndarray, ...]:
346359
347360 chunks_cohorts = tlz .groupby (invert , label_chunks .keys ())
348361
349- # 1 . Every group is contained to one block, use blockwise here.
362+ # 2 . Every group is contained to one block, use blockwise here.
350363 if bitmask .shape [CHUNK_AXIS ] == 1 or (chunks_per_label == 1 ).all ():
351364 logger .info ("find_group_cohorts: blockwise is preferred." )
352365 return "blockwise" , chunks_cohorts
353366
354- # 2 . Perfectly chunked so there is only a single cohort
367+ # 3 . Perfectly chunked so there is only a single cohort
355368 if len (chunks_cohorts ) == 1 :
356369 logger .info ("Only found a single cohort. 'map-reduce' is preferred." )
357370 return "map-reduce" , chunks_cohorts if merge else {}
358371
359- # 3 . Our dataset has chunksize one along the axis,
372+ # 4 . Our dataset has chunksize one along the axis,
360373 single_chunks = all (all (a == 1 for a in ac ) for ac in chunks )
361- # 4 . Every chunk only has a single group, but that group might extend across multiple chunks
374+ # 5 . Every chunk only has a single group, but that group might extend across multiple chunks
362375 one_group_per_chunk = (bitmask .sum (axis = LABEL_AXIS ) == 1 ).all ()
363- # 5 . Existing cohorts don't overlap, great for time grouping with perfect chunking
376+ # 6 . Existing cohorts don't overlap, great for time grouping with perfect chunking
364377 no_overlapping_cohorts = (np .bincount (np .concatenate (tuple (chunks_cohorts .keys ()))) == 1 ).all ()
365378 if one_group_per_chunk or single_chunks or no_overlapping_cohorts :
366379 logger .info ("find_group_cohorts: cohorts is preferred, chunking is perfect." )
@@ -393,6 +406,7 @@ def invert(x) -> tuple[np.ndarray, ...]:
393406 sparsity , MAX_SPARSITY_FOR_COHORTS
394407 )
395408 )
409+ # 7. Groups seem fairly randomly distributed, use "map-reduce".
396410 if sparsity > MAX_SPARSITY_FOR_COHORTS :
397411 if not merge :
398412 logger .info (
0 commit comments