Skip to content

Commit 0242767

Browse files
authored
Add value_from_statisics to AggregateUDFImpl, remove special case for min/max/count aggregate statistics (#12296)
* Removes min/max/count comparison based on name in aggregate statistics * Abstracting away value from statistics * Removing imports * Introduced StatisticsArgs * Fixed docs
1 parent ddb4fac commit 0242767

File tree

6 files changed

+154
-170
lines changed

6 files changed

+154
-170
lines changed

datafusion/expr/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ pub use logical_plan::*;
9090
pub use partition_evaluator::PartitionEvaluator;
9191
pub use sqlparser;
9292
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
93-
pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF};
93+
pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs};
9494
pub use udf::{ScalarUDF, ScalarUDFImpl};
9595
pub use udwf::{ReversedUDWF, WindowUDF, WindowUDFImpl};
9696
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};

datafusion/expr/src/udaf.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ use std::vec;
2626

2727
use arrow::datatypes::{DataType, Field};
2828

29-
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
29+
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics};
30+
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
3031

3132
use crate::expr::AggregateFunction;
3233
use crate::function::{
@@ -94,6 +95,19 @@ impl fmt::Display for AggregateUDF {
9495
}
9596
}
9697

98+
pub struct StatisticsArgs<'a> {
99+
pub statistics: &'a Statistics,
100+
pub return_type: &'a DataType,
101+
/// Whether the aggregate function is distinct.
102+
///
103+
/// ```sql
104+
/// SELECT COUNT(DISTINCT column1) FROM t;
105+
/// ```
106+
pub is_distinct: bool,
107+
/// The physical expression of arguments the aggregate function takes.
108+
pub exprs: &'a [Arc<dyn PhysicalExpr>],
109+
}
110+
97111
impl AggregateUDF {
98112
/// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
99113
///
@@ -244,6 +258,13 @@ impl AggregateUDF {
244258
self.inner.is_descending()
245259
}
246260

261+
pub fn value_from_stats(
262+
&self,
263+
statistics_args: &StatisticsArgs,
264+
) -> Option<ScalarValue> {
265+
self.inner.value_from_stats(statistics_args)
266+
}
267+
247268
/// See [`AggregateUDFImpl::default_value`] for more details.
248269
pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
249270
self.inner.default_value(data_type)
@@ -556,6 +577,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
556577
fn is_descending(&self) -> Option<bool> {
557578
None
558579
}
580+
// Return the value of the current UDF from the statistics
581+
fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
582+
None
583+
}
559584

560585
/// Returns default value of the function given the input is all `null`.
561586
///

datafusion/functions-aggregate/src/count.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
// under the License.
1717

1818
use ahash::RandomState;
19+
use datafusion_common::stats::Precision;
1920
use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
21+
use datafusion_physical_expr::expressions;
2022
use std::collections::HashSet;
2123
use std::ops::BitAnd;
2224
use std::{fmt::Debug, sync::Arc};
@@ -46,14 +48,15 @@ use datafusion_expr::{
4648
function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
4749
EmitTo, GroupsAccumulator, Signature, Volatility,
4850
};
49-
use datafusion_expr::{Expr, ReversedUDAF, TypeSignature};
51+
use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature};
5052
use datafusion_functions_aggregate_common::aggregate::count_distinct::{
5153
BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
5254
PrimitiveDistinctCountAccumulator,
5355
};
5456
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
5557
use datafusion_physical_expr_common::binary_map::OutputType;
5658

59+
use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
5760
make_udaf_expr_and_func!(
5861
Count,
5962
count,
@@ -291,6 +294,36 @@ impl AggregateUDFImpl for Count {
291294
fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
292295
Ok(ScalarValue::Int64(Some(0)))
293296
}
297+
298+
fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
299+
if statistics_args.is_distinct {
300+
return None;
301+
}
302+
if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows {
303+
if statistics_args.exprs.len() == 1 {
304+
// TODO optimize with exprs other than Column
305+
if let Some(col_expr) = statistics_args.exprs[0]
306+
.as_any()
307+
.downcast_ref::<expressions::Column>()
308+
{
309+
let current_val = &statistics_args.statistics.column_statistics
310+
[col_expr.index()]
311+
.null_count;
312+
if let &Precision::Exact(val) = current_val {
313+
return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
314+
}
315+
} else if let Some(lit_expr) = statistics_args.exprs[0]
316+
.as_any()
317+
.downcast_ref::<expressions::Literal>()
318+
{
319+
if lit_expr.value() == &COUNT_STAR_EXPANSION {
320+
return Some(ScalarValue::Int64(Some(num_rows as i64)));
321+
}
322+
}
323+
}
324+
}
325+
None
326+
}
294327
}
295328

296329
#[derive(Debug)]

datafusion/functions-aggregate/src/min_max.rs

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// under the License.
1616

1717
//! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function
18-
//! [`Min`] and [`MinAccumulator`] accumulator for the `max` function
18+
//! [`Min`] and [`MinAccumulator`] accumulator for the `min` function
1919
2020
// distributed with this work for additional information
2121
// regarding copyright ownership. The ASF licenses this file
@@ -49,10 +49,12 @@ use arrow::datatypes::{
4949
UInt8Type,
5050
};
5151
use arrow_schema::IntervalUnit;
52+
use datafusion_common::stats::Precision;
5253
use datafusion_common::{
53-
downcast_value, exec_err, internal_err, DataFusionError, Result,
54+
downcast_value, exec_err, internal_err, ColumnStatistics, DataFusionError, Result,
5455
};
5556
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
57+
use datafusion_physical_expr::expressions;
5658
use std::fmt::Debug;
5759

5860
use arrow::datatypes::i256;
@@ -63,10 +65,10 @@ use arrow::datatypes::{
6365
};
6466

6567
use datafusion_common::ScalarValue;
66-
use datafusion_expr::GroupsAccumulator;
6768
use datafusion_expr::{
6869
function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility,
6970
};
71+
use datafusion_expr::{GroupsAccumulator, StatisticsArgs};
7072
use half::f16;
7173
use std::ops::Deref;
7274

@@ -147,6 +149,54 @@ macro_rules! instantiate_min_accumulator {
147149
}};
148150
}
149151

152+
trait FromColumnStatistics {
153+
fn value_from_column_statistics(
154+
&self,
155+
stats: &ColumnStatistics,
156+
) -> Option<ScalarValue>;
157+
158+
fn value_from_statistics(
159+
&self,
160+
statistics_args: &StatisticsArgs,
161+
) -> Option<ScalarValue> {
162+
if let Precision::Exact(num_rows) = &statistics_args.statistics.num_rows {
163+
match *num_rows {
164+
0 => return ScalarValue::try_from(statistics_args.return_type).ok(),
165+
value if value > 0 => {
166+
let col_stats = &statistics_args.statistics.column_statistics;
167+
if statistics_args.exprs.len() == 1 {
168+
// TODO optimize with exprs other than Column
169+
if let Some(col_expr) = statistics_args.exprs[0]
170+
.as_any()
171+
.downcast_ref::<expressions::Column>()
172+
{
173+
return self.value_from_column_statistics(
174+
&col_stats[col_expr.index()],
175+
);
176+
}
177+
}
178+
}
179+
_ => {}
180+
}
181+
}
182+
None
183+
}
184+
}
185+
186+
impl FromColumnStatistics for Max {
187+
fn value_from_column_statistics(
188+
&self,
189+
col_stats: &ColumnStatistics,
190+
) -> Option<ScalarValue> {
191+
if let Precision::Exact(ref val) = col_stats.max_value {
192+
if !val.is_null() {
193+
return Some(val.clone());
194+
}
195+
}
196+
None
197+
}
198+
}
199+
150200
impl AggregateUDFImpl for Max {
151201
fn as_any(&self) -> &dyn std::any::Any {
152202
self
@@ -272,6 +322,7 @@ impl AggregateUDFImpl for Max {
272322
fn is_descending(&self) -> Option<bool> {
273323
Some(true)
274324
}
325+
275326
fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity {
276327
datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
277328
}
@@ -282,6 +333,9 @@ impl AggregateUDFImpl for Max {
282333
fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
283334
datafusion_expr::ReversedUDAF::Identical
284335
}
336+
fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
337+
self.value_from_statistics(statistics_args)
338+
}
285339
}
286340

287341
// Statically-typed version of min/max(array) -> ScalarValue for string types
@@ -926,6 +980,20 @@ impl Default for Min {
926980
}
927981
}
928982

983+
impl FromColumnStatistics for Min {
984+
fn value_from_column_statistics(
985+
&self,
986+
col_stats: &ColumnStatistics,
987+
) -> Option<ScalarValue> {
988+
if let Precision::Exact(ref val) = col_stats.min_value {
989+
if !val.is_null() {
990+
return Some(val.clone());
991+
}
992+
}
993+
None
994+
}
995+
}
996+
929997
impl AggregateUDFImpl for Min {
930998
fn as_any(&self) -> &dyn std::any::Any {
931999
self
@@ -1052,6 +1120,9 @@ impl AggregateUDFImpl for Min {
10521120
Some(false)
10531121
}
10541122

1123+
fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
1124+
self.value_from_statistics(statistics_args)
1125+
}
10551126
fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity {
10561127
datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
10571128
}

0 commit comments

Comments
 (0)