Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 78 additions & 6 deletions datafusion/physical-expr/src/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use datafusion_common::{
Result, ScalarValue, Statistics, assert_or_internal_err, internal_datafusion_err,
plan_err,
};
use datafusion_expr_common::interval_arithmetic::Interval;

use datafusion_physical_expr_common::metrics::ExecutionPlanMetricsSet;
use datafusion_physical_expr_common::metrics::ExpressionEvaluatorMetrics;
Expand Down Expand Up @@ -714,9 +715,12 @@ impl ProjectionExprs {
}
}
} else {
// TODO stats: estimate more statistics from expressions
// (expressions should compute their statistics themselves)
ColumnStatistics::new_unknown()
// Propagate statistics through expressions (CAST, arithmetic, etc.)
// using the interval arithmetic system (evaluate_bounds).
project_column_statistics_through_expr(
expr.as_ref(),
&stats.column_statistics,
)
};
column_statistics.push(col_stats);
}
Expand All @@ -726,6 +730,70 @@ impl ProjectionExprs {
}
}

/// Propagate min/max statistics through an expression using
/// [`PhysicalExpr::evaluate_bounds`]. Works for any expression that
/// implements `evaluate_bounds` (CAST, negation, arithmetic with literals, etc.).
fn project_column_statistics_through_expr(
expr: &dyn PhysicalExpr,
column_stats: &[ColumnStatistics],
) -> ColumnStatistics {
match compute_bounds_and_exactness(expr, column_stats) {
Some((interval, all_exact)) => ColumnStatistics {
min_value: to_precision(interval.lower().clone(), all_exact),
max_value: to_precision(interval.upper().clone(), all_exact),
null_count: Precision::Absent,
distinct_count: Precision::Absent,
sum_value: Precision::Absent,
byte_size: Precision::Absent,
},
None => ColumnStatistics::new_unknown(),
}
}

/// Convert a bound value to the appropriate [`Precision`] level.
fn to_precision(value: ScalarValue, exact: bool) -> Precision<ScalarValue> {
if value.is_null() {
Precision::Absent
} else if exact {
Precision::Exact(value)
} else {
Precision::Inexact(value)
}
}

/// Recursively compute the output [`Interval`] and whether all leaf
/// statistics are exact, in a single traversal of the expression tree.
fn compute_bounds_and_exactness(
expr: &dyn PhysicalExpr,
column_stats: &[ColumnStatistics],
) -> Option<(Interval, bool)> {
if let Some(col) = expr.downcast_ref::<Column>() {
let stats = &column_stats[col.index()];
let min = stats.min_value.get_value()?;
let max = stats.max_value.get_value()?;
let exact = stats.min_value.is_exact().unwrap_or(false)
&& stats.max_value.is_exact().unwrap_or(false);
return Some((Interval::try_new(min.clone(), max.clone()).ok()?, exact));
}

if let Some(lit) = expr.downcast_ref::<Literal>() {
let val = lit.value();
return Some((Interval::try_new(val.clone(), val.clone()).ok()?, true));
}

let children = expr.children();
let mut child_intervals = Vec::with_capacity(children.len());
let mut all_exact = true;
for child in &children {
let (interval, exact) =
compute_bounds_and_exactness(child.as_ref(), column_stats)?;
child_intervals.push(interval);
all_exact &= exact;
}
let child_refs: Vec<&Interval> = child_intervals.iter().collect();
Some((expr.evaluate_bounds(&child_refs).ok()?, all_exact))
}

impl<'a> IntoIterator for &'a ProjectionExprs {
type Item = &'a ProjectionExpr;
type IntoIter = std::slice::Iter<'a, ProjectionExpr>;
Expand Down Expand Up @@ -2772,13 +2840,17 @@ pub(crate) mod tests {
// Should have 2 column statistics
assert_eq!(output_stats.column_statistics.len(), 2);

// First column (expression) should have unknown statistics
// First column (col0 + 1) should have propagated min/max via evaluate_bounds
assert_eq!(
output_stats.column_statistics[0].distinct_count,
Precision::Absent
output_stats.column_statistics[0].min_value,
Precision::Exact(ScalarValue::Int64(Some(-3)))
);
assert_eq!(
output_stats.column_statistics[0].max_value,
Precision::Exact(ScalarValue::Int64(Some(22)))
);
assert_eq!(
output_stats.column_statistics[0].distinct_count,
Precision::Absent
);

Expand Down
46 changes: 33 additions & 13 deletions datafusion/physical-optimizer/src/aggregate_statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ use datafusion_common::Result;
use datafusion_common::config::ConfigOptions;
use datafusion_common::scalar::ScalarValue;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_physical_plan::aggregates::{AggregateExec, AggregateInputMode};
use datafusion_physical_plan::aggregates::{
AggregateExec, AggregateInputMode, AggregateMode,
};
use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr};
use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs};
Expand Down Expand Up @@ -49,7 +51,7 @@ impl PhysicalOptimizerRule for AggregateStatistics {
plan: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
if let Some(partial_agg_exec) = take_optimizable(&*plan) {
if let Some(partial_agg_exec) = take_optimizable(&plan) {
let partial_agg_exec = partial_agg_exec
.downcast_ref::<AggregateExec>()
.expect("take_optimizable() ensures that this is a AggregateExec");
Expand Down Expand Up @@ -106,19 +108,37 @@ impl PhysicalOptimizerRule for AggregateStatistics {
}
}

/// assert if the node passed as argument is a final `AggregateExec` node that can be optimized:
/// - its child (with possible intermediate layers) is a partial `AggregateExec` node
/// - they both have no grouping expression
/// Returns an `AggregateExec` whose statistics can be used to replace the
/// entire aggregate with literal values, if the plan is eligible.
///
/// If this is the case, return a ref to the partial `AggregateExec`, else `None`.
/// We would have preferred to return a casted ref to AggregateExec but the recursion requires
/// the `ExecutionPlan.children()` method that returns an owned reference.
fn take_optimizable(node: &dyn ExecutionPlan) -> Option<Arc<dyn ExecutionPlan>> {
if let Some(final_agg_exec) = node.downcast_ref::<AggregateExec>()
&& final_agg_exec.mode().input_mode() == AggregateInputMode::Partial
&& final_agg_exec.group_expr().is_empty()
/// Two patterns are recognized:
///
/// 1. **Final wrapping Partial** (multi-partition): A final `AggregateExec`
/// (input mode = `Partial`) with no GROUP BY whose descendant is a partial
/// `AggregateExec` (input mode = `Raw`) with no GROUP BY and no filters.
/// Returns the inner partial aggregate.
///
/// 2. **Single / SinglePartitioned** (single-partition): A `Single` or
/// `SinglePartitioned` `AggregateExec` with no GROUP BY and no filters.
/// Returns the aggregate itself.
fn take_optimizable(plan: &Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
let agg_exec = plan.downcast_ref::<AggregateExec>()?;

// Case 1: Single-mode aggregate — processes raw input, produces final output
if matches!(
agg_exec.mode(),
AggregateMode::Single | AggregateMode::SinglePartitioned
) && agg_exec.group_expr().is_empty()
&& agg_exec.filter_expr().iter().all(|e| e.is_none())
{
return Some(Arc::clone(plan));
}

// Case 2: Final aggregate wrapping a Partial aggregate
if agg_exec.mode().input_mode() == AggregateInputMode::Partial
&& agg_exec.group_expr().is_empty()
{
let mut child = Arc::clone(final_agg_exec.input());
let mut child = Arc::clone(agg_exec.input());
loop {
if let Some(partial_agg_exec) = child.downcast_ref::<AggregateExec>()
&& partial_agg_exec.mode().input_mode() == AggregateInputMode::Raw
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sqllogictest/test_files/clickbench.slt
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ logical_plan
03)----Projection: CAST(CAST(hits_raw.EventDate AS Int32) AS Date32) AS EventDate
04)------TableScan: hits_raw projection=[EventDate]
physical_plan
01)AggregateExec: mode=Single, gby=[], aggr=[min(hits.EventDate), max(hits.EventDate)]
02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, projection=[CAST(CAST(EventDate@5 AS Int32) AS Date32) as EventDate], file_type=parquet
01)ProjectionExec: expr=[2013-07-15 as min(hits.EventDate), 2013-07-15 as max(hits.EventDate)]
02)--PlaceholderRowExec

query DD
SELECT MIN("EventDate"), MAX("EventDate") FROM hits;
Expand Down
Loading