Skip to content

Commit a13ca36

Browse files
committed
Use nanlen instead of _count
1 parent d471622 commit a13ca36

File tree

2 files changed

+15
-38
lines changed

2 files changed

+15
-38
lines changed

dask_groupby/aggregations.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -116,22 +116,8 @@ def nansum_of_squares(group_idx, array, size=None, fill_value=None):
116116
return sum_of_squares(group_idx, array, func="nansum", size=size, fill_value=fill_value)
117117

118118

119-
def _count(group_idx, array, size=None, fill_value=None):
120-
import numpy_groupies as npg
121-
122-
return npg.aggregate_numpy.aggregate(
123-
group_idx,
124-
~np.isnan(array),
125-
axis=-1,
126-
func="sum",
127-
size=size,
128-
fill_value=fill_value,
129-
dtype=np.intp,
130-
)
131-
132-
133119
count = Aggregation(
134-
"count", chunk=_count, combine="sum", fill_value=0, final_fill_value=0, dtype=int
120+
"count", chunk="nanlen", combine="sum", fill_value=0, final_fill_value=0, dtype=np.intp
135121
)
136122

137123
# note that the fill values are the result of np.func([np.nan, np.nan])
@@ -141,15 +127,15 @@ def _count(group_idx, array, size=None, fill_value=None):
141127
nanprod = Aggregation("nanprod", chunk="nanprod", combine="prod", fill_value=1, final_fill_value=1)
142128
mean = Aggregation(
143129
"mean",
144-
chunk=("sum", _count),
130+
chunk=("sum", "nanlen"),
145131
combine=("sum", "sum"),
146132
finalize=lambda sum_, count: sum_ / count,
147133
fill_value=(0, 0),
148134
dtype=np.float64,
149135
)
150136
nanmean = Aggregation(
151137
"nanmean",
152-
chunk=("nansum", _count),
138+
chunk=("nansum", "nanlen"),
153139
combine=("sum", "sum"),
154140
finalize=lambda sum_, count: sum_ / count,
155141
fill_value=(0, 0),
@@ -171,7 +157,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
171157
# var, std always promote to float, so we set nan
172158
var = Aggregation(
173159
"var",
174-
chunk=(sum_of_squares, "sum", _count),
160+
chunk=(sum_of_squares, "sum", "nanlen"),
175161
combine=("sum", "sum", "sum"),
176162
finalize=_var_finalize,
177163
fill_value=0,
@@ -180,7 +166,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
180166
)
181167
nanvar = Aggregation(
182168
"nanvar",
183-
chunk=(nansum_of_squares, "nansum", _count),
169+
chunk=(nansum_of_squares, "nansum", "nanlen"),
184170
combine=("sum", "sum", "sum"),
185171
finalize=_var_finalize,
186172
fill_value=0,
@@ -189,7 +175,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
189175
)
190176
std = Aggregation(
191177
"std",
192-
chunk=(sum_of_squares, "sum", _count),
178+
chunk=(sum_of_squares, "sum", "nanlen"),
193179
combine=("sum", "sum", "sum"),
194180
finalize=_std_finalize,
195181
fill_value=0,
@@ -198,7 +184,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
198184
)
199185
nanstd = Aggregation(
200186
"nanstd",
201-
chunk=(nansum_of_squares, "nansum", _count),
187+
chunk=(nansum_of_squares, "nansum", "nanlen"),
202188
combine=("sum", "sum", "sum"),
203189
finalize=_std_finalize,
204190
fill_value=0,

dask_groupby/core.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pandas as pd
2222

2323
from . import aggregations
24-
from .aggregations import Aggregation, _atleast_1d, _count, _get_fill_value
24+
from .aggregations import Aggregation, _atleast_1d, _get_fill_value
2525
from .xrutils import is_duck_array, is_duck_dask_array
2626

2727
if TYPE_CHECKING:
@@ -525,7 +525,7 @@ def chunk_reduce(
525525
size=size,
526526
# important when reducing with "offset" groups
527527
fill_value=fv,
528-
dtype=dtype,
528+
dtype=np.intp if reduction == "nanlen" else dtype,
529529
)
530530
if np.any(~mask):
531531
# remove NaN group label which should be last
@@ -578,13 +578,6 @@ def _finalize_results(
578578
2. Calling agg.finalize with intermediate results
579579
3. Mask using counts and fill with user-provided fill_value.
580580
4. reindex to expected_groups
581-
582-
Parameters
583-
----------
584-
mask_counts: bool
585-
Whether to mask out results using counts which is expected to be the last element in
586-
results["intermediates"]. Should be False when dask arrays are not involved.
587-
588581
"""
589582
squeezed = _squeeze_results(results, axis)
590583

@@ -682,7 +675,7 @@ def _conc2(key1, key2=None, axis=None) -> np.ndarray:
682675
if agg.reduction_type == "argreduce":
683676

684677
# If _count was added for masking later, we need to account for that
685-
if agg.chunk[-1] == _count:
678+
if agg.chunk[-1] == "nanlen":
686679
slicer = slice(None, -1)
687680
else:
688681
slicer = slice(None, None)
@@ -701,7 +694,7 @@ def _conc2(key1, key2=None, axis=None) -> np.ndarray:
701694
backend=backend,
702695
)
703696

704-
if agg.chunk[-1] == _count:
697+
if agg.chunk[-1] == "nanlen":
705698
counts = _conc2(key1="intermediates", key2=2, axis=axis)
706699
# sum the counts
707700
results["intermediates"].append(
@@ -1116,12 +1109,10 @@ def groupby_reduce(
11161109
# (agg.finalize = None). We still need to do the reindexing step in finalize
11171110
# so that everything matches the dask version.
11181111
reduction.finalize = None
1119-
# npg's count counts the number of groups
1120-
# we want to count the number of non-NaN array elements in each group
1121-
# So we use our custom _count instead of "count"
1122-
func = reduction.name if reduction.name != "count" else _count
1112+
# xarray's count is npg's nanlen
1113+
func = reduction.name if reduction.name != "count" else "nanlen"
11231114
if min_count is not None:
1124-
func = (func, _count)
1115+
func = (func, "nanlen")
11251116

11261117
results = chunk_reduce(
11271118
array,
@@ -1162,7 +1153,7 @@ def groupby_reduce(
11621153

11631154
# we need to explicitly track counts so that we can mask at the end
11641155
if fill_value is not None or min_count is not None:
1165-
reduction.chunk += (_count,)
1156+
reduction.chunk += ("nanlen",)
11661157
reduction.combine += ("sum",)
11671158
reduction.fill_value["intermediate"] += (0,)
11681159

0 commit comments

Comments
 (0)