Skip to content
1 change: 1 addition & 0 deletions .idea/inspectionProfiles/project_inspections.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

123 changes: 120 additions & 3 deletions parser/astgen/ast_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@

from dataclasses import dataclass
from enum import Enum
from typing import Callable, TypeAlias, Iterable
from typing import Callable, TypeAlias, Iterable, TypeVar

from util import flatten_force
from ..common import HasRegion, StrRegion

__all__ = [
"AstNode", "AstProgramNode", "VarDeclScope", "VarDeclType", "AstDeclNode",
"AstRepeat", "AstIf", "AstWhile", "AstAssign", "AstAugAssign", "AstDefine",
"AstNumber", "AstString", "AstAnyName", "AstIdent", "AstAttrName",
"AstListLiteral", "AstAttribute", "AstItem", "AstCall", "AstOp", "AstBinOp",
"AstUnaryOp", 'walk_ast', 'WalkableT', 'WalkerFnT', 'WalkerCallType'
"AstUnaryOp", 'walk_ast', 'WalkableT', 'WalkerFnT', 'WalkerCallType',
"FilteredWalker"
]


Expand Down Expand Up @@ -48,7 +50,6 @@ def _walk_obj_members(cls, o: WalkableT, fn: WalkerFnT):
if o is None:
return
if isinstance(o, AstNode):
# noinspection PyProtectedMember
return o._walk_members(fn)
try:
it = iter(o)
Expand All @@ -75,6 +76,122 @@ def walk_multiple_objects(cls, fn: WalkerFnT, objs: Iterable[WalkableT]):
walk_ast = AstNode.walk_obj


# region <FilteredWalker>
WT = TypeVar('WT', bound=WalkableT)
VT = TypeVar('VT')
SpecificCbT = Callable[[WT], bool | None]
SpecificCbsDict = dict[type[WT] | type, list[Callable[[WT], bool | None]]]
BothCbT = Callable[[WT, WalkerCallType], bool | None]
BothCbsDict = dict[type[WT] | type, list[Callable[[WT, WalkerCallType], bool | None]]]


class WalkerFilterRegistry:
def __init__(self, enter_cbs: SpecificCbsDict = (),
exit_cbs: SpecificCbsDict = (),
both_sbc: BothCbsDict = ()):
self.enter_cbs: SpecificCbsDict = dict(enter_cbs) # Copy them,
self.exit_cbs: SpecificCbsDict = dict(exit_cbs) # also converts default () -> {}
self.both_cbs: BothCbsDict = dict(both_sbc)

def copy(self):
return WalkerFilterRegistry(self.enter_cbs, self.exit_cbs, self.both_cbs)

def register_both(self, t: type[WT], fn: Callable[[WT, WalkerCallType], bool | None]):
self.both_cbs.setdefault(t, []).append(fn)
return self

def register_enter(self, t: type[WT], fn: Callable[[WT], bool | None]):
self.enter_cbs.setdefault(t, []).append(fn)
return self

def register_exit(self, t: type[WT], fn: Callable[[WT], bool | None]):
self.exit_cbs.setdefault(t, []).append(fn)
return self

def __call__(self, *args, **kwargs):
return self

def on_enter(self, *tps: type[WT] | type):
"""Decorator version of register_enter."""
def decor(fn: SpecificCbT):
for t in tps:
self.register_enter(t, fn)
return fn
return decor

def on_exit(self, *tps: type[WT] | type):
"""Decorator version of register_exit."""
def decor(fn: SpecificCbT):
for t in tps:
self.register_exit(t, fn)
return fn
return decor

def on_both(self, *tps: type[WT] | type):
"""Decorator version of register_both."""
def decor(fn: BothCbT):
for t in tps:
self.register_both(t, fn)
return fn
return decor


class FilteredWalker(WalkerFilterRegistry):
def __init__(self):
cls_reg = self.class_registry()
super().__init__(cls_reg.enter_cbs, cls_reg.exit_cbs, cls_reg.both_cbs)

@classmethod
def class_registry(cls) -> WalkerFilterRegistry:
return WalkerFilterRegistry()

@classmethod
def create_cls_registry(cls, fn=None):
"""Create a class-level registry that can be added to using decorators.

This can be used in two ways (at the top of your class)::

# MUST be this name
class_registry = FilteredWalker.create_cls_registry()

or::

@classmethod
@FilteredWalker.create_cls_registry
def class_registry(cls): # MUST be this name
pass

and when registering methods::

@class_registry.on_enter(AstDefine)
def enter_define(self, ...):
...

The restrictions on name are because we have no other way of detecting
it (without metaclass dark magic) as we can't refer to the class while
its namespace is being evaluated
"""
if fn is not None and (parent := fn(cls)) is not None:
return WalkerFilterRegistry.copy(parent)
return WalkerFilterRegistry()

def __call__(self, o: WalkableT, call_type: WalkerCallType):
result = None
# Call more specific ones first
specific_cbs = self.enter_cbs if call_type == WalkerCallType.PRE else self.exit_cbs
for fn in self._get_funcs(specific_cbs, type(o)):
result = fn(o) or result
for fn in self._get_funcs(self.both_cbs, type(o)):
result = fn(o, call_type) or result
return result

@classmethod
def _get_funcs(cls, mapping: dict[type[WT] | type, list[VT]], tp: type[WT]) -> list[VT]:
"""Also looks at superclasses/MRO"""
return flatten_force(mapping.get(sub, []) for sub in tp.mro())
# endregion


@dataclass
class AstProgramNode(AstNode):
name = 'program'
Expand Down
2 changes: 1 addition & 1 deletion parser/astgen/errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from parser.common import BaseParseError, BaseLocatedError
from ..common import BaseParseError, BaseLocatedError


class AstParseError(BaseParseError):
Expand Down
2 changes: 1 addition & 1 deletion parser/lexer/errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from parser.common import BaseParseError, BaseLocatedError
from ..common import BaseParseError, BaseLocatedError


class TokenizerError(BaseParseError):
Expand Down
98 changes: 45 additions & 53 deletions parser/typecheck/typecheck.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Callable, TypeVar

from parser.astgen.ast_node import (
AstNode, walk_ast, WalkableT, WalkerCallType, AstIdent, AstDeclNode,
AstDefine, VarDeclType, VarDeclScope)
from parser.astgen.astgen import AstGen
from parser.common import BaseLocatedError, StrRegion
from util import flatten_force

WT = TypeVar('WT', bound=WalkableT)
VT = TypeVar('VT')


class FilteredWalker:
def __init__(self):
self.enter_cbs: dict[type[WT] | type, list[Callable[[WT], bool | None]]] = {}
self.exit_cbs: dict[type[WT] | type, list[Callable[[WT], bool | None]]] = {}
self.both_cbs: dict[type[WT] | type, list[
Callable[[WT, WalkerCallType], bool | None]]] = {}

def register_both(self, t: type[WT], fn: Callable[[WT, WalkerCallType], bool | None]):
self.both_cbs.setdefault(t, []).append(fn)
return self

def register_enter(self, t: type[WT], fn: Callable[[WT], bool | None]):
self.enter_cbs.setdefault(t, []).append(fn)
return self

def register_exit(self, t: type[WT], fn: Callable[[WT], bool | None]):
self.exit_cbs.setdefault(t, []).append(fn)
return self

def __call__(self, o: WalkableT, call_type: WalkerCallType):
result = None
# Call more specific ones first
specific_cbs = self.enter_cbs if call_type == WalkerCallType.PRE else self.exit_cbs
for fn in self._get_funcs(specific_cbs, type(o)):
result = fn(o) or result
for fn in self._get_funcs(self.both_cbs, type(o)):
result = fn(o, call_type) or result
return result

@classmethod
def _get_funcs(cls, mapping: dict[type[WT] | type, list[VT]], tp: type[WT]) -> list[VT]:
"""Also looks at superclasses/MRO"""
return flatten_force(mapping.get(sub, []) for sub in tp.mro())
from util.recursive_eq import recursive_eq
from ..astgen.ast_node import (
AstNode, walk_ast, AstIdent, AstDeclNode, AstDefine, VarDeclType,
VarDeclScope, FilteredWalker)
from ..astgen.astgen import AstGen
from ..common import BaseLocatedError, StrRegion


@dataclass
Expand Down Expand Up @@ -124,6 +85,9 @@ class Scope:
(so type codegen/type-checker knows what each AstIdent refers to)"""


Scope.__eq__ = recursive_eq(Scope.__eq__)


class NameResolutionError(BaseLocatedError):
pass

Expand Down Expand Up @@ -191,14 +155,15 @@ def enter_fn_decl(fn: AstDefine):
raise self.err("Function already declared", fn.ident.region)
subscope = Scope()
params: list[ParamInfo] = []
for tp, param in fn.params:
if tp.id not in PARAM_TYPES:
raise self.err("Unknown parameter type", tp.region)
if param.id in subscope.declared:
raise self.err("There is already a parameter of this name", param.region)
tp = BoolType() if param.id == 'bool' else ValType()
subscope.declared[param.id] = NameInfo(subscope, param.id, tp, is_param=True)
params.append(ParamInfo(param.id, tp))
for tp_node, name_node in fn.params:
if tp_node.id not in PARAM_TYPES:
raise self.err("Unknown parameter type", tp_node.region)
if (name := name_node.id) in subscope.declared:
raise self.err("There is already a parameter of this name",
name_node.region)
tp = BoolType() if tp_node.id == 'bool' else ValType()
subscope.declared[name] = NameInfo(subscope, name, tp, is_param=True)
params.append(ParamInfo(name, tp))
curr_scope.declared[ident] = info = FuncInfo.from_param_info(
curr_scope, ident, params,
ret_type=VoidType(), subscope=subscope)
Expand Down Expand Up @@ -226,3 +191,30 @@ def enter_fn_decl(fn: AstDefine):

def err(self, msg: str, region: StrRegion):
return NameResolutionError(msg, region, self.src)


class Typechecker:
def __init__(self, name_resolver: NameResolver):
self.resolver = name_resolver
self.src = self.resolver.src
self.is_ok: bool | None = None

def _init(self):
self.resolver.run()
self.ast = self.resolver.ast
self.top_scope = self.resolver.top_scope

def run(self):
if self.is_ok is None:
return self.is_ok
self._typecheck()
self.is_ok = True
return self.is_ok

def _typecheck(self):
walker = FilteredWalker()

self.ast.walk(walker)
...


28 changes: 27 additions & 1 deletion test/test_typecheck/test_name_resolve.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from unittest.mock import Mock, patch

from parser.common import StrRegion
from parser.typecheck.typecheck import NameResolver
from parser.typecheck.typecheck import (
NameResolver, Scope, NameInfo, BoolType, ValType, VoidType,
FuncInfo, ParamInfo)
from test.common import CommonTestCase


Expand Down Expand Up @@ -48,6 +50,30 @@ def test_top_scope_attr(self):
self.assertIs(v2, nr.top_scope)
m.assert_called_once() # Still only once

def test_params(self):
src = ('def f1(bool b0, val v0, string s0, number n0) {let L0=s0..v0;};'
'def f2() {}')
sc = Scope()
f1_scope = Scope()
f1_scope.declared = {
'b0': NameInfo(f1_scope, 'b0', BoolType(), is_param=True),
'v0': (v0 := NameInfo(f1_scope, 'v0', ValType(), is_param=True)),
's0': (s0 := NameInfo(f1_scope, 's0', ValType(), is_param=True)),
'n0': NameInfo(f1_scope, 'n0', ValType(), is_param=True),
'L0': NameInfo(f1_scope, 'L0', ValType())
}
f1_scope.used = {'v0': v0, 's0': s0}
sc.declared = {
'f1': FuncInfo.from_param_info(sc, 'f1', [
ParamInfo('b0', BoolType()),
ParamInfo('v0', ValType()),
ParamInfo('s0', ValType()), # val == string == number for now
ParamInfo('n0', ValType()),
], VoidType(), f1_scope),
'f2': FuncInfo.from_param_info(sc, 'f2', [], VoidType(), Scope())
}
self.assertEqual(sc, self.getNameResolver(src).run())


class TestNameResolveErrors(CommonTestCase):
def test_undefined_var(self):
Expand Down
1 change: 1 addition & 0 deletions util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from os import PathLike
from typing import TypeVar, Any, overload, Iterable

from .recursive_eq import recursive_eq
from .simple_process_pool import *
from .timeouts import *

Expand Down
26 changes: 26 additions & 0 deletions util/recursive_eq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Safe recursive equality comparison."""
import functools


def recursive_eq(fn):
"""Must be used as decorator, like reprlib.recursive_repr.
Works by hypothesising that 2 ids are equal. Then, it tries to compare
them. If it encounters one of them again, it checks that the corresponding
value is the hypothesised value. If so, they're equal. If not, they're
unequal."""
hypotheses: dict[int, int] = {} # int <-> int (should be undirected)

@functools.wraps(fn)
def eq(a, b):
if (bid_exp := hypotheses.get(id(a))) is not None:
return bid_exp == id(b)
if (aid_exp := hypotheses.get(id(b))) is not None:
return aid_exp == id(a)
hypotheses[id(a)] = id(b)
hypotheses[id(b)] = id(a)
try:
return fn(a, b) # Will call this function again
finally:
del hypotheses[id(a)]
del hypotheses[id(b)]
return eq