Skip to content

Commit 9278233

Browse files
authored
fix duplicated schema name error from count wildcard (#14824)
* fix name * upd doc * drop table * real count() * clippy * fix tests * fix test * fix other tests * fix proto test * fix substrait test * fnt * alias for count wildcard * subtrait * fix tests * add test * fmt * fix avro * upd test * alias whole expr aggr * fmt * window * fix * fix tests * avro * upd testing * tpch * upd test
1 parent 212f424 commit 9278233

34 files changed

+666
-533
lines changed

datafusion-testing

Submodule datafusion-testing updated 257 files

datafusion/core/tests/dataframe/dataframe_functions.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,9 +1145,9 @@ async fn test_count_wildcard() -> Result<()> {
11451145
.build()
11461146
.unwrap();
11471147

1148-
let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\
1149-
\n Projection: count(*) [count(*):Int64]\
1150-
\n Aggregate: groupBy=[[test.b]], aggr=[[count(*)]] [b:UInt32, count(*):Int64]\
1148+
let expected = "Sort: count(Int64(1)) ASC NULLS LAST [count(Int64(1)):Int64]\
1149+
\n Projection: count(Int64(1)) [count(Int64(1)):Int64]\
1150+
\n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1))]] [b:UInt32, count(Int64(1)):Int64]\
11511151
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
11521152

11531153
let formatted_plan = plan.display_indent_schema().to_string();

datafusion/core/tests/dataframe/mod.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2455,7 +2455,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
24552455
let ctx = create_join_context()?;
24562456

24572457
let sql_results = ctx
2458-
.sql("select b,count(*) from t1 group by b order by count(*)")
2458+
.sql("select b,count(1) from t1 group by b order by count(1)")
24592459
.await?
24602460
.explain(false, false)?
24612461
.collect()
@@ -2481,7 +2481,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
24812481
async fn test_count_wildcard_on_where_in() -> Result<()> {
24822482
let ctx = create_join_context()?;
24832483
let sql_results = ctx
2484-
.sql("SELECT a,b FROM t1 WHERE a in (SELECT count(*) FROM t2)")
2484+
.sql("SELECT a,b FROM t1 WHERE a in (SELECT count(1) FROM t2)")
24852485
.await?
24862486
.explain(false, false)?
24872487
.collect()
@@ -2522,7 +2522,7 @@ async fn test_count_wildcard_on_where_in() -> Result<()> {
25222522
async fn test_count_wildcard_on_where_exist() -> Result<()> {
25232523
let ctx = create_join_context()?;
25242524
let sql_results = ctx
2525-
.sql("SELECT a, b FROM t1 WHERE EXISTS (SELECT count(*) FROM t2)")
2525+
.sql("SELECT a, b FROM t1 WHERE EXISTS (SELECT count(1) FROM t2)")
25262526
.await?
25272527
.explain(false, false)?
25282528
.collect()
@@ -2559,7 +2559,7 @@ async fn test_count_wildcard_on_window() -> Result<()> {
25592559
let ctx = create_join_context()?;
25602560

25612561
let sql_results = ctx
2562-
.sql("select count(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1")
2562+
.sql("select count(1) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1")
25632563
.await?
25642564
.explain(false, false)?
25652565
.collect()
@@ -2598,7 +2598,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> {
25982598
register_alltypes_tiny_pages_parquet(&ctx).await?;
25992599

26002600
let sql_results = ctx
2601-
.sql("select count(*) from t1")
2601+
.sql("select count(1) from t1")
26022602
.await?
26032603
.explain(false, false)?
26042604
.collect()
@@ -2628,7 +2628,7 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
26282628
let ctx = create_join_context()?;
26292629

26302630
let sql_results = ctx
2631-
.sql("select a,b from t1 where (select count(*) from t2 where t1.a = t2.a)>0;")
2631+
.sql("select a,b from t1 where (select count(1) from t2 where t1.a = t2.a)>0;")
26322632
.await?
26332633
.explain(false, false)?
26342634
.collect()

datafusion/core/tests/sql/explain_analyze.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ async fn explain_analyze_baseline_metrics() {
7171
);
7272
assert_metrics!(
7373
&formatted,
74-
"ProjectionExec: expr=[count(*)",
75-
"metrics=[output_rows=1, elapsed_compute="
74+
"ProjectionExec: expr=[]",
75+
"metrics=[output_rows=5, elapsed_compute="
7676
);
7777
assert_metrics!(
7878
&formatted,
@@ -687,7 +687,7 @@ async fn csv_explain_analyze() {
687687
// Only test basic plumbing and try to avoid having to change too
688688
// many things. explain_analyze_baseline_metrics covers the values
689689
// in greater depth
690-
let needle = "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[count(*)], metrics=[output_rows=5";
690+
let needle = "ProjectionExec: expr=[count(Int64(1))@1 as count(*), c1@0 as c1], metrics=[output_rows=5";
691691
assert_contains!(&formatted, needle);
692692

693693
let verbose_needle = "Output Rows";
@@ -778,13 +778,11 @@ async fn explain_logical_plan_only() {
778778
let actual = normalize_vec_for_explain(actual);
779779

780780
let expected = vec![
781-
vec![
782-
"logical_plan",
783-
"Aggregate: groupBy=[[]], aggr=[[count(*)]]\
784-
\n SubqueryAlias: t\
785-
\n Projection: \
786-
\n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))"
787-
]];
781+
vec!["logical_plan", "Projection: count(Int64(1)) AS count(*)\
782+
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\
783+
\n SubqueryAlias: t\
784+
\n Projection: \
785+
\n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))"]];
788786
assert_eq!(expected, actual);
789787
}
790788

datafusion/expr/src/expr_rewriter/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ pub struct NamePreserver {
286286

287287
/// If the qualified name of an expression is remembered, it will be preserved
288288
/// when rewriting the expression
289+
#[derive(Debug)]
289290
pub enum SavedName {
290291
/// Saved qualified name to be preserved
291292
Saved {

datafusion/expr/src/udaf.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ impl fmt::Display for AggregateUDF {
100100
}
101101

102102
/// Arguments passed to [`AggregateUDFImpl::value_from_stats`]
103+
#[derive(Debug)]
103104
pub struct StatisticsArgs<'a> {
104105
/// The statistics of the aggregate input
105106
pub statistics: &'a Statistics,

datafusion/functions-aggregate/src/count.rs

Lines changed: 2 additions & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,11 @@
1717

1818
use ahash::RandomState;
1919
use datafusion_common::stats::Precision;
20-
use datafusion_expr::expr::{
21-
schema_name_from_exprs, schema_name_from_sorts, AggregateFunctionParams,
22-
WindowFunctionParams,
23-
};
2420
use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
2521
use datafusion_macros::user_doc;
2622
use datafusion_physical_expr::expressions;
2723
use std::collections::HashSet;
28-
use std::fmt::{Debug, Write};
24+
use std::fmt::Debug;
2925
use std::mem::{size_of, size_of_val};
3026
use std::ops::BitAnd;
3127
use std::sync::Arc;
@@ -51,11 +47,11 @@ use datafusion_common::{
5147
downcast_value, internal_err, not_impl_err, Result, ScalarValue,
5248
};
5349
use datafusion_expr::function::StateFieldsArgs;
54-
use datafusion_expr::{expr_vec_fmt, Expr, ReversedUDAF, StatisticsArgs, TypeSignature};
5550
use datafusion_expr::{
5651
function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
5752
Documentation, EmitTo, GroupsAccumulator, SetMonotonicity, Signature, Volatility,
5853
};
54+
use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature};
5955
use datafusion_functions_aggregate_common::aggregate::count_distinct::{
6056
BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
6157
PrimitiveDistinctCountAccumulator,
@@ -148,185 +144,6 @@ impl AggregateUDFImpl for Count {
148144
"count"
149145
}
150146

151-
fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
152-
let AggregateFunctionParams {
153-
args,
154-
distinct,
155-
filter,
156-
order_by,
157-
null_treatment,
158-
} = params;
159-
160-
let mut schema_name = String::new();
161-
162-
if is_count_wildcard(args) {
163-
schema_name.write_str("count(*)")?;
164-
} else {
165-
schema_name.write_fmt(format_args!(
166-
"{}({}{})",
167-
self.name(),
168-
if *distinct { "DISTINCT " } else { "" },
169-
schema_name_from_exprs(args)?
170-
))?;
171-
}
172-
173-
if let Some(null_treatment) = null_treatment {
174-
schema_name.write_fmt(format_args!(" {}", null_treatment))?;
175-
}
176-
177-
if let Some(filter) = filter {
178-
schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
179-
};
180-
181-
if let Some(order_by) = order_by {
182-
schema_name.write_fmt(format_args!(
183-
" ORDER BY [{}]",
184-
schema_name_from_sorts(order_by)?
185-
))?;
186-
};
187-
188-
Ok(schema_name)
189-
}
190-
191-
fn window_function_schema_name(
192-
&self,
193-
params: &WindowFunctionParams,
194-
) -> Result<String> {
195-
let WindowFunctionParams {
196-
args,
197-
partition_by,
198-
order_by,
199-
window_frame,
200-
null_treatment,
201-
} = params;
202-
203-
let mut schema_name = String::new();
204-
205-
if is_count_wildcard(args) {
206-
schema_name.write_str("count(*)")?;
207-
} else {
208-
schema_name.write_fmt(format_args!(
209-
"{}({})",
210-
self.name(),
211-
schema_name_from_exprs(args)?
212-
))?;
213-
}
214-
215-
if let Some(null_treatment) = null_treatment {
216-
schema_name.write_fmt(format_args!(" {}", null_treatment))?;
217-
}
218-
219-
if !partition_by.is_empty() {
220-
schema_name.write_fmt(format_args!(
221-
" PARTITION BY [{}]",
222-
schema_name_from_exprs(partition_by)?
223-
))?;
224-
}
225-
226-
if !order_by.is_empty() {
227-
schema_name.write_fmt(format_args!(
228-
" ORDER BY [{}]",
229-
schema_name_from_sorts(order_by)?
230-
))?;
231-
};
232-
233-
schema_name.write_fmt(format_args!(" {window_frame}"))?;
234-
235-
Ok(schema_name)
236-
}
237-
238-
fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
239-
let AggregateFunctionParams {
240-
args,
241-
distinct,
242-
filter,
243-
order_by,
244-
null_treatment,
245-
} = params;
246-
247-
let mut display_name = String::new();
248-
249-
if is_count_wildcard(args) {
250-
display_name.write_str("count(*)")?;
251-
} else {
252-
display_name.write_fmt(format_args!(
253-
"{}({}{})",
254-
self.name(),
255-
if *distinct { "DISTINCT " } else { "" },
256-
args.iter()
257-
.map(|arg| format!("{arg}"))
258-
.collect::<Vec<String>>()
259-
.join(", ")
260-
))?;
261-
}
262-
263-
if let Some(nt) = null_treatment {
264-
display_name.write_fmt(format_args!(" {}", nt))?;
265-
}
266-
if let Some(fe) = filter {
267-
display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
268-
}
269-
if let Some(ob) = order_by {
270-
display_name.write_fmt(format_args!(
271-
" ORDER BY [{}]",
272-
ob.iter()
273-
.map(|o| format!("{o}"))
274-
.collect::<Vec<String>>()
275-
.join(", ")
276-
))?;
277-
}
278-
279-
Ok(display_name)
280-
}
281-
282-
fn window_function_display_name(
283-
&self,
284-
params: &WindowFunctionParams,
285-
) -> Result<String> {
286-
let WindowFunctionParams {
287-
args,
288-
partition_by,
289-
order_by,
290-
window_frame,
291-
null_treatment,
292-
} = params;
293-
294-
let mut display_name = String::new();
295-
296-
if is_count_wildcard(args) {
297-
display_name.write_str("count(*)")?;
298-
} else {
299-
display_name.write_fmt(format_args!(
300-
"{}({})",
301-
self.name(),
302-
expr_vec_fmt!(args)
303-
))?;
304-
}
305-
306-
if let Some(null_treatment) = null_treatment {
307-
display_name.write_fmt(format_args!(" {}", null_treatment))?;
308-
}
309-
310-
if !partition_by.is_empty() {
311-
display_name.write_fmt(format_args!(
312-
" PARTITION BY [{}]",
313-
expr_vec_fmt!(partition_by)
314-
))?;
315-
}
316-
317-
if !order_by.is_empty() {
318-
display_name
319-
.write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?;
320-
};
321-
322-
display_name.write_fmt(format_args!(
323-
" {} BETWEEN {} AND {}",
324-
window_frame.units, window_frame.start_bound, window_frame.end_bound
325-
))?;
326-
327-
Ok(display_name)
328-
}
329-
330147
fn signature(&self) -> &Signature {
331148
&self.signature
332149
}
@@ -547,15 +364,6 @@ impl AggregateUDFImpl for Count {
547364
}
548365
}
549366

550-
fn is_count_wildcard(args: &[Expr]) -> bool {
551-
match args {
552-
[] => true, // count()
553-
// All const should be coerced to int64 or rejected by the signature
554-
[Expr::Literal(ScalarValue::Int64(Some(_)))] => true, // count(1)
555-
_ => false, // More than one argument or non-matching cases
556-
}
557-
}
558-
559367
#[derive(Debug)]
560368
struct CountAccumulator {
561369
count: i64,

0 commit comments

Comments
 (0)