@@ -468,21 +468,10 @@ def first(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fr
468
468
if not isinstance (min_count , int ):
469
469
raise TypeError ("min_count must be integer" )
470
470
471
- if min_count > 0 :
472
-
473
- def first (col : Column ) -> Column :
474
- return F .when (
475
- F .count (F .when (~ F .isnull (col ), F .lit (0 ))) < min_count , F .lit (None )
476
- ).otherwise (F .first (col , ignorenulls = True ))
477
-
478
- else :
479
-
480
- def first (col : Column ) -> Column :
481
- return F .first (col , ignorenulls = True )
482
-
483
471
return self ._reduce_for_stat_function (
484
- first ,
472
+ lambda col : F . first ( col , ignorenulls = True ) ,
485
473
accepted_spark_types = (NumericType , BooleanType ) if numeric_only else None ,
474
+ min_count = min_count ,
486
475
)
487
476
488
477
def last (self , numeric_only : Optional [bool ] = False , min_count : int = - 1 ) -> FrameLike :
@@ -549,21 +538,10 @@ def last(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fra
549
538
if not isinstance (min_count , int ):
550
539
raise TypeError ("min_count must be integer" )
551
540
552
- if min_count > 0 :
553
-
554
- def last (col : Column ) -> Column :
555
- return F .when (
556
- F .count (F .when (~ F .isnull (col ), F .lit (0 ))) < min_count , F .lit (None )
557
- ).otherwise (F .last (col , ignorenulls = True ))
558
-
559
- else :
560
-
561
- def last (col : Column ) -> Column :
562
- return F .last (col , ignorenulls = True )
563
-
564
541
return self ._reduce_for_stat_function (
565
- last ,
542
+ lambda col : F . last ( col , ignorenulls = True ) ,
566
543
accepted_spark_types = (NumericType , BooleanType ) if numeric_only else None ,
544
+ min_count = min_count ,
567
545
)
568
546
569
547
def max (self , numeric_only : Optional [bool ] = False , min_count : int = - 1 ) -> FrameLike :
@@ -624,20 +602,10 @@ def max(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fram
624
602
if not isinstance (min_count , int ):
625
603
raise TypeError ("min_count must be integer" )
626
604
627
- if min_count > 0 :
628
-
629
- def max (col : Column ) -> Column :
630
- return F .when (
631
- F .count (F .when (~ F .isnull (col ), F .lit (0 ))) < min_count , F .lit (None )
632
- ).otherwise (F .max (col ))
633
-
634
- else :
635
-
636
- def max (col : Column ) -> Column :
637
- return F .max (col )
638
-
639
605
return self ._reduce_for_stat_function (
640
- max , accepted_spark_types = (NumericType , BooleanType ) if numeric_only else None
606
+ F .max ,
607
+ accepted_spark_types = (NumericType , BooleanType ) if numeric_only else None ,
608
+ min_count = min_count ,
641
609
)
642
610
643
611
def mean (self , numeric_only : Optional [bool ] = True ) -> FrameLike :
@@ -802,20 +770,10 @@ def min(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fram
802
770
if not isinstance (min_count , int ):
803
771
raise TypeError ("min_count must be integer" )
804
772
805
- if min_count > 0 :
806
-
807
- def min (col : Column ) -> Column :
808
- return F .when (
809
- F .count (F .when (~ F .isnull (col ), F .lit (0 ))) < min_count , F .lit (None )
810
- ).otherwise (F .min (col ))
811
-
812
- else :
813
-
814
- def min (col : Column ) -> Column :
815
- return F .min (col )
816
-
817
773
return self ._reduce_for_stat_function (
818
- min , accepted_spark_types = (NumericType , BooleanType ) if numeric_only else None
774
+ F .min ,
775
+ accepted_spark_types = (NumericType , BooleanType ) if numeric_only else None ,
776
+ min_count = min_count ,
819
777
)
820
778
821
779
# TODO: sync the doc.
@@ -944,20 +902,11 @@ def sum(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameL
944
902
f"numeric_only=False, skip unsupported columns: { unsupported } "
945
903
)
946
904
947
- if min_count > 0 :
948
-
949
- def sum (col : Column ) -> Column :
950
- return F .when (
951
- F .count (F .when (~ F .isnull (col ), F .lit (0 ))) < min_count , F .lit (None )
952
- ).otherwise (F .sum (col ))
953
-
954
- else :
955
-
956
- def sum (col : Column ) -> Column :
957
- return F .sum (col )
958
-
959
905
return self ._reduce_for_stat_function (
960
- sum , accepted_spark_types = (NumericType ,), bool_to_numeric = True
906
+ F .sum ,
907
+ accepted_spark_types = (NumericType , BooleanType ),
908
+ bool_to_numeric = True ,
909
+ min_count = min_count ,
961
910
)
962
911
963
912
# TODO: sync the doc.
@@ -1324,22 +1273,11 @@ def prod(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> Frame
1324
1273
1325
1274
self ._validate_agg_columns (numeric_only = numeric_only , function_name = "prod" )
1326
1275
1327
- if min_count > 0 :
1328
-
1329
- def prod (col : Column ) -> Column :
1330
- return F .when (
1331
- F .count (F .when (~ F .isnull (col ), F .lit (0 ))) < min_count , F .lit (None )
1332
- ).otherwise (SF .product (col , True ))
1333
-
1334
- else :
1335
-
1336
- def prod (col : Column ) -> Column :
1337
- return SF .product (col , True )
1338
-
1339
1276
return self ._reduce_for_stat_function (
1340
- prod ,
1277
+ lambda col : SF . product ( col , True ) ,
1341
1278
accepted_spark_types = (NumericType , BooleanType ),
1342
1279
bool_to_numeric = True ,
1280
+ min_count = min_count ,
1343
1281
)
1344
1282
1345
1283
def all (self , skipna : bool = True ) -> FrameLike :
@@ -3596,6 +3534,7 @@ def _reduce_for_stat_function(
3596
3534
sfun : Callable [[Column ], Column ],
3597
3535
accepted_spark_types : Optional [Tuple [Type [DataType ], ...]] = None ,
3598
3536
bool_to_numeric : bool = False ,
3537
+ ** kwargs : Any ,
3599
3538
) -> FrameLike :
3600
3539
"""Apply an aggregate function `sfun` per column and reduce to a FrameLike.
3601
3540
@@ -3615,14 +3554,19 @@ def _reduce_for_stat_function(
3615
3554
psdf : DataFrame = DataFrame (internal )
3616
3555
3617
3556
if len (psdf ._internal .column_labels ) > 0 :
3557
+ min_count = kwargs .get ("min_count" , 0 )
3618
3558
stat_exprs = []
3619
3559
for label in psdf ._internal .column_labels :
3620
3560
psser = psdf ._psser_for (label )
3621
- stat_exprs .append (
3622
- sfun (psser ._dtype_op .nan_to_null (psser ).spark .column ).alias (
3623
- psser ._internal .data_spark_column_names [0 ]
3561
+ input_scol = psser ._dtype_op .nan_to_null (psser ).spark .column
3562
+ output_scol = sfun (input_scol )
3563
+
3564
+ if min_count > 0 :
3565
+ output_scol = F .when (
3566
+ F .count (F .when (~ F .isnull (input_scol ), F .lit (0 ))) >= min_count , output_scol
3624
3567
)
3625
- )
3568
+
3569
+ stat_exprs .append (output_scol .alias (psser ._internal .data_spark_column_names [0 ]))
3626
3570
sdf = sdf .groupby (* groupkey_names ).agg (* stat_exprs )
3627
3571
else :
3628
3572
sdf = sdf .select (* groupkey_names ).distinct ()
0 commit comments