Skip to content

Commit c5f8e7e

Browse files
Support array_join function nullReplacement parameter
1 parent 5d3c94d commit c5f8e7e

File tree

4 files changed

+57
-14
lines changed

4 files changed

+57
-14
lines changed

native/core/src/execution/planner.rs

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -767,19 +767,28 @@ impl PhysicalPlanner {
767767
Ok(Arc::new(case_expr))
768768
}
769769
ExprStruct::ArrayJoin(expr) => {
770-
let src_array_expr =
771-
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
772-
let key_expr =
773-
self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?;
774-
let args = vec![Arc::clone(&src_array_expr), key_expr];
770+
let array_expr =
771+
self.create_expr(expr.array_expr.as_ref().unwrap(), Arc::clone(&input_schema))?;
772+
let delimiter_expr = self.create_expr(
773+
expr.delimiter_expr.as_ref().unwrap(),
774+
Arc::clone(&input_schema),
775+
)?;
776+
777+
let mut args = vec![Arc::clone(&array_expr), delimiter_expr];
778+
if expr.null_replacement_expr.is_some() {
779+
let null_replacement_expr =
780+
self.create_expr(expr.null_replacement_expr.as_ref().unwrap(), Arc::clone(&input_schema))?;
781+
args.push(null_replacement_expr)
782+
}
783+
775784
let datafusion_array_to_string = array_to_string_udf();
776-
let array_intersect_expr = Arc::new(ScalarFunctionExpr::new(
785+
let array_join_expr = Arc::new(ScalarFunctionExpr::new(
777786
"array_join",
778787
datafusion_array_to_string,
779788
args,
780789
DataType::Utf8,
781790
));
782-
Ok(array_intersect_expr)
791+
Ok(array_join_expr)
783792
}
784793
expr => Err(ExecutionError::GeneralError(format!(
785794
"Not implemented: {:?}",

native/proto/src/proto/expr.proto

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ message Expr {
8686
ArrayInsert array_insert = 59;
8787
BinaryExpr array_contains = 60;
8888
BinaryExpr array_remove = 61;
89-
BinaryExpr array_join = 63;
89+
ArrayJoin array_join = 62;
9090
}
9191
}
9292

@@ -415,6 +415,12 @@ message ArrayInsert {
415415
bool legacy_negative_index = 4;
416416
}
417417

418+
message ArrayJoin {
419+
Expr array_expr = 1;
420+
Expr delimiter_expr = 2;
421+
Expr null_replacement_expr = 3;
422+
}
423+
418424
message DataType {
419425
enum DataTypeId {
420426
BOOL = 0;

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2284,12 +2284,38 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
22842284
expr.children(1),
22852285
inputs,
22862286
(builder, binaryExpr) => builder.setArrayAppend(binaryExpr))
2287-
case _ if expr.prettyName == "array_join" =>
2288-
createBinaryExpr(
2289-
expr.children(0),
2290-
expr.children(1),
2291-
inputs,
2292-
(builder, binaryExpr) => builder.setArrayJoin(binaryExpr))
2287+
case ArrayJoin(arrayExpr, delimiterExpr, nullReplacementExpr) =>
2288+
val arrayExprProto = exprToProto(arrayExpr, inputs, binding)
2289+
val delimiterExprProto = exprToProto(delimiterExpr, inputs, binding)
2290+
2291+
if (arrayExprProto.isDefined && delimiterExprProto.isDefined) {
2292+
val arrayJoinBuilder = nullReplacementExpr match {
2293+
case Some(nrExpr) =>
2294+
val nullReplacementExprProto = exprToProto(nrExpr, inputs, binding)
2295+
ExprOuterClass.ArrayJoin
2296+
.newBuilder()
2297+
.setArrayExpr(arrayExprProto.get)
2298+
.setDelimiterExpr(delimiterExprProto.get)
2299+
.setNullReplacementExpr(nullReplacementExprProto.get)
2300+
case None =>
2301+
ExprOuterClass.ArrayJoin
2302+
.newBuilder()
2303+
.setArrayExpr(arrayExprProto.get)
2304+
.setDelimiterExpr(delimiterExprProto.get)
2305+
}
2306+
Some(
2307+
ExprOuterClass.Expr
2308+
.newBuilder()
2309+
.setArrayJoin(arrayJoinBuilder)
2310+
.build())
2311+
} else {
2312+
val exprs: List[Expression] = nullReplacementExpr match {
2313+
case Some(nrExpr) => List(arrayExpr, delimiterExpr, nrExpr)
2314+
case None => List(arrayExpr, delimiterExpr)
2315+
}
2316+
withInfo(expr, "unsupported arguments for ArrayJoin", exprs: _*)
2317+
None
2318+
}
22932319
case _ =>
22942320
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
22952321
None

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2554,6 +2554,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
25542554
spark.read.parquet(path.toString).createOrReplaceTempView("t1")
25552555
checkSparkAnswerAndOperator(sql(
25562556
"SELECT array_join(array(cast(_1 as string), cast(_2 as string), cast(_6 as string)), ' @ ') from t1"))
2557+
checkSparkAnswerAndOperator(sql(
2558+
"SELECT array_join(array(cast(_1 as string), cast(_2 as string), cast(_6 as string)), ' @ ', ' +++ ') from t1"))
25572559
checkSparkAnswerAndOperator(sql(
25582560
"SELECT array_join(array('hello', 'world', cast(_2 as string)), ' ') from t1 where _2 is not null"))
25592561
checkSparkAnswerAndOperator(

0 commit comments

Comments
 (0)