Skip to content

Commit 93b3d9c

Browse files
authored
Handle alias when parsing sql(parse_sql_expr) (#12939)
* fix: Fix parse_sql_expr not handling alias * cargo fmt * fix parse_sql_expr example(remove alias) * add testing * add SUM udaf to TestContextProvider and modify test_sql_to_expr_with_alias for function * revert change on example `parse_sql_expr`
1 parent 6196ff2 commit 93b3d9c

File tree

4 files changed

+82
-18
lines changed

4 files changed

+82
-18
lines changed

datafusion-examples/examples/parse_sql_expr.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,11 @@ async fn query_parquet_demo() -> Result<()> {
121121

122122
assert_batches_eq!(
123123
&[
124-
"+------------+----------------------+",
125-
"| double_col | sum(?table?.int_col) |",
126-
"+------------+----------------------+",
127-
"| 10.1 | 4 |",
128-
"+------------+----------------------+",
124+
"+------------+-------------+",
125+
"| double_col | sum_int_col |",
126+
"+------------+-------------+",
127+
"| 10.1 | 4 |",
128+
"+------------+-------------+",
129129
],
130130
&result
131131
);

datafusion/core/src/execution/session_state.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ use datafusion_sql::planner::{ContextProvider, ParserOptions, PlannerContext, Sq
6868
use itertools::Itertools;
6969
use log::{debug, info};
7070
use object_store::ObjectStore;
71-
use sqlparser::ast::Expr as SQLExpr;
71+
use sqlparser::ast::{Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias};
7272
use sqlparser::dialect::dialect_from_str;
7373
use std::any::Any;
7474
use std::collections::hash_map::Entry;
@@ -500,11 +500,22 @@ impl SessionState {
500500
sql: &str,
501501
dialect: &str,
502502
) -> datafusion_common::Result<SQLExpr> {
503+
self.sql_to_expr_with_alias(sql, dialect).map(|x| x.expr)
504+
}
505+
506+
/// parse a sql string into a sqlparser-rs AST [`SQLExprWithAlias`].
507+
///
508+
/// See [`Self::create_logical_expr`] for parsing sql to [`Expr`].
509+
pub fn sql_to_expr_with_alias(
510+
&self,
511+
sql: &str,
512+
dialect: &str,
513+
) -> datafusion_common::Result<SQLExprWithAlias> {
503514
let dialect = dialect_from_str(dialect).ok_or_else(|| {
504515
plan_datafusion_err!(
505516
"Unsupported SQL dialect: {dialect}. Available dialects: \
506-
Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
507-
MsSQL, ClickHouse, BigQuery, Ansi."
517+
Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
518+
MsSQL, ClickHouse, BigQuery, Ansi."
508519
)
509520
})?;
510521

@@ -603,15 +614,15 @@ impl SessionState {
603614
) -> datafusion_common::Result<Expr> {
604615
let dialect = self.config.options().sql_parser.dialect.as_str();
605616

606-
let sql_expr = self.sql_to_expr(sql, dialect)?;
617+
let sql_expr = self.sql_to_expr_with_alias(sql, dialect)?;
607618

608619
let provider = SessionContextProvider {
609620
state: self,
610621
tables: HashMap::new(),
611622
};
612623

613624
let query = SqlToRel::new_with_options(&provider, self.get_parser_options());
614-
query.sql_to_expr(sql_expr, df_schema, &mut PlannerContext::new())
625+
query.sql_to_expr_with_alias(sql_expr, df_schema, &mut PlannerContext::new())
615626
}
616627

617628
/// Returns the [`Analyzer`] for this session

datafusion/sql/src/expr/mod.rs

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ use datafusion_expr::planner::{
2323
use recursive::recursive;
2424
use sqlparser::ast::{
2525
BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, DictionaryField,
26-
Expr as SQLExpr, MapEntry, StructField, Subscript, TrimWhereField, Value,
26+
Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias, MapEntry, StructField, Subscript,
27+
TrimWhereField, Value,
2728
};
2829

2930
use datafusion_common::{
@@ -50,6 +51,19 @@ mod unary_op;
5051
mod value;
5152

5253
impl<S: ContextProvider> SqlToRel<'_, S> {
54+
pub(crate) fn sql_expr_to_logical_expr_with_alias(
55+
&self,
56+
sql: SQLExprWithAlias,
57+
schema: &DFSchema,
58+
planner_context: &mut PlannerContext,
59+
) -> Result<Expr> {
60+
let mut expr =
61+
self.sql_expr_to_logical_expr(sql.expr, schema, planner_context)?;
62+
if let Some(alias) = sql.alias {
63+
expr = expr.alias(alias.value);
64+
}
65+
Ok(expr)
66+
}
5367
pub(crate) fn sql_expr_to_logical_expr(
5468
&self,
5569
sql: SQLExpr,
@@ -131,6 +145,20 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
131145
)))
132146
}
133147

148+
pub fn sql_to_expr_with_alias(
149+
&self,
150+
sql: SQLExprWithAlias,
151+
schema: &DFSchema,
152+
planner_context: &mut PlannerContext,
153+
) -> Result<Expr> {
154+
let mut expr =
155+
self.sql_expr_to_logical_expr_with_alias(sql, schema, planner_context)?;
156+
expr = self.rewrite_partial_qualifier(expr, schema);
157+
self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?;
158+
let (expr, _) = expr.infer_placeholder_types(schema)?;
159+
Ok(expr)
160+
}
161+
134162
/// Generate a relational expression from a SQL expression
135163
pub fn sql_to_expr(
136164
&self,
@@ -1091,8 +1119,11 @@ mod tests {
10911119
None
10921120
}
10931121

1094-
fn get_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> {
1095-
None
1122+
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
1123+
match name {
1124+
"sum" => Some(datafusion_functions_aggregate::sum::sum_udaf()),
1125+
_ => None,
1126+
}
10961127
}
10971128

10981129
fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
@@ -1112,7 +1143,7 @@ mod tests {
11121143
}
11131144

11141145
fn udaf_names(&self) -> Vec<String> {
1115-
Vec::new()
1146+
vec!["sum".to_string()]
11161147
}
11171148

11181149
fn udwf_names(&self) -> Vec<String> {
@@ -1167,4 +1198,25 @@ mod tests {
11671198
test_stack_overflow!(2048);
11681199
test_stack_overflow!(4096);
11691200
test_stack_overflow!(8192);
1201+
#[test]
1202+
fn test_sql_to_expr_with_alias() {
1203+
let schema = DFSchema::empty();
1204+
let mut planner_context = PlannerContext::default();
1205+
1206+
let expr_str = "SUM(int_col) as sum_int_col";
1207+
1208+
let dialect = GenericDialect {};
1209+
let mut parser = Parser::new(&dialect).try_with_sql(expr_str).unwrap();
1210+
// from sqlparser
1211+
let sql_expr = parser.parse_expr_with_alias().unwrap();
1212+
1213+
let context_provider = TestContextProvider::new();
1214+
let sql_to_rel = SqlToRel::new(&context_provider);
1215+
1216+
let expr = sql_to_rel
1217+
.sql_expr_to_logical_expr_with_alias(sql_expr, &schema, &mut planner_context)
1218+
.unwrap();
1219+
1220+
assert!(matches!(expr, Expr::Alias(_)));
1221+
}
11701222
}

datafusion/sql/src/parser.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
use std::collections::VecDeque;
2121
use std::fmt;
2222

23+
use sqlparser::ast::ExprWithAlias;
2324
use sqlparser::{
2425
ast::{
25-
ColumnDef, ColumnOptionDef, Expr, ObjectName, OrderByExpr, Query,
26+
ColumnDef, ColumnOptionDef, ObjectName, OrderByExpr, Query,
2627
Statement as SQLStatement, TableConstraint, Value,
2728
},
2829
dialect::{keywords::Keyword, Dialect, GenericDialect},
@@ -328,7 +329,7 @@ impl<'a> DFParser<'a> {
328329
pub fn parse_sql_into_expr_with_dialect(
329330
sql: &str,
330331
dialect: &dyn Dialect,
331-
) -> Result<Expr, ParserError> {
332+
) -> Result<ExprWithAlias, ParserError> {
332333
let mut parser = DFParser::new_with_dialect(sql, dialect)?;
333334
parser.parse_expr()
334335
}
@@ -377,7 +378,7 @@ impl<'a> DFParser<'a> {
377378
}
378379
}
379380

380-
pub fn parse_expr(&mut self) -> Result<Expr, ParserError> {
381+
pub fn parse_expr(&mut self) -> Result<ExprWithAlias, ParserError> {
381382
if let Token::Word(w) = self.parser.peek_token().token {
382383
match w.keyword {
383384
Keyword::CREATE | Keyword::COPY | Keyword::EXPLAIN => {
@@ -387,7 +388,7 @@ impl<'a> DFParser<'a> {
387388
}
388389
}
389390

390-
self.parser.parse_expr()
391+
self.parser.parse_expr_with_alias()
391392
}
392393

393394
/// Parse a SQL `COPY TO` statement

0 commit comments

Comments
 (0)