Skip to content
Open
10 changes: 10 additions & 0 deletions code_to_optimize/async_adder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import asyncio


async def async_add(a, b):
"""Simple async function that adds two numbers."""
await asyncio.sleep(0.001) # Simulate some async work
print(f"codeflash stdout: Adding {a} + {b}")
result = a + b
print(f"result: {result}")
return result
63 changes: 63 additions & 0 deletions code_to_optimize/async_examples/concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import asyncio
import time
import random


async def fake_api_call(delay, data):
await asyncio.sleep(delay)
return f"Processed: {data}"


async def cpu_bound_task(n):
result = 0
for i in range(n):
result += i ** 2
return result


async def some_api_call(urls):
tasks = [
fake_api_call(random.uniform(0.5, 2.0), url)
for i, url in enumerate(urls)
]
return await asyncio.gather(*tasks)


async def inefficient_task_creation():
results = []
for i in range(10):
task = asyncio.create_task(fake_api_call(0.5, f"data_{i}"))
result = await task
results.append(result)

return results


async def manga():
results = []

for i in range(5):
async_result = await fake_api_call(0.3, f"async_{i}")
results.append(async_result)

time.sleep(0.5)
cpu_result = sum(range(100000))
results.append(f"CPU result: {cpu_result}")

return results


if __name__ == "__main__":
async def main():
print("Running inefficient concurrency examples...")

urls = [f"https://api.example.com/data/{i}" for i in range(5)]
start_time = time.time()
results = await sequential_api_calls(urls)
print(f"Sequential calls took: {time.time() - start_time:.2f}s")

start_time = time.time()
results = await inefficient_task_creation()
print(f"Inefficient tasks took: {time.time() - start_time:.2f}s")

asyncio.run(main())
118 changes: 113 additions & 5 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
test_framework: str,
call_positions: list[CodePosition],
mode: TestingMode = TestingMode.BEHAVIOR,
is_async: bool = False,
) -> None:
self.mode: TestingMode = mode
self.function_object = function
Expand All @@ -64,13 +65,16 @@ def __init__(
self.module_path = module_path
self.test_framework = test_framework
self.call_positions = call_positions
self.is_async = is_async
if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
self.class_name = function.top_level_parent_name

def find_and_update_line_node(
self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None
) -> Iterable[ast.stmt] | None:
call_node = None
await_node = None

for node in ast.walk(test_node):
if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions):
call_node = node
Expand Down Expand Up @@ -120,6 +124,64 @@ def find_and_update_line_node(
node.keywords = call_node.keywords
break

# Check for awaited function calls
elif (
isinstance(node, ast.Await)
and isinstance(node.value, ast.Call)
and node_in_call_position(node.value, self.call_positions)
):
call_node = node.value
await_node = node
if isinstance(call_node.func, ast.Name):
function_name = call_node.func.id
call_node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
call_node.args = [
ast.Name(id=function_name, ctx=ast.Load()),
ast.Constant(value=self.module_path),
ast.Constant(value=test_class_name or None),
ast.Constant(value=node_name),
ast.Constant(value=self.function_object.qualified_name),
ast.Constant(value=index),
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
*(
[ast.Name(id="codeflash_cur", ctx=ast.Load()), ast.Name(id="codeflash_con", ctx=ast.Load())]
if self.mode == TestingMode.BEHAVIOR
else []
),
*call_node.args,
]
call_node.keywords = call_node.keywords
# Keep the await wrapper around the modified call
await_node.value = call_node
break
if isinstance(call_node.func, ast.Attribute):
function_to_test = call_node.func.attr
if function_to_test == self.function_object.function_name:
function_name = ast.unparse(call_node.func)
call_node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
call_node.args = [
ast.Name(id=function_name, ctx=ast.Load()),
ast.Constant(value=self.module_path),
ast.Constant(value=test_class_name or None),
ast.Constant(value=node_name),
ast.Constant(value=self.function_object.qualified_name),
ast.Constant(value=index),
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
*(
[
ast.Name(id="codeflash_cur", ctx=ast.Load()),
ast.Name(id="codeflash_con", ctx=ast.Load()),
]
if self.mode == TestingMode.BEHAVIOR
else []
),
*call_node.args,
]
call_node.keywords = call_node.keywords
# Keep the await wrapper around the modified call
await_node.value = call_node
break

if call_node is None:
return None
return [test_node]
Expand All @@ -129,9 +191,34 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
for inner_node in ast.walk(node):
if isinstance(inner_node, ast.FunctionDef):
self.visit_FunctionDef(inner_node, node.name)
elif isinstance(inner_node, ast.AsyncFunctionDef):
self.visit_AsyncFunctionDef(inner_node, node.name)

return node

def visit_AsyncFunctionDef(
self, node: ast.AsyncFunctionDef, test_class_name: str | None = None
) -> ast.AsyncFunctionDef:
sync_node = ast.FunctionDef(
name=node.name,
args=node.args,
body=node.body,
decorator_list=node.decorator_list,
returns=node.returns,
lineno=node.lineno,
col_offset=node.col_offset if hasattr(node, "col_offset") else 0,
)
processed_sync = self.visit_FunctionDef(sync_node, test_class_name)
return ast.AsyncFunctionDef(
name=processed_sync.name,
args=processed_sync.args,
body=processed_sync.body,
decorator_list=processed_sync.decorator_list,
returns=processed_sync.returns,
lineno=processed_sync.lineno,
col_offset=processed_sync.col_offset if hasattr(processed_sync, "col_offset") else 0,
)

def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
if node.name.startswith("test_"):
did_update = False
Expand Down Expand Up @@ -328,6 +415,7 @@ def inject_profiling_into_existing_test(
tests_project_root: Path,
test_framework: str,
mode: TestingMode = TestingMode.BEHAVIOR,
is_async: bool = False,
) -> tuple[bool, str | None]:
with test_path.open(encoding="utf8") as f:
test_code = f.read()
Expand All @@ -342,7 +430,9 @@ def inject_profiling_into_existing_test(
import_visitor.visit(tree)
func = import_visitor.imported_as

tree = InjectPerfOnly(func, test_module_path, test_framework, call_positions, mode=mode).visit(tree)
tree = InjectPerfOnly(func, test_module_path, test_framework, call_positions, mode=mode, is_async=is_async).visit(
tree
)
new_imports = [
ast.Import(names=[ast.alias(name="time")]),
ast.Import(names=[ast.alias(name="gc")]),
Expand All @@ -354,11 +444,11 @@ def inject_profiling_into_existing_test(
)
if test_framework == "unittest":
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
tree.body = [*new_imports, create_wrapper_function(mode), *tree.body]
tree.body = [*new_imports, create_wrapper_function(mode, is_async), *tree.body]
return True, isort.code(ast.unparse(tree), float_to_top=True)


def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.FunctionDef:
def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR, is_async: bool = False) -> ast.FunctionDef:
lineno = 1
wrapper_body: list[ast.stmt] = [
ast.Assign(
Expand Down Expand Up @@ -536,7 +626,15 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
),
ast.Assign(
targets=[ast.Name(id="return_value", ctx=ast.Store())],
value=ast.Call(
value=ast.Await(
value=ast.Call(
func=ast.Name(id="wrapped", ctx=ast.Load()),
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
)
)
if is_async
else ast.Call(
func=ast.Name(id="wrapped", ctx=ast.Load()),
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
Expand Down Expand Up @@ -703,7 +801,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
),
ast.Return(value=ast.Name(id="return_value", ctx=ast.Load()), lineno=lineno + 19),
]
return ast.FunctionDef(
func_def = ast.FunctionDef(
name="codeflash_wrap",
args=ast.arguments(
args=[
Expand All @@ -729,3 +827,13 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
decorator_list=[],
returns=None,
)
if is_async:
return ast.AsyncFunctionDef(
name="codeflash_wrap",
args=func_def.args,
body=func_def.body,
lineno=func_def.lineno,
decorator_list=func_def.decorator_list,
returns=func_def.returns,
)
return func_def
12 changes: 9 additions & 3 deletions codeflash/code_utils/static_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,19 @@ def get_first_top_level_object_def_ast(

def get_first_top_level_function_or_method_ast(
function_name: str, parents: list[FunctionParent], node: ast.AST
) -> ast.FunctionDef | None:
) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
if not parents:
return get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node)
result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node)
if result is None:
result = get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, node)
return result
if parents[0].type == "ClassDef" and (
class_node := get_first_top_level_object_def_ast(parents[0].name, ast.ClassDef, node)
):
return get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node)
result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node)
if result is None:
result = get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, class_node)
return result
return None


Expand Down
32 changes: 30 additions & 2 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ def visit_FunctionDef(self, node: FunctionDef) -> None:
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
)

def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
# Check if the async function has a return statement and add it to the list
if function_has_return_statement(node) and not function_is_a_property(node):
self.functions.append(
FunctionToOptimize(
function_name=node.name, file_path=self.file_path, parents=self.ast_path[:], is_async=True
)
)

def generic_visit(self, node: ast.AST) -> None:
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
self.ast_path.append(FunctionParent(node.name, node.__class__.__name__))
Expand All @@ -121,6 +130,7 @@ class FunctionToOptimize:
parents: A list of parent scopes, which could be classes or functions.
starting_line: The starting line number of the function in the file.
ending_line: The ending line number of the function in the file.
is_async: Whether this function is defined as async.

The qualified_name property provides the full name of the function, including
any parent class or function names. The qualified_name_with_modules_from_root
Expand All @@ -133,6 +143,7 @@ class FunctionToOptimize:
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
starting_line: Optional[int] = None
ending_line: Optional[int] = None
is_async: bool = False

@property
def top_level_parent_name(self) -> str:
Expand Down Expand Up @@ -221,6 +232,7 @@ def get_functions_to_optimize(
f"It might take about {humanize_runtime(functions_count * three_min_in_ns)} to fully optimize this project. Codeflash "
f"will keep opening pull requests as it finds optimizations."
)
console.rule()
return filtered_modified_functions, functions_count, trace_file_path


Expand Down Expand Up @@ -396,11 +408,27 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
)
)

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
if self.class_name is None and node.name == self.function_name:
self.is_top_level = True
self.function_has_args = any(
(
bool(node.args.args),
bool(node.args.kwonlyargs),
bool(node.args.kwarg),
bool(node.args.posonlyargs),
bool(node.args.vararg),
)
)

def visit_ClassDef(self, node: ast.ClassDef) -> None:
# iterate over the class methods
if node.name == self.class_name:
for body_node in node.body:
if isinstance(body_node, ast.FunctionDef) and body_node.name == self.function_name:
if (
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
and body_node.name == self.function_name
):
self.is_top_level = True
if any(
isinstance(decorator, ast.Name) and decorator.id == "classmethod"
Expand All @@ -418,7 +446,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
# This way, if we don't have the class name, we can still find the static method
for body_node in node.body:
if (
isinstance(body_node, ast.FunctionDef)
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
and body_node.name == self.function_name
and body_node.lineno in {self.line_no, self.line_no + 1}
and any(
Expand Down
Loading
Loading