@@ -2823,9 +2823,6 @@ def groupby_scan(
28232823 # nothing to do, no NaNs!
28242824 return array
28252825
2826- is_bool_array = np .issubdtype (array .dtype , bool )
2827- array = array .astype (np .int_ ) if is_bool_array else array
2828-
28292826 if expected_groups is not None :
28302827 raise NotImplementedError ("Setting `expected_groups` and binning is not supported yet." )
28312828 expected_groups = _validate_expected_groups (nby , expected_groups )
@@ -2855,6 +2852,11 @@ def groupby_scan(
28552852 if array .dtype .kind in "Mm" :
28562853 cast_to = array .dtype
28572854 array = array .view (np .int64 )
2855+ elif array .dtype .kind == "b" :
2856+ array = array .view (np .int8 )
2857+ cast_to = None
2858+ if agg .preserves_dtype :
2859+ cast_to = bool
28582860 else :
28592861 cast_to = None
28602862
@@ -2869,6 +2871,7 @@ def groupby_scan(
28692871 agg .dtype = np .result_type (array .dtype , np .uint )
28702872 else :
28712873 agg .dtype = array .dtype if dtype is None else dtype
2874+ agg .identity = xrdtypes ._get_fill_value (agg .dtype , agg .identity )
28722875
28732876 (single_axis ,) = axis_ # type: ignore[misc]
28742877 # avoid some roundoff error when we can.
@@ -2887,7 +2890,7 @@ def groupby_scan(
28872890
28882891 if not has_dask :
28892892 final_state = chunk_scan (inp , axis = single_axis , agg = agg , dtype = agg .dtype )
2890- result = _finalize_scan (final_state )
2893+ result = _finalize_scan (final_state , dtype = agg . dtype )
28912894 else :
28922895 result = dask_groupby_scan (inp .array , inp .group_idx , axes = axis_ , agg = agg )
28932896
@@ -2940,9 +2943,9 @@ def _zip(group_idx: np.ndarray, array: np.ndarray) -> AlignedArrays:
29402943 return AlignedArrays (group_idx = group_idx , array = array )
29412944
29422945
2943- def _finalize_scan (block : ScanState ) -> np .ndarray :
2946+ def _finalize_scan (block : ScanState , dtype ) -> np .ndarray :
29442947 assert block .result is not None
2945- return block .result .array
2948+ return block .result .array . astype ( dtype , copy = False )
29462949
29472950
29482951def dask_groupby_scan (array , by , axes : T_Axes , agg : Scan ) -> DaskArray :
@@ -2985,7 +2988,7 @@ def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray:
29852988 )
29862989
29872990 # 3. Unzip and extract the final result array, discard groups
2988- result = map_blocks (_finalize_scan , accumulated , dtype = agg .dtype )
2991+ result = map_blocks (partial ( _finalize_scan , dtype = agg . dtype ) , accumulated , dtype = agg .dtype )
29892992
29902993 assert result .chunks == array .chunks
29912994
0 commit comments