@@ -413,6 +413,7 @@ def chunk_reduce(
413413 reindex : bool = False ,
414414 isbin : bool = False ,
415415 backend : str = "numpy" ,
416+ kwargs = None ,
416417) -> IntermediateDict :
417418 """
418419 Wrapper for numpy_groupies aggregate that supports nD ``array`` and
@@ -458,6 +459,9 @@ def chunk_reduce(
458459 if not isinstance (fill_value , Sequence ):
459460 fill_value = (fill_value ,)
460461
462+ if kwargs is None :
463+ kwargs = ({},) * len (func )
464+
461465 # when axis is a tuple
462466 # collapse and move reduction dimensions to the end
463467 if isinstance (axis , Sequence ) and len (axis ) < by .ndim :
@@ -503,7 +507,7 @@ def chunk_reduce(
503507 final_array_shape += results ["groups" ].shape
504508 final_groups_shape += results ["groups" ].shape
505509
506- for reduction , fv in zip (func , fill_value ):
510+ for reduction , fv , kw in zip (func , fill_value , kwargs ):
507511 if empty :
508512 result = np .full (shape = final_array_shape , fill_value = fv )
509513 else :
@@ -516,6 +520,7 @@ def chunk_reduce(
516520 size = size ,
517521 # important when reducing with "offset" groups
518522 fill_value = fv ,
523+ ** kw ,
519524 )
520525 else :
521526 result = _get_aggregate (backend )(
@@ -527,6 +532,7 @@ def chunk_reduce(
527532 # important when reducing with "offset" groups
528533 fill_value = fv ,
529534 dtype = np .intp if reduction == "nanlen" else dtype ,
535+ ** kw ,
530536 )
531537 if np .any (~ mask ):
532538 # remove NaN group label which should be last
@@ -573,6 +579,7 @@ def _finalize_results(
573579 expected_groups : Union [Sequence , np .ndarray , None ],
574580 fill_value : Any ,
575581 min_count : Optional [int ] = None ,
582+ finalize_kwargs : Optional [Mapping ] = None ,
576583):
577584 """Finalize results by
578585 1. Squeezing out dummy dimensions
@@ -595,10 +602,11 @@ def _finalize_results(
595602 if fill_value is not None :
596603 counts = squeezed ["intermediates" ][- 1 ]
597604 squeezed ["intermediates" ] = squeezed ["intermediates" ][:- 1 ]
598-
599605 if min_count is None :
600606 min_count = 1
601- result [agg .name ] = agg .finalize (* squeezed ["intermediates" ])
607+ if finalize_kwargs is None :
608+ finalize_kwargs = {}
609+ result [agg .name ] = agg .finalize (* squeezed ["intermediates" ], ** finalize_kwargs )
602610 result [agg .name ] = np .where (counts >= min_count , result [agg .name ], fill_value )
603611
604612 # Final reindexing has to be here to be lazy
@@ -621,10 +629,13 @@ def _npg_aggregate(
621629 fill_value : Any = None ,
622630 min_count : Optional [int ] = None ,
623631 backend : str = "numpy" ,
632+ finalize_kwargs : Optional [Mapping ] = None ,
624633) -> FinalResultsDict :
625634 """Final aggregation step of tree reduction"""
626635 results = _npg_combine (x_chunk , agg , axis , keepdims , group_ndim , backend )
627- return _finalize_results (results , agg , axis , expected_groups , fill_value , min_count )
636+ return _finalize_results (
637+ results , agg , axis , expected_groups , fill_value , min_count , finalize_kwargs
638+ )
628639
629640
630641def _npg_combine (
@@ -782,6 +793,7 @@ def groupby_agg(
782793 min_count : Optional [int ] = None ,
783794 isbin : bool = False ,
784795 backend : str = "numpy" ,
796+ finalize_kwargs : Optional [Mapping ] = None ,
785797) -> Tuple ["DaskArray" , Union [np .ndarray , "DaskArray" ]]:
786798
787799 import dask .array
@@ -851,6 +863,14 @@ def groupby_agg(
851863 group_chunks = (len (expected_groups ),) if expected_groups is not None else (np .nan ,)
852864 expected_agg = expected_groups
853865
866+ agg_kwargs = dict (
867+ group_ndim = by .ndim ,
868+ fill_value = fill_value ,
869+ min_count = min_count ,
870+ backend = backend ,
871+ finalize_kwargs = finalize_kwargs ,
872+ )
873+
854874 if method == "mapreduce" :
855875 # reduced is really a dict mapping reduction name to array
856876 # and "groups" to an array of group labels
@@ -862,10 +882,7 @@ def groupby_agg(
862882 _npg_aggregate ,
863883 agg = agg ,
864884 expected_groups = expected_agg ,
865- group_ndim = by .ndim ,
866- fill_value = fill_value ,
867- min_count = min_count ,
868- backend = backend ,
885+ ** agg_kwargs ,
869886 ),
870887 combine = partial (_npg_combine , agg = agg , group_ndim = by .ndim , backend = backend ),
871888 name = f"{ name } -reduce" ,
@@ -892,10 +909,7 @@ def groupby_agg(
892909 _npg_aggregate ,
893910 agg = agg ,
894911 expected_groups = None ,
895- group_ndim = by .ndim ,
896- fill_value = fill_value ,
897- min_count = min_count ,
898- backend = backend ,
912+ ** agg_kwargs ,
899913 axis = axis ,
900914 keepdims = True ,
901915 ),
@@ -982,6 +996,7 @@ def groupby_reduce(
982996 split_out : int = 1 ,
983997 method : str = "mapreduce" ,
984998 backend : str = "numpy" ,
999+ finalize_kwargs : Optional [Mapping ] = None ,
9851000) -> Tuple ["DaskArray" , Union [np .ndarray , "DaskArray" ]]:
9861001 """
9871002 GroupBy reductions using tree reductions for dask.array
@@ -1026,6 +1041,8 @@ def groupby_reduce(
10261041 chunking ``array`` for this method by first rechunking using ``rechunk_for_cohorts``.
10271042 backend: {"numpy", "numba"}, optional
10281043 Backend for numpy_groupies. numpy by default.
1044+ finalize_kwargs: Mapping, optional
1045+ Kwargs passed to finalize the reduction such as ddof for var, std.
10291046
10301047 Returns
10311048 -------
@@ -1112,18 +1129,25 @@ def groupby_reduce(
11121129 reduction .finalize = None
11131130 # xarray's count is npg's nanlen
11141131 func = reduction .name if reduction .name != "count" else "nanlen"
1115- if min_count is not None :
1132+ if finalize_kwargs is None :
1133+ finalize_kwargs = {}
1134+ if isinstance (finalize_kwargs , Mapping ):
1135+ finalize_kwargs = (finalize_kwargs ,)
1136+ append_nanlen = min_count is not None or reduction .name in ["nanvar" , "nanstd" ]
1137+ if append_nanlen :
11161138 func = (func , "nanlen" )
1139+ finalize_kwargs = finalize_kwargs + ({},)
11171140
11181141 results = chunk_reduce (
11191142 array ,
11201143 by ,
11211144 func = func ,
11221145 axis = axis ,
11231146 expected_groups = expected_groups if isbin else None ,
1124- fill_value = (fill_value , 0 ) if min_count is not None else fill_value ,
1147+ fill_value = (fill_value , 0 ) if append_nanlen else fill_value ,
11251148 dtype = reduction .dtype ,
11261149 isbin = isbin ,
1150+ kwargs = finalize_kwargs ,
11271151 ) # type: ignore
11281152
11291153 if reduction .name in ["argmin" , "argmax" , "nanargmax" , "nanargmin" ]:
@@ -1133,6 +1157,12 @@ def groupby_reduce(
11331157 results ["intermediates" ][0 ] = np .unravel_index (
11341158 results ["intermediates" ][0 ], array .shape
11351159 )[- 1 ]
1160+ elif reduction .name in ["nanvar" , "nanstd" ]:
1161+ # Fix npg bug where all-NaN rows are 0 instead of NaN
1162+ value , counts = results ["intermediates" ]
1163+ mask = counts <= 0
1164+ value [mask ] = np .nan
1165+ results ["intermediates" ] = (value ,)
11361166
11371167 if isbin :
11381168 expected_groups = np .arange (len (expected_groups ) - 1 )
@@ -1167,6 +1197,7 @@ def groupby_reduce(
11671197 min_count = min_count ,
11681198 isbin = isbin ,
11691199 backend = backend ,
1200+ finalize_kwargs = finalize_kwargs ,
11701201 )
11711202 if method == "cohorts" :
11721203 assert len (axis ) == 1
0 commit comments