Skip to content

Commit faa5a3f

Browse files
authored
Tsaucer/find window fn (#747)
* Add a search order when attempting to locate the appropriate window function * Remove unnecessary markings * Linting * Code cleanup
1 parent a3908ed commit faa5a3f

12 files changed

+100
-40
lines changed

examples/tpch/_tests.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from datafusion import col, lit, functions as F
2222
from util import get_answer_file
2323

24+
2425
def df_selection(col_name, col_type):
2526
if col_type == pa.float64() or isinstance(col_type, pa.Decimal128Type):
2627
return F.round(col(col_name), lit(2)).alias(col_name)
@@ -29,14 +30,16 @@ def df_selection(col_name, col_type):
2930
else:
3031
return col(col_name)
3132

33+
3234
def load_schema(col_name, col_type):
3335
if col_type == pa.int64() or col_type == pa.int32():
3436
return col_name, pa.string()
3537
elif isinstance(col_type, pa.Decimal128Type):
3638
return col_name, pa.float64()
3739
else:
3840
return col_name, col_type
39-
41+
42+
4043
def expected_selection(col_name, col_type):
4144
if col_type == pa.int64() or col_type == pa.int32():
4245
return F.trim(col(col_name)).cast(col_type).alias(col_name)
@@ -45,20 +48,23 @@ def expected_selection(col_name, col_type):
4548
else:
4649
return col(col_name)
4750

51+
4852
def selections_and_schema(original_schema):
49-
columns = [ (c, original_schema.field(c).type) for c in original_schema.names ]
53+
columns = [(c, original_schema.field(c).type) for c in original_schema.names]
5054

51-
df_selections = [ df_selection(c, t) for (c, t) in columns]
52-
expected_schema = [ load_schema(c, t) for (c, t) in columns]
53-
expected_selections = [ expected_selection(c, t) for (c, t) in columns]
55+
df_selections = [df_selection(c, t) for (c, t) in columns]
56+
expected_schema = [load_schema(c, t) for (c, t) in columns]
57+
expected_selections = [expected_selection(c, t) for (c, t) in columns]
5458

5559
return (df_selections, expected_schema, expected_selections)
5660

61+
5762
def check_q17(df):
5863
raw_value = float(df.collect()[0]["avg_yearly"][0].as_py())
5964
value = round(raw_value, 2)
6065
assert abs(value - 348406.05) < 0.001
6166

67+
6268
@pytest.mark.parametrize(
6369
("query_code", "answer_file"),
6470
[
@@ -72,9 +78,7 @@ def check_q17(df):
7278
("q08_market_share", "q8"),
7379
("q09_product_type_profit_measure", "q9"),
7480
("q10_returned_item_reporting", "q10"),
75-
pytest.param(
76-
"q11_important_stock_identification", "q11",
77-
),
81+
("q11_important_stock_identification", "q11"),
7882
("q12_ship_mode_order_priority", "q12"),
7983
("q13_customer_distribution", "q13"),
8084
("q14_promotion_effect", "q14"),
@@ -97,13 +101,20 @@ def test_tpch_query_vs_answer_file(query_code: str, answer_file: str):
97101
if answer_file == "q17":
98102
return check_q17(df)
99103

100-
(df_selections, expected_schema, expected_selections) = selections_and_schema(df.schema())
104+
(df_selections, expected_schema, expected_selections) = selections_and_schema(
105+
df.schema()
106+
)
101107

102108
df = df.select(*df_selections)
103109

104110
read_schema = pa.schema(expected_schema)
105111

106-
df_expected = module.ctx.read_csv(get_answer_file(answer_file), schema=read_schema, delimiter="|", file_extension=".out")
112+
df_expected = module.ctx.read_csv(
113+
get_answer_file(answer_file),
114+
schema=read_schema,
115+
delimiter="|",
116+
file_extension=".out",
117+
)
107118

108119
df_expected = df_expected.select(*expected_selections)
109120

examples/tpch/convert_data_to_parquet.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,14 @@
117117

118118
curr_dir = os.path.dirname(os.path.abspath(__file__))
119119
for filename, curr_schema in all_schemas.items():
120-
121120
# For convenience, go ahead and convert the schema column names to lowercase
122121
curr_schema = [(s[0].lower(), s[1]) for s in curr_schema]
123122

124123
# Pre-collect the output columns so we can ignore the null field we add
125124
# in to handle the trailing | in the file
126125
output_cols = [r[0] for r in curr_schema]
127126

128-
curr_schema = [ pyarrow.field(r[0], r[1], nullable=False) for r in curr_schema]
127+
curr_schema = [pyarrow.field(r[0], r[1], nullable=False) for r in curr_schema]
129128

130129
# Trailing | requires extra field for in processing
131130
curr_schema.append(("some_null", pyarrow.null()))

examples/tpch/q08_market_share.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@
4747

4848
ctx = SessionContext()
4949

50-
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_type")
50+
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns(
51+
"p_partkey", "p_type"
52+
)
5153
df_supplier = ctx.read_parquet(get_data_path("supplier.parquet")).select_columns(
5254
"s_suppkey", "s_nationkey"
5355
)

examples/tpch/q09_product_type_profit_measure.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@
3939

4040
ctx = SessionContext()
4141

42-
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_name")
42+
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns(
43+
"p_partkey", "p_name"
44+
)
4345
df_supplier = ctx.read_parquet(get_data_path("supplier.parquet")).select_columns(
4446
"s_suppkey", "s_nationkey"
4547
)

examples/tpch/q13_customer_distribution.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141
df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
4242
"o_custkey", "o_comment"
4343
)
44-
df_customer = ctx.read_parquet(get_data_path("customer.parquet")).select_columns("c_custkey")
44+
df_customer = ctx.read_parquet(get_data_path("customer.parquet")).select_columns(
45+
"c_custkey"
46+
)
4547

4648
# Use a regex to remove special cases
4749
df_orders = df_orders.filter(

examples/tpch/q14_promotion_effect.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@
4444
df_lineitem = ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
4545
"l_partkey", "l_shipdate", "l_extendedprice", "l_discount"
4646
)
47-
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_type")
47+
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns(
48+
"p_partkey", "p_type"
49+
)
4850

4951

5052
# Check part type begins with PROMO

examples/tpch/q16_part_supplier_relationship.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@
6262
# Select the parts we are interested in
6363
df_part = df_part.filter(col("p_brand") != lit(BRAND))
6464
df_part = df_part.filter(
65-
F.substring(col("p_type"), lit(0), lit(len(TYPE_TO_IGNORE) + 1)) != lit(TYPE_TO_IGNORE)
65+
F.substring(col("p_type"), lit(0), lit(len(TYPE_TO_IGNORE) + 1))
66+
!= lit(TYPE_TO_IGNORE)
6667
)
6768

6869
# Python conversion of integer to literal casts it to int64 but the data for

examples/tpch/q17_small_quantity_order.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,13 @@
5656
# Find the average quantity
5757
window_frame = WindowFrame("rows", None, None)
5858
df = df.with_column(
59-
"avg_quantity", F.window("avg", [col("l_quantity")], window_frame=window_frame, partition_by=[col("l_partkey")])
59+
"avg_quantity",
60+
F.window(
61+
"avg",
62+
[col("l_quantity")],
63+
window_frame=window_frame,
64+
partition_by=[col("l_partkey")],
65+
),
6066
)
6167

6268
df = df.filter(col("l_quantity") < lit(0.2) * col("avg_quantity"))

examples/tpch/q20_potential_part_promotion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@
4040

4141
ctx = SessionContext()
4242

43-
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_name")
43+
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns(
44+
"p_partkey", "p_name"
45+
)
4446
df_lineitem = ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
4547
"l_shipdate", "l_partkey", "l_suppkey", "l_quantity"
4648
)

examples/tpch/q22_global_sales_opportunity.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@
3838
df_customer = ctx.read_parquet(get_data_path("customer.parquet")).select_columns(
3939
"c_phone", "c_acctbal", "c_custkey"
4040
)
41-
df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns("o_custkey")
41+
df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
42+
"o_custkey"
43+
)
4244

4345
# The nation code is a two digit number, but we need to convert it to a string literal
4446
nation_codes = F.make_array(*[lit(str(n)) for n in NATION_CODES])

examples/tpch/util.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,17 @@
2020
"""
2121

2222
import os
23-
from pathlib import Path
23+
2424

2525
def get_data_path(filename: str) -> str:
2626
path = os.path.dirname(os.path.abspath(__file__))
2727

2828
return os.path.join(path, "data", filename)
2929

30+
3031
def get_answer_file(answer_file: str) -> str:
3132
path = os.path.dirname(os.path.abspath(__file__))
3233

33-
return os.path.join(path, "../../benchmarks/tpch/data/answers", f"{answer_file}.out")
34+
return os.path.join(
35+
path, "../../benchmarks/tpch/data/answers", f"{answer_file}.out"
36+
)

src/functions.rs

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use datafusion::functions_aggregate::all_default_aggregate_functions;
1819
use pyo3::{prelude::*, wrap_pyfunction};
1920

2021
use crate::common::data_type::NullTreatment;
@@ -311,6 +312,50 @@ fn case(expr: PyExpr) -> PyResult<PyCaseBuilder> {
311312
})
312313
}
313314

315+
/// Helper function to find the appropriate window function. First, if a session
316+
/// context is defined check it's registered functions. If no context is defined,
317+
/// attempt to find from all default functions. Lastly, as a fall back attempt
318+
/// to use built in window functions, which are being deprecated.
319+
fn find_window_fn(name: &str, ctx: Option<PySessionContext>) -> PyResult<WindowFunctionDefinition> {
320+
let mut maybe_fn = match &ctx {
321+
Some(ctx) => {
322+
let session_state = ctx.ctx.state();
323+
324+
match session_state.window_functions().contains_key(name) {
325+
true => session_state
326+
.window_functions()
327+
.get(name)
328+
.map(|f| WindowFunctionDefinition::WindowUDF(f.clone())),
329+
false => session_state
330+
.aggregate_functions()
331+
.get(name)
332+
.map(|f| WindowFunctionDefinition::AggregateUDF(f.clone())),
333+
}
334+
}
335+
None => {
336+
let default_aggregate_fns = all_default_aggregate_functions();
337+
338+
default_aggregate_fns
339+
.iter()
340+
.find(|v| v.aliases().contains(&name.to_string()))
341+
.map(|f| WindowFunctionDefinition::AggregateUDF(f.clone()))
342+
}
343+
};
344+
345+
if maybe_fn.is_none() {
346+
maybe_fn = find_df_window_func(name).or_else(|| {
347+
ctx.and_then(|ctx| {
348+
ctx.ctx
349+
.udaf(name)
350+
.map(WindowFunctionDefinition::AggregateUDF)
351+
.ok()
352+
})
353+
});
354+
}
355+
356+
maybe_fn.ok_or(DataFusionError::Common("window function not found".to_string()).into())
357+
}
358+
314359
/// Creates a new Window function expression
315360
#[pyfunction]
316361
fn window(
@@ -321,24 +366,7 @@ fn window(
321366
window_frame: Option<PyWindowFrame>,
322367
ctx: Option<PySessionContext>,
323368
) -> PyResult<PyExpr> {
324-
// workaround for https://github.com/apache/datafusion-python/issues/730
325-
let fun = if name == "sum" {
326-
let sum_udf = functions_aggregate::sum::sum_udaf();
327-
Some(WindowFunctionDefinition::AggregateUDF(sum_udf))
328-
} else {
329-
find_df_window_func(name).or_else(|| {
330-
ctx.and_then(|ctx| {
331-
ctx.ctx
332-
.udaf(name)
333-
.map(WindowFunctionDefinition::AggregateUDF)
334-
.ok()
335-
})
336-
})
337-
};
338-
if fun.is_none() {
339-
return Err(DataFusionError::Common("window function not found".to_string()).into());
340-
}
341-
let fun = fun.unwrap();
369+
let fun = find_window_fn(name, ctx)?;
342370
let window_frame = window_frame
343371
.unwrap_or_else(|| PyWindowFrame::new("rows", None, Some(0)).unwrap())
344372
.into();

0 commit comments

Comments
 (0)