Skip to content

Commit 26c0c49

Browse files
alan910127alamb
andauthored
perf: unwrap cast for comparing ints =/!= strings (#15110)
* perf: unwrap cast for comparing ints =/!= strings * fix: update casting logic * test: add more unit test and new sqllogictest * Tweak slt tests * Revert "perf: unwrap cast for comparing ints =/!= strings" This reverts commit 808d6ab. * fix: eliminate column cast and cast literal before coercion * fix: physical expr coercion test * feat: unwrap cast after round-trip cast verification * fix: unwrap cast on round-trip cast stable strings * revert: remove avoid cast changes * refactor: apply review suggestions --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 67d0dab commit 26c0c49

File tree

4 files changed

+171
-2
lines changed

4 files changed

+171
-2
lines changed

datafusion/optimizer/src/analyzer/type_coercion.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ impl<'a> TypeCoercionRewriter<'a> {
296296
&right.get_type(right_schema)?,
297297
)
298298
.get_input_types()?;
299+
299300
Ok((
300301
left.cast_to(&left_type, left_schema)?,
301302
right.cast_to(&right_type, right_schema)?,

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1758,7 +1758,7 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
17581758
// try_cast/cast(expr as data_type) op literal
17591759
Expr::BinaryExpr(BinaryExpr { left, op, right })
17601760
if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
1761-
info, &left, &right,
1761+
info, &left, op, &right,
17621762
) && op.supports_propagation() =>
17631763
{
17641764
unwrap_cast_in_comparison_for_binary(info, left, right, op)?
@@ -1768,7 +1768,7 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
17681768
// try_cast/cast(expr as data_type) op_swap literal
17691769
Expr::BinaryExpr(BinaryExpr { left, op, right })
17701770
if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
1771-
info, &right, &left,
1771+
info, &right, op, &left,
17721772
) && op.supports_propagation()
17731773
&& op.swap().is_some() =>
17741774
{

datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,16 @@ pub(super) fn unwrap_cast_in_comparison_for_binary<S: SimplifyInfo>(
8181
let Ok(expr_type) = info.get_data_type(&expr) else {
8282
return internal_err!("Can't get the data type of the expr {:?}", &expr);
8383
};
84+
85+
if let Some(value) = cast_literal_to_type_with_op(&lit_value, &expr_type, op)
86+
{
87+
return Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
88+
left: expr,
89+
op,
90+
right: Box::new(lit(value)),
91+
})));
92+
};
93+
8494
// if the lit_value can be casted to the type of internal_left_expr
8595
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
8696
let Some(value) = try_cast_literal_to_type(&lit_value, &expr_type) else {
@@ -105,6 +115,7 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
105115
>(
106116
info: &S,
107117
expr: &Expr,
118+
op: Operator,
108119
literal: &Expr,
109120
) -> bool {
110121
match (expr, literal) {
@@ -125,6 +136,10 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
125136
return false;
126137
};
127138

139+
if cast_literal_to_type_with_op(lit_val, &expr_type, op).is_some() {
140+
return true;
141+
}
142+
128143
try_cast_literal_to_type(lit_val, &expr_type).is_some()
129144
&& is_supported_type(&expr_type)
130145
&& is_supported_type(&lit_type)
@@ -215,6 +230,52 @@ fn is_supported_dictionary_type(data_type: &DataType) -> bool {
215230
DataType::Dictionary(_, inner) if is_supported_type(inner))
216231
}
217232

233+
///// Tries to move a cast from an expression (such as column) to the literal other side of a comparison operator./
234+
///
235+
/// Specifically, rewrites
236+
/// ```sql
237+
/// cast(col) <op> <literal>
238+
/// ```
239+
///
240+
/// To
241+
///
242+
/// ```sql
243+
/// col <op> cast(<literal>)
244+
/// col <op> <casted_literal>
245+
/// ```
246+
fn cast_literal_to_type_with_op(
247+
lit_value: &ScalarValue,
248+
target_type: &DataType,
249+
op: Operator,
250+
) -> Option<ScalarValue> {
251+
match (op, lit_value) {
252+
(
253+
Operator::Eq | Operator::NotEq,
254+
ScalarValue::Utf8(Some(_))
255+
| ScalarValue::Utf8View(Some(_))
256+
| ScalarValue::LargeUtf8(Some(_)),
257+
) => {
258+
// Only try for integer types (TODO can we do this for other types
259+
// like timestamps)?
260+
use DataType::*;
261+
if matches!(
262+
target_type,
263+
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
264+
) {
265+
let casted = lit_value.cast_to(target_type).ok()?;
266+
let round_tripped = casted.cast_to(&lit_value.data_type()).ok()?;
267+
if lit_value != &round_tripped {
268+
return None;
269+
}
270+
Some(casted)
271+
} else {
272+
None
273+
}
274+
}
275+
_ => None,
276+
}
277+
}
278+
218279
/// Convert a literal value from one data type to another
219280
pub(super) fn try_cast_literal_to_type(
220281
lit_value: &ScalarValue,
@@ -468,6 +529,24 @@ mod tests {
468529
// the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type
469530
let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(99999999999i64));
470531
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
532+
533+
// cast(c1, UTF8) < '123', only eq/not_eq should be optimized
534+
let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("123"));
535+
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
536+
537+
// cast(c1, UTF8) = '0123', cast(cast('0123', Int32), UTF8) != '0123', so '0123' should not
538+
// be casted
539+
let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("0123"));
540+
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
541+
542+
// cast(c1, UTF8) = 'not a number', should not be able to cast to column type
543+
let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("not a number"));
544+
assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
545+
546+
// cast(c1, UTF8) = '99999999999', where '99999999999' does not fit into int32, so it will
547+
// not be optimized to integer comparison
548+
let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("99999999999"));
549+
assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
471550
}
472551

473552
#[test]
@@ -496,6 +575,21 @@ mod tests {
496575
let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32));
497576
let expected = null_bool();
498577
assert_eq!(optimize_test(lit_lt_lit, &schema), expected);
578+
579+
// cast(c1, UTF8) = '123' => c1 = 123
580+
let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("123"));
581+
let expected = col("c1").eq(lit(123i32));
582+
assert_eq!(optimize_test(expr_input, &schema), expected);
583+
584+
// cast(c1, UTF8) != '123' => c1 != 123
585+
let expr_input = cast(col("c1"), DataType::Utf8).not_eq(lit("123"));
586+
let expected = col("c1").not_eq(lit(123i32));
587+
assert_eq!(optimize_test(expr_input, &schema), expected);
588+
589+
// cast(c1, UTF8) = NULL => c1 = NULL
590+
let expr_input = cast(col("c1"), DataType::Utf8).eq(lit(ScalarValue::Utf8(None)));
591+
let expected = col("c1").eq(lit(ScalarValue::Int32(None)));
592+
assert_eq!(optimize_test(expr_input, &schema), expected);
499593
}
500594

501595
#[test]
@@ -505,6 +599,16 @@ mod tests {
505599
let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64));
506600
let expected = col("c6").eq(lit(0u32));
507601
assert_eq!(optimize_test(expr_input, &schema), expected);
602+
603+
// cast(c6, UTF8) = "123" => c6 = 123
604+
let expr_input = cast(col("c6"), DataType::Utf8).eq(lit("123"));
605+
let expected = col("c6").eq(lit(123u32));
606+
assert_eq!(optimize_test(expr_input, &schema), expected);
607+
608+
// cast(c6, UTF8) != "123" => c6 != 123
609+
let expr_input = cast(col("c6"), DataType::Utf8).not_eq(lit("123"));
610+
let expected = col("c6").not_eq(lit(123u32));
611+
assert_eq!(optimize_test(expr_input, &schema), expected);
508612
}
509613

510614
#[test]

datafusion/sqllogictest/test_files/push_down_filter.slt

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,74 @@ select * from test_filter_with_limit where value = 2 limit 1;
188188
----
189189
2 2
190190

191+
191192
# Tear down test_filter_with_limit table:
192193
statement ok
193194
DROP TABLE test_filter_with_limit;
194195

195196
# Tear down src_table table:
196197
statement ok
197198
DROP TABLE src_table;
199+
200+
201+
query I
202+
COPY (VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10))
203+
TO 'test_files/scratch/push_down_filter/t.parquet'
204+
STORED AS PARQUET;
205+
----
206+
10
207+
208+
statement ok
209+
CREATE EXTERNAL TABLE t
210+
(
211+
a INT
212+
)
213+
STORED AS PARQUET
214+
LOCATION 'test_files/scratch/push_down_filter/t.parquet';
215+
216+
217+
# The predicate should not have a column cast when the value is a valid i32
218+
query TT
219+
explain select a from t where a = '100';
220+
----
221+
logical_plan TableScan: t projection=[a], full_filters=[t.a = Int32(100)]
222+
223+
# The predicate should not have a column cast when the value is a valid i32
224+
query TT
225+
explain select a from t where a != '100';
226+
----
227+
logical_plan TableScan: t projection=[a], full_filters=[t.a != Int32(100)]
228+
229+
# The predicate should still have the column cast when the value is a NOT valid i32
230+
query TT
231+
explain select a from t where a = '99999999999';
232+
----
233+
logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("99999999999")]
234+
235+
# The predicate should still have the column cast when the value is a NOT valid i32
236+
query TT
237+
explain select a from t where a = '99.99';
238+
----
239+
logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("99.99")]
240+
241+
# The predicate should still have the column cast when the value is a NOT valid i32
242+
query TT
243+
explain select a from t where a = '';
244+
----
245+
logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("")]
246+
247+
# The predicate should not have a column cast when the operator is = or != and the literal can be round-trip casted without losing information.
248+
query TT
249+
explain select a from t where cast(a as string) = '100';
250+
----
251+
logical_plan TableScan: t projection=[a], full_filters=[t.a = Int32(100)]
252+
253+
# The predicate should still have the column cast when the literal alters its string representation after round-trip casting (leading zero lost).
254+
query TT
255+
explain select a from t where CAST(a AS string) = '0123';
256+
----
257+
logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("0123")]
258+
259+
260+
statement ok
261+
drop table t;

0 commit comments

Comments
 (0)