Skip to content

Commit c25060e

Browse files
authored
feat: add support for array_remove expression (#1179)
* wip: array remove * added comet expression test * updated test cases * fixed array_remove function for null values * removed commented code * remove unnecessary code * updated the test for 'array_remove' * added test for array_remove in case the input array is null * wip: case array is empty * removed test case for empty array
1 parent d52038e commit c25060e

File tree

4 files changed

+53
-0
lines changed

4 files changed

+53
-0
lines changed

native/core/src/execution/planner.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ use datafusion::{
6666
};
6767
use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr};
6868
use datafusion_functions_nested::concat::ArrayAppend;
69+
use datafusion_functions_nested::remove::array_remove_all_udf;
6970
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
7071

7172
use crate::execution::shuffle::CompressionCodec;
@@ -735,6 +736,35 @@ impl PhysicalPlanner {
735736
));
736737
Ok(array_has_expr)
737738
}
739+
ExprStruct::ArrayRemove(expr) => {
740+
let src_array_expr =
741+
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
742+
let key_expr =
743+
self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?;
744+
let args = vec![Arc::clone(&src_array_expr), Arc::clone(&key_expr)];
745+
let return_type = src_array_expr.data_type(&input_schema)?;
746+
747+
let datafusion_array_remove = array_remove_all_udf();
748+
749+
let array_remove_expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
750+
"array_remove",
751+
datafusion_array_remove,
752+
args,
753+
return_type,
754+
));
755+
let is_null_expr: Arc<dyn PhysicalExpr> = Arc::new(IsNullExpr::new(key_expr));
756+
757+
let null_literal_expr: Arc<dyn PhysicalExpr> =
758+
Arc::new(Literal::new(ScalarValue::Null));
759+
760+
let case_expr = CaseExpr::try_new(
761+
None,
762+
vec![(is_null_expr, null_literal_expr)],
763+
Some(array_remove_expr),
764+
)?;
765+
766+
Ok(Arc::new(case_expr))
767+
}
738768
expr => Err(ExecutionError::GeneralError(format!(
739769
"Not implemented: {:?}",
740770
expr

native/proto/src/proto/expr.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ message Expr {
8585
BinaryExpr array_append = 58;
8686
ArrayInsert array_insert = 59;
8787
BinaryExpr array_contains = 60;
88+
BinaryExpr array_remove = 61;
8889
}
8990
}
9091

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2266,6 +2266,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
22662266
withInfo(expr, "unsupported arguments for GetArrayStructFields", child)
22672267
None
22682268
}
2269+
case expr if expr.prettyName == "array_remove" =>
2270+
createBinaryExpr(
2271+
expr.children(0),
2272+
expr.children(1),
2273+
inputs,
2274+
(builder, binaryExpr) => builder.setArrayRemove(binaryExpr))
22692275
case expr if expr.prettyName == "array_contains" =>
22702276
createBinaryExpr(
22712277
expr.children(0),

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2529,4 +2529,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
25292529
spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
25302530
}
25312531
}
2532+
2533+
test("array_remove") {
2534+
Seq(true, false).foreach { dictionaryEnabled =>
2535+
withTempDir { dir =>
2536+
val path = new Path(dir.toURI.toString, "test.parquet")
2537+
makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
2538+
spark.read.parquet(path.toString).createOrReplaceTempView("t1")
2539+
checkSparkAnswerAndOperator(
2540+
sql("SELECT array_remove(array(_2, _3,_4), _2) from t1 where _2 is null"))
2541+
checkSparkAnswerAndOperator(
2542+
sql("SELECT array_remove(array(_2, _3,_4), _3) from t1 where _3 is not null"))
2543+
checkSparkAnswerAndOperator(sql(
2544+
"SELECT array_remove(case when _2 = _3 THEN array(_2, _3,_4) ELSE null END, _3) from t1"))
2545+
}
2546+
}
2547+
}
25322548
}

0 commit comments

Comments
 (0)