Skip to content

Commit 5dbf992

Browse files
authored
Multiple groupers v3 (#76)
1 parent 35dd38d commit 5dbf992

File tree

4 files changed

+197
-124
lines changed

4 files changed

+197
-124
lines changed

flox/core.py

Lines changed: 116 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import itertools
55
import operator
66
from collections import namedtuple
7-
from functools import partial
7+
from functools import partial, reduce
88
from typing import TYPE_CHECKING, Any, Callable, Dict, Mapping, Sequence, Union
99

1010
import numpy as np
@@ -59,16 +59,14 @@ def _prepare_for_flox(group_idx, array):
5959
return group_idx, ordered_array
6060

6161

62-
def _get_expected_groups(by, sort, raise_if_dask=True) -> pd.Index | None:
62+
def _get_expected_groups(by, sort, *, raise_if_dask=True) -> pd.Index | None:
6363
if is_duck_dask_array(by):
6464
if raise_if_dask:
65-
raise ValueError("Please provide `expected_groups`.")
65+
raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
6666
return None
6767
flatby = by.ravel()
6868
expected = pd.unique(flatby[~isnull(flatby)])
69-
if sort:
70-
expected = np.sort(expected)
71-
return _convert_expected_groups_to_index(expected, isbin=False)
69+
return _convert_expected_groups_to_index((expected,), isbin=(False,), sort=sort)[0]
7270

7371

7472
def _get_chunk_reduction(reduction_type: str) -> Callable:
@@ -378,6 +376,7 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
378376
Copied from xhistogram &
379377
https://stackoverflow.com/questions/46256279/bin-elements-per-row-vectorized-2d-bincount-for-numpy
380378
"""
379+
assert labels.ndim > 1
381380
offset: np.ndarray = (
382381
labels + np.arange(np.prod(labels.shape[:-1])).reshape((*labels.shape[:-1], -1)) * ngroups
383382
)
@@ -388,7 +387,12 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
388387

389388

390389
def factorize_(
391-
by: tuple, axis, expected_groups: tuple[pd.Index, ...] = None, reindex=False, sort=True
390+
by: tuple,
391+
axis,
392+
expected_groups: tuple[pd.Index, ...] = None,
393+
reindex=False,
394+
sort=True,
395+
fastpath=False,
392396
):
393397
"""
394398
Returns an array of integer codes for groups (and associated data)
@@ -413,7 +417,7 @@ def factorize_(
413417
raise ValueError("Please pass bin edges in expected_groups.")
414418
# TODO: fix for binning
415419
found_groups.append(expect)
416-
# pd.cut with bins = IntervalIndex[datetime64] doesn't work...
420+
# pd.cut with bins = IntervalIndex[datetime64] doesn't work...
417421
if groupvar.dtype.kind == "M":
418422
expect = np.concatenate([expect.left.to_numpy(), [expect.right[-1].to_numpy()]])
419423
idx = pd.cut(groupvar.ravel(), bins=expect).codes.copy()
@@ -440,10 +444,15 @@ def factorize_(
440444
grp_shape = tuple(len(grp) for grp in found_groups)
441445
ngroups = np.prod(grp_shape)
442446
if len(by) > 1:
443-
group_idx = np.ravel_multi_index(factorized, grp_shape).reshape(by[0].shape)
447+
group_idx = np.ravel_multi_index(factorized, grp_shape, mode="wrap").reshape(by[0].shape)
448+
nan_by_mask = reduce(np.logical_or, [isnull(b) for b in by])
449+
group_idx[nan_by_mask] = -1
444450
else:
445451
group_idx = factorized[0]
446452

453+
if fastpath:
454+
return group_idx, found_groups, grp_shape
455+
447456
if np.isscalar(axis) and groupvar.ndim > 1:
448457
# Not reducing along all dimensions of by
449458
# this is OK because for 3D by and axis=(1,2),
@@ -1244,33 +1253,78 @@ def _validate_reindex(reindex: bool, func, method, expected_groups) -> bool:
12441253

12451254

12461255
def _assert_by_is_aligned(shape, by):
1247-
if shape[-by.ndim :] != by.shape:
1248-
raise ValueError(
1249-
"`array` and `by` arrays must be aligned "
1250-
"i.e. array.shape[-by.ndim :] == by.shape. "
1251-
"for every array in `by`."
1252-
f"Received array of shape {shape} but "
1253-
f"`by` has shape {by.shape}."
1256+
for idx, b in enumerate(by):
1257+
if shape[-b.ndim :] != b.shape:
1258+
raise ValueError(
1259+
"`array` and `by` arrays must be aligned "
1260+
"i.e. array.shape[-by.ndim :] == by.shape. "
1261+
"for every array in `by`."
1262+
f"Received array of shape {shape} but "
1263+
f"array {idx} in `by` has shape {b.shape}."
1264+
)
1265+
1266+
1267+
def _convert_expected_groups_to_index(
1268+
expected_groups: tuple, isbin: bool, sort: bool
1269+
) -> pd.Index | None:
1270+
out = []
1271+
for ex, isbin_ in zip(expected_groups, isbin):
1272+
if isinstance(ex, pd.IntervalIndex) or (isinstance(ex, pd.Index) and not isbin):
1273+
if sort:
1274+
ex = ex.sort_values()
1275+
out.append(ex)
1276+
elif ex is not None:
1277+
if isbin_:
1278+
out.append(pd.IntervalIndex.from_arrays(ex[:-1], ex[1:]))
1279+
else:
1280+
if sort:
1281+
ex = np.sort(ex)
1282+
out.append(pd.Index(ex))
1283+
else:
1284+
assert ex is None
1285+
out.append(None)
1286+
return tuple(out)
1287+
1288+
1289+
def _lazy_factorize_wrapper(*by, **kwargs):
1290+
group_idx, *rest = factorize_(by, **kwargs)
1291+
return group_idx
1292+
1293+
1294+
def _factorize_multiple(by, expected_groups, by_is_dask):
1295+
kwargs = dict(
1296+
expected_groups=expected_groups,
1297+
axis=None, # always None, we offset later if necessary.
1298+
fastpath=True,
1299+
)
1300+
if by_is_dask:
1301+
import dask.array
1302+
1303+
group_idx = dask.array.map_blocks(
1304+
_lazy_factorize_wrapper,
1305+
*np.broadcast_arrays(*by),
1306+
meta=np.array((), dtype=np.int64),
1307+
**kwargs,
12541308
)
1309+
found_groups = tuple(None if is_duck_dask_array(b) else pd.unique(b) for b in by)
1310+
grp_shape = tuple(len(e) for e in expected_groups)
1311+
else:
1312+
group_idx, found_groups, grp_shape = factorize_(by, **kwargs)
12551313

1314+
final_groups = tuple(
1315+
found if expect is None else expect.to_numpy()
1316+
for found, expect in zip(found_groups, expected_groups)
1317+
)
12561318

1257-
def _convert_expected_groups_to_index(expected_groups, isbin: bool) -> pd.Index | None:
1258-
if isinstance(expected_groups, pd.IntervalIndex) or (
1259-
isinstance(expected_groups, pd.Index) and not isbin
1260-
):
1261-
return expected_groups
1262-
if isbin:
1263-
return pd.IntervalIndex.from_arrays(expected_groups[:-1], expected_groups[1:])
1264-
elif expected_groups is not None:
1265-
return pd.Index(expected_groups)
1266-
return expected_groups
1319+
if any(grp is None for grp in final_groups):
1320+
raise ValueError("Please provide expected_groups when grouping by a dask array.")
1321+
return (group_idx,), final_groups, grp_shape
12671322

12681323

12691324
def groupby_reduce(
12701325
array: np.ndarray | DaskArray,
1271-
by: np.ndarray | DaskArray,
1326+
*by: np.ndarray | DaskArray,
12721327
func: str | Aggregation,
1273-
*,
12741328
expected_groups: Sequence | np.ndarray | None = None,
12751329
sort: bool = True,
12761330
isbin: bool = False,
@@ -1383,18 +1437,38 @@ def groupby_reduce(
13831437
)
13841438
reindex = _validate_reindex(reindex, func, method, expected_groups)
13851439

1386-
if not is_duck_array(by):
1387-
by = np.asarray(by)
1440+
by: tuple = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
1441+
nby = len(by)
1442+
by_is_dask = any(is_duck_dask_array(b) for b in by)
13881443
if not is_duck_array(array):
13891444
array = np.asarray(array)
1445+
if isinstance(isbin, bool):
1446+
isbin = (isbin,) * len(by)
1447+
if expected_groups is None:
1448+
expected_groups = (None,) * len(by)
13901449

13911450
_assert_by_is_aligned(array.shape, by)
13921451

1452+
if len(by) == 1 and not isinstance(expected_groups, tuple):
1453+
expected_groups = (np.asarray(expected_groups),)
1454+
elif len(expected_groups) != len(by):
1455+
raise ValueError("len(expected_groups) != len(by)")
1456+
13931457
# We convert to pd.Index since that lets us know if we are binning or not
13941458
# (pd.IntervalIndex or not)
1395-
expected_groups = _convert_expected_groups_to_index(expected_groups, isbin)
1396-
if expected_groups is not None and sort:
1397-
expected_groups = expected_groups.sort_values()
1459+
expected_groups = _convert_expected_groups_to_index(expected_groups, isbin, sort)
1460+
1461+
# when grouping by multiple variables, we factorize early.
1462+
# TODO: could restrict this to dask-only
1463+
if nby > 1:
1464+
by, final_groups, grp_shape = _factorize_multiple(
1465+
by, expected_groups, by_is_dask=by_is_dask
1466+
)
1467+
expected_groups = (pd.RangeIndex(np.prod(grp_shape)),)
1468+
1469+
assert len(by) == 1
1470+
by = by[0]
1471+
expected_groups = expected_groups[0]
13981472

13991473
if axis is None:
14001474
axis = tuple(array.ndim + np.arange(-by.ndim, 0))
@@ -1408,7 +1482,7 @@ def groupby_reduce(
14081482

14091483
# TODO: make sure expected_groups is unique
14101484
if len(axis) == 1 and by.ndim > 1 and expected_groups is None:
1411-
if not is_duck_dask_array(by):
1485+
if not by_is_dask:
14121486
expected_groups = _get_expected_groups(by, sort)
14131487
else:
14141488
# When we reduce along all axes, we are guaranteed to see all
@@ -1422,6 +1496,7 @@ def groupby_reduce(
14221496
"Please provide ``expected_groups`` when not reducing along all axes."
14231497
)
14241498

1499+
assert len(axis) <= by.ndim
14251500
if len(axis) < by.ndim:
14261501
by = _move_reduce_dims_to_end(by, -array.ndim + np.array(axis) + by.ndim)
14271502
array = _move_reduce_dims_to_end(array, axis)
@@ -1514,7 +1589,7 @@ def groupby_reduce(
15141589
result, *groups = partial_agg(
15151590
array,
15161591
by,
1517-
expected_groups=expected_groups,
1592+
expected_groups=None if method == "blockwise" else expected_groups,
15181593
reindex=reindex,
15191594
method=method,
15201595
sort=sort,
@@ -1526,4 +1601,10 @@ def groupby_reduce(
15261601
result = result[..., sorted_idx]
15271602
groups = (groups[0][sorted_idx],)
15281603

1604+
if nby > 1:
1605+
# nan group labels are factorized to -1, and preserved
1606+
# now we get rid of them
1607+
nanmask = groups[0] == -1
1608+
groups = final_groups
1609+
result = result[..., ~nanmask].reshape(result.shape[:-1] + grp_shape)
15291610
return (result, *groups)

0 commit comments

Comments
 (0)