Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nielstron committed Jan 2, 2025
1 parent 58e7841 commit 84ed868
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 14 deletions.
64 changes: 55 additions & 9 deletions tests/test_roundtrips.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from uplc.transformer import unique_variables, debrujin_variables, undebrujin_variables
from uplc.ast import *
from uplc import lexer
from uplc.transformer.plutus_version_enforcer import PlutusVersionEnforcer, UnsupportedTerm
from uplc.transformer.plutus_version_enforcer import (
PlutusVersionEnforcer,
UnsupportedTerm,
)
from uplc.util import NodeVisitor


Expand Down Expand Up @@ -107,6 +110,7 @@ def visit_Lambda(self, node: Lambda):
def visit_Variable(self, node: Variable):
self.check_bound(node.name)


@hst.composite
def uplc_expr_all_bound(draw, uplc_expr):
x = draw(uplc_expr)
Expand All @@ -128,7 +132,13 @@ def rec_expr_strategies(uplc_expr):
uplc_case = hst.builds(Case, uplc_expr, hst.lists(uplc_expr))
uplc_lambda_bound = uplc_expr_all_bound(uplc_expr)
return hst.one_of(
uplc_lambda, uplc_delay, uplc_force, uplc_apply, uplc_lambda_bound, uplc_case, uplc_constr
uplc_lambda,
uplc_delay,
uplc_force,
uplc_apply,
uplc_lambda_bound,
uplc_case,
uplc_constr,
)


Expand All @@ -143,6 +153,7 @@ def rec_expr_strategies(uplc_expr):
# This strategy also produces invalid programs (due to variables not being bound)
uplc_program_any = hst.builds(Program, uplc_version, uplc_expr)


class HasConstrCaseVisitor(NodeVisitor):
has_constr_or_case: bool = False

Expand All @@ -152,16 +163,18 @@ def visit_Case(self, _: Case):
def visit_Constr(self, _: Constr):
self.has_constr_or_case = True


@hst.composite
def uplc_program_correct_version(draw, uplc_expr, uplc_version):
x = draw(uplc_expr)
has_constr_or_case_visitor = HasConstrCaseVisitor()
has_constr_or_case_visitor.visit(x)
if has_constr_or_case_visitor.has_constr_or_case:
return Program((1,1,0), x)
return Program((1, 1, 0), x)
version = draw(uplc_version)
return Program(version, x)


uplc_expr_valid = uplc_expr_all_bound(uplc_expr)
# This strategy only produces valid programs (all variables are bound)
uplc_program_valid = uplc_program_correct_version(uplc_expr_valid, uplc_version)
Expand All @@ -180,6 +193,7 @@ def uplc_program_correct_version(draw, uplc_expr, uplc_version):

uplc_program = hst.one_of(uplc_program_any, uplc_program_valid)


def has_correct_version(x: Program):
v = PlutusVersionEnforcer()
try:
Expand All @@ -189,6 +203,17 @@ def has_correct_version(x: Program):
return True


class AllVarsNumbersVisitor(NodeVisitor):
all_vars_are_numbers: bool = True

def visit_Variable(self, node: Variable):
try:
int(node.name)
except ValueError:
self.all_vars_are_numbers = False
self.generic_visit(node)


class HypothesisTests(unittest.TestCase):
@hypothesis.given(uplc_program, hst.sampled_from(UPLCDialect))
@hypothesis.settings(max_examples=1000, deadline=None)
Expand Down Expand Up @@ -246,14 +271,16 @@ def test_rewrite_no_semantic_change(self, p):
code = dumps(p)
try:
rewrite_p = unique_variables.UniqueVariableTransformer().visit(parse(code))
except unique_variables.FreeVariableError:
except (unique_variables.FreeVariableError, SyntaxError):
return
try:
res = eval(p)
res = unique_variables.UniqueVariableTransformer().visit(res.result)
res = res.dumps()
except unique_variables.FreeVariableError:
self.fail(f"Free variable error occurred after evaluation in {code}")
except (unique_variables.FreeVariableError, SyntaxError):
self.fail(
f"Free variable/ Syntax error occurred after evaluation in {code}"
)
except Exception as e:
res = e.__class__
try:
Expand All @@ -262,7 +289,7 @@ def test_rewrite_no_semantic_change(self, p):
rewrite_res.result
)
rewrite_res = rewrite_res.dumps()
except unique_variables.FreeVariableError:
except (unique_variables.FreeVariableError, SyntaxError):
self.fail(f"Free variable error occurred after evaluation in {code}")
except Exception as e:
rewrite_res = e.__class__
Expand Down Expand Up @@ -317,7 +344,10 @@ def test_raises_syntaxerror(self, p):
)
def test_preeval_no_semantic_change(self, p):
code = dumps(p)
orig_p = parse(code).term
try:
orig_p = parse(code).term
except SyntaxError:
return
rewrite_p = pre_evaluation.PreEvaluationOptimizer().visit(p).term
params = []
try:
Expand Down Expand Up @@ -389,9 +419,21 @@ def test_debrujin_undebrujin(self, p: Program):
debrujin
)
self.assertEqual(p_unique, undebrujin, "incorrect flatten roundtrip")
all_vars_numbers_visitor = AllVarsNumbersVisitor()
all_vars_numbers_visitor.visit(debrujin)
self.assertTrue(
all_vars_numbers_visitor.all_vars_are_numbers,
"Some variable is not a number",
)

@hypothesis.given(uplc_program_valid)
@hypothesis.settings(max_examples=1000, deadline=datetime.timedelta(seconds=10))
@hypothesis.example(
Program(
version=(1, 1, 0),
term=Lambda(var_name="x", term=Constr(tag=0, fields=[Variable(name="x")])),
)
)
@hypothesis.example(
Program(version=(1, 0, 0), term=PlutusMap(value=frozendict.frozendict({})))
)
Expand All @@ -401,7 +443,11 @@ def test_debrujin_undebrujin(self, p: Program):
@hypothesis.example(Program(version=(1, 0, 0), term=BuiltinUnit()))
def test_flat_unflat_roundtrip(self, p: Program):
p_unique = unique_variables.UniqueVariableTransformer().visit(p)
self.assertEqual(p_unique, unflatten(flatten(p)), "incorrect flatten roundtrip")
self.assertEqual(
p_unique,
unflatten(flatten(p)),
"incorrect flatten roundtrip",
)

# TODO test invalid programs being detected with an free variable error

Expand Down
26 changes: 24 additions & 2 deletions uplc/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1522,6 +1522,17 @@ def __getattr__(self, item):
pass
return object.__getattribute__(self, item)

def __setattr__(self, key, value):
try:
if key.startswith("fields_"):
index = int(key.split("_")[1])
newlist = self.fields.copy()
newlist[index] = value
self.fields = newlist
except:
pass
return object.__setattr__(self, key, value)


@dataclass
class Case(AST):
Expand All @@ -1533,7 +1544,7 @@ def dumps(self, dialect=UPLCDialect.Plutus) -> str:

@property
def _fields(self):
return [f"branches_{i}" for i in range(len(self.branches))]
return ["scrutinee"] + [f"branches_{i}" for i in range(len(self.branches))]

def __getattr__(self, item):
try:
Expand All @@ -1542,4 +1553,15 @@ def __getattr__(self, item):
return self.branches[index]
except:
pass
return super().__getattr__(item)
return super().__getattribute__(item)

def __setattr__(self, key, value):
try:
if key.startswith("branches_"):
index = int(key.split("_")[1])
newlist = self.branches.copy()
newlist[index] = value
self.branches = newlist
except:
pass
return object.__setattr__(self, key, value)
21 changes: 19 additions & 2 deletions uplc/flat_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Callable

from .ast import *
from .parser import PLUTUS_V3
from .transformer.plutus_version_enforcer import UnsupportedTerm

UPLC_TAG_WIDTHS = {
"term": 4,
Expand Down Expand Up @@ -69,6 +71,7 @@ class UplcDeserializer:
def __init__(self, bits: str):
self._bits = bits
self._pos = 0
self._version = None

def tag_width(self, category: str) -> int:
assert category in UPLC_TAG_WIDTHS, f"unknown tag category {category}"
Expand Down Expand Up @@ -108,6 +111,8 @@ def read_term(self) -> AST:
return self.read_builtin()
elif tag == 8:
return self.read_constr()
elif tag == 9:
return self.read_case()
else:
raise ValueError(f"term tag {tag} unhandled")

Expand Down Expand Up @@ -300,12 +305,23 @@ def read_builtin(self) -> BuiltIn:
return BuiltIn(builtin)

def read_constr(self) -> Constr:
if self._version < PLUTUS_V3:
raise UnsupportedTerm("Invalid term encoded (Constr in pre-PlutusV3)")
# in theory limited to 64 bits
id = self.read_integer(signed=False)

builtin = self.read_integer()
fields = self.read_list(self.read_term)

return BuiltIn(builtin)
return Constr(id, fields)

def read_case(self) -> Case:
if self._version < PLUTUS_V3:
raise UnsupportedTerm("Invalid term encoded (Case in pre-PlutusV3)")
scrutinee = self.read_term()

branches = self.read_list(self.read_term)

return Case(scrutinee, branches)

def finalize(self):
self.move_to_byte_boundary(True)
Expand Down Expand Up @@ -333,6 +349,7 @@ def read_program(self) -> Program:
self.read_integer(signed=False),
self.read_integer(signed=False),
)
self._version = version

expr = self.read_term()

Expand Down
16 changes: 16 additions & 0 deletions uplc/flat_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,22 @@ def visit_Force(self, n: Force):
def visit_Error(self, n: Error):
self.bit_writer.write("0110")

def visit_Constr(self, n: Constr):
self.bit_writer.write("1000")
self.bit_writer.write_int(n.tag, False)
self.visit_list(n.fields)

def visit_Case(self, n: Case):
self.bit_writer.write("1001")
self.visit(n.scrutinee)
self.visit_list(n.branches)

def visit_list(self, n: List[AST]):
for v in n:
self.bit_writer.write("1")
self.visit(v)
self.bit_writer.write("0")

def visit_BuiltIn(self, n: BuiltIn):
self.bit_writer.write("0111")
# write index of uplc builtin
Expand Down
1 change: 0 additions & 1 deletion uplc/transformer/plutus_version_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


class UnsupportedTerm(ValueError):

def __init__(self, message):
self.message = message

Expand Down

0 comments on commit 84ed868

Please sign in to comment.