Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code loading chunk size #15

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ vectorstore_settings:
- { repo: "lvhuyen/SparkAid", branch: "master"}
from_json: False
from_store: True
chunk_size: 1000
type: connect
filter_in_iteration: False

api_ref:
vector_store_path: "/raid/shared/masterproject2024/vector_stores/api/nv_split512"
Expand All @@ -37,6 +39,7 @@ iteration_limit: 3
# Linter Settings
linter_config :
enabled_linters: # List of linters to use. Possible values: 'pylint', 'mypy', 'flake8', 'spark_connect'
- syntax
- pylint
- spark_connect
feedback_types: # Return only these severities. Possible values: 'error', 'warning', 'convention' (maybe more)
Expand Down
29 changes: 15 additions & 14 deletions src/linter/python_linter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import subprocess
import hydra
from linter.python_linter.linter import PythonLinter
from linter.python_linter.linter import PythonLinter, TreeSitterLinter
from linter.python_linter.spark_connect_matcher import (
RDDApiMatcher,
MapPartitionsMatcher,
Expand Down Expand Up @@ -172,25 +172,26 @@ def run_flake8(code):
return []


def lint_codestring(code, lint_config):
def lint_codestring(code, lint_config, with_compile: bool = True):
"""
Lints the given code string and returns the diagnostics as a JSON object.
"""
diagnostics = []

# check if "code" does actually contain code and is parseable with ast
try:
compile(code, "temp_lint_code.py", "exec")
except Exception as e:
return [
{
"message": "Syntax error: " + str(e),
"line": 0,
"col": 0,
"type": "syntax_error",
"linter": "syntax",
}
]
if "syntax" in lint_config.enabled_linters:
try:
compile(code, "temp_lint_code.py", "exec")
except Exception as e:
return [
{
"message": "Syntax error: " + str(e),
"line": 0,
"col": 0,
"type": "syntax_error",
"linter": "syntax",
}
]

if "spark_connect" in lint_config.enabled_linters:
diagnostics += format_diagnostics(
Expand Down
29 changes: 29 additions & 0 deletions src/linter/python_linter/linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,35 @@
Log4JMatcher,
CommandContextMatcher,
)
from tree_sitter import Language, Parser
import tree_sitter_python as tspython

PY_LANGUAGE = Language(tspython.language())
parser = Parser(PY_LANGUAGE)


class TreeSitterLinter:
def __init__(self):
self.matchers: List[Matcher] = []

def add_matcher(self, matcher: Matcher):
"""Add a matcher to the linter."""
self.matchers.append(matcher)

def lint(self, code: str) -> List[Dict]:
"""Parse code with Tree-Sitter and run all matchers."""
tree = parser.parse(bytes(code, "utf-8"))
diagnostics = []

# Traverse the tree and apply matchers
def traverse(node):
for matcher in self.matchers:
diagnostics.extend(matcher.lint(node))
for child in node.children:
traverse(child)

traverse(tree.root_node)
return diagnostics


class PythonLinter:
Expand Down
171 changes: 141 additions & 30 deletions src/linter/python_linter/spark_connect_matcher.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,18 @@
from typing import Iterator, List, Dict
from typing import Iterator, Dict
from tree_sitter import Node
import ast


class Matcher:
def lint(self, node: ast.AST) -> Iterator[Dict]:

def lint(self, node) -> Iterator[Dict]:
"""
Analyze a node and yield diagnostics as dictionaries.
Analyze a Tree-Sitter node and yield diagnostics as dictionaries.
Subclasses must override this method.
"""
raise NotImplementedError("Subclasses must implement 'lint'")


class RddAttributeMatcher(Matcher):
def lint(self, node: ast.AST) -> Iterator[Dict]:
if isinstance(node, ast.Attribute) and node.attr == "rdd":
yield {
"message_id": "E9009",
"message": "Accessing 'rdd' from a DataFrame is not allowed. Use DataFrame APIs instead.",
"line": node.lineno,
"col": node.col_offset,
}


class JvmAccessMatcher(Matcher):
_FIELDS = [
"_jvm",
Expand All @@ -31,7 +22,8 @@ class JvmAccessMatcher(Matcher):
"_jsparkSession",
]

def lint(self, node: ast.AST) -> Iterator[Dict]:
def lint(self, node) -> Iterator[Dict]:
# Original AST logic for checking Spark Driver JVM access
if isinstance(node, ast.Attribute) and node.attr in self._FIELDS:
yield {
"message_id": "E9010",
Expand All @@ -40,6 +32,17 @@ def lint(self, node: ast.AST) -> Iterator[Dict]:
"col": node.col_offset,
}

# Tree-Sitter logic for checking Spark Driver JVM access
if isinstance(node, Node) and node.type == "attribute":
for field in self._FIELDS:
if field in node.text.decode("utf-8"):
yield {
"message_id": "E9010",
"message": f"Cannot access Spark Driver JVM field '{field}' in shared clusters.",
"line": node.start_point[0] + 1,
"col": node.start_point[1],
}


class RDDApiMatcher(Matcher):
_SC_METHODS = [
Expand All @@ -61,7 +64,8 @@ class RDDApiMatcher(Matcher):
"wholeTextFiles",
]

def lint(self, node: ast.AST) -> Iterator[Dict]:
def lint(self, node) -> Iterator[Dict]:
# Original AST logic for detecting SparkContext methods
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
if node.func.attr in self._SC_METHODS:
yield {
Expand All @@ -72,29 +76,87 @@ def lint(self, node: ast.AST) -> Iterator[Dict]:
"col": node.col_offset,
}

# Tree-Sitter logic: Detecting calls to SparkContext methods
if isinstance(node, Node) and node.type == "call":
function_node = node.child_by_field_name("function")
if function_node:
func_name = function_node.text.decode("utf-8").split(".")[-1]
if func_name in self._SC_METHODS:
yield {
"message_id": "E9008",
"message": f"Usage of SparkContext method '{func_name}' is not allowed. "
f"Use DataFrame APIs instead.",
"line": node.start_point[0] + 1,
"col": node.start_point[1],
}


class RddAttributeMatcher(Matcher):

def lint(self, node) -> Iterator[Dict]:
# Original AST logic for 'rdd' access
if isinstance(node, ast.Attribute) and node.attr == "rdd":
yield {
"message_id": "E9009",
"message": "Accessing 'rdd' from a DataFrame is not allowed. Use DataFrame APIs instead.",
"line": node.lineno,
"col": node.col_offset,
}

# Tree-Sitter logic: Checking for '.rdd' attribute in Tree-Sitter nodes
if isinstance(node, Node) and node.type == "attribute":
if node.text.decode("utf-8").endswith(".rdd"):
yield {
"message_id": "E9009",
"message": "Accessing 'rdd' from a DataFrame is not allowed. Use DataFrame APIs instead.",
"line": node.start_point[0] + 1, # 0-based to 1-based
"col": node.start_point[1],
}


class SparkSqlContextMatcher(Matcher):
_ATTRIBUTES = ["sc", "sqlContext", "sparkContext"]
_KNOWN_REPLACEMENTS = {"getConf": "conf", "_conf": "conf"}

def lint(self, node: ast.AST) -> Iterator[Dict]:
def lint(self, node) -> Iterator[Dict]:
"""
Analyze the AST node and yield diagnostics for legacy context usage.
Analyze the AST or Tree-Sitter node and yield diagnostics for legacy context usage.
"""
# Check for direct usage of `sc`, `sqlContext`, or `sparkContext`
# Handle Python's AST nodes
if isinstance(node, ast.Attribute):
# Case: sc.getConf or sqlContext.getConf
if isinstance(node.value, ast.Name) and node.value.id in self._ATTRIBUTES:
yield self._get_advice(node, node.value.id, node.attr)
yield self._get_advice(
node, node.value.id, node.attr, node.lineno, node.col_offset
)

# Case: df.sparkContext.getConf
if (
isinstance(node.value, ast.Attribute)
and node.value.attr in self._ATTRIBUTES
):
yield self._get_advice(node, node.value.attr, node.attr)
yield self._get_advice(
node, node.value.attr, node.attr, node.lineno, node.col_offset
)

# Handle Tree-Sitter nodes
if isinstance(node, Node) and node.type == "attribute":
# Decode the base and attribute from Tree-Sitter node
base_node = node.child_by_field_name("object")
attr_node = node.child_by_field_name("attribute")
if base_node and attr_node:
base = base_node.text.decode("utf-8")
attr = attr_node.text.decode("utf-8")
if base in self._ATTRIBUTES:
yield self._get_advice(
node,
base,
attr,
node.start_point[0] + 1, # Tree-Sitter's 0-based to 1-based
node.start_point[1],
)

def _get_advice(self, node: ast.Attribute, base: str, attr: str) -> Dict:
def _get_advice(self, node, base: str, attr: str, line: int, col: int) -> Dict:
"""
Generate advice message for prohibited usage.
"""
Expand All @@ -103,19 +165,21 @@ def _get_advice(self, node: ast.Attribute, base: str, attr: str) -> Dict:
return {
"message_id": "E9011",
"message": f"'{base}.{attr}' is not supported. Rewrite it using 'spark.{replacement}' instead.",
"line": node.lineno,
"col": node.col_offset,
"line": line,
"col": col,
}
return {
"message_id": "E9011",
"message": f"'{base}' is not supported. Rewrite it using 'spark' instead.",
"line": node.lineno,
"col": node.col_offset,
"line": line,
"col": col,
}


class MapPartitionsMatcher(Matcher):
def lint(self, node: ast.AST) -> Iterator[Dict]:

def lint(self, node) -> Iterator[Dict]:
# Original AST logic for detecting 'mapPartitions'
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
if node.func.attr == "mapPartitions":
yield {
Expand All @@ -125,19 +189,44 @@ def lint(self, node: ast.AST) -> Iterator[Dict]:
"col": node.col_offset,
}

# Tree-Sitter logic: Detecting 'mapPartitions' usage
if isinstance(node, Node) and node.type == "call":
function_node = node.child_by_field_name("function")
if function_node:
if function_node.text.decode("utf-8").endswith("mapPartitions"):
yield {
"message_id": "E9002",
"message": "Usage of 'mapPartitions' is not allowed. Use 'mapInArrow' or Pandas UDFs instead.",
"line": node.start_point[0] + 1,
"col": node.start_point[1],
}


class SetLogLevelMatcher(Matcher):
def lint(self, node: ast.AST) -> Iterator[Dict]:

def lint(self, node) -> Iterator[Dict]:
# Original AST logic for detecting 'setLogLevel'
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
if node.func.attr == "setLogLevel":
yield {
"message_id": "E9004",
"message": "Setting Spark log level from code is not allowed. "
"Use Spark configuration instead.",
"message": "Setting Spark log level from code is not allowed. Use Spark configuration instead.",
"line": node.lineno,
"col": node.col_offset,
}

# Tree-Sitter logic: Detecting 'setLogLevel' usage
if isinstance(node, Node) and node.type == "call":
function_node = node.child_by_field_name("function")
if function_node:
if function_node.text.decode("utf-8").endswith("setLogLevel"):
yield {
"message_id": "E9004",
"message": "Setting Spark log level from code is not allowed. Use Spark configuration instead.",
"line": node.start_point[0] + 1,
"col": node.start_point[1],
}


class Log4JMatcher(Matcher):
def lint(self, node: ast.AST) -> Iterator[Dict]:
Expand All @@ -151,6 +240,16 @@ def lint(self, node: ast.AST) -> Iterator[Dict]:
"col": node.col_offset,
}

if isinstance(node, Node) and node.type == "attribute":
if node.text.decode("utf-8") == "org.apache.log4j":
yield {
"message_id": "E9005",
"message": "Accessing Log4J logger from Spark JVM is not allowed. "
"Use Python logging.getLogger() instead.",
"line": node.start_point[0] + 1, # Tree-Sitter's 0-based to 1-based
"col": node.start_point[1],
}


class CommandContextMatcher(Matcher):
def lint(self, node: ast.AST) -> Iterator[Dict]:
Expand All @@ -162,3 +261,15 @@ def lint(self, node: ast.AST) -> Iterator[Dict]:
"line": node.lineno,
"col": node.col_offset,
}

if isinstance(node, Node) and node.type == "call":
# Traverse children to find if it's a `toJson` call
for child in node.children:
if child.type == "attribute" and child.text.decode("utf-8") == "toJson":
yield {
"message_id": "E9007",
"message": "Usage of 'toJson' is not allowed. Use 'toSafeJson' where supported.",
"line": child.start_point[0]
+ 1, # Tree-Sitter's 0-based to 1-based
"col": child.start_point[1],
}
Loading