2525from ... import opcodes as OperandDef
2626from ...config import options
2727from ...core .custom_log import redirect_custom_log
28- from ...core import ENTITY_TYPE , OutputType
28+ from ...core import ENTITY_TYPE , OutputType , recursive_tile
2929from ...core .context import get_context
3030from ...core .operand import OperandStage
3131from ...serialization .serializables import (
6464
6565_support_get_group_without_as_index = pd_release_version [:2 ] > (1 , 0 )
6666
67+ _FUNCS_PREFER_SHUFFLE = {"nunique" }
68+
6769
6870class SizeRecorder :
6971 def __init__ (self ):
@@ -163,6 +165,8 @@ class DataFrameGroupByAgg(DataFrameOperand, DataFrameOperandMixin):
163165 method = StringField ("method" )
164166 use_inf_as_na = BoolField ("use_inf_as_na" )
165167
168+ map_on_shuffle = AnyField ("map_on_shuffle" )
169+
166170 # for chunk
167171 combine_size = Int32Field ("combine_size" )
168172 chunk_store_limit = Int64Field ("chunk_store_limit" )
@@ -421,10 +425,29 @@ def _tile_with_shuffle(
421425 in_df : TileableType ,
422426 out_df : TileableType ,
423427 func_infos : ReductionSteps ,
428+ agg_chunks : List [ChunkType ] = None ,
424429 ):
425- # First, perform groupby and aggregation on each chunk.
426- agg_chunks = cls ._gen_map_chunks (op , in_df .chunks , out_df , func_infos )
427- return cls ._perform_shuffle (op , agg_chunks , in_df , out_df , func_infos )
430+ if op .map_on_shuffle is None :
431+ op .map_on_shuffle = all (
432+ agg_fun .custom_reduction is None for agg_fun in func_infos .agg_funcs
433+ )
434+
435+ if not op .map_on_shuffle :
436+ groupby_params = op .groupby_params .copy ()
437+ selection = groupby_params .pop ("selection" , None )
438+ groupby = in_df .groupby (** groupby_params )
439+ if selection :
440+ groupby = groupby [selection ]
441+ result = groupby .transform (
442+ op .raw_func , _call_agg = True , index = out_df .index_value
443+ )
444+ return (yield from recursive_tile (result ))
445+ else :
446+ # First, perform groupby and aggregation on each chunk.
447+ agg_chunks = agg_chunks or cls ._gen_map_chunks (
448+ op , in_df .chunks , out_df , func_infos
449+ )
450+ return cls ._perform_shuffle (op , agg_chunks , in_df , out_df , func_infos )
428451
429452 @classmethod
430453 def _perform_shuffle (
@@ -624,8 +647,10 @@ def _tile_auto(
624647 else :
625648 # otherwise, use shuffle
626649 logger .debug ("Choose shuffle method for groupby operand %s" , op )
627- return cls ._perform_shuffle (
628- op , chunks + left_chunks , in_df , out_df , func_infos
650+ return (
651+ yield from cls ._tile_with_shuffle (
652+ op , in_df , out_df , func_infos , chunks + left_chunks
653+ )
629654 )
630655
631656 @classmethod
@@ -638,12 +663,16 @@ def tile(cls, op: "DataFrameGroupByAgg"):
638663 func_infos = cls ._compile_funcs (op , in_df )
639664
640665 if op .method == "auto" :
641- if len (in_df .chunks ) <= op .combine_size :
666+ if set (op .func ) & _FUNCS_PREFER_SHUFFLE :
667+ return (
668+ yield from cls ._tile_with_shuffle (op , in_df , out_df , func_infos )
669+ )
670+ elif len (in_df .chunks ) <= op .combine_size :
642671 return cls ._tile_with_tree (op , in_df , out_df , func_infos )
643672 else :
644673 return (yield from cls ._tile_auto (op , in_df , out_df , func_infos ))
645674 if op .method == "shuffle" :
646- return cls ._tile_with_shuffle (op , in_df , out_df , func_infos )
675+ return ( yield from cls ._tile_with_shuffle (op , in_df , out_df , func_infos ) )
647676 elif op .method == "tree" :
648677 return cls ._tile_with_tree (op , in_df , out_df , func_infos )
649678 else : # pragma: no cover
@@ -1075,7 +1104,15 @@ def execute(cls, ctx, op: "DataFrameGroupByAgg"):
10751104 pd .reset_option ("mode.use_inf_as_na" )
10761105
10771106
1078- def agg (groupby , func = None , method = "auto" , combine_size = None , * args , ** kwargs ):
1107+ def agg (
1108+ groupby ,
1109+ func = None ,
1110+ method = "auto" ,
1111+ combine_size = None ,
1112+ map_on_shuffle = None ,
1113+ * args ,
1114+ ** kwargs ,
1115+ ):
10791116 """
10801117 Aggregate using one or more operations on grouped data.
10811118
@@ -1091,7 +1128,11 @@ def agg(groupby, func=None, method="auto", combine_size=None, *args, **kwargs):
10911128 in distributed mode and use 'tree' in local mode.
10921129 combine_size : int
10931130 The number of chunks to combine when method is 'tree'
1094-
1131+ map_on_shuffle : bool
1132+ When not specified, will decide whether to perform aggregation on the
1133+ map stage of shuffle (currently no aggregation when there is custom
1134+ reduction in functions). Otherwise, whether to call map on map stage
1135+ of shuffle is determined by the value.
10951136
10961137 Returns
10971138 -------
@@ -1138,5 +1179,6 @@ def agg(groupby, func=None, method="auto", combine_size=None, *args, **kwargs):
11381179 combine_size = combine_size or options .combine_size ,
11391180 chunk_store_limit = options .chunk_store_limit ,
11401181 use_inf_as_na = use_inf_as_na ,
1182+ map_on_shuffle = map_on_shuffle ,
11411183 )
11421184 return agg_op (groupby )
0 commit comments