Skip to content

fix duplicated schema name error from count wildcard #14824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion-testing
6 changes: 3 additions & 3 deletions datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1145,9 +1145,9 @@ async fn test_count_wildcard() -> Result<()> {
.build()
.unwrap();

let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\
\n Projection: count(*) [count(*):Int64]\
\n Aggregate: groupBy=[[test.b]], aggr=[[count(*)]] [b:UInt32, count(*):Int64]\
let expected = "Sort: count(Int64(1)) ASC NULLS LAST [count(Int64(1)):Int64]\
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

count_all() is count(1)

\n Projection: count(Int64(1)) [count(Int64(1)):Int64]\
\n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1))]] [b:UInt32, count(Int64(1)):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

let formatted_plan = plan.display_indent_schema().to_string();
Expand Down
12 changes: 6 additions & 6 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2455,7 +2455,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
let ctx = create_join_context()?;

let sql_results = ctx
.sql("select b,count(*) from t1 group by b order by count(*)")
.sql("select b,count(1) from t1 group by b order by count(1)")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

count_all() is count(1) now, so we need to change to count(1) to have a consistent name (count() is count(a) AS count() now).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to double check -- the reason this needs to change is that the test is comparing again a dataframe built with count_all() which now uses count(1)

Though maybe we could change count_all() to return count(1) as "count(*)" so it would be consistent with older versions?

Copy link
Contributor Author

@jayzhan211 jayzhan211 Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  left: "+---------------+------------------------------------------------------------------------------------------------------------+\n| plan_type     | plan                                                                                                       |\n+---------------+------------------------------------------------------------------------------------------------------------+\n| logical_plan  | Projection: t1.b, count(*)                                                                                 |\n|               |   Sort: count(Int64(1)) AS count(*) AS count(*) ASC NULLS LAST                                             |\n|               |     Projection: t1.b, count(Int64(1)) AS count(*), count(Int64(1))                                         |\n|               |       Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1))]]                                                |\n|               |         TableScan: t1 projection=[b]                                                                       |\n| physical_plan | ProjectionExec: expr=[b@0 as b, count(*)@1 as count(*)]                                                    |\n|               |   SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST]                                              |\n|               |     SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true]                        |\n|               |       ProjectionExec: expr=[b@0 as b, count(Int64(1))@1 as count(*), count(Int64(1))@1 as count(Int64(1))] |\n|               |         AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(Int64(1))]                       |\n|               |           CoalesceBatchesExec: target_batch_size=8192                                                      |\n|               |             RepartitionExec: partitioning=Hash([b@0], 12), input_partitions=12                             |\n|               |               RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1                        |\n|               |                 AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(Int64(1))]                        |\n|               |                   DataSourceExec: partitions=1, partition_sizes=[1]                                        |\n|               |                                                                                                            |\n+---------------+------------------------------------------------------------------------------------------------------------+"
 right: "+---------------+-----------------------------------------------------------------------------------+\n| plan_type     | plan                                                                              |\n+---------------+-----------------------------------------------------------------------------------+\n| logical_plan  | Sort: count(Int64(1)) ASC NULLS LAST                                              |\n|               |   Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1))]]                           |\n|               |     TableScan: t1 projection=[b]                                                  |\n| physical_plan | SortPreservingMergeExec: [count(Int64(1))@1 ASC NULLS LAST]                       |\n|               |   SortExec: expr=[count(Int64(1))@1 ASC NULLS LAST], preserve_partitioning=[true] |\n|               |     AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(Int64(1))]  |\n|               |       CoalesceBatchesExec: target_batch_size=8192                                 |\n|               |         RepartitionExec: partitioning=Hash([b@0], 12), input_partitions=12        |\n|               |           RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1   |\n|               |             AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(Int64(1))]   |\n|               |               DataSourceExec: partitions=1, partition_sizes=[1]                   |\n|               |                                                                                   |\n+---------------+-----------------------------------------------------------------------------------+"

This is the err after changing it back to count(*)
we have additional projection now

Projection: t1.b, count(Int64(1)) AS count(*), count(Int64(1))

Copy link
Contributor Author

@jayzhan211 jayzhan211 Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double alias in sort 😕

query I
SELECT count(*) order by count(*);
----
1

query TT
explain SELECT count(*) order by count(*);
----
logical_plan
01)Projection: count(*)
02)--Sort: count(Int64(1)) AS count(*) AS count(*) ASC NULLS LAST
03)----Projection: count(Int64(1)) AS count(*), count(Int64(1))
04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
05)--------EmptyRelation
physical_plan
01)ProjectionExec: expr=[1 as count(*)]
02)--PlaceholderRowExec

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found I could avoid the double alias by adding a check in Expr::alias:

diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index f8baf9c94..2f3c2c575 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -1276,7 +1276,14 @@ impl Expr {

     /// Return `self AS name` alias expression
     pub fn alias(self, name: impl Into<String>) -> Expr {
-        Expr::Alias(Alias::new(self, None::<&str>, name.into()))
+        let name = name.into();
+        // don't realias the same thing
+        if matches!(&self, Expr::Alias(Alias {name: existing_name, ..} ) if existing_name == &name)
+        {
+            self
+        } else {
+            Expr::Alias(Alias::new(self, None::<&str>, name))
+        }
     }

     /// Return `self AS name` alias expression with a specific qualifier
@@ -1285,7 +1292,15 @@ impl Expr {
         relation: Option<impl Into<TableReference>>,
         name: impl Into<String>,
     ) -> Expr {
-        Expr::Alias(Alias::new(self, relation, name.into()))
+        let relation = relation.map(|r| r.into());
+        let name = name.into();
+        // don't realias the same thing
+        if matches!(&self, Expr::Alias(Alias {name: existing_name, relation: existing_relation, ..} ) if existing_name == &name && relation.as_ref()==existing_relation.as_ref() )
+        {
+            self
+        } else {
+            Expr::Alias(Alias::new(self, relation, name))
+        }
     }

     /// Remove an alias from an expression if one exists.
diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs
index a3339f0fc..1faf1968b 100644
--- a/datafusion/functions-aggregate/src/count.rs
+++ b/datafusion/functions-aggregate/src/count.rs
@@ -81,7 +81,7 @@ pub fn count_distinct(expr: Expr) -> Expr {

 /// Creates aggregation to count all rows, equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`
 pub fn count_all() -> Expr {
-    count(Expr::Literal(COUNT_STAR_EXPANSION))
+    count(Expr::Literal(COUNT_STAR_EXPANSION)).alias("count(*)")
 }

 #[user_doc(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.await?
.explain(false, false)?
.collect()
Expand All @@ -2481,7 +2481,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
async fn test_count_wildcard_on_where_in() -> Result<()> {
let ctx = create_join_context()?;
let sql_results = ctx
.sql("SELECT a,b FROM t1 WHERE a in (SELECT count(*) FROM t2)")
.sql("SELECT a,b FROM t1 WHERE a in (SELECT count(1) FROM t2)")
.await?
.explain(false, false)?
.collect()
Expand Down Expand Up @@ -2522,7 +2522,7 @@ async fn test_count_wildcard_on_where_in() -> Result<()> {
async fn test_count_wildcard_on_where_exist() -> Result<()> {
let ctx = create_join_context()?;
let sql_results = ctx
.sql("SELECT a, b FROM t1 WHERE EXISTS (SELECT count(*) FROM t2)")
.sql("SELECT a, b FROM t1 WHERE EXISTS (SELECT count(1) FROM t2)")
.await?
.explain(false, false)?
.collect()
Expand Down Expand Up @@ -2559,7 +2559,7 @@ async fn test_count_wildcard_on_window() -> Result<()> {
let ctx = create_join_context()?;

let sql_results = ctx
.sql("select count(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1")
.sql("select count(1) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1")
.await?
.explain(false, false)?
.collect()
Expand Down Expand Up @@ -2598,7 +2598,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> {
register_alltypes_tiny_pages_parquet(&ctx).await?;

let sql_results = ctx
.sql("select count(*) from t1")
.sql("select count(1) from t1")
.await?
.explain(false, false)?
.collect()
Expand Down Expand Up @@ -2628,7 +2628,7 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
let ctx = create_join_context()?;

let sql_results = ctx
.sql("select a,b from t1 where (select count(*) from t2 where t1.a = t2.a)>0;")
.sql("select a,b from t1 where (select count(1) from t2 where t1.a = t2.a)>0;")
.await?
.explain(false, false)?
.collect()
Expand Down
18 changes: 8 additions & 10 deletions datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ async fn explain_analyze_baseline_metrics() {
);
assert_metrics!(
&formatted,
"ProjectionExec: expr=[count(*)",
"metrics=[output_rows=1, elapsed_compute="
"ProjectionExec: expr=[]",
"metrics=[output_rows=5, elapsed_compute="
);
assert_metrics!(
&formatted,
Expand Down Expand Up @@ -687,7 +687,7 @@ async fn csv_explain_analyze() {
// Only test basic plumbing and try to avoid having to change too
// many things. explain_analyze_baseline_metrics covers the values
// in greater depth
let needle = "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[count(*)], metrics=[output_rows=5";
let needle = "ProjectionExec: expr=[count(Int64(1))@1 as count(*), c1@0 as c1], metrics=[output_rows=5";
assert_contains!(&formatted, needle);

let verbose_needle = "Output Rows";
Expand Down Expand Up @@ -778,13 +778,11 @@ async fn explain_logical_plan_only() {
let actual = normalize_vec_for_explain(actual);

let expected = vec![
vec![
"logical_plan",
"Aggregate: groupBy=[[]], aggr=[[count(*)]]\
\n SubqueryAlias: t\
\n Projection: \
\n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))"
]];
vec!["logical_plan", "Projection: count(Int64(1)) AS count(*)\
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\
\n SubqueryAlias: t\
\n Projection: \
\n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))"]];
assert_eq!(expected, actual);
}

Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/expr_rewriter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ pub struct NamePreserver {

/// If the qualified name of an expression is remembered, it will be preserved
/// when rewriting the expression
#[derive(Debug)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯

pub enum SavedName {
/// Saved qualified name to be preserved
Saved {
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ impl fmt::Display for AggregateUDF {
}

/// Arguments passed to [`AggregateUDFImpl::value_from_stats`]
#[derive(Debug)]
pub struct StatisticsArgs<'a> {
/// The statistics of the aggregate input
pub statistics: &'a Statistics,
Expand Down
196 changes: 2 additions & 194 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,11 @@

use ahash::RandomState;
use datafusion_common::stats::Precision;
use datafusion_expr::expr::{
schema_name_from_exprs, schema_name_from_sorts, AggregateFunctionParams,
WindowFunctionParams,
};
use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
use datafusion_macros::user_doc;
use datafusion_physical_expr::expressions;
use std::collections::HashSet;
use std::fmt::{Debug, Write};
use std::fmt::Debug;
use std::mem::{size_of, size_of_val};
use std::ops::BitAnd;
use std::sync::Arc;
Expand All @@ -51,11 +47,11 @@ use datafusion_common::{
downcast_value, internal_err, not_impl_err, Result, ScalarValue,
};
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::{expr_vec_fmt, Expr, ReversedUDAF, StatisticsArgs, TypeSignature};
use datafusion_expr::{
function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
Documentation, EmitTo, GroupsAccumulator, SetMonotonicity, Signature, Volatility,
};
use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature};
use datafusion_functions_aggregate_common::aggregate::count_distinct::{
BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
PrimitiveDistinctCountAccumulator,
Expand Down Expand Up @@ -148,185 +144,6 @@ impl AggregateUDFImpl for Count {
"count"
}

fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
let AggregateFunctionParams {
args,
distinct,
filter,
order_by,
null_treatment,
} = params;

let mut schema_name = String::new();

if is_count_wildcard(args) {
schema_name.write_str("count(*)")?;
} else {
schema_name.write_fmt(format_args!(
"{}({}{})",
self.name(),
if *distinct { "DISTINCT " } else { "" },
schema_name_from_exprs(args)?
))?;
}

if let Some(null_treatment) = null_treatment {
schema_name.write_fmt(format_args!(" {}", null_treatment))?;
}

if let Some(filter) = filter {
schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
};

if let Some(order_by) = order_by {
schema_name.write_fmt(format_args!(
" ORDER BY [{}]",
schema_name_from_sorts(order_by)?
))?;
};

Ok(schema_name)
}

fn window_function_schema_name(
&self,
params: &WindowFunctionParams,
) -> Result<String> {
let WindowFunctionParams {
args,
partition_by,
order_by,
window_frame,
null_treatment,
} = params;

let mut schema_name = String::new();

if is_count_wildcard(args) {
schema_name.write_str("count(*)")?;
} else {
schema_name.write_fmt(format_args!(
"{}({})",
self.name(),
schema_name_from_exprs(args)?
))?;
}

if let Some(null_treatment) = null_treatment {
schema_name.write_fmt(format_args!(" {}", null_treatment))?;
}

if !partition_by.is_empty() {
schema_name.write_fmt(format_args!(
" PARTITION BY [{}]",
schema_name_from_exprs(partition_by)?
))?;
}

if !order_by.is_empty() {
schema_name.write_fmt(format_args!(
" ORDER BY [{}]",
schema_name_from_sorts(order_by)?
))?;
};

schema_name.write_fmt(format_args!(" {window_frame}"))?;

Ok(schema_name)
}

fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
let AggregateFunctionParams {
args,
distinct,
filter,
order_by,
null_treatment,
} = params;

let mut display_name = String::new();

if is_count_wildcard(args) {
display_name.write_str("count(*)")?;
} else {
display_name.write_fmt(format_args!(
"{}({}{})",
self.name(),
if *distinct { "DISTINCT " } else { "" },
args.iter()
.map(|arg| format!("{arg}"))
.collect::<Vec<String>>()
.join(", ")
))?;
}

if let Some(nt) = null_treatment {
display_name.write_fmt(format_args!(" {}", nt))?;
}
if let Some(fe) = filter {
display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
}
if let Some(ob) = order_by {
display_name.write_fmt(format_args!(
" ORDER BY [{}]",
ob.iter()
.map(|o| format!("{o}"))
.collect::<Vec<String>>()
.join(", ")
))?;
}

Ok(display_name)
}

fn window_function_display_name(
&self,
params: &WindowFunctionParams,
) -> Result<String> {
let WindowFunctionParams {
args,
partition_by,
order_by,
window_frame,
null_treatment,
} = params;

let mut display_name = String::new();

if is_count_wildcard(args) {
display_name.write_str("count(*)")?;
} else {
display_name.write_fmt(format_args!(
"{}({})",
self.name(),
expr_vec_fmt!(args)
))?;
}

if let Some(null_treatment) = null_treatment {
display_name.write_fmt(format_args!(" {}", null_treatment))?;
}

if !partition_by.is_empty() {
display_name.write_fmt(format_args!(
" PARTITION BY [{}]",
expr_vec_fmt!(partition_by)
))?;
}

if !order_by.is_empty() {
display_name
.write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?;
};

display_name.write_fmt(format_args!(
" {} BETWEEN {} AND {}",
window_frame.units, window_frame.start_bound, window_frame.end_bound
))?;

Ok(display_name)
}

fn signature(&self) -> &Signature {
&self.signature
}
Expand Down Expand Up @@ -547,15 +364,6 @@ impl AggregateUDFImpl for Count {
}
}

fn is_count_wildcard(args: &[Expr]) -> bool {
match args {
[] => true, // count()
// All const should be coerced to int64 or rejected by the signature
[Expr::Literal(ScalarValue::Int64(Some(_)))] => true, // count(1)
_ => false, // More than one argument or non-matching cases
}
}

#[derive(Debug)]
struct CountAccumulator {
count: i64,
Expand Down
Loading
Loading