Skip to content

Commit 8ea0cd1

Browse files
authored
Cleanup (#315)
* Cleanup * seen_groups for numbagg only * Use _atleast_1d in more places.
1 parent 0c4a7f9 commit 8ea0cd1

File tree

3 files changed

+15
-37
lines changed

3 files changed

+15
-37
lines changed

flox/aggregations.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,10 @@ def _get_fill_value(dtype, fill_value):
133133
return fill_value
134134

135135

136-
def _atleast_1d(inp):
136+
def _atleast_1d(inp, min_length: int = 1):
137137
if xrutils.is_scalar(inp):
138-
inp = (inp,)
138+
inp = (inp,) * min_length
139+
assert len(inp) >= min_length
139140
return inp
140141

141142

flox/core.py

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -803,29 +803,11 @@ def chunk_reduce(
803803
dict
804804
"""
805805

806-
if not (isinstance(func, str) or callable(func)):
807-
funcs = func
808-
else:
809-
funcs = (func,)
806+
funcs = _atleast_1d(func)
810807
nfuncs = len(funcs)
811-
812-
if isinstance(dtype, Sequence):
813-
dtypes = dtype
814-
else:
815-
dtypes = (dtype,) * nfuncs
816-
assert len(dtypes) >= nfuncs
817-
818-
if isinstance(fill_value, Sequence):
819-
fill_values = fill_value
820-
else:
821-
fill_values = (fill_value,) * nfuncs
822-
assert len(fill_values) >= nfuncs
823-
824-
if isinstance(kwargs, Sequence):
825-
kwargss = kwargs
826-
else:
827-
kwargss = ({},) * nfuncs
828-
assert len(kwargss) >= nfuncs
808+
dtypes = _atleast_1d(dtype, nfuncs)
809+
fill_values = _atleast_1d(fill_value, nfuncs)
810+
kwargss = _atleast_1d({}, nfuncs) if kwargs is None else kwargs
829811

830812
if isinstance(axis, Sequence):
831813
axes: T_Axes = axis
@@ -862,7 +844,8 @@ def chunk_reduce(
862844

863845
# do this *before* possible broadcasting below.
864846
# factorize_ has already taken care of offsetting
865-
seen_groups = _unique(group_idx)
847+
if engine == "numbagg":
848+
seen_groups = _unique(group_idx)
866849

867850
order = "C"
868851
if nax > 1:
@@ -1551,12 +1534,9 @@ def dask_groupby_agg(
15511534
groups = _extract_unknown_groups(reduced, dtype=by.dtype)
15521535
group_chunks = ((np.nan,),)
15531536
else:
1554-
if expected_groups is None:
1555-
expected_groups_ = _get_expected_groups(by_input, sort=sort)
1556-
else:
1557-
expected_groups_ = expected_groups
1558-
groups = (expected_groups_.to_numpy(),)
1559-
group_chunks = ((len(expected_groups_),),)
1537+
assert expected_groups is not None
1538+
groups = (expected_groups.to_numpy(),)
1539+
group_chunks = ((len(expected_groups),),)
15601540

15611541
elif method == "cohorts":
15621542
chunks_cohorts = find_group_cohorts(
@@ -2063,10 +2043,7 @@ def groupby_reduce(
20632043
is_bool_array = np.issubdtype(array.dtype, bool)
20642044
array = array.astype(int) if is_bool_array else array
20652045

2066-
if isinstance(isbin, Sequence):
2067-
isbins = isbin
2068-
else:
2069-
isbins = (isbin,) * nby
2046+
isbins = _atleast_1d(isbin, nby)
20702047

20712048
_assert_by_is_aligned(array.shape, bys)
20722049

flox/xrutils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,15 @@ def __dask_tokenize__(self):
8484
def is_scalar(value: Any, include_0d: bool = True) -> bool:
8585
"""Whether to treat a value as a scalar.
8686
87-
Any non-iterable, string, or 0-D array
87+
Any non-iterable, string, dict, or 0-D array
8888
"""
8989
NON_NUMPY_SUPPORTED_ARRAY_TYPES = (dask_array_type, pd.Index)
9090

9191
if include_0d:
9292
include_0d = getattr(value, "ndim", None) == 0
9393
return (
9494
include_0d
95-
or isinstance(value, (str, bytes))
95+
or isinstance(value, (str, bytes, dict))
9696
or not (
9797
isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES)
9898
or hasattr(value, "__array_function__")

0 commit comments

Comments
 (0)