Skip to content

Commit

Permalink
Black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
nielstron committed Jan 2, 2025
1 parent 5ff2169 commit 7200c86
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 77 deletions.
6 changes: 2 additions & 4 deletions tests/test_acceptance.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,9 @@ def test_acceptance_tests(self, _, dirpath, rewriter):
cost_content = f.read()
if "error" in cost_content:
return
numbers = re.findall(r'(cpu|mem)\s*:\s*(\d+)\b', cost_content)
numbers = re.findall(r"(cpu|mem)\s*:\s*(\d+)\b", cost_content)
self.assertEqual(len(numbers), 2, "Could not parse cost pattern")
cost = {
k: int(v) for k, v in numbers
}
cost = {k: int(v) for k, v in numbers}
expected_spent_budget = Budget(cost["cpu"], cost["mem"])
if rewriter in (
pre_evaluation.PreEvaluationOptimizer,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_roundtrips.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def rec_expr_strategies(uplc_expr):
)


uplc_version = hst.sampled_from([(1,0,0), (1,1,0)])
uplc_version = hst.sampled_from([(1, 0, 0), (1, 1, 0)])
# This strategy also produces invalid programs (due to variables not being bound)
uplc_program_any = hst.builds(Program, uplc_version, uplc_expr)

Expand Down
120 changes: 81 additions & 39 deletions uplc/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,27 @@
try:
from pyblst import BlstP1Element, BlstP2Element, BlstFP12Element
except ImportError:
BlstP1Element, BlstP2Element, BlstFP12Element = type("BlstP1Element"), type("BlstP2Element"), type("BlstFP12Element")
BlstP1Element, BlstP2Element, BlstFP12Element = (
type("BlstP1Element"),
type("BlstP2Element"),
type("BlstFP12Element"),
)


@functools.lru_cache()
def pyblst():
try:
import pyblst
except ImportError:
raise RuntimeError("BLS extensions not installed. Run 'pip install \"uplc[bls]\"', 'pip install pyblst' or 'poetry install --all-extras' for bls primitive support.")
raise RuntimeError(
"BLS extensions not installed. Run 'pip install \"uplc[bls]\"', 'pip install pyblst' or 'poetry install --all-extras' for bls primitive support."
)
return pyblst


sys.set_int_max_str_digits(32000)


class UPLCDialect(enum.Enum):
LegacyAiken = "legacy-aiken"
Plutus = "plutus"
Expand All @@ -72,11 +81,13 @@ class FrameApplyArg(Context):
term: "AST"
ctx: Context


@dataclass
class FrameApplyFunArg(Context):
arg: Any
ctx: Context


@dataclass
class FrameForce(Context):
ctx: Context
Expand All @@ -86,6 +97,7 @@ class FrameForce(Context):
class NoFrame(Context):
pass


@dataclass
class FrameConstr(Context):
env: frozendict.frozendict
Expand All @@ -94,12 +106,14 @@ class FrameConstr(Context):
resolved_fields: List["AST"]
ctx: Context


@dataclass
class FrameCases(Context):
env: frozendict.frozendict
branches: List["AST"]
ctx: Context


class Step:
pass

Expand Down Expand Up @@ -329,6 +343,7 @@ def __eq__(self, other):
def encode(self, *args):
return BuiltinByteString(self.value.encode())


@dataclass(frozen=True)
class BuiltinBLS12381G1Element(Constant):
value: BlstP1Element
Expand All @@ -346,7 +361,7 @@ def __setstate__(self, state):
object.__setattr__(self, "value", BlstP1Element.uncompress(state))

def ex_mem(self) -> int:
#TODO
# TODO
return len(self.value.compress())


Expand All @@ -361,9 +376,10 @@ def valuestring(self, dialect=UPLCDialect.Plutus):
return f"0x{self.value.compress().hex()}"

def ex_mem(self) -> int:
#TODO
# TODO
return len(self.value.compress())


@dataclass(frozen=True)
class BuiltinBLS12381Mlresult(Constant):
value: BlstFP12Element
Expand All @@ -375,9 +391,10 @@ def valuestring(self, dialect=UPLCDialect.Plutus):
return "<opaque>"

def ex_mem(self) -> int:
#TODO
# TODO
return 0


@dataclass(frozen=True)
class BuiltinPair(Constant):
l_value: Constant
Expand Down Expand Up @@ -550,7 +567,9 @@ def __post_init__(self):
object.__setattr__(self, "value", frozen_value)

def to_cbor(self):
return frozendict.frozendict({k.to_cbor(): v.to_cbor() for k, v in self.value.items()})
return frozendict.frozendict(
{k.to_cbor(): v.to_cbor() for k, v in self.value.items()}
)

def to_json(self):
return {
Expand Down Expand Up @@ -988,22 +1007,31 @@ def _MapData(x):
assert isinstance(x.sample_value, BuiltinPair), "Can only map over a list of pairs"
return PlutusMap({p.l_value: p.r_value for p in x.values})

def _int_to_bytestring_endianness(endianness: BuiltinBool, width: BuiltinInteger, integer: BuiltinInteger):

def _int_to_bytestring_endianness(
endianness: BuiltinBool, width: BuiltinInteger, integer: BuiltinInteger
):
width = width.value
if not 0 <= width <= 8192:
raise RuntimeError(f"Invalid width {width} (0 <= w < 8192)")
integer = integer.value
if not 0 <= integer < 256**8192:
raise RuntimeError(f"Too large value to convert {integer} (0 <= integer < 256**8192)")
raise RuntimeError(
f"Too large value to convert {integer} (0 <= integer < 256**8192)"
)
if width == 0:
width = (integer.bit_length() + 7) // 8
endianness = "big" if endianness.value else "little"
return BuiltinByteString(integer.to_bytes(width, byteorder=endianness))

def _bytestring_to_int_endianness(endianness: BuiltinBool, bytestring: BuiltinByteString):

def _bytestring_to_int_endianness(
endianness: BuiltinBool, bytestring: BuiltinByteString
):
endianness = "big" if endianness.value else "little"
return BuiltinInteger(int.from_bytes(bytestring.value, byteorder=endianness))


def _map_bytes_trunc(foo, fill):
# implements the extending/truncating of and/or/xor
def ext_trunc_logic(switch, x, y):
Expand Down Expand Up @@ -1225,9 +1253,9 @@ def _replicate_bytes(length: BuiltinInteger, val: BuiltinInteger):
BuiltInFun.Blake2b_224: single_bytestring(
lambda x: BuiltinByteString(hashlib.blake2b(x.value, digest_size=28).digest())
),
BuiltInFun.IntegerToByteString: typechecked(BuiltinBool, BuiltinInteger, BuiltinInteger)(
_int_to_bytestring_endianness
),
BuiltInFun.IntegerToByteString: typechecked(
BuiltinBool, BuiltinInteger, BuiltinInteger
)(_int_to_bytestring_endianness),
BuiltInFun.ByteStringToInteger: typechecked(BuiltinBool, BuiltinByteString)(
_bytestring_to_int_endianness
),
Expand Down Expand Up @@ -1267,19 +1295,25 @@ def _replicate_bytes(length: BuiltinInteger, val: BuiltinInteger):
BuiltInFun.Bls12_381_G1_Compress: typechecked(BuiltinBLS12381G1Element)(
lambda x: BuiltinByteString(x.value.compress())
),
BuiltInFun.Bls12_381_G1_Add: typechecked(BuiltinBLS12381G1Element, BuiltinBLS12381G1Element)(
lambda x, y: BuiltinBLS12381G1Element(x.value + y.value)
),
BuiltInFun.Bls12_381_G1_Add: typechecked(
BuiltinBLS12381G1Element, BuiltinBLS12381G1Element
)(lambda x, y: BuiltinBLS12381G1Element(x.value + y.value)),
BuiltInFun.Bls12_381_G1_Neg: typechecked(BuiltinBLS12381G1Element)(
lambda x: BuiltinBLS12381G1Element(-x.value)
),
BuiltInFun.Bls12_381_G1_ScalarMul: typechecked(BuiltinInteger, BuiltinBLS12381G1Element)(
lambda x, y: BuiltinBLS12381G1Element(y.value.scalar_mul(x.value))
),
BuiltInFun.Bls12_381_G1_HashToGroup: typechecked(BuiltinByteString, BuiltinByteString)(
lambda x, y: BuiltinBLS12381G1Element(pyblst().BlstP1Element.hash_to_group(x.value, y.value))
BuiltInFun.Bls12_381_G1_ScalarMul: typechecked(
BuiltinInteger, BuiltinBLS12381G1Element
)(lambda x, y: BuiltinBLS12381G1Element(y.value.scalar_mul(x.value))),
BuiltInFun.Bls12_381_G1_HashToGroup: typechecked(
BuiltinByteString, BuiltinByteString
)(
lambda x, y: BuiltinBLS12381G1Element(
pyblst().BlstP1Element.hash_to_group(x.value, y.value)
)
),
BuiltInFun.Bls12_381_G1_Equal: typechecked(BuiltinBLS12381G1Element, BuiltinBLS12381G1Element)(
BuiltInFun.Bls12_381_G1_Equal: typechecked(
BuiltinBLS12381G1Element, BuiltinBLS12381G1Element
)(
lambda x, y: BuiltinBool(x.value == y.value),
),
BuiltInFun.Bls12_381_G2_Uncompress: typechecked(BuiltinByteString)(
Expand All @@ -1288,30 +1322,36 @@ def _replicate_bytes(length: BuiltinInteger, val: BuiltinInteger):
BuiltInFun.Bls12_381_G2_Compress: typechecked(BuiltinBLS12381G2Element)(
lambda x: BuiltinByteString(x.value.compress())
),
BuiltInFun.Bls12_381_G2_Add: typechecked(BuiltinBLS12381G2Element, BuiltinBLS12381G2Element)(
lambda x, y: BuiltinBLS12381G2Element(x.value + y.value)
),
BuiltInFun.Bls12_381_G2_Add: typechecked(
BuiltinBLS12381G2Element, BuiltinBLS12381G2Element
)(lambda x, y: BuiltinBLS12381G2Element(x.value + y.value)),
BuiltInFun.Bls12_381_G2_Neg: typechecked(BuiltinBLS12381G2Element)(
lambda x: BuiltinBLS12381G2Element(-x.value)
),
BuiltInFun.Bls12_381_G2_ScalarMul: typechecked(BuiltinInteger, BuiltinBLS12381G2Element)(
lambda x, y: BuiltinBLS12381G2Element(y.value.scalar_mul(x.value))
),
BuiltInFun.Bls12_381_G2_HashToGroup: typechecked(BuiltinByteString, BuiltinByteString)(
lambda x, y: BuiltinBLS12381G2Element(pyblst().BlstP2Element.hash_to_group(x.value, y.value))
BuiltInFun.Bls12_381_G2_ScalarMul: typechecked(
BuiltinInteger, BuiltinBLS12381G2Element
)(lambda x, y: BuiltinBLS12381G2Element(y.value.scalar_mul(x.value))),
BuiltInFun.Bls12_381_G2_HashToGroup: typechecked(
BuiltinByteString, BuiltinByteString
)(
lambda x, y: BuiltinBLS12381G2Element(
pyblst().BlstP2Element.hash_to_group(x.value, y.value)
)
),
BuiltInFun.Bls12_381_G2_Equal: typechecked(BuiltinBLS12381G2Element, BuiltinBLS12381G2Element)(
BuiltInFun.Bls12_381_G2_Equal: typechecked(
BuiltinBLS12381G2Element, BuiltinBLS12381G2Element
)(
lambda x, y: BuiltinBool(x.value == y.value),
),
BuiltInFun.Bls12_381_MillerLoop: typechecked(BuiltinBLS12381G1Element, BuiltinBLS12381G2Element)(
lambda x, y: BuiltinBLS12381Mlresult(pyblst().miller_loop(x.value, y.value))
),
BuiltInFun.Bls12_381_MulMlResult: typechecked(BuiltinBLS12381Mlresult, BuiltinBLS12381Mlresult)(
lambda x, y: BuiltinBLS12381Mlresult(x.value * y.value)
),
BuiltInFun.Bls12_381_FinalVerify: typechecked(BuiltinBLS12381Mlresult, BuiltinBLS12381Mlresult)(
lambda x, y: BuiltinBool(pyblst().final_verify(x.value, y.value))
),
BuiltInFun.Bls12_381_MillerLoop: typechecked(
BuiltinBLS12381G1Element, BuiltinBLS12381G2Element
)(lambda x, y: BuiltinBLS12381Mlresult(pyblst().miller_loop(x.value, y.value))),
BuiltInFun.Bls12_381_MulMlResult: typechecked(
BuiltinBLS12381Mlresult, BuiltinBLS12381Mlresult
)(lambda x, y: BuiltinBLS12381Mlresult(x.value * y.value)),
BuiltInFun.Bls12_381_FinalVerify: typechecked(
BuiltinBLS12381Mlresult, BuiltinBLS12381Mlresult
)(lambda x, y: BuiltinBool(pyblst().final_verify(x.value, y.value))),
}

BuiltInFunForceMap = defaultdict(int)
Expand Down Expand Up @@ -1460,6 +1500,7 @@ class Apply(AST):
def dumps(self, dialect=UPLCDialect.Plutus) -> str:
return f"[{self.f.dumps(dialect=dialect)} {self.x.dumps(dialect=dialect)}]"


@dataclass
class Constr(AST):
tag: int
Expand All @@ -1481,6 +1522,7 @@ def __getattr__(self, item):
pass
return object.__getattribute__(self, item)


@dataclass
class Case(AST):
scrutinee: AST
Expand Down
Loading

0 comments on commit 7200c86

Please sign in to comment.