Skip to content

Commit b71bec0

Browse files
authored
feat: implement Unary Expr in substrait (#8534)
Signed-off-by: Ruihang Xia <[email protected]>
1 parent b7fde3c commit b71bec0

File tree

3 files changed

+169
-86
lines changed

3 files changed

+169
-86
lines changed

datafusion/substrait/src/logical_plan/consumer.rs

Lines changed: 32 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,7 +1253,9 @@ struct BuiltinExprBuilder {
12531253
impl BuiltinExprBuilder {
12541254
pub fn try_from_name(name: &str) -> Option<Self> {
12551255
match name {
1256-
"not" | "like" | "ilike" | "is_null" | "is_not_null" => Some(Self {
1256+
"not" | "like" | "ilike" | "is_null" | "is_not_null" | "is_true"
1257+
| "is_false" | "is_not_true" | "is_not_false" | "is_unknown"
1258+
| "is_not_unknown" | "negative" => Some(Self {
12571259
expr_name: name.to_string(),
12581260
}),
12591261
_ => None,
@@ -1267,37 +1269,51 @@ impl BuiltinExprBuilder {
12671269
extensions: &HashMap<u32, &String>,
12681270
) -> Result<Arc<Expr>> {
12691271
match self.expr_name.as_str() {
1270-
"not" => Self::build_not_expr(f, input_schema, extensions).await,
12711272
"like" => Self::build_like_expr(false, f, input_schema, extensions).await,
12721273
"ilike" => Self::build_like_expr(true, f, input_schema, extensions).await,
1273-
"is_null" => {
1274-
Self::build_is_null_expr(false, f, input_schema, extensions).await
1275-
}
1276-
"is_not_null" => {
1277-
Self::build_is_null_expr(true, f, input_schema, extensions).await
1274+
"not" | "negative" | "is_null" | "is_not_null" | "is_true" | "is_false"
1275+
| "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" => {
1276+
Self::build_unary_expr(&self.expr_name, f, input_schema, extensions).await
12781277
}
12791278
_ => {
12801279
not_impl_err!("Unsupported builtin expression: {}", self.expr_name)
12811280
}
12821281
}
12831282
}
12841283

1285-
async fn build_not_expr(
1284+
async fn build_unary_expr(
1285+
fn_name: &str,
12861286
f: &ScalarFunction,
12871287
input_schema: &DFSchema,
12881288
extensions: &HashMap<u32, &String>,
12891289
) -> Result<Arc<Expr>> {
12901290
if f.arguments.len() != 1 {
1291-
return not_impl_err!("Expect one argument for `NOT` expr");
1291+
return substrait_err!("Expect one argument for {fn_name} expr");
12921292
}
12931293
let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else {
1294-
return not_impl_err!("Invalid arguments type for `NOT` expr");
1294+
return substrait_err!("Invalid arguments type for {fn_name} expr");
12951295
};
1296-
let expr = from_substrait_rex(expr_substrait, input_schema, extensions)
1296+
let arg = from_substrait_rex(expr_substrait, input_schema, extensions)
12971297
.await?
12981298
.as_ref()
12991299
.clone();
1300-
Ok(Arc::new(Expr::Not(Box::new(expr))))
1300+
let arg = Box::new(arg);
1301+
1302+
let expr = match fn_name {
1303+
"not" => Expr::Not(arg),
1304+
"negative" => Expr::Negative(arg),
1305+
"is_null" => Expr::IsNull(arg),
1306+
"is_not_null" => Expr::IsNotNull(arg),
1307+
"is_true" => Expr::IsTrue(arg),
1308+
"is_false" => Expr::IsFalse(arg),
1309+
"is_not_true" => Expr::IsNotTrue(arg),
1310+
"is_not_false" => Expr::IsNotFalse(arg),
1311+
"is_unknown" => Expr::IsUnknown(arg),
1312+
"is_not_unknown" => Expr::IsNotUnknown(arg),
1313+
_ => return not_impl_err!("Unsupported builtin expression: {}", fn_name),
1314+
};
1315+
1316+
Ok(Arc::new(expr))
13011317
}
13021318

13031319
async fn build_like_expr(
@@ -1308,25 +1324,25 @@ impl BuiltinExprBuilder {
13081324
) -> Result<Arc<Expr>> {
13091325
let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" };
13101326
if f.arguments.len() != 3 {
1311-
return not_impl_err!("Expect three arguments for `{fn_name}` expr");
1327+
return substrait_err!("Expect three arguments for `{fn_name}` expr");
13121328
}
13131329

13141330
let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else {
1315-
return not_impl_err!("Invalid arguments type for `{fn_name}` expr");
1331+
return substrait_err!("Invalid arguments type for `{fn_name}` expr");
13161332
};
13171333
let expr = from_substrait_rex(expr_substrait, input_schema, extensions)
13181334
.await?
13191335
.as_ref()
13201336
.clone();
13211337
let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else {
1322-
return not_impl_err!("Invalid arguments type for `{fn_name}` expr");
1338+
return substrait_err!("Invalid arguments type for `{fn_name}` expr");
13231339
};
13241340
let pattern = from_substrait_rex(pattern_substrait, input_schema, extensions)
13251341
.await?
13261342
.as_ref()
13271343
.clone();
13281344
let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else {
1329-
return not_impl_err!("Invalid arguments type for `{fn_name}` expr");
1345+
return substrait_err!("Invalid arguments type for `{fn_name}` expr");
13301346
};
13311347
let escape_char_expr =
13321348
from_substrait_rex(escape_char_substrait, input_schema, extensions)
@@ -1347,30 +1363,4 @@ impl BuiltinExprBuilder {
13471363
case_insensitive,
13481364
})))
13491365
}
1350-
1351-
async fn build_is_null_expr(
1352-
is_not: bool,
1353-
f: &ScalarFunction,
1354-
input_schema: &DFSchema,
1355-
extensions: &HashMap<u32, &String>,
1356-
) -> Result<Arc<Expr>> {
1357-
let fn_name = if is_not { "IS NOT NULL" } else { "IS NULL" };
1358-
let arg = f.arguments.first().ok_or_else(|| {
1359-
substrait_datafusion_err!("expect one argument for `{fn_name}` expr")
1360-
})?;
1361-
match &arg.arg_type {
1362-
Some(ArgType::Value(e)) => {
1363-
let expr = from_substrait_rex(e, input_schema, extensions)
1364-
.await?
1365-
.as_ref()
1366-
.clone();
1367-
if is_not {
1368-
Ok(Arc::new(Expr::IsNotNull(Box::new(expr))))
1369-
} else {
1370-
Ok(Arc::new(Expr::IsNull(Box::new(expr))))
1371-
}
1372-
}
1373-
_ => substrait_err!("Invalid arguments for `{fn_name}` expression"),
1374-
}
1375-
}
13761366
}

datafusion/substrait/src/logical_plan/producer.rs

Lines changed: 97 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,50 +1083,76 @@ pub fn to_substrait_rex(
10831083
col_ref_offset,
10841084
extension_info,
10851085
),
1086-
Expr::IsNull(arg) => {
1087-
let arguments: Vec<FunctionArgument> = vec![FunctionArgument {
1088-
arg_type: Some(ArgType::Value(to_substrait_rex(
1089-
arg,
1090-
schema,
1091-
col_ref_offset,
1092-
extension_info,
1093-
)?)),
1094-
}];
1095-
1096-
let function_name = "is_null".to_string();
1097-
let function_anchor = _register_function(function_name, extension_info);
1098-
Ok(Expression {
1099-
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
1100-
function_reference: function_anchor,
1101-
arguments,
1102-
output_type: None,
1103-
args: vec![],
1104-
options: vec![],
1105-
})),
1106-
})
1107-
}
1108-
Expr::IsNotNull(arg) => {
1109-
let arguments: Vec<FunctionArgument> = vec![FunctionArgument {
1110-
arg_type: Some(ArgType::Value(to_substrait_rex(
1111-
arg,
1112-
schema,
1113-
col_ref_offset,
1114-
extension_info,
1115-
)?)),
1116-
}];
1117-
1118-
let function_name = "is_not_null".to_string();
1119-
let function_anchor = _register_function(function_name, extension_info);
1120-
Ok(Expression {
1121-
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
1122-
function_reference: function_anchor,
1123-
arguments,
1124-
output_type: None,
1125-
args: vec![],
1126-
options: vec![],
1127-
})),
1128-
})
1129-
}
1086+
Expr::Not(arg) => to_substrait_unary_scalar_fn(
1087+
"not",
1088+
arg,
1089+
schema,
1090+
col_ref_offset,
1091+
extension_info,
1092+
),
1093+
Expr::IsNull(arg) => to_substrait_unary_scalar_fn(
1094+
"is_null",
1095+
arg,
1096+
schema,
1097+
col_ref_offset,
1098+
extension_info,
1099+
),
1100+
Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn(
1101+
"is_not_null",
1102+
arg,
1103+
schema,
1104+
col_ref_offset,
1105+
extension_info,
1106+
),
1107+
Expr::IsTrue(arg) => to_substrait_unary_scalar_fn(
1108+
"is_true",
1109+
arg,
1110+
schema,
1111+
col_ref_offset,
1112+
extension_info,
1113+
),
1114+
Expr::IsFalse(arg) => to_substrait_unary_scalar_fn(
1115+
"is_false",
1116+
arg,
1117+
schema,
1118+
col_ref_offset,
1119+
extension_info,
1120+
),
1121+
Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn(
1122+
"is_unknown",
1123+
arg,
1124+
schema,
1125+
col_ref_offset,
1126+
extension_info,
1127+
),
1128+
Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn(
1129+
"is_not_true",
1130+
arg,
1131+
schema,
1132+
col_ref_offset,
1133+
extension_info,
1134+
),
1135+
Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn(
1136+
"is_not_false",
1137+
arg,
1138+
schema,
1139+
col_ref_offset,
1140+
extension_info,
1141+
),
1142+
Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn(
1143+
"is_not_unknown",
1144+
arg,
1145+
schema,
1146+
col_ref_offset,
1147+
extension_info,
1148+
),
1149+
Expr::Negative(arg) => to_substrait_unary_scalar_fn(
1150+
"negative",
1151+
arg,
1152+
schema,
1153+
col_ref_offset,
1154+
extension_info,
1155+
),
11301156
_ => {
11311157
not_impl_err!("Unsupported expression: {expr:?}")
11321158
}
@@ -1591,6 +1617,33 @@ fn to_substrait_literal(value: &ScalarValue) -> Result<Expression> {
15911617
})
15921618
}
15931619

1620+
/// Util to generate substrait [RexType::ScalarFunction] with one argument
1621+
fn to_substrait_unary_scalar_fn(
1622+
fn_name: &str,
1623+
arg: &Expr,
1624+
schema: &DFSchemaRef,
1625+
col_ref_offset: usize,
1626+
extension_info: &mut (
1627+
Vec<extensions::SimpleExtensionDeclaration>,
1628+
HashMap<String, u32>,
1629+
),
1630+
) -> Result<Expression> {
1631+
let function_anchor = _register_function(fn_name.to_string(), extension_info);
1632+
let substrait_expr = to_substrait_rex(arg, schema, col_ref_offset, extension_info)?;
1633+
1634+
Ok(Expression {
1635+
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
1636+
function_reference: function_anchor,
1637+
arguments: vec![FunctionArgument {
1638+
arg_type: Some(ArgType::Value(substrait_expr)),
1639+
}],
1640+
output_type: None,
1641+
options: vec![],
1642+
..Default::default()
1643+
})),
1644+
})
1645+
}
1646+
15941647
fn try_to_substrait_null(v: &ScalarValue) -> Result<LiteralType> {
15951648
let default_nullability = r#type::Nullability::Nullable as i32;
15961649
match v {

datafusion/substrait/tests/cases/roundtrip_logical_plan.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,46 @@ async fn roundtrip_ilike() -> Result<()> {
483483
roundtrip("SELECT f FROM data WHERE f ILIKE 'a%b'").await
484484
}
485485

486+
#[tokio::test]
487+
async fn roundtrip_not() -> Result<()> {
488+
roundtrip("SELECT * FROM data WHERE NOT d").await
489+
}
490+
491+
#[tokio::test]
492+
async fn roundtrip_negative() -> Result<()> {
493+
roundtrip("SELECT * FROM data WHERE -a = 1").await
494+
}
495+
496+
#[tokio::test]
497+
async fn roundtrip_is_true() -> Result<()> {
498+
roundtrip("SELECT * FROM data WHERE d IS TRUE").await
499+
}
500+
501+
#[tokio::test]
502+
async fn roundtrip_is_false() -> Result<()> {
503+
roundtrip("SELECT * FROM data WHERE d IS FALSE").await
504+
}
505+
506+
#[tokio::test]
507+
async fn roundtrip_is_not_true() -> Result<()> {
508+
roundtrip("SELECT * FROM data WHERE d IS NOT TRUE").await
509+
}
510+
511+
#[tokio::test]
512+
async fn roundtrip_is_not_false() -> Result<()> {
513+
roundtrip("SELECT * FROM data WHERE d IS NOT FALSE").await
514+
}
515+
516+
#[tokio::test]
517+
async fn roundtrip_is_unknown() -> Result<()> {
518+
roundtrip("SELECT * FROM data WHERE d IS UNKNOWN").await
519+
}
520+
521+
#[tokio::test]
522+
async fn roundtrip_is_not_unknown() -> Result<()> {
523+
roundtrip("SELECT * FROM data WHERE d IS NOT UNKNOWN").await
524+
}
525+
486526
#[tokio::test]
487527
async fn roundtrip_union() -> Result<()> {
488528
roundtrip("SELECT a, e FROM data UNION SELECT a, e FROM data").await

0 commit comments

Comments
 (0)