Skip to content

Commit 7824059

Browse files
Feat: Support array_join function
1 parent 2e34b5f commit 7824059

File tree

4 files changed

+83
-0
lines changed

4 files changed

+83
-0
lines changed

native/core/src/execution/planner.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr}
6868
use datafusion_functions_nested::concat::ArrayAppend;
6969
use datafusion_functions_nested::remove::array_remove_all_udf;
7070
use datafusion_functions_nested::set_ops::array_intersect_udf;
71+
use datafusion_functions_nested::string::array_to_string_udf;
7172
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
7273

7374
use crate::execution::shuffle::CompressionCodec;
@@ -791,6 +792,32 @@ impl PhysicalPlanner {
791792
));
792793
Ok(array_intersect_expr)
793794
}
795+
ExprStruct::ArrayJoin(expr) => {
796+
let array_expr =
797+
self.create_expr(expr.array_expr.as_ref().unwrap(), Arc::clone(&input_schema))?;
798+
let delimiter_expr = self.create_expr(
799+
expr.delimiter_expr.as_ref().unwrap(),
800+
Arc::clone(&input_schema),
801+
)?;
802+
803+
let mut args = vec![Arc::clone(&array_expr), delimiter_expr];
804+
if expr.null_replacement_expr.is_some() {
805+
let null_replacement_expr = self.create_expr(
806+
expr.null_replacement_expr.as_ref().unwrap(),
807+
Arc::clone(&input_schema),
808+
)?;
809+
args.push(null_replacement_expr)
810+
}
811+
812+
let datafusion_array_to_string = array_to_string_udf();
813+
let array_join_expr = Arc::new(ScalarFunctionExpr::new(
814+
"array_join",
815+
datafusion_array_to_string,
816+
args,
817+
DataType::Utf8,
818+
));
819+
Ok(array_join_expr)
820+
}
794821
expr => Err(ExecutionError::GeneralError(format!(
795822
"Not implemented: {:?}",
796823
expr

native/proto/src/proto/expr.proto

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ message Expr {
8787
BinaryExpr array_contains = 60;
8888
BinaryExpr array_remove = 61;
8989
BinaryExpr array_intersect = 62;
90+
ArrayJoin array_join = 63;
9091
}
9192
}
9293

@@ -415,6 +416,12 @@ message ArrayInsert {
415416
bool legacy_negative_index = 4;
416417
}
417418

419+
message ArrayJoin {
420+
Expr array_expr = 1;
421+
Expr delimiter_expr = 2;
422+
Expr null_replacement_expr = 3;
423+
}
424+
418425
message DataType {
419426
enum DataTypeId {
420427
BOOL = 0;

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2308,6 +2308,38 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
23082308
expr.children(1),
23092309
inputs,
23102310
(builder, binaryExpr) => builder.setArrayIntersect(binaryExpr))
2311+
case ArrayJoin(arrayExpr, delimiterExpr, nullReplacementExpr) =>
2312+
val arrayExprProto = exprToProto(arrayExpr, inputs, binding)
2313+
val delimiterExprProto = exprToProto(delimiterExpr, inputs, binding)
2314+
2315+
if (arrayExprProto.isDefined && delimiterExprProto.isDefined) {
2316+
val arrayJoinBuilder = nullReplacementExpr match {
2317+
case Some(nrExpr) =>
2318+
val nullReplacementExprProto = exprToProto(nrExpr, inputs, binding)
2319+
ExprOuterClass.ArrayJoin
2320+
.newBuilder()
2321+
.setArrayExpr(arrayExprProto.get)
2322+
.setDelimiterExpr(delimiterExprProto.get)
2323+
.setNullReplacementExpr(nullReplacementExprProto.get)
2324+
case None =>
2325+
ExprOuterClass.ArrayJoin
2326+
.newBuilder()
2327+
.setArrayExpr(arrayExprProto.get)
2328+
.setDelimiterExpr(delimiterExprProto.get)
2329+
}
2330+
Some(
2331+
ExprOuterClass.Expr
2332+
.newBuilder()
2333+
.setArrayJoin(arrayJoinBuilder)
2334+
.build())
2335+
} else {
2336+
val exprs: List[Expression] = nullReplacementExpr match {
2337+
case Some(nrExpr) => List(arrayExpr, delimiterExpr, nrExpr)
2338+
case None => List(arrayExpr, delimiterExpr)
2339+
}
2340+
withInfo(expr, "unsupported arguments for ArrayJoin", exprs: _*)
2341+
None
2342+
}
23112343
case _ =>
23122344
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
23132345
None

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2691,4 +2691,21 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
26912691
}
26922692
}
26932693

2694+
test("array_join") {
2695+
Seq(true, false).foreach { dictionaryEnabled =>
2696+
withTempDir { dir =>
2697+
val path = new Path(dir.toURI.toString, "test.parquet")
2698+
makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
2699+
spark.read.parquet(path.toString).createOrReplaceTempView("t1")
2700+
checkSparkAnswerAndOperator(sql(
2701+
"SELECT array_join(array(cast(_1 as string), cast(_2 as string), cast(_6 as string)), ' @ ') from t1"))
2702+
checkSparkAnswerAndOperator(sql(
2703+
"SELECT array_join(array(cast(_1 as string), cast(_2 as string), cast(_6 as string)), ' @ ', ' +++ ') from t1"))
2704+
checkSparkAnswerAndOperator(sql(
2705+
"SELECT array_join(array('hello', 'world', cast(_2 as string)), ' ') from t1 where _2 is not null"))
2706+
checkSparkAnswerAndOperator(
2707+
sql("SELECT array_join(array('hello', '-', 'world', cast(_2 as string)), ' ') from t1"))
2708+
}
2709+
}
2710+
}
26942711
}

0 commit comments

Comments
 (0)