-
-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy path_data.py
77 lines (55 loc) · 2.04 KB
/
_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from __future__ import annotations
import ast
import collections
import pkgutil
from collections.abc import Iterable
from typing import Callable
from typing import NamedTuple
from typing import Protocol
from typing import TypeVar
from tokenize_rt import Offset
from tokenize_rt import Token
from add_trailing_comma import _plugins
class State(NamedTuple):
in_fstring: bool = False
AST_T = TypeVar('AST_T', bound=ast.AST)
TokenFunc = Callable[[int, list[Token]], None]
ASTFunc = Callable[[State, AST_T], Iterable[tuple[Offset, TokenFunc]]]
FUNCS = collections.defaultdict(list)
def register(tp: type[AST_T]) -> Callable[[ASTFunc[AST_T]], ASTFunc[AST_T]]:
def register_decorator(func: ASTFunc[AST_T]) -> ASTFunc[AST_T]:
FUNCS[tp].append(func)
return func
return register_decorator
class ASTCallbackMapping(Protocol):
def __getitem__(self, tp: type[AST_T]) -> list[ASTFunc[AST_T]]: ...
def visit(
funcs: ASTCallbackMapping,
tree: ast.AST,
) -> dict[Offset, list[TokenFunc]]:
nodes = [(tree, State())]
ret = collections.defaultdict(list)
while nodes:
node, state = nodes.pop()
tp = type(node)
for ast_func in funcs[tp]:
for offset, token_func in ast_func(state, node):
ret[offset].append(token_func)
if tp is ast.FormattedValue:
state = state._replace(in_fstring=True)
for name in reversed(node._fields):
value = getattr(node, name)
if isinstance(value, ast.AST):
nodes.append((value, state))
elif isinstance(value, list):
for value in reversed(value):
if isinstance(value, ast.AST):
nodes.append((value, state))
return ret
def _import_plugins() -> None:
# trigger an import of all of the plugins
plugins_path = _plugins.__path__
mod_infos = pkgutil.walk_packages(plugins_path, f'{_plugins.__name__}.')
for _, name, _ in mod_infos:
__import__(name, fromlist=['_trash'])
_import_plugins()