@@ -25,7 +25,7 @@ use datafusion_common::config::ConfigOptions;
25
25
use datafusion_common:: tree_node:: { Transformed , TreeNodeRewriter } ;
26
26
use datafusion_common:: {
27
27
exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema ,
28
- DFSchemaRef , DataFusionError , Result , ScalarValue ,
28
+ DataFusionError , Result , ScalarValue ,
29
29
} ;
30
30
use datafusion_expr:: expr:: {
31
31
self , AggregateFunctionDefinition , Between , BinaryExpr , Case , Exists , InList ,
@@ -99,9 +99,7 @@ fn analyze_internal(
99
99
// select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3)
100
100
schema. merge ( external_schema) ;
101
101
102
- let mut expr_rewrite = TypeCoercionRewriter {
103
- schema : Arc :: new ( schema) ,
104
- } ;
102
+ let mut expr_rewrite = TypeCoercionRewriter { schema : & schema } ;
105
103
106
104
let new_expr = plan
107
105
. expressions ( )
@@ -116,11 +114,11 @@ fn analyze_internal(
116
114
plan. with_new_exprs ( new_expr, new_inputs)
117
115
}
118
116
119
- pub ( crate ) struct TypeCoercionRewriter {
120
- pub ( crate ) schema : DFSchemaRef ,
117
+ pub ( crate ) struct TypeCoercionRewriter < ' a > {
118
+ pub ( crate ) schema : & ' a DFSchema ,
121
119
}
122
120
123
- impl TreeNodeRewriter for TypeCoercionRewriter {
121
+ impl < ' a > TreeNodeRewriter for TypeCoercionRewriter < ' a > {
124
122
type Node = Expr ;
125
123
126
124
fn f_up ( & mut self , expr : Expr ) -> Result < Transformed < Expr > > {
@@ -132,14 +130,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
132
130
subquery,
133
131
outer_ref_columns,
134
132
} ) => {
135
- let new_plan = analyze_internal ( & self . schema , & subquery) ?;
133
+ let new_plan = analyze_internal ( self . schema , & subquery) ?;
136
134
Ok ( Transformed :: yes ( Expr :: ScalarSubquery ( Subquery {
137
135
subquery : Arc :: new ( new_plan) ,
138
136
outer_ref_columns,
139
137
} ) ) )
140
138
}
141
139
Expr :: Exists ( Exists { subquery, negated } ) => {
142
- let new_plan = analyze_internal ( & self . schema , & subquery. subquery ) ?;
140
+ let new_plan = analyze_internal ( self . schema , & subquery. subquery ) ?;
143
141
Ok ( Transformed :: yes ( Expr :: Exists ( Exists {
144
142
subquery : Subquery {
145
143
subquery : Arc :: new ( new_plan) ,
@@ -153,8 +151,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
153
151
subquery,
154
152
negated,
155
153
} ) => {
156
- let new_plan = analyze_internal ( & self . schema , & subquery. subquery ) ?;
157
- let expr_type = expr. get_type ( & self . schema ) ?;
154
+ let new_plan = analyze_internal ( self . schema , & subquery. subquery ) ?;
155
+ let expr_type = expr. get_type ( self . schema ) ?;
158
156
let subquery_type = new_plan. schema ( ) . field ( 0 ) . data_type ( ) ;
159
157
let common_type = comparison_coercion ( & expr_type, subquery_type) . ok_or ( plan_datafusion_err ! (
160
158
"expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery"
@@ -165,32 +163,32 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
165
163
outer_ref_columns : subquery. outer_ref_columns ,
166
164
} ;
167
165
Ok ( Transformed :: yes ( Expr :: InSubquery ( InSubquery :: new (
168
- Box :: new ( expr. cast_to ( & common_type, & self . schema ) ?) ,
166
+ Box :: new ( expr. cast_to ( & common_type, self . schema ) ?) ,
169
167
cast_subquery ( new_subquery, & common_type) ?,
170
168
negated,
171
169
) ) ) )
172
170
}
173
171
Expr :: Not ( expr) => Ok ( Transformed :: yes ( not ( get_casted_expr_for_bool_op (
174
172
* expr,
175
- & self . schema ,
173
+ self . schema ,
176
174
) ?) ) ) ,
177
175
Expr :: IsTrue ( expr) => Ok ( Transformed :: yes ( is_true (
178
- get_casted_expr_for_bool_op ( * expr, & self . schema ) ?,
176
+ get_casted_expr_for_bool_op ( * expr, self . schema ) ?,
179
177
) ) ) ,
180
178
Expr :: IsNotTrue ( expr) => Ok ( Transformed :: yes ( is_not_true (
181
- get_casted_expr_for_bool_op ( * expr, & self . schema ) ?,
179
+ get_casted_expr_for_bool_op ( * expr, self . schema ) ?,
182
180
) ) ) ,
183
181
Expr :: IsFalse ( expr) => Ok ( Transformed :: yes ( is_false (
184
- get_casted_expr_for_bool_op ( * expr, & self . schema ) ?,
182
+ get_casted_expr_for_bool_op ( * expr, self . schema ) ?,
185
183
) ) ) ,
186
184
Expr :: IsNotFalse ( expr) => Ok ( Transformed :: yes ( is_not_false (
187
- get_casted_expr_for_bool_op ( * expr, & self . schema ) ?,
185
+ get_casted_expr_for_bool_op ( * expr, self . schema ) ?,
188
186
) ) ) ,
189
187
Expr :: IsUnknown ( expr) => Ok ( Transformed :: yes ( is_unknown (
190
- get_casted_expr_for_bool_op ( * expr, & self . schema ) ?,
188
+ get_casted_expr_for_bool_op ( * expr, self . schema ) ?,
191
189
) ) ) ,
192
190
Expr :: IsNotUnknown ( expr) => Ok ( Transformed :: yes ( is_not_unknown (
193
- get_casted_expr_for_bool_op ( * expr, & self . schema ) ?,
191
+ get_casted_expr_for_bool_op ( * expr, self . schema ) ?,
194
192
) ) ) ,
195
193
Expr :: Like ( Like {
196
194
negated,
@@ -199,8 +197,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
199
197
escape_char,
200
198
case_insensitive,
201
199
} ) => {
202
- let left_type = expr. get_type ( & self . schema ) ?;
203
- let right_type = pattern. get_type ( & self . schema ) ?;
200
+ let left_type = expr. get_type ( self . schema ) ?;
201
+ let right_type = pattern. get_type ( self . schema ) ?;
204
202
let coerced_type = like_coercion ( & left_type, & right_type) . ok_or_else ( || {
205
203
let op_name = if case_insensitive {
206
204
"ILIKE"
@@ -211,8 +209,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
211
209
"There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
212
210
)
213
211
} ) ?;
214
- let expr = Box :: new ( expr. cast_to ( & coerced_type, & self . schema ) ?) ;
215
- let pattern = Box :: new ( pattern. cast_to ( & coerced_type, & self . schema ) ?) ;
212
+ let expr = Box :: new ( expr. cast_to ( & coerced_type, self . schema ) ?) ;
213
+ let pattern = Box :: new ( pattern. cast_to ( & coerced_type, self . schema ) ?) ;
216
214
Ok ( Transformed :: yes ( Expr :: Like ( Like :: new (
217
215
negated,
218
216
expr,
@@ -223,14 +221,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
223
221
}
224
222
Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) => {
225
223
let ( left_type, right_type) = get_input_types (
226
- & left. get_type ( & self . schema ) ?,
224
+ & left. get_type ( self . schema ) ?,
227
225
& op,
228
- & right. get_type ( & self . schema ) ?,
226
+ & right. get_type ( self . schema ) ?,
229
227
) ?;
230
228
Ok ( Transformed :: yes ( Expr :: BinaryExpr ( BinaryExpr :: new (
231
- Box :: new ( left. cast_to ( & left_type, & self . schema ) ?) ,
229
+ Box :: new ( left. cast_to ( & left_type, self . schema ) ?) ,
232
230
op,
233
- Box :: new ( right. cast_to ( & right_type, & self . schema ) ?) ,
231
+ Box :: new ( right. cast_to ( & right_type, self . schema ) ?) ,
234
232
) ) ) )
235
233
}
236
234
Expr :: Between ( Between {
@@ -239,15 +237,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
239
237
low,
240
238
high,
241
239
} ) => {
242
- let expr_type = expr. get_type ( & self . schema ) ?;
243
- let low_type = low. get_type ( & self . schema ) ?;
240
+ let expr_type = expr. get_type ( self . schema ) ?;
241
+ let low_type = low. get_type ( self . schema ) ?;
244
242
let low_coerced_type = comparison_coercion ( & expr_type, & low_type)
245
243
. ok_or_else ( || {
246
244
DataFusionError :: Internal ( format ! (
247
245
"Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
248
246
) )
249
247
} ) ?;
250
- let high_type = high. get_type ( & self . schema ) ?;
248
+ let high_type = high. get_type ( self . schema ) ?;
251
249
let high_coerced_type = comparison_coercion ( & expr_type, & low_type)
252
250
. ok_or_else ( || {
253
251
DataFusionError :: Internal ( format ! (
@@ -262,21 +260,21 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
262
260
) )
263
261
} ) ?;
264
262
Ok ( Transformed :: yes ( Expr :: Between ( Between :: new (
265
- Box :: new ( expr. cast_to ( & coercion_type, & self . schema ) ?) ,
263
+ Box :: new ( expr. cast_to ( & coercion_type, self . schema ) ?) ,
266
264
negated,
267
- Box :: new ( low. cast_to ( & coercion_type, & self . schema ) ?) ,
268
- Box :: new ( high. cast_to ( & coercion_type, & self . schema ) ?) ,
265
+ Box :: new ( low. cast_to ( & coercion_type, self . schema ) ?) ,
266
+ Box :: new ( high. cast_to ( & coercion_type, self . schema ) ?) ,
269
267
) ) ) )
270
268
}
271
269
Expr :: InList ( InList {
272
270
expr,
273
271
list,
274
272
negated,
275
273
} ) => {
276
- let expr_data_type = expr. get_type ( & self . schema ) ?;
274
+ let expr_data_type = expr. get_type ( self . schema ) ?;
277
275
let list_data_types = list
278
276
. iter ( )
279
- . map ( |list_expr| list_expr. get_type ( & self . schema ) )
277
+ . map ( |list_expr| list_expr. get_type ( self . schema ) )
280
278
. collect :: < Result < Vec < _ > > > ( ) ?;
281
279
let result_type =
282
280
get_coerce_type_for_list ( & expr_data_type, & list_data_types) ;
@@ -286,11 +284,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
286
284
) ,
287
285
Some ( coerced_type) => {
288
286
// find the coerced type
289
- let cast_expr = expr. cast_to ( & coerced_type, & self . schema ) ?;
287
+ let cast_expr = expr. cast_to ( & coerced_type, self . schema ) ?;
290
288
let cast_list_expr = list
291
289
. into_iter ( )
292
290
. map ( |list_expr| {
293
- list_expr. cast_to ( & coerced_type, & self . schema )
291
+ list_expr. cast_to ( & coerced_type, self . schema )
294
292
} )
295
293
. collect :: < Result < Vec < _ > > > ( ) ?;
296
294
Ok ( Transformed :: yes ( Expr :: InList ( InList :: new (
@@ -302,18 +300,17 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
302
300
}
303
301
}
304
302
Expr :: Case ( case) => {
305
- let case = coerce_case_expression ( case, & self . schema ) ?;
303
+ let case = coerce_case_expression ( case, self . schema ) ?;
306
304
Ok ( Transformed :: yes ( Expr :: Case ( case) ) )
307
305
}
308
306
Expr :: ScalarFunction ( ScalarFunction { func_def, args } ) => match func_def {
309
307
ScalarFunctionDefinition :: UDF ( fun) => {
310
308
let new_expr = coerce_arguments_for_signature (
311
309
args,
312
- & self . schema ,
310
+ self . schema ,
313
311
fun. signature ( ) ,
314
312
) ?;
315
- let new_expr =
316
- coerce_arguments_for_fun ( new_expr, & self . schema , & fun) ?;
313
+ let new_expr = coerce_arguments_for_fun ( new_expr, self . schema , & fun) ?;
317
314
Ok ( Transformed :: yes ( Expr :: ScalarFunction (
318
315
ScalarFunction :: new_udf ( fun, new_expr) ,
319
316
) ) )
@@ -331,7 +328,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
331
328
let new_expr = coerce_agg_exprs_for_signature (
332
329
& fun,
333
330
args,
334
- & self . schema ,
331
+ self . schema ,
335
332
& fun. signature ( ) ,
336
333
) ?;
337
334
Ok ( Transformed :: yes ( Expr :: AggregateFunction (
@@ -348,7 +345,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
348
345
AggregateFunctionDefinition :: UDF ( fun) => {
349
346
let new_expr = coerce_arguments_for_signature (
350
347
args,
351
- & self . schema ,
348
+ self . schema ,
352
349
fun. signature ( ) ,
353
350
) ?;
354
351
Ok ( Transformed :: yes ( Expr :: AggregateFunction (
@@ -375,14 +372,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
375
372
null_treatment,
376
373
} ) => {
377
374
let window_frame =
378
- coerce_window_frame ( window_frame, & self . schema , & order_by) ?;
375
+ coerce_window_frame ( window_frame, self . schema , & order_by) ?;
379
376
380
377
let args = match & fun {
381
378
expr:: WindowFunctionDefinition :: AggregateFunction ( fun) => {
382
379
coerce_agg_exprs_for_signature (
383
380
fun,
384
381
args,
385
- & self . schema ,
382
+ self . schema ,
386
383
& fun. signature ( ) ,
387
384
) ?
388
385
}
@@ -495,7 +492,7 @@ fn coerce_frame_bound(
495
492
// For example, ROWS and GROUPS frames use `UInt64` during calculations.
496
493
fn coerce_window_frame (
497
494
window_frame : WindowFrame ,
498
- schema : & DFSchemaRef ,
495
+ schema : & DFSchema ,
499
496
expressions : & [ Expr ] ,
500
497
) -> Result < WindowFrame > {
501
498
let mut window_frame = window_frame;
@@ -531,7 +528,7 @@ fn coerce_window_frame(
531
528
532
529
// Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion.
533
530
// The above op will be rewrite to the binary op when creating the physical op.
534
- fn get_casted_expr_for_bool_op ( expr : Expr , schema : & DFSchemaRef ) -> Result < Expr > {
531
+ fn get_casted_expr_for_bool_op ( expr : Expr , schema : & DFSchema ) -> Result < Expr > {
535
532
let left_type = expr. get_type ( schema) ?;
536
533
get_input_types ( & left_type, & Operator :: IsDistinctFrom , & DataType :: Boolean ) ?;
537
534
expr. cast_to ( & DataType :: Boolean , schema)
@@ -615,7 +612,7 @@ fn coerce_agg_exprs_for_signature(
615
612
. collect ( )
616
613
}
617
614
618
- fn coerce_case_expression ( case : Case , schema : & DFSchemaRef ) -> Result < Case > {
615
+ fn coerce_case_expression ( case : Case , schema : & DFSchema ) -> Result < Case > {
619
616
// Given expressions like:
620
617
//
621
618
// CASE a1
@@ -1238,7 +1235,7 @@ mod test {
1238
1235
vec ! [ Field :: new( "a" , DataType :: Int64 , true ) ] . into ( ) ,
1239
1236
std:: collections:: HashMap :: new ( ) ,
1240
1237
) ?) ;
1241
- let mut rewriter = TypeCoercionRewriter { schema } ;
1238
+ let mut rewriter = TypeCoercionRewriter { schema : & schema } ;
1242
1239
let expr = is_true ( lit ( 12i32 ) . gt ( lit ( 13i64 ) ) ) ;
1243
1240
let expected = is_true ( cast ( lit ( 12i32 ) , DataType :: Int64 ) . gt ( lit ( 13i64 ) ) ) ;
1244
1241
let result = expr. rewrite ( & mut rewriter) . data ( ) ?;
@@ -1249,7 +1246,7 @@ mod test {
1249
1246
vec ! [ Field :: new( "a" , DataType :: Int64 , true ) ] . into ( ) ,
1250
1247
std:: collections:: HashMap :: new ( ) ,
1251
1248
) ?) ;
1252
- let mut rewriter = TypeCoercionRewriter { schema } ;
1249
+ let mut rewriter = TypeCoercionRewriter { schema : & schema } ;
1253
1250
let expr = is_true ( lit ( 12i32 ) . eq ( lit ( 13i64 ) ) ) ;
1254
1251
let expected = is_true ( cast ( lit ( 12i32 ) , DataType :: Int64 ) . eq ( lit ( 13i64 ) ) ) ;
1255
1252
let result = expr. rewrite ( & mut rewriter) . data ( ) ?;
@@ -1260,7 +1257,7 @@ mod test {
1260
1257
vec ! [ Field :: new( "a" , DataType :: Int64 , true ) ] . into ( ) ,
1261
1258
std:: collections:: HashMap :: new ( ) ,
1262
1259
) ?) ;
1263
- let mut rewriter = TypeCoercionRewriter { schema } ;
1260
+ let mut rewriter = TypeCoercionRewriter { schema : & schema } ;
1264
1261
let expr = is_true ( lit ( 12i32 ) . lt ( lit ( 13i64 ) ) ) ;
1265
1262
let expected = is_true ( cast ( lit ( 12i32 ) , DataType :: Int64 ) . lt ( lit ( 13i64 ) ) ) ;
1266
1263
let result = expr. rewrite ( & mut rewriter) . data ( ) ?;
0 commit comments