@@ -33,8 +33,8 @@ use datafusion_common::{
33
33
} ;
34
34
use datafusion_common:: { internal_err, DFSchema , DataFusionError , Result , ScalarValue } ;
35
35
use datafusion_expr:: {
36
- and, lit, or, BinaryExpr , Case , ColumnarValue , Expr , Like , Operator , Volatility ,
37
- WindowFunctionDefinition ,
36
+ and, binary :: BinaryTypeCoercer , lit, or, BinaryExpr , Case , ColumnarValue , Expr , Like ,
37
+ Operator , Volatility , WindowFunctionDefinition ,
38
38
} ;
39
39
use datafusion_expr:: { expr:: ScalarFunction , interval_arithmetic:: NullableInterval } ;
40
40
use datafusion_expr:: {
@@ -976,30 +976,39 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
976
976
// Rules for Multiply
977
977
//
978
978
979
- // A * 1 --> A
979
+ // A * 1 --> A (with type coercion if needed)
980
980
Expr :: BinaryExpr ( BinaryExpr {
981
981
left,
982
982
op : Multiply ,
983
983
right,
984
- } ) if is_one ( & right) => Transformed :: yes ( * left) ,
985
- // 1 * A --> A
984
+ } ) if is_one ( & right) => {
985
+ simplify_right_is_one_case ( info, left, & Multiply , & right) ?
986
+ }
987
+ // A * null --> null
986
988
Expr :: BinaryExpr ( BinaryExpr {
987
989
left,
988
990
op : Multiply ,
989
991
right,
990
- } ) if is_one ( & left) => Transformed :: yes ( * right) ,
991
- // A * null --> null
992
+ } ) if is_null ( & right) => {
993
+ simplify_right_is_null_case ( info, & left, & Multiply , right) ?
994
+ }
995
+ // 1 * A --> A
992
996
Expr :: BinaryExpr ( BinaryExpr {
993
- left : _ ,
997
+ left,
994
998
op : Multiply ,
995
999
right,
996
- } ) if is_null ( & right) => Transformed :: yes ( * right) ,
1000
+ } ) if is_one ( & left) => {
1001
+ // 1 * A is equivalent to A * 1
1002
+ simplify_right_is_one_case ( info, right, & Multiply , & left) ?
1003
+ }
997
1004
// null * A --> null
998
1005
Expr :: BinaryExpr ( BinaryExpr {
999
1006
left,
1000
1007
op : Multiply ,
1001
- right : _,
1002
- } ) if is_null ( & left) => Transformed :: yes ( * left) ,
1008
+ right,
1009
+ } ) if is_null ( & left) => {
1010
+ simplify_right_is_null_case ( info, & right, & Multiply , left) ?
1011
+ }
1003
1012
1004
1013
// A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN)
1005
1014
Expr :: BinaryExpr ( BinaryExpr {
@@ -1033,19 +1042,23 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
1033
1042
left,
1034
1043
op : Divide ,
1035
1044
right,
1036
- } ) if is_one ( & right) => Transformed :: yes ( * left) ,
1037
- // null / A --> null
1045
+ } ) if is_one ( & right) => {
1046
+ simplify_right_is_one_case ( info, left, & Divide , & right) ?
1047
+ }
1048
+ // A / null --> null
1038
1049
Expr :: BinaryExpr ( BinaryExpr {
1039
1050
left,
1040
1051
op : Divide ,
1041
- right : _,
1042
- } ) if is_null ( & left) => Transformed :: yes ( * left) ,
1043
- // A / null --> null
1052
+ right,
1053
+ } ) if is_null ( & right) => {
1054
+ simplify_right_is_null_case ( info, & left, & Divide , right) ?
1055
+ }
1056
+ // null / A --> null
1044
1057
Expr :: BinaryExpr ( BinaryExpr {
1045
- left : _ ,
1058
+ left,
1046
1059
op : Divide ,
1047
1060
right,
1048
- } ) if is_null ( & right ) => Transformed :: yes ( * right) ,
1061
+ } ) if is_null ( & left ) => simplify_null_div_other_case ( info , left , & right) ? ,
1049
1062
1050
1063
//
1051
1064
// Rules for Modulo
@@ -1997,6 +2010,84 @@ fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result<Expr> {
1997
2010
}
1998
2011
}
1999
2012
2013
+ // A * 1 -> A
2014
+ // A / 1 -> A
2015
+ //
2016
+ // Move this function body out of the large match branch avoid stack overflow
2017
+ fn simplify_right_is_one_case < S : SimplifyInfo > (
2018
+ info : & S ,
2019
+ left : Box < Expr > ,
2020
+ op : & Operator ,
2021
+ right : & Expr ,
2022
+ ) -> Result < Transformed < Expr > > {
2023
+ // Check if resulting type would be different due to coercion
2024
+ let left_type = info. get_data_type ( & left) ?;
2025
+ let right_type = info. get_data_type ( right) ?;
2026
+ match BinaryTypeCoercer :: new ( & left_type, op, & right_type) . get_result_type ( ) {
2027
+ Ok ( result_type) => {
2028
+ // Only cast if the types differ
2029
+ if left_type != result_type {
2030
+ Ok ( Transformed :: yes ( Expr :: Cast ( Cast :: new ( left, result_type) ) ) )
2031
+ } else {
2032
+ Ok ( Transformed :: yes ( * left) )
2033
+ }
2034
+ }
2035
+ Err ( _) => Ok ( Transformed :: yes ( * left) ) ,
2036
+ }
2037
+ }
2038
+
2039
+ // A * null -> null
2040
+ // A / null -> null
2041
+ //
2042
+ // Move this function body out of the large match branch avoid stack overflow
2043
+ fn simplify_right_is_null_case < S : SimplifyInfo > (
2044
+ info : & S ,
2045
+ left : & Expr ,
2046
+ op : & Operator ,
2047
+ right : Box < Expr > ,
2048
+ ) -> Result < Transformed < Expr > > {
2049
+ // Check if resulting type would be different due to coercion
2050
+ let left_type = info. get_data_type ( left) ?;
2051
+ let right_type = info. get_data_type ( & right) ?;
2052
+ match BinaryTypeCoercer :: new ( & left_type, op, & right_type) . get_result_type ( ) {
2053
+ Ok ( result_type) => {
2054
+ // Only cast if the types differ
2055
+ if right_type != result_type {
2056
+ Ok ( Transformed :: yes ( Expr :: Cast ( Cast :: new ( right, result_type) ) ) )
2057
+ } else {
2058
+ Ok ( Transformed :: yes ( * right) )
2059
+ }
2060
+ }
2061
+ Err ( _) => Ok ( Transformed :: yes ( * right) ) ,
2062
+ }
2063
+ }
2064
+
2065
+ // null / A --> null
2066
+ //
2067
+ // Move this function body out of the large match branch avoid stack overflow
2068
+ fn simplify_null_div_other_case < S : SimplifyInfo > (
2069
+ info : & S ,
2070
+ left : Box < Expr > ,
2071
+ right : & Expr ,
2072
+ ) -> Result < Transformed < Expr > > {
2073
+ // Check if resulting type would be different due to coercion
2074
+ let left_type = info. get_data_type ( & left) ?;
2075
+ let right_type = info. get_data_type ( right) ?;
2076
+ match BinaryTypeCoercer :: new ( & left_type, & Operator :: Divide , & right_type)
2077
+ . get_result_type ( )
2078
+ {
2079
+ Ok ( result_type) => {
2080
+ // Only cast if the types differ
2081
+ if left_type != result_type {
2082
+ Ok ( Transformed :: yes ( Expr :: Cast ( Cast :: new ( left, result_type) ) ) )
2083
+ } else {
2084
+ Ok ( Transformed :: yes ( * left) )
2085
+ }
2086
+ }
2087
+ Err ( _) => Ok ( Transformed :: yes ( * left) ) ,
2088
+ }
2089
+ }
2090
+
2000
2091
#[ cfg( test) ]
2001
2092
mod tests {
2002
2093
use crate :: simplify_expressions:: SimplifyContext ;
@@ -2316,12 +2407,12 @@ mod tests {
2316
2407
// A / null --> null
2317
2408
let null = lit ( ScalarValue :: Null ) ;
2318
2409
{
2319
- let expr = col ( "c " ) / null. clone ( ) ;
2410
+ let expr = col ( "c1 " ) / null. clone ( ) ;
2320
2411
assert_eq ! ( simplify( expr) , null) ;
2321
2412
}
2322
2413
// null / A --> null
2323
2414
{
2324
- let expr = null. clone ( ) / col ( "c " ) ;
2415
+ let expr = null. clone ( ) / col ( "c1 " ) ;
2325
2416
assert_eq ! ( simplify( expr) , null) ;
2326
2417
}
2327
2418
}
0 commit comments