Skip to content

Commit 6c182da

Browse files
zhengruifengHyukjinKwon
authored andcommitted
[SPARK-40744][PS] Make _reduce_for_stat_function in groupby accept min_count
### What changes were proposed in this pull request? Make `_reduce_for_stat_function` in `groupby` accept `min_count` ### Why are the changes needed? to simplify the implementations ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing UTs Closes apache#38201 from zhengruifeng/ps_groupby_mc. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent d94c65e commit 6c182da

File tree

1 file changed

+26
-82
lines changed

1 file changed

+26
-82
lines changed

python/pyspark/pandas/groupby.py

+26-82
Original file line numberDiff line numberDiff line change
@@ -468,21 +468,10 @@ def first(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fr
468468
if not isinstance(min_count, int):
469469
raise TypeError("min_count must be integer")
470470

471-
if min_count > 0:
472-
473-
def first(col: Column) -> Column:
474-
return F.when(
475-
F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None)
476-
).otherwise(F.first(col, ignorenulls=True))
477-
478-
else:
479-
480-
def first(col: Column) -> Column:
481-
return F.first(col, ignorenulls=True)
482-
483471
return self._reduce_for_stat_function(
484-
first,
472+
lambda col: F.first(col, ignorenulls=True),
485473
accepted_spark_types=(NumericType, BooleanType) if numeric_only else None,
474+
min_count=min_count,
486475
)
487476

488477
def last(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> FrameLike:
@@ -549,21 +538,10 @@ def last(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fra
549538
if not isinstance(min_count, int):
550539
raise TypeError("min_count must be integer")
551540

552-
if min_count > 0:
553-
554-
def last(col: Column) -> Column:
555-
return F.when(
556-
F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None)
557-
).otherwise(F.last(col, ignorenulls=True))
558-
559-
else:
560-
561-
def last(col: Column) -> Column:
562-
return F.last(col, ignorenulls=True)
563-
564541
return self._reduce_for_stat_function(
565-
last,
542+
lambda col: F.last(col, ignorenulls=True),
566543
accepted_spark_types=(NumericType, BooleanType) if numeric_only else None,
544+
min_count=min_count,
567545
)
568546

569547
def max(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> FrameLike:
@@ -624,20 +602,10 @@ def max(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fram
624602
if not isinstance(min_count, int):
625603
raise TypeError("min_count must be integer")
626604

627-
if min_count > 0:
628-
629-
def max(col: Column) -> Column:
630-
return F.when(
631-
F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None)
632-
).otherwise(F.max(col))
633-
634-
else:
635-
636-
def max(col: Column) -> Column:
637-
return F.max(col)
638-
639605
return self._reduce_for_stat_function(
640-
max, accepted_spark_types=(NumericType, BooleanType) if numeric_only else None
606+
F.max,
607+
accepted_spark_types=(NumericType, BooleanType) if numeric_only else None,
608+
min_count=min_count,
641609
)
642610

643611
def mean(self, numeric_only: Optional[bool] = True) -> FrameLike:
@@ -802,20 +770,10 @@ def min(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fram
802770
if not isinstance(min_count, int):
803771
raise TypeError("min_count must be integer")
804772

805-
if min_count > 0:
806-
807-
def min(col: Column) -> Column:
808-
return F.when(
809-
F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None)
810-
).otherwise(F.min(col))
811-
812-
else:
813-
814-
def min(col: Column) -> Column:
815-
return F.min(col)
816-
817773
return self._reduce_for_stat_function(
818-
min, accepted_spark_types=(NumericType, BooleanType) if numeric_only else None
774+
F.min,
775+
accepted_spark_types=(NumericType, BooleanType) if numeric_only else None,
776+
min_count=min_count,
819777
)
820778

821779
# TODO: sync the doc.
@@ -944,20 +902,11 @@ def sum(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameL
944902
f"numeric_only=False, skip unsupported columns: {unsupported}"
945903
)
946904

947-
if min_count > 0:
948-
949-
def sum(col: Column) -> Column:
950-
return F.when(
951-
F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None)
952-
).otherwise(F.sum(col))
953-
954-
else:
955-
956-
def sum(col: Column) -> Column:
957-
return F.sum(col)
958-
959905
return self._reduce_for_stat_function(
960-
sum, accepted_spark_types=(NumericType,), bool_to_numeric=True
906+
F.sum,
907+
accepted_spark_types=(NumericType, BooleanType),
908+
bool_to_numeric=True,
909+
min_count=min_count,
961910
)
962911

963912
# TODO: sync the doc.
@@ -1324,22 +1273,11 @@ def prod(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> Frame
13241273

13251274
self._validate_agg_columns(numeric_only=numeric_only, function_name="prod")
13261275

1327-
if min_count > 0:
1328-
1329-
def prod(col: Column) -> Column:
1330-
return F.when(
1331-
F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None)
1332-
).otherwise(SF.product(col, True))
1333-
1334-
else:
1335-
1336-
def prod(col: Column) -> Column:
1337-
return SF.product(col, True)
1338-
13391276
return self._reduce_for_stat_function(
1340-
prod,
1277+
lambda col: SF.product(col, True),
13411278
accepted_spark_types=(NumericType, BooleanType),
13421279
bool_to_numeric=True,
1280+
min_count=min_count,
13431281
)
13441282

13451283
def all(self, skipna: bool = True) -> FrameLike:
@@ -3596,6 +3534,7 @@ def _reduce_for_stat_function(
35963534
sfun: Callable[[Column], Column],
35973535
accepted_spark_types: Optional[Tuple[Type[DataType], ...]] = None,
35983536
bool_to_numeric: bool = False,
3537+
**kwargs: Any,
35993538
) -> FrameLike:
36003539
"""Apply an aggregate function `sfun` per column and reduce to a FrameLike.
36013540
@@ -3615,14 +3554,19 @@ def _reduce_for_stat_function(
36153554
psdf: DataFrame = DataFrame(internal)
36163555

36173556
if len(psdf._internal.column_labels) > 0:
3557+
min_count = kwargs.get("min_count", 0)
36183558
stat_exprs = []
36193559
for label in psdf._internal.column_labels:
36203560
psser = psdf._psser_for(label)
3621-
stat_exprs.append(
3622-
sfun(psser._dtype_op.nan_to_null(psser).spark.column).alias(
3623-
psser._internal.data_spark_column_names[0]
3561+
input_scol = psser._dtype_op.nan_to_null(psser).spark.column
3562+
output_scol = sfun(input_scol)
3563+
3564+
if min_count > 0:
3565+
output_scol = F.when(
3566+
F.count(F.when(~F.isnull(input_scol), F.lit(0))) >= min_count, output_scol
36243567
)
3625-
)
3568+
3569+
stat_exprs.append(output_scol.alias(psser._internal.data_spark_column_names[0]))
36263570
sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)
36273571
else:
36283572
sdf = sdf.select(*groupkey_names).distinct()

0 commit comments

Comments
 (0)