@@ -346,6 +346,62 @@ async fn topk_invariants() -> Result<()> {
346
346
Ok ( ( ) )
347
347
}
348
348
349
+ #[ tokio:: test]
350
+ async fn topk_invariants_after_invalid_mutation ( ) -> Result < ( ) > {
351
+ // CONTROL
352
+ // Build a valid topK plan.
353
+ let config = SessionConfig :: new ( ) . with_target_partitions ( 48 ) ;
354
+ let runtime = Arc :: new ( RuntimeEnv :: default ( ) ) ;
355
+ let state = SessionStateBuilder :: new ( )
356
+ . with_config ( config)
357
+ . with_runtime_env ( runtime)
358
+ . with_default_features ( )
359
+ . with_query_planner ( Arc :: new ( TopKQueryPlanner { } ) )
360
+ // 1. adds a valid TopKPlanNode
361
+ . with_optimizer_rule ( Arc :: new ( TopKOptimizerRule {
362
+ invariant_mock : Some ( InvariantMock {
363
+ should_fail_invariant : false ,
364
+ kind : InvariantLevel :: Always ,
365
+ } ) ,
366
+ } ) )
367
+ . with_analyzer_rule ( Arc :: new ( MyAnalyzerRule { } ) )
368
+ . build ( ) ;
369
+ let ctx = setup_table ( SessionContext :: new_with_state ( state) ) . await ?;
370
+ run_and_compare_query ( ctx, "Topk context" ) . await ?;
371
+
372
+ // Test
373
+ // Build a valid topK plan.
374
+ // Then have an invalid mutation in an optimizer run.
375
+ let config = SessionConfig :: new ( ) . with_target_partitions ( 48 ) ;
376
+ let runtime = Arc :: new ( RuntimeEnv :: default ( ) ) ;
377
+ let state = SessionStateBuilder :: new ( )
378
+ . with_config ( config)
379
+ . with_runtime_env ( runtime)
380
+ . with_default_features ( )
381
+ . with_query_planner ( Arc :: new ( TopKQueryPlanner { } ) )
382
+ // 1. adds a valid TopKPlanNode
383
+ . with_optimizer_rule ( Arc :: new ( TopKOptimizerRule {
384
+ invariant_mock : Some ( InvariantMock {
385
+ should_fail_invariant : false ,
386
+ kind : InvariantLevel :: Always ,
387
+ } ) ,
388
+ } ) )
389
+ // 2. break the TopKPlanNode
390
+ . with_optimizer_rule ( Arc :: new ( OptimizerMakeExtensionNodeInvalid { } ) )
391
+ . with_analyzer_rule ( Arc :: new ( MyAnalyzerRule { } ) )
392
+ . build ( ) ;
393
+ let ctx = setup_table ( SessionContext :: new_with_state ( state) ) . await ?;
394
+ matches ! (
395
+ & * run_and_compare_query( ctx, "Topk context" )
396
+ . await
397
+ . unwrap_err( )
398
+ . message( ) ,
399
+ "node fails check, such as improper inputs"
400
+ ) ;
401
+
402
+ Ok ( ( ) )
403
+ }
404
+
349
405
fn make_topk_context ( ) -> SessionContext {
350
406
make_topk_context_with_invariants ( None )
351
407
}
@@ -366,6 +422,49 @@ fn make_topk_context_with_invariants(
366
422
SessionContext :: new_with_state ( state)
367
423
}
368
424
425
+ #[ derive( Debug ) ]
426
+ struct OptimizerMakeExtensionNodeInvalid ;
427
+
428
+ impl OptimizerRule for OptimizerMakeExtensionNodeInvalid {
429
+ fn name ( & self ) -> & str {
430
+ "OptimizerMakeExtensionNodeInvalid"
431
+ }
432
+
433
+ fn apply_order ( & self ) -> Option < ApplyOrder > {
434
+ Some ( ApplyOrder :: TopDown )
435
+ }
436
+
437
+ fn supports_rewrite ( & self ) -> bool {
438
+ true
439
+ }
440
+
441
+ // Example rewrite pass which impacts validity of the extension node.
442
+ fn rewrite (
443
+ & self ,
444
+ plan : LogicalPlan ,
445
+ _config : & dyn OptimizerConfig ,
446
+ ) -> Result < Transformed < LogicalPlan > , DataFusionError > {
447
+ if let LogicalPlan :: Extension ( Extension { node } ) = & plan {
448
+ if let Some ( prev) = node. as_any ( ) . downcast_ref :: < TopKPlanNode > ( ) {
449
+ return Ok ( Transformed :: yes ( LogicalPlan :: Extension ( Extension {
450
+ node : Arc :: new ( TopKPlanNode {
451
+ k : prev. k ,
452
+ input : prev. input . clone ( ) ,
453
+ expr : prev. expr . clone ( ) ,
454
+ // In a real use case, this rewriter could have change the number of inputs, etc
455
+ invariant_mock : Some ( InvariantMock {
456
+ should_fail_invariant : true ,
457
+ kind : InvariantLevel :: Always ,
458
+ } ) ,
459
+ } ) ,
460
+ } ) ) ) ;
461
+ }
462
+ } ;
463
+
464
+ Ok ( Transformed :: no ( plan) )
465
+ }
466
+ }
467
+
369
468
// ------ The implementation of the TopK code follows -----
370
469
371
470
#[ derive( Debug ) ]
0 commit comments