4848)
4949from .cache import memoize
5050from .lib import ArrayLayer , dask_array_type , sparse_array_type
51+ from .options import OPTIONS
5152from .xrutils import (
5253 _contains_cftime_datetimes ,
5354 _to_pytimedelta ,
111112# _simple_combine.
112113DUMMY_AXIS = - 2
113114
115+
114116logger = logging .getLogger ("flox" )
115117
116118
@@ -215,8 +217,11 @@ def identity(x: T) -> T:
215217 return x
216218
217219
218- def _issorted (arr : np .ndarray ) -> bool :
219- return bool ((arr [:- 1 ] <= arr [1 :]).all ())
220+ def _issorted (arr : np .ndarray , ascending = True ) -> bool :
221+ if ascending :
222+ return bool ((arr [:- 1 ] <= arr [1 :]).all ())
223+ else :
224+ return bool ((arr [:- 1 ] >= arr [1 :]).all ())
220225
221226
222227def _is_arg_reduction (func : T_Agg ) -> bool :
@@ -299,7 +304,7 @@ def _collapse_axis(arr: np.ndarray, naxis: int) -> np.ndarray:
299304def _get_optimal_chunks_for_groups (chunks , labels ):
300305 chunkidx = np .cumsum (chunks ) - 1
301306 # what are the groups at chunk boundaries
302- labels_at_chunk_bounds = _unique (labels [chunkidx ])
307+ labels_at_chunk_bounds = pd . unique (labels [chunkidx ])
303308 # what's the last index of all groups
304309 last_indexes = npg .aggregate_numpy .aggregate (labels , np .arange (len (labels )), func = "last" )
305310 # what's the last index of groups at the chunk boundaries.
@@ -317,6 +322,8 @@ def _get_optimal_chunks_for_groups(chunks, labels):
317322 Δl = abs (c - l )
318323 if c == 0 or newchunkidx [- 1 ] > l :
319324 continue
325+ f = f .item () # noqa
326+ l = l .item () # noqa
320327 if Δf < Δl and f > newchunkidx [- 1 ]:
321328 newchunkidx .append (f )
322329 else :
@@ -708,7 +715,9 @@ def rechunk_for_cohorts(
708715 return array .rechunk ({axis : newchunks })
709716
710717
711- def rechunk_for_blockwise (array : DaskArray , axis : T_Axis , labels : np .ndarray ) -> DaskArray :
718+ def rechunk_for_blockwise (
719+ array : DaskArray , axis : T_Axis , labels : np .ndarray , * , force : bool = True
720+ ) -> tuple [T_MethodOpt , DaskArray ]:
712721 """
713722 Rechunks array so that group boundaries line up with chunk boundaries, allowing
714723 embarrassingly parallel group reductions.
@@ -731,14 +740,47 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
731740 DaskArray
732741 Rechunked array
733742 """
734- # TODO: this should be unnecessary?
735- labels = factorize_ ((labels ,), axes = ())[0 ]
743+
736744 chunks = array .chunks [axis ]
737- newchunks = _get_optimal_chunks_for_groups (chunks , labels )
745+ if len (chunks ) == 1 :
746+ return "blockwise" , array
747+
748+ # import dask
749+ # from dask.utils import parse_bytes
750+ # factor = parse_bytes(dask.config.get("array.chunk-size")) / (
751+ # math.prod(array.chunksize) * array.dtype.itemsize
752+ # )
753+ # if factor > BLOCKWISE_DEFAULT_ARRAY_CHUNK_SIZE_FACTOR:
754+ # new_constant_chunks = math.ceil(factor) * max(chunks)
755+ # q, r = divmod(array.shape[axis], new_constant_chunks)
756+ # new_input_chunks = (new_constant_chunks,) * q + (r,)
757+ # else:
758+ new_input_chunks = chunks
759+
760+ # FIXME: this should be unnecessary?
761+ labels = factorize_ ((labels ,), axes = ())[0 ]
762+ newchunks = _get_optimal_chunks_for_groups (new_input_chunks , labels )
738763 if newchunks == chunks :
739- return array
764+ return "blockwise" , array
765+
766+ Δn = abs (len (newchunks ) - len (new_input_chunks ))
767+ if pass_num_chunks_threshold := (
768+ Δn / len (new_input_chunks ) < OPTIONS ["rechunk_blockwise_num_chunks_threshold" ]
769+ ):
770+ logger .debug ("blockwise rechunk passes num chunks threshold" )
771+ if pass_chunk_size_threshold := (
772+ # we just pick the max because number of chunks may have changed.
773+ (abs (max (newchunks ) - max (new_input_chunks )) / max (new_input_chunks ))
774+ < OPTIONS ["rechunk_blockwise_chunk_size_threshold" ]
775+ ):
776+ logger .debug ("blockwise rechunk passes chunk size change threshold" )
777+
778+ if force or (pass_num_chunks_threshold and pass_chunk_size_threshold ):
779+ logger .debug ("Rechunking to enable blockwise." )
780+ return "blockwise" , array .rechunk ({axis : newchunks })
740781 else :
741- return array .rechunk ({axis : newchunks })
782+ logger .debug ("Didn't meet thresholds to do automatic rechunking for blockwise reductions." )
783+ return None , array
742784
743785
744786def reindex_numpy (array , from_ : pd .Index , to : pd .Index , fill_value , dtype , axis : int ):
@@ -2704,6 +2746,11 @@ def groupby_reduce(
27042746 has_dask = is_duck_dask_array (array ) or is_duck_dask_array (by_ )
27052747 has_cubed = is_duck_cubed_array (array ) or is_duck_cubed_array (by_ )
27062748
2749+ if method is None and is_duck_dask_array (array ) and not any_by_dask and by_ .ndim == 1 and _issorted (by_ ):
2750+ # Let's try rechunking for sorted 1D by.
2751+ (single_axis ,) = axis_
2752+ method , array = rechunk_for_blockwise (array , single_axis , by_ , force = False )
2753+
27072754 is_first_last = _is_first_last_reduction (func )
27082755 if is_first_last :
27092756 if has_dask and nax != 1 :
@@ -2891,7 +2938,7 @@ def groupby_reduce(
28912938
28922939 # if preferred method is already blockwise, no need to rechunk
28932940 if preferred_method != "blockwise" and method == "blockwise" and by_ .ndim == 1 :
2894- array = rechunk_for_blockwise (array , axis = - 1 , labels = by_ )
2941+ _ , array = rechunk_for_blockwise (array , axis = - 1 , labels = by_ )
28952942
28962943 result , groups = partial_agg (
28972944 array = array ,
0 commit comments