Skip to content

Commit b98415e

Browse files
Feat: Support array_join
1 parent c25060e commit b98415e

File tree

4 files changed

+39
-0
lines changed

4 files changed

+39
-0
lines changed

native/core/src/execution/planner.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ use datafusion::{
6767
use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr};
6868
use datafusion_functions_nested::concat::ArrayAppend;
6969
use datafusion_functions_nested::remove::array_remove_all_udf;
70+
use datafusion_functions_nested::string::array_to_string_udf;
7071
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
7172

7273
use crate::execution::shuffle::CompressionCodec;
@@ -765,6 +766,21 @@ impl PhysicalPlanner {
765766

766767
Ok(Arc::new(case_expr))
767768
}
769+
ExprStruct::ArrayJoin(expr) => {
770+
let src_array_expr =
771+
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
772+
let key_expr =
773+
self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?;
774+
let args = vec![Arc::clone(&src_array_expr), key_expr];
775+
let datafusion_array_to_string = array_to_string_udf();
776+
let array_intersect_expr = Arc::new(ScalarFunctionExpr::new(
777+
"array_join",
778+
datafusion_array_to_string,
779+
args,
780+
DataType::Utf8,
781+
));
782+
Ok(array_intersect_expr)
783+
}
768784
expr => Err(ExecutionError::GeneralError(format!(
769785
"Not implemented: {:?}",
770786
expr

native/proto/src/proto/expr.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ message Expr {
8686
ArrayInsert array_insert = 59;
8787
BinaryExpr array_contains = 60;
8888
BinaryExpr array_remove = 61;
89+
BinaryExpr array_join = 63;
8990
}
9091
}
9192

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2284,6 +2284,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
22842284
expr.children(1),
22852285
inputs,
22862286
(builder, binaryExpr) => builder.setArrayAppend(binaryExpr))
2287+
case _ if expr.prettyName == "array_join" =>
2288+
createBinaryExpr(
2289+
expr.children(0),
2290+
expr.children(1),
2291+
inputs,
2292+
(builder, binaryExpr) => builder.setArrayJoin(binaryExpr))
22872293
case _ =>
22882294
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
22892295
None

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2545,4 +2545,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
25452545
}
25462546
}
25472547
}
2548+
2549+
test("array_join") {
2550+
Seq(true, false).foreach { dictionaryEnabled =>
2551+
withTempDir { dir =>
2552+
val path = new Path(dir.toURI.toString, "test.parquet")
2553+
makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
2554+
spark.read.parquet(path.toString).createOrReplaceTempView("t1")
2555+
checkSparkAnswerAndOperator(
2556+
sql("SELECT array_join(array(cast(_1 as string), cast(_2 as string), cast(_6 as string)), ' @ ') from t1"))
2557+
checkSparkAnswerAndOperator(
2558+
sql("SELECT array_join(array('hello', 'world', cast(_2 as string)), ' ') from t1 where _2 is not null"))
2559+
checkSparkAnswerAndOperator(
2560+
sql("SELECT array_join(array('hello', '-', 'world', cast(_2 as string)), ' ') from t1"))
2561+
}
2562+
}
2563+
}
25482564
}

0 commit comments

Comments
 (0)