Skip to content

Commit 22464f0

Browse files
authored
MINOR: Partial fix for SQL aggregate queries with aliases (#2464)
1 parent 522ea52 commit 22464f0

File tree

4 files changed

+142
-9
lines changed

4 files changed

+142
-9
lines changed

datafusion/core/src/sql/planner.rs

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
10061006
.map(|e| {
10071007
let group_by_expr =
10081008
self.sql_expr_to_logical_expr(e, &combined_schema, ctes)?;
1009+
// aliases from the projection can conflict with same-named expressions in the input
1010+
let mut alias_map = alias_map.clone();
1011+
for f in plan.schema().fields() {
1012+
alias_map.remove(f.name());
1013+
}
10091014
let group_by_expr = resolve_aliases_to_exprs(&group_by_expr, &alias_map)?;
10101015
let group_by_expr =
10111016
resolve_positions_to_exprs(&group_by_expr, &select_exprs)
@@ -1020,7 +1025,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
10201025
.collect::<Result<Vec<Expr>>>()?;
10211026

10221027
// process group by, aggregation or having
1023-
let (plan, select_exprs_post_aggr, having_expr_post_aggr_opt) =
1028+
let (plan, select_exprs_post_aggr, having_expr_post_aggr) =
10241029
if !group_by_exprs.is_empty() || !aggr_exprs.is_empty() {
10251030
self.aggregate(
10261031
plan,
@@ -1048,7 +1053,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
10481053
(plan, select_exprs, having_expr_opt)
10491054
};
10501055

1051-
let plan = if let Some(having_expr_post_aggr) = having_expr_post_aggr_opt {
1056+
let plan = if let Some(having_expr_post_aggr) = having_expr_post_aggr {
10521057
LogicalPlanBuilder::from(plan)
10531058
.filter(having_expr_post_aggr)?
10541059
.build()?
@@ -1107,7 +1112,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
11071112
LogicalPlanBuilder::from(input).project(expr)?.build()
11081113
}
11091114

1110-
/// Wrap a plan in an aggregate
1115+
/// Create an aggregate plan.
1116+
///
1117+
/// An aggregate plan consists of grouping expressions, aggregate expressions, and an
1118+
/// optional HAVING expression (which is a filter on the output of the aggregate).
1119+
///
1120+
/// # Arguments
1121+
///
1122+
/// * `input` - The input plan that will be aggregated. The grouping, aggregate, and
1123+
/// "having" expressions must all be resolvable from this plan.
1124+
/// * `select_exprs` - The projection expressions from the SELECT clause.
1125+
/// * `having_expr_opt` - Optional HAVING clause.
1126+
/// * `group_by_exprs` - Grouping expressions from the GROUP BY clause. These can be column
1127+
/// references or more complex expressions.
1128+
/// * `aggr_exprs` - Aggregate expressions, such as `SUM(a)` or `COUNT(1)`.
1129+
///
1130+
/// # Return
1131+
///
1132+
/// The return value is a triplet of the following items:
1133+
///
1134+
/// * `plan` - A [LogicalPlan::Aggregate] plan for the newly created aggregate.
1135+
/// * `select_exprs_post_aggr` - The projection expressions rewritten to reference columns from
1136+
/// the aggregate
1137+
/// * `having_expr_post_aggr` - The "having" expression rewritten to reference a column from
1138+
/// the aggregate
11111139
fn aggregate(
11121140
&self,
11131141
input: LogicalPlan,
@@ -1148,7 +1176,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
11481176

11491177
// Rewrite the HAVING expression to use the columns produced by the
11501178
// aggregation.
1151-
let having_expr_post_aggr_opt = if let Some(having_expr) = having_expr_opt {
1179+
let having_expr_post_aggr = if let Some(having_expr) = having_expr_opt {
11521180
let having_expr_post_aggr =
11531181
rebase_expr(having_expr, &aggr_projection_exprs, &input)?;
11541182

@@ -1163,7 +1191,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
11631191
None
11641192
};
11651193

1166-
Ok((plan, select_exprs_post_aggr, having_expr_post_aggr_opt))
1194+
Ok((plan, select_exprs_post_aggr, having_expr_post_aggr))
11671195
}
11681196

11691197
/// Wrap a plan in a limit

datafusion/core/src/sql/utils.rs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use crate::{
2727
error::{DataFusionError, Result},
2828
logical_plan::{Column, ExpressionVisitor, Recursion},
2929
};
30+
use datafusion_expr::expr::find_columns_referenced_by_expr;
3031
use std::collections::HashMap;
3132

3233
/// Collect all deeply nested `Expr::AggregateFunction` and
@@ -58,9 +59,13 @@ pub(crate) fn find_window_exprs(exprs: &[Expr]) -> Vec<Expr> {
5859
}
5960

6061
/// Collect all deeply nested `Expr::Column`'s. They are returned in order of
61-
/// appearance (depth first), with duplicates omitted.
62+
/// appearance (depth first), and may contain duplicates.
6263
pub(crate) fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
63-
find_exprs_in_exprs(exprs, &|nested_expr| matches!(nested_expr, Expr::Column(_)))
64+
exprs
65+
.iter()
66+
.flat_map(find_columns_referenced_by_expr)
67+
.map(Expr::Column)
68+
.collect()
6469
}
6570

6671
/// Search the provided `Expr`'s, and all of their nested `Expr`, for any that
@@ -137,8 +142,16 @@ where
137142
/// Convert any `Expr` to an `Expr::Column`.
138143
pub(crate) fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
139144
match expr {
140-
Expr::Column(_) => Ok(expr.clone()),
141-
_ => Ok(Expr::Column(Column::from_name(expr.name(plan.schema())?))),
145+
Expr::Column(col) => {
146+
let field = plan.schema().field_from_column(col)?;
147+
Ok(Expr::Column(field.qualified_column()))
148+
}
149+
_ => {
150+
// we should not be trying to create a name for the expression
151+
// based on the input schema but this is the current behavior
152+
// see https://github.com/apache/arrow-datafusion/issues/2456
153+
Ok(Expr::Column(Column::from_name(expr.name(plan.schema())?)))
154+
}
142155
}
143156
}
144157

datafusion/core/tests/sql/group_by.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,27 @@ async fn csv_query_group_by_avg() -> Result<()> {
232232
Ok(())
233233
}
234234

235+
#[tokio::test]
236+
async fn csv_query_group_by_with_aliases() -> Result<()> {
237+
let ctx = SessionContext::new();
238+
register_aggregate_csv(&ctx).await?;
239+
let sql = "SELECT c1 AS c12, avg(c12) AS c1 FROM aggregate_test_100 GROUP BY c1";
240+
let actual = execute_to_batches(&ctx, sql).await;
241+
let expected = vec![
242+
"+-----+---------------------+",
243+
"| c12 | c1 |",
244+
"+-----+---------------------+",
245+
"| a | 0.48754517466109415 |",
246+
"| b | 0.41040709263815384 |",
247+
"| c | 0.6600456536439784 |",
248+
"| d | 0.48855379387549824 |",
249+
"| e | 0.48600669271341534 |",
250+
"+-----+---------------------+",
251+
];
252+
assert_batches_sorted_eq!(expected, &actual);
253+
Ok(())
254+
}
255+
235256
#[tokio::test]
236257
async fn csv_query_group_by_int_count() -> Result<()> {
237258
let ctx = SessionContext::new();

datafusion/expr/src/expr.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,77 @@ pub enum Expr {
251251
QualifiedWildcard { qualifier: String },
252252
}
253253

254+
/// Recursively find all columns referenced by an expression
255+
pub fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
256+
match e {
257+
Expr::Alias(expr, _)
258+
| Expr::Negative(expr)
259+
| Expr::Cast { expr, .. }
260+
| Expr::TryCast { expr, .. }
261+
| Expr::Sort { expr, .. }
262+
| Expr::InList { expr, .. }
263+
| Expr::InSubquery { expr, .. }
264+
| Expr::GetIndexedField { expr, .. }
265+
| Expr::Not(expr)
266+
| Expr::IsNotNull(expr)
267+
| Expr::IsNull(expr) => find_columns_referenced_by_expr(expr),
268+
Expr::Column(c) => vec![c.clone()],
269+
Expr::BinaryExpr { left, right, .. } => {
270+
let mut cols = vec![];
271+
cols.extend(find_columns_referenced_by_expr(left.as_ref()));
272+
cols.extend(find_columns_referenced_by_expr(right.as_ref()));
273+
cols
274+
}
275+
Expr::Case {
276+
expr,
277+
when_then_expr,
278+
else_expr,
279+
} => {
280+
let mut cols = vec![];
281+
if let Some(expr) = expr {
282+
cols.extend(find_columns_referenced_by_expr(expr.as_ref()));
283+
}
284+
for (w, t) in when_then_expr {
285+
cols.extend(find_columns_referenced_by_expr(w.as_ref()));
286+
cols.extend(find_columns_referenced_by_expr(t.as_ref()));
287+
}
288+
if let Some(else_expr) = else_expr {
289+
cols.extend(find_columns_referenced_by_expr(else_expr.as_ref()));
290+
}
291+
cols
292+
}
293+
Expr::ScalarFunction { args, .. } => args
294+
.iter()
295+
.flat_map(find_columns_referenced_by_expr)
296+
.collect(),
297+
Expr::AggregateFunction { args, .. } => args
298+
.iter()
299+
.flat_map(find_columns_referenced_by_expr)
300+
.collect(),
301+
Expr::ScalarVariable(_, _)
302+
| Expr::Exists { .. }
303+
| Expr::Wildcard
304+
| Expr::QualifiedWildcard { .. }
305+
| Expr::ScalarSubquery(_)
306+
| Expr::Literal(_) => vec![],
307+
Expr::Between {
308+
expr, low, high, ..
309+
} => {
310+
let mut cols = vec![];
311+
cols.extend(find_columns_referenced_by_expr(expr.as_ref()));
312+
cols.extend(find_columns_referenced_by_expr(low.as_ref()));
313+
cols.extend(find_columns_referenced_by_expr(high.as_ref()));
314+
cols
315+
}
316+
Expr::ScalarUDF { args, .. }
317+
| Expr::WindowFunction { args, .. }
318+
| Expr::AggregateUDF { args, .. } => args
319+
.iter()
320+
.flat_map(find_columns_referenced_by_expr)
321+
.collect(),
322+
}
323+
}
324+
254325
/// Fixed seed for the hashing so that Ords are consistent across runs
255326
const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0);
256327

0 commit comments

Comments
 (0)