diff --git a/verkle_trie_pedersen/_blst.so b/verkle_trie_pedersen/_blst.so new file mode 100755 index 00000000..3bc7193c Binary files /dev/null and b/verkle_trie_pedersen/_blst.so differ diff --git a/verkle_trie_pedersen/blst.py b/verkle_trie_pedersen/blst.py new file mode 100644 index 00000000..0dd6f080 --- /dev/null +++ b/verkle_trie_pedersen/blst.py @@ -0,0 +1,237 @@ +# This file was automatically generated by SWIG (http://www.swig.org). +# Version 4.0.1 +# +# Do not make changes to this file unless you know what you are doing--modify +# the SWIG interface file instead. + +from sys import version_info as _swig_python_version_info +if _swig_python_version_info < (2, 7, 0): + raise RuntimeError("Python 2.7 or later required") + +# Import the low-level C/C++ module +if __package__ or "." in __name__: + from . import _blst +else: + import _blst + +try: + import builtins as __builtin__ +except ImportError: + import __builtin__ + +_swig_new_instance_method = _blst.SWIG_PyInstanceMethod_New +_swig_new_static_method = _blst.SWIG_PyStaticMethod_New + +def _swig_repr(self): + try: + strthis = "proxy of " + self.this.__repr__() + except __builtin__.Exception: + strthis = "" + return "<%s.%s; %s >" % (self.__class__.__module__, self.__class__.__name__, strthis,) + + +def _swig_setattr_nondynamic_instance_variable(set): + def set_instance_attr(self, name, value): + if name == "thisown": + self.this.own(value) + elif name == "this": + set(self, name, value) + elif hasattr(self, name) and isinstance(getattr(type(self), name), property): + set(self, name, value) + else: + raise AttributeError("You cannot add instance attributes to %s" % self) + return set_instance_attr + + +def _swig_setattr_nondynamic_class_variable(set): + def set_class_attr(cls, name, value): + if hasattr(cls, name) and not isinstance(getattr(cls, name), property): + set(cls, name, value) + else: + raise AttributeError("You cannot add class attributes to %s" % cls) + return set_class_attr + + +def _swig_add_metaclass(metaclass): + """Class decorator for adding a metaclass to a SWIG wrapped class - a slimmed down version of six.add_metaclass""" + def wrapper(cls): + return metaclass(cls.__name__, cls.__bases__, cls.__dict__.copy()) + return wrapper + + +class _SwigNonDynamicMeta(type): + """Meta class to enforce nondynamic attributes (no new attributes) for a class""" + __setattr__ = _swig_setattr_nondynamic_class_variable(type.__setattr__) + + +BLST_SUCCESS = _blst.BLST_SUCCESS +BLST_BAD_ENCODING = _blst.BLST_BAD_ENCODING +BLST_POINT_NOT_ON_CURVE = _blst.BLST_POINT_NOT_ON_CURVE +BLST_POINT_NOT_IN_GROUP = _blst.BLST_POINT_NOT_IN_GROUP +BLST_AGGR_TYPE_MISMATCH = _blst.BLST_AGGR_TYPE_MISMATCH +BLST_VERIFY_FAIL = _blst.BLST_VERIFY_FAIL +BLST_PK_IS_INFINITY = _blst.BLST_PK_IS_INFINITY +class SecretKey(object): + thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag") + __repr__ = _swig_repr + keygen = _swig_new_instance_method(_blst.SecretKey_keygen) + from_bendian = _swig_new_instance_method(_blst.SecretKey_from_bendian) + from_lendian = _swig_new_instance_method(_blst.SecretKey_from_lendian) + to_bendian = _swig_new_instance_method(_blst.SecretKey_to_bendian) + to_lendian = _swig_new_instance_method(_blst.SecretKey_to_lendian) + + def __init__(self): + _blst.SecretKey_swiginit(self, _blst.new_SecretKey()) + __swig_destroy__ = _blst.delete_SecretKey + +# Register SecretKey in _blst: +_blst.SecretKey_swigregister(SecretKey) + +class P1_Affine(object): + thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag") + __repr__ = _swig_repr + + def __init__(self, *args): + _blst.P1_Affine_swiginit(self, _blst.new_P1_Affine(*args)) + dup = _swig_new_instance_method(_blst.P1_Affine_dup) + to_jacobian = _swig_new_instance_method(_blst.P1_Affine_to_jacobian) + serialize = _swig_new_instance_method(_blst.P1_Affine_serialize) + compress = _swig_new_instance_method(_blst.P1_Affine_compress) + on_curve = _swig_new_instance_method(_blst.P1_Affine_on_curve) + in_group = _swig_new_instance_method(_blst.P1_Affine_in_group) + is_inf = _swig_new_instance_method(_blst.P1_Affine_is_inf) + is_equal = _swig_new_instance_method(_blst.P1_Affine_is_equal) + core_verify = _swig_new_instance_method(_blst.P1_Affine_core_verify) + generator = _swig_new_static_method(_blst.P1_Affine_generator) + __swig_destroy__ = _blst.delete_P1_Affine + +# Register P1_Affine in _blst: +_blst.P1_Affine_swigregister(P1_Affine) +P1_Affine_generator = _blst.P1_Affine_generator + +class P1(object): + thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag") + __repr__ = _swig_repr + + def __init__(self, *args): + _blst.P1_swiginit(self, _blst.new_P1(*args)) + dup = _swig_new_instance_method(_blst.P1_dup) + to_affine = _swig_new_instance_method(_blst.P1_to_affine) + serialize = _swig_new_instance_method(_blst.P1_serialize) + compress = _swig_new_instance_method(_blst.P1_compress) + on_curve = _swig_new_instance_method(_blst.P1_on_curve) + in_group = _swig_new_instance_method(_blst.P1_in_group) + is_inf = _swig_new_instance_method(_blst.P1_is_inf) + is_equal = _swig_new_instance_method(_blst.P1_is_equal) + aggregate = _swig_new_instance_method(_blst.P1_aggregate) + sign_with = _swig_new_instance_method(_blst.P1_sign_with) + hash_to = _swig_new_instance_method(_blst.P1_hash_to) + encode_to = _swig_new_instance_method(_blst.P1_encode_to) + mult = _swig_new_instance_method(_blst.P1_mult) + cneg = _swig_new_instance_method(_blst.P1_cneg) + neg = _swig_new_instance_method(_blst.P1_neg) + add = _swig_new_instance_method(_blst.P1_add) + dbl = _swig_new_instance_method(_blst.P1_dbl) + generator = _swig_new_static_method(_blst.P1_generator) + __swig_destroy__ = _blst.delete_P1 + +# Register P1 in _blst: +_blst.P1_swigregister(P1) +P1_generator = _blst.P1_generator + +class P2_Affine(object): + thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag") + __repr__ = _swig_repr + + def __init__(self, *args): + _blst.P2_Affine_swiginit(self, _blst.new_P2_Affine(*args)) + dup = _swig_new_instance_method(_blst.P2_Affine_dup) + to_jacobian = _swig_new_instance_method(_blst.P2_Affine_to_jacobian) + serialize = _swig_new_instance_method(_blst.P2_Affine_serialize) + compress = _swig_new_instance_method(_blst.P2_Affine_compress) + on_curve = _swig_new_instance_method(_blst.P2_Affine_on_curve) + in_group = _swig_new_instance_method(_blst.P2_Affine_in_group) + is_inf = _swig_new_instance_method(_blst.P2_Affine_is_inf) + is_equal = _swig_new_instance_method(_blst.P2_Affine_is_equal) + core_verify = _swig_new_instance_method(_blst.P2_Affine_core_verify) + generator = _swig_new_static_method(_blst.P2_Affine_generator) + __swig_destroy__ = _blst.delete_P2_Affine + +# Register P2_Affine in _blst: +_blst.P2_Affine_swigregister(P2_Affine) +P2_Affine_generator = _blst.P2_Affine_generator + +class P2(object): + thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag") + __repr__ = _swig_repr + + def __init__(self, *args): + _blst.P2_swiginit(self, _blst.new_P2(*args)) + dup = _swig_new_instance_method(_blst.P2_dup) + to_affine = _swig_new_instance_method(_blst.P2_to_affine) + serialize = _swig_new_instance_method(_blst.P2_serialize) + compress = _swig_new_instance_method(_blst.P2_compress) + on_curve = _swig_new_instance_method(_blst.P2_on_curve) + in_group = _swig_new_instance_method(_blst.P2_in_group) + is_inf = _swig_new_instance_method(_blst.P2_is_inf) + is_equal = _swig_new_instance_method(_blst.P2_is_equal) + aggregate = _swig_new_instance_method(_blst.P2_aggregate) + sign_with = _swig_new_instance_method(_blst.P2_sign_with) + hash_to = _swig_new_instance_method(_blst.P2_hash_to) + encode_to = _swig_new_instance_method(_blst.P2_encode_to) + mult = _swig_new_instance_method(_blst.P2_mult) + cneg = _swig_new_instance_method(_blst.P2_cneg) + neg = _swig_new_instance_method(_blst.P2_neg) + add = _swig_new_instance_method(_blst.P2_add) + dbl = _swig_new_instance_method(_blst.P2_dbl) + generator = _swig_new_static_method(_blst.P2_generator) + __swig_destroy__ = _blst.delete_P2 + +# Register P2 in _blst: +_blst.P2_swigregister(P2) +P2_generator = _blst.P2_generator + +G1 = _blst.G1 +G2 = _blst.G2 +class PT(object): + thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag") + __repr__ = _swig_repr + + def __init__(self, *args): + _blst.PT_swiginit(self, _blst.new_PT(*args)) + dup = _swig_new_instance_method(_blst.PT_dup) + is_one = _swig_new_instance_method(_blst.PT_is_one) + is_equal = _swig_new_instance_method(_blst.PT_is_equal) + sqr = _swig_new_instance_method(_blst.PT_sqr) + mul = _swig_new_instance_method(_blst.PT_mul) + final_exp = _swig_new_instance_method(_blst.PT_final_exp) + __swig_destroy__ = _blst.delete_PT + +# Register PT in _blst: +_blst.PT_swigregister(PT) + +class Pairing(object): + thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag") + __repr__ = _swig_repr + + def __init__(self, *args): + _blst.Pairing_swiginit(self, _blst.new_Pairing(*args)) + __swig_destroy__ = _blst.delete_Pairing + aggregate = _swig_new_instance_method(_blst.Pairing_aggregate) + mul_n_aggregate = _swig_new_instance_method(_blst.Pairing_mul_n_aggregate) + commit = _swig_new_instance_method(_blst.Pairing_commit) + merge = _swig_new_instance_method(_blst.Pairing_merge) + finalverify = _swig_new_instance_method(_blst.Pairing_finalverify) + +# Register Pairing in _blst: +_blst.Pairing_swigregister(Pairing) + +cdata = _blst.cdata +memmove = _blst.memmove + +cvar = _blst.cvar +BLS12_381_G1 = cvar.BLS12_381_G1 +BLS12_381_NEG_G1 = cvar.BLS12_381_NEG_G1 +BLS12_381_G2 = cvar.BLS12_381_G2 +BLS12_381_NEG_G2 = cvar.BLS12_381_NEG_G2 + diff --git a/verkle_trie_pedersen/compute_stats.sh b/verkle_trie_pedersen/compute_stats.sh new file mode 100755 index 00000000..e1b16812 --- /dev/null +++ b/verkle_trie_pedersen/compute_stats.sh @@ -0,0 +1,15 @@ +echo -e "WIDTH_BITS\tWIDTH\tNUMBER_INITIAL_KEYS\tNUMBER_KEYS_PROOF\taverage_depth\tproof_size\tproof_time\tcheck_time" > stats.txt + + +python verkle_trie.py 5 65536 500 >> stats.txt +python verkle_trie.py 6 65536 500 >> stats.txt +python verkle_trie.py 7 65536 500 >> stats.txt +python verkle_trie.py 8 65536 500 >> stats.txt +python verkle_trie.py 9 65536 500 >> stats.txt +python verkle_trie.py 10 65536 500 >> stats.txt +python verkle_trie.py 11 65536 500 >> stats.txt +python verkle_trie.py 12 65536 500 >> stats.txt +python verkle_trie.py 13 65536 500 >> stats.txt +python verkle_trie.py 14 65536 500 >> stats.txt +python verkle_trie.py 15 65536 500 >> stats.txt +python verkle_trie.py 16 65536 500 >> stats.txt diff --git a/verkle_trie_pedersen/ipa.py b/verkle_trie_pedersen/ipa.py new file mode 100644 index 00000000..4997dc30 --- /dev/null +++ b/verkle_trie_pedersen/ipa.py @@ -0,0 +1,114 @@ +import blst +import pippenger + +# +# Utilities for dealing with polynomials in evaluation form +# +# A polynomial in evaluation for is defined by its values on DOMAIN, +# where DOMAIN is [omega**0, omega**1, omega**2, ..., omega**(WIDTH-1)] +# where omega is a WIDTH root of unity, i.e. omega**WIDTH % MODULUS == 1 +# +# Any polynomial of degree < WIDTH can be represented uniquely in this form, +# and many operations (such as multiplication and exact division) are more +# efficient. +# +# By precomputing the trusted setup in Lagrange basis, we can also easily +# commit to a a polynomial in evaluation form. +# + +class KzgUtils(): + + """ + Class that defines helper function for Kate proofs in evaluation form (Lagrange basis) + """ + def __init__(self, MODULUS, WIDTH, DOMAIN, SETUP, primefield): + self.MODULUS = MODULUS + self.WIDTH = WIDTH + self.DOMAIN = DOMAIN + self.SETUP = SETUP + self.primefield = primefield + # Precomputed inverses of 1 / (1 - DOMAIN[i]) + self.inverses = [0] + [primefield.inv(1 - DOMAIN[i]) for i in range(1, WIDTH)] + self.inverse_width = primefield.inv(self.WIDTH) + + + def evaluate_polynomial_in_evaluation_form(self, f, z): + """ + Takes a polynomial in evaluation form and evaluates it at one point outside the domain. + Uses the barycentric formula: + f(z) = (1 - z**WIDTH) / WIDTH * sum_(i=0)^WIDTH (f(DOMAIN[i]) * DOMAIN[i]) / (z - DOMAIN[i]) + """ + r = 0 + for i in range(self.WIDTH): + r += self.primefield.div(f[i] * self.DOMAIN[i], (z - self.DOMAIN[i]) ) + r = r * (pow(z, self.WIDTH, self.MODULUS) - 1) * self.inverse_width % self.MODULUS + + return r + + + def compute_inner_quotient_in_evaluation_form(self, f, index): + """ + Compute the quotient q(X) = (f(X) - f(DOMAIN[index])) / (X - DOMAIN[index]) in evaluation form. + + Inner means that the value z = DOMAIN[index] is one of the points at which f is evaluated -- so unlike an outer + quotient (where z is not in DOMAIN), we need to do some extra work to compute q[index] where the formula above + is 0 / 0 + """ + q = [0] * self.WIDTH + y = f[index] + for i in range(self.WIDTH): + if i != index: + q[i] = (f[i] - y) * self.DOMAIN[-i] * self.inverses[index - i] % self.MODULUS + q[index] += - self.DOMAIN[(i - index) % self.WIDTH] * q[i] % self.MODULUS + + return q + + + def compute_outer_quotient_in_evaluation_form(self, f, z, y): + """ + Compute the quotient q(X) = (f(X) - y)) / (X - z) in evaluation form. Note that this only works if the quotient + is exact, i.e. f(z) = y, and otherwise returns garbage + """ + q = [0] * self.WIDTH + for i in range(self.WIDTH): + q[i] = self.primefield.div(f[i] - y, self.DOMAIN[i] - z) + + return q + + + def check_kzg_proof(self, C, z, y, pi): + """ + Check the KZG proof + e(C - [y], [1]) = e(pi, [s - z]) + which is equivalent to + e(C - [y], [1]) * e(-pi, [s - z]) == 1 + """ + pairing = blst.PT(blst.G2().to_affine(), C.dup().add(blst.G1().mult(y).neg()).to_affine()) + pairing.mul(blst.PT(self.SETUP["g2"][1].dup().add(blst.G2().mult(z).neg()).to_affine(), pi.dup().neg().to_affine())) + + return pairing.final_exp().is_one() + + + def evaluate_and_compute_kzg_proof(self, f, z): + """ + Evaluates a function f (given in evaluation form) at a point z (which can be in the DOMAIN or not) + and gives y = f(z) as well as a Kate proof that this is the correct result + """ + if z in self.DOMAIN: + index = self.DOMAIN.index(z) + y = f[index] + q = self.compute_inner_quotient_in_evaluation_form(f, index) + else: + y = self.evaluate_polynomial_in_evaluation_form(f, z) + q = self.compute_outer_quotient_in_evaluation_form(f, z, y) + + return y, pippenger.pippenger_simple(self.SETUP["g1_lagrange"], q) + + + def compute_commitment_lagrange(self, values): + """ + Computes a commitment for a function given in evaluation form. + 'values' is a dictionary and can have missing indices, which improves efficiency. + """ + commitment = pippenger.pippenger_simple([self.SETUP["g1_lagrange"][i] for i in values.keys()], values.values()) + return commitment \ No newline at end of file diff --git a/verkle_trie_pedersen/pippenger.py b/verkle_trie_pedersen/pippenger.py new file mode 100644 index 00000000..51d72030 --- /dev/null +++ b/verkle_trie_pedersen/pippenger.py @@ -0,0 +1,68 @@ +import blst +from itertools import zip_longest +from collections import defaultdict +from random import randint +from time import time + +def integer_in_base(i, b): + r = [] + while i > 0: + r.append(i % b) + i //= b + return r + +def pippenger_simple(group_elements, factors): + """ + A naive implementation of a Pippenger-like multiexponentiation algorithm. Don't use this + in practice, a native implementation in the blst library will perform much better. + """ + assert len(group_elements) == len(factors) + n = len(group_elements) + d = 1 + while (d + 2) * 2**(d + 2) < n: + d += 1 + b = 2**d + factors_decomposed = [integer_in_base(factor, b) for factor in factors] + result = blst.P1_generator().mult(0) + for bases in reversed(list(zip_longest(*factors_decomposed, fillvalue=0))): + total = blst.P1_generator().mult(0) + base_elements_dict = defaultdict(list) + for index, base in enumerate(bases): + if base > 0: + base_elements_dict[base].append(group_elements[index]) + for base, base_elements in base_elements_dict.items(): + if len(base_elements) > 0: + sum_base_elements = base_elements[0].dup() + for x in base_elements[1:]: + sum_base_elements.add(x) + sum_base_elements.mult(base) + total.add(sum_base_elements) + result.mult(b).add(total) + return result + +def lincomb_naive(group_elements, factors): + """ + Direct linear combination + """ + assert len(group_elements) == len(factors) + result = blst.P1_generator().mult(0) + for g, f in zip(group_elements, factors): + result.add(g.dup().mult(f)) + return result + +def test_pippenger(group_elements, factors): + """ + Test and time pippenger_simple + """ + time_a = time() + naive_result = lincomb_naive(group_elements, factors) + time_b = time() + print("n = {0} multiexp".format(len(group_elements))) + print("Naive linear combination: {0:.6f} s".format(time_b - time_a)) + pippenger_result = pippenger_simple(group_elements, factors) + time_c = time() + print("Using simple Pippenger algorithm: {0:.6f} s".format(time_c - time_b)) + assert naive_result.is_equal(pippenger_result) + +if __name__ == "__main__": + test_pippenger([blst.P1_generator()]*16384, [randint(0, 2**255) for i in range(16384)]) \ No newline at end of file diff --git a/verkle_trie_pedersen/poly_utils.py b/verkle_trie_pedersen/poly_utils.py new file mode 100644 index 00000000..eb8972d3 --- /dev/null +++ b/verkle_trie_pedersen/poly_utils.py @@ -0,0 +1,298 @@ +# Creates an object that includes convenience operations for numbers +# and polynomials in some prime field + +# Also added interpolation over an arbitrary DOMAIN (not roots of unity) +# + +class PrimeField(): + def __init__(self, MODULUS, WIDTH): + assert pow(2, MODULUS, MODULUS) == 2 + self.WIDTH = WIDTH + self.MODULUS = MODULUS + self.DOMAIN = range(WIDTH) + + self.A = self.zpoly(DOMAIN) + self.Aprime = self.formal_derivative(self.A) + + # i-th Lagrange polynomial + self.lagrange_polys = [] + + # Aprime evaluated on the DOMAIN + self.Aprime_DOMAIN = [] + + # Aprime on the DOMAIN, inverted + self.Aprime_DOMAIN_inv = [] + for i, x in enumerate(DOMAIN): + self.Aprime_DOMAIN.append(self.eval_poly_at(self.Aprime, x)) + self.Aprime_DOMAIN_inv.append(self.inv(self.Aprime_DOMAIN[-1])) + self.lagrange_polys.append(self.mul_polys([self.Aprime_DOMAIN_inv[-1]], + self.div_polys(self.A, [-x, 1]))) + + # Inverses needed for quotients + self.INVERSES = [self.inv(x) for x in list(range(WIDTH)) + list(range(-WIDTH + 1, 0))] + + + def formal_derivative(self, f): + return [(n + 1) * c % self.MODULUS for n, c in enumerate(f[1:])] + + + def evaluate_polynomial_in_evaluation_form(self, f, z): + """ + Takes a polynomial in evaluation form and evaluates it at one point outside the DOMAIN. + Uses the barycentric formula: + f(z) = A(z) * sum_(i=0)^(WIDTH-1) f(DOMAIN[i]) / A'(DOMAIN[i]) * 1 / (z - DOMAIN[i]) + """ + r = 0 + for x, i in enumerate(self.DOMAIN): + r += self.div(f[i], self.Aprime_DOMAIN[i] * (z - x) ) + r = r * self.eval_poly_at(self.A, z) % self.MODULUS + + return r + + + def compute_inner_quotient_in_evaluation_form(self, f, index): + """ + Compute the quotient q(X) = (f(X) - f(DOMAIN[index])) / (X - DOMAIN[index]) in evaluation form. + + Inner means that the value z = DOMAIN[index] is one of the points at which f is evaluated -- so unlike an outer + quotient (where z is not in DOMAIN), we need to do some extra work to compute q[index] where the formula above + is 0 / 0 + """ + q = [0] * self.WIDTH + y = f[index] + for i in range(self.WIDTH): + if i != index: + q[i] = (f[i] - y) * self.inverses[index - i] * self.Aprime_DOMAIN[index] * self.Aprime_DOMAIN_inv[i] % self.MODULUS + q[index] += - self.DOMAIN[(i - index) % self.WIDTH] * q[i] % self.MODULUS + + return q + + + def compute_outer_quotient_in_evaluation_form(self, f, z, y): + """ + Compute the quotient q(X) = (f(X) - y)) / (X - z) in evaluation form. Note that this only works if the quotient + is exact, i.e. f(z) = y, and otherwise returns garbage + """ + q = [0] * self.WIDTH + for i in range(self.WIDTH): + q[i] = self.primefield.div(f[i] - y, self.DOMAIN[i] - z) + + return q + + def add(self, x, y): + return (x+y) % self.MODULUS + + def sub(self, x, y): + return (x-y) % self.MODULUS + + def mul(self, x, y): + return (x*y) % self.MODULUS + + def exp(self, x, p): + return pow(x, p, self.MODULUS) + + # Modular inverse using the extended Euclidean algorithm + def inv(self, a): + if a == 0: + return 0 + lm, hm = 1, 0 + low, high = a % self.MODULUS, self.MODULUS + while low > 1: + r = high//low + nm, new = hm-lm*r, high-low*r + lm, low, hm, high = nm, new, lm, low + return lm % self.MODULUS + + def multi_inv(self, values): + partials = [1] + for i in range(len(values)): + partials.append(self.mul(partials[-1], values[i] or 1)) + inv = self.inv(partials[-1]) + outputs = [0] * len(values) + for i in range(len(values), 0, -1): + outputs[i-1] = self.mul(partials[i-1], inv) if values[i-1] else 0 + inv = self.mul(inv, values[i-1] or 1) + return outputs + + def div(self, x, y): + return self.mul(x, self.inv(y)) + + # Evaluate a polynomial at a point + def eval_poly_at(self, p, x): + y = 0 + power_of_x = 1 + for i, p_coeff in enumerate(p): + y += power_of_x * p_coeff + power_of_x = (power_of_x * x) % self.MODULUS + return y % self.MODULUS + + # Arithmetic for polynomials + def add_polys(self, a, b): + return [((a[i] if i < len(a) else 0) + (b[i] if i < len(b) else 0)) + % self.MODULUS for i in range(max(len(a), len(b)))] + + def sub_polys(self, a, b): + return [((a[i] if i < len(a) else 0) - (b[i] if i < len(b) else 0)) + % self.MODULUS for i in range(max(len(a), len(b)))] + + def mul_by_const(self, a, c): + return [(x*c) % self.MODULUS for x in a] + + def mul_polys(self, a, b): + o = [0] * (len(a) + len(b) - 1) + for i, aval in enumerate(a): + for j, bval in enumerate(b): + o[i+j] += a[i] * b[j] + return [x % self.MODULUS for x in o] + + def div_polys(self, a, b): + assert len(a) >= len(b) + a = [x for x in a] + o = [] + apos = len(a) - 1 + bpos = len(b) - 1 + diff = apos - bpos + while diff >= 0: + quot = self.div(a[apos], b[bpos]) + o.insert(0, quot) + for i in range(bpos, -1, -1): + a[diff+i] -= b[i] * quot + apos -= 1 + diff -= 1 + return [x % self.MODULUS for x in o] + + def mod_polys(self, a, b): + return self.sub_polys(a, self.mul_polys(b, self.div_polys(a, b)))[:len(b)-1] + + # Build a polynomial from a few coefficients + def sparse(self, coeff_dict): + o = [0] * (max(coeff_dict.keys()) + 1) + for k, v in coeff_dict.items(): + o[k] = v % self.MODULUS + return o + + # Build a polynomial that returns 0 at all specified xs + def zpoly(self, xs): + root = [1] + for x in xs: + root.insert(0, 0) + for j in range(len(root)-1): + root[j] -= root[j+1] * x + return [x % self.MODULUS for x in root] + + # Given p+1 y values and x values with no errors, recovers the original + # p+1 degree polynomial. + # Lagrange interpolation works roughly in the following way. + # 1. Suppose you have a set of points, eg. x = [1, 2, 3], y = [2, 5, 10] + # 2. For each x, generate a polynomial which equals its corresponding + # y coordinate at that point and 0 at all other points provided. + # 3. Add these polynomials together. + + def lagrange_interp(self, xs, ys): + # Generate master numerator polynomial, eg. (x - x1) * (x - x2) * ... * (x - xn) + root = self.zpoly(xs) + assert len(root) == len(ys) + 1 + # print(root) + # Generate per-value numerator polynomials, eg. for x=x2, + # (x - x1) * (x - x3) * ... * (x - xn), by dividing the master + # polynomial back by each x coordinate + nums = [self.div_polys(root, [-x, 1]) for x in xs] + # Generate denominators by evaluating numerator polys at each x + denoms = [self.eval_poly_at(nums[i], xs[i]) for i in range(len(xs))] + invdenoms = self.multi_inv(denoms) + # Generate output polynomial, which is the sum of the per-value numerator + # polynomials rescaled to have the right y values + b = [0 for y in ys] + for i in range(len(xs)): + yslice = self.mul(ys[i], invdenoms[i]) + for j in range(len(ys)): + if nums[i][j] and ys[i]: + b[j] += nums[i][j] * yslice + return [x % self.MODULUS for x in b] + + # Optimized poly evaluation for degree 4 + def eval_quartic(self, p, x): + xsq = x * x % self.MODULUS + xcb = xsq * x + return (p[0] + p[1] * x + p[2] * xsq + p[3] * xcb) % self.MODULUS + + # Optimized version of the above restricted to deg-4 polynomials + def lagrange_interp_4(self, xs, ys): + x01, x02, x03, x12, x13, x23 = \ + xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3] + m = self.MODULUS + eq0 = [-x12 * xs[3] % m, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1] + eq1 = [-x02 * xs[3] % m, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1] + eq2 = [-x01 * xs[3] % m, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1] + eq3 = [-x01 * xs[2] % m, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1] + e0 = self.eval_poly_at(eq0, xs[0]) + e1 = self.eval_poly_at(eq1, xs[1]) + e2 = self.eval_poly_at(eq2, xs[2]) + e3 = self.eval_poly_at(eq3, xs[3]) + e01 = e0 * e1 + e23 = e2 * e3 + invall = self.inv(e01 * e23) + inv_y0 = ys[0] * invall * e1 * e23 % m + inv_y1 = ys[1] * invall * e0 * e23 % m + inv_y2 = ys[2] * invall * e01 * e3 % m + inv_y3 = ys[3] * invall * e01 * e2 % m + return [(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % m for i in range(4)] + + # Optimized version of the above restricted to deg-2 polynomials + def lagrange_interp_2(self, xs, ys): + m = self.MODULUS + eq0 = [-xs[1] % m, 1] + eq1 = [-xs[0] % m, 1] + e0 = self.eval_poly_at(eq0, xs[0]) + e1 = self.eval_poly_at(eq1, xs[1]) + invall = self.inv(e0 * e1) + inv_y0 = ys[0] * invall * e1 + inv_y1 = ys[1] * invall * e0 + return [(eq0[i] * inv_y0 + eq1[i] * inv_y1) % m for i in range(2)] + + # Optimized version of the above restricted to deg-4 polynomials + def multi_interp_4(self, xsets, ysets): + data = [] + invtargets = [] + for xs, ys in zip(xsets, ysets): + x01, x02, x03, x12, x13, x23 = \ + xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3] + m = self.MODULUS + eq0 = [-x12 * xs[3] % m, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1] + eq1 = [-x02 * xs[3] % m, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1] + eq2 = [-x01 * xs[3] % m, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1] + eq3 = [-x01 * xs[2] % m, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1] + e0 = self.eval_quartic(eq0, xs[0]) + e1 = self.eval_quartic(eq1, xs[1]) + e2 = self.eval_quartic(eq2, xs[2]) + e3 = self.eval_quartic(eq3, xs[3]) + data.append([ys, eq0, eq1, eq2, eq3]) + invtargets.extend([e0, e1, e2, e3]) + invalls = self.multi_inv(invtargets) + o = [] + for (i, (ys, eq0, eq1, eq2, eq3)) in enumerate(data): + invallz = invalls[i*4:i*4+4] + inv_y0 = ys[0] * invallz[0] % m + inv_y1 = ys[1] * invallz[1] % m + inv_y2 = ys[2] * invallz[2] % m + inv_y3 = ys[3] * invallz[3] % m + o.append([(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % m for i in range(4)]) + # assert o == [self.lagrange_interp_4(xs, ys) for xs, ys in zip(xsets, ysets)] + return o + + +if __name__ == "__main__": + primefield = PrimeField(11, [0,1,2,3]) + for i, x in enumerate(primefield.DOMAIN): + assert primefield.eval_poly_at(primefield.lagrange_polys[i], x) == 1 + for y in primefield.DOMAIN[:i] + primefield.DOMAIN[i+1:]: + assert primefield.eval_poly_at(primefield.lagrange_polys[i], y) == 0 + + poly = [3, 4, 3, 2] + poly_eval = [primefield.eval_poly_at(poly, x) for x in primefield.DOMAIN] + + assert primefield.eval_poly_at(poly, 5) == primefield.evaluate_polynomial_in_evaluation_form(poly_eval, 5) + + poly_eval_quotient = primefield.compute_inner_quotient_in_evaluation_form(poly_eval, 2) + + poly_quotient = primefield.div_polys([poly[0] - poly_eval[2]] + poly[1:], [-3, 1]) \ No newline at end of file diff --git a/verkle_trie_pedersen/verkle_trie.py b/verkle_trie_pedersen/verkle_trie.py new file mode 100644 index 00000000..5d3158d1 --- /dev/null +++ b/verkle_trie_pedersen/verkle_trie.py @@ -0,0 +1,677 @@ +import pippenger +import blst +import hashlib +from random import randint, shuffle +from poly_utils import PrimeField +from time import time +from kzg_utils import KzgUtils +from fft import fft +import sys + +# +# Proof of concept implementation for verkle tries +# +# All polynomials in this implementation are represented in evaluation form, i.e. by their values +# on DOMAIN. See kzg_utils.py for more explanation +# + +# BLS12_381 curve modulus +MODULUS = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001 + +# Primitive root for the field +PRIMITIVE_ROOT = 5 + +assert pow(PRIMITIVE_ROOT, (MODULUS - 1) // 2, MODULUS) != 1 +assert pow(PRIMITIVE_ROOT, MODULUS - 1, MODULUS) == 1 + +primefield = PrimeField(MODULUS) + +# Verkle trie parameters +KEY_LENGTH = 256 # bits +WIDTH_BITS = 10 +WIDTH = 2**WIDTH_BITS + +ROOT_OF_UNITY = pow(PRIMITIVE_ROOT, (MODULUS - 1) // WIDTH, MODULUS) +DOMAIN = [pow(ROOT_OF_UNITY, i, MODULUS) for i in range(WIDTH)] + +# Number of key-value pairs to insert +NUMBER_INITIAL_KEYS = 2**15 + +# Number of keys to insert after computing initial tree +NUMBER_ADDED_KEYS = 512 + +# Number of keys to delete +NUMBER_DELETED_KEYS = 512 + +# Number of key/values pair in proof +NUMBER_KEYS_PROOF = 5000 + +def generate_setup(size, secret): + """ + Generates a setup in the G1 group and G2 group, as well as the Lagrange polynomials in G1 (via FFT) + """ + g1_setup = [blst.G1().mult(pow(secret, i, MODULUS)) for i in range(size)] + g2_setup = [blst.G2().mult(pow(secret, i, MODULUS)) for i in range(size)] + g1_lagrange = fft(g1_setup, MODULUS, ROOT_OF_UNITY, inv=True) + return {"g1": g1_setup, "g2": g2_setup, "g1_lagrange": g1_lagrange} + + +def get_verkle_indices(key): + """ + Generates the list of verkle indices for key + """ + x = int.from_bytes(key, "big") + last_index_bits = KEY_LENGTH % WIDTH_BITS + index = (x % (2**last_index_bits)) << (WIDTH_BITS - last_index_bits) + x //= 2**last_index_bits + indices = [index] + for i in range((KEY_LENGTH - 1) // WIDTH_BITS): + index = x % WIDTH + x //= WIDTH + indices.append(index) + return tuple(reversed(indices)) + + +def hash(x): + if isinstance(x, bytes): + return hashlib.sha256(x).digest() + elif isinstance(x, blst.P1): + return hash(x.compress()) + b = b"" + for a in x: + if isinstance(a, bytes): + b += a + elif isinstance(a, int): + b += a.to_bytes(32, "little") + elif isinstance(a, blst.P1): + b += hash(a.compress()) + return hash(b) + + +def hash_to_int(x): + return int.from_bytes(hash(x), "little") + + +def insert_verkle_node(root, key, value): + """ + Insert node without updating hashes/commitments (useful for building a full trie) + """ + current_node = root + indices = iter(get_verkle_indices(key)) + index = None + while current_node["node_type"] == "inner": + previous_node = current_node + previous_index = index + index = next(indices) + if index in current_node: + current_node = current_node[index] + else: + current_node[index] = {"node_type": "leaf", "key": key, "value": value} + return + if current_node["key"] == key: + current_node["value"] = value + else: + previous_node[index] = {"node_type": "inner", "commitment": blst.G1().mult(0)} + insert_verkle_node(root, key, value) + insert_verkle_node(root, current_node["key"], current_node["value"]) + + +def update_verkle_node(root, key, value): + """ + Update or insert node and update all commitments and hashes + """ + current_node = root + indices = iter(get_verkle_indices(key)) + index = None + path = [] + + new_node = {"node_type": "leaf", "key": key, "value": value} + add_node_hash(new_node) + + while True: + index = next(indices) + path.append((index, current_node)) + if index in current_node: + if current_node[index]["node_type"] == "leaf": + old_node = current_node[index] + if current_node[index]["key"] == key: + current_node[index] = new_node + value_change = (MODULUS + int.from_bytes(new_node["hash"], "little") + - int.from_bytes(old_node["hash"], "little")) % MODULUS + break + else: + new_inner_node = {"node_type": "inner"} + new_index = next(indices) + old_index = get_verkle_indices(old_node["key"])[len(path)] + # TODO! Handle old_index == new_index + assert old_index != new_index + new_inner_node[new_index] = new_node + new_inner_node[old_index] = old_node + add_node_hash(new_inner_node) + current_node[index] = new_inner_node + value_change = (MODULUS + int.from_bytes(new_inner_node["hash"], "little") + - int.from_bytes(old_node["hash"], "little")) % MODULUS + break + current_node = current_node[index] + else: + current_node[index] = new_node + value_change = int.from_bytes(new_node["hash"], "little") % MODULUS + break + + # Update all the parent commitments along 'path' + for index, node in reversed(path): + node["commitment"].add(SETUP["g1_lagrange"][index].dup().mult(value_change)) + old_hash = node["hash"] + new_hash = hash(node["commitment"]) + node["hash"] = new_hash + value_change = (MODULUS + int.from_bytes(new_hash, "little") + - int.from_bytes(old_hash, "little")) % MODULUS + + +def get_only_child(node): + """ + Returns the only child of a node which has only one child. Returns 'None' if node has 0 or >1 children + """ + child_count = 0 + only_child = None + for key in node: + if isinstance(key, int): + child_count += 1 + only_child = node[key] + return only_child if child_count == 1 else None + + +def delete_verkle_node(root, key): + """ + Delete node and update all commitments and hashes + """ + current_node = root + indices = iter(get_verkle_indices(key)) + index = None + path = [] + + while True: + index = next(indices) + path.append((index, current_node)) + assert index in current_node, "Tried to delete non-existent key" + if current_node[index]["node_type"] == "leaf": + deleted_node = current_node[index] + assert deleted_node["key"] == key, "Tried to delete non-existent key" + del current_node[index] + value_change = (MODULUS - int.from_bytes(deleted_node["hash"], "little")) % MODULUS + break + current_node = current_node[index] + + # Update all the parent commitments along 'path' + replacement_node = None + for index, node in reversed(path): + if replacement_node != None: + node[index] = replacement_node + replacement_node = None + only_child = get_only_child(node) + if only_child != None and only_child["node_type"] == "leaf" and node != root: + replacement_node = only_child + value_change = (MODULUS + int.from_bytes(only_child["hash"], "little") + - int.from_bytes(node["hash"], "little")) % MODULUS + else: + node["commitment"].add(SETUP["g1_lagrange"][index].dup().mult(value_change)) + old_hash = node["hash"] + new_hash = hash(node["commitment"]) + node["hash"] = new_hash + value_change = (MODULUS + int.from_bytes(new_hash, "little") + - int.from_bytes(old_hash, "little")) % MODULUS + + +def add_node_hash(node): + """ + Recursively adds all missing commitments and hashes to a verkle trie structure. + """ + if node["node_type"] == "leaf": + node["hash"] = hash([node["key"], node["value"]]) + if node["node_type"] == "inner": + lagrange_polynomials = [] + values = {} + for i in range(WIDTH): + if i in node: + if "hash" not in node[i]: + add_node_hash(node[i]) + values[i] = int.from_bytes(node[i]["hash"], "little") + commitment = kzg_utils.compute_commitment_lagrange(values) + node["commitment"] = commitment + node["hash"] = hash(commitment.compress()) + + +def get_total_depth(root): + """ + Computes the total depth (sum of the depth of all nodes) of a verkle trie + """ + if root["node_type"] == "inner": + total_depth = 0 + num_nodes = 0 + for i in range(WIDTH): + if i in root: + depth, nodes = get_total_depth(root[i]) + num_nodes += nodes + total_depth += nodes + depth + return total_depth, num_nodes + else: + return 0, 1 + + +def check_valid_tree(root, is_trie_root=True): + """ + Checks that the tree is valid + """ + if root["node_type"] == "inner": + if not is_trie_root: + only_child = get_only_child(root) + if only_child is not None: + assert only_child["node_type"] == "inner" + + lagrange_polynomials = [] + values = {} + for i in range(WIDTH): + if i in root: + if "hash" not in root[i]: + add_node_hash(node[i]) + values[i] = int.from_bytes(root[i]["hash"], "little") + commitment = kzg_utils.compute_commitment_lagrange(values) + assert root["commitment"].is_equal(commitment) + assert root["hash"] == hash(commitment.compress()) + + for i in range(WIDTH): + if i in root: + check_valid_tree(root[i], False) + else: + assert root["hash"] == hash([root["key"], root["value"]]) + + +def get_average_depth(trie): + """ + Get the average depth of nodes in a verkle trie + """ + depth, nodes = get_total_depth(trie) + return depth / nodes + + +def find_node(root, key): + """ + Finds 'key' in verkle trie. Returns the full node (not just the value) or None if not present + """ + current_node = root + indices = iter(get_verkle_indices(key)) + while current_node["node_type"] == "inner": + index = next(indices) + if index in current_node: + current_node = current_node[index] + else: + return None + if current_node["key"] == key: + return current_node + return None + + +def find_node_with_path(root, key): + """ + As 'find_node', but returns the path of all nodes on the way to 'key' as well as their index + """ + current_node = root + indices = iter(get_verkle_indices(key)) + path = [] + current_index_path = [] + while current_node["node_type"] == "inner": + index = next(indices) + path.append((tuple(current_index_path), index, current_node)) + current_index_path.append(index) + if index in current_node: + current_node = current_node[index] + else: + return path, None + if current_node["key"] == key: + return path, current_node + return path, None + + +def get_proof_size(proof): + depths, commitments_sorted_by_index_serialized, D_serialized, y, sigma_serialized = proof + size = len(depths) # assume 8 bit integer to represent the depth + size += 48 * len(commitments_sorted_by_index_serialized) + size += 48 + 32 + 48 + return size + +lasttime = [0] + + +def start_logging_time_if_eligible(string, eligible): + if eligible: + print(string, file=sys.stderr) + lasttime[0] = time() + + +def log_time_if_eligible(string, width, eligible): + if eligible: + print(string + ' ' * max(1, width - len(string)) + "{0:7.3f} s".format(time() - lasttime[0]), file=sys.stderr) + lasttime[0] = time() + + +def make_kzg_multiproof(Cs, fs, indices, ys, display_times=True): + """ + Computes a KZG multiproof according to the schema described here: + https://notes.ethereum.org/nrQqhVpQRi6acQckwm1Ryg + + zs[i] = DOMAIN[indexes[i]] + """ + + # Step 1: Construct g(X) polynomial in evaluation form + r = hash_to_int([hash(C) for C in Cs] + indices + ys) % MODULUS + + log_time_if_eligible(" Hashed to r", 30, display_times) + + g = [0 for i in range(WIDTH)] + power_of_r = 1 + for f, index in zip(fs, indices): + quotient = kzg_utils.compute_inner_quotient_in_evaluation_form(f, index) + for i in range(WIDTH): + g[i] += power_of_r * quotient[i] + + power_of_r = power_of_r * r % MODULUS + + log_time_if_eligible(" Computed g polynomial", 30, display_times) + + D = kzg_utils.compute_commitment_lagrange({i: v for i, v in enumerate(g)}) + + log_time_if_eligible(" Computed commitment D", 30, display_times) + + # Step 2: Compute f in evaluation form + + t = hash_to_int([r, D]) % MODULUS + + h = [0 for i in range(WIDTH)] + power_of_r = 1 + + for f, index in zip(fs, indices): + denominator_inv = primefield.inv(t - DOMAIN[index]) + for i in range(WIDTH): + h[i] += power_of_r * f[i] * denominator_inv % MODULUS + + power_of_r = power_of_r * r % MODULUS + + log_time_if_eligible(" Computed h polynomial", 30, display_times) + + # Step 3: Evaluate and compute KZG proofs + + y, pi = kzg_utils.evaluate_and_compute_kzg_proof(h, t) + w, rho = kzg_utils.evaluate_and_compute_kzg_proof(g, t) + + + # Compress both proofs into one + + E = kzg_utils.compute_commitment_lagrange({i: v for i, v in enumerate(h)}) + q = hash_to_int([E, D, y, w]) + sigma = pi.dup().add(rho.dup().mult(q)) + + log_time_if_eligible(" Computed KZG proofs", 30, display_times) + + return D.compress(), y, sigma.compress() + + +def check_kzg_multiproof(Cs, indices, ys, proof, display_times=True): + """ + Verifies a KZG multiproof according to the schema described here: + https://notes.ethereum.org/nrQqhVpQRi6acQckwm1Ryg + """ + + D_serialized, y, sigma_serialized = proof + D = blst.P1(D_serialized) + sigma = blst.P1(sigma_serialized) + + # Step 1 + r = hash_to_int([hash(C) for C in Cs] + indices + ys) + + log_time_if_eligible(" Computed r hash", 30, display_times) + + # Step 2 + t = hash_to_int([r, D]) + E_coefficients = [] + g_2_of_t = 0 + power_of_r = 1 + + for index, y in zip(indices, ys): + E_coefficient = primefield.div(power_of_r, t - DOMAIN[index]) + E_coefficients.append(E_coefficient) + g_2_of_t += E_coefficient * y % MODULUS + + power_of_r = power_of_r * r % MODULUS + + log_time_if_eligible(" Computed g2 and e coeffs", 30, display_times) + + E = pippenger.pippenger_simple(Cs, E_coefficients) + + log_time_if_eligible(" Computed E commitment", 30, display_times) + + # Step 3 (Check KZG proofs) + w = (y - g_2_of_t) % MODULUS + + q = hash_to_int([E, D, y, w]) + + if not kzg_utils.check_kzg_proof(E.dup().add(D.dup().mult(q)), t, y + q * w, sigma): + return False + + log_time_if_eligible(" Checked KZG proofs", 30, display_times) + + return True + + +def make_verkle_proof(trie, keys, display_times=True): + """ + Creates a proof for the 'keys' in the verkle trie given by 'trie' + """ + + start_logging_time_if_eligible(" Starting proof computation", display_times) + + # Step 0: Find all keys in the trie + nodes_by_index = {} + nodes_by_index_and_subindex = {} + values = [] + depths = [] + for key in keys: + path, node = find_node_with_path(trie, key) + depths.append(len(path)) + values.append(node["value"]) + for index, subindex, node in path: + nodes_by_index[index] = node + nodes_by_index_and_subindex[(index, subindex)] = node + + log_time_if_eligible(" Computed key paths", 30, display_times) + + # All commitments, but without any duplications. These are for sending over the wire as part of the proof + nodes_sorted_by_index = list(map(lambda x: x[1], sorted(nodes_by_index.items()))) + + # Nodes sorted + nodes_sorted_by_index_and_subindex = list(map(lambda x: x[1], sorted(nodes_by_index_and_subindex.items()))) + + indices = list(map(lambda x: x[0][1], sorted(nodes_by_index_and_subindex.items()))) + + ys = list(map(lambda x: int.from_bytes(x[1][x[0][1]]["hash"], "little"), sorted(nodes_by_index_and_subindex.items()))) + + log_time_if_eligible(" Sorted all commitments", 30, display_times) + + fs = [] + Cs = [x["commitment"] for x in nodes_sorted_by_index_and_subindex] + + for node in nodes_sorted_by_index_and_subindex: + fs.append([int.from_bytes(node[i]["hash"], "little") if i in node else 0 for i in range(WIDTH)]) + + D, y, sigma = make_kzg_multiproof(Cs, fs, indices, ys, display_times) + + commitments_sorted_by_index_serialized = [x["commitment"].compress() for x in nodes_sorted_by_index[1:]] + + log_time_if_eligible(" Serialized commitments", 30, display_times) + + return depths, commitments_sorted_by_index_serialized, D, y, sigma + + +def check_verkle_proof(trie, keys, values, proof, display_times=True): + """ + Checks Verkle tree proof according to + https://notes.ethereum.org/nrQqhVpQRi6acQckwm1Ryg?both + """ + + start_logging_time_if_eligible(" Starting proof check", display_times) + + # Unpack the proof + depths, commitments_sorted_by_index_serialized, D_serialized, y, sigma_serialized = proof + commitments_sorted_by_index = [blst.P1(trie)] + [blst.P1(x) for x in commitments_sorted_by_index_serialized] + + all_indices = set() + all_indices_and_subindices = set() + + leaf_values_by_index_and_subindex = {} + + # Find all required indices + for key, value, depth in zip(keys, values, depths): + verkle_indices = get_verkle_indices(key) + for i in range(depth): + all_indices.add(verkle_indices[:i]) + all_indices_and_subindices.add((verkle_indices[:i], verkle_indices[i])) + leaf_values_by_index_and_subindex[(verkle_indices[:depth - 1], verkle_indices[depth - 1])] = hash([key, value]) + + all_indices = sorted(all_indices) + all_indices_and_subindices = sorted(all_indices_and_subindices) + + log_time_if_eligible(" Computed indices", 30, display_times) + + # Step 0: recreate the commitment list sorted by indices + commitments_by_index = {index: commitment for index, commitment in zip(all_indices, commitments_sorted_by_index)} + commitments_by_index_and_subindex = {index_and_subindex: commitments_by_index[index_and_subindex[0]] + for index_and_subindex in all_indices_and_subindices} + + subhashes_by_index_and_subindex = {} + for index_and_subindex in all_indices_and_subindices: + full_subindex = index_and_subindex[0] + (index_and_subindex[1],) + if full_subindex in commitments_by_index: + subhashes_by_index_and_subindex[index_and_subindex] = hash(commitments_by_index[full_subindex]) + else: + subhashes_by_index_and_subindex[index_and_subindex] = leaf_values_by_index_and_subindex[index_and_subindex] + + Cs = list(map(lambda x: x[1], sorted(commitments_by_index_and_subindex.items()))) + + indices = list(map(lambda x: x[1], sorted(all_indices_and_subindices))) + + ys = list(map(lambda x: int.from_bytes(x[1], "little"), sorted(subhashes_by_index_and_subindex.items()))) + + log_time_if_eligible(" Recreated commitment lists", 30, display_times) + + return check_kzg_multiproof(Cs, indices, ys, [D_serialized, y, sigma_serialized], display_times) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + WIDTH_BITS = int(sys.argv[1]) + WIDTH = 2 ** WIDTH_BITS + ROOT_OF_UNITY = pow(PRIMITIVE_ROOT, (MODULUS - 1) // WIDTH, MODULUS) + DOMAIN = [pow(ROOT_OF_UNITY, i, MODULUS) for i in range(WIDTH)] + + NUMBER_INITIAL_KEYS = int(sys.argv[2]) + + NUMBER_KEYS_PROOF = int(sys.argv[3]) + + NUMBER_DELETED_KEYS = 0 + NUMBER_ADDED_KEYS = 0 + + SETUP = generate_setup(WIDTH, 8927347823478352432985) + kzg_utils = KzgUtils(MODULUS, WIDTH, DOMAIN, SETUP, primefield) + + + # Build a random verkle trie + root = {"node_type": "inner", "commitment": blst.G1().mult(0)} + + values = {} + + for i in range(NUMBER_INITIAL_KEYS): + key = randint(0, 2**256-1).to_bytes(32, "little") + value = randint(0, 2**256-1).to_bytes(32, "little") + insert_verkle_node(root, key, value) + values[key] = value + + average_depth = get_average_depth(root) + + print("Inserted {0} elements for an average depth of {1:.3f}".format(NUMBER_INITIAL_KEYS, average_depth), file=sys.stderr) + + time_a = time() + add_node_hash(root) + time_b = time() + + print("Computed verkle root in {0:.3f} s".format(time_b - time_a), file=sys.stderr) + + if NUMBER_ADDED_KEYS > 0: + + time_a = time() + check_valid_tree(root) + time_b = time() + + print("[Checked tree valid: {0:.3f} s]".format(time_b - time_a), file=sys.stderr) + + time_x = time() + for i in range(NUMBER_ADDED_KEYS): + key = randint(0, 2**256-1).to_bytes(32, "little") + value = randint(0, 2**256-1).to_bytes(32, "little") + update_verkle_node(root, key, value) + values[key] = value + time_y = time() + + print("Additionally inserted {0} elements in {1:.3f} s".format(NUMBER_ADDED_KEYS, time_y - time_x), file=sys.stderr) + print("Keys in tree now: {0}, average depth: {1:.3f}".format(get_total_depth(root)[1], get_average_depth(root)), file=sys.stderr) + + time_a = time() + check_valid_tree(root) + time_b = time() + + print("[Checked tree valid: {0:.3f} s]".format(time_b - time_a), file=sys.stderr) + + if NUMBER_DELETED_KEYS > 0: + + all_keys = list(values.keys()) + shuffle(all_keys) + + keys_to_delete = all_keys[:NUMBER_DELETED_KEYS] + + time_a = time() + for key in keys_to_delete: + delete_verkle_node(root, key) + del values[key] + time_b = time() + + print("Deleted {0} elements in {1:.3f} s".format(NUMBER_DELETED_KEYS, time_b - time_a), file=sys.stderr) + print("Keys in tree now: {0}, average depth: {1:.3f}".format(get_total_depth(root)[1], get_average_depth(root)), file=sys.stderr) + + + time_a = time() + check_valid_tree(root) + time_b = time() + + print("[Checked tree valid: {0:.3f} s]".format(time_b - time_a), file=sys.stderr) + + + all_keys = list(values.keys()) + shuffle(all_keys) + + keys_in_proof = all_keys[:NUMBER_KEYS_PROOF] + + time_a = time() + proof = make_verkle_proof(root, keys_in_proof) + time_b = time() + + proof_size = get_proof_size(proof) + proof_time = time_b - time_a + + print("Computed proof for {0} keys (size = {1} bytes) in {2:.3f} s".format(NUMBER_KEYS_PROOF, proof_size, time_b - time_a), file=sys.stderr) + + time_a = time() + check_verkle_proof(root["commitment"].compress(), keys_in_proof, [values[key] for key in keys_in_proof], proof) + time_b = time() + check_time = time_b - time_a + + print("Checked proof in {0:.3f} s".format(time_b - time_a), file=sys.stderr) + + print("{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\t{7}".format(WIDTH_BITS, WIDTH, NUMBER_INITIAL_KEYS, NUMBER_KEYS_PROOF, average_depth, proof_size, proof_time, check_time)) \ No newline at end of file