59
59
//!
60
60
61
61
use std:: fmt:: Debug ;
62
+ use std:: hash:: Hash ;
62
63
use std:: task:: { Context , Poll } ;
63
64
use std:: { any:: Any , collections:: BTreeMap , fmt, sync:: Arc } ;
64
65
@@ -93,7 +94,7 @@ use datafusion::{
93
94
use datafusion_common:: config:: ConfigOptions ;
94
95
use datafusion_common:: tree_node:: { Transformed , TransformedResult , TreeNode } ;
95
96
use datafusion_common:: ScalarValue ;
96
- use datafusion_expr:: { FetchType , Projection , SortExpr } ;
97
+ use datafusion_expr:: { FetchType , Invariant , InvariantLevel , Projection , SortExpr } ;
97
98
use datafusion_optimizer:: optimizer:: ApplyOrder ;
98
99
use datafusion_optimizer:: AnalyzerRule ;
99
100
use datafusion_physical_plan:: execution_plan:: { Boundedness , EmissionType } ;
@@ -295,15 +296,71 @@ async fn topk_plan() -> Result<()> {
295
296
Ok ( ( ) )
296
297
}
297
298
299
+ #[ tokio:: test]
300
+ /// Run invariant checks on the logical plan extension [`TopKPlanNode`].
301
+ async fn topk_invariants ( ) -> Result < ( ) > {
302
+ // Test: pass an InvariantLevel::Always
303
+ let pass = InvariantMock {
304
+ should_fail_invariant : false ,
305
+ kind : InvariantLevel :: Always ,
306
+ } ;
307
+ let ctx = setup_table ( make_topk_context_with_invariants ( Some ( pass) ) ) . await ?;
308
+ run_and_compare_query ( ctx, "Topk context" ) . await ?;
309
+
310
+ // Test: fail an InvariantLevel::Always
311
+ let fail = InvariantMock {
312
+ should_fail_invariant : true ,
313
+ kind : InvariantLevel :: Always ,
314
+ } ;
315
+ let ctx = setup_table ( make_topk_context_with_invariants ( Some ( fail) ) ) . await ?;
316
+ matches ! (
317
+ & * run_and_compare_query( ctx, "Topk context" )
318
+ . await
319
+ . unwrap_err( )
320
+ . message( ) ,
321
+ "node fails check, such as improper inputs"
322
+ ) ;
323
+
324
+ // Test: pass an InvariantLevel::Executable
325
+ let pass = InvariantMock {
326
+ should_fail_invariant : false ,
327
+ kind : InvariantLevel :: Executable ,
328
+ } ;
329
+ let ctx = setup_table ( make_topk_context_with_invariants ( Some ( pass) ) ) . await ?;
330
+ run_and_compare_query ( ctx, "Topk context" ) . await ?;
331
+
332
+ // Test: fail an InvariantLevel::Executable
333
+ let fail = InvariantMock {
334
+ should_fail_invariant : true ,
335
+ kind : InvariantLevel :: Executable ,
336
+ } ;
337
+ let ctx = setup_table ( make_topk_context_with_invariants ( Some ( fail) ) ) . await ?;
338
+ matches ! (
339
+ & * run_and_compare_query( ctx, "Topk context" )
340
+ . await
341
+ . unwrap_err( )
342
+ . message( ) ,
343
+ "node fails check, such as improper inputs"
344
+ ) ;
345
+
346
+ Ok ( ( ) )
347
+ }
348
+
298
349
fn make_topk_context ( ) -> SessionContext {
350
+ make_topk_context_with_invariants ( None )
351
+ }
352
+
353
+ fn make_topk_context_with_invariants (
354
+ invariant_mock : Option < InvariantMock > ,
355
+ ) -> SessionContext {
299
356
let config = SessionConfig :: new ( ) . with_target_partitions ( 48 ) ;
300
357
let runtime = Arc :: new ( RuntimeEnv :: default ( ) ) ;
301
358
let state = SessionStateBuilder :: new ( )
302
359
. with_config ( config)
303
360
. with_runtime_env ( runtime)
304
361
. with_default_features ( )
305
362
. with_query_planner ( Arc :: new ( TopKQueryPlanner { } ) )
306
- . with_optimizer_rule ( Arc :: new ( TopKOptimizerRule { } ) )
363
+ . with_optimizer_rule ( Arc :: new ( TopKOptimizerRule { invariant_mock } ) )
307
364
. with_analyzer_rule ( Arc :: new ( MyAnalyzerRule { } ) )
308
365
. build ( ) ;
309
366
SessionContext :: new_with_state ( state)
@@ -336,7 +393,10 @@ impl QueryPlanner for TopKQueryPlanner {
336
393
}
337
394
338
395
#[ derive( Default , Debug ) ]
339
- struct TopKOptimizerRule { }
396
+ struct TopKOptimizerRule {
397
+ /// A testing-only hashable fixture.
398
+ invariant_mock : Option < InvariantMock > ,
399
+ }
340
400
341
401
impl OptimizerRule for TopKOptimizerRule {
342
402
fn name ( & self ) -> & str {
@@ -380,6 +440,7 @@ impl OptimizerRule for TopKOptimizerRule {
380
440
k : fetch,
381
441
input : input. as_ref ( ) . clone ( ) ,
382
442
expr : expr[ 0 ] . clone ( ) ,
443
+ invariant_mock : self . invariant_mock . clone ( ) ,
383
444
} ) ,
384
445
} ) ) ) ;
385
446
}
@@ -396,6 +457,10 @@ struct TopKPlanNode {
396
457
/// The sort expression (this example only supports a single sort
397
458
/// expr)
398
459
expr : SortExpr ,
460
+
461
+ /// A testing-only hashable fixture.
462
+ /// For actual use, define the [`Invariant`] in the [`UserDefinedLogicalNodeCore::invariants`].
463
+ invariant_mock : Option < InvariantMock > ,
399
464
}
400
465
401
466
impl Debug for TopKPlanNode {
@@ -406,6 +471,20 @@ impl Debug for TopKPlanNode {
406
471
}
407
472
}
408
473
474
+ #[ derive( Debug , Clone , PartialEq , Eq , PartialOrd , Hash ) ]
475
+ struct InvariantMock {
476
+ should_fail_invariant : bool ,
477
+ kind : InvariantLevel ,
478
+ }
479
+
480
+ fn invariant_helper_mock_ok ( _: & LogicalPlan ) -> Result < ( ) > {
481
+ Ok ( ( ) )
482
+ }
483
+
484
+ fn invariant_helper_mock_fails ( _: & LogicalPlan ) -> Result < ( ) > {
485
+ internal_err ! ( "node fails check, such as improper inputs" )
486
+ }
487
+
409
488
impl UserDefinedLogicalNodeCore for TopKPlanNode {
410
489
fn name ( & self ) -> & str {
411
490
"TopK"
@@ -420,6 +499,26 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode {
420
499
self . input . schema ( )
421
500
}
422
501
502
+ fn invariants ( & self ) -> Vec < Invariant > {
503
+ if let Some ( InvariantMock {
504
+ should_fail_invariant,
505
+ kind,
506
+ } ) = self . invariant_mock . clone ( )
507
+ {
508
+ if should_fail_invariant {
509
+ return vec ! [ Invariant {
510
+ kind,
511
+ fun: Arc :: new( invariant_helper_mock_fails) ,
512
+ } ] ;
513
+ }
514
+ return vec ! [ Invariant {
515
+ kind,
516
+ fun: Arc :: new( invariant_helper_mock_ok) ,
517
+ } ] ;
518
+ }
519
+ vec ! [ ] // same as default impl
520
+ }
521
+
423
522
fn expressions ( & self ) -> Vec < Expr > {
424
523
vec ! [ self . expr. expr. clone( ) ]
425
524
}
@@ -440,6 +539,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode {
440
539
k : self . k ,
441
540
input : inputs. swap_remove ( 0 ) ,
442
541
expr : self . expr . with_expr ( exprs. swap_remove ( 0 ) ) ,
542
+ invariant_mock : self . invariant_mock . clone ( ) ,
443
543
} )
444
544
}
445
545
0 commit comments