Skip to content

Commit 8497cc9

Browse files
committed
Clean up binning by using pd.Index for expected_groups
1 parent ced653f commit 8497cc9

File tree

4 files changed

+68
-69
lines changed

4 files changed

+68
-69
lines changed

flox/core.py

Lines changed: 51 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,14 @@ def _prepare_for_flox(group_idx, array):
5858
return group_idx, ordered_array
5959

6060

61-
def _get_expected_groups(by, raise_if_dask=True) -> np.ndarray | None:
61+
def _get_expected_groups(by, raise_if_dask=True) -> pd.Index | None:
6262
if is_duck_dask_array(by):
6363
if raise_if_dask:
6464
raise ValueError("Please provide `expected_groups`.")
6565
return None
6666
flatby = by.ravel()
67-
return np.unique(flatby[~isnull(flatby)])
67+
expected = np.unique(flatby[~isnull(flatby)])
68+
return _convert_expected_groups_to_index(expected, isbin=False)
6869

6970

7071
def _get_chunk_reduction(reduction_type: str) -> Callable:
@@ -324,31 +325,28 @@ def rechunk_for_blockwise(array, axis, labels):
324325

325326
def reindex_(array: np.ndarray, from_, to, fill_value=None, axis: int = -1) -> np.ndarray:
326327

328+
assert isinstance(to, pd.Index)
327329
assert axis in (0, -1)
328330

329-
from_ = np.atleast_1d(from_)
330-
to = np.atleast_1d(to)
331-
332331
if to.ndim > 1:
333332
raise ValueError(f"Cannot reindex to a multidimensional array: {to}")
334333

335-
# short-circuit for trivial case
336-
if len(from_) == len(to) and np.all(from_ == to):
337-
return array
338-
339334
if array.shape[axis] == 0:
340335
# all groups were NaN
341336
reindexed = np.full(array.shape[:-1] + (len(to),), fill_value, dtype=array.dtype)
342337
return reindexed
343338

339+
from_ = pd.Index(from_)
340+
# short-circuit for trivial case
341+
if from_.equals(to):
342+
return array
343+
344344
if from_.dtype.kind == "O" and isinstance(from_[0], tuple):
345345
raise NotImplementedError(
346346
"Currently does not support reindexing with object arrays of tuples. "
347347
"These occur when grouping by multi-indexed variables in xarray."
348348
)
349-
idx = np.array(
350-
[np.argwhere(np.array(from_) == label)[0, 0] if label in from_ else -1 for label in to]
351-
)
349+
idx = from_.get_indexer(to)
352350
indexer = [slice(None, None)] * array.ndim
353351
indexer[axis] = idx # type: ignore
354352
reindexed = array[tuple(indexer)]
@@ -384,29 +382,25 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
384382
return offset, size
385383

386384

387-
def factorize_(by: tuple, axis, expected_groups: tuple = None, isbin: tuple = None):
385+
def factorize_(by: tuple, axis, expected_groups: tuple[pd.Index, ...] = None):
388386
if not isinstance(by, tuple):
389387
raise ValueError(f"Expected `by` to be a tuple. Received {type(by)} instead")
390388

391-
if isbin is None:
392-
isbin = (False,) * len(by)
393389
if expected_groups is None:
394390
expected_groups = (None,) * len(by)
395391

396392
factorized = []
397393
found_groups = []
398-
for groupvar, expect, tobin in zip(by, expected_groups, isbin):
399-
if tobin:
394+
for groupvar, expect in zip(by, expected_groups):
395+
if isinstance(expect, pd.IntervalIndex):
400396
# when binning we change expected groups to integers marking the interval
401397
# this makes the reindexing logic simpler.
402398
if expect is None:
403399
raise ValueError("Please pass bin edges in expected_groups.")
404400
# idx = np.digitize(groupvar.ravel(), expect) - 1
405-
idx = pd.cut(groupvar.ravel(), bins=expect, labels=False)
401+
idx = pd.cut(groupvar.ravel(), bins=expect, labels=False).codes.copy()
406402
# same sentinel value as factorize
407-
idx[np.isnan(idx)] = -1
408-
idx = idx.astype(int, copy=False)
409-
found_groups.append(np.arange(len(expect) - 1))
403+
found_groups.append(expect)
410404
else:
411405
idx, groups = pd.factorize(groupvar.ravel())
412406
found_groups.append(np.array(groups))
@@ -447,12 +441,11 @@ def chunk_argreduce(
447441
array_plus_idx: tuple[np.ndarray, ...],
448442
by: np.ndarray,
449443
func: Sequence[str],
450-
expected_groups: Sequence | np.ndarray | None,
444+
expected_groups: pd.Index | None,
451445
axis: int | Sequence[int],
452446
fill_value: Mapping[str | Callable, Any],
453447
dtype=None,
454448
reindex: bool = False,
455-
isbin: bool = False,
456449
engine: str = "numpy",
457450
) -> IntermediateDict:
458451
"""
@@ -470,7 +463,6 @@ def chunk_argreduce(
470463
expected_groups=None,
471464
axis=axis,
472465
fill_value=fill_value,
473-
isbin=isbin,
474466
dtype=dtype,
475467
engine=engine,
476468
)
@@ -493,12 +485,11 @@ def chunk_reduce(
493485
array: np.ndarray,
494486
by: np.ndarray,
495487
func: str | Callable | Sequence[str] | Sequence[Callable],
496-
expected_groups: Sequence | np.ndarray = None,
488+
expected_groups: pd.Index | None,
497489
axis: int | Sequence[int] = None,
498490
fill_value: Mapping[str | Callable, Any] = None,
499491
dtype=None,
500492
reindex: bool = False,
501-
isbin: bool = False,
502493
engine: str = "numpy",
503494
kwargs=None,
504495
) -> IntermediateDict:
@@ -574,12 +565,9 @@ def chunk_reduce(
574565
# indices=[0,0,0]. This is necessary when combining block results
575566
# factorize can handle strings etc unlike digitize
576567
group_idx, groups, _, ngroups, size, props = factorize_(
577-
(by,), axis, expected_groups=(expected_groups,), isbin=(isbin,)
568+
(by,), axis, expected_groups=(expected_groups,)
578569
)
579570
groups = groups[0]
580-
# TODO: why?
581-
if isbin:
582-
expected_groups = groups
583571

584572
# always reshape to 1D along group dimensions
585573
newshape = array.shape[: array.ndim - by.ndim] + (np.prod(array.shape[-by.ndim :]),)
@@ -590,7 +578,8 @@ def chunk_reduce(
590578

591579
results: IntermediateDict = {"groups": [], "intermediates": []}
592580
if reindex and expected_groups is not None:
593-
results["groups"] = np.array(expected_groups)
581+
# TODO: what happens with binning here?
582+
results["groups"] = expected_groups.values
594583
else:
595584
if empty:
596585
results["groups"] = np.array([np.nan])
@@ -654,16 +643,13 @@ def chunk_reduce(
654643
if props.offset_group:
655644
result = result.reshape(*final_array_shape[:-1], ngroups)
656645
if reindex:
657-
if not isbin:
658-
result = reindex_(result, groups, expected_groups, fill_value=fv)
646+
result = reindex_(result, groups, expected_groups, fill_value=fv)
659647
else:
660648
result = result[..., sortidx]
661649
result = result.reshape(final_array_shape)
662650
results["intermediates"].append(result)
663-
if final_groups_shape:
664-
# This happens when to_group is broadcasted, and we reduce along the broadcast
665-
# dimension
666-
results["groups"] = np.broadcast_to(results["groups"], final_groups_shape)
651+
652+
results["groups"] = np.broadcast_to(results["groups"], final_groups_shape)
667653
return results
668654

669655

@@ -691,7 +677,7 @@ def _finalize_results(
691677
results: IntermediateDict,
692678
agg: Aggregation,
693679
axis: Sequence[int],
694-
expected_groups: Sequence | np.ndarray | None,
680+
expected_groups: pd.Index | None,
695681
fill_value: Any,
696682
):
697683
"""Finalize results by
@@ -731,7 +717,7 @@ def _finalize_results(
731717
finalized[agg.name] = reindex_(
732718
finalized[agg.name], squeezed["groups"], expected_groups, fill_value=fill_value
733719
)
734-
finalized["groups"] = expected_groups
720+
finalized["groups"] = expected_groups.to_numpy()
735721
else:
736722
finalized["groups"] = squeezed["groups"]
737723

@@ -742,7 +728,7 @@ def _aggregate(
742728
x_chunk,
743729
combine: Callable,
744730
agg: Aggregation,
745-
expected_groups: Sequence | np.ndarray | None,
731+
expected_groups: pd.Index | None,
746732
axis: Sequence,
747733
keepdims,
748734
fill_value: Any,
@@ -796,7 +782,9 @@ def reindex_intermediates(x, agg, unique_groups):
796782
new_shape = x["groups"].shape[:-1] + (len(unique_groups),)
797783
newx = {"groups": np.broadcast_to(unique_groups, new_shape)}
798784
newx["intermediates"] = tuple(
799-
reindex_(v, from_=x["groups"].squeeze(), to=unique_groups, fill_value=f)
785+
reindex_(
786+
v, from_=np.atleast_1d(x["groups"].squeeze()), to=pd.Index(unique_groups), fill_value=f
787+
)
800788
for v, f in zip(x["intermediates"], agg.fill_value["intermediate"])
801789
)
802790
return newx
@@ -940,7 +928,7 @@ def split_blocks(applied, split_out, expected_groups, split_name):
940928
return intermediate, group_chunks
941929

942930

943-
def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, isbin, engine):
931+
def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, engine):
944932
"""
945933
Blockwise groupby reduction that produces the final result. This code path is
946934
also used for non-dask array aggregations.
@@ -964,13 +952,12 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, isbi
964952
by,
965953
func=func,
966954
axis=axis,
967-
expected_groups=expected_groups if isbin else None,
955+
expected_groups=expected_groups,
968956
# This fill_value should only apply to groups that only contain NaN observations
969957
# BUT there is funkiness when axis is a subset of all possible values
970958
# (see below)
971959
fill_value=(agg.fill_value[agg.name], 0),
972960
dtype=(agg.dtype[agg.name], np.intp),
973-
isbin=isbin,
974961
kwargs=finalize_kwargs,
975962
engine=engine,
976963
) # type: ignore
@@ -995,9 +982,6 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, isbi
995982
value[mask] = np.nan
996983
results["intermediates"][0] = value
997984

998-
if isbin:
999-
expected_groups = np.arange(len(expected_groups) - 1)
1000-
1001985
# When axis is a subset of possible values; then npg will
1002986
# apply it to groups that don't exist along a particular axis (for e.g.)
1003987
# since these count as a group that is absent. thoo!
@@ -1022,12 +1006,11 @@ def dask_groupby_agg(
10221006
array: DaskArray,
10231007
by: DaskArray | np.ndarray,
10241008
agg: Aggregation,
1025-
expected_groups: Sequence | np.ndarray | None,
1009+
expected_groups: pd.Index | None,
10261010
axis: Sequence = None,
10271011
split_out: int = 1,
10281012
fill_value: Any = None,
10291013
method: str = "map-reduce",
1030-
isbin: bool = False,
10311014
reindex: bool = False,
10321015
engine: str = "numpy",
10331016
) -> tuple[DaskArray, np.ndarray | DaskArray]:
@@ -1107,8 +1090,7 @@ def dask_groupby_agg(
11071090
partial(
11081091
blockwise_method,
11091092
axis=axis,
1110-
expected_groups=expected_groups if reindex or split_out > 1 or isbin else None,
1111-
isbin=isbin,
1093+
expected_groups=expected_groups,
11121094
engine=engine,
11131095
),
11141096
inds,
@@ -1129,9 +1111,6 @@ def dask_groupby_agg(
11291111
)
11301112
else:
11311113
intermediate = applied
1132-
# from this point on, we just work with bin indexes when binning
1133-
if isbin:
1134-
expected_groups = np.arange(len(expected_groups) - 1)
11351114
if expected_groups is None:
11361115
expected_groups = _get_expected_groups(by_input, raise_if_dask=False)
11371116
group_chunks = (len(expected_groups),) if expected_groups is not None else (np.nan,)
@@ -1212,7 +1191,7 @@ def dask_groupby_agg(
12121191
if method == "map-reduce":
12131192
if expected_groups is None:
12141193
expected_groups = _get_expected_groups(by_input)
1215-
groups = (expected_groups,)
1194+
groups = (expected_groups.values,)
12161195
else:
12171196
groups = (np.concatenate(groups_in_block),)
12181197

@@ -1269,12 +1248,22 @@ def _assert_by_is_aligned(shape, by):
12691248
)
12701249

12711250

1251+
def _convert_expected_groups_to_index(expected_groups, isbin: bool) -> pd.Index | None:
1252+
if isinstance(expected_groups, pd.Index):
1253+
return expected_groups
1254+
if isbin:
1255+
return pd.IntervalIndex.from_arrays(expected_groups[:-1], expected_groups[1:])
1256+
elif expected_groups is not None:
1257+
return pd.Index(expected_groups)
1258+
return None
1259+
1260+
12721261
def groupby_reduce(
12731262
array: np.ndarray | DaskArray,
12741263
by: np.ndarray | DaskArray,
12751264
func: str | Aggregation,
12761265
*,
1277-
expected_groups: Sequence | np.ndarray = None,
1266+
expected_groups: Sequence | np.ndarray | None = None,
12781267
sort: bool = True,
12791268
isbin: bool = False,
12801269
axis=None,
@@ -1393,6 +1382,10 @@ def groupby_reduce(
13931382

13941383
_assert_by_is_aligned(array.shape, by)
13951384

1385+
# We convert to pd.Index since that lets us know if we are binning or not
1386+
# (pd.IntervalIndex or not)
1387+
expected_groups = _convert_expected_groups_to_index(expected_groups, isbin)
1388+
13961389
if axis is None:
13971390
axis = tuple(array.ndim + np.arange(-by.ndim, 0))
13981391
else:
@@ -1435,7 +1428,7 @@ def groupby_reduce(
14351428
assert isinstance(finalize_kwargs, dict)
14361429
agg.finalize_kwargs = finalize_kwargs
14371430

1438-
kwargs = dict(axis=axis, fill_value=fill_value, isbin=isbin, engine=engine)
1431+
kwargs = dict(axis=axis, fill_value=fill_value, engine=engine)
14391432

14401433
if not is_duck_dask_array(array) and not is_duck_dask_array(by):
14411434
results = _reduce_blockwise(array, by, agg, expected_groups=expected_groups, **kwargs)
@@ -1484,7 +1477,7 @@ def groupby_reduce(
14841477
r, *g = partial_agg(
14851478
array_subset,
14861479
by[np.ix_(*indexer)],
1487-
expected_groups=cohort,
1480+
expected_groups=pd.Index(cohort),
14881481
# reindex to expected_groups at the blockwise step.
14891482
# this approach avoids replacing non-cohort members with
14901483
# np.nan or some other sentinel value, and preserves dtypes
@@ -1504,7 +1497,6 @@ def groupby_reduce(
15041497
if method == "blockwise" and by.ndim == 1:
15051498
array = rechunk_for_blockwise(array, axis=-1, labels=by)
15061499

1507-
# TODO: test with mixed array kinds (numpy array + dask by)
15081500
result, *groups = partial_agg(
15091501
array, by, expected_groups=expected_groups, reindex=reindex, method=method
15101502
)

flox/xarray.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ def xarray_reduce(
259259
tuple(g.data for g in by),
260260
axis,
261261
expected_groups,
262-
isbin,
263262
)
264263
to_group = xr.DataArray(group_idx, dims=dim, coords={d: by[0][d] for d in by[0].indexes})
265264
else:
@@ -307,7 +306,7 @@ def wrapper(array, to_group, *, func, skipna, **kwargs):
307306
reindexed = reindex_(
308307
result,
309308
from_=groups,
310-
to=np.arange(np.prod(group_shape)),
309+
to=pd.Index(np.arange(np.prod(group_shape))),
311310
fill_value=fill_value,
312311
axis=-1,
313312
)

tests/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
from distutils import version
44

55
import numpy as np
6+
import pandas as pd
67
import pytest
78

9+
pd_types = (pd.Index,)
10+
811
try:
912
import dask
1013
import dask.array as da
@@ -84,7 +87,9 @@ def assert_equal(a, b):
8487
a = np.array(a)
8588
if isinstance(b, list):
8689
b = np.array(b)
87-
if has_xarray and isinstance(a, xr_types) or isinstance(b, xr_types):
90+
if isinstance(a, pd_types) or isinstance(b, pd_types):
91+
pd.testing.assert_index_equal(a, b)
92+
elif has_xarray and isinstance(a, xr_types) or isinstance(b, xr_types):
8893
xr.testing.assert_identical(a, b)
8994
elif has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type):
9095
# does some validation of the dask graph

0 commit comments

Comments
 (0)