Skip to content

Commit f42f3ff

Browse files
authored
Pass ddof through for numbagg (#302)
* Support ddof with numbagg * Fix tests
1 parent 1368f0f commit f42f3ff

File tree

4 files changed

+27
-15
lines changed

4 files changed

+27
-15
lines changed

ci/environment.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ dependencies:
1919
- pytest-xdist
2020
- xarray
2121
- pre-commit
22-
- numbagg>=0.3
2322
- numpy_groupies>=0.9.19
2423
- pooch
2524
- toolz

ci/no-dask.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ dependencies:
1919
- pooch
2020
- toolz
2121
- numba
22+
- numbagg>=0.3

flox/aggregate_numbagg.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import numbagg
44
import numbagg.grouped
55
import numpy as np
6+
from packaging.version import Version
7+
8+
NUMBAGG_SUPPORTS_DDOF = Version(numbagg.__version__) >= Version("0.7.0")
69

710
DEFAULT_FILL_VALUE = {
811
"nansum": 0,
@@ -42,6 +45,7 @@ def _numbagg_wrapper(
4245
size=None,
4346
fill_value=None,
4447
dtype=None,
48+
**kwargs,
4549
):
4650
cast_to = CAST_TO.get(func, None)
4751
if cast_to:
@@ -56,6 +60,7 @@ def _numbagg_wrapper(
5660
group_idx,
5761
axis=axis,
5862
num_labels=size,
63+
**kwargs,
5964
# The following are unsupported
6065
# fill_value=fill_value,
6166
# dtype=dtype,
@@ -65,30 +70,36 @@ def _numbagg_wrapper(
6570

6671

6772
def nanvar(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, ddof=0):
68-
assert ddof != 0
69-
73+
kwargs = {}
74+
if NUMBAGG_SUPPORTS_DDOF:
75+
kwargs["ddof"] = ddof
76+
elif ddof != 1:
77+
raise ValueError("Need numbagg >= v0.7.0 to support ddof != 1")
7078
return _numbagg_wrapper(
7179
group_idx,
7280
array,
7381
axis=axis,
7482
size=size,
7583
func="nanvar",
76-
# ddof=0,
84+
**kwargs,
7785
# fill_value=fill_value,
7886
# dtype=dtype,
7987
)
8088

8189

8290
def nanstd(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, ddof=0):
83-
assert ddof != 0
84-
91+
kwargs = {}
92+
if NUMBAGG_SUPPORTS_DDOF:
93+
kwargs["ddof"] = ddof
94+
elif ddof != 1:
95+
raise ValueError("Need numbagg >= v0.7.0 to support ddof != 1")
8596
return _numbagg_wrapper(
8697
group_idx,
8798
array,
8899
axis=axis,
89100
size=size,
90-
func="nanstd"
91-
# ddof=0,
101+
func="nanstd",
102+
**kwargs,
92103
# fill_value=fill_value,
93104
# dtype=dtype,
94105
)

flox/aggregations.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,16 @@ def generic_aggregate(
7373
from . import aggregate_numbagg
7474

7575
try:
76-
if (
77-
# numabgg hardcodes ddof=1
78-
("var" in func or "std" in func)
79-
and kwargs.get("ddof", 0) == 0
80-
):
81-
method = get_npg_aggregation(func, engine="numpy")
82-
76+
if "var" in func or "std" in func:
77+
ddof = kwargs.get("ddof", 0)
78+
if aggregate_numbagg.NUMBAGG_SUPPORTS_DDOF or (ddof != 0):
79+
method = getattr(aggregate_numbagg, func)
80+
else:
81+
logger.debug(f"numbagg too old for ddof={ddof}. Falling back to numpy")
82+
method = get_npg_aggregation(func, engine="numpy")
8383
else:
8484
method = getattr(aggregate_numbagg, func)
85+
8586
except AttributeError:
8687
logger.debug(f"Couldn't find {func} for engine='numbagg'. Falling back to numpy")
8788
method = get_npg_aggregation(func, engine="numpy")

0 commit comments

Comments
 (0)