Skip to content

Commit 2cef60e

Browse files
authored
Fix xarray_reduce when by has NaN. (#54)
* Fix xarray_reduce when `by` has NaN. * fix test
1 parent 4890787 commit 2cef60e

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

flox/xarray.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def xarray_reduce(
217217
dim = _atleast_1d(dim)
218218

219219
if any(d not in grouper_dims and d not in obj.dims for d in dim):
220-
raise ValueError(f"cannot reduce over dimensions {dim}")
220+
raise ValueError(f"Cannot reduce over absent dimensions {dim}.")
221221

222222
dims_not_in_groupers = tuple(d for d in dim if d not in grouper_dims)
223223
if dims_not_in_groupers == dim and not any(isbin):
@@ -248,7 +248,11 @@ def xarray_reduce(
248248
to_group = xr.DataArray(group_idx, dims=dim, coords={d: by[0][d] for d in by[0].indexes})
249249
else:
250250
if expected_groups is None and isinstance(by[0].data, np.ndarray):
251-
expected_groups = (np.unique(by[0].data),)
251+
uniques = np.unique(by[0].data)
252+
nans = isnull(uniques)
253+
if nans.any():
254+
uniques = uniques[~nans]
255+
expected_groups = (uniques,)
252256
if expected_groups is None:
253257
raise NotImplementedError(
254258
"Please provide expected_groups if not grouping by a numpy-backed DataArray"
@@ -346,6 +350,12 @@ def wrapper(array, to_group, *, func, skipna, **kwargs):
346350
},
347351
)
348352

353+
# restore non-dim coord variables without the core dimension
354+
# TODO: shouldn't apply_ufunc handle this?
355+
for var in set(ds.variables) - set(ds.dims):
356+
if all(d not in ds[var].dims for d in dim):
357+
actual[var] = ds[var]
358+
349359
for name, expect, isbin_ in zip(group_names, expected_groups, isbin):
350360
if isbin_:
351361
expect = [pd.Interval(left, right) for left, right in zip(expect[:-1], expect[1:])]

tests/test_xarray.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ def test_xarray_reduce(skipna, add_nan, min_count, engine):
5656
)
5757
assert_equal(expected, actual)
5858

59+
da["labels2"] = da.labels2.astype(float)
60+
da["labels2"][0] = np.nan
61+
expected = da.groupby("labels2").sum(skipna=skipna, min_count=min_count)
62+
actual = xarray_reduce(
63+
da, "labels2", func="sum", skipna=skipna, min_count=min_count, engine=engine
64+
)
65+
assert_equal(expected, actual)
66+
5967
# test dimension ordering
6068
# actual = xarray_reduce(
6169
# da.transpose("y", ...), "labels", func="sum", skipna=skipna, min_count=min_count
@@ -175,7 +183,7 @@ def test_xarray_reduce_errors():
175183
xarray_reduce(da, by, func="mean")
176184

177185
by.name = "by"
178-
with pytest.raises(ValueError, match="cannot reduce over"):
186+
with pytest.raises(ValueError, match="Cannot reduce over"):
179187
xarray_reduce(da, by, func="mean", dim="foo")
180188

181189
if has_dask:

0 commit comments

Comments
 (0)