Skip to content

Commit 86d8f57

Browse files
Support arrays_overlap function
1 parent 497e40b commit 86d8f57

File tree

4 files changed

+54
-0
lines changed

4 files changed

+54
-0
lines changed

native/core/src/execution/planner.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ use datafusion::{
6565
prelude::SessionContext,
6666
};
6767
use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr};
68+
use datafusion_functions_nested::array_has::array_has_any_udf;
6869
use datafusion_functions_nested::concat::ArrayAppend;
6970
use datafusion_functions_nested::remove::array_remove_all_udf;
7071
use datafusion_functions_nested::set_ops::array_intersect_udf;
@@ -818,6 +819,21 @@ impl PhysicalPlanner {
818819
));
819820
Ok(array_join_expr)
820821
}
822+
ExprStruct::ArraysOverlap(expr) => {
823+
let left_array_expr =
824+
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
825+
let right_array_expr =
826+
self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?;
827+
let args = vec![Arc::clone(&left_array_expr), right_array_expr];
828+
let datafusion_array_has_any = array_has_any_udf();
829+
let array_has_any_expr = Arc::new(ScalarFunctionExpr::new(
830+
"array_has_any",
831+
datafusion_array_has_any,
832+
args,
833+
DataType::Boolean,
834+
));
835+
Ok(array_has_any_expr)
836+
}
821837
expr => Err(ExecutionError::GeneralError(format!(
822838
"Not implemented: {:?}",
823839
expr

native/proto/src/proto/expr.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ message Expr {
8888
BinaryExpr array_remove = 61;
8989
BinaryExpr array_intersect = 62;
9090
ArrayJoin array_join = 63;
91+
BinaryExpr arrays_overlap = 64;
9192
}
9293
}
9394

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2428,6 +2428,22 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
24282428
withInfo(expr, "unsupported arguments for ArrayJoin", exprs: _*)
24292429
None
24302430
}
2431+
case ArraysOverlap(leftArrayExpr, rightArrayExpr) =>
2432+
if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) {
2433+
createBinaryExpr(
2434+
expr,
2435+
leftArrayExpr,
2436+
rightArrayExpr,
2437+
inputs,
2438+
binding,
2439+
(builder, binaryExpr) => builder.setArraysOverlap(binaryExpr))
2440+
} else {
2441+
withInfo(
2442+
expr,
2443+
s"$expr is not supported yet. To enable all incompatible casts, set " +
2444+
s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true")
2445+
None
2446+
}
24312447
case _ =>
24322448
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
24332449
None

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2701,4 +2701,25 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
27012701
}
27022702
}
27032703
}
2704+
2705+
test("arrays_overlap") {
2706+
withSQLConf(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") {
2707+
Seq(true, false).foreach { dictionaryEnabled =>
2708+
withTempDir { dir =>
2709+
val path = new Path(dir.toURI.toString, "test.parquet")
2710+
makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
2711+
spark.read.parquet(path.toString).createOrReplaceTempView("t1")
2712+
checkSparkAnswerAndOperator(sql(
2713+
"SELECT arrays_overlap(array(_2, _3, _4), array(_3, _4)) from t1 where _2 is not null"))
2714+
checkSparkAnswerAndOperator(sql(
2715+
"SELECT arrays_overlap(array('a', null, cast(_1 as string)), array('b', cast(_1 as string), cast(_2 as string))) from t1 where _1 is not null"))
2716+
checkSparkAnswerAndOperator(sql(
2717+
"SELECT arrays_overlap(array('a', null), array('b', null)) from t1 where _1 is not null"))
2718+
checkSparkAnswerAndOperator(
2719+
spark.sql("SELECT arrays_overlap((CASE WHEN _2 =_3 THEN array(_6, _7) END), array(_6, _7)) FROM t1"));
2720+
}
2721+
}
2722+
}
2723+
}
2724+
27042725
}

0 commit comments

Comments
 (0)