@@ -22,7 +22,7 @@ use std::sync::Arc;
22
22
use arrow:: datatypes:: { DataType , IntervalUnit } ;
23
23
24
24
use datafusion_common:: config:: ConfigOptions ;
25
- use datafusion_common:: tree_node:: { Transformed , TreeNodeRewriter } ;
25
+ use datafusion_common:: tree_node:: { Transformed , TreeNode , TreeNodeRewriter } ;
26
26
use datafusion_common:: {
27
27
exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema ,
28
28
DataFusionError , Result , ScalarValue ,
@@ -31,8 +31,8 @@ use datafusion_expr::expr::{
31
31
self , AggregateFunctionDefinition , Between , BinaryExpr , Case , Exists , InList ,
32
32
InSubquery , Like , ScalarFunction , WindowFunction ,
33
33
} ;
34
- use datafusion_expr:: expr_rewriter:: rewrite_preserving_name;
35
34
use datafusion_expr:: expr_schema:: cast_subquery;
35
+ use datafusion_expr:: logical_plan:: tree_node:: unwrap_arc;
36
36
use datafusion_expr:: logical_plan:: Subquery ;
37
37
use datafusion_expr:: type_coercion:: binary:: {
38
38
comparison_coercion, get_input_types, like_coercion,
@@ -52,6 +52,7 @@ use datafusion_expr::{
52
52
} ;
53
53
54
54
use crate :: analyzer:: AnalyzerRule ;
55
+ use crate :: utils:: NamePreserver ;
55
56
56
57
#[ derive( Default ) ]
57
58
pub struct TypeCoercion { }
@@ -68,26 +69,28 @@ impl AnalyzerRule for TypeCoercion {
68
69
}
69
70
70
71
fn analyze ( & self , plan : LogicalPlan , _: & ConfigOptions ) -> Result < LogicalPlan > {
71
- analyze_internal ( & DFSchema :: empty ( ) , & plan)
72
+ let empty_schema = DFSchema :: empty ( ) ;
73
+
74
+ let transformed_plan = plan
75
+ . transform_up_with_subqueries ( |plan| analyze_internal ( & empty_schema, plan) ) ?
76
+ . data ;
77
+
78
+ Ok ( transformed_plan)
72
79
}
73
80
}
74
81
82
+ /// use the external schema to handle the correlated subqueries case
83
+ ///
84
+ /// Assumes that children have already been optimized
75
85
fn analyze_internal (
76
- // use the external schema to handle the correlated subqueries case
77
86
external_schema : & DFSchema ,
78
- plan : & LogicalPlan ,
79
- ) -> Result < LogicalPlan > {
80
- // optimize child plans first
81
- let new_inputs = plan
82
- . inputs ( )
83
- . iter ( )
84
- . map ( |p| analyze_internal ( external_schema, p) )
85
- . collect :: < Result < Vec < _ > > > ( ) ?;
87
+ plan : LogicalPlan ,
88
+ ) -> Result < Transformed < LogicalPlan > > {
86
89
// get schema representing all available input fields. This is used for data type
87
90
// resolution only, so order does not matter here
88
- let mut schema = merge_schema ( new_inputs . iter ( ) . collect ( ) ) ;
91
+ let mut schema = merge_schema ( plan . inputs ( ) ) ;
89
92
90
- if let LogicalPlan :: TableScan ( ts) = plan {
93
+ if let LogicalPlan :: TableScan ( ts) = & plan {
91
94
let source_schema = DFSchema :: try_from_qualified_schema (
92
95
ts. table_name . clone ( ) ,
93
96
& ts. source . schema ( ) ,
@@ -100,25 +103,75 @@ fn analyze_internal(
100
103
// select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3)
101
104
schema. merge ( external_schema) ;
102
105
103
- let mut expr_rewrite = TypeCoercionRewriter { schema : & schema } ;
104
-
105
- let new_expr = plan
106
- . expressions ( )
107
- . into_iter ( )
108
- . map ( | expr| {
109
- // ensure aggregate names don't change:
110
- // https://github.com/apache/datafusion/issues/3555
111
- rewrite_preserving_name ( expr , & mut expr_rewrite )
112
- } )
113
- . collect :: < Result < Vec < _ > > > ( ) ? ;
114
-
115
- plan . with_new_exprs ( new_expr , new_inputs )
106
+ let mut expr_rewrite = TypeCoercionRewriter :: new ( & schema) ;
107
+
108
+ let name_preserver = NamePreserver :: new ( & plan) ;
109
+ // apply coercion rewrite all expressions in the plan individually
110
+ plan . map_expressions ( |expr| {
111
+ let original_name = name_preserver . save ( & expr) ? ;
112
+ expr . rewrite ( & mut expr_rewrite ) ?
113
+ . map_data ( |expr| original_name . restore ( expr ) )
114
+ } ) ?
115
+ // coerce join expressions specially
116
+ . map_data ( |plan| expr_rewrite . coerce_joins ( plan ) ) ?
117
+ // recompute the schema after the expressions have been rewritten as the types may have changed
118
+ . map_data ( |plan| plan . recompute_schema ( ) )
116
119
}
117
120
118
121
pub ( crate ) struct TypeCoercionRewriter < ' a > {
119
122
pub ( crate ) schema : & ' a DFSchema ,
120
123
}
121
124
125
+ impl < ' a > TypeCoercionRewriter < ' a > {
126
+ fn new ( schema : & ' a DFSchema ) -> Self {
127
+ Self { schema }
128
+ }
129
+
130
+ /// Coerce join equality expressions
131
+ ///
132
+ /// Joins must be treated specially as their equality expressions are stored
133
+ /// as a parallel list of left and right expressions, rather than a single
134
+ /// equality expression
135
+ ///
136
+ /// For example, on_exprs like `t1.a = t2.b AND t1.x = t2.y` will be stored
137
+ /// as a list of `(t1.a, t2.b), (t1.x, t2.y)`
138
+ fn coerce_joins ( & mut self , plan : LogicalPlan ) -> Result < LogicalPlan > {
139
+ let LogicalPlan :: Join ( mut join) = plan else {
140
+ return Ok ( plan) ;
141
+ } ;
142
+
143
+ join. on = join
144
+ . on
145
+ . into_iter ( )
146
+ . map ( |( lhs, rhs) | {
147
+ // coerce the arguments as though they were a single binary equality
148
+ // expression
149
+ let ( lhs, rhs) = self . coerce_binary_op ( lhs, Operator :: Eq , rhs) ?;
150
+ Ok ( ( lhs, rhs) )
151
+ } )
152
+ . collect :: < Result < Vec < _ > > > ( ) ?;
153
+
154
+ Ok ( LogicalPlan :: Join ( join) )
155
+ }
156
+
157
+ fn coerce_binary_op (
158
+ & self ,
159
+ left : Expr ,
160
+ op : Operator ,
161
+ right : Expr ,
162
+ ) -> Result < ( Expr , Expr ) > {
163
+ let ( left_type, right_type) = get_input_types (
164
+ & left. get_type ( self . schema ) ?,
165
+ & op,
166
+ & right. get_type ( self . schema ) ?,
167
+ ) ?;
168
+ Ok ( (
169
+ left. cast_to ( & left_type, self . schema ) ?,
170
+ right. cast_to ( & right_type, self . schema ) ?,
171
+ ) )
172
+ }
173
+ }
174
+
122
175
impl < ' a > TreeNodeRewriter for TypeCoercionRewriter < ' a > {
123
176
type Node = Expr ;
124
177
@@ -131,14 +184,15 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
131
184
subquery,
132
185
outer_ref_columns,
133
186
} ) => {
134
- let new_plan = analyze_internal ( self . schema , & subquery) ? ;
187
+ let new_plan = analyze_internal ( self . schema , unwrap_arc ( subquery) ) ? . data ;
135
188
Ok ( Transformed :: yes ( Expr :: ScalarSubquery ( Subquery {
136
189
subquery : Arc :: new ( new_plan) ,
137
190
outer_ref_columns,
138
191
} ) ) )
139
192
}
140
193
Expr :: Exists ( Exists { subquery, negated } ) => {
141
- let new_plan = analyze_internal ( self . schema , & subquery. subquery ) ?;
194
+ let new_plan =
195
+ analyze_internal ( self . schema , unwrap_arc ( subquery. subquery ) ) ?. data ;
142
196
Ok ( Transformed :: yes ( Expr :: Exists ( Exists {
143
197
subquery : Subquery {
144
198
subquery : Arc :: new ( new_plan) ,
@@ -152,7 +206,8 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
152
206
subquery,
153
207
negated,
154
208
} ) => {
155
- let new_plan = analyze_internal ( self . schema , & subquery. subquery ) ?;
209
+ let new_plan =
210
+ analyze_internal ( self . schema , unwrap_arc ( subquery. subquery ) ) ?. data ;
156
211
let expr_type = expr. get_type ( self . schema ) ?;
157
212
let subquery_type = new_plan. schema ( ) . field ( 0 ) . data_type ( ) ;
158
213
let common_type = comparison_coercion ( & expr_type, subquery_type) . ok_or ( plan_datafusion_err ! (
@@ -221,15 +276,11 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
221
276
) ) ) )
222
277
}
223
278
Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) => {
224
- let ( left_type, right_type) = get_input_types (
225
- & left. get_type ( self . schema ) ?,
226
- & op,
227
- & right. get_type ( self . schema ) ?,
228
- ) ?;
279
+ let ( left, right) = self . coerce_binary_op ( * left, op, * right) ?;
229
280
Ok ( Transformed :: yes ( Expr :: BinaryExpr ( BinaryExpr :: new (
230
- Box :: new ( left. cast_to ( & left_type , self . schema ) ? ) ,
281
+ Box :: new ( left) ,
231
282
op,
232
- Box :: new ( right. cast_to ( & right_type , self . schema ) ? ) ,
283
+ Box :: new ( right) ,
233
284
) ) ) )
234
285
}
235
286
Expr :: Between ( Between {
0 commit comments