Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ use datafusion_spark::function::math::expm1::SparkExpm1;
use datafusion_spark::function::math::hex::SparkHex;
use datafusion_spark::function::math::width_bucket::SparkWidthBucket;
use datafusion_spark::function::string::char::CharFunc;
use datafusion_spark::function::array::array_contains::SparkArrayContains;
use datafusion_spark::function::string::concat::SparkConcat;
use datafusion_spark::function::string::space::SparkSpace;
use futures::poll;
Expand Down Expand Up @@ -387,6 +388,7 @@ fn prepare_datafusion_session_context(

// register UDFs from datafusion-spark crate
fn register_datafusion_spark_function(session_ctx: &SessionContext) {
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkArrayContains::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkExpm1::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha2::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(CharFunc::default()));
Expand Down
35 changes: 5 additions & 30 deletions spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,36 +133,11 @@ object CometArrayContains extends CometExpressionSerde[ArrayContains] {
val arrayExprProto = exprToProto(expr.children.head, inputs, binding)
val keyExprProto = exprToProto(expr.children(1), inputs, binding)

val arrayContainsScalarExpr =
scalarFunctionExprToProto("array_has", arrayExprProto, keyExprProto)

// Handle NULL array input - return NULL if array is NULL (matching Spark's behavior)
val isNotNullExpr = createUnaryExpr(
expr,
expr.children.head,
inputs,
binding,
(builder, unaryExpr) => builder.setIsNotNull(unaryExpr))

val nullLiteralProto = exprToProto(Literal(null, BooleanType), Seq.empty)

if (arrayContainsScalarExpr.isDefined && isNotNullExpr.isDefined &&
nullLiteralProto.isDefined) {
val caseWhenExpr = ExprOuterClass.CaseWhen
.newBuilder()
.addWhen(isNotNullExpr.get)
.addThen(arrayContainsScalarExpr.get)
.setElseExpr(nullLiteralProto.get)
.build()
Some(
ExprOuterClass.Expr
.newBuilder()
.setCaseWhen(caseWhenExpr)
.build())
} else {
withInfo(expr, expr.children: _*)
None
}
// Delegates to datafusion-spark's SparkArrayContains which handles
// Spark's three-valued NULL semantics natively (no CASE WHEN needed).
val arrayContainsExpr =
scalarFunctionExprToProto("array_contains", arrayExprProto, keyExprProto)
optExprWithInfo(arrayContainsExpr, expr)
}
}

Expand Down