Skip to content

Commit 68dc5d7

Browse files
committed
Fix arg reductions and filling
I changed -1 to 0 for intermediate dask reductions; This *should* have no effect but allows an unravelling if needed
1 parent 4c1c3f2 commit 68dc5d7

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

dask_groupby/aggregations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def _zip_index(array_, idx_):
288288
chunk=("max", "argmax"), # order is important
289289
combine=("max", "argmax"),
290290
reduction_type="argreduce",
291-
fill_value=(dtypes.NINF, -1),
291+
fill_value=(dtypes.NINF, 0),
292292
final_fill_value=-1,
293293
finalize=lambda *x: x[1],
294294
dtype=np.intp,
@@ -300,7 +300,7 @@ def _zip_index(array_, idx_):
300300
chunk=("min", "argmin"), # order is important
301301
combine=("min", "argmin"),
302302
reduction_type="argreduce",
303-
fill_value=(dtypes.INF, -1),
303+
fill_value=(dtypes.INF, 0),
304304
final_fill_value=-1,
305305
finalize=lambda *x: x[1],
306306
dtype=np.intp,

dask_groupby/core.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,12 +1186,18 @@ def groupby_reduce(
11861186
) # type: ignore
11871187

11881188
if reduction.name in ["argmin", "argmax", "nanargmax", "nanargmin"]:
1189-
if array.ndim > 1 and by.ndim == 1:
1189+
if array.ndim > 1:
1190+
# default fill_value is -1; we can't unravel that;
1191+
# so replace -1 with 0; unravel; then replace 0 with -1
1192+
# UGH!
1193+
idx = results["intermediates"][0]
1194+
mask = idx == -1
1195+
idx[mask] = 0
11901196
# Fix npg bug where argmax with nD array, 1D group_idx, axis=-1
11911197
# will return wrong indices
1192-
results["intermediates"][0] = np.unravel_index(
1193-
results["intermediates"][0], array.shape
1194-
)[-1]
1198+
idx = np.unravel_index(idx, array.shape)[-1]
1199+
idx[mask] = -1
1200+
results["intermediates"][0] = idx
11951201
elif reduction.name in ["nanvar", "nanstd"]:
11961202
# Fix npg bug where all-NaN rows are 0 instead of NaN
11971203
value, counts = results["intermediates"]

tests/test_core.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -367,9 +367,8 @@ def test_dask_reduce_axis_subset():
367367
"axis", [None, (0, 1, 2), (0, 1), (0, 2), (1, 2), 0, 1, 2, (0,), (1,), (2,)]
368368
)
369369
def test_groupby_reduce_axis_subset_against_numpy(func, axis, backend):
370-
if not isinstance(axis, int):
371-
if "arg" in func and (axis is None or len(axis) > 1):
372-
pytest.skip()
370+
if not isinstance(axis, int) and "arg" in func and (axis is None or len(axis) > 1):
371+
pytest.skip()
373372
if func in ["all", "any"]:
374373
fill_value = False
375374
else:

0 commit comments

Comments
 (0)