Skip to content

Commit d648448

Browse files
committed
Let xarray return np.nan for count in empty bins.
Xarray passes through min_count=1 so this is pretty clean. I removed the xarray test since it needn't live here
1 parent b8e507c commit d648448

File tree

3 files changed

+6
-14
lines changed

3 files changed

+6
-14
lines changed

dask_groupby/core.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,15 +1146,14 @@ def groupby_reduce(
11461146
)
11471147
reduction.fill_value[func] = _get_fill_value(reduction.dtype, reduction.fill_value[func])
11481148

1149-
# TODO: delete?
1150-
# if fill_value is None:
1151-
# fill_value = reduction.fill_value[func]
1152-
11531149
if min_count is not None:
1154-
assert func in ["nansum", "nanprod"]
1150+
# Let this pass so that xarray can keep return np.nan for bins with
1151+
# no observations. The restriction of min_count to nansum, nanprod
1152+
# seems to be an Xarray limitation so there's no reason we need to copy it.
11551153
# nansum, nanprod have fill_value=0, 1
11561154
# overwrite than when min_count is set
1157-
fill_value = np.nan
1155+
if func in ["nansum", "nanprod"] and fill_value is None:
1156+
fill_value = np.nan
11581157

11591158
# TODO: handle reduction being something custom not present in numpy_groupies
11601159
if not is_duck_dask_array(array) and not is_duck_dask_array(by):

tests/test_core.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,7 @@ def test_groupby_reduce(
116116

117117
@pytest.mark.parametrize("engine", ["numpy", "numba"])
118118
@pytest.mark.parametrize("size", ((12,), (12, 5)))
119-
@pytest.mark.parametrize(
120-
"func",
121-
ALL_FUNCS,
122-
)
119+
@pytest.mark.parametrize("func", ALL_FUNCS)
123120
def test_groupby_reduce_all(size, func, engine):
124121

125122
by = np.ones(size[-1])

tests/test_xarray.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,6 @@ def test_xarray_groupby_bins(chunks):
290290
)
291291
xr.testing.assert_equal(actual, expected)
292292

293-
# TODO: fix this test?
294-
# expected_xr = array.groupby_bins(labels, bins=[1, 2, 4, 5]).count() #.fillna(0)
295-
# xr.testing.assert_equal(actual, expected_xr)
296-
297293
da = xr.DataArray(np.random.randn(2, 3, 4))
298294
bins = [-1, 0, 1, 2]
299295
with xr.set_options(use_numpy_groupies=False):

0 commit comments

Comments
 (0)