Skip to content
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
776bc5a
use cumsum from flox
Illviljan Dec 6, 2025
ae27632
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
a5f9326
Update groupby.py
Illviljan Dec 6, 2025
50ccca4
Update groupby.py
Illviljan Dec 6, 2025
f55531e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
06ac372
Update groupby.py
Illviljan Dec 6, 2025
31244e6
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
dd47536
Update groupby.py
Illviljan Dec 6, 2025
e867f12
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
88e0ebc
Update groupby.py
Illviljan Dec 6, 2025
181d4a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
a82ec39
use apply_ufunc for dataset and dataarray handling
Illviljan Dec 6, 2025
6c6abed
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
24c3f1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
d8d0eaa
Update groupby.py
Illviljan Dec 6, 2025
55ff46a
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
33d1360
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
c97ae98
sync protocols with each other
Illviljan Dec 6, 2025
06b52ae
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
84f9b44
typing
Illviljan Dec 6, 2025
2978877
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
0a9adee
add dataset and version requirement
Illviljan Dec 6, 2025
ae9a3d8
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 6, 2025
c056d1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
d4873b9
Update _aggregations.py
Illviljan Dec 6, 2025
21cbde2
Update xarray/core/groupby.py
Illviljan Dec 6, 2025
4aebc47
Update groupby.py
Illviljan Dec 6, 2025
f4cab24
Update groupby.py
Illviljan Dec 6, 2025
23d9d50
Update groupby.py
Illviljan Dec 6, 2025
9b64db2
Update generate_aggregations.py
Illviljan Dec 6, 2025
928b158
Renove workaround in test
Illviljan Dec 7, 2025
130f98e
Update _aggregations.py
Illviljan Dec 7, 2025
5a3e754
Update _aggregations.py
Illviljan Dec 7, 2025
d912cda
Update test_groupby.py
Illviljan Dec 7, 2025
3bc8dc7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2025
ec8ffd6
clean ups
Illviljan Dec 7, 2025
b0cf8c4
Merge branch 'main' into cumsum_flox
Illviljan Dec 7, 2025
07a4d35
Add expected groups, add options
Illviljan Dec 8, 2025
d0f7ed2
Update groupby.py
Illviljan Dec 8, 2025
098be30
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2025
87d5f77
expeced_groups not supported in groupby_scan
Illviljan Dec 8, 2025
16c93ea
Merge branch 'cumsum_flox' of https://github.com/Illviljan/xarray int…
Illviljan Dec 8, 2025
dfe269a
Update _aggregations.py
Illviljan Dec 9, 2025
b2c3d51
Update _aggregations.py
Illviljan Dec 9, 2025
e28f458
Update generate_aggregations.py
Illviljan Dec 9, 2025
55a36ab
Update _aggregations.py
Illviljan Dec 9, 2025
ff531e1
Update _aggregations.py
Illviljan Dec 9, 2025
43aad2e
Update _aggregations.py
Illviljan Dec 9, 2025
8dfcc56
Update _aggregations.py
Illviljan Dec 9, 2025
9dac0a4
Update generate_aggregations.py
Illviljan Dec 9, 2025
0ba3504
Update _aggregations.py
Illviljan Dec 9, 2025
da2a3e3
Update _aggregations.py
Illviljan Dec 9, 2025
95e6fd3
Update _aggregations.py
Illviljan Dec 9, 2025
7d358b0
Update _aggregations.py
Illviljan Dec 9, 2025
f4fe7a0
Update _aggregations.py
Illviljan Dec 9, 2025
74f1073
Update _aggregations.py
Illviljan Dec 9, 2025
50f6209
Update _aggregations.py
Illviljan Dec 9, 2025
87675b2
Update _aggregations.py
Illviljan Dec 9, 2025
9aee62e
Update _aggregations.py
Illviljan Dec 9, 2025
02ee023
Update _aggregations.py
Illviljan Dec 9, 2025
82557c4
Update _aggregations.py
Illviljan Dec 9, 2025
9721574
Update _aggregations.py
Illviljan Dec 9, 2025
e1fba81
Update _aggregations.py
Illviljan Dec 9, 2025
5137fd8
Update _aggregations.py
Illviljan Dec 9, 2025
59a7f38
Update _aggregations.py
Illviljan Dec 9, 2025
7f519f0
Update _aggregations.py
Illviljan Dec 9, 2025
c4f5f83
Update _aggregations.py
Illviljan Dec 9, 2025
bf5197d
Update _aggregations.py
Illviljan Dec 9, 2025
5563600
Update _aggregations.py
Illviljan Dec 9, 2025
510300d
Update _aggregations.py
Illviljan Dec 9, 2025
5fe07df
Update _aggregations.py
Illviljan Dec 9, 2025
293cc1f
Update _aggregations.py
Illviljan Dec 9, 2025
d9f694c
Update _aggregations.py
Illviljan Dec 9, 2025
c9814db
Update _aggregations.py
Illviljan Dec 10, 2025
6ed0f99
Update _aggregations.py
Illviljan Dec 10, 2025
43a827d
Update test_groupby.py
Illviljan Dec 10, 2025
8d65562
Update test_groupby.py
Illviljan Dec 10, 2025
d19bbca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2025
acf4022
Update test_groupby.py
Illviljan Dec 10, 2025
f263da6
Update generate_aggregations.py
Illviljan Dec 10, 2025
e56d0b8
Update test_groupby.py
Illviljan Dec 10, 2025
8cbfd9d
Merge branch 'main' into cumsum_flox
Illviljan Dec 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 66 additions & 15 deletions xarray/core/_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3655,6 +3655,17 @@ def _flox_reduce(
) -> Dataset:
raise NotImplementedError()

def _flox_scan(
self,
dim: Dims,
*,
func: str,
skipna: bool | None = None,
keep_attrs: bool | None = None,
**kwargs: Any,
) -> Dataset:
raise NotImplementedError()
Copy link
Contributor Author

@Illviljan Illviljan Dec 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made these changes manually now.
I'm not getting pytest-accept to correctly fix the docstrings in _aggregations.py, it's for example not indenting correctly. I'm not sure if this is just a Windows 10 thing.


def count(
self,
dim: Dims = None,
Expand Down Expand Up @@ -5015,14 +5026,28 @@ def cumsum(
Data variables:
da (time) float64 48B 1.0 2.0 3.0 3.0 4.0 nan
"""
return self.reduce(
duck_array_ops.cumsum,
dim=dim,
skipna=skipna,
numeric_only=True,
keep_attrs=keep_attrs,
**kwargs,
)
if (
flox_available
and OPTIONS["use_flox"]
and module_available("flox", minversion="0.10.5")
and contains_only_chunked_or_numpy(self._obj)
):
return self._flox_scan(
func="cumsum",
dim=dim,
skipna=skipna,
# fill_value=fill_value,
keep_attrs=keep_attrs,
**kwargs,
)
else:
return self.reduce(
duck_array_ops.cumsum,
dim=dim,
skipna=skipna,
keep_attrs=keep_attrs,
**kwargs,
)

def cumprod(
self,
Expand Down Expand Up @@ -6647,6 +6672,17 @@ def _flox_reduce(
) -> DataArray:
raise NotImplementedError()

def _flox_scan(
self,
dim: Dims,
*,
func: str,
skipna: bool | None = None,
keep_attrs: bool | None = None,
**kwargs: Any,
) -> DataArray:
raise NotImplementedError()

def count(
self,
dim: Dims = None,
Expand Down Expand Up @@ -7904,13 +7940,28 @@ def cumsum(
* time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30
labels (time) <U1 24B 'a' 'b' 'c' 'c' 'b' 'a'
"""
return self.reduce(
duck_array_ops.cumsum,
dim=dim,
skipna=skipna,
keep_attrs=keep_attrs,
**kwargs,
)
if (
flox_available
and OPTIONS["use_flox"]
and module_available("flox", minversion="0.10.5")
and contains_only_chunked_or_numpy(self._obj)
):
return self._flox_scan(
func="cumsum",
dim=dim,
skipna=skipna,
# fill_value=fill_value,
keep_attrs=keep_attrs,
**kwargs,
)
else:
return self.reduce(
duck_array_ops.cumsum,
dim=dim,
skipna=skipna,
keep_attrs=keep_attrs,
**kwargs,
)

def cumprod(
self,
Expand Down
93 changes: 77 additions & 16 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from packaging.version import Version

from xarray.computation import ops
from xarray.computation.apply_ufunc import apply_ufunc
from xarray.computation.arithmetic import (
DataArrayGroupbyArithmetic,
DatasetGroupbyArithmetic,
Expand Down Expand Up @@ -1028,6 +1029,26 @@ def _maybe_unstack(self, obj):

return obj

def _parse_dim(self, dim: Dims) -> tuple[Hashable, ...]:
parsed_dim: tuple[Hashable, ...]
if isinstance(dim, str):
parsed_dim = (dim,)
elif dim is None:
parsed_dim_list = list()
# preserve order
for dim_ in itertools.chain(
*(grouper.codes.dims for grouper in self.groupers)
):
if dim_ not in parsed_dim_list:
parsed_dim_list.append(dim_)
parsed_dim = tuple(parsed_dim_list)
elif dim is ...:
parsed_dim = tuple(self._original_obj.dims)
else:
parsed_dim = tuple(dim)

return parsed_dim

def _flox_reduce(
self,
dim: Dims,
Expand Down Expand Up @@ -1088,22 +1109,7 @@ def _flox_reduce(
# set explicitly to avoid unnecessarily accumulating count
kwargs["min_count"] = 0

parsed_dim: tuple[Hashable, ...]
if isinstance(dim, str):
parsed_dim = (dim,)
elif dim is None:
parsed_dim_list = list()
# preserve order
for dim_ in itertools.chain(
*(grouper.codes.dims for grouper in self.groupers)
):
if dim_ not in parsed_dim_list:
parsed_dim_list.append(dim_)
parsed_dim = tuple(parsed_dim_list)
elif dim is ...:
parsed_dim = tuple(obj.dims)
else:
parsed_dim = tuple(dim)
parsed_dim = self._parse_dim(dim)

# Do this so we raise the same error message whether flox is present or not.
# Better to control it here than in flox.
Expand Down Expand Up @@ -1202,6 +1208,61 @@ def _flox_reduce(

return result

def _flox_scan(
self,
dim: Dims,
*,
func: str,
skipna: bool | None = None,
keep_attrs: bool | None = None,
**kwargs: Any,
) -> DataArray:
from flox import groupby_scan

obj = self._original_obj

parsed_dim = self._parse_dim(dim)

axis = obj.get_axis_num(parsed_dim)
# axis = (axis_,) if isinstance(axis_, int) else axis_
codes = tuple(g.codes for g in self.groupers)

def wrapper(array, *by, func: str, skipna: bool | None, **kwargs):
if skipna or (skipna is None and obj.dtype.kind in "cfO"):
if "nan" not in func:
func = f"nan{func}"

return groupby_scan(array, *codes, func=func, **kwargs)

actual = apply_ufunc(
wrapper,
obj,
*codes,
# input_core_dims=input_core_dims,
# for xarray's test_groupby_duplicate_coordinate_labels
# exclude_dims=set(dim_tuple),
# output_core_dims=[output_core_dims],
dask="allowed",
# dask_gufunc_kwargs=dict(
# output_sizes=output_sizes,
# output_dtypes=[dtype] if dtype is not None else None,
# ),
keep_attrs=(
_get_keep_attrs(default=True) if keep_attrs is None else keep_attrs
),
kwargs=dict(
func=func,
skipna=skipna,
expected_groups=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be the same as _flox_reduce. This is an important optimization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expected_groups is not supported in groupby_scan. For a future PR I think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh dang

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a look and remember why this. What do we do for group 0 when a user says grouped_scan(np.array([1, 2, 3], by=[0, 1, 2], expected_groups=[1, 2])?

Copy link
Contributor Author

@Illviljan Illviljan Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking it should be np.nan (or fill_value) for groups missing in expected groups.

An analog could be

  • groupby_reduce "uses" __getitem__ to mask missing groups.
  • groupby_scan will have to "use" np.where(mask, np.nan) to continue masking but with the same shape.

My expected result:

import flox
import numpy as np


# groupby_reduce omits 0:
flox.groupby_reduce(
    np.array([1, 2, 3]), [0, 1, 2], func="sum", expected_groups=[1, 2]
)
# (array([2, 3]), array([1, 2]))

flox.groupby_scan(
    np.array([1, 2, 3]), [0, 1, 2], func="cumsum", expected_groups=[1, 2]
)
# array([np.nan, 2, 3])

axis=axis,
dtype=None,
method=None,
engine=None,
),
)

return actual

def fillna(self, value: Any) -> T_Xarray:
"""Fill missing values in this object by group.

Expand Down
21 changes: 18 additions & 3 deletions xarray/util/generate_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import textwrap
from dataclasses import dataclass, field
from typing import NamedTuple
from typing import Literal, NamedTuple

MODULE_PREAMBLE = '''\
"""Mixin classes with reduction operations."""
Expand Down Expand Up @@ -132,6 +132,17 @@ def _flox_reduce(
dim: Dims,
**kwargs: Any,
) -> {obj}:
raise NotImplementedError()

def _flox_scan(
self,
dim: Dims,
*,
func: str,
skipna: bool | None = None,
keep_attrs: bool | None = None,
**kwargs: Any,
) -> DataArray:
raise NotImplementedError()"""

TEMPLATE_REDUCTION_SIGNATURE = '''
Expand Down Expand Up @@ -284,6 +295,7 @@ def __init__(
see_also_methods=(),
min_flox_version=None,
additional_notes="",
flox_aggregation_type: Literal["reduce", "scan"] = "reduce",
):
self.name = name
self.extra_kwargs = extra_kwargs
Expand All @@ -292,6 +304,7 @@ def __init__(
self.see_also_methods = see_also_methods
self.min_flox_version = min_flox_version
self.additional_notes = additional_notes
self.flox_aggregation_type = flox_aggregation_type
if bool_reduce:
self.array_method = f"array_{name}"
self.np_example_array = (
Expand Down Expand Up @@ -444,7 +457,7 @@ def generate_code(self, method, has_keep_attrs):

# median isn't enabled yet, because it would break if a single group was present in multiple
# chunks. The non-flox code path will just rechunk every group to a single chunk and execute the median
method_is_not_flox_supported = method.name in ("median", "cumsum", "cumprod")
method_is_not_flox_supported = method.name in ("median", "cumprod")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI in a future PR, I'd like to use the new flox.is_supported_aggregation here. It's a little smarter about this dispatching. We'll also have to figure out what to do about median which currently auto-rechunks so it always works.

if method_is_not_flox_supported:
indent = 12
else:
Expand Down Expand Up @@ -476,7 +489,7 @@ def generate_code(self, method, has_keep_attrs):
+ f"""
and contains_only_chunked_or_numpy(self._obj)
):
return self._flox_reduce(
return self._flox_{method.flox_aggregation_type}(
func="{method.name}",
dim=dim,{extra_kwargs}
# fill_value=fill_value,
Expand Down Expand Up @@ -537,6 +550,8 @@ def generate_code(self, method, has_keep_attrs):
numeric_only=True,
see_also_methods=("cumulative",),
additional_notes=_CUM_NOTES,
min_flox_version="0.10.5",
flox_aggregation_type="scan",
),
Method(
"cumprod",
Expand Down
Loading