Skip to content

Commit 4f93736

Browse files
committed
Update python wrapper for arguments appropriate to bool operators
1 parent fdee791 commit 4f93736

File tree

2 files changed

+37
-23
lines changed

2 files changed

+37
-23
lines changed

python/datafusion/functions.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1901,14 +1901,38 @@ def bit_xor(
19011901
return Expr(f.bit_xor(expression.expr, distinct=distinct, filter=filter_raw))
19021902

19031903

1904-
def bool_and(arg: Expr, distinct: bool = False) -> Expr:
1905-
"""Computes the boolean AND of the argument."""
1906-
return Expr(f.bool_and(arg.expr, distinct=distinct))
1904+
def bool_and(expression: Expr, filter: Optional[Expr] = None) -> Expr:
1905+
"""Computes the boolean AND of the argument.
19071906
1907+
This aggregate function will compare every value in the input partition. These are
1908+
expected to be boolean values.
19081909
1909-
def bool_or(arg: Expr, distinct: bool = False) -> Expr:
1910-
"""Computes the boolean OR of the argument."""
1911-
return Expr(f.bool_or(arg.expr, distinct=distinct))
1910+
If using the builder functions described in ref:`_aggregation` this function ignores
1911+
the options ``order_by``, ``null_treatment``, and ``distinct``.
1912+
1913+
Args:
1914+
expression: Argument to perform calculation on
1915+
filter: If provided, only compute against rows for which the filter is true
1916+
"""
1917+
filter_raw = filter.expr if filter is not None else None
1918+
return Expr(f.bool_and(expression.expr, filter=filter_raw))
1919+
1920+
1921+
def bool_or(expression: Expr, filter: Optional[Expr] = None) -> Expr:
1922+
"""Computes the boolean OR of the argument.
1923+
1924+
This aggregate function will compare every value in the input partition. These are
1925+
expected to be boolean values.
1926+
1927+
If using the builder functions described in ref:`_aggregation` this function ignores
1928+
the options ``order_by``, ``null_treatment``, and ``distinct``.
1929+
1930+
Args:
1931+
expression: Argument to perform calculation on
1932+
filter: If provided, only compute against rows for which the filter is true
1933+
"""
1934+
filter_raw = filter.expr if filter is not None else None
1935+
return Expr(f.bool_or(expression.expr, filter=filter_raw))
19121936

19131937

19141938
def lead(

python/datafusion/tests/test_aggregation.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def test_aggregate_100(df_aggregate_100):
179179
assert result.column("c5") == pa.array([83, 68, 122, 124, 117])
180180

181181

182-
data_test_bitwise_functions = [
182+
data_test_bitwise_and_boolean_functions = [
183183
("bit_and", f.bit_and(column("a")), [0]),
184184
("bit_and_filter", f.bit_and(column("a"), filter=column("a") != lit(2)), [1]),
185185
("bit_or", f.bit_or(column("b")), [6]),
@@ -192,29 +192,19 @@ def test_aggregate_100(df_aggregate_100):
192192
f.bit_xor(column("b"), distinct=True, filter=column("a") != lit(3)),
193193
[4],
194194
),
195+
("bool_and", f.bool_and(column("d")), [False]),
196+
("bool_and_filter", f.bool_and(column("d"), filter=column("a") != lit(3)), [True]),
197+
("bool_or", f.bool_or(column("d")), [True]),
198+
("bool_or_filter", f.bool_or(column("d"), filter=column("a") == lit(3)), [False]),
195199
]
196200

197201

198-
@pytest.mark.parametrize("name,expr,result", data_test_bitwise_functions)
199-
def test_bit_add_or_xor(df, name, expr, result):
202+
@pytest.mark.parametrize("name,expr,result", data_test_bitwise_and_boolean_functions)
203+
def test_bit_and_bool_fns(df, name, expr, result):
200204
df = df.aggregate([], [expr.alias(name)])
201205

202206
expected = {
203207
name: result,
204208
}
205209

206210
assert df.collect()[0].to_pydict() == expected
207-
208-
209-
def test_bool_and_or(df):
210-
df = df.aggregate(
211-
[],
212-
[
213-
f.bool_and(column("d")),
214-
f.bool_or(column("d")),
215-
],
216-
)
217-
result = df.collect()
218-
result = result[0]
219-
assert result.column(0) == pa.array([False])
220-
assert result.column(1) == pa.array([True])

0 commit comments

Comments
 (0)