Skip to content

Commit e03f9f6

Browse files
authored
Remove CountWildcardRule in Analyzer and move the functionality in ExprPlanner, add plan_aggregate and plan_window to planner (#14689)
* count planner * window * update slt * remove rule * rm rule * doc * fix name * fix name * fix test * tpch test * fix avro * rename * switch to count(*) * use count(*) * rename * doc * rename window funciotn * fmt * rm print * upd logic * count null
1 parent 22156b2 commit e03f9f6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+652
-442
lines changed

datafusion/core/src/execution/session_state_defaults.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ impl SessionStateDefaults {
9494
feature = "unicode_expressions"
9595
))]
9696
Arc::new(functions::planner::UserDefinedFunctionPlanner),
97+
Arc::new(functions_aggregate::planner::AggregateFunctionPlanner),
98+
Arc::new(functions_window::planner::WindowFunctionPlanner),
9799
];
98100

99101
expr_planners

datafusion/core/tests/dataframe/dataframe_functions.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use arrow::{
2222
array::{Int32Array, StringArray},
2323
record_batch::RecordBatch,
2424
};
25+
use datafusion_functions_aggregate::count::count_all;
2526
use std::sync::Arc;
2627

2728
use datafusion::error::Result;
@@ -31,7 +32,7 @@ use datafusion::prelude::*;
3132
use datafusion::assert_batches_eq;
3233
use datafusion_common::{DFSchema, ScalarValue};
3334
use datafusion_expr::expr::Alias;
34-
use datafusion_expr::ExprSchemable;
35+
use datafusion_expr::{table_scan, ExprSchemable, LogicalPlanBuilder};
3536
use datafusion_functions_aggregate::expr_fn::{approx_median, approx_percentile_cont};
3637
use datafusion_functions_nested::map::map;
3738

@@ -1123,3 +1124,34 @@ async fn test_fn_map() -> Result<()> {
11231124

11241125
Ok(())
11251126
}
1127+
1128+
/// Call count wildcard from dataframe API
1129+
#[tokio::test]
1130+
async fn test_count_wildcard() -> Result<()> {
1131+
let schema = Schema::new(vec![
1132+
Field::new("a", DataType::UInt32, false),
1133+
Field::new("b", DataType::UInt32, false),
1134+
Field::new("c", DataType::UInt32, false),
1135+
]);
1136+
1137+
let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1138+
let plan = LogicalPlanBuilder::from(table_scan)
1139+
.aggregate(vec![col("b")], vec![count_all()])
1140+
.unwrap()
1141+
.project(vec![count_all()])
1142+
.unwrap()
1143+
.sort(vec![count_all().sort(true, false)])
1144+
.unwrap()
1145+
.build()
1146+
.unwrap();
1147+
1148+
let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\
1149+
\n Projection: count(*) [count(*):Int64]\
1150+
\n Aggregate: groupBy=[[test.b]], aggr=[[count(*)]] [b:UInt32, count(*):Int64]\
1151+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
1152+
1153+
let formatted_plan = plan.display_indent_schema().to_string();
1154+
assert_eq!(formatted_plan, expected);
1155+
1156+
Ok(())
1157+
}

datafusion/core/tests/dataframe/mod.rs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ use arrow::datatypes::{
3232
};
3333
use arrow::error::ArrowError;
3434
use arrow::util::pretty::pretty_format_batches;
35-
use datafusion_functions_aggregate::count::count_udaf;
35+
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
36+
use datafusion_functions_aggregate::count::{count_all, count_udaf};
3637
use datafusion_functions_aggregate::expr_fn::{
3738
array_agg, avg, count, count_distinct, max, median, min, sum,
3839
};
@@ -72,7 +73,7 @@ use datafusion_expr::expr::{GroupingSet, Sort, WindowFunction};
7273
use datafusion_expr::var_provider::{VarProvider, VarType};
7374
use datafusion_expr::{
7475
cast, col, create_udf, exists, in_subquery, lit, out_ref_col, placeholder,
75-
scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan,
76+
scalar_subquery, when, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan,
7677
ScalarFunctionImplementation, WindowFrame, WindowFrameBound, WindowFrameUnits,
7778
WindowFunctionDefinition,
7879
};
@@ -2463,8 +2464,8 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
24632464
let df_results = ctx
24642465
.table("t1")
24652466
.await?
2466-
.aggregate(vec![col("b")], vec![count(wildcard())])?
2467-
.sort(vec![count(wildcard()).sort(true, false)])?
2467+
.aggregate(vec![col("b")], vec![count_all()])?
2468+
.sort(vec![count_all().sort(true, false)])?
24682469
.explain(false, false)?
24692470
.collect()
24702471
.await?;
@@ -2498,8 +2499,8 @@ async fn test_count_wildcard_on_where_in() -> Result<()> {
24982499
Arc::new(
24992500
ctx.table("t2")
25002501
.await?
2501-
.aggregate(vec![], vec![count(wildcard())])?
2502-
.select(vec![count(wildcard())])?
2502+
.aggregate(vec![], vec![count_all()])?
2503+
.select(vec![count_all()])?
25032504
.into_optimized_plan()?,
25042505
),
25052506
))?
@@ -2532,8 +2533,8 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> {
25322533
.filter(exists(Arc::new(
25332534
ctx.table("t2")
25342535
.await?
2535-
.aggregate(vec![], vec![count(wildcard())])?
2536-
.select(vec![count(wildcard())])?
2536+
.aggregate(vec![], vec![count_all()])?
2537+
.select(vec![count_all()])?
25372538
.into_unoptimized_plan(),
25382539
// Usually, into_optimized_plan() should be used here, but due to
25392540
// https://github.com/apache/datafusion/issues/5771,
@@ -2568,7 +2569,7 @@ async fn test_count_wildcard_on_window() -> Result<()> {
25682569
.await?
25692570
.select(vec![Expr::WindowFunction(WindowFunction::new(
25702571
WindowFunctionDefinition::AggregateUDF(count_udaf()),
2571-
vec![wildcard()],
2572+
vec![Expr::Literal(COUNT_STAR_EXPANSION)],
25722573
))
25732574
.order_by(vec![Sort::new(col("a"), false, true)])
25742575
.window_frame(WindowFrame::new_bounds(
@@ -2599,17 +2600,16 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> {
25992600
let sql_results = ctx
26002601
.sql("select count(*) from t1")
26012602
.await?
2602-
.select(vec![col("count(*)")])?
26032603
.explain(false, false)?
26042604
.collect()
26052605
.await?;
26062606

2607-
// add `.select(vec![count(wildcard())])?` to make sure we can analyze all node instead of just top node.
2607+
// add `.select(vec![count_wildcard()])?` to make sure we can analyze all node instead of just top node.
26082608
let df_results = ctx
26092609
.table("t1")
26102610
.await?
2611-
.aggregate(vec![], vec![count(wildcard())])?
2612-
.select(vec![count(wildcard())])?
2611+
.aggregate(vec![], vec![count_all()])?
2612+
.select(vec![count_all()])?
26132613
.explain(false, false)?
26142614
.collect()
26152615
.await?;
@@ -2646,8 +2646,8 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
26462646
ctx.table("t2")
26472647
.await?
26482648
.filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))?
2649-
.aggregate(vec![], vec![count(wildcard())])?
2650-
.select(vec![col(count(wildcard()).to_string())])?
2649+
.aggregate(vec![], vec![count_all()])?
2650+
.select(vec![col(count_all().to_string())])?
26512651
.into_unoptimized_plan(),
26522652
))
26532653
.gt(lit(ScalarValue::UInt8(Some(0)))),

datafusion/core/tests/sql/explain_analyze.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ async fn explain_logical_plan_only() {
780780
let expected = vec![
781781
vec![
782782
"logical_plan",
783-
"Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]\
783+
"Aggregate: groupBy=[[]], aggr=[[count(*)]]\
784784
\n SubqueryAlias: t\
785785
\n Projection: \
786786
\n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))"

datafusion/expr/src/expr.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2294,7 +2294,6 @@ impl Display for SchemaDisplay<'_> {
22942294
| Expr::OuterReferenceColumn(..)
22952295
| Expr::Placeholder(_)
22962296
| Expr::Wildcard { .. } => write!(f, "{}", self.0),
2297-
22982297
Expr::AggregateFunction(AggregateFunction { func, params }) => {
22992298
match func.schema_name(params) {
23002299
Ok(name) => {

datafusion/expr/src/planner.rs

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@ use datafusion_common::{
2525
config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema,
2626
Result, TableReference,
2727
};
28-
use sqlparser::ast;
28+
use sqlparser::ast::{self, NullTreatment};
2929

30-
use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF};
30+
use crate::{
31+
AggregateUDF, Expr, GetFieldAccess, ScalarUDF, SortExpr, TableSource, WindowFrame,
32+
WindowFunctionDefinition, WindowUDF,
33+
};
3134

3235
/// Provides the `SQL` query planner meta-data about tables and
3336
/// functions referenced in SQL statements, without a direct dependency on the
@@ -138,7 +141,7 @@ pub trait ExprPlanner: Debug + Send + Sync {
138141

139142
/// Plan an array literal, such as `[1, 2, 3]`
140143
///
141-
/// Returns origin expression arguments if not possible
144+
/// Returns original expression arguments if not possible
142145
fn plan_array_literal(
143146
&self,
144147
exprs: Vec<Expr>,
@@ -149,14 +152,14 @@ pub trait ExprPlanner: Debug + Send + Sync {
149152

150153
/// Plan a `POSITION` expression, such as `POSITION(<expr> in <expr>)`
151154
///
152-
/// returns origin expression arguments if not possible
155+
/// Returns original expression arguments if not possible
153156
fn plan_position(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
154157
Ok(PlannerResult::Original(args))
155158
}
156159

157160
/// Plan a dictionary literal, such as `{ key: value, ...}`
158161
///
159-
/// Returns origin expression arguments if not possible
162+
/// Returns original expression arguments if not possible
160163
fn plan_dictionary_literal(
161164
&self,
162165
expr: RawDictionaryExpr,
@@ -167,14 +170,14 @@ pub trait ExprPlanner: Debug + Send + Sync {
167170

168171
/// Plan an extract expression, such as`EXTRACT(month FROM foo)`
169172
///
170-
/// Returns origin expression arguments if not possible
173+
/// Returns original expression arguments if not possible
171174
fn plan_extract(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
172175
Ok(PlannerResult::Original(args))
173176
}
174177

175178
/// Plan an substring expression, such as `SUBSTRING(<expr> [FROM <expr>] [FOR <expr>])`
176179
///
177-
/// Returns origin expression arguments if not possible
180+
/// Returns original expression arguments if not possible
178181
fn plan_substring(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
179182
Ok(PlannerResult::Original(args))
180183
}
@@ -195,14 +198,14 @@ pub trait ExprPlanner: Debug + Send + Sync {
195198

196199
/// Plans an overlay expression, such as `overlay(str PLACING substr FROM pos [FOR count])`
197200
///
198-
/// Returns origin expression arguments if not possible
201+
/// Returns original expression arguments if not possible
199202
fn plan_overlay(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
200203
Ok(PlannerResult::Original(args))
201204
}
202205

203206
/// Plans a `make_map` expression, such as `make_map(key1, value1, key2, value2, ...)`
204207
///
205-
/// Returns origin expression arguments if not possible
208+
/// Returns original expression arguments if not possible
206209
fn plan_make_map(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
207210
Ok(PlannerResult::Original(args))
208211
}
@@ -230,6 +233,23 @@ pub trait ExprPlanner: Debug + Send + Sync {
230233
fn plan_any(&self, expr: RawBinaryExpr) -> Result<PlannerResult<RawBinaryExpr>> {
231234
Ok(PlannerResult::Original(expr))
232235
}
236+
237+
/// Plans aggregate functions, such as `COUNT(<expr>)`
238+
///
239+
/// Returns original expression arguments if not possible
240+
fn plan_aggregate(
241+
&self,
242+
expr: RawAggregateExpr,
243+
) -> Result<PlannerResult<RawAggregateExpr>> {
244+
Ok(PlannerResult::Original(expr))
245+
}
246+
247+
/// Plans window functions, such as `COUNT(<expr>)`
248+
///
249+
/// Returns original expression arguments if not possible
250+
fn plan_window(&self, expr: RawWindowExpr) -> Result<PlannerResult<RawWindowExpr>> {
251+
Ok(PlannerResult::Original(expr))
252+
}
233253
}
234254

235255
/// An operator with two arguments to plan
@@ -266,6 +286,30 @@ pub struct RawDictionaryExpr {
266286
pub values: Vec<Expr>,
267287
}
268288

289+
/// This structure is used by `AggregateFunctionPlanner` to plan operators with
290+
/// custom expressions.
291+
#[derive(Debug, Clone)]
292+
pub struct RawAggregateExpr {
293+
pub func: Arc<AggregateUDF>,
294+
pub args: Vec<Expr>,
295+
pub distinct: bool,
296+
pub filter: Option<Box<Expr>>,
297+
pub order_by: Option<Vec<SortExpr>>,
298+
pub null_treatment: Option<NullTreatment>,
299+
}
300+
301+
/// This structure is used by `WindowFunctionPlanner` to plan operators with
302+
/// custom expressions.
303+
#[derive(Debug, Clone)]
304+
pub struct RawWindowExpr {
305+
pub func_def: WindowFunctionDefinition,
306+
pub args: Vec<Expr>,
307+
pub partition_by: Vec<Expr>,
308+
pub order_by: Vec<SortExpr>,
309+
pub window_frame: WindowFrame,
310+
pub null_treatment: Option<NullTreatment>,
311+
}
312+
269313
/// Result of planning a raw expr with [`ExprPlanner`]
270314
#[derive(Debug, Clone)]
271315
pub enum PlannerResult<T> {

datafusion/expr/src/udaf.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -515,27 +515,32 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
515515
null_treatment,
516516
} = params;
517517

518-
let mut schema_name = String::new();
518+
let mut display_name = String::new();
519519

520-
schema_name.write_fmt(format_args!(
520+
display_name.write_fmt(format_args!(
521521
"{}({}{})",
522522
self.name(),
523523
if *distinct { "DISTINCT " } else { "" },
524524
expr_vec_fmt!(args)
525525
))?;
526526

527527
if let Some(nt) = null_treatment {
528-
schema_name.write_fmt(format_args!(" {}", nt))?;
528+
display_name.write_fmt(format_args!(" {}", nt))?;
529529
}
530530
if let Some(fe) = filter {
531-
schema_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
531+
display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
532532
}
533-
if let Some(order_by) = order_by {
534-
schema_name
535-
.write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?;
533+
if let Some(ob) = order_by {
534+
display_name.write_fmt(format_args!(
535+
" ORDER BY [{}]",
536+
ob.iter()
537+
.map(|o| format!("{o}"))
538+
.collect::<Vec<String>>()
539+
.join(", ")
540+
))?;
536541
}
537542

538-
Ok(schema_name)
543+
Ok(display_name)
539544
}
540545

541546
/// Returns the user-defined display name of function, given the arguments

0 commit comments

Comments
 (0)