Skip to content

Commit fa31c78

Browse files
authored
Improve coerce API so it does not need DFSchema (#10331)
1 parent 6d77748 commit fa31c78

File tree

4 files changed

+54
-69
lines changed

4 files changed

+54
-69
lines changed

datafusion-examples/examples/expr_api.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ pub fn physical_expr(schema: &Schema, expr: Expr) -> Result<Arc<dyn PhysicalExpr
258258
ExprSimplifier::new(SimplifyContext::new(&props).with_schema(df_schema.clone()));
259259

260260
// apply type coercion here to ensure types match
261-
let expr = simplifier.coerce(expr, df_schema.clone())?;
261+
let expr = simplifier.coerce(expr, &df_schema)?;
262262

263263
create_physical_expr(&expr, df_schema.as_ref(), &props)
264264
}

datafusion/core/src/test_util/parquet.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ impl TestParquetFile {
169169
let parquet_options = ctx.copied_table_options().parquet;
170170
if let Some(filter) = maybe_filter {
171171
let simplifier = ExprSimplifier::new(context);
172-
let filter = simplifier.coerce(filter, df_schema.clone()).unwrap();
172+
let filter = simplifier.coerce(filter, &df_schema).unwrap();
173173
let physical_filter_expr =
174174
create_physical_expr(&filter, &df_schema, &ExecutionProps::default())?;
175175
let parquet_exec = Arc::new(ParquetExec::new(

datafusion/optimizer/src/analyzer/type_coercion.rs

Lines changed: 48 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use datafusion_common::config::ConfigOptions;
2525
use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
2626
use datafusion_common::{
2727
exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema,
28-
DFSchemaRef, DataFusionError, Result, ScalarValue,
28+
DataFusionError, Result, ScalarValue,
2929
};
3030
use datafusion_expr::expr::{
3131
self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList,
@@ -99,9 +99,7 @@ fn analyze_internal(
9999
// select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3)
100100
schema.merge(external_schema);
101101

102-
let mut expr_rewrite = TypeCoercionRewriter {
103-
schema: Arc::new(schema),
104-
};
102+
let mut expr_rewrite = TypeCoercionRewriter { schema: &schema };
105103

106104
let new_expr = plan
107105
.expressions()
@@ -116,11 +114,11 @@ fn analyze_internal(
116114
plan.with_new_exprs(new_expr, new_inputs)
117115
}
118116

119-
pub(crate) struct TypeCoercionRewriter {
120-
pub(crate) schema: DFSchemaRef,
117+
pub(crate) struct TypeCoercionRewriter<'a> {
118+
pub(crate) schema: &'a DFSchema,
121119
}
122120

123-
impl TreeNodeRewriter for TypeCoercionRewriter {
121+
impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
124122
type Node = Expr;
125123

126124
fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
@@ -132,14 +130,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
132130
subquery,
133131
outer_ref_columns,
134132
}) => {
135-
let new_plan = analyze_internal(&self.schema, &subquery)?;
133+
let new_plan = analyze_internal(self.schema, &subquery)?;
136134
Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
137135
subquery: Arc::new(new_plan),
138136
outer_ref_columns,
139137
})))
140138
}
141139
Expr::Exists(Exists { subquery, negated }) => {
142-
let new_plan = analyze_internal(&self.schema, &subquery.subquery)?;
140+
let new_plan = analyze_internal(self.schema, &subquery.subquery)?;
143141
Ok(Transformed::yes(Expr::Exists(Exists {
144142
subquery: Subquery {
145143
subquery: Arc::new(new_plan),
@@ -153,8 +151,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
153151
subquery,
154152
negated,
155153
}) => {
156-
let new_plan = analyze_internal(&self.schema, &subquery.subquery)?;
157-
let expr_type = expr.get_type(&self.schema)?;
154+
let new_plan = analyze_internal(self.schema, &subquery.subquery)?;
155+
let expr_type = expr.get_type(self.schema)?;
158156
let subquery_type = new_plan.schema().field(0).data_type();
159157
let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!(
160158
"expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery"
@@ -165,32 +163,32 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
165163
outer_ref_columns: subquery.outer_ref_columns,
166164
};
167165
Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
168-
Box::new(expr.cast_to(&common_type, &self.schema)?),
166+
Box::new(expr.cast_to(&common_type, self.schema)?),
169167
cast_subquery(new_subquery, &common_type)?,
170168
negated,
171169
))))
172170
}
173171
Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
174172
*expr,
175-
&self.schema,
173+
self.schema,
176174
)?))),
177175
Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
178-
get_casted_expr_for_bool_op(*expr, &self.schema)?,
176+
get_casted_expr_for_bool_op(*expr, self.schema)?,
179177
))),
180178
Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
181-
get_casted_expr_for_bool_op(*expr, &self.schema)?,
179+
get_casted_expr_for_bool_op(*expr, self.schema)?,
182180
))),
183181
Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
184-
get_casted_expr_for_bool_op(*expr, &self.schema)?,
182+
get_casted_expr_for_bool_op(*expr, self.schema)?,
185183
))),
186184
Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
187-
get_casted_expr_for_bool_op(*expr, &self.schema)?,
185+
get_casted_expr_for_bool_op(*expr, self.schema)?,
188186
))),
189187
Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
190-
get_casted_expr_for_bool_op(*expr, &self.schema)?,
188+
get_casted_expr_for_bool_op(*expr, self.schema)?,
191189
))),
192190
Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
193-
get_casted_expr_for_bool_op(*expr, &self.schema)?,
191+
get_casted_expr_for_bool_op(*expr, self.schema)?,
194192
))),
195193
Expr::Like(Like {
196194
negated,
@@ -199,8 +197,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
199197
escape_char,
200198
case_insensitive,
201199
}) => {
202-
let left_type = expr.get_type(&self.schema)?;
203-
let right_type = pattern.get_type(&self.schema)?;
200+
let left_type = expr.get_type(self.schema)?;
201+
let right_type = pattern.get_type(self.schema)?;
204202
let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| {
205203
let op_name = if case_insensitive {
206204
"ILIKE"
@@ -211,8 +209,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
211209
"There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
212210
)
213211
})?;
214-
let expr = Box::new(expr.cast_to(&coerced_type, &self.schema)?);
215-
let pattern = Box::new(pattern.cast_to(&coerced_type, &self.schema)?);
212+
let expr = Box::new(expr.cast_to(&coerced_type, self.schema)?);
213+
let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?);
216214
Ok(Transformed::yes(Expr::Like(Like::new(
217215
negated,
218216
expr,
@@ -223,14 +221,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
223221
}
224222
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
225223
let (left_type, right_type) = get_input_types(
226-
&left.get_type(&self.schema)?,
224+
&left.get_type(self.schema)?,
227225
&op,
228-
&right.get_type(&self.schema)?,
226+
&right.get_type(self.schema)?,
229227
)?;
230228
Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
231-
Box::new(left.cast_to(&left_type, &self.schema)?),
229+
Box::new(left.cast_to(&left_type, self.schema)?),
232230
op,
233-
Box::new(right.cast_to(&right_type, &self.schema)?),
231+
Box::new(right.cast_to(&right_type, self.schema)?),
234232
))))
235233
}
236234
Expr::Between(Between {
@@ -239,15 +237,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
239237
low,
240238
high,
241239
}) => {
242-
let expr_type = expr.get_type(&self.schema)?;
243-
let low_type = low.get_type(&self.schema)?;
240+
let expr_type = expr.get_type(self.schema)?;
241+
let low_type = low.get_type(self.schema)?;
244242
let low_coerced_type = comparison_coercion(&expr_type, &low_type)
245243
.ok_or_else(|| {
246244
DataFusionError::Internal(format!(
247245
"Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
248246
))
249247
})?;
250-
let high_type = high.get_type(&self.schema)?;
248+
let high_type = high.get_type(self.schema)?;
251249
let high_coerced_type = comparison_coercion(&expr_type, &low_type)
252250
.ok_or_else(|| {
253251
DataFusionError::Internal(format!(
@@ -262,21 +260,21 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
262260
))
263261
})?;
264262
Ok(Transformed::yes(Expr::Between(Between::new(
265-
Box::new(expr.cast_to(&coercion_type, &self.schema)?),
263+
Box::new(expr.cast_to(&coercion_type, self.schema)?),
266264
negated,
267-
Box::new(low.cast_to(&coercion_type, &self.schema)?),
268-
Box::new(high.cast_to(&coercion_type, &self.schema)?),
265+
Box::new(low.cast_to(&coercion_type, self.schema)?),
266+
Box::new(high.cast_to(&coercion_type, self.schema)?),
269267
))))
270268
}
271269
Expr::InList(InList {
272270
expr,
273271
list,
274272
negated,
275273
}) => {
276-
let expr_data_type = expr.get_type(&self.schema)?;
274+
let expr_data_type = expr.get_type(self.schema)?;
277275
let list_data_types = list
278276
.iter()
279-
.map(|list_expr| list_expr.get_type(&self.schema))
277+
.map(|list_expr| list_expr.get_type(self.schema))
280278
.collect::<Result<Vec<_>>>()?;
281279
let result_type =
282280
get_coerce_type_for_list(&expr_data_type, &list_data_types);
@@ -286,11 +284,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
286284
),
287285
Some(coerced_type) => {
288286
// find the coerced type
289-
let cast_expr = expr.cast_to(&coerced_type, &self.schema)?;
287+
let cast_expr = expr.cast_to(&coerced_type, self.schema)?;
290288
let cast_list_expr = list
291289
.into_iter()
292290
.map(|list_expr| {
293-
list_expr.cast_to(&coerced_type, &self.schema)
291+
list_expr.cast_to(&coerced_type, self.schema)
294292
})
295293
.collect::<Result<Vec<_>>>()?;
296294
Ok(Transformed::yes(Expr::InList(InList ::new(
@@ -302,18 +300,17 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
302300
}
303301
}
304302
Expr::Case(case) => {
305-
let case = coerce_case_expression(case, &self.schema)?;
303+
let case = coerce_case_expression(case, self.schema)?;
306304
Ok(Transformed::yes(Expr::Case(case)))
307305
}
308306
Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def {
309307
ScalarFunctionDefinition::UDF(fun) => {
310308
let new_expr = coerce_arguments_for_signature(
311309
args,
312-
&self.schema,
310+
self.schema,
313311
fun.signature(),
314312
)?;
315-
let new_expr =
316-
coerce_arguments_for_fun(new_expr, &self.schema, &fun)?;
313+
let new_expr = coerce_arguments_for_fun(new_expr, self.schema, &fun)?;
317314
Ok(Transformed::yes(Expr::ScalarFunction(
318315
ScalarFunction::new_udf(fun, new_expr),
319316
)))
@@ -331,7 +328,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
331328
let new_expr = coerce_agg_exprs_for_signature(
332329
&fun,
333330
args,
334-
&self.schema,
331+
self.schema,
335332
&fun.signature(),
336333
)?;
337334
Ok(Transformed::yes(Expr::AggregateFunction(
@@ -348,7 +345,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
348345
AggregateFunctionDefinition::UDF(fun) => {
349346
let new_expr = coerce_arguments_for_signature(
350347
args,
351-
&self.schema,
348+
self.schema,
352349
fun.signature(),
353350
)?;
354351
Ok(Transformed::yes(Expr::AggregateFunction(
@@ -375,14 +372,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
375372
null_treatment,
376373
}) => {
377374
let window_frame =
378-
coerce_window_frame(window_frame, &self.schema, &order_by)?;
375+
coerce_window_frame(window_frame, self.schema, &order_by)?;
379376

380377
let args = match &fun {
381378
expr::WindowFunctionDefinition::AggregateFunction(fun) => {
382379
coerce_agg_exprs_for_signature(
383380
fun,
384381
args,
385-
&self.schema,
382+
self.schema,
386383
&fun.signature(),
387384
)?
388385
}
@@ -495,7 +492,7 @@ fn coerce_frame_bound(
495492
// For example, ROWS and GROUPS frames use `UInt64` during calculations.
496493
fn coerce_window_frame(
497494
window_frame: WindowFrame,
498-
schema: &DFSchemaRef,
495+
schema: &DFSchema,
499496
expressions: &[Expr],
500497
) -> Result<WindowFrame> {
501498
let mut window_frame = window_frame;
@@ -531,7 +528,7 @@ fn coerce_window_frame(
531528

532529
// Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion.
533530
// The above op will be rewrite to the binary op when creating the physical op.
534-
fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchemaRef) -> Result<Expr> {
531+
fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
535532
let left_type = expr.get_type(schema)?;
536533
get_input_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?;
537534
expr.cast_to(&DataType::Boolean, schema)
@@ -615,7 +612,7 @@ fn coerce_agg_exprs_for_signature(
615612
.collect()
616613
}
617614

618-
fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result<Case> {
615+
fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
619616
// Given expressions like:
620617
//
621618
// CASE a1
@@ -1238,7 +1235,7 @@ mod test {
12381235
vec![Field::new("a", DataType::Int64, true)].into(),
12391236
std::collections::HashMap::new(),
12401237
)?);
1241-
let mut rewriter = TypeCoercionRewriter { schema };
1238+
let mut rewriter = TypeCoercionRewriter { schema: &schema };
12421239
let expr = is_true(lit(12i32).gt(lit(13i64)));
12431240
let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64)));
12441241
let result = expr.rewrite(&mut rewriter).data()?;
@@ -1249,7 +1246,7 @@ mod test {
12491246
vec![Field::new("a", DataType::Int64, true)].into(),
12501247
std::collections::HashMap::new(),
12511248
)?);
1252-
let mut rewriter = TypeCoercionRewriter { schema };
1249+
let mut rewriter = TypeCoercionRewriter { schema: &schema };
12531250
let expr = is_true(lit(12i32).eq(lit(13i64)));
12541251
let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64)));
12551252
let result = expr.rewrite(&mut rewriter).data()?;
@@ -1260,7 +1257,7 @@ mod test {
12601257
vec![Field::new("a", DataType::Int64, true)].into(),
12611258
std::collections::HashMap::new(),
12621259
)?);
1263-
let mut rewriter = TypeCoercionRewriter { schema };
1260+
let mut rewriter = TypeCoercionRewriter { schema: &schema };
12641261
let expr = is_true(lit(12i32).lt(lit(13i64)));
12651262
let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64)));
12661263
let result = expr.rewrite(&mut rewriter).data()?;

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ use datafusion_common::{
3131
cast::{as_large_list_array, as_list_array},
3232
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
3333
};
34-
use datafusion_common::{
35-
internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
36-
};
34+
use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
3735
use datafusion_expr::expr::{InList, InSubquery};
3836
use datafusion_expr::simplify::ExprSimplifyResult;
3937
use datafusion_expr::{
@@ -208,14 +206,8 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
208206
///
209207
/// See the [type coercion module](datafusion_expr::type_coercion)
210208
/// documentation for more details on type coercion
211-
///
212-
// Would be nice if this API could use the SimplifyInfo
213-
// rather than creating an DFSchemaRef coerces rather than doing
214-
// it manually.
215-
// https://github.com/apache/datafusion/issues/3793
216-
pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result<Expr> {
209+
pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result<Expr> {
217210
let mut expr_rewrite = TypeCoercionRewriter { schema };
218-
219211
expr.rewrite(&mut expr_rewrite).data()
220212
}
221213

@@ -1686,7 +1678,7 @@ mod tests {
16861678
sync::Arc,
16871679
};
16881680

1689-
use datafusion_common::{assert_contains, ToDFSchema};
1681+
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
16901682
use datafusion_expr::{interval_arithmetic::Interval, *};
16911683

16921684
use crate::simplify_expressions::SimplifyContext;
@@ -1721,11 +1713,7 @@ mod tests {
17211713
// should fully simplify to 3 < i (though i has been coerced to i64)
17221714
let expected = lit(3i64).lt(col("i"));
17231715

1724-
// Would be nice if this API could use the SimplifyInfo
1725-
// rather than creating an DFSchemaRef coerces rather than doing
1726-
// it manually.
1727-
// https://github.com/apache/datafusion/issues/3793
1728-
let expr = simplifier.coerce(expr, schema).unwrap();
1716+
let expr = simplifier.coerce(expr, &schema).unwrap();
17291717

17301718
assert_eq!(expected, simplifier.simplify(expr).unwrap());
17311719
}

0 commit comments

Comments
 (0)