Skip to content

Commit d471622

Browse files
committed
Support numba backend
Closes #17 commit 7c79cf65ab797201b8fcfc264a5596ddeb85fbdc Author: dcherian <deepak@cherian.net> Date: Sun Oct 3 18:18:40 2021 +0530 Signature improvements.
1 parent 4b6a64e commit d471622

File tree

5 files changed

+66
-14
lines changed

5 files changed

+66
-14
lines changed

ci/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@ dependencies:
1515
- numpy_groupies
1616
- pooch
1717
- toolz
18+
- numba
1819
- pip:
1920
- icecream

dask_groupby/core.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@
3232
FinalResultsDict = Dict[str, Union["DaskArray", np.ndarray]]
3333

3434

35+
def _get_aggregate(backend):
36+
if backend == "numba":
37+
return npg.aggregate_numba.aggregate
38+
elif backend == "numpy":
39+
return npg.aggregate_numpy.aggregate
40+
else:
41+
raise ValueError(
42+
"Expected backend to be one of ['numpy', 'numba']. Received {backend} instead."
43+
)
44+
45+
3546
def _get_chunk_reduction(reduction_type: str) -> Callable:
3647
if reduction_type == "reduce":
3748
return chunk_reduce
@@ -353,6 +364,7 @@ def chunk_argreduce(
353364
dtype=None,
354365
reindex: bool = False,
355366
isbin: bool = False,
367+
backend: str = "numpy",
356368
) -> IntermediateDict:
357369
"""
358370
Per-chunk arg reduction.
@@ -371,6 +383,7 @@ def chunk_argreduce(
371383
fill_value=fill_value,
372384
isbin=isbin,
373385
dtype=dtype,
386+
backend=backend,
374387
)
375388
if not np.isnan(results["groups"]).all():
376389
# will not work for empty groups...
@@ -398,6 +411,7 @@ def chunk_reduce(
398411
dtype=None,
399412
reindex: bool = False,
400413
isbin: bool = False,
414+
backend: str = "numpy",
401415
) -> IntermediateDict:
402416
"""
403417
Wrapper for numpy_groupies aggregate that supports nD ``array`` and
@@ -503,7 +517,7 @@ def chunk_reduce(
503517
fill_value=fv,
504518
)
505519
else:
506-
result = npg.aggregate_numpy.aggregate(
520+
result = _get_aggregate(backend)(
507521
group_idx,
508522
array,
509523
axis=-1,
@@ -612,9 +626,10 @@ def _npg_aggregate(
612626
group_ndim: int,
613627
fill_value: Any = None,
614628
min_count: Optional[int] = None,
629+
backend: str = "numpy",
615630
) -> FinalResultsDict:
616631
"""Final aggregation step of tree reduction"""
617-
results = _npg_combine(x_chunk, agg, axis, keepdims, group_ndim)
632+
results = _npg_combine(x_chunk, agg, axis, keepdims, group_ndim, backend)
618633
return _finalize_results(results, agg, axis, expected_groups, fill_value, min_count)
619634

620635

@@ -624,6 +639,7 @@ def _npg_combine(
624639
axis: Sequence,
625640
keepdims: bool,
626641
group_ndim: int,
642+
backend: str,
627643
) -> IntermediateDict:
628644
"""Combine intermediates step of tree reduction."""
629645
from dask.array.core import _concatenate2
@@ -682,6 +698,7 @@ def _conc2(key1, key2=None, axis=None) -> np.ndarray:
682698
expected_groups=None,
683699
fill_value=agg.fill_value["intermediate"][slicer],
684700
dtype=agg.dtype,
701+
backend=backend,
685702
)
686703

687704
if agg.chunk[-1] == _count:
@@ -696,6 +713,7 @@ def _conc2(key1, key2=None, axis=None) -> np.ndarray:
696713
expected_groups=None,
697714
fill_value=(0,),
698715
dtype=np.intp,
716+
backend=backend,
699717
)["intermediates"][0]
700718
)
701719

@@ -720,6 +738,7 @@ def _conc2(key1, key2=None, axis=None) -> np.ndarray:
720738
axis=axis,
721739
expected_groups=None,
722740
fill_value=fv,
741+
backend=backend,
723742
)
724743
results["intermediates"].append(*_results["intermediates"])
725744
results["groups"] = _results["groups"]
@@ -769,6 +788,7 @@ def groupby_agg(
769788
method: str = "mapreduce",
770789
min_count: Optional[int] = None,
771790
isbin: bool = False,
791+
backend: str = "numpy",
772792
) -> Tuple["DaskArray", Union[np.ndarray, "DaskArray"]]:
773793

774794
import dask.array
@@ -806,6 +826,7 @@ def groupby_agg(
806826
fill_value=agg.fill_value["intermediate"],
807827
isbin=isbin,
808828
reindex=split_out > 1,
829+
backend=backend,
809830
),
810831
inds,
811832
array,
@@ -851,8 +872,9 @@ def groupby_agg(
851872
group_ndim=by.ndim,
852873
fill_value=fill_value,
853874
min_count=min_count,
875+
backend=backend,
854876
),
855-
combine=partial(_npg_combine, agg=agg, group_ndim=by.ndim),
877+
combine=partial(_npg_combine, agg=agg, group_ndim=by.ndim, backend=backend),
856878
name=f"{name}-reduce",
857879
dtype=array.dtype,
858880
axis=axis,
@@ -880,6 +902,7 @@ def groupby_agg(
880902
group_ndim=by.ndim,
881903
fill_value=fill_value,
882904
min_count=min_count,
905+
backend=backend,
883906
axis=axis,
884907
keepdims=True,
885908
),
@@ -963,6 +986,7 @@ def groupby_reduce(
963986
min_count: Optional[int] = None,
964987
split_out: int = 1,
965988
method: str = "mapreduce",
989+
backend: str = "numpy",
966990
) -> Tuple["DaskArray", Union[np.ndarray, "DaskArray"]]:
967991
"""
968992
GroupBy reductions using tree reductions for dask.array
@@ -1005,6 +1029,8 @@ def groupby_reduce(
10051029
This works well for many time groupings where the group labels repeat
10061030
at regular intervals like 'hour', 'month', dayofyear' etc. Optimize
10071031
chunking ``array`` for this method by first rechunking using ``rechunk_for_cohorts``.
1032+
backend: {"numpy", "numba"}, optional
1033+
Backend for numpy_groupies. numpy by default.
10081034
10091035
Returns
10101036
-------
@@ -1148,6 +1174,7 @@ def groupby_reduce(
11481174
fill_value=fill_value,
11491175
min_count=min_count,
11501176
isbin=isbin,
1177+
backend=backend,
11511178
)
11521179
if method == "cohorts":
11531180
assert len(axis) == 1

dask_groupby/xarray.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def xarray_reduce(
5757
split_out: int = 1,
5858
fill_value=None,
5959
method: str = "mapreduce",
60+
backend: str = "numpy",
6061
keep_attrs: bool = True,
6162
skipna: bool = True,
6263
min_count: Optional[int] = None,
@@ -105,8 +106,15 @@ def xarray_reduce(
105106
'month', dayofyear' etc. Optimize chunking ``array`` for this
106107
method by first rechunking using ``rechunk_for_cohorts``.
107108
108-
skipna: bool
109+
backend: {"numpy", "numba"}, optional
110+
Backend for numpy_groupies
111+
keep_attrs: bool, optional
112+
Preserve attrs?
113+
skipna: bool, optional
109114
Use NaN-skipping aggregations like nanmean?
115+
min_count: int, optional
116+
NaN out when number of non-NaN values in aggregation is < min_count
117+
Only applies to nansum, nanprod.
110118
111119
Raises
112120
------
@@ -266,6 +274,7 @@ def wrapper(*args, **kwargs):
266274
"fill_value": fill_value,
267275
"method": method,
268276
"min_count": min_count,
277+
"backend": backend,
269278
# The following mess exists becuase for multiple `by`s I factorize eagerly
270279
# here before passing it on; this means I have to handle the
271280
# "binning by single by variable" case explicitly where the factorization

tests/test_core.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def test_alignment_error():
3131
groupby_reduce(da, labels, func="mean")
3232

3333

34+
@pytest.mark.parametrize("backend", ["numpy", "numba"])
3435
@pytest.mark.parametrize("dtype", (float, int))
3536
@pytest.mark.parametrize("chunk, split_out", [(False, 1), (True, 1), (True, 2), (True, 3)])
3637
@pytest.mark.parametrize("expected_groups", [None, [0, 1, 2], np.array([0, 1, 2])])
@@ -59,7 +60,9 @@ def test_alignment_error():
5960
# (np.ones((12,)), np.array([labels, labels])), # form 4
6061
],
6162
)
62-
def test_groupby_reduce(array, by, expected, func, expected_groups, chunk, split_out, dtype):
63+
def test_groupby_reduce(
64+
array, by, expected, func, expected_groups, chunk, split_out, dtype, backend
65+
):
6366
array = array.astype(dtype)
6467
if chunk:
6568
if expected_groups is None:
@@ -81,10 +84,12 @@ def test_groupby_reduce(array, by, expected, func, expected_groups, chunk, split
8184
expected_groups=expected_groups,
8285
fill_value=123,
8386
split_out=split_out,
87+
backend=backend,
8488
)
8589
assert_equal(expected, result)
8690

8791

92+
@pytest.mark.parametrize("backend", ["numpy", "numba"])
8893
@pytest.mark.parametrize("size", ((12,), (12, 5)))
8994
@pytest.mark.parametrize(
9095
"func",
@@ -109,7 +114,7 @@ def test_groupby_reduce(array, by, expected, func, expected_groups, chunk, split
109114
pytest.param("nanargmin", marks=(pytest.mark.xfail,)),
110115
),
111116
)
112-
def test_groupby_reduce_all(size, func):
117+
def test_groupby_reduce_all(size, func, backend):
113118

114119
array = np.random.randn(*size)
115120
by = np.ones(size[-1])
@@ -123,13 +128,15 @@ def test_groupby_reduce_all(size, func):
123128
expected = getattr(np, func)(array, axis=-1)
124129
expected = np.expand_dims(expected, -1)
125130

126-
actual, _ = groupby_reduce(array, by, func=func)
131+
actual, _ = groupby_reduce(array, by, func=func, backend=backend)
127132
if "arg" in func:
128133
assert actual.dtype.kind == "i"
129134
assert_equal(actual, expected)
130135

131136
for method in ["mapreduce", "cohorts"]:
132-
actual, _ = groupby_reduce(da.from_array(array, chunks=3), by, func=func, method=method)
137+
actual, _ = groupby_reduce(
138+
da.from_array(array, chunks=3), by, func=func, method=method, backend=backend
139+
)
133140
if "arg" in func:
134141
assert actual.dtype.kind == "i"
135142
assert_equal(actual, expected)
@@ -336,14 +343,15 @@ def test_dask_reduce_axis_subset():
336343
)
337344

338345

346+
@pytest.mark.parametrize("backend", ["numpy", "numba"])
339347
@pytest.mark.parametrize(
340348
"axis", [None, (0, 1, 2), (0, 1), (0, 2), (1, 2), 0, 1, 2, (0,), (1,), (2,)]
341349
)
342-
def test_groupby_reduce_axis_subset_against_numpy(axis):
350+
def test_groupby_reduce_axis_subset_against_numpy(axis, backend):
343351
# tests against the numpy output to make sure dask compute matches
344352
by = np.broadcast_to(labels2d, (3, *labels2d.shape))
345353
array = np.ones_like(by)
346-
kwargs = dict(func="count", axis=axis, expected_groups=[0, 2], fill_value=123)
354+
kwargs = dict(func="count", axis=axis, expected_groups=[0, 2], fill_value=123, backend=backend)
347355
with raise_if_dask_computes():
348356
actual, _ = groupby_reduce(
349357
da.from_array(array, chunks=(-1, 2, 3)),
@@ -354,6 +362,7 @@ def test_groupby_reduce_axis_subset_against_numpy(axis):
354362
assert_equal(actual, expected)
355363

356364

365+
@pytest.mark.parametrize("backend", ["numpy", "numba"])
357366
@pytest.mark.parametrize("chunks", [None, (2, 2, 3)])
358367
@pytest.mark.parametrize(
359368
"axis, groups, expected_shape",
@@ -363,7 +372,7 @@ def test_groupby_reduce_axis_subset_against_numpy(axis):
363372
(None, [0], (1,)), # global reduction; 0 shaped group axis; 1 group
364373
],
365374
)
366-
def test_groupby_reduce_nans(chunks, axis, groups, expected_shape):
375+
def test_groupby_reduce_nans(chunks, axis, groups, expected_shape, backend):
367376
def _maybe_chunk(arr):
368377
if chunks:
369378
return da.from_array(arr, chunks=chunks)
@@ -383,6 +392,7 @@ def _maybe_chunk(arr):
383392
expected_groups=groups,
384393
axis=axis,
385394
fill_value=0,
395+
backend=backend,
386396
)
387397
assert_equal(result, np.zeros(expected_shape, dtype=np.int64))
388398

@@ -394,7 +404,8 @@ def _maybe_chunk(arr):
394404
# by = np.broadcast_to(labels2d, (3, *labels2d.shape))
395405

396406

397-
def test_groupby_all_nan_blocks():
407+
@pytest.mark.parametrize("backend", ["numpy", "numba"])
408+
def test_groupby_all_nan_blocks(backend):
398409
labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0])
399410
nan_labels = labels.astype(float) # copy
400411
nan_labels[:5] = np.nan
@@ -410,6 +421,7 @@ def test_groupby_all_nan_blocks():
410421
da.from_array(by, chunks=(1, 3)),
411422
func="sum",
412423
expected_groups=None,
424+
backend=backend,
413425
)
414426
assert_equal(actual, expected)
415427

tests/test_xarray.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
dask.config.set(scheduler="sync")
1717

1818

19+
@pytest.mark.parametrize("backend", ["numpy", "numba"])
1920
@pytest.mark.parametrize("min_count", [None, 1, 3])
2021
@pytest.mark.parametrize("add_nan", [True, False])
2122
@pytest.mark.parametrize("skipna", [True, False])
22-
def test_xarray_reduce(skipna, add_nan, min_count):
23+
def test_xarray_reduce(skipna, add_nan, min_count, backend):
2324
arr = np.ones((4, 12))
2425

2526
if add_nan:
@@ -38,7 +39,9 @@ def test_xarray_reduce(skipna, add_nan, min_count):
3839
).expand_dims(z=4)
3940

4041
expected = da.groupby("labels").sum(skipna=skipna, min_count=min_count)
41-
actual = xarray_reduce(da, "labels", func="sum", skipna=skipna, min_count=min_count)
42+
actual = xarray_reduce(
43+
da, "labels", func="sum", skipna=skipna, min_count=min_count, backend=backend
44+
)
4245
assert_equal(expected, actual)
4346

4447
# test dimension ordering

0 commit comments

Comments
 (0)