Skip to content

Commit 951d6b9

Browse files
authored
Add in user example that compares a two different approaches to UDFs (#770)
* Add in user example that compares a two different approaches to UDFs * add license
1 parent 66bfe36 commit 951d6b9

File tree

1 file changed

+186
-0
lines changed

1 file changed

+186
-0
lines changed

examples/python-udf-comparisons.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from datafusion import SessionContext, col, lit, udf, functions as F
19+
import os
20+
import pyarrow as pa
21+
import pyarrow.compute as pc
22+
import time
23+
24+
path = os.path.dirname(os.path.abspath(__file__))
25+
filepath = os.path.join(path, "../tpch/data/lineitem.parquet")
26+
27+
# This example serves to demonstrate alternate approaches to answering the
28+
# question "return all of the rows that have a specific combination of these
29+
# values". We have the combinations we care about provided as a python
30+
# list of tuples. There is no built in function that supports this operation,
31+
# but it can be explicilty specified via a single expression or we can
32+
# use a user defined function.
33+
34+
ctx = SessionContext()
35+
36+
# These part keys and suppliers are chosen because there are
37+
# cases where two suppliers each have two of the part keys
38+
# but we are interested in these specific combinations.
39+
40+
values_of_interest = [
41+
(1530, 4031, "N"),
42+
(6530, 1531, "N"),
43+
(5618, 619, "N"),
44+
(8118, 8119, "N"),
45+
]
46+
47+
partkeys = [lit(r[0]) for r in values_of_interest]
48+
suppkeys = [lit(r[1]) for r in values_of_interest]
49+
returnflags = [lit(r[2]) for r in values_of_interest]
50+
51+
df_lineitem = ctx.read_parquet(filepath).select(
52+
"l_partkey", "l_suppkey", "l_returnflag"
53+
)
54+
55+
start_time = time.time()
56+
57+
df_simple_filter = df_lineitem.filter(
58+
F.in_list(col("l_partkey"), partkeys),
59+
F.in_list(col("l_suppkey"), suppkeys),
60+
F.in_list(col("l_returnflag"), returnflags),
61+
)
62+
63+
num_rows = df_simple_filter.count()
64+
print(
65+
f"Simple filtering has number {num_rows} rows and took {time.time() - start_time} s"
66+
)
67+
print("This is the incorrect number of rows!")
68+
start_time = time.time()
69+
70+
# Explicitly check for the combinations of interest.
71+
# This works but is not scalable.
72+
73+
filter_expr = (
74+
(
75+
(col("l_partkey") == values_of_interest[0][0])
76+
& (col("l_suppkey") == values_of_interest[0][1])
77+
& (col("l_returnflag") == values_of_interest[0][2])
78+
)
79+
| (
80+
(col("l_partkey") == values_of_interest[1][0])
81+
& (col("l_suppkey") == values_of_interest[1][1])
82+
& (col("l_returnflag") == values_of_interest[1][2])
83+
)
84+
| (
85+
(col("l_partkey") == values_of_interest[2][0])
86+
& (col("l_suppkey") == values_of_interest[2][1])
87+
& (col("l_returnflag") == values_of_interest[2][2])
88+
)
89+
| (
90+
(col("l_partkey") == values_of_interest[3][0])
91+
& (col("l_suppkey") == values_of_interest[3][1])
92+
& (col("l_returnflag") == values_of_interest[3][2])
93+
)
94+
)
95+
96+
df_explicit_filter = df_lineitem.filter(filter_expr)
97+
98+
num_rows = df_explicit_filter.count()
99+
print(
100+
f"Explicit filtering has number {num_rows} rows and took {time.time() - start_time} s"
101+
)
102+
start_time = time.time()
103+
104+
# Instead try a python UDF
105+
106+
107+
def is_of_interest_impl(
108+
partkey_arr: pa.Array,
109+
suppkey_arr: pa.Array,
110+
returnflag_arr: pa.Array,
111+
) -> pa.Array:
112+
result = []
113+
for idx, partkey in enumerate(partkey_arr):
114+
partkey = partkey.as_py()
115+
suppkey = suppkey_arr[idx].as_py()
116+
returnflag = returnflag_arr[idx].as_py()
117+
value = (partkey, suppkey, returnflag)
118+
result.append(value in values_of_interest)
119+
120+
return pa.array(result)
121+
122+
123+
is_of_interest = udf(
124+
is_of_interest_impl,
125+
[pa.int32(), pa.int32(), pa.utf8()],
126+
pa.bool_(),
127+
"stable",
128+
)
129+
130+
df_udf_filter = df_lineitem.filter(
131+
is_of_interest(col("l_partkey"), col("l_suppkey"), col("l_returnflag"))
132+
)
133+
134+
num_rows = df_udf_filter.count()
135+
print(f"UDF filtering has number {num_rows} rows and took {time.time() - start_time} s")
136+
start_time = time.time()
137+
138+
# Now use a user defined function but lean on the built in pyarrow array
139+
# functions so we never convert rows to python objects.
140+
141+
# To see other pyarrow compute functions see
142+
# https://arrow.apache.org/docs/python/api/compute.html
143+
#
144+
# It is important that the number of rows in the returned array
145+
# matches the original array, so we cannot use functions like
146+
# filtered_partkey_arr.filter(filtered_suppkey_arr).
147+
148+
149+
def udf_using_pyarrow_compute_impl(
150+
partkey_arr: pa.Array,
151+
suppkey_arr: pa.Array,
152+
returnflag_arr: pa.Array,
153+
) -> pa.Array:
154+
results = None
155+
for partkey, suppkey, returnflag in values_of_interest:
156+
filtered_partkey_arr = pc.equal(partkey_arr, partkey)
157+
filtered_suppkey_arr = pc.equal(suppkey_arr, suppkey)
158+
filtered_returnflag_arr = pc.equal(returnflag_arr, returnflag)
159+
160+
resultant_arr = pc.and_(filtered_partkey_arr, filtered_suppkey_arr)
161+
resultant_arr = pc.and_(resultant_arr, filtered_returnflag_arr)
162+
163+
if results is None:
164+
results = resultant_arr
165+
else:
166+
results = pc.or_(results, resultant_arr)
167+
168+
return results
169+
170+
171+
udf_using_pyarrow_compute = udf(
172+
udf_using_pyarrow_compute_impl,
173+
[pa.int32(), pa.int32(), pa.utf8()],
174+
pa.bool_(),
175+
"stable",
176+
)
177+
178+
df_udf_pyarrow_compute = df_lineitem.filter(
179+
udf_using_pyarrow_compute(col("l_partkey"), col("l_suppkey"), col("l_returnflag"))
180+
)
181+
182+
num_rows = df_udf_pyarrow_compute.count()
183+
print(
184+
f"UDF filtering using pyarrow compute has number {num_rows} rows and took {time.time() - start_time} s"
185+
)
186+
start_time = time.time()

0 commit comments

Comments
 (0)