Skip to content

fix: improve hover and navigation features #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions vyper_lsp/analyzer/AstAnalyzer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import re
from typing import List, Optional
from typing import List, Optional, Tuple
import warnings
from packaging.version import Version
from lsprotocol.types import (
Expand All @@ -26,6 +26,7 @@
diagnostic_from_exception,
is_internal_fn,
is_state_var,
range_from_node,
)
from lsprotocol.types import (
CompletionItem,
Expand Down Expand Up @@ -124,7 +125,6 @@ def get_completions_in_doc(
if element == "self":
for fn in self.ast.get_internal_functions():
items.append(CompletionItem(label=fn))
# TODO: This should exclude constants and immutables
for var in self.ast.get_state_variables():
items.append(CompletionItem(label=var))
else:
Expand Down Expand Up @@ -194,7 +194,9 @@ def _format_fn_signature(self, node: nodes.FunctionDef) -> str:
function_def = match.group()
return f"(Internal Function) {function_def}"

def hover_info(self, document: Document, pos: Position) -> Optional[str]:
def hover_info(
self, document: Document, pos: Position
) -> Optional[Tuple[str, Range]]:
if len(document.lines) < pos.line:
return None

Expand All @@ -204,34 +206,33 @@ def hover_info(self, document: Document, pos: Position) -> Optional[str]:

if is_internal_fn(full_word):
node = self.ast.find_function_declaration_node_for_name(word)
return node and self._format_fn_signature(node)
return node and (self._format_fn_signature(node), range_from_node(node))

if is_state_var(full_word):
node = self.ast.find_state_variable_declaration_node_for_name(word)
if not node:
return None
variable_type = node.annotation.id
return f"(State Variable) **{word}** : **{variable_type}**"
return node and (
f"(State Variable) **{word}** : **{node.annotation.id}**",
range_from_node(node),
)

if word in self.ast.get_structs():
node = self.ast.find_type_declaration_node_for_name(word)
return node and f"(Struct) **{word}**"
return node and (f"(Struct) **{word}**", range_from_node(node))

if word in self.ast.get_enums():
node = self.ast.find_type_declaration_node_for_name(word)
return node and f"(Enum) **{word}**"
return node and (f"(Enum) **{word}**", range_from_node(node))

if word in self.ast.get_events():
node = self.ast.find_type_declaration_node_for_name(word)
return node and f"(Event) **{word}**"
return node and (f"(Event) **{word}**", range_from_node(node))

if word in self.ast.get_constants():
node = self.ast.find_state_variable_declaration_node_for_name(word)
if not node:
return None

variable_type = node.annotation.id
return f"(Constant) **{word}** : **{variable_type}**"
return node and (
f"(Constant) **{word}** : **{node.annotation.id}**",
range_from_node(node),
)

return None

Expand Down
136 changes: 88 additions & 48 deletions vyper_lsp/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ class AST:
ast_data_folded = None
ast_data_unfolded = None

custom_type_node_types = (nodes.StructDef, nodes.EnumDef)
custom_type_node_types = (
nodes.StructDef,
nodes.EnumDef,
nodes.InterfaceDef,
nodes.EventDef,
)

@classmethod
def from_node(cls, node: VyperNode):
Expand Down Expand Up @@ -67,16 +72,21 @@ def get_top_level_nodes(self, *args, **kwargs):
return self.best_ast.get_children(*args, **kwargs)

def get_enums(self) -> List[str]:
return [node.name for node in self.get_descendants(nodes.EnumDef)]
return [node.name for node in self.get_top_level_nodes(nodes.EnumDef)]

def get_structs(self) -> List[str]:
return [node.name for node in self.get_descendants(nodes.StructDef)]
return [node.name for node in self.get_top_level_nodes(nodes.StructDef)]

def get_events(self) -> List[str]:
return [node.name for node in self.get_descendants(nodes.EventDef)]
return [node.name for node in self.get_top_level_nodes(nodes.EventDef)]

def get_interfaces(self):
return [node.name for node in self.get_top_level_nodes(nodes.InterfaceDef)]

def get_user_defined_types(self):
return [node.name for node in self.get_descendants(self.custom_type_node_types)]
return [
node.name for node in self.get_top_level_nodes(self.custom_type_node_types)
]

def get_constants(self):
# NOTE: Constants should be fetched from self.ast_data, they are missing
Expand All @@ -86,10 +96,29 @@ def get_constants(self):

return [
node.target.id
for node in self.ast_data.get_children(nodes.VariableDecl)
for node in self.get_top_level_nodes(nodes.VariableDecl)
if node.is_constant
]

def get_immutables(self):
return [
node.target.id
for node in self.get_top_level_nodes(nodes.VariableDecl)
if node.is_immutable
]

def get_state_variables(self):
# NOTE: The state variables should be fetched from self.ast_data, they are
# missing from self.ast_data_unfolded and self.ast_data_folded when constants
if self.ast_data is None:
return []

return [
node.target.id
for node in self.get_top_level_nodes(nodes.VariableDecl)
if not node.is_constant and not node.is_immutable
]

def get_enum_variants(self, enum: str):
enum_node = self.find_type_declaration_node_for_name(enum)
if enum_node is None:
Expand All @@ -104,16 +133,6 @@ def get_struct_fields(self, struct: str):

return [node.target.id for node in struct_node.get_children(nodes.AnnAssign)]

def get_state_variables(self):
# NOTE: The state variables should be fetched from self.ast_data, they are
# missing from self.ast_data_unfolded and self.ast_data_folded when constants
if self.ast_data is None:
return []

return [
node.target.id for node in self.ast_data.get_descendants(nodes.VariableDecl)
]

def get_internal_function_nodes(self):
function_nodes = self.get_descendants(nodes.FunctionDef)
internal_nodes = []
Expand All @@ -138,14 +157,20 @@ def find_nodes_referencing_state_variable(self, variable: str):
nodes.Attribute, {"value.id": "self", "attr": variable}
)

def find_nodes_referencing_constant(self, constant: str):
name_nodes = self.get_descendants(nodes.Name, {"id": constant})
def find_nodes_referencing_constant_or_immutable(self, name: str):
name_nodes = self.get_descendants(nodes.Name, {"id": name})
return [
node
for node in name_nodes
if not isinstance(node.get_ancestor(), nodes.VariableDecl)
]

def find_nodes_referencing_constant(self, constant: str):
return self.find_nodes_referencing_constant_or_immutable(constant)

def find_nodes_referencing_immutable(self, immutable: str):
return self.find_nodes_referencing_constant_or_immutable(immutable)

def get_attributes_for_symbol(self, symbol: str):
node = self.find_type_declaration_node_for_name(symbol)
if node is None:
Expand All @@ -159,12 +184,8 @@ def get_attributes_for_symbol(self, symbol: str):
return []

def find_function_declaration_node_for_name(self, function: str):
for node in self.get_descendants(nodes.FunctionDef):
name_match = node.name == function
not_interface_declaration = not isinstance(
node.get_ancestor(), nodes.InterfaceDef
)
if name_match and not_interface_declaration:
for node in self.get_top_level_nodes(nodes.FunctionDef):
if node.name == function:
return node

return None
Expand All @@ -175,15 +196,15 @@ def find_state_variable_declaration_node_for_name(self, variable: str):
if self.ast_data is None:
return None

for node in self.ast_data.get_descendants(nodes.VariableDecl):
for node in self.get_top_level_nodes(nodes.VariableDecl):
if node.target.id == variable:
return node

return None

def find_type_declaration_node_for_name(self, symbol: str):
searchable_types = self.custom_type_node_types + (nodes.EventDef,)
for node in self.get_descendants(searchable_types):
searchable_types = self.custom_type_node_types
for node in self.get_top_level_nodes(searchable_types):
if node.name == symbol:
return node
if isinstance(node, nodes.EnumDef):
Expand All @@ -193,17 +214,44 @@ def find_type_declaration_node_for_name(self, symbol: str):

return None

def find_nodes_referencing_enum(self, enum: str):
def find_nodes_referencing_type(self, type_name: str):
return_nodes = []

for node in self.get_descendants(nodes.AnnAssign, {"annotation.id": enum}):
return_nodes.append(node)
for node in self.get_descendants(nodes.Attribute, {"value.id": enum}):
return_nodes.append(node)
for node in self.get_descendants(nodes.VariableDecl, {"annotation.id": enum}):
return_nodes.append(node)
for node in self.get_descendants(nodes.FunctionDef, {"returns.id": enum}):
return_nodes.append(node)
type_expressions = set()

for node in self.get_descendants():
if hasattr(node, "annotation"):
type_expressions.add(node.annotation)
elif hasattr(node, "returns") and node.returns:
type_expressions.add(node.returns)

# TODO cover more builtin
for node in self.get_descendants(nodes.Call, {"func.id": "empty"}):
type_expressions.add(node.args[0])

for node in type_expressions:
for subnode in node.get_descendants(include_self=True):
if isinstance(subnode, nodes.Name) and subnode.id == type_name:
return_nodes.append(subnode)

return return_nodes

def find_nodes_referencing_callable_type(self, type_name: str):
return_nodes = self.find_nodes_referencing_type(type_name)

for node in self.get_descendants(nodes.Call, {"func.id": type_name}):
# ERC20(foo)
# my_struct({x:0})
return_nodes.append(node.func)

return return_nodes

def find_nodes_referencing_enum(self, type_name: str):
return_nodes = self.find_nodes_referencing_type(type_name)

for node in self.get_descendants(nodes.Attribute, {"value.id": type_name}):
# A.o
return_nodes.append(node.value)

return return_nodes

Expand All @@ -212,19 +260,11 @@ def find_nodes_referencing_enum_variant(self, enum: str, variant: str):
nodes.Attribute, {"attr": variant, "value.id": enum}
)

def find_nodes_referencing_struct(self, struct: str):
return_nodes = []
def find_nodes_referencing_struct(self, type_name: str):
return self.find_nodes_referencing_callable_type(type_name)

for node in self.get_descendants(nodes.AnnAssign, {"annotation.id": struct}):
return_nodes.append(node)
for node in self.get_descendants(nodes.Call, {"func.id": struct}):
return_nodes.append(node)
for node in self.get_descendants(nodes.VariableDecl, {"annotation.id": struct}):
return_nodes.append(node)
for node in self.get_descendants(nodes.FunctionDef, {"returns.id": struct}):
return_nodes.append(node)

return return_nodes
def find_nodes_referencing_interfaces(self, type_name: str):
return self.find_nodes_referencing_callable_type(type_name)

def find_top_level_node_at_pos(self, pos: Position) -> Optional[VyperNode]:
for node in self.get_top_level_nodes():
Expand Down
3 changes: 2 additions & 1 deletion vyper_lsp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def hover(ls: LanguageServer, params: HoverParams):
document = ls.workspace.get_text_document(params.text_document.uri)
hover_info = ast_analyzer.hover_info(document, params.position)
if hover_info:
return Hover(contents=hover_info, range=None)
hover_content, range = hover_info
return Hover(contents=hover_content, range=range)


@server.feature(
Expand Down
18 changes: 14 additions & 4 deletions vyper_lsp/navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,14 @@ def _is_state_var_decl(self, line, word):
return is_top_level and is_state_variable

def _is_constant_decl(self, line, word):
is_top_level = not line[0].isspace()
is_constant = "constant(" in line
return is_constant and self._is_state_var_decl(line, word)
return is_top_level and is_constant and word in self.ast.get_constants()

def _is_immutable_decl(self, line, word):
is_top_level = not line[0].isspace()
is_immutable = "immutable(" in line
return is_top_level and is_immutable and word in self.ast.get_immutables()

def _is_internal_fn(self, line, word, expression):
is_def = line.startswith("def")
Expand All @@ -89,12 +95,18 @@ def finalize(refs):
if word in self.ast.get_structs() or word in self.ast.get_events():
return finalize(self.ast.find_nodes_referencing_struct(word))

if word in self.ast.get_interfaces():
return finalize(self.ast.find_nodes_referencing_interfaces(word))

if self._is_internal_fn(og_line, word, expression):
return finalize(self.ast.find_nodes_referencing_internal_function(word))

if self._is_constant_decl(og_line, word):
return finalize(self.ast.find_nodes_referencing_constant(word))

if self._is_immutable_decl(og_line, word):
return finalize(self.ast.find_nodes_referencing_immutable(word))

if self._is_state_var_decl(og_line, word):
return finalize(self.ast.find_nodes_referencing_state_variable(word))

Expand Down Expand Up @@ -139,9 +151,7 @@ def find_declaration(self, document: Document, pos: Position) -> Optional[Range]
return self._find_state_variable_declaration(word)
elif word in self.ast.get_user_defined_types():
return self.find_type_declaration(word)
elif word in self.ast.get_events():
return self.find_type_declaration(word)
elif word in self.ast.get_constants():
elif word in self.ast.get_constants() or word in self.ast.get_immutables():
return self._find_state_variable_declaration(word)
elif isinstance(top_level_node, FunctionDef):
range_ = self._find_variable_declaration_under_node(top_level_node, word)
Expand Down