59
59
//!
60
60
61
61
use std:: fmt:: Debug ;
62
+ use std:: hash:: Hash ;
62
63
use std:: task:: { Context , Poll } ;
63
64
use std:: { any:: Any , collections:: BTreeMap , fmt, sync:: Arc } ;
64
65
@@ -93,7 +94,7 @@ use datafusion::{
93
94
use datafusion_common:: config:: ConfigOptions ;
94
95
use datafusion_common:: tree_node:: { Transformed , TransformedResult , TreeNode } ;
95
96
use datafusion_common:: ScalarValue ;
96
- use datafusion_expr:: { FetchType , Projection , SortExpr } ;
97
+ use datafusion_expr:: { FetchType , InvariantLevel , Projection , SortExpr } ;
97
98
use datafusion_optimizer:: optimizer:: ApplyOrder ;
98
99
use datafusion_optimizer:: AnalyzerRule ;
99
100
use datafusion_physical_plan:: execution_plan:: { Boundedness , EmissionType } ;
@@ -295,20 +296,175 @@ async fn topk_plan() -> Result<()> {
295
296
Ok ( ( ) )
296
297
}
297
298
299
+ #[ tokio:: test]
300
+ /// Run invariant checks on the logical plan extension [`TopKPlanNode`].
301
+ async fn topk_invariants ( ) -> Result < ( ) > {
302
+ // Test: pass an InvariantLevel::Always
303
+ let pass = InvariantMock {
304
+ should_fail_invariant : false ,
305
+ kind : InvariantLevel :: Always ,
306
+ } ;
307
+ let ctx = setup_table ( make_topk_context_with_invariants ( Some ( pass) ) ) . await ?;
308
+ run_and_compare_query ( ctx, "Topk context" ) . await ?;
309
+
310
+ // Test: fail an InvariantLevel::Always
311
+ let fail = InvariantMock {
312
+ should_fail_invariant : true ,
313
+ kind : InvariantLevel :: Always ,
314
+ } ;
315
+ let ctx = setup_table ( make_topk_context_with_invariants ( Some ( fail) ) ) . await ?;
316
+ matches ! (
317
+ & * run_and_compare_query( ctx, "Topk context" )
318
+ . await
319
+ . unwrap_err( )
320
+ . message( ) ,
321
+ "node fails check, such as improper inputs"
322
+ ) ;
323
+
324
+ // Test: pass an InvariantLevel::Executable
325
+ let pass = InvariantMock {
326
+ should_fail_invariant : false ,
327
+ kind : InvariantLevel :: Executable ,
328
+ } ;
329
+ let ctx = setup_table ( make_topk_context_with_invariants ( Some ( pass) ) ) . await ?;
330
+ run_and_compare_query ( ctx, "Topk context" ) . await ?;
331
+
332
+ // Test: fail an InvariantLevel::Executable
333
+ let fail = InvariantMock {
334
+ should_fail_invariant : true ,
335
+ kind : InvariantLevel :: Executable ,
336
+ } ;
337
+ let ctx = setup_table ( make_topk_context_with_invariants ( Some ( fail) ) ) . await ?;
338
+ matches ! (
339
+ & * run_and_compare_query( ctx, "Topk context" )
340
+ . await
341
+ . unwrap_err( )
342
+ . message( ) ,
343
+ "node fails check, such as improper inputs"
344
+ ) ;
345
+
346
+ Ok ( ( ) )
347
+ }
348
+
349
+ #[ tokio:: test]
350
+ async fn topk_invariants_after_invalid_mutation ( ) -> Result < ( ) > {
351
+ // CONTROL
352
+ // Build a valid topK plan.
353
+ let config = SessionConfig :: new ( ) . with_target_partitions ( 48 ) ;
354
+ let runtime = Arc :: new ( RuntimeEnv :: default ( ) ) ;
355
+ let state = SessionStateBuilder :: new ( )
356
+ . with_config ( config)
357
+ . with_runtime_env ( runtime)
358
+ . with_default_features ( )
359
+ . with_query_planner ( Arc :: new ( TopKQueryPlanner { } ) )
360
+ // 1. adds a valid TopKPlanNode
361
+ . with_optimizer_rule ( Arc :: new ( TopKOptimizerRule {
362
+ invariant_mock : Some ( InvariantMock {
363
+ should_fail_invariant : false ,
364
+ kind : InvariantLevel :: Always ,
365
+ } ) ,
366
+ } ) )
367
+ . with_analyzer_rule ( Arc :: new ( MyAnalyzerRule { } ) )
368
+ . build ( ) ;
369
+ let ctx = setup_table ( SessionContext :: new_with_state ( state) ) . await ?;
370
+ run_and_compare_query ( ctx, "Topk context" ) . await ?;
371
+
372
+ // Test
373
+ // Build a valid topK plan.
374
+ // Then have an invalid mutation in an optimizer run.
375
+ let config = SessionConfig :: new ( ) . with_target_partitions ( 48 ) ;
376
+ let runtime = Arc :: new ( RuntimeEnv :: default ( ) ) ;
377
+ let state = SessionStateBuilder :: new ( )
378
+ . with_config ( config)
379
+ . with_runtime_env ( runtime)
380
+ . with_default_features ( )
381
+ . with_query_planner ( Arc :: new ( TopKQueryPlanner { } ) )
382
+ // 1. adds a valid TopKPlanNode
383
+ . with_optimizer_rule ( Arc :: new ( TopKOptimizerRule {
384
+ invariant_mock : Some ( InvariantMock {
385
+ should_fail_invariant : false ,
386
+ kind : InvariantLevel :: Always ,
387
+ } ) ,
388
+ } ) )
389
+ // 2. break the TopKPlanNode
390
+ . with_optimizer_rule ( Arc :: new ( OptimizerMakeExtensionNodeInvalid { } ) )
391
+ . with_analyzer_rule ( Arc :: new ( MyAnalyzerRule { } ) )
392
+ . build ( ) ;
393
+ let ctx = setup_table ( SessionContext :: new_with_state ( state) ) . await ?;
394
+ matches ! (
395
+ & * run_and_compare_query( ctx, "Topk context" )
396
+ . await
397
+ . unwrap_err( )
398
+ . message( ) ,
399
+ "node fails check, such as improper inputs"
400
+ ) ;
401
+
402
+ Ok ( ( ) )
403
+ }
404
+
298
405
fn make_topk_context ( ) -> SessionContext {
406
+ make_topk_context_with_invariants ( None )
407
+ }
408
+
409
+ fn make_topk_context_with_invariants (
410
+ invariant_mock : Option < InvariantMock > ,
411
+ ) -> SessionContext {
299
412
let config = SessionConfig :: new ( ) . with_target_partitions ( 48 ) ;
300
413
let runtime = Arc :: new ( RuntimeEnv :: default ( ) ) ;
301
414
let state = SessionStateBuilder :: new ( )
302
415
. with_config ( config)
303
416
. with_runtime_env ( runtime)
304
417
. with_default_features ( )
305
418
. with_query_planner ( Arc :: new ( TopKQueryPlanner { } ) )
306
- . with_optimizer_rule ( Arc :: new ( TopKOptimizerRule { } ) )
419
+ . with_optimizer_rule ( Arc :: new ( TopKOptimizerRule { invariant_mock } ) )
307
420
. with_analyzer_rule ( Arc :: new ( MyAnalyzerRule { } ) )
308
421
. build ( ) ;
309
422
SessionContext :: new_with_state ( state)
310
423
}
311
424
425
+ #[ derive( Debug ) ]
426
+ struct OptimizerMakeExtensionNodeInvalid ;
427
+
428
+ impl OptimizerRule for OptimizerMakeExtensionNodeInvalid {
429
+ fn name ( & self ) -> & str {
430
+ "OptimizerMakeExtensionNodeInvalid"
431
+ }
432
+
433
+ fn apply_order ( & self ) -> Option < ApplyOrder > {
434
+ Some ( ApplyOrder :: TopDown )
435
+ }
436
+
437
+ fn supports_rewrite ( & self ) -> bool {
438
+ true
439
+ }
440
+
441
+ // Example rewrite pass which impacts validity of the extension node.
442
+ fn rewrite (
443
+ & self ,
444
+ plan : LogicalPlan ,
445
+ _config : & dyn OptimizerConfig ,
446
+ ) -> Result < Transformed < LogicalPlan > , DataFusionError > {
447
+ if let LogicalPlan :: Extension ( Extension { node } ) = & plan {
448
+ if let Some ( prev) = node. as_any ( ) . downcast_ref :: < TopKPlanNode > ( ) {
449
+ return Ok ( Transformed :: yes ( LogicalPlan :: Extension ( Extension {
450
+ node : Arc :: new ( TopKPlanNode {
451
+ k : prev. k ,
452
+ input : prev. input . clone ( ) ,
453
+ expr : prev. expr . clone ( ) ,
454
+ // In a real use case, this rewriter could have change the number of inputs, etc
455
+ invariant_mock : Some ( InvariantMock {
456
+ should_fail_invariant : true ,
457
+ kind : InvariantLevel :: Always ,
458
+ } ) ,
459
+ } ) ,
460
+ } ) ) ) ;
461
+ }
462
+ } ;
463
+
464
+ Ok ( Transformed :: no ( plan) )
465
+ }
466
+ }
467
+
312
468
// ------ The implementation of the TopK code follows -----
313
469
314
470
#[ derive( Debug ) ]
@@ -336,7 +492,10 @@ impl QueryPlanner for TopKQueryPlanner {
336
492
}
337
493
338
494
#[ derive( Default , Debug ) ]
339
- struct TopKOptimizerRule { }
495
+ struct TopKOptimizerRule {
496
+ /// A testing-only hashable fixture.
497
+ invariant_mock : Option < InvariantMock > ,
498
+ }
340
499
341
500
impl OptimizerRule for TopKOptimizerRule {
342
501
fn name ( & self ) -> & str {
@@ -380,6 +539,7 @@ impl OptimizerRule for TopKOptimizerRule {
380
539
k : fetch,
381
540
input : input. as_ref ( ) . clone ( ) ,
382
541
expr : expr[ 0 ] . clone ( ) ,
542
+ invariant_mock : self . invariant_mock . clone ( ) ,
383
543
} ) ,
384
544
} ) ) ) ;
385
545
}
@@ -396,6 +556,10 @@ struct TopKPlanNode {
396
556
/// The sort expression (this example only supports a single sort
397
557
/// expr)
398
558
expr : SortExpr ,
559
+
560
+ /// A testing-only hashable fixture.
561
+ /// For actual use, define the [`Invariant`] in the [`UserDefinedLogicalNodeCore::invariants`].
562
+ invariant_mock : Option < InvariantMock > ,
399
563
}
400
564
401
565
impl Debug for TopKPlanNode {
@@ -406,6 +570,12 @@ impl Debug for TopKPlanNode {
406
570
}
407
571
}
408
572
573
+ #[ derive( Debug , Clone , PartialEq , Eq , PartialOrd , Hash ) ]
574
+ struct InvariantMock {
575
+ should_fail_invariant : bool ,
576
+ kind : InvariantLevel ,
577
+ }
578
+
409
579
impl UserDefinedLogicalNodeCore for TopKPlanNode {
410
580
fn name ( & self ) -> & str {
411
581
"TopK"
@@ -420,6 +590,19 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode {
420
590
self . input . schema ( )
421
591
}
422
592
593
+ fn check_invariants ( & self , check : InvariantLevel , _plan : & LogicalPlan ) -> Result < ( ) > {
594
+ if let Some ( InvariantMock {
595
+ should_fail_invariant,
596
+ kind,
597
+ } ) = self . invariant_mock . clone ( )
598
+ {
599
+ if should_fail_invariant && check == kind {
600
+ return internal_err ! ( "node fails check, such as improper inputs" ) ;
601
+ }
602
+ }
603
+ Ok ( ( ) )
604
+ }
605
+
423
606
fn expressions ( & self ) -> Vec < Expr > {
424
607
vec ! [ self . expr. expr. clone( ) ]
425
608
}
@@ -440,6 +623,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode {
440
623
k : self . k ,
441
624
input : inputs. swap_remove ( 0 ) ,
442
625
expr : self . expr . with_expr ( exprs. swap_remove ( 0 ) ) ,
626
+ invariant_mock : self . invariant_mock . clone ( ) ,
443
627
} )
444
628
}
445
629
0 commit comments