Skip to content

Commit af3e3ce

Browse files
Raise error if multiple by's are used with Ellipsis (#149)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ab29d2c commit af3e3ce

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

flox/xarray.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def xarray_reduce(
194194
if skipna is not None and isinstance(func, Aggregation):
195195
raise ValueError("skipna must be None when func is an Aggregation.")
196196

197+
nby = len(by)
197198
for b in by:
198199
if isinstance(b, xr.DataArray) and b.name is None:
199200
raise ValueError("Cannot group by unnamed DataArrays.")
@@ -203,11 +204,11 @@ def xarray_reduce(
203204
keep_attrs = True
204205

205206
if isinstance(isbin, bool):
206-
isbin = (isbin,) * len(by)
207+
isbin = (isbin,) * nby
207208
if expected_groups is None:
208-
expected_groups = (None,) * len(by)
209+
expected_groups = (None,) * nby
209210
if isinstance(expected_groups, (np.ndarray, list)): # TODO: test for list
210-
if len(by) == 1:
211+
if nby == 1:
211212
expected_groups = (expected_groups,)
212213
else:
213214
raise ValueError("Needs better message.")
@@ -239,6 +240,8 @@ def xarray_reduce(
239240
ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables])
240241

241242
if dim is Ellipsis:
243+
if nby > 1:
244+
raise NotImplementedError("Multiple by are not allowed when dim is Ellipsis.")
242245
dim = tuple(obj.dims)
243246
if by[0].name in ds.dims and not isbin[0]:
244247
dim = tuple(d for d in dim if d != by[0].name)
@@ -351,7 +354,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
351354
missing_dim[k] = v
352355

353356
input_core_dims = _get_input_core_dims(group_names, dim, ds, grouper_dims)
354-
input_core_dims += [input_core_dims[-1]] * (len(by) - 1)
357+
input_core_dims += [input_core_dims[-1]] * (nby - 1)
355358

356359
actual = xr.apply_ufunc(
357360
wrapper,
@@ -409,7 +412,7 @@ def wrapper(array, *by, func, skipna, **kwargs):
409412
if unindexed_dims:
410413
actual = actual.drop_vars(unindexed_dims)
411414

412-
if len(by) == 1:
415+
if nby == 1:
413416
for var in actual:
414417
if isinstance(obj, xr.DataArray):
415418
template = obj

tests/test_xarray.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ def test_xarray_reduce_multiple_groupers_2(pass_expected_groups, chunk, engine):
159159
actual = xarray_reduce(da, "labels", "labels2", **kwargs)
160160
xr.testing.assert_identical(expected, actual)
161161

162+
with pytest.raises(NotImplementedError):
163+
xarray_reduce(da, "labels", "labels2", dim=..., **kwargs)
164+
162165

163166
@requires_dask
164167
def test_dask_groupers_error():

0 commit comments

Comments
 (0)