Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion/core/src/execution/session_state_defaults.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ impl SessionStateDefaults {
default_catalog
}

/// NOTE!
/// returns the list of default [`ExprPlanner`]s
pub fn default_expr_planners() -> Vec<Arc<dyn ExprPlanner>> {
let expr_planners: Vec<Arc<dyn ExprPlanner>> = vec![
Expand Down
229 changes: 123 additions & 106 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,124 +178,129 @@ impl<'a> BinaryTypeCoercer<'a> {
use arrow::datatypes::DataType::*;
use Operator::*;
let result = match self.op {
Eq |
NotEq |
Lt |
LtEq |
Gt |
GtEq |
IsDistinctFrom |
IsNotDistinctFrom => {
comparison_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
plan_datafusion_err!(
"Cannot infer common argument type for comparison operation {} {} {}",
self.lhs,
self.op,
self.rhs
)
})
}
And | Or => if matches!((lhs, rhs), (Boolean | Null, Boolean | Null)) {
// Logical binary boolean operators can only be evaluated for
// boolean or null arguments.
Ok(Signature::uniform(Boolean))
} else {
plan_err!(
"Cannot infer common argument type for logical boolean operation {} {} {}", self.lhs, self.op, self.rhs
)
}
RegexMatch | RegexIMatch | RegexNotMatch | RegexNotIMatch => {
regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
plan_datafusion_err!(
"Cannot infer common argument type for regex operation {} {} {}", self.lhs, self.op, self.rhs
)
})
}
LikeMatch | ILikeMatch | NotLikeMatch | NotILikeMatch => {
regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
plan_datafusion_err!(
"Cannot infer common argument type for regex operation {} {} {}", self.lhs, self.op, self.rhs
)
})
}
BitwiseAnd | BitwiseOr | BitwiseXor | BitwiseShiftRight | BitwiseShiftLeft => {
bitwise_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| {
plan_datafusion_err!(
"Cannot infer common type for bitwise operation {} {} {}", self.lhs, self.op, self.rhs
)
})
}
StringConcat => {
string_concat_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| {
plan_datafusion_err!(
"Cannot infer common string type for string concat operation {} {} {}", self.lhs, self.op, self.rhs
)
})
}
AtArrow | ArrowAt => {
// Array contains or search (similar to LIKE) operation
array_coercion(lhs, rhs)
.or_else(|| like_coercion(lhs, rhs)).map(Signature::comparison).ok_or_else(|| {
Eq | NotEq | Lt | LtEq | Gt | GtEq | IsDistinctFrom | IsNotDistinctFrom => {
comparison_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
plan_datafusion_err!(
"Cannot infer common argument type for operation {} {} {}", self.lhs, self.op, self.rhs
"Cannot infer common argument type for comparison operation {} {} {}",
self.lhs,
self.op,
self.rhs
)
})
}
AtAt => {
// text search has similar signature to LIKE
like_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
plan_datafusion_err!(
"Cannot infer common argument type for AtAt operation {} {} {}", self.lhs, self.op, self.rhs
}
And | Or => if matches!((lhs, rhs), (Boolean | Null, Boolean | Null)) {
// Logical binary boolean operators can only be evaluated for
// boolean or null arguments.
Ok(Signature::uniform(Boolean))
} else {
plan_err!(
"Cannot infer common argument type for logical boolean operation {} {} {}", self.lhs, self.op, self.rhs
)
})
}
Plus | Minus | Multiply | Divide | Modulo => {
if let Ok(ret) = self.get_result(lhs, rhs) {
// Temporal arithmetic, e.g. Date32 + Interval
Ok(Signature{
lhs: lhs.clone(),
rhs: rhs.clone(),
ret,
}
RegexMatch | RegexIMatch | RegexNotMatch | RegexNotIMatch => {
regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
plan_datafusion_err!(
"Cannot infer common argument type for regex operation {} {} {}", self.lhs, self.op, self.rhs
)
})
} else if let Some(coerced) = temporal_coercion_strict_timezone(lhs, rhs) {
// Temporal arithmetic by first coercing to a common time representation
// e.g. Date32 - Timestamp
let ret = self.get_result(&coerced, &coerced).map_err(|e| {
}
LikeMatch | ILikeMatch | NotLikeMatch | NotILikeMatch => {
regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
plan_datafusion_err!(
"Cannot get result type for temporal operation {coerced} {} {coerced}: {e}", self.op
"Cannot infer common argument type for regex operation {} {} {}", self.lhs, self.op, self.rhs
)
})?;
Ok(Signature{
lhs: coerced.clone(),
rhs: coerced,
ret,
})
} else if let Some((lhs, rhs)) = math_decimal_coercion(lhs, rhs) {
// Decimal arithmetic, e.g. Decimal(10, 2) + Decimal(10, 0)
let ret = self.get_result(&lhs, &rhs).map_err(|e| {
}
BitwiseAnd | BitwiseOr | BitwiseXor | BitwiseShiftRight | BitwiseShiftLeft => {
bitwise_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| {
plan_datafusion_err!(
"Cannot get result type for decimal operation {} {} {}: {e}", self.lhs, self.op, self.rhs
"Cannot infer common type for bitwise operation {} {} {}", self.lhs, self.op, self.rhs
)
})?;
Ok(Signature{
lhs,
rhs,
ret,
})
} else if let Some(numeric) = mathematics_numerical_coercion(lhs, rhs) {
// Numeric arithmetic, e.g. Int32 + Int32
Ok(Signature::uniform(numeric))
} else {
plan_err!(
"Cannot coerce arithmetic expression {} {} {} to valid types", self.lhs, self.op, self.rhs
)
}
},
IntegerDivide | Arrow | LongArrow | HashArrow | HashLongArrow
| HashMinus | AtQuestion | Question | QuestionAnd | QuestionPipe => {
not_impl_err!("Operator {} is not yet supported", self.op)
}
};
StringConcat => {
string_concat_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| {
plan_datafusion_err!(
"Cannot infer common string type for string concat operation {} {} {}", self.lhs, self.op, self.rhs
)
})
}
AtArrow | ArrowAt => {
// Array contains or search (similar to LIKE) operation
array_coercion(lhs, rhs)
.or_else(|| like_coercion(lhs, rhs)).map(Signature::comparison).ok_or_else(|| {
plan_datafusion_err!(
"Cannot infer common argument type for operation {} {} {}", self.lhs, self.op, self.rhs
)
})
}
AtAt => {
// text search has similar signature to LIKE
like_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
plan_datafusion_err!(
"Cannot infer common argument type for AtAt operation {} {} {}", self.lhs, self.op, self.rhs
)
})
}
Plus | Minus | Multiply | Divide | Modulo => {
if let Ok(ret) = self.get_result(lhs, rhs) {
// Temporal arithmetic, e.g. Date32 + Interval
Ok(Signature{
lhs: lhs.clone(),
rhs: rhs.clone(),
ret,
})
} else if let Some(coerced) = temporal_coercion_strict_timezone(lhs, rhs) {
// Temporal arithmetic by first coercing to a common time representation
// e.g. Date32 - Timestamp
let ret = self.get_result(&coerced, &coerced).map_err(|e| {
plan_datafusion_err!(
"Cannot get result type for temporal operation {coerced} {} {coerced}: {e}", self.op
)
})?;
Ok(Signature{
lhs: coerced.clone(),
rhs: coerced,
ret,
})
} else if let Some((lhs, rhs)) = temporal_coercion_resolve_ints_to_intervals(lhs, rhs) {
// e.g. Date32 + Int32
let ret = self.get_result(&lhs, &rhs).map_err(|e| {
plan_datafusion_err!(
"Cannot get result type for temporal operation {} {} {}: {e}", self.lhs, self.op, self.rhs
)
})?;
Ok(Signature{
lhs: lhs.clone(),
rhs: rhs.clone(),
ret,
})
} else if let Some((lhs, rhs)) = math_decimal_coercion(lhs, rhs) {
// decimal arithmetic, e.g. Decimal(10, 2) + Decimal(10, 0)
let ret = self.get_result(&lhs, &rhs).map_err(|e| {
plan_datafusion_err!(
"Cannot get result type for decimal operation {} {} {}: {e}", self.lhs, self.op, self.rhs
)
})?;
Ok(Signature{
lhs,
rhs,
ret,
})
} else if let Some(numeric) = mathematics_numerical_coercion(lhs, rhs) {
// Numeric arithmetic, e.g. Int32 + Int32
Ok(Signature::uniform(numeric))
} else {
plan_err!(
"Cannot coerce arithmetic expression {} {} {} to valid types", self.lhs, self.op, self.rhs
)
}
},
IntegerDivide | Arrow | LongArrow | HashArrow | HashLongArrow | HashMinus | AtQuestion | Question | QuestionAnd | QuestionPipe => {
not_impl_err!("Operator {} is not yet supported", self.op)
}
};

result.map_err(|err| {
let diagnostic =
Diagnostic::new_error("expressions have incompatible types", self.span())
Expand Down Expand Up @@ -1433,6 +1438,18 @@ fn temporal_coercion_nonstrict_timezone(
}
}

fn temporal_coercion_resolve_ints_to_intervals(
lhs: &DataType,
rhs: &DataType,
) -> Option<(DataType, DataType)> {
use arrow::datatypes::DataType::{Date32, Int32, Int64, Interval};
use arrow::datatypes::IntervalUnit::DayTime;
match (lhs, rhs) {
(Date32, Int32 | Int64) => Some((lhs.clone(), Interval(DayTime))),
_ => None,
}
}

/// Strict Timezone coercion is useful in scenarios where we cannot guarantee a stable relationship
/// between two timestamps with different timezones or do not want implicit coercion between them.
///
Expand Down
53 changes: 45 additions & 8 deletions datafusion/functions-nested/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use datafusion_expr::{
use datafusion_functions::core::get_field as get_field_inner;
use datafusion_functions::expr_fn::get_field;
use datafusion_functions_aggregate::nth_value::nth_value_udaf;
use sqlparser::ast::BinaryOperator;
use std::sync::Arc;

use crate::map::map_udf;
Expand All @@ -51,7 +52,7 @@ impl ExprPlanner for NestedFunctionPlanner {
) -> Result<PlannerResult<RawBinaryExpr>> {
let RawBinaryExpr { op, left, right } = expr;

if op == sqlparser::ast::BinaryOperator::StringConcat {
if op == BinaryOperator::StringConcat {
let left_type = left.get_type(schema)?;
let right_type = right.get_type(schema)?;
let left_list_ndims = list_ndims(&left_type);
Expand All @@ -75,25 +76,61 @@ impl ExprPlanner for NestedFunctionPlanner {
} else if left_list_ndims < right_list_ndims {
return Ok(PlannerResult::Planned(array_prepend(left, right)));
}
} else if matches!(
op,
sqlparser::ast::BinaryOperator::AtArrow
| sqlparser::ast::BinaryOperator::ArrowAt
) {
} else if matches!(op, BinaryOperator::AtArrow | BinaryOperator::ArrowAt) {
let left_type = left.get_type(schema)?;
let right_type = right.get_type(schema)?;
let left_list_ndims = list_ndims(&left_type);
let right_list_ndims = list_ndims(&right_type);
// if both are list
if left_list_ndims > 0 && right_list_ndims > 0 {
if op == sqlparser::ast::BinaryOperator::AtArrow {
if op == BinaryOperator::AtArrow {
// array1 @> array2 -> array_has_all(array1, array2)
return Ok(PlannerResult::Planned(array_has_all(left, right)));
} else {
// array1 <@ array2 -> array_has_all(array2, array1)
return Ok(PlannerResult::Planned(array_has_all(right, left)));
}
}
// } else if matches!(op, BinaryOperator::Plus | BinaryOperator::Minus)
// && matches!(left.get_type(schema)?, DataType::Date32)
// && matches!(right.get_type(schema)?, DataType::Int32 | DataType::Int64)
// {
// use arrow::datatypes::IntervalDayTime;
// use datafusion_common::ScalarValue;
// use datafusion_expr::BinaryExpr;
// use datafusion_expr::Operator;
// use sqlparser::ast::BinaryOperator;
//
// let op: Operator = match op {
// BinaryOperator::Plus => Operator::Plus,
// BinaryOperator::Minus => Operator::Minus,
// _ => unreachable!(),
// };
//
// let new_right: Expr = match right {
// Expr::Literal(ScalarValue::Int32(Some(i)), meta) => Expr::Literal(
// ScalarValue::IntervalDayTime(Some(IntervalDayTime {
// days: i,
// milliseconds: 0,
// })),
// meta,
// ),
// Expr::Literal(ScalarValue::Int64(Some(i)), meta) => Expr::Literal(
// ScalarValue::IntervalDayTime(Some(IntervalDayTime {
// days: i as i32,
// milliseconds: 0,
// })),
// meta,
// ),
// _ => unreachable!(),
// };
//
// let planned = Expr::BinaryExpr(BinaryExpr {
// left: Box::new(left.clone()),
// right: Box::new(new_right),
// op,
// });
// return Ok(PlannerResult::Planned(planned));
}

Ok(PlannerResult::Original(RawBinaryExpr { op, left, right }))
Expand Down Expand Up @@ -123,7 +160,7 @@ impl ExprPlanner for NestedFunctionPlanner {
}

fn plan_any(&self, expr: RawBinaryExpr) -> Result<PlannerResult<RawBinaryExpr>> {
if expr.op == sqlparser::ast::BinaryOperator::Eq {
if expr.op == BinaryOperator::Eq {
Ok(PlannerResult::Planned(Expr::ScalarFunction(
ScalarFunction::new_udf(
array_has_udf(),
Expand Down
10 changes: 8 additions & 2 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast};
use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};

use super::inlist_simplifier::ShortenInListSimplifier;
use super::make_interval_simplifier::MakeIntervalSimplifier;
use super::utils::*;
use crate::analyzer::type_coercion::TypeCoercionRewriter;
use crate::simplify_expressions::guarantees::GuaranteeRewriter;
Expand All @@ -55,6 +56,7 @@ use crate::simplify_expressions::unwrap_cast::{
unwrap_cast_in_comparison_for_binary,
};
use crate::simplify_expressions::SimplifyInfo;
use arrow::datatypes::IntervalDayTime;
use datafusion_expr::expr::FieldMetadata;
use datafusion_expr_common::casts::try_cast_literal_to_type;
use indexmap::IndexSet;
Expand Down Expand Up @@ -256,6 +258,10 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
}
// shorten inlist should be started after other inlist rules are applied
expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?;

let mut make_int_interval = MakeIntervalSimplifier::new();
expr = expr.rewrite(&mut make_int_interval).data()?;

Ok((
Transformed::new_transformed(expr, has_transformed),
num_cycles,
Expand Down Expand Up @@ -767,8 +773,8 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
use datafusion_expr::Operator::{
And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, BitwiseXor,
Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch,
RegexNotIMatch, RegexNotMatch,
Divide, Eq, Minus, Modulo, Multiply, NotEq, Or, Plus, RegexIMatch,
RegexMatch, RegexNotIMatch, RegexNotMatch,
};

let info = self.info;
Expand Down
Loading
Loading