Skip to content

Commit 76f5110

Browse files
authored
Implement TPCH substrait integration teset, support tpch_1 (#10842)
* support tpch_1 consumer_producer_test * refactor and optimize code
1 parent 1e37066 commit 76f5110

File tree

7 files changed

+1013
-47
lines changed

7 files changed

+1013
-47
lines changed

datafusion/substrait/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ object_store = { workspace = true }
4141
pbjson-types = "0.6"
4242
prost = "0.12"
4343
substrait = { version = "0.34.0", features = ["serde"] }
44+
url = { workspace = true }
4445

4546
[dev-dependencies]
4647
datafusion-functions-aggregate = { workspace = true }

datafusion/substrait/src/logical_plan/consumer.rs

Lines changed: 109 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ use datafusion::arrow::datatypes::{
2222
use datafusion::common::{
2323
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
2424
};
25+
use substrait::proto::expression::literal::IntervalDayToSecond;
26+
use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile;
27+
use url::Url;
2528

2629
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
2730
use datafusion::execution::FunctionRegistry;
@@ -45,7 +48,7 @@ use datafusion::{
4548
use substrait::proto::exchange_rel::ExchangeKind;
4649
use substrait::proto::expression::literal::user_defined::Val;
4750
use substrait::proto::expression::subquery::SubqueryType;
48-
use substrait::proto::expression::{FieldReference, Literal, ScalarFunction};
51+
use substrait::proto::expression::{self, FieldReference, Literal, ScalarFunction};
4952
use substrait::proto::{
5053
aggregate_function::AggregationInvocation,
5154
expression::{
@@ -129,14 +132,7 @@ fn scalar_function_type_from_str(
129132
name: &str,
130133
) -> Result<ScalarFunctionType> {
131134
let s = ctx.state();
132-
let name = match name.rsplit_once(':') {
133-
// Since 0.32.0, Substrait requires the function names to be in a compound format
134-
// https://substrait.io/extensions/#function-signature-compound-names
135-
// for example, `add:i8_i8`.
136-
// On the consumer side, we don't really care about the signature though, just the name.
137-
Some((name, _)) => name,
138-
None => name,
139-
};
135+
let name = substrait_fun_name(name);
140136

141137
if let Some(func) = s.scalar_functions().get(name) {
142138
return Ok(ScalarFunctionType::Udf(func.to_owned()));
@@ -153,6 +149,18 @@ fn scalar_function_type_from_str(
153149
not_impl_err!("Unsupported function name: {name:?}")
154150
}
155151

152+
pub fn substrait_fun_name(name: &str) -> &str {
153+
let name = match name.rsplit_once(':') {
154+
// Since 0.32.0, Substrait requires the function names to be in a compound format
155+
// https://substrait.io/extensions/#function-signature-compound-names
156+
// for example, `add:i8_i8`.
157+
// On the consumer side, we don't really care about the signature though, just the name.
158+
Some((name, _)) => name,
159+
None => name,
160+
};
161+
name
162+
}
163+
156164
fn split_eq_and_noneq_join_predicate_with_nulls_equality(
157165
filter: &Expr,
158166
) -> (Vec<(Column, Column)>, bool, Option<Expr>) {
@@ -239,6 +247,43 @@ pub async fn from_substrait_plan(
239247
}
240248
}
241249

250+
/// parse projection
251+
pub fn extract_projection(
252+
t: LogicalPlan,
253+
projection: &::core::option::Option<expression::MaskExpression>,
254+
) -> Result<LogicalPlan> {
255+
match projection {
256+
Some(MaskExpression { select, .. }) => match &select.as_ref() {
257+
Some(projection) => {
258+
let column_indices: Vec<usize> = projection
259+
.struct_items
260+
.iter()
261+
.map(|item| item.field as usize)
262+
.collect();
263+
match t {
264+
LogicalPlan::TableScan(mut scan) => {
265+
let fields = column_indices
266+
.iter()
267+
.map(|i| scan.projected_schema.qualified_field(*i))
268+
.map(|(qualifier, field)| {
269+
(qualifier.cloned(), Arc::new(field.clone()))
270+
})
271+
.collect();
272+
scan.projection = Some(column_indices);
273+
scan.projected_schema = DFSchemaRef::new(
274+
DFSchema::new_with_metadata(fields, HashMap::new())?,
275+
);
276+
Ok(LogicalPlan::TableScan(scan))
277+
}
278+
_ => plan_err!("unexpected plan for table"),
279+
}
280+
}
281+
_ => Ok(t),
282+
},
283+
_ => Ok(t),
284+
}
285+
}
286+
242287
/// Convert Substrait Rel to DataFusion DataFrame
243288
#[async_recursion]
244289
pub async fn from_substrait_rel(
@@ -408,7 +453,6 @@ pub async fn from_substrait_rel(
408453
};
409454
aggr_expr.push(agg_func?.as_ref().clone());
410455
}
411-
412456
input.aggregate(group_expr, aggr_expr)?.build()
413457
} else {
414458
not_impl_err!("Aggregate without an input is not valid")
@@ -489,41 +533,7 @@ pub async fn from_substrait_rel(
489533
};
490534
let t = ctx.table(table_reference).await?;
491535
let t = t.into_optimized_plan()?;
492-
match &read.projection {
493-
Some(MaskExpression { select, .. }) => match &select.as_ref() {
494-
Some(projection) => {
495-
let column_indices: Vec<usize> = projection
496-
.struct_items
497-
.iter()
498-
.map(|item| item.field as usize)
499-
.collect();
500-
match &t {
501-
LogicalPlan::TableScan(scan) => {
502-
let fields = column_indices
503-
.iter()
504-
.map(|i| {
505-
scan.projected_schema.qualified_field(*i)
506-
})
507-
.map(|(qualifier, field)| {
508-
(qualifier.cloned(), Arc::new(field.clone()))
509-
})
510-
.collect();
511-
let mut scan = scan.clone();
512-
scan.projection = Some(column_indices);
513-
scan.projected_schema =
514-
DFSchemaRef::new(DFSchema::new_with_metadata(
515-
fields,
516-
HashMap::new(),
517-
)?);
518-
Ok(LogicalPlan::TableScan(scan))
519-
}
520-
_ => plan_err!("unexpected plan for table"),
521-
}
522-
}
523-
_ => Ok(t),
524-
},
525-
_ => Ok(t),
526-
}
536+
extract_projection(t, &read.projection)
527537
}
528538
Some(ReadType::VirtualTable(vt)) => {
529539
let base_schema = read.base_schema.as_ref().ok_or_else(|| {
@@ -569,7 +579,42 @@ pub async fn from_substrait_rel(
569579

570580
Ok(LogicalPlan::Values(Values { schema, values }))
571581
}
572-
_ => not_impl_err!("Only NamedTable and VirtualTable reads are supported"),
582+
Some(ReadType::LocalFiles(lf)) => {
583+
fn extract_filename(name: &str) -> Option<String> {
584+
let corrected_url =
585+
if name.starts_with("file://") && !name.starts_with("file:///") {
586+
name.replacen("file://", "file:///", 1)
587+
} else {
588+
name.to_string()
589+
};
590+
591+
Url::parse(&corrected_url).ok().and_then(|url| {
592+
let path = url.path();
593+
std::path::Path::new(path)
594+
.file_name()
595+
.map(|filename| filename.to_string_lossy().to_string())
596+
})
597+
}
598+
599+
// we could use the file name to check the original table provider
600+
// TODO: currently does not support multiple local files
601+
let filename: Option<String> =
602+
lf.items.first().and_then(|x| match x.path_type.as_ref() {
603+
Some(UriFile(name)) => extract_filename(name),
604+
_ => None,
605+
});
606+
607+
if lf.items.len() > 1 || filename.is_none() {
608+
return not_impl_err!("Only single file reads are supported");
609+
}
610+
let name = filename.unwrap();
611+
// directly use unwrap here since we could determine it is a valid one
612+
let table_reference = TableReference::Bare { table: name.into() };
613+
let t = ctx.table(table_reference).await?;
614+
let t = t.into_optimized_plan()?;
615+
extract_projection(t, &read.projection)
616+
}
617+
_ => not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type),
573618
},
574619
Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) {
575620
Ok(set_op) => match set_op {
@@ -810,14 +855,22 @@ pub async fn from_substrait_agg_func(
810855
f.function_reference
811856
);
812857
};
813-
858+
// function_name.split(':').next().unwrap_or(function_name);
859+
let function_name = substrait_fun_name((**function_name).as_str());
814860
// try udaf first, then built-in aggr fn.
815861
if let Ok(fun) = ctx.udaf(function_name) {
816862
Ok(Arc::new(Expr::AggregateFunction(
817863
expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None),
818864
)))
819865
} else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name)
820866
{
867+
match &fun {
868+
// deal with situation that count(*) got no arguments
869+
aggregate_function::AggregateFunction::Count if args.is_empty() => {
870+
args.push(Expr::Literal(ScalarValue::Int64(Some(1))));
871+
}
872+
_ => {}
873+
}
821874
Ok(Arc::new(Expr::AggregateFunction(
822875
expr::AggregateFunction::new(fun, args, distinct, filter, order_by, None),
823876
)))
@@ -1261,6 +1314,8 @@ fn from_substrait_type(
12611314
r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type(
12621315
s, dfs_names, name_idx,
12631316
)?)),
1317+
r#type::Kind::Varchar(_) => Ok(DataType::Utf8),
1318+
r#type::Kind::FixedChar(_) => Ok(DataType::Utf8),
12641319
_ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"),
12651320
},
12661321
_ => not_impl_err!("`None` Substrait kind is not supported"),
@@ -1549,6 +1604,13 @@ fn from_substrait_literal(
15491604
Some(LiteralType::Null(ntype)) => {
15501605
from_substrait_null(ntype, dfs_names, name_idx)?
15511606
}
1607+
Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond {
1608+
days,
1609+
seconds,
1610+
microseconds,
1611+
})) => {
1612+
ScalarValue::new_interval_dt(*days, (seconds * 1000) + (microseconds / 1000))
1613+
}
15521614
Some(LiteralType::UserDefined(user_defined)) => {
15531615
match user_defined.type_reference {
15541616
INTERVAL_YEAR_MONTH_TYPE_REF => {
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! TPCH `substrait_consumer` tests
19+
//!
20+
//! This module tests that substrait plans as json encoded protobuf can be
21+
//! correctly read as DataFusion plans.
22+
//!
23+
//! The input data comes from <https://github.com/substrait-io/consumer-testing/tree/main/substrait_consumer/tests/integration/queries/tpch_substrait_plans>
24+
25+
#[cfg(test)]
26+
mod tests {
27+
use datafusion::common::Result;
28+
use datafusion::execution::options::CsvReadOptions;
29+
use datafusion::prelude::SessionContext;
30+
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
31+
use std::fs::File;
32+
use std::io::BufReader;
33+
use substrait::proto::Plan;
34+
35+
#[tokio::test]
36+
async fn tpch_test_1() -> Result<()> {
37+
let ctx = create_context().await?;
38+
let path = "tests/testdata/tpch_substrait_plans/query_1.json";
39+
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
40+
File::open(path).expect("file not found"),
41+
))
42+
.expect("failed to parse json");
43+
44+
let plan = from_substrait_plan(&ctx, &proto).await?;
45+
46+
assert!(
47+
format!("{:?}", plan).eq_ignore_ascii_case(
48+
"Sort: FILENAME_PLACEHOLDER_0.l_returnflag ASC NULLS LAST, FILENAME_PLACEHOLDER_0.l_linestatus ASC NULLS LAST\n \
49+
Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus]], aggr=[[SUM(FILENAME_PLACEHOLDER_0.l_quantity), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax), AVG(FILENAME_PLACEHOLDER_0.l_quantity), AVG(FILENAME_PLACEHOLDER_0.l_extendedprice), AVG(FILENAME_PLACEHOLDER_0.l_discount), COUNT(Int64(1))]]\n \
50+
Projection: FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus, FILENAME_PLACEHOLDER_0.l_quantity, FILENAME_PLACEHOLDER_0.l_extendedprice, FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount), FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount) * (CAST(Int32(1) AS Decimal128(19, 0)) + FILENAME_PLACEHOLDER_0.l_tax), FILENAME_PLACEHOLDER_0.l_discount\n \
51+
Filter: FILENAME_PLACEHOLDER_0.l_shipdate <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 120, milliseconds: 0 }\")\n \
52+
TableScan: FILENAME_PLACEHOLDER_0 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]"
53+
)
54+
);
55+
Ok(())
56+
}
57+
58+
async fn create_context() -> datafusion::common::Result<SessionContext> {
59+
let ctx = SessionContext::new();
60+
ctx.register_csv(
61+
"FILENAME_PLACEHOLDER_0",
62+
"tests/testdata/tpch/lineitem.csv",
63+
CsvReadOptions::default(),
64+
)
65+
.await?;
66+
Ok(ctx)
67+
}
68+
}

datafusion/substrait/tests/cases/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
mod consumer_integration;
1819
mod logical_plans;
1920
mod roundtrip_logical_plan;
2021
mod roundtrip_physical_plan;
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
l_orderkey,l_partkey,l_suppkey,l_linenumber,l_quantity,l_extendedprice,l_discount,l_tax,l_returnflag,l_linestatus,l_shipdate,l_commitdate,l_receiptdate,l_shipinstruct,l_shipmode,l_comment
2+
1,1,1,1,17,21168.23,0.04,0.02,'N','O','1996-03-13','1996-02-12','1996-03-22','DELIVER IN PERSON','TRUCK','egular courts above the'
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
<!---
2+
Licensed to the Apache Software Foundation (ASF) under one
3+
or more contributor license agreements. See the NOTICE file
4+
distributed with this work for additional information
5+
regarding copyright ownership. The ASF licenses this file
6+
to you under the Apache License, Version 2.0 (the
7+
"License"); you may not use this file except in compliance
8+
with the License. You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing,
13+
software distributed under the License is distributed on an
14+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
KIND, either express or implied. See the License for the
16+
specific language governing permissions and limitations
17+
under the License.
18+
-->
19+
20+
# Apache DataFusion Substrait consumer integration test
21+
22+
these test json files come from [consumer-testing](https://github.com/substrait-io/consumer-testing/tree/main/substrait_consumer/tests/integration/queries/tpch_substrait_plans)

0 commit comments

Comments
 (0)