Skip to content

Commit 675eb82

Browse files
authored
Fix bugs in SQL planner with GROUP BY scalar function and alias (#2457)
1 parent 22464f0 commit 675eb82

File tree

3 files changed

+65
-8
lines changed

3 files changed

+65
-8
lines changed

datafusion/core/src/sql/planner.rs

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ use crate::logical_plan::{
3737
use crate::optimizer::utils::exprlist_to_columns;
3838
use crate::prelude::JoinType;
3939
use crate::scalar::ScalarValue;
40-
use crate::sql::utils::{make_decimal_type, normalize_ident};
40+
use crate::sql::utils::{make_decimal_type, normalize_ident, resolve_columns};
4141
use crate::{
4242
error::{DataFusionError, Result},
4343
physical_plan::aggregates,
@@ -1144,30 +1144,45 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
11441144
group_by_exprs: Vec<Expr>,
11451145
aggr_exprs: Vec<Expr>,
11461146
) -> Result<(LogicalPlan, Vec<Expr>, Option<Expr>)> {
1147+
// create the aggregate plan
1148+
let plan = LogicalPlanBuilder::from(input.clone())
1149+
.aggregate(group_by_exprs.clone(), aggr_exprs.clone())?
1150+
.build()?;
1151+
1152+
// in this next section of code we are re-writing the projection to refer to columns
1153+
// output by the aggregate plan. For example, if the projection contains the expression
1154+
// `SUM(a)` then we replace that with a reference to a column `#SUM(a)` produced by
1155+
// the aggregate plan.
1156+
1157+
// combine the original grouping and aggregate expressions into one list (note that
1158+
// we do not add the "having" expression since that is not part of the projection)
11471159
let aggr_projection_exprs = group_by_exprs
11481160
.iter()
11491161
.chain(aggr_exprs.iter())
11501162
.cloned()
11511163
.collect::<Vec<Expr>>();
11521164

1153-
let plan = LogicalPlanBuilder::from(input.clone())
1154-
.aggregate(group_by_exprs, aggr_exprs)?
1155-
.build()?;
1165+
// now attempt to resolve columns and replace with fully-qualified columns
1166+
let aggr_projection_exprs = aggr_projection_exprs
1167+
.iter()
1168+
.map(|expr| resolve_columns(expr, &input))
1169+
.collect::<Result<Vec<Expr>>>()?;
11561170

1157-
// After aggregation, these are all of the columns that will be
1158-
// available to next phases of planning.
1171+
// next we replace any expressions that are not a column with a column referencing
1172+
// an output column from the aggregate schema
11591173
let column_exprs_post_aggr = aggr_projection_exprs
11601174
.iter()
11611175
.map(|expr| expr_as_column_expr(expr, &input))
11621176
.collect::<Result<Vec<Expr>>>()?;
11631177

1164-
// Rewrite the SELECT expression to use the columns produced by the
1165-
// aggregation.
1178+
// next we re-write the projection
11661179
let select_exprs_post_aggr = select_exprs
11671180
.iter()
11681181
.map(|expr| rebase_expr(expr, &aggr_projection_exprs, &input))
11691182
.collect::<Result<Vec<Expr>>>()?;
11701183

1184+
// finally, we have some validation that the re-written projection can be resolved
1185+
// from the aggregate output columns
11711186
check_columns_satisfy_exprs(
11721187
&column_exprs_post_aggr,
11731188
&select_exprs_post_aggr,

datafusion/core/src/sql/utils.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,22 @@ pub(crate) fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Exp
155155
}
156156
}
157157

158+
/// Make a best-effort attempt at resolving all columns in the expression tree
159+
pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
160+
clone_with_replacement(expr, &|nested_expr| {
161+
match nested_expr {
162+
Expr::Column(col) => {
163+
let field = plan.schema().field_from_column(col)?;
164+
Ok(Some(Expr::Column(field.qualified_column())))
165+
}
166+
_ => {
167+
// keep recursing
168+
Ok(None)
169+
}
170+
}
171+
})
172+
}
173+
158174
/// Rebuilds an `Expr` as a projection on top of a collection of `Expr`'s.
159175
///
160176
/// For example, the expression `a + b < 1` would require, as input, the 2

datafusion/core/tests/sql/group_by.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,32 @@ async fn csv_query_having_without_group_by() -> Result<()> {
211211
Ok(())
212212
}
213213

214+
#[tokio::test]
215+
async fn csv_query_group_by_substr() -> Result<()> {
216+
let ctx = SessionContext::new();
217+
register_aggregate_csv(&ctx).await?;
218+
// there is an input column "c1" as well a projection expression aliased as "c1"
219+
let sql = "SELECT substr(c1, 1, 1) c1 \
220+
FROM aggregate_test_100 \
221+
GROUP BY substr(c1, 1, 1) \
222+
";
223+
let actual = execute_to_batches(&ctx, sql).await;
224+
#[rustfmt::skip]
225+
let expected = vec![
226+
"+----+",
227+
"| c1 |",
228+
"+----+",
229+
"| a |",
230+
"| b |",
231+
"| c |",
232+
"| d |",
233+
"| e |",
234+
"+----+",
235+
];
236+
assert_batches_sorted_eq!(expected, &actual);
237+
Ok(())
238+
}
239+
214240
#[tokio::test]
215241
async fn csv_query_group_by_avg() -> Result<()> {
216242
let ctx = SessionContext::new();

0 commit comments

Comments
 (0)