44import itertools
55import operator
66from collections import namedtuple
7- from functools import partial
7+ from functools import partial , reduce
88from typing import TYPE_CHECKING , Any , Callable , Dict , Mapping , Sequence , Union
99
1010import numpy as np
@@ -59,16 +59,14 @@ def _prepare_for_flox(group_idx, array):
5959 return group_idx , ordered_array
6060
6161
62- def _get_expected_groups (by , sort , raise_if_dask = True ) -> pd .Index | None :
62+ def _get_expected_groups (by , sort , * , raise_if_dask = True ) -> pd .Index | None :
6363 if is_duck_dask_array (by ):
6464 if raise_if_dask :
65- raise ValueError ("Please provide ` expected_groups` ." )
65+ raise ValueError ("Please provide expected_groups if not grouping by a numpy array ." )
6666 return None
6767 flatby = by .ravel ()
6868 expected = pd .unique (flatby [~ isnull (flatby )])
69- if sort :
70- expected = np .sort (expected )
71- return _convert_expected_groups_to_index (expected , isbin = False )
69+ return _convert_expected_groups_to_index ((expected ,), isbin = (False ,), sort = sort )[0 ]
7270
7371
7472def _get_chunk_reduction (reduction_type : str ) -> Callable :
@@ -378,6 +376,7 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
378376 Copied from xhistogram &
379377 https://stackoverflow.com/questions/46256279/bin-elements-per-row-vectorized-2d-bincount-for-numpy
380378 """
379+ assert labels .ndim > 1
381380 offset : np .ndarray = (
382381 labels + np .arange (np .prod (labels .shape [:- 1 ])).reshape ((* labels .shape [:- 1 ], - 1 )) * ngroups
383382 )
@@ -388,7 +387,12 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
388387
389388
390389def factorize_ (
391- by : tuple , axis , expected_groups : tuple [pd .Index , ...] = None , reindex = False , sort = True
390+ by : tuple ,
391+ axis ,
392+ expected_groups : tuple [pd .Index , ...] = None ,
393+ reindex = False ,
394+ sort = True ,
395+ fastpath = False ,
392396):
393397 """
394398 Returns an array of integer codes for groups (and associated data)
@@ -413,7 +417,7 @@ def factorize_(
413417 raise ValueError ("Please pass bin edges in expected_groups." )
414418 # TODO: fix for binning
415419 found_groups .append (expect )
416- # pd.cut with bins = IntervalIndex[datetime64] doesn't work...
420+ # pd.cut with bins = IntervalIndex[datetime64] doesn't work...
417421 if groupvar .dtype .kind == "M" :
418422 expect = np .concatenate ([expect .left .to_numpy (), [expect .right [- 1 ].to_numpy ()]])
419423 idx = pd .cut (groupvar .ravel (), bins = expect ).codes .copy ()
@@ -440,10 +444,15 @@ def factorize_(
440444 grp_shape = tuple (len (grp ) for grp in found_groups )
441445 ngroups = np .prod (grp_shape )
442446 if len (by ) > 1 :
443- group_idx = np .ravel_multi_index (factorized , grp_shape ).reshape (by [0 ].shape )
447+ group_idx = np .ravel_multi_index (factorized , grp_shape , mode = "wrap" ).reshape (by [0 ].shape )
448+ nan_by_mask = reduce (np .logical_or , [isnull (b ) for b in by ])
449+ group_idx [nan_by_mask ] = - 1
444450 else :
445451 group_idx = factorized [0 ]
446452
453+ if fastpath :
454+ return group_idx , found_groups , grp_shape
455+
447456 if np .isscalar (axis ) and groupvar .ndim > 1 :
448457 # Not reducing along all dimensions of by
449458 # this is OK because for 3D by and axis=(1,2),
@@ -1244,33 +1253,78 @@ def _validate_reindex(reindex: bool, func, method, expected_groups) -> bool:
12441253
12451254
12461255def _assert_by_is_aligned (shape , by ):
1247- if shape [- by .ndim :] != by .shape :
1248- raise ValueError (
1249- "`array` and `by` arrays must be aligned "
1250- "i.e. array.shape[-by.ndim :] == by.shape. "
1251- "for every array in `by`."
1252- f"Received array of shape { shape } but "
1253- f"`by` has shape { by .shape } ."
1256+ for idx , b in enumerate (by ):
1257+ if shape [- b .ndim :] != b .shape :
1258+ raise ValueError (
1259+ "`array` and `by` arrays must be aligned "
1260+ "i.e. array.shape[-by.ndim :] == by.shape. "
1261+ "for every array in `by`."
1262+ f"Received array of shape { shape } but "
1263+ f"array { idx } in `by` has shape { b .shape } ."
1264+ )
1265+
1266+
1267+ def _convert_expected_groups_to_index (
1268+ expected_groups : tuple , isbin : bool , sort : bool
1269+ ) -> pd .Index | None :
1270+ out = []
1271+ for ex , isbin_ in zip (expected_groups , isbin ):
1272+ if isinstance (ex , pd .IntervalIndex ) or (isinstance (ex , pd .Index ) and not isbin ):
1273+ if sort :
1274+ ex = ex .sort_values ()
1275+ out .append (ex )
1276+ elif ex is not None :
1277+ if isbin_ :
1278+ out .append (pd .IntervalIndex .from_arrays (ex [:- 1 ], ex [1 :]))
1279+ else :
1280+ if sort :
1281+ ex = np .sort (ex )
1282+ out .append (pd .Index (ex ))
1283+ else :
1284+ assert ex is None
1285+ out .append (None )
1286+ return tuple (out )
1287+
1288+
1289+ def _lazy_factorize_wrapper (* by , ** kwargs ):
1290+ group_idx , * rest = factorize_ (by , ** kwargs )
1291+ return group_idx
1292+
1293+
1294+ def _factorize_multiple (by , expected_groups , by_is_dask ):
1295+ kwargs = dict (
1296+ expected_groups = expected_groups ,
1297+ axis = None , # always None, we offset later if necessary.
1298+ fastpath = True ,
1299+ )
1300+ if by_is_dask :
1301+ import dask .array
1302+
1303+ group_idx = dask .array .map_blocks (
1304+ _lazy_factorize_wrapper ,
1305+ * np .broadcast_arrays (* by ),
1306+ meta = np .array ((), dtype = np .int64 ),
1307+ ** kwargs ,
12541308 )
1309+ found_groups = tuple (None if is_duck_dask_array (b ) else pd .unique (b ) for b in by )
1310+ grp_shape = tuple (len (e ) for e in expected_groups )
1311+ else :
1312+ group_idx , found_groups , grp_shape = factorize_ (by , ** kwargs )
12551313
1314+ final_groups = tuple (
1315+ found if expect is None else expect .to_numpy ()
1316+ for found , expect in zip (found_groups , expected_groups )
1317+ )
12561318
1257- def _convert_expected_groups_to_index (expected_groups , isbin : bool ) -> pd .Index | None :
1258- if isinstance (expected_groups , pd .IntervalIndex ) or (
1259- isinstance (expected_groups , pd .Index ) and not isbin
1260- ):
1261- return expected_groups
1262- if isbin :
1263- return pd .IntervalIndex .from_arrays (expected_groups [:- 1 ], expected_groups [1 :])
1264- elif expected_groups is not None :
1265- return pd .Index (expected_groups )
1266- return expected_groups
1319+ if any (grp is None for grp in final_groups ):
1320+ raise ValueError ("Please provide expected_groups when grouping by a dask array." )
1321+ return (group_idx ,), final_groups , grp_shape
12671322
12681323
12691324def groupby_reduce (
12701325 array : np .ndarray | DaskArray ,
1271- by : np .ndarray | DaskArray ,
1326+ * by : np .ndarray | DaskArray ,
12721327 func : str | Aggregation ,
1273- * ,
12741328 expected_groups : Sequence | np .ndarray | None = None ,
12751329 sort : bool = True ,
12761330 isbin : bool = False ,
@@ -1383,18 +1437,38 @@ def groupby_reduce(
13831437 )
13841438 reindex = _validate_reindex (reindex , func , method , expected_groups )
13851439
1386- if not is_duck_array (by ):
1387- by = np .asarray (by )
1440+ by : tuple = tuple (np .asarray (b ) if not is_duck_array (b ) else b for b in by )
1441+ nby = len (by )
1442+ by_is_dask = any (is_duck_dask_array (b ) for b in by )
13881443 if not is_duck_array (array ):
13891444 array = np .asarray (array )
1445+ if isinstance (isbin , bool ):
1446+ isbin = (isbin ,) * len (by )
1447+ if expected_groups is None :
1448+ expected_groups = (None ,) * len (by )
13901449
13911450 _assert_by_is_aligned (array .shape , by )
13921451
1452+ if len (by ) == 1 and not isinstance (expected_groups , tuple ):
1453+ expected_groups = (np .asarray (expected_groups ),)
1454+ elif len (expected_groups ) != len (by ):
1455+ raise ValueError ("len(expected_groups) != len(by)" )
1456+
13931457 # We convert to pd.Index since that lets us know if we are binning or not
13941458 # (pd.IntervalIndex or not)
1395- expected_groups = _convert_expected_groups_to_index (expected_groups , isbin )
1396- if expected_groups is not None and sort :
1397- expected_groups = expected_groups .sort_values ()
1459+ expected_groups = _convert_expected_groups_to_index (expected_groups , isbin , sort )
1460+
1461+ # when grouping by multiple variables, we factorize early.
1462+ # TODO: could restrict this to dask-only
1463+ if nby > 1 :
1464+ by , final_groups , grp_shape = _factorize_multiple (
1465+ by , expected_groups , by_is_dask = by_is_dask
1466+ )
1467+ expected_groups = (pd .RangeIndex (np .prod (grp_shape )),)
1468+
1469+ assert len (by ) == 1
1470+ by = by [0 ]
1471+ expected_groups = expected_groups [0 ]
13981472
13991473 if axis is None :
14001474 axis = tuple (array .ndim + np .arange (- by .ndim , 0 ))
@@ -1408,7 +1482,7 @@ def groupby_reduce(
14081482
14091483 # TODO: make sure expected_groups is unique
14101484 if len (axis ) == 1 and by .ndim > 1 and expected_groups is None :
1411- if not is_duck_dask_array ( by ) :
1485+ if not by_is_dask :
14121486 expected_groups = _get_expected_groups (by , sort )
14131487 else :
14141488 # When we reduce along all axes, we are guaranteed to see all
@@ -1422,6 +1496,7 @@ def groupby_reduce(
14221496 "Please provide ``expected_groups`` when not reducing along all axes."
14231497 )
14241498
1499+ assert len (axis ) <= by .ndim
14251500 if len (axis ) < by .ndim :
14261501 by = _move_reduce_dims_to_end (by , - array .ndim + np .array (axis ) + by .ndim )
14271502 array = _move_reduce_dims_to_end (array , axis )
@@ -1514,7 +1589,7 @@ def groupby_reduce(
15141589 result , * groups = partial_agg (
15151590 array ,
15161591 by ,
1517- expected_groups = expected_groups ,
1592+ expected_groups = None if method == "blockwise" else expected_groups ,
15181593 reindex = reindex ,
15191594 method = method ,
15201595 sort = sort ,
@@ -1526,4 +1601,10 @@ def groupby_reduce(
15261601 result = result [..., sorted_idx ]
15271602 groups = (groups [0 ][sorted_idx ],)
15281603
1604+ if nby > 1 :
1605+ # nan group labels are factorized to -1, and preserved
1606+ # now we get rid of them
1607+ nanmask = groups [0 ] == - 1
1608+ groups = final_groups
1609+ result = result [..., ~ nanmask ].reshape (result .shape [:- 1 ] + grp_shape )
15291610 return (result , * groups )
0 commit comments