Skip to content

Commit 0b2b4fb

Browse files
sgrebnovgoldmedal
andauthored
Support unparsing plans with both Aggregation and Window functions (#12705)
* Support unparsing plans with both Aggregation and Window functions (#35) * Fix unparsing for aggregation grouping sets * Add test for grouping set unparsing * Update datafusion/sql/src/unparser/utils.rs Co-authored-by: Jax Liu <[email protected]> * Update datafusion/sql/src/unparser/utils.rs Co-authored-by: Jax Liu <[email protected]> * Update * More tests --------- Co-authored-by: Jax Liu <[email protected]>
1 parent cfd861c commit 0b2b4fb

File tree

3 files changed

+134
-46
lines changed

3 files changed

+134
-46
lines changed

datafusion/sql/src/unparser/plan.rs

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ use super::{
3838
rewrite_plan_for_sort_on_non_projected_fields,
3939
subquery_alias_inner_query_and_columns, TableAliasRewriter,
4040
},
41-
utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant},
41+
utils::{
42+
find_agg_node_within_select, find_window_nodes_within_select,
43+
unproject_window_exprs,
44+
},
4245
Unparser,
4346
};
4447

@@ -172,13 +175,17 @@ impl Unparser<'_> {
172175
p: &Projection,
173176
select: &mut SelectBuilder,
174177
) -> Result<()> {
175-
match find_agg_node_within_select(plan, None, true) {
176-
Some(AggVariant::Aggregate(agg)) => {
178+
match (
179+
find_agg_node_within_select(plan, true),
180+
find_window_nodes_within_select(plan, None, true),
181+
) {
182+
(Some(agg), window) => {
183+
let window_option = window.as_deref();
177184
let items = p
178185
.expr
179186
.iter()
180187
.map(|proj_expr| {
181-
let unproj = unproject_agg_exprs(proj_expr, agg)?;
188+
let unproj = unproject_agg_exprs(proj_expr, agg, window_option)?;
182189
self.select_item_to_sql(&unproj)
183190
})
184191
.collect::<Result<Vec<_>>>()?;
@@ -192,7 +199,7 @@ impl Unparser<'_> {
192199
vec![],
193200
));
194201
}
195-
Some(AggVariant::Window(window)) => {
202+
(None, Some(window)) => {
196203
let items = p
197204
.expr
198205
.iter()
@@ -204,7 +211,7 @@ impl Unparser<'_> {
204211

205212
select.projection(items);
206213
}
207-
None => {
214+
_ => {
208215
let items = p
209216
.expr
210217
.iter()
@@ -287,10 +294,10 @@ impl Unparser<'_> {
287294
self.select_to_sql_recursively(p.input.as_ref(), query, select, relation)
288295
}
289296
LogicalPlan::Filter(filter) => {
290-
if let Some(AggVariant::Aggregate(agg)) =
291-
find_agg_node_within_select(plan, None, select.already_projected())
297+
if let Some(agg) =
298+
find_agg_node_within_select(plan, select.already_projected())
292299
{
293-
let unprojected = unproject_agg_exprs(&filter.predicate, agg)?;
300+
let unprojected = unproject_agg_exprs(&filter.predicate, agg, None)?;
294301
let filter_expr = self.expr_to_sql(&unprojected)?;
295302
select.having(Some(filter_expr));
296303
} else {

datafusion/sql/src/unparser/utils.rs

Lines changed: 97 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,58 +18,83 @@
1818
use datafusion_common::{
1919
internal_err,
2020
tree_node::{Transformed, TreeNode},
21-
Result,
21+
Column, DataFusionError, Result,
22+
};
23+
use datafusion_expr::{
24+
utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window,
2225
};
23-
use datafusion_expr::{Aggregate, Expr, LogicalPlan, Window};
2426

25-
/// One of the possible aggregation plans which can be found within a single select query.
26-
pub(crate) enum AggVariant<'a> {
27-
Aggregate(&'a Aggregate),
28-
Window(Vec<&'a Window>),
27+
/// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists
28+
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
29+
/// If an Aggregate or node is not found prior to this or at all before reaching the end
30+
/// of the tree, None is returned.
31+
pub(crate) fn find_agg_node_within_select(
32+
plan: &LogicalPlan,
33+
already_projected: bool,
34+
) -> Option<&Aggregate> {
35+
// Note that none of the nodes that have a corresponding node can have more
36+
// than 1 input node. E.g. Projection / Filter always have 1 input node.
37+
let input = plan.inputs();
38+
let input = if input.len() > 1 {
39+
return None;
40+
} else {
41+
input.first()?
42+
};
43+
// Agg nodes explicitly return immediately with a single node
44+
if let LogicalPlan::Aggregate(agg) = input {
45+
Some(agg)
46+
} else if let LogicalPlan::TableScan(_) = input {
47+
None
48+
} else if let LogicalPlan::Projection(_) = input {
49+
if already_projected {
50+
None
51+
} else {
52+
find_agg_node_within_select(input, true)
53+
}
54+
} else {
55+
find_agg_node_within_select(input, already_projected)
56+
}
2957
}
3058

31-
/// Recursively searches children of [LogicalPlan] to find an Aggregate or window node if one exists
59+
/// Recursively searches children of [LogicalPlan] to find Window nodes if exist
3260
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
33-
/// If an Aggregate or window node is not found prior to this or at all before reaching the end
34-
/// of the tree, None is returned. It is assumed that a Window and Aggregate node cannot both
35-
/// be found in a single select query.
36-
pub(crate) fn find_agg_node_within_select<'a>(
61+
/// If Window node is not found prior to this or at all before reaching the end
62+
/// of the tree, None is returned.
63+
pub(crate) fn find_window_nodes_within_select<'a>(
3764
plan: &'a LogicalPlan,
38-
mut prev_windows: Option<AggVariant<'a>>,
65+
mut prev_windows: Option<Vec<&'a Window>>,
3966
already_projected: bool,
40-
) -> Option<AggVariant<'a>> {
41-
// Note that none of the nodes that have a corresponding agg node can have more
67+
) -> Option<Vec<&'a Window>> {
68+
// Note that none of the nodes that have a corresponding node can have more
4269
// than 1 input node. E.g. Projection / Filter always have 1 input node.
4370
let input = plan.inputs();
4471
let input = if input.len() > 1 {
45-
return None;
72+
return prev_windows;
4673
} else {
4774
input.first()?
4875
};
4976

50-
// Agg nodes explicitly return immediately with a single node
5177
// Window nodes accumulate in a vec until encountering a TableScan or 2nd projection
5278
match input {
53-
LogicalPlan::Aggregate(agg) => Some(AggVariant::Aggregate(agg)),
5479
LogicalPlan::Window(window) => {
5580
prev_windows = match &mut prev_windows {
56-
Some(AggVariant::Window(windows)) => {
81+
Some(windows) => {
5782
windows.push(window);
5883
prev_windows
5984
}
60-
_ => Some(AggVariant::Window(vec![window])),
85+
_ => Some(vec![window]),
6186
};
62-
find_agg_node_within_select(input, prev_windows, already_projected)
87+
find_window_nodes_within_select(input, prev_windows, already_projected)
6388
}
6489
LogicalPlan::Projection(_) => {
6590
if already_projected {
6691
prev_windows
6792
} else {
68-
find_agg_node_within_select(input, prev_windows, true)
93+
find_window_nodes_within_select(input, prev_windows, true)
6994
}
7095
}
7196
LogicalPlan::TableScan(_) => prev_windows,
72-
_ => find_agg_node_within_select(input, prev_windows, already_projected),
97+
_ => find_window_nodes_within_select(input, prev_windows, already_projected),
7398
}
7499
}
75100

@@ -78,22 +103,34 @@ pub(crate) fn find_agg_node_within_select<'a>(
78103
///
79104
/// For example, if expr contains the column expr "COUNT(*)" it will be transformed
80105
/// into an actual aggregate expression COUNT(*) as identified in the aggregate node.
81-
pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result<Expr> {
106+
pub(crate) fn unproject_agg_exprs(
107+
expr: &Expr,
108+
agg: &Aggregate,
109+
windows: Option<&[&Window]>,
110+
) -> Result<Expr> {
82111
expr.clone()
83112
.transform(|sub_expr| {
84113
if let Expr::Column(c) = sub_expr {
85-
// find the column in the agg schema
86-
if let Ok(n) = agg.schema.index_of_column(&c) {
87-
let unprojected_expr = agg
88-
.group_expr
89-
.iter()
90-
.chain(agg.aggr_expr.iter())
91-
.nth(n)
92-
.unwrap();
114+
if let Some(unprojected_expr) = find_agg_expr(agg, &c)? {
93115
Ok(Transformed::yes(unprojected_expr.clone()))
116+
} else if let Some(mut unprojected_expr) =
117+
windows.and_then(|w| find_window_expr(w, &c.name).cloned())
118+
{
119+
if let Expr::WindowFunction(func) = &mut unprojected_expr {
120+
// Window function can contain an aggregation column, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected
121+
func.args.iter_mut().try_for_each(|arg| {
122+
if let Expr::Column(c) = arg {
123+
if let Some(expr) = find_agg_expr(agg, c)? {
124+
*arg = expr.clone();
125+
}
126+
}
127+
Ok::<(), DataFusionError>(())
128+
})?;
129+
}
130+
Ok(Transformed::yes(unprojected_expr))
94131
} else {
95132
internal_err!(
96-
"Tried to unproject agg expr not found in provided Aggregate!"
133+
"Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name
97134
)
98135
}
99136
} else {
@@ -112,11 +149,7 @@ pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result
112149
expr.clone()
113150
.transform(|sub_expr| {
114151
if let Expr::Column(c) = sub_expr {
115-
if let Some(unproj) = windows
116-
.iter()
117-
.flat_map(|w| w.window_expr.iter())
118-
.find(|window_expr| window_expr.schema_name().to_string() == c.name)
119-
{
152+
if let Some(unproj) = find_window_expr(windows, &c.name) {
120153
Ok(Transformed::yes(unproj.clone()))
121154
} else {
122155
Ok(Transformed::no(Expr::Column(c)))
@@ -127,3 +160,30 @@ pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result
127160
})
128161
.map(|e| e.data)
129162
}
163+
164+
fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result<Option<&'a Expr>> {
165+
if let Ok(index) = agg.schema.index_of_column(column) {
166+
if matches!(agg.group_expr.as_slice(), [Expr::GroupingSet(_)]) {
167+
// For grouping set expr, we must operate by expression list from the grouping set
168+
let grouping_expr = grouping_set_to_exprlist(agg.group_expr.as_slice())?;
169+
Ok(grouping_expr
170+
.into_iter()
171+
.chain(agg.aggr_expr.iter())
172+
.nth(index))
173+
} else {
174+
Ok(agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index))
175+
}
176+
} else {
177+
Ok(None)
178+
}
179+
}
180+
181+
fn find_window_expr<'a>(
182+
windows: &'a [&'a Window],
183+
column_name: &'a str,
184+
) -> Option<&'a Expr> {
185+
windows
186+
.iter()
187+
.flat_map(|w| w.window_expr.iter())
188+
.find(|expr| expr.schema_name().to_string() == column_name)
189+
}

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,26 @@ fn roundtrip_statement() -> Result<()> {
149149
"WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3",
150150
"WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col), w4 as (SELECT 'd' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3 UNION ALL SELECT * FROM w4",
151151
"WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col) SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col",
152+
r#"SELECT id, first_name,
153+
SUM(id) AS total_sum,
154+
SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum,
155+
MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total
156+
FROM person JOIN orders ON person.id = orders.customer_id GROUP BY id, first_name"#,
157+
r#"SELECT id, first_name,
158+
SUM(id) AS total_sum,
159+
SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum,
160+
MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total
161+
FROM (SELECT id, first_name from person) person JOIN (SELECT customer_id FROM orders) orders ON person.id = orders.customer_id GROUP BY id, first_name"#,
162+
r#"SELECT id, first_name, last_name, customer_id, SUM(id) AS total_sum
163+
FROM person
164+
JOIN orders ON person.id = orders.customer_id
165+
GROUP BY ROLLUP(id, first_name, last_name, customer_id)"#,
166+
r#"SELECT id, first_name, last_name,
167+
SUM(id) AS total_sum,
168+
COUNT(*) AS total_count,
169+
SUM(id) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_total
170+
FROM person
171+
GROUP BY GROUPING SETS ((id, first_name, last_name), (first_name, last_name), (last_name))"#,
152172
];
153173

154174
// For each test sql string, we transform as follows:
@@ -164,6 +184,7 @@ fn roundtrip_statement() -> Result<()> {
164184
let state = MockSessionState::default()
165185
.with_aggregate_function(sum_udaf())
166186
.with_aggregate_function(count_udaf())
187+
.with_aggregate_function(max_udaf())
167188
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()));
168189
let context = MockContextProvider { state };
169190
let sql_to_rel = SqlToRel::new(&context);

0 commit comments

Comments
 (0)