88 Any ,
99 Callable ,
1010 Dict ,
11- Iterable ,
1211 Mapping ,
1312 Optional ,
1413 Sequence ,
@@ -131,9 +130,9 @@ def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
131130 Parameters
132131 ----------
133132 labels: np.ndarray
134- 1D Array of group labels
135- chunks : tuple
136- chunks along grouping dimension for array that is being reduced
133+ mD Array of group labels
134+ array : tuple
135+ nD array that is being reduced
137136 merge: bool, optional
138137 Attempt to merge cohorts when one cohort's chunks are a subset
139138 of another cohort's chunks.
@@ -147,19 +146,35 @@ def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
147146 """
148147 import copy
149148
149+ import dask
150150 import toolz as tlz
151151
152152 if method == "split-reduce" :
153153 return np .unique (labels ).reshape (- 1 , 1 ).tolist ()
154154
155- which_chunk = np .repeat (np .arange (len (chunks )), chunks )
155+ # To do this, we must have values in memory so casting to numpy should be safe
156+ labels = np .asarray (labels )
157+
158+ # Build an array with the shape of labels, but where every element is the "chunk number"
159+ # 1. First subset the array appropriately
160+ axis = range (- labels .ndim , 0 )
161+ # Easier to create a dask array and use the .blocks property
162+ array = dask .array .ones (tuple (sum (c ) for c in chunks ), chunks = chunks )
163+
164+ # Iterate over each block and create a new block of same shape with "chunk number"
165+ shape = tuple (array .blocks .shape [ax ] for ax in axis )
166+ blocks = np .empty (np .prod (shape ), dtype = object )
167+ for idx , block in enumerate (array .blocks .ravel ()):
168+ blocks [idx ] = np .full (tuple (block .shape [ax ] for ax in axis ), idx )
169+ which_chunk = np .block (blocks .reshape (shape ).tolist ()).ravel ()
170+
156171 # these are chunks where a label is present
157- label_chunks = {lab : tuple (np .unique (which_chunk [labels == lab ])) for lab in np .unique (labels )}
172+ label_chunks = {
173+ lab : tuple (np .unique (which_chunk [labels .ravel () == lab ])) for lab in np .unique (labels )
174+ }
158175 # These invert the label_chunks mapping so we know which labels occur together.
159176 chunks_cohorts = tlz .groupby (label_chunks .get , label_chunks .keys ())
160177
161- # TODO: sort by length of values (i.e. cohort);
162- # then loop in reverse and merge when keys are subsets of initial keys?
163178 if merge :
164179 # First sort by number of chunks occupied by cohort
165180 sorted_chunks_cohorts = dict (
@@ -299,7 +314,6 @@ def reindex_(array: np.ndarray, from_, to, fill_value=None, axis: int = -1) -> n
299314 reindexed = np .full (array .shape [:- 1 ] + (len (to ),), fill_value , dtype = array .dtype )
300315 return reindexed
301316
302- from_ = np .atleast_1d (from_ )
303317 if from_ .dtype .kind == "O" and isinstance (from_ [0 ], tuple ):
304318 raise NotImplementedError (
305319 "Currently does not support reindexing with object arrays of tuples. "
@@ -706,14 +720,14 @@ def _npg_aggregate(
706720 expected_groups : Union [Sequence , np .ndarray , None ],
707721 axis : Sequence ,
708722 keepdims ,
709- group_ndim : int ,
723+ neg_axis : Sequence ,
710724 fill_value : Any = None ,
711725 min_count : Optional [int ] = None ,
712726 engine : str = "numpy" ,
713727 finalize_kwargs : Optional [Mapping ] = None ,
714728) -> FinalResultsDict :
715729 """Final aggregation step of tree reduction"""
716- results = _npg_combine (x_chunk , agg , axis , keepdims , group_ndim , engine )
730+ results = _npg_combine (x_chunk , agg , axis , keepdims , neg_axis , engine )
717731 return _finalize_results (
718732 results , agg , axis , expected_groups , fill_value , min_count , finalize_kwargs
719733 )
@@ -742,7 +756,7 @@ def _npg_combine(
742756 agg : Aggregation ,
743757 axis : Sequence ,
744758 keepdims : bool ,
745- group_ndim : int ,
759+ neg_axis : Sequence ,
746760 engine : str ,
747761) -> IntermediateDict :
748762 """Combine intermediates step of tree reduction."""
@@ -771,12 +785,7 @@ def reindex_intermediates(x):
771785
772786 x_chunk = deepmap (reindex_intermediates , x_chunk )
773787
774- group_conc_axis : Iterable [int ]
775- if group_ndim == 1 :
776- group_conc_axis = (0 ,)
777- else :
778- group_conc_axis = sorted (group_ndim - ax - 1 for ax in axis )
779- groups = _conc2 (x_chunk , "groups" , axis = group_conc_axis )
788+ groups = _conc2 (x_chunk , "groups" , axis = neg_axis )
780789
781790 if agg .reduction_type == "argreduce" :
782791 # If "nanlen" was added for masking later, we need to account for that
@@ -830,7 +839,7 @@ def reindex_intermediates(x):
830839 np .empty (shape = (1 ,) * (len (axis ) - 1 ) + (0 ,), dtype = agg .dtype )
831840 )
832841 results ["groups" ] = np .empty (
833- shape = (1 ,) * (len (group_conc_axis ) - 1 ) + (0 ,), dtype = groups .dtype
842+ shape = (1 ,) * (len (neg_axis ) - 1 ) + (0 ,), dtype = groups .dtype
834843 )
835844 else :
836845 _results = chunk_reduce (
@@ -891,6 +900,7 @@ def groupby_agg(
891900 method : str = "map-reduce" ,
892901 min_count : Optional [int ] = None ,
893902 isbin : bool = False ,
903+ reindex : bool = False ,
894904 engine : str = "numpy" ,
895905 finalize_kwargs : Optional [Mapping ] = None ,
896906) -> Tuple ["DaskArray" , Union [np .ndarray , "DaskArray" ]]:
@@ -902,6 +912,9 @@ def groupby_agg(
902912 assert isinstance (axis , Sequence )
903913 assert all (ax >= 0 for ax in axis )
904914
915+ # these are negative axis indices useful for concatenating the intermediates
916+ neg_axis = range (- len (axis ), 0 )
917+
905918 inds = tuple (range (array .ndim ))
906919 name = f"groupby_{ agg .name } "
907920 token = dask .base .tokenize (array , by , agg , expected_groups , axis , split_out )
@@ -926,11 +939,11 @@ def groupby_agg(
926939 axis = axis ,
927940 # with the current implementation we want reindexing at the blockwise step
928941 # only reindex to groups present at combine stage
929- expected_groups = expected_groups if split_out > 1 or isbin else None ,
942+ expected_groups = expected_groups if reindex or split_out > 1 or isbin else None ,
930943 fill_value = agg .fill_value ["intermediate" ],
931944 dtype = agg .dtype ["intermediate" ],
932945 isbin = isbin ,
933- reindex = split_out > 1 ,
946+ reindex = reindex or ( split_out > 1 ) ,
934947 engine = engine ,
935948 ),
936949 inds ,
@@ -964,7 +977,7 @@ def groupby_agg(
964977 expected_agg = expected_groups
965978
966979 agg_kwargs = dict (
967- group_ndim = by . ndim ,
980+ neg_axis = neg_axis ,
968981 fill_value = fill_value ,
969982 min_count = min_count ,
970983 engine = engine ,
@@ -984,7 +997,7 @@ def groupby_agg(
984997 expected_groups = expected_agg ,
985998 ** agg_kwargs ,
986999 ),
987- combine = partial (_npg_combine , agg = agg , group_ndim = by . ndim , engine = engine ),
1000+ combine = partial (_npg_combine , agg = agg , neg_axis = neg_axis , engine = engine ),
9881001 name = f"{ name } -reduce" ,
9891002 dtype = array .dtype ,
9901003 axis = axis ,
@@ -996,12 +1009,7 @@ def groupby_agg(
9961009 # Blockwise apply the aggregation step so that one input chunk → one output chunk
9971010 # TODO: We could combine this with the chunk reduction and do everything in one task.
9981011 # This would also optimize the single block along reduced-axis case.
999- if (
1000- expected_groups is None
1001- or split_out > 1
1002- or len (axis ) > 1
1003- or not isinstance (by_maybe_numpy , np .ndarray )
1004- ):
1012+ if expected_groups is None or split_out > 1 or not isinstance (by_maybe_numpy , np .ndarray ):
10051013 raise NotImplementedError
10061014
10071015 reduced = dask .array .blockwise (
@@ -1020,17 +1028,25 @@ def groupby_agg(
10201028 dtype = array .dtype ,
10211029 meta = array ._meta ,
10221030 align_arrays = False ,
1023- name = f"{ name } -blockwise-agg- { token } " ,
1031+ name = f"{ name } -blockwise-{ token } " ,
10241032 )
1025- chunks = array .chunks [axis [0 ]]
10261033
10271034 # find number of groups in each chunk, this is needed for output chunks
10281035 # along the reduced axis
1029- bnds = np .insert (np .cumsum (chunks ), 0 , 0 )
1030- groups_per_chunk = tuple (
1031- len (np .unique (by_maybe_numpy [i0 :i1 ])) for i0 , i1 in zip (bnds [:- 1 ], bnds [1 :])
1032- )
1033- output_chunks = reduced .chunks [: - (len (axis ))] + (groups_per_chunk ,)
1036+ from dask .array .core import slices_from_chunks
1037+
1038+ slices = slices_from_chunks (tuple (array .chunks [ax ] for ax in axis ))
1039+ if expected_groups is None :
1040+ groups_in_block = tuple (np .unique (by_maybe_numpy [slc ]) for slc in slices )
1041+ else :
1042+ # For cohorts, we could be indexing a block with groups that
1043+ # are not in the cohort (usually for nD `by`)
1044+ # Only keep the expected groups.
1045+ groups_in_block = tuple (
1046+ np .intersect1d (by_maybe_numpy [slc ], expected_groups ) for slc in slices
1047+ )
1048+ ngroups_per_block = tuple (len (groups ) for groups in groups_in_block )
1049+ output_chunks = reduced .chunks [: - (len (axis ))] + (ngroups_per_block ,)
10341050 else :
10351051 raise ValueError (f"Unknown method={ method } ." )
10361052
@@ -1059,15 +1075,22 @@ def _getitem(d, key1, key2):
10591075 ),
10601076 )
10611077 else :
1062- groups = (expected_groups ,)
1078+ if method == "map-reduce" :
1079+ groups = (expected_groups ,)
1080+ else :
1081+ groups = (np .concatenate (groups_in_block ),)
10631082
10641083 layer : Dict [Tuple , Tuple ] = {} # type: ignore
10651084 agg_name = f"{ name } -{ token } "
10661085 for ochunk in itertools .product (* ochunks ):
10671086 if method == "blockwise" :
1068- inchunk = ochunk
1087+ if len (axis ) == 1 :
1088+ inchunk = ochunk
1089+ else :
1090+ nblocks = tuple (len (array .chunks [ax ]) for ax in axis )
1091+ inchunk = ochunk [:- 1 ] + np .unravel_index (ochunk [- 1 ], nblocks )
10691092 else :
1070- inchunk = ochunk [:- 1 ] + (0 ,) * ( len (axis ) ) + (ochunk [- 1 ],) * int (split_out > 1 )
1093+ inchunk = ochunk [:- 1 ] + (0 ,) * len (axis ) + (ochunk [- 1 ],) * int (split_out > 1 )
10711094 layer [(agg_name , * ochunk )] = (
10721095 operator .getitem ,
10731096 (reduced .name , * inchunk ),
@@ -1089,6 +1112,7 @@ def groupby_reduce(
10891112 func : Union [str , Aggregation ],
10901113 * ,
10911114 expected_groups : Union [Sequence , np .ndarray ] = None ,
1115+ sort : bool = True ,
10921116 isbin : bool = False ,
10931117 axis = None ,
10941118 fill_value = None ,
@@ -1114,6 +1138,10 @@ def groupby_reduce(
11141138 Expected unique labels.
11151139 isbin : bool, optional
11161140 Are ``expected_groups`` bin edges?
1141+ sort : (optional), bool
1142+ Whether groups should be returned in sorted order. Only applies for dask
1143+ reductions when ``method`` is not `"map-reduce"`. For ``"map-reduce", the groups
1144+ are always sorted.
11171145 axis : (optional) None or int or Sequence[int]
11181146 If None, reduce across all dimensions of by
11191147 Else, reduce across corresponding axes of array
@@ -1138,17 +1166,19 @@ def groupby_reduce(
11381166 * ``"blockwise"``:
11391167 Only reduce using blockwise and avoid aggregating blocks
11401168 together. Useful for resampling-style reductions where group
1141- members are always together. The array is rechunked so that
1142- chunk boundaries line up with group boundaries
1169+ members are always together. If `by` is 1D, `array` is automatically
1170+ rechunked so that chunk boundaries line up with group boundaries
11431171 i.e. each block contains all members of any group present
1144- in that block.
1172+ in that block. For nD `by`, you must make sure that all members of a group
1173+ are present in a single block.
11451174 * ``"cohorts"``:
11461175 Finds group labels that tend to occur together ("cohorts"),
11471176 indexes out cohorts and reduces that subset using "map-reduce",
11481177 repeat for all cohorts. This works well for many time groupings
11491178 where the group labels repeat at regular intervals like 'hour',
11501179 'month', dayofyear' etc. Optimize chunking ``array`` for this
1151- method by first rechunking using ``rechunk_for_cohorts``.
1180+ method by first rechunking using ``rechunk_for_cohorts``
1181+ (for 1D ``by`` only).
11521182 * ``"split-reduce"``:
11531183 Break out each group into its own array and then ``"map-reduce"``.
11541184 This is implemented by having each group be its own cohort,
@@ -1208,6 +1238,11 @@ def groupby_reduce(
12081238 else :
12091239 axis = np .core .numeric .normalize_axis_tuple (axis , array .ndim ) # type: ignore
12101240
1241+ if method in ["blockwise" , "cohorts" , "split-reduce" ] and len (axis ) != by .ndim :
1242+ raise NotImplementedError (
1243+ "Must reduce along all dimensions of `by` when method != 'map-reduce'."
1244+ )
1245+
12111246 if expected_groups is None and isinstance (by , np .ndarray ):
12121247 flatby = by .ravel ()
12131248 expected_groups = np .unique (flatby [~ isnull (flatby )])
@@ -1366,59 +1401,58 @@ def groupby_reduce(
13661401 )
13671402
13681403 if method in ["split-reduce" , "cohorts" ]:
1369- if by .ndim > 1 :
1370- raise ValueError (
1371- "`by` must be 1D when method='split-reduce' and method='cohorts'. "
1372- f"Received { by .ndim } D array. Please use method='map-reduce' instead."
1373- )
1374- assert axis == (array .ndim - 1 ,)
1375-
1376- cohorts = find_group_cohorts (by , array .chunks [axis [0 ]], merge = True , method = method )
1377- idx = np .arange (len (by ))
1404+ cohorts = find_group_cohorts (
1405+ by , [array .chunks [ax ] for ax in axis ], merge = True , method = method
1406+ )
13781407
13791408 results = []
13801409 groups_ = []
13811410 for cohort in cohorts :
13821411 cohort = sorted (cohort )
1383- # indexes for a subset of groups
1384- subset_idx = idx [np .isin (by , cohort )]
1385- array_subset = array [..., subset_idx ]
1386- numblocks = len (array_subset .chunks [- 1 ])
1412+ # equivalent of xarray.DataArray.where(mask, drop=True)
1413+ mask = np .isin (by , cohort )
1414+ indexer = [np .unique (v ) for v in np .nonzero (mask )]
1415+ array_subset = array
1416+ for ax , idxr in zip (range (- by .ndim , 0 ), indexer ):
1417+ array_subset = np .take (array_subset , idxr , axis = ax )
1418+ numblocks = np .prod ([len (array_subset .chunks [ax ]) for ax in axis ])
13871419
13881420 # get final result for these groups
13891421 r , * g = partial_agg (
13901422 array_subset ,
1391- by [subset_idx ],
1423+ by [np . ix_ ( * indexer ) ],
13921424 expected_groups = cohort ,
1425+ # reindex to expected_groups at the blockwise step.
1426+ # this approach avoids replacing non-cohort members with
1427+ # np.nan or some other sentinel value, and preserves dtypes
1428+ reindex = True ,
13931429 # if only a single block along axis, we can just work blockwise
13941430 # inspired by https://github.com/dask/dask/issues/8361
1395- method = "blockwise" if numblocks == 1 else "map-reduce" ,
1431+ method = "blockwise" if numblocks == 1 and len ( axis ) == by . ndim else "map-reduce" ,
13961432 )
13971433 results .append (r )
13981434 groups_ .append (cohort )
13991435
14001436 # concatenate results together,
14011437 # sort to make sure we match expected output
1402- allgroups = np .hstack (groups_ )
1403- sorted_idx = np .argsort (allgroups )
1404- result = np .concatenate (results , axis = - 1 )[..., sorted_idx ]
1405- groups = (allgroups [sorted_idx ],)
1406-
1438+ groups = (np .hstack (groups_ ),)
1439+ result = np .concatenate (results , axis = - 1 )
14071440 else :
14081441 if method == "blockwise" :
1409- if by .ndim > 1 :
1410- raise ValueError (
1411- "For method='blockwise', ``by`` must be 1D. "
1412- f"Received { by .ndim } dimensions instead."
1413- )
1414- array = rechunk_for_blockwise (array , axis = - 1 , labels = by )
1415-
1416- # TODO: test with mixed array kinds (numpy + dask; dask + numpy)
1442+ if by .ndim == 1 :
1443+ array = rechunk_for_blockwise (array , axis = - 1 , labels = by )
1444+
1445+ # TODO: test with mixed array kinds (numpy array + dask by)
14171446 result , * groups = partial_agg (
14181447 array ,
14191448 by ,
14201449 expected_groups = expected_groups ,
14211450 method = method ,
14221451 )
1452+ if sort and method != "map-reduce" :
1453+ assert len (groups ) == 1
1454+ sorted_idx = np .argsort (groups [0 ])
1455+ result = result [..., sorted_idx ]
1456+ groups = (groups [0 ][sorted_idx ],)
14231457
14241458 return (result , * groups )
0 commit comments