4444 quantile_new_dims_func ,
4545)
4646from .cache import memoize
47+ from .lib import ArrayLayer
4748from .xrutils import (
4849 _contains_cftime_datetimes ,
4950 _to_pytimedelta ,
7273 from typing import Unpack
7374 except (ModuleNotFoundError , ImportError ):
7475 Unpack : Any # type: ignore[no-redef]
75-
76- import cubed .Array as CubedArray
77- import dask .array .Array as DaskArray
78- from dask .typing import Graph
76+ from .types import CubedArray , DaskArray , Graph
7977
8078 T_DuckArray : TypeAlias = np .ndarray | DaskArray | CubedArray # Any ?
8179 T_By : TypeAlias = T_DuckArray
@@ -1191,7 +1189,7 @@ def _aggregate(
11911189 agg : Aggregation ,
11921190 expected_groups : pd .Index | None ,
11931191 axis : T_Axes ,
1194- keepdims ,
1192+ keepdims : bool ,
11951193 fill_value : Any ,
11961194 reindex : bool ,
11971195) -> FinalResultsDict :
@@ -1511,7 +1509,7 @@ def subset_to_blocks(
15111509 blkshape : tuple [int , ...] | None = None ,
15121510 reindexer = identity ,
15131511 chunks_as_array : tuple [np .ndarray , ...] | None = None ,
1514- ) -> DaskArray :
1512+ ) -> ArrayLayer :
15151513 """
15161514 Advanced indexing of .blocks such that we always get a regular array back.
15171515
@@ -1525,10 +1523,8 @@ def subset_to_blocks(
15251523 -------
15261524 dask.array
15271525 """
1528- import dask .array
15291526 from dask .array .slicing import normalize_index
15301527 from dask .base import tokenize
1531- from dask .highlevelgraph import HighLevelGraph
15321528
15331529 if blkshape is None :
15341530 blkshape = array .blocks .shape
@@ -1538,9 +1534,6 @@ def subset_to_blocks(
15381534
15391535 index = _normalize_indexes (array , flatblocks , blkshape )
15401536
1541- if all (not isinstance (i , np .ndarray ) and i == slice (None ) for i in index ):
1542- return dask .array .map_blocks (reindexer , array , meta = array ._meta )
1543-
15441537 # These rest is copied from dask.array.core.py with slight modifications
15451538 index = normalize_index (index , array .numblocks )
15461539 index = tuple (slice (k , k + 1 ) if isinstance (k , Integral ) else k for k in index )
@@ -1553,10 +1546,7 @@ def subset_to_blocks(
15531546
15541547 keys = itertools .product (* (range (len (c )) for c in chunks ))
15551548 layer : Graph = {(name ,) + key : (reindexer , tuple (new_keys [key ].tolist ())) for key in keys }
1556-
1557- graph = HighLevelGraph .from_collections (name , layer , dependencies = [array ])
1558-
1559- return dask .array .Array (graph , name , chunks , meta = array )
1549+ return ArrayLayer (layer = layer , chunks = chunks , name = name )
15601550
15611551
15621552def _extract_unknown_groups (reduced , dtype ) -> tuple [DaskArray ]:
@@ -1613,6 +1603,9 @@ def dask_groupby_agg(
16131603) -> tuple [DaskArray , tuple [np .ndarray | DaskArray ]]:
16141604 import dask .array
16151605 from dask .array .core import slices_from_chunks
1606+ from dask .highlevelgraph import HighLevelGraph
1607+
1608+ from .dask_array_ops import _tree_reduce
16161609
16171610 # I think _tree_reduce expects this
16181611 assert isinstance (axis , Sequence )
@@ -1742,35 +1735,44 @@ def dask_groupby_agg(
17421735 assert chunks_cohorts
17431736 block_shape = array .blocks .shape [- len (axis ) :]
17441737
1745- reduced_ = []
1738+ out_name = f" { name } -reduce- { method } - { token } "
17461739 groups_ = []
17471740 chunks_as_array = tuple (np .array (c ) for c in array .chunks )
1748- for blks , cohort in chunks_cohorts .items ():
1741+ dsk : Graph = {}
1742+ for icohort , (blks , cohort ) in enumerate (chunks_cohorts .items ()):
17491743 cohort_index = pd .Index (cohort )
17501744 reindexer = (
17511745 partial (reindex_intermediates , agg = agg , unique_groups = cohort_index )
17521746 if do_simple_combine
17531747 else identity
17541748 )
1755- reindexed = subset_to_blocks (intermediate , blks , block_shape , reindexer , chunks_as_array )
1749+ subset = subset_to_blocks (intermediate , blks , block_shape , reindexer , chunks_as_array )
1750+ dsk |= subset .layer # type: ignore[operator]
17561751 # now that we have reindexed, we can set reindex=True explicitlly
1757- reduced_ . append (
1758- tree_reduce (
1759- reindexed ,
1760- combine = partial ( combine , agg = agg , reindex = do_simple_combine ) ,
1761- aggregate = partial (
1762- aggregate ,
1763- expected_groups = cohort_index ,
1764- reindex = do_simple_combine ,
1765- ),
1766- )
1752+ _tree_reduce (
1753+ subset ,
1754+ out_dsk = dsk ,
1755+ name = out_name ,
1756+ block_index = icohort ,
1757+ axis = axis ,
1758+ combine = partial ( combine , agg = agg , reindex = do_simple_combine , keepdims = True ) ,
1759+ aggregate = partial (
1760+ aggregate , expected_groups = cohort_index , reindex = do_simple_combine , keepdims = True
1761+ ),
17671762 )
17681763 # This is done because pandas promotes to 64-bit types when an Index is created
17691764 # So we use the index to generate the return value for consistency with "map-reduce"
17701765 # This is important on windows
17711766 groups_ .append (cohort_index .values )
17721767
1773- reduced = dask .array .concatenate (reduced_ , axis = - 1 )
1768+ graph = HighLevelGraph .from_collections (out_name , dsk , dependencies = [intermediate ])
1769+
1770+ out_chunks = list (array .chunks )
1771+ out_chunks [axis [- 1 ]] = tuple (len (c ) for c in chunks_cohorts .values ())
1772+ for ax in axis [:- 1 ]:
1773+ out_chunks [ax ] = (1 ,)
1774+ reduced = dask .array .Array (graph , out_name , out_chunks , meta = array ._meta )
1775+
17741776 groups = (np .concatenate (groups_ ),)
17751777 group_chunks = (tuple (len (cohort ) for cohort in groups_ ),)
17761778
0 commit comments