@@ -48,6 +48,39 @@ def is_argument_name(name: str, arguments_node: ast.arguments) -> bool:
48
48
)
49
49
50
50
51
+ class AsyncIOGatherRemover (ast .NodeTransformer ):
52
+ def _contains_asyncio_gather (self , node : ast .AST ) -> bool :
53
+ """Check if a node contains asyncio.gather calls."""
54
+ for child_node in ast .walk (node ):
55
+ if (
56
+ isinstance (child_node , ast .Call )
57
+ and isinstance (child_node .func , ast .Attribute )
58
+ and isinstance (child_node .func .value , ast .Name )
59
+ and child_node .func .value .id == "asyncio"
60
+ and child_node .func .attr == "gather"
61
+ ):
62
+ return True
63
+
64
+ if (
65
+ isinstance (child_node , ast .Call )
66
+ and isinstance (child_node .func , ast .Name )
67
+ and child_node .func .id == "gather"
68
+ ):
69
+ return True
70
+
71
+ return False
72
+
73
+ def visit_FunctionDef (self , node : ast .FunctionDef ) -> ast .FunctionDef | None :
74
+ if node .name .startswith ("test_" ) and self ._contains_asyncio_gather (node ):
75
+ return None
76
+ return self .generic_visit (node )
77
+
78
+ def visit_AsyncFunctionDef (self , node : ast .AsyncFunctionDef ) -> ast .AsyncFunctionDef | None :
79
+ if node .name .startswith ("test_" ) and self ._contains_asyncio_gather (node ):
80
+ return None
81
+ return self .generic_visit (node )
82
+
83
+
51
84
class InjectPerfOnly (ast .NodeTransformer ):
52
85
def __init__ (
53
86
self ,
@@ -397,6 +430,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
397
430
file_path = self .function .file_path ,
398
431
starting_line = self .function .starting_line ,
399
432
ending_line = self .function .ending_line ,
433
+ is_async = self .function .is_async ,
400
434
)
401
435
else :
402
436
self .imported_as = FunctionToOptimize (
@@ -405,6 +439,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
405
439
file_path = self .function .file_path ,
406
440
starting_line = self .function .starting_line ,
407
441
ending_line = self .function .ending_line ,
442
+ is_async = self .function .is_async ,
408
443
)
409
444
410
445
@@ -415,7 +450,6 @@ def inject_profiling_into_existing_test(
415
450
tests_project_root : Path ,
416
451
test_framework : str ,
417
452
mode : TestingMode = TestingMode .BEHAVIOR ,
418
- is_async : bool = False ,
419
453
) -> tuple [bool , str | None ]:
420
454
with test_path .open (encoding = "utf8" ) as f :
421
455
test_code = f .read ()
@@ -430,6 +464,13 @@ def inject_profiling_into_existing_test(
430
464
import_visitor .visit (tree )
431
465
func = import_visitor .imported_as
432
466
467
+ is_async = function_to_optimize .is_async
468
+ logger .debug (f"Using async status from discovery phase for { function_to_optimize .function_name } : { is_async } " )
469
+
470
+ if is_async :
471
+ asyncio_gather_remover = AsyncIOGatherRemover ()
472
+ tree = asyncio_gather_remover .visit (tree )
473
+
433
474
tree = InjectPerfOnly (func , test_module_path , test_framework , call_positions , mode = mode , is_async = is_async ).visit (
434
475
tree
435
476
)
@@ -444,11 +485,15 @@ def inject_profiling_into_existing_test(
444
485
)
445
486
if test_framework == "unittest" :
446
487
new_imports .append (ast .Import (names = [ast .alias (name = "timeout_decorator" )]))
488
+ if is_async :
489
+ new_imports .append (ast .Import (names = [ast .alias (name = "inspect" )]))
447
490
tree .body = [* new_imports , create_wrapper_function (mode , is_async ), * tree .body ]
448
491
return True , isort .code (ast .unparse (tree ), float_to_top = True )
449
492
450
493
451
- def create_wrapper_function (mode : TestingMode = TestingMode .BEHAVIOR , is_async : bool = False ) -> ast .FunctionDef :
494
+ def create_wrapper_function (
495
+ mode : TestingMode = TestingMode .BEHAVIOR , is_async : bool = False
496
+ ) -> ast .FunctionDef | ast .AsyncFunctionDef :
452
497
lineno = 1
453
498
wrapper_body : list [ast .stmt ] = [
454
499
ast .Assign (
@@ -624,22 +669,70 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR, is_async:
624
669
),
625
670
lineno = lineno + 11 ,
626
671
),
627
- ast .Assign (
628
- targets = [ast .Name (id = "return_value" , ctx = ast .Store ())],
629
- value = ast .Await (
630
- value = ast .Call (
631
- func = ast .Name (id = "wrapped" , ctx = ast .Load ()),
632
- args = [ast .Starred (value = ast .Name (id = "args" , ctx = ast .Load ()), ctx = ast .Load ())],
633
- keywords = [ast .keyword (arg = None , value = ast .Name (id = "kwargs" , ctx = ast .Load ()))],
634
- )
635
- )
672
+ # For async wrappers
673
+ # Call the wrapped function first, then check if result is awaitable before awaiting.
674
+ # This handles mixed scenarios where async tests might call both sync and async functions.
675
+ * (
676
+ [
677
+ ast .Assign (
678
+ targets = [ast .Name (id = "ret" , ctx = ast .Store ())],
679
+ value = ast .Call (
680
+ func = ast .Name (id = "wrapped" , ctx = ast .Load ()),
681
+ args = [ast .Starred (value = ast .Name (id = "args" , ctx = ast .Load ()), ctx = ast .Load ())],
682
+ keywords = [ast .keyword (arg = None , value = ast .Name (id = "kwargs" , ctx = ast .Load ()))],
683
+ ),
684
+ lineno = lineno + 12 ,
685
+ ),
686
+ ast .If (
687
+ test = ast .Call (
688
+ func = ast .Attribute (
689
+ value = ast .Name (id = "inspect" , ctx = ast .Load ()), attr = "isawaitable" , ctx = ast .Load ()
690
+ ),
691
+ args = [ast .Name (id = "ret" , ctx = ast .Load ())],
692
+ keywords = [],
693
+ ),
694
+ body = [
695
+ ast .Assign (
696
+ targets = [ast .Name (id = "counter" , ctx = ast .Store ())],
697
+ value = ast .Call (
698
+ func = ast .Attribute (
699
+ value = ast .Name (id = "time" , ctx = ast .Load ()),
700
+ attr = "perf_counter_ns" ,
701
+ ctx = ast .Load (),
702
+ ),
703
+ args = [],
704
+ keywords = [],
705
+ ),
706
+ lineno = lineno + 14 ,
707
+ ),
708
+ ast .Assign (
709
+ targets = [ast .Name (id = "return_value" , ctx = ast .Store ())],
710
+ value = ast .Await (value = ast .Name (id = "ret" , ctx = ast .Load ())),
711
+ lineno = lineno + 15 ,
712
+ ),
713
+ ],
714
+ orelse = [
715
+ ast .Assign (
716
+ targets = [ast .Name (id = "return_value" , ctx = ast .Store ())],
717
+ value = ast .Name (id = "ret" , ctx = ast .Load ()),
718
+ lineno = lineno + 16 ,
719
+ )
720
+ ],
721
+ lineno = lineno + 13 ,
722
+ ),
723
+ ]
636
724
if is_async
637
- else ast .Call (
638
- func = ast .Name (id = "wrapped" , ctx = ast .Load ()),
639
- args = [ast .Starred (value = ast .Name (id = "args" , ctx = ast .Load ()), ctx = ast .Load ())],
640
- keywords = [ast .keyword (arg = None , value = ast .Name (id = "kwargs" , ctx = ast .Load ()))],
641
- ),
642
- lineno = lineno + 12 ,
725
+ else [
726
+ ast .Assign (
727
+ targets = [ast .Name (id = "return_value" , ctx = ast .Store ())],
728
+ value = ast .Call (
729
+ func = ast .Name (id = "wrapped" , ctx = ast .Load ()),
730
+ args = [ast .Starred (value = ast .Name (id = "args" , ctx = ast .Load ()), ctx = ast .Load ())],
731
+ keywords = [ast .keyword (arg = None , value = ast .Name (id = "kwargs" , ctx = ast .Load ()))],
732
+ ),
733
+ lineno = lineno + 12 ,
734
+ )
735
+ ]
643
736
),
644
737
ast .Assign (
645
738
targets = [ast .Name (id = "codeflash_duration" , ctx = ast .Store ())],
0 commit comments