Skip to content

Commit d8bc49f

Browse files
wiedldalamb
andauthored
Provide user-defined invariants for logical node extensions. (#14329)
* feat(13525): permit user-defined invariants on logical plan extensions * test(13525): demonstrate extension node invariants catching improper mutation during an optimizer pass * chore: update docs * refactor: remove the extra Invariant interface around an FnMut, since it doesn't make sense for the extension node's checks --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent cfc7c60 commit d8bc49f

File tree

5 files changed

+263
-10
lines changed

5 files changed

+263
-10
lines changed

datafusion/core/tests/user_defined/user_defined_plan.rs

Lines changed: 187 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
//!
6060
6161
use std::fmt::Debug;
62+
use std::hash::Hash;
6263
use std::task::{Context, Poll};
6364
use std::{any::Any, collections::BTreeMap, fmt, sync::Arc};
6465

@@ -93,7 +94,7 @@ use datafusion::{
9394
use datafusion_common::config::ConfigOptions;
9495
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
9596
use datafusion_common::ScalarValue;
96-
use datafusion_expr::{FetchType, Projection, SortExpr};
97+
use datafusion_expr::{FetchType, InvariantLevel, Projection, SortExpr};
9798
use datafusion_optimizer::optimizer::ApplyOrder;
9899
use datafusion_optimizer::AnalyzerRule;
99100
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
@@ -295,20 +296,175 @@ async fn topk_plan() -> Result<()> {
295296
Ok(())
296297
}
297298

299+
#[tokio::test]
300+
/// Run invariant checks on the logical plan extension [`TopKPlanNode`].
301+
async fn topk_invariants() -> Result<()> {
302+
// Test: pass an InvariantLevel::Always
303+
let pass = InvariantMock {
304+
should_fail_invariant: false,
305+
kind: InvariantLevel::Always,
306+
};
307+
let ctx = setup_table(make_topk_context_with_invariants(Some(pass))).await?;
308+
run_and_compare_query(ctx, "Topk context").await?;
309+
310+
// Test: fail an InvariantLevel::Always
311+
let fail = InvariantMock {
312+
should_fail_invariant: true,
313+
kind: InvariantLevel::Always,
314+
};
315+
let ctx = setup_table(make_topk_context_with_invariants(Some(fail))).await?;
316+
matches!(
317+
&*run_and_compare_query(ctx, "Topk context")
318+
.await
319+
.unwrap_err()
320+
.message(),
321+
"node fails check, such as improper inputs"
322+
);
323+
324+
// Test: pass an InvariantLevel::Executable
325+
let pass = InvariantMock {
326+
should_fail_invariant: false,
327+
kind: InvariantLevel::Executable,
328+
};
329+
let ctx = setup_table(make_topk_context_with_invariants(Some(pass))).await?;
330+
run_and_compare_query(ctx, "Topk context").await?;
331+
332+
// Test: fail an InvariantLevel::Executable
333+
let fail = InvariantMock {
334+
should_fail_invariant: true,
335+
kind: InvariantLevel::Executable,
336+
};
337+
let ctx = setup_table(make_topk_context_with_invariants(Some(fail))).await?;
338+
matches!(
339+
&*run_and_compare_query(ctx, "Topk context")
340+
.await
341+
.unwrap_err()
342+
.message(),
343+
"node fails check, such as improper inputs"
344+
);
345+
346+
Ok(())
347+
}
348+
349+
#[tokio::test]
350+
async fn topk_invariants_after_invalid_mutation() -> Result<()> {
351+
// CONTROL
352+
// Build a valid topK plan.
353+
let config = SessionConfig::new().with_target_partitions(48);
354+
let runtime = Arc::new(RuntimeEnv::default());
355+
let state = SessionStateBuilder::new()
356+
.with_config(config)
357+
.with_runtime_env(runtime)
358+
.with_default_features()
359+
.with_query_planner(Arc::new(TopKQueryPlanner {}))
360+
// 1. adds a valid TopKPlanNode
361+
.with_optimizer_rule(Arc::new(TopKOptimizerRule {
362+
invariant_mock: Some(InvariantMock {
363+
should_fail_invariant: false,
364+
kind: InvariantLevel::Always,
365+
}),
366+
}))
367+
.with_analyzer_rule(Arc::new(MyAnalyzerRule {}))
368+
.build();
369+
let ctx = setup_table(SessionContext::new_with_state(state)).await?;
370+
run_and_compare_query(ctx, "Topk context").await?;
371+
372+
// Test
373+
// Build a valid topK plan.
374+
// Then have an invalid mutation in an optimizer run.
375+
let config = SessionConfig::new().with_target_partitions(48);
376+
let runtime = Arc::new(RuntimeEnv::default());
377+
let state = SessionStateBuilder::new()
378+
.with_config(config)
379+
.with_runtime_env(runtime)
380+
.with_default_features()
381+
.with_query_planner(Arc::new(TopKQueryPlanner {}))
382+
// 1. adds a valid TopKPlanNode
383+
.with_optimizer_rule(Arc::new(TopKOptimizerRule {
384+
invariant_mock: Some(InvariantMock {
385+
should_fail_invariant: false,
386+
kind: InvariantLevel::Always,
387+
}),
388+
}))
389+
// 2. break the TopKPlanNode
390+
.with_optimizer_rule(Arc::new(OptimizerMakeExtensionNodeInvalid {}))
391+
.with_analyzer_rule(Arc::new(MyAnalyzerRule {}))
392+
.build();
393+
let ctx = setup_table(SessionContext::new_with_state(state)).await?;
394+
matches!(
395+
&*run_and_compare_query(ctx, "Topk context")
396+
.await
397+
.unwrap_err()
398+
.message(),
399+
"node fails check, such as improper inputs"
400+
);
401+
402+
Ok(())
403+
}
404+
298405
fn make_topk_context() -> SessionContext {
406+
make_topk_context_with_invariants(None)
407+
}
408+
409+
fn make_topk_context_with_invariants(
410+
invariant_mock: Option<InvariantMock>,
411+
) -> SessionContext {
299412
let config = SessionConfig::new().with_target_partitions(48);
300413
let runtime = Arc::new(RuntimeEnv::default());
301414
let state = SessionStateBuilder::new()
302415
.with_config(config)
303416
.with_runtime_env(runtime)
304417
.with_default_features()
305418
.with_query_planner(Arc::new(TopKQueryPlanner {}))
306-
.with_optimizer_rule(Arc::new(TopKOptimizerRule {}))
419+
.with_optimizer_rule(Arc::new(TopKOptimizerRule { invariant_mock }))
307420
.with_analyzer_rule(Arc::new(MyAnalyzerRule {}))
308421
.build();
309422
SessionContext::new_with_state(state)
310423
}
311424

425+
#[derive(Debug)]
426+
struct OptimizerMakeExtensionNodeInvalid;
427+
428+
impl OptimizerRule for OptimizerMakeExtensionNodeInvalid {
429+
fn name(&self) -> &str {
430+
"OptimizerMakeExtensionNodeInvalid"
431+
}
432+
433+
fn apply_order(&self) -> Option<ApplyOrder> {
434+
Some(ApplyOrder::TopDown)
435+
}
436+
437+
fn supports_rewrite(&self) -> bool {
438+
true
439+
}
440+
441+
// Example rewrite pass which impacts validity of the extension node.
442+
fn rewrite(
443+
&self,
444+
plan: LogicalPlan,
445+
_config: &dyn OptimizerConfig,
446+
) -> Result<Transformed<LogicalPlan>, DataFusionError> {
447+
if let LogicalPlan::Extension(Extension { node }) = &plan {
448+
if let Some(prev) = node.as_any().downcast_ref::<TopKPlanNode>() {
449+
return Ok(Transformed::yes(LogicalPlan::Extension(Extension {
450+
node: Arc::new(TopKPlanNode {
451+
k: prev.k,
452+
input: prev.input.clone(),
453+
expr: prev.expr.clone(),
454+
// In a real use case, this rewriter could have change the number of inputs, etc
455+
invariant_mock: Some(InvariantMock {
456+
should_fail_invariant: true,
457+
kind: InvariantLevel::Always,
458+
}),
459+
}),
460+
})));
461+
}
462+
};
463+
464+
Ok(Transformed::no(plan))
465+
}
466+
}
467+
312468
// ------ The implementation of the TopK code follows -----
313469

314470
#[derive(Debug)]
@@ -336,7 +492,10 @@ impl QueryPlanner for TopKQueryPlanner {
336492
}
337493

338494
#[derive(Default, Debug)]
339-
struct TopKOptimizerRule {}
495+
struct TopKOptimizerRule {
496+
/// A testing-only hashable fixture.
497+
invariant_mock: Option<InvariantMock>,
498+
}
340499

341500
impl OptimizerRule for TopKOptimizerRule {
342501
fn name(&self) -> &str {
@@ -380,6 +539,7 @@ impl OptimizerRule for TopKOptimizerRule {
380539
k: fetch,
381540
input: input.as_ref().clone(),
382541
expr: expr[0].clone(),
542+
invariant_mock: self.invariant_mock.clone(),
383543
}),
384544
})));
385545
}
@@ -396,6 +556,10 @@ struct TopKPlanNode {
396556
/// The sort expression (this example only supports a single sort
397557
/// expr)
398558
expr: SortExpr,
559+
560+
/// A testing-only hashable fixture.
561+
/// For actual use, define the [`Invariant`] in the [`UserDefinedLogicalNodeCore::invariants`].
562+
invariant_mock: Option<InvariantMock>,
399563
}
400564

401565
impl Debug for TopKPlanNode {
@@ -406,6 +570,12 @@ impl Debug for TopKPlanNode {
406570
}
407571
}
408572

573+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
574+
struct InvariantMock {
575+
should_fail_invariant: bool,
576+
kind: InvariantLevel,
577+
}
578+
409579
impl UserDefinedLogicalNodeCore for TopKPlanNode {
410580
fn name(&self) -> &str {
411581
"TopK"
@@ -420,6 +590,19 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode {
420590
self.input.schema()
421591
}
422592

593+
fn check_invariants(&self, check: InvariantLevel, _plan: &LogicalPlan) -> Result<()> {
594+
if let Some(InvariantMock {
595+
should_fail_invariant,
596+
kind,
597+
}) = self.invariant_mock.clone()
598+
{
599+
if should_fail_invariant && check == kind {
600+
return internal_err!("node fails check, such as improper inputs");
601+
}
602+
}
603+
Ok(())
604+
}
605+
423606
fn expressions(&self) -> Vec<Expr> {
424607
vec![self.expr.expr.clone()]
425608
}
@@ -440,6 +623,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode {
440623
k: self.k,
441624
input: inputs.swap_remove(0),
442625
expr: self.expr.with_expr(exprs.swap_remove(0)),
626+
invariant_mock: self.invariant_mock.clone(),
443627
})
444628
}
445629

datafusion/expr/src/logical_plan/extension.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ use std::cmp::Ordering;
2222
use std::hash::{Hash, Hasher};
2323
use std::{any::Any, collections::HashSet, fmt, sync::Arc};
2424

25+
use super::InvariantLevel;
26+
2527
/// This defines the interface for [`LogicalPlan`] nodes that can be
2628
/// used to extend DataFusion with custom relational operators.
2729
///
@@ -54,6 +56,9 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync {
5456
/// Return the output schema of this logical plan node.
5557
fn schema(&self) -> &DFSchemaRef;
5658

59+
/// Perform check of invariants for the extension node.
60+
fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()>;
61+
5762
/// Returns all expressions in the current logical plan node. This should
5863
/// not include expressions of any inputs (aka non-recursively).
5964
///
@@ -244,6 +249,17 @@ pub trait UserDefinedLogicalNodeCore:
244249
/// Return the output schema of this logical plan node.
245250
fn schema(&self) -> &DFSchemaRef;
246251

252+
/// Perform check of invariants for the extension node.
253+
///
254+
/// This is the default implementation for extension nodes.
255+
fn check_invariants(
256+
&self,
257+
_check: InvariantLevel,
258+
_plan: &LogicalPlan,
259+
) -> Result<()> {
260+
Ok(())
261+
}
262+
247263
/// Returns all expressions in the current logical plan node. This
248264
/// should not include expressions of any inputs (aka
249265
/// non-recursively). These expressions are used for optimizer
@@ -336,6 +352,10 @@ impl<T: UserDefinedLogicalNodeCore> UserDefinedLogicalNode for T {
336352
self.schema()
337353
}
338354

355+
fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()> {
356+
self.check_invariants(check, plan)
357+
}
358+
339359
fn expressions(&self) -> Vec<Expr> {
340360
self.expressions()
341361
}

datafusion/expr/src/logical_plan/invariants.rs

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ use crate::{
2828
Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window,
2929
};
3030

31+
use super::Extension;
32+
33+
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)]
3134
pub enum InvariantLevel {
3235
/// Invariants that are always true in DataFusion `LogicalPlan`s
3336
/// such as the number of expected children and no duplicated output fields
@@ -41,19 +44,56 @@ pub enum InvariantLevel {
4144
Executable,
4245
}
4346

44-
pub fn assert_always_invariants(plan: &LogicalPlan) -> Result<()> {
47+
/// Apply the [`InvariantLevel::Always`] check at the current plan node only.
48+
///
49+
/// This does not recurs to any child nodes.
50+
pub fn assert_always_invariants_at_current_node(plan: &LogicalPlan) -> Result<()> {
4551
// Refer to <https://datafusion.apache.org/contributor-guide/specification/invariants.html#relation-name-tuples-in-logical-fields-and-logical-columns-are-unique>
4652
assert_unique_field_names(plan)?;
4753

4854
Ok(())
4955
}
5056

57+
/// Visit the plan nodes, and confirm the [`InvariantLevel::Executable`]
58+
/// as well as the less stringent [`InvariantLevel::Always`] checks.
5159
pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> {
52-
assert_always_invariants(plan)?;
60+
// Always invariants
61+
assert_always_invariants_at_current_node(plan)?;
62+
assert_valid_extension_nodes(plan, InvariantLevel::Always)?;
63+
64+
// Executable invariants
65+
assert_valid_extension_nodes(plan, InvariantLevel::Executable)?;
5366
assert_valid_semantic_plan(plan)?;
5467
Ok(())
5568
}
5669

70+
/// Asserts that the query plan, and subplan, extension nodes have valid invariants.
71+
///
72+
/// Refer to [`UserDefinedLogicalNode::check_invariants`](super::UserDefinedLogicalNode)
73+
/// for more details of user-provided extension node invariants.
74+
fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Result<()> {
75+
plan.apply_with_subqueries(|plan: &LogicalPlan| {
76+
if let LogicalPlan::Extension(Extension { node }) = plan {
77+
node.check_invariants(check, plan)?;
78+
}
79+
plan.apply_expressions(|expr| {
80+
// recursively look for subqueries
81+
expr.apply(|expr| {
82+
match expr {
83+
Expr::Exists(Exists { subquery, .. })
84+
| Expr::InSubquery(InSubquery { subquery, .. })
85+
| Expr::ScalarSubquery(subquery) => {
86+
assert_valid_extension_nodes(&subquery.subquery, check)?;
87+
}
88+
_ => {}
89+
};
90+
Ok(TreeNodeRecursion::Continue)
91+
})
92+
})
93+
})
94+
.map(|_| ())
95+
}
96+
5797
/// Returns an error if plan, and subplans, do not have unique fields.
5898
///
5999
/// This invariant is subject to change.
@@ -87,7 +127,7 @@ pub fn assert_expected_schema(schema: &DFSchemaRef, plan: &LogicalPlan) -> Resul
87127

88128
/// Asserts that the subqueries are structured properly with valid node placement.
89129
///
90-
/// Refer to [`check_subquery_expr`] for more details.
130+
/// Refer to [`check_subquery_expr`] for more details of the internal invariants.
91131
fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> {
92132
plan.apply_with_subqueries(|plan: &LogicalPlan| {
93133
plan.apply_expressions(|expr| {

0 commit comments

Comments
 (0)