Skip to content

Commit b8e507c

Browse files
committed
Fix binning by dimension coordinate
1 parent d67f60a commit b8e507c

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

dask_groupby/xarray.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,8 @@ def xarray_reduce(
175175
ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables])
176176
if dim is Ellipsis:
177177
dim = tuple(obj.dims)
178-
if by[0].name in ds.dims:
178+
if by[0].name in ds.dims and not isbin[0]:
179179
dim = tuple(d for d in dim if d != by[0].name)
180-
dim = tuple(dim)
181180

182181
# TODO: do this for specific reductions only
183182
bad_dtypes = tuple(
@@ -203,14 +202,16 @@ def xarray_reduce(
203202
raise ValueError(f"cannot reduce over dimensions {dim}")
204203

205204
dims_not_in_groupers = tuple(d for d in dim if d not in grouper_dims)
206-
if dims_not_in_groupers == dim:
205+
if dims_not_in_groupers == dim and not any(isbin):
207206
# reducing along a dimension along which groups do not vary
208207
# This is really just a normal reduction.
208+
# This is not right when binning so we exclude.
209209
if skipna:
210210
dsfunc = func[3:]
211211
else:
212212
dsfunc = func
213-
result = getattr(ds, dsfunc)(dim=dim)
213+
# TODO: skipna needs test
214+
result = getattr(ds, dsfunc)(dim=dim, skipna=skipna)
214215
if isinstance(obj, xr.DataArray):
215216
return obj._from_temp_dataset(result)
216217
else:

tests/test_xarray.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,16 @@ def test_xarray_groupby_bins(chunks):
290290
)
291291
xr.testing.assert_equal(actual, expected)
292292

293-
# TODO: fix this test
294-
# expected_xr = array.groupby_bins(labels, bins=[1, 2, 4, 5]).count().fillna(0)
293+
# TODO: fix this test?
294+
# expected_xr = array.groupby_bins(labels, bins=[1, 2, 4, 5]).count() #.fillna(0)
295295
# xr.testing.assert_equal(actual, expected_xr)
296296

297+
da = xr.DataArray(np.random.randn(2, 3, 4))
298+
bins = [-1, 0, 1, 2]
299+
with xr.set_options(use_numpy_groupies=False):
300+
actual = da.groupby_bins("dim_0", bins).mean(...)
301+
with xr.set_options(use_numpy_groupies=True):
302+
expected = da.groupby_bins("dim_0", bins).mean(...)
303+
xr.testing.assert_allclose(actual, expected)
304+
297305
# TODO: test cut_kwargs

0 commit comments

Comments
 (0)