@@ -81,6 +81,16 @@ pub(super) fn unwrap_cast_in_comparison_for_binary<S: SimplifyInfo>(
81
81
let Ok ( expr_type) = info. get_data_type ( & expr) else {
82
82
return internal_err ! ( "Can't get the data type of the expr {:?}" , & expr) ;
83
83
} ;
84
+
85
+ if let Some ( value) = cast_literal_to_type_with_op ( & lit_value, & expr_type, op)
86
+ {
87
+ return Ok ( Transformed :: yes ( Expr :: BinaryExpr ( BinaryExpr {
88
+ left : expr,
89
+ op,
90
+ right : Box :: new ( lit ( value) ) ,
91
+ } ) ) ) ;
92
+ } ;
93
+
84
94
// if the lit_value can be casted to the type of internal_left_expr
85
95
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
86
96
let Some ( value) = try_cast_literal_to_type ( & lit_value, & expr_type) else {
@@ -105,6 +115,7 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
105
115
> (
106
116
info : & S ,
107
117
expr : & Expr ,
118
+ op : Operator ,
108
119
literal : & Expr ,
109
120
) -> bool {
110
121
match ( expr, literal) {
@@ -125,6 +136,10 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
125
136
return false ;
126
137
} ;
127
138
139
+ if cast_literal_to_type_with_op ( lit_val, & expr_type, op) . is_some ( ) {
140
+ return true ;
141
+ }
142
+
128
143
try_cast_literal_to_type ( lit_val, & expr_type) . is_some ( )
129
144
&& is_supported_type ( & expr_type)
130
145
&& is_supported_type ( & lit_type)
@@ -215,6 +230,52 @@ fn is_supported_dictionary_type(data_type: &DataType) -> bool {
215
230
DataType :: Dictionary ( _, inner) if is_supported_type( inner) )
216
231
}
217
232
233
+ ///// Tries to move a cast from an expression (such as column) to the literal other side of a comparison operator./
234
+ ///
235
+ /// Specifically, rewrites
236
+ /// ```sql
237
+ /// cast(col) <op> <literal>
238
+ /// ```
239
+ ///
240
+ /// To
241
+ ///
242
+ /// ```sql
243
+ /// col <op> cast(<literal>)
244
+ /// col <op> <casted_literal>
245
+ /// ```
246
+ fn cast_literal_to_type_with_op (
247
+ lit_value : & ScalarValue ,
248
+ target_type : & DataType ,
249
+ op : Operator ,
250
+ ) -> Option < ScalarValue > {
251
+ match ( op, lit_value) {
252
+ (
253
+ Operator :: Eq | Operator :: NotEq ,
254
+ ScalarValue :: Utf8 ( Some ( _) )
255
+ | ScalarValue :: Utf8View ( Some ( _) )
256
+ | ScalarValue :: LargeUtf8 ( Some ( _) ) ,
257
+ ) => {
258
+ // Only try for integer types (TODO can we do this for other types
259
+ // like timestamps)?
260
+ use DataType :: * ;
261
+ if matches ! (
262
+ target_type,
263
+ Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
264
+ ) {
265
+ let casted = lit_value. cast_to ( target_type) . ok ( ) ?;
266
+ let round_tripped = casted. cast_to ( & lit_value. data_type ( ) ) . ok ( ) ?;
267
+ if lit_value != & round_tripped {
268
+ return None ;
269
+ }
270
+ Some ( casted)
271
+ } else {
272
+ None
273
+ }
274
+ }
275
+ _ => None ,
276
+ }
277
+ }
278
+
218
279
/// Convert a literal value from one data type to another
219
280
pub ( super ) fn try_cast_literal_to_type (
220
281
lit_value : & ScalarValue ,
@@ -468,6 +529,24 @@ mod tests {
468
529
// the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type
469
530
let expr_lt = cast ( col ( "c1" ) , DataType :: Int64 ) . lt ( lit ( 99999999999i64 ) ) ;
470
531
assert_eq ! ( optimize_test( expr_lt. clone( ) , & schema) , expr_lt) ;
532
+
533
+ // cast(c1, UTF8) < '123', only eq/not_eq should be optimized
534
+ let expr_lt = cast ( col ( "c1" ) , DataType :: Utf8 ) . lt ( lit ( "123" ) ) ;
535
+ assert_eq ! ( optimize_test( expr_lt. clone( ) , & schema) , expr_lt) ;
536
+
537
+ // cast(c1, UTF8) = '0123', cast(cast('0123', Int32), UTF8) != '0123', so '0123' should not
538
+ // be casted
539
+ let expr_lt = cast ( col ( "c1" ) , DataType :: Utf8 ) . lt ( lit ( "0123" ) ) ;
540
+ assert_eq ! ( optimize_test( expr_lt. clone( ) , & schema) , expr_lt) ;
541
+
542
+ // cast(c1, UTF8) = 'not a number', should not be able to cast to column type
543
+ let expr_input = cast ( col ( "c1" ) , DataType :: Utf8 ) . eq ( lit ( "not a number" ) ) ;
544
+ assert_eq ! ( optimize_test( expr_input. clone( ) , & schema) , expr_input) ;
545
+
546
+ // cast(c1, UTF8) = '99999999999', where '99999999999' does not fit into int32, so it will
547
+ // not be optimized to integer comparison
548
+ let expr_input = cast ( col ( "c1" ) , DataType :: Utf8 ) . eq ( lit ( "99999999999" ) ) ;
549
+ assert_eq ! ( optimize_test( expr_input. clone( ) , & schema) , expr_input) ;
471
550
}
472
551
473
552
#[ test]
@@ -496,6 +575,21 @@ mod tests {
496
575
let lit_lt_lit = cast ( null_i8 ( ) , DataType :: Int32 ) . lt ( lit ( 12i32 ) ) ;
497
576
let expected = null_bool ( ) ;
498
577
assert_eq ! ( optimize_test( lit_lt_lit, & schema) , expected) ;
578
+
579
+ // cast(c1, UTF8) = '123' => c1 = 123
580
+ let expr_input = cast ( col ( "c1" ) , DataType :: Utf8 ) . eq ( lit ( "123" ) ) ;
581
+ let expected = col ( "c1" ) . eq ( lit ( 123i32 ) ) ;
582
+ assert_eq ! ( optimize_test( expr_input, & schema) , expected) ;
583
+
584
+ // cast(c1, UTF8) != '123' => c1 != 123
585
+ let expr_input = cast ( col ( "c1" ) , DataType :: Utf8 ) . not_eq ( lit ( "123" ) ) ;
586
+ let expected = col ( "c1" ) . not_eq ( lit ( 123i32 ) ) ;
587
+ assert_eq ! ( optimize_test( expr_input, & schema) , expected) ;
588
+
589
+ // cast(c1, UTF8) = NULL => c1 = NULL
590
+ let expr_input = cast ( col ( "c1" ) , DataType :: Utf8 ) . eq ( lit ( ScalarValue :: Utf8 ( None ) ) ) ;
591
+ let expected = col ( "c1" ) . eq ( lit ( ScalarValue :: Int32 ( None ) ) ) ;
592
+ assert_eq ! ( optimize_test( expr_input, & schema) , expected) ;
499
593
}
500
594
501
595
#[ test]
@@ -505,6 +599,16 @@ mod tests {
505
599
let expr_input = cast ( col ( "c6" ) , DataType :: UInt64 ) . eq ( lit ( 0u64 ) ) ;
506
600
let expected = col ( "c6" ) . eq ( lit ( 0u32 ) ) ;
507
601
assert_eq ! ( optimize_test( expr_input, & schema) , expected) ;
602
+
603
+ // cast(c6, UTF8) = "123" => c6 = 123
604
+ let expr_input = cast ( col ( "c6" ) , DataType :: Utf8 ) . eq ( lit ( "123" ) ) ;
605
+ let expected = col ( "c6" ) . eq ( lit ( 123u32 ) ) ;
606
+ assert_eq ! ( optimize_test( expr_input, & schema) , expected) ;
607
+
608
+ // cast(c6, UTF8) != "123" => c6 != 123
609
+ let expr_input = cast ( col ( "c6" ) , DataType :: Utf8 ) . not_eq ( lit ( "123" ) ) ;
610
+ let expected = col ( "c6" ) . not_eq ( lit ( 123u32 ) ) ;
611
+ assert_eq ! ( optimize_test( expr_input, & schema) , expected) ;
508
612
}
509
613
510
614
#[ test]
0 commit comments