@@ -735,14 +735,24 @@ def _expand_dims(results: IntermediateDict) -> IntermediateDict:
735735def _simple_combine (
736736 x_chunk , agg : Aggregation , axis : Sequence , keepdims : bool , is_aggregate : bool = False
737737) -> IntermediateDict :
738+ """
739+ 'Simple' combination of blockwise results.
740+
741+ 1. After the blockwise groupby-reduce, all blocks contain a value for all possible groups,
742+ and are of the same shape; i.e. reindex must have been True
743+ 2. _expand_dims was used to insert an extra axis DUMMY_AXIS
744+ 3. Here we concatenate along DUMMY_AXIS, and then call the combine function along
745+ DUMMY_AXIS
746+ 4. At the final agggregate step, we squeeze out DUMMY_AXIS
747+ """
738748 from dask .array .core import deepfirst
739749
740750 results = {"groups" : deepfirst (x_chunk )["groups" ]}
741751 results ["intermediates" ] = []
742752 for idx , combine in enumerate (agg .combine ):
743- array = _conc2 (x_chunk , key1 = "intermediates" , key2 = idx , axis = axis )
753+ array = _conc2 (x_chunk , key1 = "intermediates" , key2 = idx , axis = axis [: - 1 ] + ( DUMMY_AXIS ,) )
744754 assert array .ndim >= 2
745- result = getattr (np , combine )(array , axis = axis , keepdims = True )
755+ result = getattr (np , combine )(array , axis = axis [: - 1 ] + ( DUMMY_AXIS ,) , keepdims = True )
746756 if is_aggregate :
747757 # squeeze out DUMMY_AXIS if this is the last step i.e. called from _aggregate
748758 result = result .squeeze (axis = DUMMY_AXIS )
0 commit comments