Skip to content

Commit 55ebc17

Browse files
committed
Update count and count_star with approprate aggregation options
1 parent 8d16a3c commit 55ebc17

File tree

2 files changed

+44
-12
lines changed

2 files changed

+44
-12
lines changed

python/datafusion/functions.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,18 @@ def col(name: str) -> Expr:
365365
return Expr(f.col(name))
366366

367367

368-
def count_star() -> Expr:
369-
"""Create a COUNT(1) aggregate expression."""
370-
return Expr(f.count(Expr.literal(1)))
368+
def count_star(filter: Optional[Expr] = None) -> Expr:
369+
"""Create a COUNT(1) aggregate expression.
370+
371+
This aggregate function will count all of the rows in the partition.
372+
373+
If using the builder functions described in ref:`_aggregation` this function ignores
374+
the options ``order_by``, ``distinct``, and ``null_treatment``.
375+
376+
Args:
377+
filter: If provided, only count rows for which the filter is true
378+
"""
379+
return count(Expr.literal(1), filter=filter)
371380

372381

373382
def case(expr: Expr) -> CaseBuilder:
@@ -1660,15 +1669,33 @@ def corr(value_y: Expr, value_x: Expr, filter: Optional[Expr] = None) -> Expr:
16601669
return Expr(f.corr(value_y.expr, value_x.expr, filter=filter_raw))
16611670

16621671

1663-
def count(args: Expr | list[Expr] | None = None, distinct: bool = False) -> Expr:
1664-
"""Returns the number of rows that match the given arguments."""
1665-
if args is None:
1666-
return count(Expr.literal(1), distinct=distinct)
1667-
if isinstance(args, list):
1668-
args = [arg.expr for arg in args]
1669-
elif isinstance(args, Expr):
1670-
args = [args.expr]
1671-
return Expr(f.count(*args, distinct=distinct))
1672+
def count(
1673+
expressions: Expr | list[Expr] | None = None,
1674+
distinct: bool = False,
1675+
filter: Optional[Expr] = None,
1676+
) -> Expr:
1677+
"""Returns the number of rows that match the given arguments.
1678+
1679+
This aggregate function will count the non-null rows provided in the expression.
1680+
1681+
If using the builder functions described in ref:`_aggregation` this function ignores
1682+
the options ``order_by`` and ``null_treatment``.
1683+
1684+
Args:
1685+
expressions: Argument to perform bitwise calculation on
1686+
distinct: If True, a single entry for each distinct value will be in the result
1687+
filter: If provided, only compute against rows for which the filter is true
1688+
"""
1689+
filter_raw = filter.expr if filter is not None else None
1690+
1691+
if expressions is None:
1692+
args = [Expr.literal(1).expr]
1693+
elif isinstance(expressions, list):
1694+
args = [arg.expr for arg in expressions]
1695+
else:
1696+
args = [expressions.expr]
1697+
1698+
return Expr(f.count(*args, distinct=distinct, filter=filter_raw))
16721699

16731700

16741701
def covar(y: Expr, x: Expr) -> Expr:

python/datafusion/tests/test_aggregation.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,11 @@ def test_aggregation_stats(df, agg_expr, calc_expected):
136136
False,
137137
),
138138
(f.avg(column("b"), filter=column("a") != lit(1)), pa.array([5.0]), False),
139+
(f.count(column("b"), distinct=True), pa.array([2]), False),
140+
(f.count(column("b"), filter=column("a") != 3), pa.array([2]), False),
141+
(f.count(), pa.array([3]), False),
142+
(f.count(column("e")), pa.array([2]), False),
143+
(f.count_star(filter=column("a") != 3), pa.array([2]), False),
139144
],
140145
)
141146
def test_aggregation(df, agg_expr, expected, array_sort):

0 commit comments

Comments
 (0)