Skip to content

Commit 58bc9be

Browse files
authored
Fix missing resampling groups. (#312)
* Fix missing resampling groups. Closes pydata/xarray#8592 * Update climatology notebook
1 parent 15abf49 commit 58bc9be

File tree

4 files changed

+112
-81
lines changed

4 files changed

+112
-81
lines changed

asv_bench/benchmarks/cohorts.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@ class NWMMidwest(Cohorts):
7373
def setup(self, *args, **kwargs):
7474
x = np.repeat(np.arange(30), 150)
7575
y = np.repeat(np.arange(30), 60)
76-
self.by = x[np.newaxis, :] * y[:, np.newaxis]
76+
by = x[np.newaxis, :] * y[:, np.newaxis]
77+
78+
self.by = flox.core._factorize_multiple(
79+
(by,), expected_groups=(None,), any_by_dask=False, reindex=False
80+
)[0][0]
7781

7882
self.array = dask.array.ones(self.by.shape, chunks=(350, 350))
7983
self.axis = (-2, -1)

docs/source/user-stories/climatology.ipynb

Lines changed: 71 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
"outputs": [],
2323
"source": [
2424
"import dask.array\n",
25-
"import matplotlib.pyplot as plt\n",
26-
"import numpy as np\n",
2725
"import pandas as pd\n",
2826
"import xarray as xr\n",
2927
"\n",
@@ -56,6 +54,27 @@
5654
"oisst"
5755
]
5856
},
57+
{
58+
"cell_type": "markdown",
59+
"id": "b7f519ee-e575-492c-a70b-8dad63a8c222",
60+
"metadata": {},
61+
"source": [
62+
"To account for Feb-29 being present in some years, we'll construct a time vector to group by as \"mmm-dd\" string.\n",
63+
"\n",
64+
"For more options, see https://strftime.org/"
65+
]
66+
},
67+
{
68+
"cell_type": "code",
69+
"execution_count": null,
70+
"id": "3c42a618-47bc-4c83-a902-ec4cf3420180",
71+
"metadata": {},
72+
"outputs": [],
73+
"source": [
74+
"day = oisst.time.dt.strftime(\"%h-%d\").rename(\"day\")\n",
75+
"day"
76+
]
77+
},
5978
{
6079
"cell_type": "markdown",
6180
"id": "6d913e7f-25bd-43c4-98b6-93bcb420c524",
@@ -80,7 +99,7 @@
8099
"source": [
81100
"flox.xarray.xarray_reduce(\n",
82101
" oisst,\n",
83-
" oisst.time.dt.dayofyear,\n",
102+
" day,\n",
84103
" func=\"mean\",\n",
85104
" method=\"map-reduce\",\n",
86105
")"
@@ -106,7 +125,7 @@
106125
"source": [
107126
"flox.xarray.xarray_reduce(\n",
108127
" oisst.chunk({\"lat\": -1, \"lon\": 120}),\n",
109-
" oisst.time.dt.dayofyear,\n",
128+
" day,\n",
110129
" func=\"mean\",\n",
111130
" method=\"map-reduce\",\n",
112131
")"
@@ -143,7 +162,7 @@
143162
"source": [
144163
"flox.xarray.xarray_reduce(\n",
145164
" oisst,\n",
146-
" oisst.time.dt.dayofyear,\n",
165+
" day,\n",
147166
" func=\"mean\",\n",
148167
" method=\"cohorts\",\n",
149168
")"
@@ -160,10 +179,7 @@
160179
"[click here](https://flox.readthedocs.io/en/latest/implementation.html#method-cohorts)).\n",
161180
"Now we have the opposite problem: the chunk sizes on the output are too small.\n",
162181
"\n",
163-
"Looking more closely, We can see the cohorts that `flox` has detected are not\n",
164-
"really cohorts, each cohort is a single group label. We've replicated Xarray's\n",
165-
"current strategy; what flox calls\n",
166-
"[\"split-reduce\"](https://flox.readthedocs.io/en/latest/implementation.html#method-split-reduce-xarray-s-current-groupby-strategy)\n"
182+
"Let us inspect the cohorts"
167183
]
168184
},
169185
{
@@ -173,112 +189,81 @@
173189
"metadata": {},
174190
"outputs": [],
175191
"source": [
176-
"flox.core.find_group_cohorts(\n",
177-
" labels=oisst.time.dt.dayofyear.data,\n",
192+
"# integer codes for each \"day\"\n",
193+
"codes, _ = pd.factorize(day.data)\n",
194+
"cohorts = flox.core.find_group_cohorts(\n",
195+
" labels=codes,\n",
178196
" chunks=(oisst.chunksizes[\"time\"],),\n",
179-
").values()"
197+
")\n",
198+
"print(len(cohorts))"
180199
]
181200
},
182201
{
183202
"cell_type": "markdown",
184-
"id": "bcbdbb3b-2aed-4f3f-ad20-efabb52b5e68",
203+
"id": "068b4109-b7f4-4c16-918d-9a18ff2ed183",
185204
"metadata": {},
186205
"source": [
187-
"## Rechunking data for cohorts\n",
188-
"\n",
189-
"Can we fix the \"out of phase\" problem by rechunking along time?\n",
190-
"\n",
191-
"First lets see where the current chunk boundaries are\n"
206+
"Looking more closely, we can see many cohorts with a single entry. "
192207
]
193208
},
194209
{
195210
"cell_type": "code",
196211
"execution_count": null,
197-
"id": "90a884bc-1b71-4874-8143-73b3b5c41458",
212+
"id": "57983cd0-a2e0-4d16-abe6-9572f6f252bf",
198213
"metadata": {},
199214
"outputs": [],
200215
"source": [
201-
"array = oisst.data\n",
202-
"labels = oisst.time.dt.dayofyear.data\n",
203-
"axis = oisst.get_axis_num(\"time\")\n",
204-
"oldchunks = array.chunks[axis]\n",
205-
"oldbreaks = np.insert(np.cumsum(oldchunks), 0, 0)\n",
206-
"labels_at_breaks = labels[oldbreaks[:-1]]\n",
207-
"labels_at_breaks"
216+
"cohorts.values()"
208217
]
209218
},
210219
{
211220
"cell_type": "markdown",
212-
"id": "4b2573e5-0d30-4cb8-b5af-751b824f0689",
221+
"id": "bcbdbb3b-2aed-4f3f-ad20-efabb52b5e68",
213222
"metadata": {},
214223
"source": [
215-
"Now we'll use a convenient function `rechunk_for_cohorts` to rechunk the `oisst`\n",
216-
"dataset along time. We'll ask it to rechunk so that a new chunk starts at each\n",
217-
"of the elements\n",
224+
"## Rechunking data for cohorts\n",
218225
"\n",
219-
"```\n",
220-
"[244, 264, 284, 304, 324, 344, 364, 19, 39, 59, 79, 99, 119,\n",
221-
" 139, 159, 179, 199, 219, 239]\n",
222-
"```\n",
226+
"Can we fix the \"out of phase\" problem by rechunking along time?\n",
223227
"\n",
224-
"These are labels at the chunk boundaries in the first year of data. We are\n",
225-
"forcing that chunking pattern to repeat as much as possible. We also tell the\n",
226-
"function to ignore any existing chunk boundaries.\n"
228+
"First lets see where the current chunk boundaries are"
227229
]
228230
},
229231
{
230232
"cell_type": "code",
231233
"execution_count": null,
232-
"id": "a9ab6382-e93b-49e9-8e2e-1ba526046aea",
234+
"id": "40d393a5-7a4e-4d33-997b-4c422a0b8100",
233235
"metadata": {},
234236
"outputs": [],
235237
"source": [
236-
"rechunked = flox.xarray.rechunk_for_cohorts(\n",
237-
" oisst,\n",
238-
" dim=\"time\",\n",
239-
" labels=oisst.time.dt.dayofyear,\n",
240-
" force_new_chunk_at=[\n",
241-
" 244,\n",
242-
" 264,\n",
243-
" 284,\n",
244-
" 304,\n",
245-
" 324,\n",
246-
" 344,\n",
247-
" 364,\n",
248-
" 19,\n",
249-
" 39,\n",
250-
" 59,\n",
251-
" 79,\n",
252-
" 99,\n",
253-
" 119,\n",
254-
" 139,\n",
255-
" 159,\n",
256-
" 179,\n",
257-
" 199,\n",
258-
" 219,\n",
259-
" 239,\n",
260-
" ],\n",
261-
" ignore_old_chunks=True,\n",
262-
")\n",
263-
"rechunked"
238+
"oisst.chunksizes[\"time\"][:10]"
264239
]
265240
},
266241
{
267242
"cell_type": "markdown",
268-
"id": "570d869b-9612-4de9-83ee-336a35c1fdad",
243+
"id": "cd0033a3-d211-4aef-a284-c9fd3f75f6e4",
244+
"metadata": {},
245+
"source": [
246+
"We'll choose to rechunk such that a single month in is a chunk. This is not too different from the current chunking but will help your periodicity problem"
247+
]
248+
},
249+
{
250+
"cell_type": "code",
251+
"execution_count": null,
252+
"id": "5914a350-a7db-49b3-9504-6d63ff874f5e",
269253
"metadata": {},
254+
"outputs": [],
270255
"source": [
271-
"We see that chunks are mostly 20 elements long in time with some differences\n"
256+
"newchunks = xr.ones_like(day).astype(int).resample(time=\"M\").count()"
272257
]
273258
},
274259
{
275260
"cell_type": "code",
276261
"execution_count": null,
277-
"id": "86bb4461-d921-40f8-9ff7-8d6e7e8c7e4b",
262+
"id": "90a884bc-1b71-4874-8143-73b3b5c41458",
278263
"metadata": {},
279264
"outputs": [],
280265
"source": [
281-
"plt.plot(rechunked.chunksizes[\"time\"], marker=\"x\", ls=\"none\")"
266+
"rechunked = oisst.chunk(time=tuple(newchunks.data))"
282267
]
283268
},
284269
{
@@ -296,10 +281,22 @@
296281
"metadata": {},
297282
"outputs": [],
298283
"source": [
299-
"flox.core.find_group_cohorts(\n",
300-
" labels=rechunked.time.dt.dayofyear.data,\n",
284+
"new_cohorts = flox.core.find_group_cohorts(\n",
285+
" labels=codes,\n",
301286
" chunks=(rechunked.chunksizes[\"time\"],),\n",
302-
").values()"
287+
")\n",
288+
"# one cohort per month!\n",
289+
"len(new_cohorts)"
290+
]
291+
},
292+
{
293+
"cell_type": "code",
294+
"execution_count": null,
295+
"id": "4e2b6f70-c057-4783-ad55-21b20ff27e7f",
296+
"metadata": {},
297+
"outputs": [],
298+
"source": [
299+
"new_cohorts.values()"
303300
]
304301
},
305302
{
@@ -318,7 +315,7 @@
318315
"metadata": {},
319316
"outputs": [],
320317
"source": [
321-
"flox.xarray.xarray_reduce(rechunked, rechunked.time.dt.dayofyear, func=\"mean\", method=\"cohorts\")"
318+
"flox.xarray.xarray_reduce(rechunked, day, func=\"mean\", method=\"cohorts\")"
322319
]
323320
},
324321
{

flox/core.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -293,13 +293,16 @@ def find_group_cohorts(
293293

294294
# can happen when `expected_groups` is passed but not all labels are present
295295
# (binning, resampling)
296-
present_labels = chunks_per_label != 0
297-
if not present_labels.all():
298-
bitmask = bitmask[..., present_labels]
296+
present_labels = np.arange(bitmask.shape[LABEL_AXIS])
297+
present_labels_mask = chunks_per_label != 0
298+
if not present_labels_mask.all():
299+
present_labels = present_labels[present_labels_mask]
300+
bitmask = bitmask[..., present_labels_mask]
301+
chunks_per_label = chunks_per_label[present_labels_mask]
299302

300303
label_chunks = {
301-
lab: bitmask.indices[slice(bitmask.indptr[lab], bitmask.indptr[lab + 1])]
302-
for lab in range(bitmask.shape[-1])
304+
present_labels[idx]: bitmask.indices[slice(bitmask.indptr[idx], bitmask.indptr[idx + 1])]
305+
for idx in range(bitmask.shape[LABEL_AXIS])
303306
}
304307

305308
# Invert the label_chunks mapping so we know which labels occur together.
@@ -334,7 +337,7 @@ def invert(x) -> tuple[np.ndarray, ...]:
334337
# - S is the existing set
335338
MIN_CONTAINMENT = 0.75 # arbitrary
336339
asfloat = bitmask.astype(float)
337-
containment = ((asfloat.T @ asfloat) / chunks_per_label[present_labels]).tocsr()
340+
containment = ((asfloat.T @ asfloat) / chunks_per_label).tocsr()
338341
mask = containment.data < MIN_CONTAINMENT
339342
containment.data[mask] = 0
340343
containment.eliminate_zeros()

tests/test_xarray.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,3 +629,30 @@ def test_groupby_2d_dataset():
629629
expected.counts.dims == actual.counts.dims
630630
) # https://github.com/pydata/xarray/issues/8292
631631
xr.testing.assert_identical(expected, actual)
632+
633+
634+
@pytest.mark.parametrize("chunk", (pytest.param(True, marks=requires_dask), False))
635+
def test_resampling_missing_groups(chunk):
636+
# Regression test for https://github.com/pydata/xarray/issues/8592
637+
time_coords = pd.to_datetime(
638+
["2018-06-13T03:40:36", "2018-06-13T05:50:37", "2018-06-15T03:02:34"]
639+
)
640+
641+
latitude_coords = [0.0]
642+
longitude_coords = [0.0]
643+
644+
data = [[[1.0]], [[2.0]], [[3.0]]]
645+
646+
da = xr.DataArray(
647+
data,
648+
coords={"time": time_coords, "latitude": latitude_coords, "longitude": longitude_coords},
649+
dims=["time", "latitude", "longitude"],
650+
)
651+
if chunk:
652+
da = da.chunk(time=1)
653+
# Without chunking the dataarray, it works:
654+
with xr.set_options(use_flox=False):
655+
expected = da.resample(time="1D").mean()
656+
with xr.set_options(use_flox=True):
657+
actual = da.resample(time="1D").mean()
658+
xr.testing.assert_identical(expected, actual)

0 commit comments

Comments
 (0)