Skip to content

Commit d477f3b

Browse files
committed
wip: array remove
1 parent 46a28db commit d477f3b

File tree

4 files changed

+46
-0
lines changed

4 files changed

+46
-0
lines changed

native/core/src/execution/planner.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ use datafusion::{
6666
};
6767
use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr};
6868
use datafusion_functions_nested::concat::ArrayAppend;
69+
use datafusion_functions_nested::remove::array_remove_all_udf;
6970
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
7071

7172
use crate::execution::spark_plan::SparkPlan;
@@ -719,6 +720,24 @@ impl PhysicalPlanner {
719720
expr.legacy_negative_index,
720721
)))
721722
}
723+
ExprStruct::ArrayRemove(expr) => {
724+
let src_array_expr =
725+
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
726+
let key_expr =
727+
self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?;
728+
let args = vec![Arc::clone(&src_array_expr), key_expr];
729+
let return_type = src_array_expr.data_type(&input_schema)?;
730+
731+
let datafusion_array_remove = array_remove_all_udf();
732+
733+
let array_remove_expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
734+
"array_remove",
735+
datafusion_array_remove,
736+
args,
737+
return_type,
738+
));
739+
Ok(array_remove_expr)
740+
}
722741
expr => Err(ExecutionError::GeneralError(format!(
723742
"Not implemented: {:?}",
724743
expr

native/proto/src/proto/expr.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ message Expr {
8484
GetArrayStructFields get_array_struct_fields = 57;
8585
BinaryExpr array_append = 58;
8686
ArrayInsert array_insert = 59;
87+
BinaryExpr array_remove = 60;
8788
}
8889
}
8990

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2266,6 +2266,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
22662266
withInfo(expr, "unsupported arguments for GetArrayStructFields", child)
22672267
None
22682268
}
2269+
case expr if expr.prettyName == "array_remove" =>
2270+
createBinaryExpr(
2271+
expr.children(0),
2272+
expr.children(1),
2273+
inputs,
2274+
(builder, binaryExpr) => builder.setArrayRemove(binaryExpr))
22692275
case _ if expr.prettyName == "array_append" =>
22702276
createBinaryExpr(
22712277
expr.children(0),

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2517,4 +2517,24 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
25172517
checkSparkAnswer(df.select("arrUnsupportedArgs"))
25182518
}
25192519
}
2520+
test("array_remove") {
2521+
Seq(true, false).foreach { dictionaryEnabled =>
2522+
withTempDir { dir =>
2523+
val path = new Path(dir.toURI.toString, "test.parquet")
2524+
makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
2525+
2526+
// Test basic array_remove functionality
2527+
checkSparkAnswerAndOperator(sql("SELECT array_remove(array(1, 2, 3, null, 3), 2)"))
2528+
2529+
// Test removing multiple occurrences
2530+
checkSparkAnswerAndOperator(sql("SELECT array_remove(array(1, 3, 3, null, 3), 3)"))
2531+
2532+
// Test removing null
2533+
checkSparkAnswerAndOperator(sql("SELECT array_remove(array(1, 2, null, 4, null), null)"))
2534+
2535+
// Test when element doesn't exist
2536+
checkSparkAnswerAndOperator(sql("SELECT array_remove(array(1, 2, 3), 5)"))
2537+
}
2538+
}
2539+
}
25202540
}

0 commit comments

Comments
 (0)