Skip to content

Commit b9cef8c

Browse files
Preserve constant values across union operations (#13805)
* Add value tracking to ConstExpr for improved union optimization * Update PartialEq impl * Minor change * Add docstring for ConstExpr value * Improve constant propagation across union partitions * Add assertion for across_partitions * fix fmt * Update properties.rs * Remove redundant constant removal loop * Remove unnecessary mut * Set across_partitions=true when both sides are constant * Extract and use constant values in filter expressions * Add initial SLT for constant value tracking across UNION ALL * Assign values to ConstExpr where possible * Revert "Set across_partitions=true when both sides are constant" This reverts commit 3051cd4. * Temporarily take value from literal * Lint fixes * Cargo fmt * Add get_expr_constant_value * Make `with_value()` accept optional value * Add todo * Move test to union.slt * Fix changed slt after merge * Simplify constexpr * Update properties.rs --------- Co-authored-by: berkaysynnada <[email protected]>
1 parent 482b489 commit b9cef8c

File tree

8 files changed

+303
-90
lines changed

8 files changed

+303
-90
lines changed

datafusion/physical-expr/src/equivalence/class.rs

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use std::fmt::Display;
2424
use std::sync::Arc;
2525

2626
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
27-
use datafusion_common::JoinType;
27+
use datafusion_common::{JoinType, ScalarValue};
2828
use datafusion_physical_expr_common::physical_expr::format_physical_expr_list;
2929

3030
use indexmap::{IndexMap, IndexSet};
@@ -55,13 +55,45 @@ use indexmap::{IndexMap, IndexSet};
5555
/// // create a constant expression from a physical expression
5656
/// let const_expr = ConstExpr::from(col);
5757
/// ```
58+
// TODO: Consider refactoring the `across_partitions` and `value` fields into an enum:
59+
//
60+
// ```
61+
// enum PartitionValues {
62+
// Uniform(Option<ScalarValue>), // Same value across all partitions
63+
// Heterogeneous(Vec<Option<ScalarValue>>) // Different values per partition
64+
// }
65+
// ```
66+
//
67+
// This would provide more flexible representation of partition values.
68+
// Note: This is a breaking change for the equivalence API and should be
69+
// addressed in a separate issue/PR.
5870
#[derive(Debug, Clone)]
5971
pub struct ConstExpr {
6072
/// The expression that is known to be constant (e.g. a `Column`)
6173
expr: Arc<dyn PhysicalExpr>,
6274
/// Does the constant have the same value across all partitions? See
6375
/// struct docs for more details
64-
across_partitions: bool,
76+
across_partitions: AcrossPartitions,
77+
}
78+
79+
#[derive(PartialEq, Clone, Debug)]
80+
/// Represents whether a constant expression's value is uniform or varies across partitions.
81+
///
82+
/// The `AcrossPartitions` enum is used to describe the nature of a constant expression
83+
/// in a physical execution plan:
84+
///
85+
/// - `Heterogeneous`: The constant expression may have different values for different partitions.
86+
/// - `Uniform(Option<ScalarValue>)`: The constant expression has the same value across all partitions,
87+
/// or is `None` if the value is not specified.
88+
pub enum AcrossPartitions {
89+
Heterogeneous,
90+
Uniform(Option<ScalarValue>),
91+
}
92+
93+
impl Default for AcrossPartitions {
94+
fn default() -> Self {
95+
Self::Heterogeneous
96+
}
6597
}
6698

6799
impl PartialEq for ConstExpr {
@@ -79,23 +111,23 @@ impl ConstExpr {
79111
Self {
80112
expr,
81113
// By default, assume constant expressions are not same across partitions.
82-
across_partitions: false,
114+
across_partitions: Default::default(),
83115
}
84116
}
85117

86118
/// Set the `across_partitions` flag
87119
///
88120
/// See struct docs for more details
89-
pub fn with_across_partitions(mut self, across_partitions: bool) -> Self {
121+
pub fn with_across_partitions(mut self, across_partitions: AcrossPartitions) -> Self {
90122
self.across_partitions = across_partitions;
91123
self
92124
}
93125

94126
/// Is the expression the same across all partitions?
95127
///
96128
/// See struct docs for more details
97-
pub fn across_partitions(&self) -> bool {
98-
self.across_partitions
129+
pub fn across_partitions(&self) -> AcrossPartitions {
130+
self.across_partitions.clone()
99131
}
100132

101133
pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
@@ -113,7 +145,7 @@ impl ConstExpr {
113145
let maybe_expr = f(&self.expr);
114146
maybe_expr.map(|expr| Self {
115147
expr,
116-
across_partitions: self.across_partitions,
148+
across_partitions: self.across_partitions.clone(),
117149
})
118150
}
119151

@@ -143,14 +175,20 @@ impl ConstExpr {
143175
}
144176
}
145177

146-
/// Display implementation for `ConstExpr`
147-
///
148-
/// Example `c` or `c(across_partitions)`
149178
impl Display for ConstExpr {
150-
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
179+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151180
write!(f, "{}", self.expr)?;
152-
if self.across_partitions {
153-
write!(f, "(across_partitions)")?;
181+
match &self.across_partitions {
182+
AcrossPartitions::Heterogeneous => {
183+
write!(f, "(heterogeneous)")?;
184+
}
185+
AcrossPartitions::Uniform(value) => {
186+
if let Some(val) = value {
187+
write!(f, "(uniform: {})", val)?;
188+
} else {
189+
write!(f, "(uniform: unknown)")?;
190+
}
191+
}
154192
}
155193
Ok(())
156194
}

datafusion/physical-expr/src/equivalence/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ mod ordering;
2727
mod projection;
2828
mod properties;
2929

30-
pub use class::{ConstExpr, EquivalenceClass, EquivalenceGroup};
30+
pub use class::{AcrossPartitions, ConstExpr, EquivalenceClass, EquivalenceGroup};
3131
pub use ordering::OrderingEquivalenceClass;
3232
pub use projection::ProjectionMapping;
3333
pub use properties::{

datafusion/physical-expr/src/equivalence/ordering.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ mod tests {
262262
};
263263
use crate::expressions::{col, BinaryExpr, Column};
264264
use crate::utils::tests::TestScalarUDF;
265-
use crate::{ConstExpr, PhysicalExpr, PhysicalSortExpr};
265+
use crate::{AcrossPartitions, ConstExpr, PhysicalExpr, PhysicalSortExpr};
266266

267267
use arrow::datatypes::{DataType, Field, Schema};
268268
use arrow_schema::SortOptions;
@@ -583,9 +583,10 @@ mod tests {
583583
let eq_group = EquivalenceGroup::new(eq_group);
584584
eq_properties.add_equivalence_group(eq_group);
585585

586-
let constants = constants
587-
.into_iter()
588-
.map(|expr| ConstExpr::from(expr).with_across_partitions(true));
586+
let constants = constants.into_iter().map(|expr| {
587+
ConstExpr::from(expr)
588+
.with_across_partitions(AcrossPartitions::Uniform(None))
589+
});
589590
eq_properties = eq_properties.with_constants(constants);
590591

591592
let reqs = convert_to_sort_exprs(&reqs);

0 commit comments

Comments
 (0)