Skip to content

Commit 9543626

Browse files
committed
Update corr python wrapper to expose only builder parameters used
1 parent 32d8ddd commit 9543626

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

python/datafusion/functions.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,9 +1643,21 @@ def avg(
16431643
return Expr(f.avg(expression.expr, filter=filter_raw))
16441644

16451645

1646-
def corr(value1: Expr, value2: Expr, distinct: bool = False) -> Expr:
1647-
"""Returns the correlation coefficient between ``value1`` and ``value2``."""
1648-
return Expr(f.corr(value1.expr, value2.expr, distinct=distinct))
1646+
def corr(value_y: Expr, value_x: Expr, filter: Optional[Expr] = None) -> Expr:
1647+
"""Returns the correlation coefficient between ``value1`` and ``value2``.
1648+
1649+
This aggregate function expects both values to be numeric and will return a float.
1650+
1651+
If using the builder functions described in ref:`_aggregation` this function ignores
1652+
the options ``order_by``, ``null_treatment``, and ``distinct``.
1653+
1654+
Args:
1655+
value_y: The dependent variable for correlation
1656+
value_x: The independent variable for correlation
1657+
filter: If provided, only compute against rows for which the filter is true
1658+
"""
1659+
filter_raw = filter.expr if filter is not None else None
1660+
return Expr(f.corr(value_y.expr, value_x.expr, filter=filter_raw))
16491661

16501662

16511663
def count(args: Expr | list[Expr] | None = None, distinct: bool = False) -> Expr:

python/datafusion/tests/test_aggregation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,16 @@ def test_aggregation(df, agg_expr, expected, array_sort):
169169
),
170170
[83, 68, 122, 124, 117],
171171
),
172+
(
173+
"corr",
174+
f.corr(column("c3"), column("c2")),
175+
[-0.1056, -0.2808, 0.0023, 0.0022, -0.2473],
176+
),
177+
(
178+
"corr_w_filter",
179+
f.corr(column("c3"), column("c2"), filter=column("c3") > lit(0)),
180+
[-0.3298, 0.2925, 0.2467, -0.2269, 0.0358],
181+
),
172182
],
173183
)
174184
def test_aggregate_100(df_aggregate_100, name, expr, expected):

0 commit comments

Comments
 (0)