Skip to content

Tsaucer/find window fn #747

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions examples/tpch/_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from datafusion import col, lit, functions as F
from util import get_answer_file


def df_selection(col_name, col_type):
if col_type == pa.float64() or isinstance(col_type, pa.Decimal128Type):
return F.round(col(col_name), lit(2)).alias(col_name)
Expand All @@ -29,14 +30,16 @@ def df_selection(col_name, col_type):
else:
return col(col_name)


def load_schema(col_name, col_type):
if col_type == pa.int64() or col_type == pa.int32():
return col_name, pa.string()
elif isinstance(col_type, pa.Decimal128Type):
return col_name, pa.float64()
else:
return col_name, col_type



def expected_selection(col_name, col_type):
if col_type == pa.int64() or col_type == pa.int32():
return F.trim(col(col_name)).cast(col_type).alias(col_name)
Expand All @@ -45,20 +48,23 @@ def expected_selection(col_name, col_type):
else:
return col(col_name)


def selections_and_schema(original_schema):
columns = [ (c, original_schema.field(c).type) for c in original_schema.names ]
columns = [(c, original_schema.field(c).type) for c in original_schema.names]

df_selections = [ df_selection(c, t) for (c, t) in columns]
expected_schema = [ load_schema(c, t) for (c, t) in columns]
expected_selections = [ expected_selection(c, t) for (c, t) in columns]
df_selections = [df_selection(c, t) for (c, t) in columns]
expected_schema = [load_schema(c, t) for (c, t) in columns]
expected_selections = [expected_selection(c, t) for (c, t) in columns]

return (df_selections, expected_schema, expected_selections)


def check_q17(df):
raw_value = float(df.collect()[0]["avg_yearly"][0].as_py())
value = round(raw_value, 2)
assert abs(value - 348406.05) < 0.001


@pytest.mark.parametrize(
("query_code", "answer_file"),
[
Expand All @@ -72,9 +78,7 @@ def check_q17(df):
("q08_market_share", "q8"),
("q09_product_type_profit_measure", "q9"),
("q10_returned_item_reporting", "q10"),
pytest.param(
"q11_important_stock_identification", "q11",
),
("q11_important_stock_identification", "q11"),
("q12_ship_mode_order_priority", "q12"),
("q13_customer_distribution", "q13"),
("q14_promotion_effect", "q14"),
Expand All @@ -97,13 +101,20 @@ def test_tpch_query_vs_answer_file(query_code: str, answer_file: str):
if answer_file == "q17":
return check_q17(df)

(df_selections, expected_schema, expected_selections) = selections_and_schema(df.schema())
(df_selections, expected_schema, expected_selections) = selections_and_schema(
df.schema()
)

df = df.select(*df_selections)

read_schema = pa.schema(expected_schema)

df_expected = module.ctx.read_csv(get_answer_file(answer_file), schema=read_schema, delimiter="|", file_extension=".out")
df_expected = module.ctx.read_csv(
get_answer_file(answer_file),
schema=read_schema,
delimiter="|",
file_extension=".out",
)

df_expected = df_expected.select(*expected_selections)

Expand Down
3 changes: 1 addition & 2 deletions examples/tpch/convert_data_to_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,14 @@

curr_dir = os.path.dirname(os.path.abspath(__file__))
for filename, curr_schema in all_schemas.items():

# For convenience, go ahead and convert the schema column names to lowercase
curr_schema = [(s[0].lower(), s[1]) for s in curr_schema]

# Pre-collect the output columns so we can ignore the null field we add
# in to handle the trailing | in the file
output_cols = [r[0] for r in curr_schema]

curr_schema = [ pyarrow.field(r[0], r[1], nullable=False) for r in curr_schema]
curr_schema = [pyarrow.field(r[0], r[1], nullable=False) for r in curr_schema]

# Trailing | requires extra field for in processing
curr_schema.append(("some_null", pyarrow.null()))
Expand Down
4 changes: 3 additions & 1 deletion examples/tpch/q08_market_share.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@

ctx = SessionContext()

df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_type")
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns(
"p_partkey", "p_type"
)
df_supplier = ctx.read_parquet(get_data_path("supplier.parquet")).select_columns(
"s_suppkey", "s_nationkey"
)
Expand Down
4 changes: 3 additions & 1 deletion examples/tpch/q09_product_type_profit_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@

ctx = SessionContext()

df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_name")
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns(
"p_partkey", "p_name"
)
df_supplier = ctx.read_parquet(get_data_path("supplier.parquet")).select_columns(
"s_suppkey", "s_nationkey"
)
Expand Down
4 changes: 3 additions & 1 deletion examples/tpch/q13_customer_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
"o_custkey", "o_comment"
)
df_customer = ctx.read_parquet(get_data_path("customer.parquet")).select_columns("c_custkey")
df_customer = ctx.read_parquet(get_data_path("customer.parquet")).select_columns(
"c_custkey"
)

# Use a regex to remove special cases
df_orders = df_orders.filter(
Expand Down
4 changes: 3 additions & 1 deletion examples/tpch/q14_promotion_effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
df_lineitem = ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_partkey", "l_shipdate", "l_extendedprice", "l_discount"
)
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_type")
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns(
"p_partkey", "p_type"
)


# Check part type begins with PROMO
Expand Down
3 changes: 2 additions & 1 deletion examples/tpch/q16_part_supplier_relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@
# Select the parts we are interested in
df_part = df_part.filter(col("p_brand") != lit(BRAND))
df_part = df_part.filter(
F.substring(col("p_type"), lit(0), lit(len(TYPE_TO_IGNORE) + 1)) != lit(TYPE_TO_IGNORE)
F.substring(col("p_type"), lit(0), lit(len(TYPE_TO_IGNORE) + 1))
!= lit(TYPE_TO_IGNORE)
)

# Python conversion of integer to literal casts it to int64 but the data for
Expand Down
8 changes: 7 additions & 1 deletion examples/tpch/q17_small_quantity_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@
# Find the average quantity
window_frame = WindowFrame("rows", None, None)
df = df.with_column(
"avg_quantity", F.window("avg", [col("l_quantity")], window_frame=window_frame, partition_by=[col("l_partkey")])
"avg_quantity",
F.window(
"avg",
[col("l_quantity")],
window_frame=window_frame,
partition_by=[col("l_partkey")],
),
)

df = df.filter(col("l_quantity") < lit(0.2) * col("avg_quantity"))
Expand Down
4 changes: 3 additions & 1 deletion examples/tpch/q20_potential_part_promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@

ctx = SessionContext()

df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_name")
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns(
"p_partkey", "p_name"
)
df_lineitem = ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
"l_shipdate", "l_partkey", "l_suppkey", "l_quantity"
)
Expand Down
4 changes: 3 additions & 1 deletion examples/tpch/q22_global_sales_opportunity.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
df_customer = ctx.read_parquet(get_data_path("customer.parquet")).select_columns(
"c_phone", "c_acctbal", "c_custkey"
)
df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns("o_custkey")
df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
"o_custkey"
)

# The nation code is a two digit number, but we need to convert it to a string literal
nation_codes = F.make_array(*[lit(str(n)) for n in NATION_CODES])
Expand Down
7 changes: 5 additions & 2 deletions examples/tpch/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
"""

import os
from pathlib import Path


def get_data_path(filename: str) -> str:
path = os.path.dirname(os.path.abspath(__file__))

return os.path.join(path, "data", filename)


def get_answer_file(answer_file: str) -> str:
path = os.path.dirname(os.path.abspath(__file__))

return os.path.join(path, "../../benchmarks/tpch/data/answers", f"{answer_file}.out")
return os.path.join(
path, "../../benchmarks/tpch/data/answers", f"{answer_file}.out"
)
64 changes: 46 additions & 18 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use datafusion::functions_aggregate::all_default_aggregate_functions;
use pyo3::{prelude::*, wrap_pyfunction};

use crate::common::data_type::NullTreatment;
Expand Down Expand Up @@ -311,6 +312,50 @@ fn case(expr: PyExpr) -> PyResult<PyCaseBuilder> {
})
}

/// Helper function to find the appropriate window function. First, if a session
/// context is defined check it's registered functions. If no context is defined,
/// attempt to find from all default functions. Lastly, as a fall back attempt
/// to use built in window functions, which are being deprecated.
fn find_window_fn(name: &str, ctx: Option<PySessionContext>) -> PyResult<WindowFunctionDefinition> {
let mut maybe_fn = match &ctx {
Some(ctx) => {
let session_state = ctx.ctx.state();

match session_state.window_functions().contains_key(name) {
true => session_state
.window_functions()
.get(name)
.map(|f| WindowFunctionDefinition::WindowUDF(f.clone())),
false => session_state
.aggregate_functions()
.get(name)
.map(|f| WindowFunctionDefinition::AggregateUDF(f.clone())),
}
}
None => {
let default_aggregate_fns = all_default_aggregate_functions();

default_aggregate_fns
.iter()
.find(|v| v.aliases().contains(&name.to_string()))
.map(|f| WindowFunctionDefinition::AggregateUDF(f.clone()))
}
};

if maybe_fn.is_none() {
maybe_fn = find_df_window_func(name).or_else(|| {
ctx.and_then(|ctx| {
ctx.ctx
.udaf(name)
.map(WindowFunctionDefinition::AggregateUDF)
.ok()
})
});
}

maybe_fn.ok_or(DataFusionError::Common("window function not found".to_string()).into())
}

/// Creates a new Window function expression
#[pyfunction]
fn window(
Expand All @@ -321,24 +366,7 @@ fn window(
window_frame: Option<PyWindowFrame>,
ctx: Option<PySessionContext>,
) -> PyResult<PyExpr> {
// workaround for https://github.com/apache/datafusion-python/issues/730
let fun = if name == "sum" {
let sum_udf = functions_aggregate::sum::sum_udaf();
Some(WindowFunctionDefinition::AggregateUDF(sum_udf))
} else {
find_df_window_func(name).or_else(|| {
ctx.and_then(|ctx| {
ctx.ctx
.udaf(name)
.map(WindowFunctionDefinition::AggregateUDF)
.ok()
})
})
};
if fun.is_none() {
return Err(DataFusionError::Common("window function not found".to_string()).into());
}
let fun = fun.unwrap();
let fun = find_window_fn(name, ctx)?;
let window_frame = window_frame
.unwrap_or_else(|| PyWindowFrame::new("rows", None, Some(0)).unwrap())
.into();
Expand Down
Loading