@@ -1798,87 +1798,123 @@ def cubed_groupby_agg(
17981798 assert isinstance (axis , Sequence )
17991799 assert all (ax >= 0 for ax in axis )
18001800
1801- inds = tuple (range (array .ndim ))
1801+ if method == "blockwise" :
1802+ assert by .ndim == 1
1803+ assert expected_groups is not None
18021804
1803- by_input = by
1805+ def _reduction_func (a , by , axis , start_group , num_groups ):
1806+ # adjust group labels to start from 0 for each chunk
1807+ by_for_chunk = by - start_group
1808+ expected_groups_for_chunk = pd .RangeIndex (num_groups )
18041809
1805- # Unifying chunks is necessary for argreductions.
1806- # We need to rechunk before zipping up with the index
1807- # let's always do it anyway
1808- if not is_chunked_array (by ):
1809- # chunk numpy arrays like the input array
1810- chunks = tuple (array .chunks [ax ] if by .shape [ax ] != 1 else (1 ,) for ax in range (- by .ndim , 0 ))
1810+ axis = (axis ,) # convert integral axis to tuple
18111811
1812- by = cubed .from_array (by , chunks = chunks , spec = array .spec )
1813- _ , (array , by ) = cubed .core .unify_chunks (array , inds , by , inds [- by .ndim :])
1812+ blockwise_method = partial (
1813+ _reduce_blockwise ,
1814+ agg = agg ,
1815+ axis = axis ,
1816+ expected_groups = expected_groups_for_chunk ,
1817+ fill_value = fill_value ,
1818+ engine = engine ,
1819+ sort = sort ,
1820+ reindex = reindex ,
1821+ )
1822+ out = blockwise_method (a , by_for_chunk )
1823+ return out [agg .name ]
18141824
1815- # Cubed's groupby_reduction handles the generation of "intermediates", and the
1816- # "map-reduce" combination step, so we don't have to do that here.
1817- # Only the equivalent of "_simple_combine" is supported, there is no
1818- # support for "_grouped_combine".
1819- labels_are_unknown = is_chunked_array ( by_input ) and expected_groups is None
1820- do_simple_combine = not _is_arg_reduction ( agg ) and not labels_are_unknown
1825+ num_groups = len ( expected_groups )
1826+ result = cubed . core . groupby . groupby_blockwise (
1827+ array , by , axis = axis , func = _reduction_func , num_groups = num_groups
1828+ )
1829+ groups = ( expected_groups . to_numpy (),)
1830+ return ( result , groups )
18211831
1822- assert do_simple_combine
1823- assert method == "map-reduce"
1824- assert expected_groups is not None
1825- assert reindex is True
1826- assert len (axis ) == 1 # one axis/grouping
1832+ else :
1833+ inds = tuple (range (array .ndim ))
18271834
1828- def _groupby_func (a , by , axis , intermediate_dtype , num_groups ):
1829- blockwise_method = partial (
1830- _get_chunk_reduction (agg .reduction_type ),
1831- func = agg .chunk ,
1832- fill_value = agg .fill_value ["intermediate" ],
1833- dtype = agg .dtype ["intermediate" ],
1834- reindex = reindex ,
1835- user_dtype = agg .dtype ["user" ],
1835+ by_input = by
1836+
1837+ # Unifying chunks is necessary for argreductions.
1838+ # We need to rechunk before zipping up with the index
1839+ # let's always do it anyway
1840+ if not is_chunked_array (by ):
1841+ # chunk numpy arrays like the input array
1842+ chunks = tuple (
1843+ array .chunks [ax ] if by .shape [ax ] != 1 else (1 ,) for ax in range (- by .ndim , 0 )
1844+ )
1845+
1846+ by = cubed .from_array (by , chunks = chunks , spec = array .spec )
1847+ _ , (array , by ) = cubed .core .unify_chunks (array , inds , by , inds [- by .ndim :])
1848+
1849+ # Cubed's groupby_reduction handles the generation of "intermediates", and the
1850+ # "map-reduce" combination step, so we don't have to do that here.
1851+ # Only the equivalent of "_simple_combine" is supported, there is no
1852+ # support for "_grouped_combine".
1853+ labels_are_unknown = is_chunked_array (by_input ) and expected_groups is None
1854+ do_simple_combine = not _is_arg_reduction (agg ) and not labels_are_unknown
1855+
1856+ assert do_simple_combine
1857+ assert method == "map-reduce"
1858+ assert expected_groups is not None
1859+ assert reindex is True
1860+ assert len (axis ) == 1 # one axis/grouping
1861+
1862+ def _groupby_func (a , by , axis , intermediate_dtype , num_groups ):
1863+ blockwise_method = partial (
1864+ _get_chunk_reduction (agg .reduction_type ),
1865+ func = agg .chunk ,
1866+ fill_value = agg .fill_value ["intermediate" ],
1867+ dtype = agg .dtype ["intermediate" ],
1868+ reindex = reindex ,
1869+ user_dtype = agg .dtype ["user" ],
1870+ axis = axis ,
1871+ expected_groups = expected_groups ,
1872+ engine = engine ,
1873+ sort = sort ,
1874+ )
1875+ out = blockwise_method (a , by )
1876+ # Convert dict to one that cubed understands, dropping groups since they are
1877+ # known, and the same for every block.
1878+ return {
1879+ f"f{ idx } " : intermediate for idx , intermediate in enumerate (out ["intermediates" ])
1880+ }
1881+
1882+ def _groupby_combine (a , axis , dummy_axis , dtype , keepdims ):
1883+ # this is similar to _simple_combine, except the dummy axis and concatenation is handled by cubed
1884+ # only combine over the dummy axis, to preserve grouping along 'axis'
1885+ dtype = dict (dtype )
1886+ out = {}
1887+ for idx , combine in enumerate (agg .simple_combine ):
1888+ field = f"f{ idx } "
1889+ out [field ] = combine (a [field ], axis = dummy_axis , keepdims = keepdims )
1890+ return out
1891+
1892+ def _groupby_aggregate (a ):
1893+ # Convert cubed dict to one that _finalize_results works with
1894+ results = {"groups" : expected_groups , "intermediates" : a .values ()}
1895+ out = _finalize_results (results , agg , axis , expected_groups , fill_value , reindex )
1896+ return out [agg .name ]
1897+
1898+ # convert list of dtypes to a structured dtype for cubed
1899+ intermediate_dtype = [(f"f{ i } " , dtype ) for i , dtype in enumerate (agg .dtype ["intermediate" ])]
1900+ dtype = agg .dtype ["final" ]
1901+ num_groups = len (expected_groups )
1902+
1903+ result = cubed .core .groupby .groupby_reduction (
1904+ array ,
1905+ by ,
1906+ func = _groupby_func ,
1907+ combine_func = _groupby_combine ,
1908+ aggregate_func = _groupby_aggregate ,
18361909 axis = axis ,
1837- expected_groups = expected_groups ,
1838- engine = engine ,
1839- sort = sort ,
1910+ intermediate_dtype = intermediate_dtype ,
1911+ dtype = dtype ,
1912+ num_groups = num_groups ,
18401913 )
1841- out = blockwise_method (a , by )
1842- # Convert dict to one that cubed understands, dropping groups since they are
1843- # known, and the same for every block.
1844- return {f"f{ idx } " : intermediate for idx , intermediate in enumerate (out ["intermediates" ])}
1845-
1846- def _groupby_combine (a , axis , dummy_axis , dtype , keepdims ):
1847- # this is similar to _simple_combine, except the dummy axis and concatenation is handled by cubed
1848- # only combine over the dummy axis, to preserve grouping along 'axis'
1849- dtype = dict (dtype )
1850- out = {}
1851- for idx , combine in enumerate (agg .simple_combine ):
1852- field = f"f{ idx } "
1853- out [field ] = combine (a [field ], axis = dummy_axis , keepdims = keepdims )
1854- return out
1855-
1856- def _groupby_aggregate (a ):
1857- # Convert cubed dict to one that _finalize_results works with
1858- results = {"groups" : expected_groups , "intermediates" : a .values ()}
1859- out = _finalize_results (results , agg , axis , expected_groups , fill_value , reindex )
1860- return out [agg .name ]
1861-
1862- # convert list of dtypes to a structured dtype for cubed
1863- intermediate_dtype = [(f"f{ i } " , dtype ) for i , dtype in enumerate (agg .dtype ["intermediate" ])]
1864- dtype = agg .dtype ["final" ]
1865- num_groups = len (expected_groups )
1866-
1867- result = cubed .core .groupby .groupby_reduction (
1868- array ,
1869- by ,
1870- func = _groupby_func ,
1871- combine_func = _groupby_combine ,
1872- aggregate_func = _groupby_aggregate ,
1873- axis = axis ,
1874- intermediate_dtype = intermediate_dtype ,
1875- dtype = dtype ,
1876- num_groups = num_groups ,
1877- )
18781914
1879- groups = (expected_groups .to_numpy (),)
1915+ groups = (expected_groups .to_numpy (),)
18801916
1881- return (result , groups )
1917+ return (result , groups )
18821918
18831919
18841920def _collapse_blocks_along_axes (reduced : DaskArray , axis : T_Axes , group_chunks ) -> DaskArray :
@@ -2467,9 +2503,9 @@ def groupby_reduce(
24672503 if method is None :
24682504 method = "map-reduce"
24692505
2470- if method != "map-reduce" :
2506+ if method not in ( "map-reduce" , "blockwise" ) :
24712507 raise NotImplementedError (
2472- "Reduction for Cubed arrays is only implemented for method 'map-reduce'."
2508+ "Reduction for Cubed arrays is only implemented for methods 'map-reduce' and 'blockwise '."
24732509 )
24742510
24752511 partial_agg = partial (cubed_groupby_agg , ** kwargs )
0 commit comments