Skip to content

Commit b340997

Browse files
committed
async tests instrumentation
1 parent c373405 commit b340997

File tree

8 files changed

+621
-27
lines changed

8 files changed

+621
-27
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[tool.codeflash]
2+
# All paths are relative to this pyproject.toml's directory.
3+
module-root = "."
4+
tests-root = "tests"
5+
test-framework = "pytest"
6+
ignore-paths = []
7+
formatter-cmds = ["black $file"]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from time import sleep
2+
3+
async def tasked():
4+
sleep(0.002)
5+
return "Tasked"
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import pytest
2+
import asyncio
3+
from shocker import tasked
4+
5+
6+
@pytest.mark.asyncio
7+
async def test_tasked_basic():
8+
result = await tasked()
9+
assert result == "Tasked"
10+
11+
12+
@pytest.mark.asyncio
13+
async def test_tasked_gather():
14+
results = await asyncio.gather(*(tasked() for _ in range(5)))
15+
assert results == ["Tasked"] * 5
16+
17+
18+
def test_tasked_many_parallel_invocations():
19+
async def run_many():
20+
tasks = [tasked() for _ in range(1000)]
21+
results = await asyncio.gather(*tasks)
22+
return results
23+
24+
results = asyncio.run(run_many())
25+
assert len(results) == 1000, "Should return 1000 results"
26+
assert all(r == "Tasked" for r in results), "All results should be 'Tasked'"

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 110 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,39 @@ def is_argument_name(name: str, arguments_node: ast.arguments) -> bool:
4848
)
4949

5050

51+
class AsyncIOGatherRemover(ast.NodeTransformer):
52+
def _contains_asyncio_gather(self, node: ast.AST) -> bool:
53+
"""Check if a node contains asyncio.gather calls."""
54+
for child_node in ast.walk(node):
55+
if (
56+
isinstance(child_node, ast.Call)
57+
and isinstance(child_node.func, ast.Attribute)
58+
and isinstance(child_node.func.value, ast.Name)
59+
and child_node.func.value.id == "asyncio"
60+
and child_node.func.attr == "gather"
61+
):
62+
return True
63+
64+
if (
65+
isinstance(child_node, ast.Call)
66+
and isinstance(child_node.func, ast.Name)
67+
and child_node.func.id == "gather"
68+
):
69+
return True
70+
71+
return False
72+
73+
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef | None:
74+
if node.name.startswith("test_") and self._contains_asyncio_gather(node):
75+
return None
76+
return self.generic_visit(node)
77+
78+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef | None:
79+
if node.name.startswith("test_") and self._contains_asyncio_gather(node):
80+
return None
81+
return self.generic_visit(node)
82+
83+
5184
class InjectPerfOnly(ast.NodeTransformer):
5285
def __init__(
5386
self,
@@ -397,6 +430,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
397430
file_path=self.function.file_path,
398431
starting_line=self.function.starting_line,
399432
ending_line=self.function.ending_line,
433+
is_async=self.function.is_async,
400434
)
401435
else:
402436
self.imported_as = FunctionToOptimize(
@@ -405,6 +439,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
405439
file_path=self.function.file_path,
406440
starting_line=self.function.starting_line,
407441
ending_line=self.function.ending_line,
442+
is_async=self.function.is_async,
408443
)
409444

410445

@@ -415,7 +450,6 @@ def inject_profiling_into_existing_test(
415450
tests_project_root: Path,
416451
test_framework: str,
417452
mode: TestingMode = TestingMode.BEHAVIOR,
418-
is_async: bool = False,
419453
) -> tuple[bool, str | None]:
420454
with test_path.open(encoding="utf8") as f:
421455
test_code = f.read()
@@ -430,6 +464,13 @@ def inject_profiling_into_existing_test(
430464
import_visitor.visit(tree)
431465
func = import_visitor.imported_as
432466

467+
is_async = function_to_optimize.is_async
468+
logger.debug(f"Using async status from discovery phase for {function_to_optimize.function_name}: {is_async}")
469+
470+
if is_async:
471+
asyncio_gather_remover = AsyncIOGatherRemover()
472+
tree = asyncio_gather_remover.visit(tree)
473+
433474
tree = InjectPerfOnly(func, test_module_path, test_framework, call_positions, mode=mode, is_async=is_async).visit(
434475
tree
435476
)
@@ -444,11 +485,15 @@ def inject_profiling_into_existing_test(
444485
)
445486
if test_framework == "unittest":
446487
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
488+
if is_async:
489+
new_imports.append(ast.Import(names=[ast.alias(name="inspect")]))
447490
tree.body = [*new_imports, create_wrapper_function(mode, is_async), *tree.body]
448491
return True, isort.code(ast.unparse(tree), float_to_top=True)
449492

450493

451-
def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR, is_async: bool = False) -> ast.FunctionDef:
494+
def create_wrapper_function(
495+
mode: TestingMode = TestingMode.BEHAVIOR, is_async: bool = False
496+
) -> ast.FunctionDef | ast.AsyncFunctionDef:
452497
lineno = 1
453498
wrapper_body: list[ast.stmt] = [
454499
ast.Assign(
@@ -624,22 +669,70 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR, is_async:
624669
),
625670
lineno=lineno + 11,
626671
),
627-
ast.Assign(
628-
targets=[ast.Name(id="return_value", ctx=ast.Store())],
629-
value=ast.Await(
630-
value=ast.Call(
631-
func=ast.Name(id="wrapped", ctx=ast.Load()),
632-
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
633-
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
634-
)
635-
)
672+
# For async wrappers
673+
# Call the wrapped function first, then check if result is awaitable before awaiting.
674+
# This handles mixed scenarios where async tests might call both sync and async functions.
675+
*(
676+
[
677+
ast.Assign(
678+
targets=[ast.Name(id="ret", ctx=ast.Store())],
679+
value=ast.Call(
680+
func=ast.Name(id="wrapped", ctx=ast.Load()),
681+
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
682+
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
683+
),
684+
lineno=lineno + 12,
685+
),
686+
ast.If(
687+
test=ast.Call(
688+
func=ast.Attribute(
689+
value=ast.Name(id="inspect", ctx=ast.Load()), attr="isawaitable", ctx=ast.Load()
690+
),
691+
args=[ast.Name(id="ret", ctx=ast.Load())],
692+
keywords=[],
693+
),
694+
body=[
695+
ast.Assign(
696+
targets=[ast.Name(id="counter", ctx=ast.Store())],
697+
value=ast.Call(
698+
func=ast.Attribute(
699+
value=ast.Name(id="time", ctx=ast.Load()),
700+
attr="perf_counter_ns",
701+
ctx=ast.Load(),
702+
),
703+
args=[],
704+
keywords=[],
705+
),
706+
lineno=lineno + 14,
707+
),
708+
ast.Assign(
709+
targets=[ast.Name(id="return_value", ctx=ast.Store())],
710+
value=ast.Await(value=ast.Name(id="ret", ctx=ast.Load())),
711+
lineno=lineno + 15,
712+
),
713+
],
714+
orelse=[
715+
ast.Assign(
716+
targets=[ast.Name(id="return_value", ctx=ast.Store())],
717+
value=ast.Name(id="ret", ctx=ast.Load()),
718+
lineno=lineno + 16,
719+
)
720+
],
721+
lineno=lineno + 13,
722+
),
723+
]
636724
if is_async
637-
else ast.Call(
638-
func=ast.Name(id="wrapped", ctx=ast.Load()),
639-
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
640-
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
641-
),
642-
lineno=lineno + 12,
725+
else [
726+
ast.Assign(
727+
targets=[ast.Name(id="return_value", ctx=ast.Store())],
728+
value=ast.Call(
729+
func=ast.Name(id="wrapped", ctx=ast.Load()),
730+
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
731+
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
732+
),
733+
lineno=lineno + 12,
734+
)
735+
]
643736
),
644737
ast.Assign(
645738
targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())],

codeflash/discovery/functions_to_optimize.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
7575
parents: CSTNode | None = self.get_metadata(cst.metadata.ParentNodeProvider, node)
7676
ast_parents: list[FunctionParent] = []
7777
while parents is not None:
78-
if isinstance(parents, (cst.FunctionDef, cst.ClassDef)):
78+
if isinstance(parents, (cst.FunctionDef, cst.AsyncFunctionDef, cst.ClassDef)):
7979
ast_parents.append(FunctionParent(parents.name.value, parents.__class__.__name__))
8080
parents = self.get_metadata(cst.metadata.ParentNodeProvider, parents, default=None)
8181
self.functions.append(
@@ -85,6 +85,29 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
8585
parents=list(reversed(ast_parents)),
8686
starting_line=pos.start.line,
8787
ending_line=pos.end.line,
88+
is_async=False,
89+
)
90+
)
91+
92+
def visit_AsyncFunctionDef(self, node: cst.AsyncFunctionDef) -> None:
93+
return_visitor: ReturnStatementVisitor = ReturnStatementVisitor()
94+
node.visit(return_visitor)
95+
if return_visitor.has_return_statement:
96+
pos: CodeRange = self.get_metadata(cst.metadata.PositionProvider, node)
97+
parents: CSTNode | None = self.get_metadata(cst.metadata.ParentNodeProvider, node)
98+
ast_parents: list[FunctionParent] = []
99+
while parents is not None:
100+
if isinstance(parents, (cst.FunctionDef, cst.AsyncFunctionDef, cst.ClassDef)):
101+
ast_parents.append(FunctionParent(parents.name.value, parents.__class__.__name__))
102+
parents = self.get_metadata(cst.metadata.ParentNodeProvider, parents, default=None)
103+
self.functions.append(
104+
FunctionToOptimize(
105+
function_name=node.name.value,
106+
file_path=self.file_path,
107+
parents=list(reversed(ast_parents)),
108+
starting_line=pos.start.line,
109+
ending_line=pos.end.line,
110+
is_async=True,
88111
)
89112
)
90113

tests/test_instrument_all_and_run.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,12 @@
6262
gc.disable()
6363
try:
6464
counter = time.perf_counter_ns()
65-
return_value = await wrapped(*args, **kwargs)
65+
ret = wrapped(*args, **kwargs)
66+
if inspect.isawaitable(ret):
67+
counter = time.perf_counter_ns()
68+
return_value = await ret
69+
else:
70+
return_value = ret
6671
codeflash_duration = time.perf_counter_ns() - counter
6772
except Exception as e:
6873
codeflash_duration = time.perf_counter_ns() - counter
@@ -266,14 +271,15 @@ async def test_async_add():
266271
assert result == 5"""
267272

268273
expected = (
269-
"""import gc
274+
"""import asyncio
275+
import gc
276+
import inspect
270277
import os
271278
import sqlite3
272279
import time
273280
274281
import dill as pickle
275282
276-
import asyncio
277283
from code_to_optimize.async_adder import async_add
278284
279285
@@ -286,7 +292,7 @@ async def test_async_add():
286292
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
287293
codeflash_cur = codeflash_con.cursor()
288294
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)')
289-
result = await codeflash_wrap(async_add, '{module_path}', None, 'test_async_add', 'async_add', '1', codeflash_loop_index, codeflash_cur, codeflash_con, 2, 3)
295+
result = await codeflash_wrap(async_add, '{module_path}', None, 'test_async_add', 'async_add', '0', codeflash_loop_index, codeflash_cur, codeflash_con, 2, 3)
290296
assert result == 5
291297
codeflash_con.close()
292298
"""
@@ -320,15 +326,14 @@ async def test_async_add():
320326
project_root_path,
321327
"pytest",
322328
mode=TestingMode.BEHAVIOR,
323-
is_async=True,
329+
324330
)
325331
os.chdir(original_cwd)
326332
assert success
327333
assert new_test is not None
328-
assert "await wrapped(*args, **kwargs)" in new_test
329-
assert "async def codeflash_wrap" in new_test
330-
assert "await codeflash_wrap(async_add" in new_test
331-
334+
assert new_test.replace('"', "'") == expected.format(
335+
module_path="code_to_optimize.tests.pytest.test_async_adder_behavior_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
336+
).replace('"', "'")
332337
finally:
333338
fto_path.write_text(original_code, "utf-8")
334339
test_path.unlink(missing_ok=True)

0 commit comments

Comments
 (0)