@@ -195,7 +195,6 @@ class FactorizeKwargs(TypedDict, total=False):
195195 by : T_Bys
196196 axes : T_Axes
197197 fastpath : bool
198- expected_groups : T_ExpectIndexOptTuple | None
199198 reindex : bool
200199 sort : bool
201200
@@ -844,6 +843,67 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
844843 return offset , size
845844
846845
846+ def _factorize_single (by , expect , * , sort : bool , reindex : bool ):
847+ flat = by .reshape (- 1 )
848+ if isinstance (expect , pd .RangeIndex ):
849+ # idx is a view of the original `by` array
850+ # copy here so we don't have a race condition with the
851+ # group_idx[nanmask] = nan_sentinel assignment later
852+ # this is important in shared-memory parallelism with dask
853+ # TODO: figure out how to avoid this
854+ idx = flat .copy ()
855+ found_groups = np .array (expect )
856+ # TODO: fix by using masked integers
857+ idx [idx > expect [- 1 ]] = - 1
858+
859+ elif isinstance (expect , pd .IntervalIndex ):
860+ if expect .closed == "both" :
861+ raise NotImplementedError
862+ bins = np .concatenate ([expect .left .to_numpy (), expect .right .to_numpy ()[[- 1 ]]])
863+
864+ # digitize is 0 or idx.max() for values outside the bounds of all intervals
865+ # make it behave like pd.cut which uses -1:
866+ if len (bins ) > 1 :
867+ right = expect .closed_right
868+ idx = np .digitize (
869+ flat ,
870+ bins = bins .view (np .int64 ) if bins .dtype .kind == "M" else bins ,
871+ right = right ,
872+ )
873+ idx -= 1
874+ within_bins = flat <= bins .max () if right else flat < bins .max ()
875+ idx [~ within_bins ] = - 1
876+ else :
877+ idx = np .zeros_like (flat , dtype = np .intp ) - 1
878+ found_groups = np .array (expect )
879+ else :
880+ if expect is not None and reindex :
881+ sorter = np .argsort (expect )
882+ groups = expect [(sorter ,)] if sort else expect
883+ idx = np .searchsorted (expect , flat , sorter = sorter )
884+ mask = ~ np .isin (flat , expect ) | isnull (flat ) | (idx == len (expect ))
885+ if not sort :
886+ # idx is the index in to the sorted array.
887+ # if we didn't want sorting, unsort it back
888+ idx [(idx == len (expect ),)] = - 1
889+ idx = sorter [(idx ,)]
890+ idx [mask ] = - 1
891+ else :
892+ idx , groups = pd .factorize (flat , sort = sort )
893+ found_groups = np .array (groups )
894+
895+ return (found_groups , idx .reshape (by .shape ))
896+
897+
898+ def _ravel_factorized (* factorized : np .ndarray , grp_shape : tuple [int , ...]) -> np .ndarray :
899+ group_idx = np .ravel_multi_index (factorized , grp_shape , mode = "wrap" )
900+ # NaNs; as well as values outside the bins are coded by -1
901+ # Restore these after the raveling
902+ nan_by_mask = reduce (np .logical_or , [(f == - 1 ) for f in factorized ])
903+ group_idx [nan_by_mask ] = - 1
904+ return group_idx
905+
906+
847907@overload
848908def factorize_ (
849909 by : T_Bys ,
@@ -890,7 +950,7 @@ def factorize_(
890950 fastpath : bool = False ,
891951) -> tuple [np .ndarray , tuple [np .ndarray , ...], tuple [int , ...], int , int , FactorProps | None ]:
892952 """
893- Returns an array of integer codes for groups (and associated data)
953+ Returns an array of integer codes for groups (and associated data)
894954 by wrapping pd.cut and pd.factorize (depending on isbin).
895955 This method handles reindex and sort so that we don't spend time reindexing / sorting
896956 a possibly large results array. Instead we set up the appropriate integer codes (group_idx)
@@ -899,75 +959,32 @@ def factorize_(
899959 if expected_groups is None :
900960 expected_groups = (None ,) * len (by )
901961
902- factorized = []
903- found_groups = []
904- for groupvar , expect in zip (by , expected_groups ):
905- flat = groupvar .reshape (- 1 )
906- if isinstance (expect , pd .RangeIndex ):
907- # idx is a view of the original `by` array
908- # copy here so we don't have a race condition with the
909- # group_idx[nanmask] = nan_sentinel assignment later
910- # this is important in shared-memory parallelism with dask
911- # TODO: figure out how to avoid this
912- idx = flat .copy ()
913- found_groups .append (np .array (expect ))
914- # TODO: fix by using masked integers
915- idx [idx > expect [- 1 ]] = - 1
916-
917- elif isinstance (expect , pd .IntervalIndex ):
918- if expect .closed == "both" :
919- raise NotImplementedError
920- bins = np .concatenate ([expect .left .to_numpy (), expect .right .to_numpy ()[[- 1 ]]])
921-
922- # digitize is 0 or idx.max() for values outside the bounds of all intervals
923- # make it behave like pd.cut which uses -1:
924- if len (bins ) > 1 :
925- right = expect .closed_right
926- idx = np .digitize (
927- flat ,
928- bins = bins .view (np .int64 ) if bins .dtype .kind == "M" else bins ,
929- right = right ,
930- )
931- idx -= 1
932- within_bins = flat <= bins .max () if right else flat < bins .max ()
933- idx [~ within_bins ] = - 1
934- else :
935- idx = np .zeros_like (flat , dtype = np .intp ) - 1
936-
937- found_groups .append (np .array (expect ))
938- else :
939- if expect is not None and reindex :
940- sorter = np .argsort (expect )
941- groups = expect [(sorter ,)] if sort else expect
942- idx = np .searchsorted (expect , flat , sorter = sorter )
943- mask = ~ np .isin (flat , expect ) | isnull (flat ) | (idx == len (expect ))
944- if not sort :
945- # idx is the index in to the sorted array.
946- # if we didn't want sorting, unsort it back
947- idx [(idx == len (expect ),)] = - 1
948- idx = sorter [(idx ,)]
949- idx [mask ] = - 1
950- else :
951- idx , groups = pd .factorize (flat , sort = sort )
952-
953- found_groups .append (np .array (groups ))
954- factorized .append (idx .reshape (groupvar .shape ))
962+ if len (by ) > 2 :
963+ with ThreadPoolExecutor () as executor :
964+ futures = [
965+ executor .submit (partial (_factorize_single , sort = sort , reindex = reindex ), groupvar , expect )
966+ for groupvar , expect in zip (by , expected_groups )
967+ ]
968+ results = tuple (f .result () for f in futures )
969+ else :
970+ results = tuple (
971+ _factorize_single (groupvar , expect , sort = sort , reindex = reindex )
972+ for groupvar , expect in zip (by , expected_groups )
973+ )
974+ found_groups = [r [0 ] for r in results ]
975+ factorized = [r [1 ] for r in results ]
955976
956977 grp_shape = tuple (len (grp ) for grp in found_groups )
957978 ngroups = math .prod (grp_shape )
958979 if len (by ) > 1 :
959- group_idx = np .ravel_multi_index (factorized , grp_shape , mode = "wrap" )
960- # NaNs; as well as values outside the bins are coded by -1
961- # Restore these after the raveling
962- nan_by_mask = reduce (np .logical_or , [(f == - 1 ) for f in factorized ])
963- group_idx [nan_by_mask ] = - 1
980+ group_idx = _ravel_factorized (* factorized , grp_shape = grp_shape )
964981 else :
965- group_idx = factorized [ 0 ]
982+ ( group_idx ,) = factorized
966983
967984 if fastpath :
968985 return group_idx , tuple (found_groups ), grp_shape , ngroups , ngroups , None
969986
970- if len (axes ) == 1 and groupvar .ndim > 1 :
987+ if len (axes ) == 1 and by [ 0 ] .ndim > 1 :
971988 # Not reducing along all dimensions of by
972989 # this is OK because for 3D by and axis=(1,2),
973990 # we collapse to a 2D by and axis=-1
@@ -2258,7 +2275,6 @@ def _factorize_multiple(
22582275) -> tuple [tuple [np .ndarray ], tuple [np .ndarray , ...], tuple [int , ...]]:
22592276 kwargs : FactorizeKwargs = dict (
22602277 axes = (), # always (), we offset later if necessary.
2261- expected_groups = expected_groups ,
22622278 fastpath = True ,
22632279 # This is the only way it makes sense I think.
22642280 # reindex controls what's actually allocated in chunk_reduce
@@ -2272,34 +2288,36 @@ def _factorize_multiple(
22722288 # unifying chunks will make sure all arrays in `by` are dask arrays
22732289 # with compatible chunks, even if there was originally a numpy array
22742290 inds = tuple (range (by [0 ].ndim ))
2275- chunks , by_ = dask .array .unify_chunks (* itertools .chain (* zip (by , (inds ,) * len (by ))))
2276-
2277- group_idx = dask .array .map_blocks (
2278- _lazy_factorize_wrapper ,
2279- * by_ ,
2280- chunks = tuple (chunks .values ()),
2281- meta = np .array ((), dtype = np .int64 ),
2282- ** kwargs ,
2283- )
2284-
2285- fg , gs = [], []
22862291 for by_ , expect in zip (by , expected_groups ):
2287- if expect is None :
2288- if is_duck_dask_array (by_ ):
2289- raise ValueError ("Please provide expected_groups when grouping by a dask array." )
2292+ if expect is None and is_duck_dask_array (by_ ):
2293+ raise ValueError ("Please provide expected_groups when grouping by a dask array." )
22902294
2291- found_group = pd .unique (by_ .reshape (- 1 ))
2292- else :
2293- found_group = expect .to_numpy ()
2295+ found_groups = tuple (
2296+ pd .unique (by_ .reshape (- 1 )) if expect is None else expect .to_numpy ()
2297+ for by_ , expect in zip (by , expected_groups )
2298+ )
2299+ grp_shape = tuple (map (len , found_groups ))
22942300
2295- fg .append (found_group )
2296- gs .append (len (found_group ))
2301+ chunks , by_chunked = dask .array .unify_chunks (* itertools .chain (* zip (by , (inds ,) * len (by ))))
2302+ group_idxs = [
2303+ dask .array .map_blocks (
2304+ _lazy_factorize_wrapper ,
2305+ by_ ,
2306+ expected_groups = (expect_ ,),
2307+ meta = np .array ((), dtype = np .int64 ),
2308+ ** kwargs ,
2309+ )
2310+ for by_ , expect_ in zip (by_chunked , expected_groups )
2311+ ]
2312+ # This could be avoied but we'd use `np.where`
2313+ # instead `_ravel_factorized` instead i.e. a copy.
2314+ group_idx = dask .array .map_blocks (
2315+ _ravel_factorized , * group_idxs , grp_shape = grp_shape , chunks = tuple (chunks .values ()), dtype = np .int64
2316+ )
22972317
2298- found_groups = tuple (fg )
2299- grp_shape = tuple (gs )
23002318 else :
23012319 kwargs ["by" ] = by
2302- group_idx , found_groups , grp_shape , * _ = factorize_ (** kwargs )
2320+ group_idx , found_groups , grp_shape , * _ = factorize_ (** kwargs , expected_groups = expected_groups )
23032321
23042322 return (group_idx ,), found_groups , grp_shape
23052323
0 commit comments