Skip to content

Commit 876959d

Browse files
committed
feat(13525): permit user-defined invariants on logical plan extensions
1 parent 2aff98e commit 876959d

File tree

4 files changed

+193
-5
lines changed

4 files changed

+193
-5
lines changed

datafusion/core/tests/user_defined/user_defined_plan.rs

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
//!
6060
6161
use std::fmt::Debug;
62+
use std::hash::Hash;
6263
use std::task::{Context, Poll};
6364
use std::{any::Any, collections::BTreeMap, fmt, sync::Arc};
6465

@@ -93,7 +94,7 @@ use datafusion::{
9394
use datafusion_common::config::ConfigOptions;
9495
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
9596
use datafusion_common::ScalarValue;
96-
use datafusion_expr::{FetchType, Projection, SortExpr};
97+
use datafusion_expr::{FetchType, Invariant, InvariantLevel, Projection, SortExpr};
9798
use datafusion_optimizer::optimizer::ApplyOrder;
9899
use datafusion_optimizer::AnalyzerRule;
99100
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
@@ -295,15 +296,71 @@ async fn topk_plan() -> Result<()> {
295296
Ok(())
296297
}
297298

299+
#[tokio::test]
300+
/// Run invariant checks on the logical plan extension [`TopKPlanNode`].
301+
async fn topk_invariants() -> Result<()> {
302+
// Test: pass an InvariantLevel::Always
303+
let pass = InvariantMock {
304+
should_fail_invariant: false,
305+
kind: InvariantLevel::Always,
306+
};
307+
let ctx = setup_table(make_topk_context_with_invariants(Some(pass))).await?;
308+
run_and_compare_query(ctx, "Topk context").await?;
309+
310+
// Test: fail an InvariantLevel::Always
311+
let fail = InvariantMock {
312+
should_fail_invariant: true,
313+
kind: InvariantLevel::Always,
314+
};
315+
let ctx = setup_table(make_topk_context_with_invariants(Some(fail))).await?;
316+
matches!(
317+
&*run_and_compare_query(ctx, "Topk context")
318+
.await
319+
.unwrap_err()
320+
.message(),
321+
"node fails check, such as improper inputs"
322+
);
323+
324+
// Test: pass an InvariantLevel::Executable
325+
let pass = InvariantMock {
326+
should_fail_invariant: false,
327+
kind: InvariantLevel::Executable,
328+
};
329+
let ctx = setup_table(make_topk_context_with_invariants(Some(pass))).await?;
330+
run_and_compare_query(ctx, "Topk context").await?;
331+
332+
// Test: fail an InvariantLevel::Executable
333+
let fail = InvariantMock {
334+
should_fail_invariant: true,
335+
kind: InvariantLevel::Executable,
336+
};
337+
let ctx = setup_table(make_topk_context_with_invariants(Some(fail))).await?;
338+
matches!(
339+
&*run_and_compare_query(ctx, "Topk context")
340+
.await
341+
.unwrap_err()
342+
.message(),
343+
"node fails check, such as improper inputs"
344+
);
345+
346+
Ok(())
347+
}
348+
298349
fn make_topk_context() -> SessionContext {
350+
make_topk_context_with_invariants(None)
351+
}
352+
353+
fn make_topk_context_with_invariants(
354+
invariant_mock: Option<InvariantMock>,
355+
) -> SessionContext {
299356
let config = SessionConfig::new().with_target_partitions(48);
300357
let runtime = Arc::new(RuntimeEnv::default());
301358
let state = SessionStateBuilder::new()
302359
.with_config(config)
303360
.with_runtime_env(runtime)
304361
.with_default_features()
305362
.with_query_planner(Arc::new(TopKQueryPlanner {}))
306-
.with_optimizer_rule(Arc::new(TopKOptimizerRule {}))
363+
.with_optimizer_rule(Arc::new(TopKOptimizerRule { invariant_mock }))
307364
.with_analyzer_rule(Arc::new(MyAnalyzerRule {}))
308365
.build();
309366
SessionContext::new_with_state(state)
@@ -336,7 +393,10 @@ impl QueryPlanner for TopKQueryPlanner {
336393
}
337394

338395
#[derive(Default, Debug)]
339-
struct TopKOptimizerRule {}
396+
struct TopKOptimizerRule {
397+
/// A testing-only hashable fixture.
398+
invariant_mock: Option<InvariantMock>,
399+
}
340400

341401
impl OptimizerRule for TopKOptimizerRule {
342402
fn name(&self) -> &str {
@@ -380,6 +440,7 @@ impl OptimizerRule for TopKOptimizerRule {
380440
k: fetch,
381441
input: input.as_ref().clone(),
382442
expr: expr[0].clone(),
443+
invariant_mock: self.invariant_mock.clone(),
383444
}),
384445
})));
385446
}
@@ -396,6 +457,10 @@ struct TopKPlanNode {
396457
/// The sort expression (this example only supports a single sort
397458
/// expr)
398459
expr: SortExpr,
460+
461+
/// A testing-only hashable fixture.
462+
/// For actual use, define the [`Invariant`] in the [`UserDefinedLogicalNodeCore::invariants`].
463+
invariant_mock: Option<InvariantMock>,
399464
}
400465

401466
impl Debug for TopKPlanNode {
@@ -406,6 +471,20 @@ impl Debug for TopKPlanNode {
406471
}
407472
}
408473

474+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
475+
struct InvariantMock {
476+
should_fail_invariant: bool,
477+
kind: InvariantLevel,
478+
}
479+
480+
fn invariant_helper_mock_ok(_: &LogicalPlan) -> Result<()> {
481+
Ok(())
482+
}
483+
484+
fn invariant_helper_mock_fails(_: &LogicalPlan) -> Result<()> {
485+
internal_err!("node fails check, such as improper inputs")
486+
}
487+
409488
impl UserDefinedLogicalNodeCore for TopKPlanNode {
410489
fn name(&self) -> &str {
411490
"TopK"
@@ -420,6 +499,26 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode {
420499
self.input.schema()
421500
}
422501

502+
fn invariants(&self) -> Vec<Invariant> {
503+
if let Some(InvariantMock {
504+
should_fail_invariant,
505+
kind,
506+
}) = self.invariant_mock.clone()
507+
{
508+
if should_fail_invariant {
509+
return vec![Invariant {
510+
kind,
511+
fun: Arc::new(invariant_helper_mock_fails),
512+
}];
513+
}
514+
return vec![Invariant {
515+
kind,
516+
fun: Arc::new(invariant_helper_mock_ok),
517+
}];
518+
}
519+
vec![] // same as default impl
520+
}
521+
423522
fn expressions(&self) -> Vec<Expr> {
424523
vec![self.expr.expr.clone()]
425524
}
@@ -440,6 +539,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode {
440539
k: self.k,
441540
input: inputs.swap_remove(0),
442541
expr: self.expr.with_expr(exprs.swap_remove(0)),
542+
invariant_mock: self.invariant_mock.clone(),
443543
})
444544
}
445545

datafusion/expr/src/logical_plan/extension.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ use std::cmp::Ordering;
2222
use std::hash::{Hash, Hasher};
2323
use std::{any::Any, collections::HashSet, fmt, sync::Arc};
2424

25+
use super::invariants::Invariant;
26+
use super::InvariantLevel;
27+
2528
/// This defines the interface for [`LogicalPlan`] nodes that can be
2629
/// used to extend DataFusion with custom relational operators.
2730
///
@@ -54,6 +57,22 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync {
5457
/// Return the output schema of this logical plan node.
5558
fn schema(&self) -> &DFSchemaRef;
5659

60+
/// Return the list of invariants.
61+
///
62+
/// Implementing this function enables the user to define the
63+
/// invariants for a given logical plan extension.
64+
fn invariants(&self) -> Vec<Invariant> {
65+
vec![]
66+
}
67+
68+
/// Perform check of invariants for the extension node.
69+
fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()> {
70+
self.invariants()
71+
.into_iter()
72+
.filter(|inv| check == inv.kind)
73+
.try_for_each(|inv| inv.check(plan))
74+
}
75+
5776
/// Returns all expressions in the current logical plan node. This should
5877
/// not include expressions of any inputs (aka non-recursively).
5978
///
@@ -244,6 +263,14 @@ pub trait UserDefinedLogicalNodeCore:
244263
/// Return the output schema of this logical plan node.
245264
fn schema(&self) -> &DFSchemaRef;
246265

266+
/// Return the list of invariants.
267+
///
268+
/// Implementing this function enables the user to define the
269+
/// invariants for a given logical plan extension.
270+
fn invariants(&self) -> Vec<Invariant> {
271+
vec![]
272+
}
273+
247274
/// Returns all expressions in the current logical plan node. This
248275
/// should not include expressions of any inputs (aka
249276
/// non-recursively). These expressions are used for optimizer
@@ -336,6 +363,10 @@ impl<T: UserDefinedLogicalNodeCore> UserDefinedLogicalNode for T {
336363
self.schema()
337364
}
338365

366+
fn invariants(&self) -> Vec<Invariant> {
367+
self.invariants()
368+
}
369+
339370
fn expressions(&self) -> Vec<Expr> {
340371
self.expressions()
341372
}

datafusion/expr/src/logical_plan/invariants.rs

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::sync::Arc;
19+
1820
use datafusion_common::{
1921
internal_err, plan_err,
2022
tree_node::{TreeNode, TreeNodeRecursion},
@@ -28,6 +30,24 @@ use crate::{
2830
Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window,
2931
};
3032

33+
use super::Extension;
34+
35+
pub type InvariantFn = Arc<dyn Fn(&LogicalPlan) -> Result<()> + Send + Sync>;
36+
37+
#[derive(Clone)]
38+
pub struct Invariant {
39+
pub kind: InvariantLevel,
40+
pub fun: InvariantFn,
41+
}
42+
43+
impl Invariant {
44+
/// Return an error if invariant does not hold true.
45+
pub fn check(&self, plan: &LogicalPlan) -> Result<()> {
46+
(self.fun)(plan)
47+
}
48+
}
49+
50+
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)]
3151
pub enum InvariantLevel {
3252
/// Invariants that are always true in DataFusion `LogicalPlan`s
3353
/// such as the number of expected children and no duplicated output fields
@@ -41,19 +61,54 @@ pub enum InvariantLevel {
4161
Executable,
4262
}
4363

64+
/// Apply the [`InvariantLevel::Always`] check at the root plan node only.
4465
pub fn assert_always_invariants(plan: &LogicalPlan) -> Result<()> {
4566
// Refer to <https://datafusion.apache.org/contributor-guide/specification/invariants.html#relation-name-tuples-in-logical-fields-and-logical-columns-are-unique>
4667
assert_unique_field_names(plan)?;
4768

4869
Ok(())
4970
}
5071

72+
/// Visit the plan nodes, and confirm the [`InvariantLevel::Executable`]
73+
/// as well as the less stringent [`InvariantLevel::Always`] checks.
5174
pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> {
75+
// Always invariants
5276
assert_always_invariants(plan)?;
77+
assert_valid_extension_nodes(plan, InvariantLevel::Always)?;
78+
79+
// Executable invariants
80+
assert_valid_extension_nodes(plan, InvariantLevel::Executable)?;
5381
assert_valid_semantic_plan(plan)?;
5482
Ok(())
5583
}
5684

85+
/// Asserts that the query plan, and subplan, extension nodes have valid invariants.
86+
///
87+
/// Refer to [`UserDefinedLogicalNode::check_invariants`](super::UserDefinedLogicalNode)
88+
/// for more details of user-provided extension node invariants.
89+
fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Result<()> {
90+
plan.apply_with_subqueries(|plan: &LogicalPlan| {
91+
if let LogicalPlan::Extension(Extension { node }) = plan {
92+
node.check_invariants(check, plan)?;
93+
}
94+
plan.apply_expressions(|expr| {
95+
// recursively look for subqueries
96+
expr.apply(|expr| {
97+
match expr {
98+
Expr::Exists(Exists { subquery, .. })
99+
| Expr::InSubquery(InSubquery { subquery, .. })
100+
| Expr::ScalarSubquery(subquery) => {
101+
assert_valid_extension_nodes(&subquery.subquery, check)?;
102+
}
103+
_ => {}
104+
};
105+
Ok(TreeNodeRecursion::Continue)
106+
})
107+
})
108+
})
109+
.map(|_| ())
110+
}
111+
57112
/// Returns an error if plan, and subplans, do not have unique fields.
58113
///
59114
/// This invariant is subject to change.
@@ -87,7 +142,7 @@ pub fn assert_expected_schema(schema: &DFSchemaRef, plan: &LogicalPlan) -> Resul
87142

88143
/// Asserts that the subqueries are structured properly with valid node placement.
89144
///
90-
/// Refer to [`check_subquery_expr`] for more details.
145+
/// Refer to [`check_subquery_expr`] for more details of the internal invariants.
91146
fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> {
92147
plan.apply_with_subqueries(|plan: &LogicalPlan| {
93148
plan.apply_expressions(|expr| {

datafusion/expr/src/logical_plan/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ pub mod display;
2121
pub mod dml;
2222
mod extension;
2323
pub(crate) mod invariants;
24-
pub use invariants::{assert_expected_schema, check_subquery_expr, InvariantLevel};
24+
pub use invariants::{
25+
assert_expected_schema, check_subquery_expr, Invariant, InvariantFn, InvariantLevel,
26+
};
2527
mod plan;
2628
mod statement;
2729
pub mod tree_node;

0 commit comments

Comments
 (0)