|
6 | 6 | xr = pytest.importorskip("xarray") |
7 | 7 | # isort: on |
8 | 8 |
|
| 9 | +from flox import xrdtypes as dtypes |
9 | 10 | from flox.xarray import rechunk_for_blockwise, xarray_reduce |
10 | 11 |
|
11 | 12 | from . import ( |
@@ -193,13 +194,25 @@ def test_validate_expected_groups(expected_groups): |
193 | 194 |
|
194 | 195 |
|
195 | 196 | @requires_cftime |
| 197 | +@pytest.mark.parametrize("indexer", [slice(None), pytest.param(slice(12), id="missing-group")]) |
| 198 | +@pytest.mark.parametrize("expected_groups", [None, [0, 1, 2, 3]]) |
196 | 199 | @pytest.mark.parametrize("func", ["first", "last", "min", "max", "count"]) |
197 | | -def test_xarray_reduce_cftime_var(engine, func): |
| 200 | +def test_xarray_reduce_cftime_var(engine, indexer, expected_groups, func): |
198 | 201 | times = xr.date_range("1980-09-01 00:00", "1982-09-18 00:00", freq="ME", calendar="noleap") |
199 | 202 | ds = xr.Dataset({"var": ("time", times)}, coords={"time": np.repeat(np.arange(4), 6)}) |
| 203 | + ds = ds.isel(time=indexer) |
200 | 204 |
|
201 | | - actual = xarray_reduce(ds, ds.time, func=func) |
| 205 | + actual = xarray_reduce( |
| 206 | + ds, |
| 207 | + ds.time, |
| 208 | + func=func, |
| 209 | + fill_value=dtypes.NA if func in ["first", "last"] else np.nan, |
| 210 | + engine=engine, |
| 211 | + expected_groups=expected_groups, |
| 212 | + ) |
202 | 213 | expected = getattr(ds.groupby("time"), func)() |
| 214 | + if expected_groups is not None: |
| 215 | + expected = expected.reindex(time=expected_groups) |
203 | 216 | xr.testing.assert_identical(actual, expected) |
204 | 217 |
|
205 | 218 |
|
|
0 commit comments