Skip to content

Commit 6167ce9

Browse files
authored
feat: add substrait support for Interval types and literals (#10646)
* feat: support interval types Signed-off-by: Ruihang Xia <[email protected]> * impl literals Signed-off-by: Ruihang Xia <[email protected]> * fix deadlink in doc Signed-off-by: Ruihang Xia <[email protected]> --------- Signed-off-by: Ruihang Xia <[email protected]>
1 parent 52c4f3c commit 6167ce9

File tree

4 files changed

+273
-10
lines changed

4 files changed

+273
-10
lines changed

datafusion/substrait/src/logical_plan/consumer.rs

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use async_recursion::async_recursion;
19-
use datafusion::arrow::datatypes::{DataType, Field, TimeUnit};
19+
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
2020
use datafusion::common::{
2121
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
2222
};
@@ -39,6 +39,7 @@ use datafusion::{
3939
scalar::ScalarValue,
4040
};
4141
use substrait::proto::exchange_rel::ExchangeKind;
42+
use substrait::proto::expression::literal::user_defined::Val;
4243
use substrait::proto::expression::subquery::SubqueryType;
4344
use substrait::proto::expression::{FieldReference, Literal, ScalarFunction};
4445
use substrait::proto::{
@@ -71,9 +72,10 @@ use std::sync::Arc;
7172

7273
use crate::variation_const::{
7374
DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF, DECIMAL_256_TYPE_REF,
74-
DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, LARGE_CONTAINER_TYPE_REF,
75-
TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF,
76-
TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF,
75+
DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, INTERVAL_DAY_TIME_TYPE_REF,
76+
INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF,
77+
LARGE_CONTAINER_TYPE_REF, TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF,
78+
TIMESTAMP_NANO_TYPE_REF, TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF,
7779
};
7880

7981
enum ScalarFunctionType {
@@ -1162,6 +1164,24 @@ pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataTyp
11621164
"Unsupported Substrait type variation {v} of type {s_kind:?}"
11631165
),
11641166
},
1167+
r#type::Kind::UserDefined(u) => {
1168+
match u.type_reference {
1169+
INTERVAL_YEAR_MONTH_TYPE_REF => {
1170+
Ok(DataType::Interval(IntervalUnit::YearMonth))
1171+
}
1172+
INTERVAL_DAY_TIME_TYPE_REF => {
1173+
Ok(DataType::Interval(IntervalUnit::DayTime))
1174+
}
1175+
INTERVAL_MONTH_DAY_NANO_TYPE_REF => {
1176+
Ok(DataType::Interval(IntervalUnit::MonthDayNano))
1177+
}
1178+
_ => not_impl_err!(
1179+
"Unsupported Substrait user defined type with ref {} and variation {}",
1180+
u.type_reference,
1181+
u.type_variation_reference
1182+
),
1183+
}
1184+
},
11651185
r#type::Kind::Struct(s) => {
11661186
let mut fields = vec![];
11671187
for (i, f) in s.types.iter().enumerate() {
@@ -1387,6 +1407,54 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
13871407
builder.build()?
13881408
}
13891409
Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?,
1410+
Some(LiteralType::UserDefined(user_defined)) => {
1411+
match user_defined.type_reference {
1412+
INTERVAL_YEAR_MONTH_TYPE_REF => {
1413+
let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else {
1414+
return substrait_err!("Interval year month value is empty");
1415+
};
1416+
let value_slice: [u8; 4] =
1417+
raw_val.value.clone().try_into().map_err(|_| {
1418+
substrait_datafusion_err!(
1419+
"Failed to parse interval year month value"
1420+
)
1421+
})?;
1422+
ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes(value_slice)))
1423+
}
1424+
INTERVAL_DAY_TIME_TYPE_REF => {
1425+
let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else {
1426+
return substrait_err!("Interval day time value is empty");
1427+
};
1428+
let value_slice: [u8; 8] =
1429+
raw_val.value.clone().try_into().map_err(|_| {
1430+
substrait_datafusion_err!(
1431+
"Failed to parse interval day time value"
1432+
)
1433+
})?;
1434+
ScalarValue::IntervalDayTime(Some(i64::from_le_bytes(value_slice)))
1435+
}
1436+
INTERVAL_MONTH_DAY_NANO_TYPE_REF => {
1437+
let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else {
1438+
return substrait_err!("Interval month day nano value is empty");
1439+
};
1440+
let value_slice: [u8; 16] =
1441+
raw_val.value.clone().try_into().map_err(|_| {
1442+
substrait_datafusion_err!(
1443+
"Failed to parse interval month day nano value"
1444+
)
1445+
})?;
1446+
ScalarValue::IntervalMonthDayNano(Some(i128::from_le_bytes(
1447+
value_slice,
1448+
)))
1449+
}
1450+
_ => {
1451+
return not_impl_err!(
1452+
"Unsupported Substrait user defined type with ref {}",
1453+
user_defined.type_reference
1454+
)
1455+
}
1456+
}
1457+
}
13901458
_ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type),
13911459
};
13921460

datafusion/substrait/src/logical_plan/producer.rs

Lines changed: 122 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use std::collections::HashMap;
1919
use std::ops::Deref;
2020
use std::sync::Arc;
2121

22+
use datafusion::arrow::datatypes::IntervalUnit;
2223
use datafusion::logical_expr::{
2324
CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits,
2425
};
@@ -43,9 +44,12 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Opera
4344
use datafusion::prelude::Expr;
4445
use prost_types::Any as ProtoAny;
4546
use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
47+
use substrait::proto::expression::literal::user_defined::Val;
48+
use substrait::proto::expression::literal::UserDefined;
4649
use substrait::proto::expression::literal::{List, Struct};
4750
use substrait::proto::expression::subquery::InPredicate;
4851
use substrait::proto::expression::window_function::BoundsType;
52+
use substrait::proto::r#type::{parameter, Parameter};
4953
use substrait::proto::{CrossRel, ExchangeRel};
5054
use substrait::{
5155
proto::{
@@ -84,9 +88,12 @@ use substrait::{
8488

8589
use crate::variation_const::{
8690
DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF, DECIMAL_256_TYPE_REF,
87-
DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, LARGE_CONTAINER_TYPE_REF,
88-
TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF,
89-
TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF,
91+
DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, INTERVAL_DAY_TIME_TYPE_REF,
92+
INTERVAL_DAY_TIME_TYPE_URL, INTERVAL_MONTH_DAY_NANO_TYPE_REF,
93+
INTERVAL_MONTH_DAY_NANO_TYPE_URL, INTERVAL_YEAR_MONTH_TYPE_REF,
94+
INTERVAL_YEAR_MONTH_TYPE_URL, LARGE_CONTAINER_TYPE_REF, TIMESTAMP_MICRO_TYPE_REF,
95+
TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF, TIMESTAMP_SECOND_TYPE_REF,
96+
UNSIGNED_INTEGER_TYPE_REF,
9097
};
9198

9299
/// Convert DataFusion LogicalPlan to Substrait Plan
@@ -1398,6 +1405,49 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result<substrait::proto::
13981405
nullability,
13991406
})),
14001407
}),
1408+
DataType::Interval(interval_unit) => {
1409+
// define two type parameters for convenience
1410+
let i32_param = Parameter {
1411+
parameter: Some(parameter::Parameter::DataType(substrait::proto::Type {
1412+
kind: Some(r#type::Kind::I32(r#type::I32 {
1413+
type_variation_reference: DEFAULT_TYPE_REF,
1414+
nullability: default_nullability,
1415+
})),
1416+
})),
1417+
};
1418+
let i64_param = Parameter {
1419+
parameter: Some(parameter::Parameter::DataType(substrait::proto::Type {
1420+
kind: Some(r#type::Kind::I64(r#type::I64 {
1421+
type_variation_reference: DEFAULT_TYPE_REF,
1422+
nullability: default_nullability,
1423+
})),
1424+
})),
1425+
};
1426+
1427+
let (type_parameters, type_reference) = match interval_unit {
1428+
IntervalUnit::YearMonth => {
1429+
let type_parameters = vec![i32_param];
1430+
(type_parameters, INTERVAL_YEAR_MONTH_TYPE_REF)
1431+
}
1432+
IntervalUnit::DayTime => {
1433+
let type_parameters = vec![i64_param];
1434+
(type_parameters, INTERVAL_DAY_TIME_TYPE_REF)
1435+
}
1436+
IntervalUnit::MonthDayNano => {
1437+
// use 2 `i64` as `i128`
1438+
let type_parameters = vec![i64_param.clone(), i64_param];
1439+
(type_parameters, INTERVAL_MONTH_DAY_NANO_TYPE_REF)
1440+
}
1441+
};
1442+
Ok(substrait::proto::Type {
1443+
kind: Some(r#type::Kind::UserDefined(r#type::UserDefined {
1444+
type_reference,
1445+
type_variation_reference: DEFAULT_TYPE_REF,
1446+
nullability: default_nullability,
1447+
type_parameters,
1448+
})),
1449+
})
1450+
}
14011451
DataType::Binary => Ok(substrait::proto::Type {
14021452
kind: Some(r#type::Kind::Binary(r#type::Binary {
14031453
type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
@@ -1735,6 +1785,75 @@ fn to_substrait_literal(value: &ScalarValue) -> Result<Literal> {
17351785
}
17361786
ScalarValue::Date32(Some(d)) => (LiteralType::Date(*d), DATE_32_TYPE_REF),
17371787
// Date64 literal is not supported in Substrait
1788+
ScalarValue::IntervalYearMonth(Some(i)) => {
1789+
let bytes = i.to_le_bytes();
1790+
(
1791+
LiteralType::UserDefined(UserDefined {
1792+
type_reference: INTERVAL_YEAR_MONTH_TYPE_REF,
1793+
type_parameters: vec![Parameter {
1794+
parameter: Some(parameter::Parameter::DataType(
1795+
substrait::proto::Type {
1796+
kind: Some(r#type::Kind::I32(r#type::I32 {
1797+
type_variation_reference: DEFAULT_TYPE_REF,
1798+
nullability: r#type::Nullability::Required as i32,
1799+
})),
1800+
},
1801+
)),
1802+
}],
1803+
val: Some(Val::Value(ProtoAny {
1804+
type_url: INTERVAL_YEAR_MONTH_TYPE_URL.to_string(),
1805+
value: bytes.to_vec(),
1806+
})),
1807+
}),
1808+
INTERVAL_YEAR_MONTH_TYPE_REF,
1809+
)
1810+
}
1811+
ScalarValue::IntervalMonthDayNano(Some(i)) => {
1812+
// treat `i128` as two contiguous `i64`
1813+
let bytes = i.to_le_bytes();
1814+
let i64_param = Parameter {
1815+
parameter: Some(parameter::Parameter::DataType(substrait::proto::Type {
1816+
kind: Some(r#type::Kind::I64(r#type::I64 {
1817+
type_variation_reference: DEFAULT_TYPE_REF,
1818+
nullability: r#type::Nullability::Required as i32,
1819+
})),
1820+
})),
1821+
};
1822+
(
1823+
LiteralType::UserDefined(UserDefined {
1824+
type_reference: INTERVAL_MONTH_DAY_NANO_TYPE_REF,
1825+
type_parameters: vec![i64_param.clone(), i64_param],
1826+
val: Some(Val::Value(ProtoAny {
1827+
type_url: INTERVAL_MONTH_DAY_NANO_TYPE_URL.to_string(),
1828+
value: bytes.to_vec(),
1829+
})),
1830+
}),
1831+
INTERVAL_MONTH_DAY_NANO_TYPE_REF,
1832+
)
1833+
}
1834+
ScalarValue::IntervalDayTime(Some(i)) => {
1835+
let bytes = i.to_le_bytes();
1836+
(
1837+
LiteralType::UserDefined(UserDefined {
1838+
type_reference: INTERVAL_DAY_TIME_TYPE_REF,
1839+
type_parameters: vec![Parameter {
1840+
parameter: Some(parameter::Parameter::DataType(
1841+
substrait::proto::Type {
1842+
kind: Some(r#type::Kind::I64(r#type::I64 {
1843+
type_variation_reference: DEFAULT_TYPE_REF,
1844+
nullability: r#type::Nullability::Required as i32,
1845+
})),
1846+
},
1847+
)),
1848+
}],
1849+
val: Some(Val::Value(ProtoAny {
1850+
type_url: INTERVAL_DAY_TIME_TYPE_URL.to_string(),
1851+
value: bytes.to_vec(),
1852+
})),
1853+
}),
1854+
INTERVAL_DAY_TIME_TYPE_REF,
1855+
)
1856+
}
17381857
ScalarValue::Binary(Some(b)) => {
17391858
(LiteralType::Binary(b.clone()), DEFAULT_CONTAINER_TYPE_REF)
17401859
}

datafusion/substrait/src/variation_const.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
//! - Default type reference is 0. It is used when the actual type is the same with the original type.
2626
//! - Extended variant type references start from 1, and ususlly increase by 1.
2727
28+
// For type variations
2829
pub const DEFAULT_TYPE_REF: u32 = 0;
2930
pub const UNSIGNED_INTEGER_TYPE_REF: u32 = 1;
3031
pub const TIMESTAMP_SECOND_TYPE_REF: u32 = 0;
@@ -37,3 +38,58 @@ pub const DEFAULT_CONTAINER_TYPE_REF: u32 = 0;
3738
pub const LARGE_CONTAINER_TYPE_REF: u32 = 1;
3839
pub const DECIMAL_128_TYPE_REF: u32 = 0;
3940
pub const DECIMAL_256_TYPE_REF: u32 = 1;
41+
42+
// For custom types
43+
/// For [`DataType::Interval`] with [`IntervalUnit::YearMonth`].
44+
///
45+
/// An `i32` for elapsed whole months. See also [`ScalarValue::IntervalYearMonth`]
46+
/// for the literal definition in DataFusion.
47+
///
48+
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
49+
/// [`IntervalUnit::YearMonth`]: datafusion::arrow::datatypes::IntervalUnit::YearMonth
50+
/// [`ScalarValue::IntervalYearMonth`]: datafusion::common::ScalarValue::IntervalYearMonth
51+
pub const INTERVAL_YEAR_MONTH_TYPE_REF: u32 = 1;
52+
53+
/// For [`DataType::Interval`] with [`IntervalUnit::DayTime`].
54+
///
55+
/// An `i64` as:
56+
/// - days: `i32`
57+
/// - milliseconds: `i32`
58+
///
59+
/// See also [`ScalarValue::IntervalDayTime`] for the literal definition in DataFusion.
60+
///
61+
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
62+
/// [`IntervalUnit::DayTime`]: datafusion::arrow::datatypes::IntervalUnit::DayTime
63+
/// [`ScalarValue::IntervalDayTime`]: datafusion::common::ScalarValue::IntervalDayTime
64+
pub const INTERVAL_DAY_TIME_TYPE_REF: u32 = 2;
65+
66+
/// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`].
67+
///
68+
/// An `i128` as:
69+
/// - months: `i32`
70+
/// - days: `i32`
71+
/// - nanoseconds: `i64`
72+
///
73+
/// See also [`ScalarValue::IntervalMonthDayNano`] for the literal definition in DataFusion.
74+
///
75+
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
76+
/// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano
77+
/// [`ScalarValue::IntervalMonthDayNano`]: datafusion::common::ScalarValue::IntervalMonthDayNano
78+
pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3;
79+
80+
// For User Defined URLs
81+
/// For [`DataType::Interval`] with [`IntervalUnit::YearMonth`].
82+
///
83+
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
84+
/// [`IntervalUnit::YearMonth`]: datafusion::arrow::datatypes::IntervalUnit::YearMonth
85+
pub const INTERVAL_YEAR_MONTH_TYPE_URL: &str = "interval-year-month";
86+
/// For [`DataType::Interval`] with [`IntervalUnit::DayTime`].
87+
///
88+
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
89+
/// [`IntervalUnit::DayTime`]: datafusion::arrow::datatypes::IntervalUnit::DayTime
90+
pub const INTERVAL_DAY_TIME_TYPE_URL: &str = "interval-day-time";
91+
/// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`].
92+
///
93+
/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
94+
/// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano
95+
pub const INTERVAL_MONTH_DAY_NANO_TYPE_URL: &str = "interval-month-day-nano";

datafusion/substrait/tests/cases/roundtrip_logical_plan.rs

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use datafusion_substrait::logical_plan::{
2525
use std::hash::Hash;
2626
use std::sync::Arc;
2727

28-
use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
28+
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
2929
use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef};
3030
use datafusion::error::Result;
3131
use datafusion::execution::context::SessionState;
@@ -496,6 +496,24 @@ async fn roundtrip_arithmetic_ops() -> Result<()> {
496496
Ok(())
497497
}
498498

499+
#[tokio::test]
500+
async fn roundtrip_interval_literal() -> Result<()> {
501+
roundtrip(
502+
"SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR', 'Interval(YearMonth)')",
503+
)
504+
.await?;
505+
roundtrip(
506+
"SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR', 'Interval(DayTime)')",
507+
)
508+
.await?;
509+
roundtrip(
510+
"SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR', 'Interval(MonthDayNano)')",
511+
)
512+
.await?;
513+
514+
Ok(())
515+
}
516+
499517
#[tokio::test]
500518
async fn roundtrip_like() -> Result<()> {
501519
roundtrip("SELECT f FROM data WHERE f LIKE 'a%b'").await
@@ -1035,14 +1053,16 @@ async fn create_context() -> Result<SessionContext> {
10351053
.with_serializer_registry(Arc::new(MockSerializerRegistry));
10361054
let ctx = SessionContext::new_with_state(state);
10371055
let mut explicit_options = CsvReadOptions::new();
1038-
let schema = Schema::new(vec![
1056+
let fields = vec![
10391057
Field::new("a", DataType::Int64, true),
10401058
Field::new("b", DataType::Decimal128(5, 2), true),
10411059
Field::new("c", DataType::Date32, true),
10421060
Field::new("d", DataType::Boolean, true),
10431061
Field::new("e", DataType::UInt32, true),
10441062
Field::new("f", DataType::Utf8, true),
1045-
]);
1063+
Field::new("g", DataType::Interval(IntervalUnit::DayTime), true),
1064+
];
1065+
let schema = Schema::new(fields);
10461066
explicit_options.schema = Some(&schema);
10471067
ctx.register_csv("data", "tests/testdata/data.csv", explicit_options)
10481068
.await?;

0 commit comments

Comments
 (0)