Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions native/core/src/execution/expressions/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Array expression builders

use std::sync::Arc;

use arrow::datatypes::SchemaRef;
use datafusion::physical_expr::PhysicalExpr;
use datafusion_comet_proto::spark_expression::Expr;
use datafusion_comet_spark_expr::{ArrayExistsExpr, LambdaVariableExpr};

use crate::execution::operators::ExecutionError;
use crate::execution::planner::expression_registry::ExpressionBuilder;
use crate::execution::planner::PhysicalPlanner;
use crate::execution::serde::to_arrow_datatype;
use crate::extract_expr;

pub struct ArrayExistsBuilder;

impl ExpressionBuilder for ArrayExistsBuilder {
fn build(
&self,
spark_expr: &Expr,
input_schema: SchemaRef,
planner: &PhysicalPlanner,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
let expr = extract_expr!(spark_expr, ArrayExists);
let array_expr =
planner.create_expr(expr.array.as_ref().unwrap(), Arc::clone(&input_schema))?;
let lambda_body = planner.create_expr(expr.lambda_body.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(ArrayExistsExpr::new(
array_expr,
lambda_body,
expr.follow_three_valued_logic,
)))
}
}

pub struct LambdaVariableBuilder;

impl ExpressionBuilder for LambdaVariableBuilder {
fn build(
&self,
spark_expr: &Expr,
_input_schema: SchemaRef,
_planner: &PhysicalPlanner,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
let expr = extract_expr!(spark_expr, LambdaVariable);
let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap());
Ok(Arc::new(LambdaVariableExpr::new(data_type)))
}
}
1 change: 1 addition & 0 deletions native/core/src/execution/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! Native DataFusion expressions

pub mod arithmetic;
pub mod array;
pub mod bitwise;
pub mod comparison;
pub mod logical;
Expand Down
19 changes: 19 additions & 0 deletions native/core/src/execution/planner/expression_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ pub enum ExpressionType {
Randn,
SparkPartitionId,
MonotonicallyIncreasingId,
ArrayExists,
LambdaVariable,

// Time functions
Hour,
Expand Down Expand Up @@ -184,6 +186,9 @@ impl ExpressionRegistry {

// Register temporal expressions
self.register_temporal_expressions();

// Register array expressions
self.register_array_expressions();
}

/// Register arithmetic expression builders
Expand Down Expand Up @@ -306,6 +311,18 @@ impl ExpressionRegistry {
);
}

/// Register array expression builders
fn register_array_expressions(&mut self) {
use crate::execution::expressions::array::*;

self.builders
.insert(ExpressionType::ArrayExists, Box::new(ArrayExistsBuilder));
self.builders.insert(
ExpressionType::LambdaVariable,
Box::new(LambdaVariableBuilder),
);
}

/// Extract expression type from Spark protobuf expression
fn get_expression_type(spark_expr: &Expr) -> Result<ExpressionType, ExecutionError> {
match spark_expr.expr_struct.as_ref() {
Expand Down Expand Up @@ -370,6 +387,8 @@ impl ExpressionRegistry {
Some(ExprStruct::MonotonicallyIncreasingId(_)) => {
Ok(ExpressionType::MonotonicallyIncreasingId)
}
Some(ExprStruct::ArrayExists(_)) => Ok(ExpressionType::ArrayExists),
Some(ExprStruct::LambdaVariable(_)) => Ok(ExpressionType::LambdaVariable),

Some(ExprStruct::Hour(_)) => Ok(ExpressionType::Hour),
Some(ExprStruct::Minute(_)) => Ok(ExpressionType::Minute),
Expand Down
16 changes: 16 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ message Expr {
UnixTimestamp unix_timestamp = 65;
FromJson from_json = 66;
ToCsv to_csv = 67;
ArrayExists array_exists = 68;
LambdaVariable lambda_variable = 69;
}
}

Expand Down Expand Up @@ -440,3 +442,17 @@ message ArrayJoin {
message Rand {
int64 seed = 1;
}

message ArrayExists {
Expr array = 1;
Expr lambda_body = 2;
bool follow_three_valued_logic = 3;
}

// Currently only supports a single lambda variable per expression. The variable
// is resolved by column index (always the last column in the expanded batch
// constructed by ArrayExistsExpr). Extending to multi-argument lambdas
// (e.g. transform(array, (x, i) -> ...)) would require adding an identifier.
message LambdaVariable {
DataType datatype = 1;
}
Loading
Loading