Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 55 additions & 67 deletions sema4ai/src/sema4ai_code/robo/collect_actions_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@


def _collect_py_files(root_path: Path) -> Iterator[Path]:
# TODO: Improve this to use the same heuristics from sema4ai.actions
# https://github.com/Sema4AI/actions/blob/master/actions/src/sema4ai/actions/_collect_actions.py
for item in root_path.iterdir():
if item.is_dir():
yield from _collect_py_files(item)
Expand Down Expand Up @@ -234,6 +236,9 @@ def _collect_actions_from_ast(ast: ast_module.AST) -> Iterator[_ActionInfo]:
"action",
"query",
"predict",
"tool",
"resource",
"prompt",
]:
yield {"node": node, "kind": decorator.id}

Expand Down Expand Up @@ -273,13 +278,6 @@ def _get_ast_node_range(
}


DEFAULT_ACTION_SEARCH_GLOB = (
"*action*.py|*query*.py|*queries*.py|*predict*.py|*datasource*.py|*data_source*.py"
)

globs = DEFAULT_ACTION_SEARCH_GLOB.split("|")


def iter_actions_and_datasources(
root_directory: Path,
collect_datasources: bool = False,
Expand All @@ -289,71 +287,61 @@ def iter_actions_and_datasources(
give complete information, rather, it is a fast way to provide just simple
metadata such as the action name and location).
"""
import fnmatch

f: Path
for f in _collect_py_files(root_directory):
for glob in globs:
if fnmatch.fnmatch(f.name, glob):
try:
action_contents_file = f.read_bytes()
ast = ast_module.parse(action_contents_file, "<string>")
uri = uris.from_fs_path(str(f))

for node_info_action in _collect_actions_from_ast(ast):
function_def_node = node_info_action["node"]
node_range = _get_ast_node_range(function_def_node)
yield ActionInfoTypedDict(
uri=uri,
try:
action_contents_file = f.read_bytes()
ast = ast_module.parse(action_contents_file, "<string>")
uri = uris.from_fs_path(str(f))

for node_info_action in _collect_actions_from_ast(ast):
function_def_node = node_info_action["node"]
node_range = _get_ast_node_range(function_def_node)
yield ActionInfoTypedDict(
uri=uri,
range=node_range,
name=function_def_node.name,
kind=node_info_action["kind"],
)

if collect_datasources:
variables = _collect_variables(ast)
# Note: Instead of iterating over all nodes to collect datasources, we
# try to find the following structure:
#
# DataSourceVarName = Annotated[DataSource, DataSourceSpec(name="my_datasource")]
#
# Note that the DataSourceSpec(...) is a Call node inside the Annotated[...]
# which in turn must be inside an Assign node.

if collect_datasources:
for node_info_datasource in _collect_datasources(ast, variables):
ast_node = node_info_datasource["node"]
node_range = _get_ast_node_range(ast_node)
yield DatasourceInfoTypedDict(
range=node_range,
name=function_def_node.name,
kind=node_info_action["kind"],
uri=uri,
name=node_info_datasource.get("name") or "<name not found>",
engine=node_info_datasource.get(
"engine",
)
or "<engine not found>",
model_name=node_info_datasource.get("model_name"),
created_table=node_info_datasource.get("created_table"),
kind="datasource",
python_variable_name=node_info_datasource.get(
"python_variable_name"
),
setup_sql=node_info_datasource.get("setup_sql"),
setup_sql_files=node_info_datasource.get("setup_sql_files"),
description=node_info_datasource.get("description"),
file=node_info_datasource.get("file"),
)

if collect_datasources:
variables = _collect_variables(ast)
# Note: Instead of iterating over all nodes to collect datasources, we
# try to find the following structure:
#
# DataSourceVarName = Annotated[DataSource, DataSourceSpec(name="my_datasource")]
#
# Note that the DataSourceSpec(...) is a Call node inside the Annotated[...]
# which in turn must be inside an Assign node.

if collect_datasources:
for node_info_datasource in _collect_datasources(
ast, variables
):
ast_node = node_info_datasource["node"]
node_range = _get_ast_node_range(ast_node)
yield DatasourceInfoTypedDict(
range=node_range,
uri=uri,
name=node_info_datasource.get("name")
or "<name not found>",
engine=node_info_datasource.get(
"engine",
)
or "<engine not found>",
model_name=node_info_datasource.get("model_name"),
created_table=node_info_datasource.get(
"created_table"
),
kind="datasource",
python_variable_name=node_info_datasource.get(
"python_variable_name"
),
setup_sql=node_info_datasource.get("setup_sql"),
setup_sql_files=node_info_datasource.get(
"setup_sql_files"
),
description=node_info_datasource.get("description"),
file=node_info_datasource.get("file"),
)
except Exception as e:
log.error(
f"Unable to collect @action/@query/@predict/datasources from {f}. Error: {e}"
)
except Exception as e:
log.error(
f"Unable to collect @action/@query/@predict/datasources from {f}. Error: {e}"
)


def get_action_signature(
Expand Down
Loading