Skip to content

Commit 9974cee

Browse files
authored
fix NamedStructField should be rewritten in OperatorToFunction in subquery (#10103)
1 parent 1fa25ae commit 9974cee

File tree

2 files changed

+144
-40
lines changed

2 files changed

+144
-40
lines changed

datafusion/optimizer/src/analyzer/function_rewrite.rs

Lines changed: 89 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ use super::AnalyzerRule;
2121
use datafusion_common::config::ConfigOptions;
2222
use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
2323
use datafusion_common::{DFSchema, Result};
24+
use datafusion_expr::expr::{Exists, InSubquery};
2425
use datafusion_expr::expr_rewriter::{rewrite_preserving_name, FunctionRewrite};
2526
use datafusion_expr::utils::merge_schema;
26-
use datafusion_expr::{Expr, LogicalPlan};
27+
use datafusion_expr::{Expr, LogicalPlan, Subquery};
2728
use std::sync::Arc;
2829

2930
/// Analyzer rule that invokes [`FunctionRewrite`]s on expressions
@@ -45,52 +46,66 @@ impl AnalyzerRule for ApplyFunctionRewrites {
4546
}
4647

4748
fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) -> Result<LogicalPlan> {
48-
self.analyze_internal(&plan, options)
49+
analyze_internal(&plan, &self.function_rewrites, options)
4950
}
5051
}
5152

52-
impl ApplyFunctionRewrites {
53-
fn analyze_internal(
54-
&self,
55-
plan: &LogicalPlan,
56-
options: &ConfigOptions,
57-
) -> Result<LogicalPlan> {
58-
// optimize child plans first
59-
let new_inputs = plan
60-
.inputs()
61-
.iter()
62-
.map(|p| self.analyze_internal(p, options))
63-
.collect::<Result<Vec<_>>>()?;
64-
65-
// get schema representing all available input fields. This is used for data type
66-
// resolution only, so order does not matter here
67-
let mut schema = merge_schema(new_inputs.iter().collect());
68-
69-
if let LogicalPlan::TableScan(ts) = plan {
70-
let source_schema =
71-
DFSchema::try_from_qualified_schema(&ts.table_name, &ts.source.schema())?;
72-
schema.merge(&source_schema);
73-
}
53+
fn analyze_internal(
54+
plan: &LogicalPlan,
55+
function_rewrites: &[Arc<dyn FunctionRewrite + Send + Sync>],
56+
options: &ConfigOptions,
57+
) -> Result<LogicalPlan> {
58+
// optimize child plans first
59+
let new_inputs = plan
60+
.inputs()
61+
.iter()
62+
.map(|p| analyze_internal(p, function_rewrites, options))
63+
.collect::<Result<Vec<_>>>()?;
7464

75-
let mut expr_rewrite = OperatorToFunctionRewriter {
76-
function_rewrites: &self.function_rewrites,
77-
options,
78-
schema: &schema,
79-
};
65+
// get schema representing all available input fields. This is used for data type
66+
// resolution only, so order does not matter here
67+
let mut schema = merge_schema(new_inputs.iter().collect());
8068

81-
let new_expr = plan
82-
.expressions()
83-
.into_iter()
84-
.map(|expr| {
85-
// ensure names don't change:
86-
// https://github.com/apache/arrow-datafusion/issues/3555
87-
rewrite_preserving_name(expr, &mut expr_rewrite)
88-
})
89-
.collect::<Result<Vec<_>>>()?;
90-
91-
plan.with_new_exprs(new_expr, new_inputs)
69+
if let LogicalPlan::TableScan(ts) = plan {
70+
let source_schema = DFSchema::try_from_qualified_schema(
71+
ts.table_name.clone(),
72+
&ts.source.schema(),
73+
)?;
74+
schema.merge(&source_schema);
9275
}
76+
77+
let mut expr_rewrite = OperatorToFunctionRewriter {
78+
function_rewrites,
79+
options,
80+
schema: &schema,
81+
};
82+
83+
let new_expr = plan
84+
.expressions()
85+
.into_iter()
86+
.map(|expr| {
87+
// ensure names don't change:
88+
// https://github.com/apache/arrow-datafusion/issues/3555
89+
rewrite_preserving_name(expr, &mut expr_rewrite)
90+
})
91+
.collect::<Result<Vec<_>>>()?;
92+
93+
plan.with_new_exprs(new_expr, new_inputs)
9394
}
95+
96+
fn rewrite_subquery(
97+
mut subquery: Subquery,
98+
function_rewrites: &[Arc<dyn FunctionRewrite + Send + Sync>],
99+
options: &ConfigOptions,
100+
) -> Result<Subquery> {
101+
subquery.subquery = Arc::new(analyze_internal(
102+
&subquery.subquery,
103+
function_rewrites,
104+
options,
105+
)?);
106+
Ok(subquery)
107+
}
108+
94109
struct OperatorToFunctionRewriter<'a> {
95110
function_rewrites: &'a [Arc<dyn FunctionRewrite + Send + Sync>],
96111
options: &'a ConfigOptions,
@@ -111,6 +126,40 @@ impl<'a> TreeNodeRewriter for OperatorToFunctionRewriter<'a> {
111126
expr = result.data
112127
}
113128

129+
// recurse into subqueries if needed
130+
let expr = match expr {
131+
Expr::ScalarSubquery(subquery) => Expr::ScalarSubquery(rewrite_subquery(
132+
subquery,
133+
self.function_rewrites,
134+
self.options,
135+
)?),
136+
137+
Expr::Exists(Exists { subquery, negated }) => Expr::Exists(Exists {
138+
subquery: rewrite_subquery(
139+
subquery,
140+
self.function_rewrites,
141+
self.options,
142+
)?,
143+
negated,
144+
}),
145+
146+
Expr::InSubquery(InSubquery {
147+
expr,
148+
subquery,
149+
negated,
150+
}) => Expr::InSubquery(InSubquery {
151+
expr,
152+
subquery: rewrite_subquery(
153+
subquery,
154+
self.function_rewrites,
155+
self.options,
156+
)?,
157+
negated,
158+
}),
159+
160+
expr => expr,
161+
};
162+
114163
Ok(if transformed {
115164
Transformed::yes(expr)
116165
} else {

datafusion/sqllogictest/test_files/subquery.slt

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,3 +1060,58 @@ logical_plan
10601060
Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1)
10611061
--Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a
10621062
----TableScan: t projection=[a]
1063+
1064+
###
1065+
## Ensure that operators are rewritten in subqueries
1066+
###
1067+
1068+
statement ok
1069+
create table foo(x int) as values (1);
1070+
1071+
# Show input data
1072+
query ?
1073+
select struct(1, 'b')
1074+
----
1075+
{c0: 1, c1: b}
1076+
1077+
1078+
query T
1079+
select (select struct(1, 'b')['c1']);
1080+
----
1081+
b
1082+
1083+
query T
1084+
select 'foo' || (select struct(1, 'b')['c1']);
1085+
----
1086+
foob
1087+
1088+
query I
1089+
SELECT * FROM (VALUES (1), (2))
1090+
WHERE column1 IN (SELECT struct(1, 'b')['c0']);
1091+
----
1092+
1
1093+
1094+
# also add an expression so the subquery is the output expr
1095+
query I
1096+
SELECT * FROM (VALUES (1), (2))
1097+
WHERE 1+2 = 3 AND column1 IN (SELECT struct(1, 'b')['c0']);
1098+
----
1099+
1
1100+
1101+
1102+
query I
1103+
SELECT * FROM foo
1104+
WHERE EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND struct(1, 'b')['c0'] = 1);
1105+
----
1106+
1
1107+
1108+
# also add an expression so the subquery is the output expr
1109+
query I
1110+
SELECT * FROM foo
1111+
WHERE 1+2 = 3 AND EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND struct(1, 'b')['c0'] = 1);
1112+
----
1113+
1
1114+
1115+
1116+
statement ok
1117+
drop table foo;

0 commit comments

Comments
 (0)