Skip to content

Commit

Permalink
Works with new zorch
Browse files Browse the repository at this point in the history
  • Loading branch information
vbuterin committed Aug 6, 2024
1 parent 1662796 commit ad7e7d0
Show file tree
Hide file tree
Showing 8 changed files with 331 additions and 523 deletions.
70 changes: 25 additions & 45 deletions circlestark/fast_fft.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,45 @@
from utils import (
np, array, zeros, tobytes, arange, append, log2,
reverse_bit_order,
to_ext_if_needed, to_extension_field
cp, reverse_bit_order, log2
)

from zorch.m31 import (
zeros, array, arange, append, tobytes, add, sub, mul, cp as np,
mul_ext, modinv_ext, sum as m31_sum, eq, iszero, M31
M31, ExtendedM31, Point, modulus, zeros_like, Z, G
)
from zorch import m31
from precomputes import rbos, invx, invy, sub_domains

# Converts a list of evaluations to a list of coefficients. Note that the
# coefficients are in a "weird" basis: 1, y, x, xy, 2x^2-1...
def fft(vals, is_top_level=True):
vals = np.copy(vals)
vals = vals.copy()
shape_suffix = vals.shape[1:]
size = vals.shape[0]
for i in range(log2(size)):
vals = np.reshape(vals, (1 << i, size >> i) + shape_suffix)
vals = vals.reshape((1 << i, size >> i) + shape_suffix)
full_len = vals.shape[1]
half_len = full_len >> 1
L = vals[:, :half_len]
R = np.flip(vals[:, half_len:], (1,))
f0 = m31.add(L, R)
R = vals[:, half_len:][:, ::-1, ...] # flip along axis 1
f0 = L + R
if i==0 and is_top_level:
twiddle = invy[full_len: full_len + half_len]
else:
twiddle = invx[full_len*2: full_len*2 + half_len]
twiddle_box = twiddle.reshape((1, half_len) + (1,) * (L.ndim - 2))
f1 = m31.mul(m31.sub(L, R), twiddle_box)
f1 = (L - R) * twiddle_box
vals[:, :half_len] = f0
vals[:, half_len:] = f1
inv_size = np.array((1 << (31-log2(size))) % m31.M31, dtype=np.uint32)
return m31.mul(
(vals.reshape((size,) + shape_suffix))[rbos[size:size*2]],
inv_size
return (
(vals.reshape((size,) + shape_suffix))[rbos[size:size*2]] / size
)

# Converts a list of coefficients into a list of evaluations
def inv_fft(vals):
vals = np.copy(vals)
vals = vals.copy()
shape_suffix = vals.shape[1:]
size = vals.shape[0]
vals = reverse_bit_order(vals)
for i in range(log2(size)-1, -1, -1):
vals = np.reshape(vals, (1 << i, size >> i) + shape_suffix)
vals = vals.reshape((1 << i, size >> i) + shape_suffix)
full_len = vals.shape[1]
half_len = full_len >> 1
f0 = vals[:, :half_len]
Expand All @@ -53,35 +48,27 @@ def inv_fft(vals):
twiddle = sub_domains[full_len: full_len + half_len].y
else:
twiddle = sub_domains[full_len*2: full_len*2 + half_len].x
f1_times_twiddle = m31.mul(
f1,
twiddle.reshape((1, half_len) + (1,) * (f0.ndim - 2))
f1_times_twiddle = (
f1 * twiddle.reshape((1, half_len) + (1,) * (f0.ndim - 2))
)
L = m31.add(f0, f1_times_twiddle)
R = m31.sub(f0, f1_times_twiddle)
L = f0 + f1_times_twiddle
R = f0 - f1_times_twiddle
vals[:, :half_len] = L
vals[:, half_len:] = np.flip(R, (1,))
return np.reshape(vals, (size,) + shape_suffix)
vals[:, half_len:] = R[:, ::-1, ...]
return vals.reshape((size,) + shape_suffix)

# Given a list of evaluations, computes the evaluation of that polynomial at
# one point. The point can be in the base field or extension field
def bary_eval(vals, pt, is_extended=False, first_round_optimize=False):
def bary_eval(vals, pt):
shape_suffix = vals.shape[1:]
size = vals.shape[0]
if is_extended:
pt = pt.to_extended()
mul = m31.mul_ext
one = array([1,0,0,0])
else:
mul = m31.mul
one = array(1)
for i in range(log2(size)):
#vals = np.reshape(vals, (1 << i, size >> i) + shape_suffix)
full_len = vals.shape[0]
half_len = full_len >> 1
L = vals[:half_len]
R = np.flip(vals[half_len:], (0,))
f0 = m31.add(L, R)
R = vals[half_len:][::-1]
f0 = L + R
if i == 0:
twiddle = invy[full_len: full_len + half_len]
baryfac = pt.y
Expand All @@ -90,15 +77,8 @@ def bary_eval(vals, pt, is_extended=False, first_round_optimize=False):
if i == 1:
baryfac = pt.x
else:
baryfac = m31.sub(2 * mul(baryfac, baryfac) % M31, one)
baryfac = baryfac * baryfac * 2 - 1
twiddle_box = twiddle.reshape((half_len,) + (1,) * (L.ndim - 1))
f1 = m31.mul(m31.sub(L, R), twiddle_box)
if first_round_optimize and i==0 and one.ndim==1:
vals = m31.add(
to_extension_field(f0),
m31.mul(baryfac, f1.reshape(f1.shape+(1,)))
)
else:
vals = m31.add(f0, mul(baryfac, f1))
inv_size = (1 << (31-log2(size))) % M31
return m31.mul(vals[0], np.array(inv_size, dtype=np.uint32))
f1 = (L - R) * twiddle_box
vals = f0 + baryfac * f1
return vals[0] / size
60 changes: 30 additions & 30 deletions circlestark/fast_fri.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from zorch.m31 import (
zeros, array, arange, append, tobytes, add, sub, mul, cp as np,
mul_ext, modinv_ext, sum as m31_sum, eq, M31, iszero
M31, ExtendedM31, Point, modulus, zeros_like, Z, G
)
from utils import (
log2, M31, M31SQ, HALF, to_extension_field,
np, tobytes, reverse_bit_order,
log2, HALF, cp, reverse_bit_order,
merkelize_top_dimension, get_challenges, rbo_index_to_original
)
from precomputes import folded_rbos, invx, invy
Expand All @@ -22,9 +20,9 @@
# by 8x
def fold(values, coeff, first_round):
for i in range(FOLDS_PER_ROUND):
full_len, half_len = values.shape[-2], values.shape[-2]//2
full_len, half_len = values.shape[-1], values.shape[-1]//2
left, right = values[::2], values[1::2]
f0 = mul(add(left, right), HALF)
f0 = (left + right) * HALF
if i == 0 and first_round:
twiddle = (
invy[full_len: full_len * 2]
Expand All @@ -35,10 +33,10 @@ def fold(values, coeff, first_round):
invx[full_len*2: full_len * 3]
[folded_rbos[full_len:full_len*2:2]]
)
twiddle_box = np.zeros_like(left)
twiddle_box = zeros_like(left)
twiddle_box[:] = twiddle.reshape((half_len,) + (1,) * (left.ndim-1))
f1 = mul(mul(sub(left, right), HALF), twiddle_box)
values = add(f0, mul_ext(f1, coeff))
f1 = (left - right) * HALF * twiddle_box
values = f0 + f1 * coeff
return values

# This performs the same folding step as above, but at a pre-supplied list
Expand All @@ -49,7 +47,7 @@ def fold_with_positions(values, domain_size, positions, coeff, first_round):
positions = positions[::2]
for i in range(FOLDS_PER_ROUND):
left, right = values[::2], values[1::2]
f0 = mul(add(left, right), HALF)
f0 = (left + right) * HALF
if i == 0 and first_round:
unrbo_positions = rbo_index_to_original(domain_size, positions)
twiddle = invy[domain_size + unrbo_positions]
Expand All @@ -59,17 +57,17 @@ def fold_with_positions(values, domain_size, positions, coeff, first_round):
(positions << 1) >> i
)
twiddle = invx[domain_size * 2 + unrbo_positions]
twiddle_box = np.zeros_like(left)
twiddle_box = zeros_like(left)
twiddle_box[:] = twiddle.reshape((left.shape[0],) + (1,)*(left.ndim-1))
f1 = mul(mul(sub(left, right), HALF), twiddle_box)
values = add(f0, mul_ext(f1, coeff))
f1 = (left - right) * HALF * twiddle_box
values = f0 + f1 * coeff
positions = positions[::2]
domain_size //= 2
return values

# Generate a FRI proof
def prove_low_degree(evaluations, extra_entropy=b''):
assert len(evaluations.shape) == 2 and evaluations.shape[-1] == 4
assert evaluations.ndim == 1 and isinstance(evaluations, ExtendedM31)
# Commit Merkle root
values = evaluations[folded_rbos[len(evaluations):len(evaluations)*2]]
leaves = []
Expand All @@ -87,16 +85,16 @@ def prove_low_degree(evaluations, extra_entropy=b''):
roots.append(trees[-1][1])
print('Root: 0x{}'.format(roots[-1].hex()))
print("Descent round {}: {} values".format(i+1, len(values)))
fold_factor = get_challenges(b''.join(roots), M31, 4)
fold_factor = ExtendedM31(get_challenges(b''.join(roots), modulus, 4))
print("Fold factor: {}".format(fold_factor))
values = fold(values, fold_factor, i==0)
entropy = extra_entropy + b''.join(roots) + tobytes(values)
entropy = extra_entropy + b''.join(roots) + values.tobytes()
challenges = get_challenges(
entropy, len(evaluations) >> FOLDS_PER_ROUND, NUM_CHALLENGES
)
round_challenges = (
challenges.reshape((1,)+challenges.shape)
>> arange(0, rounds * FOLDS_PER_ROUND, FOLDS_PER_ROUND)
>> cp.arange(0, rounds * FOLDS_PER_ROUND, FOLDS_PER_ROUND)
.reshape((rounds,) + (1,) * challenges.ndim)
)

Expand All @@ -106,7 +104,7 @@ def prove_low_degree(evaluations, extra_entropy=b''):
]
round_challenges_xfold = (
round_challenges.reshape(round_challenges.shape + (1,)) * 8
+ arange(FOLD_SIZE_RATIO).reshape(1, 1, FOLD_SIZE_RATIO)
+ cp.arange(FOLD_SIZE_RATIO).reshape(1, 1, FOLD_SIZE_RATIO)
)

leaf_values = [
Expand All @@ -128,23 +126,25 @@ def verify_low_degree(proof, extra_entropy=b''):
final_values = proof["final_values"]
len_evaluations = final_values.shape[0] << (FOLDS_PER_ROUND * len(roots))
print("Verifying FRI proof")
entropy = extra_entropy + b''.join(roots) + tobytes(final_values)
entropy = extra_entropy + b''.join(roots) + final_values.tobytes()
challenges = get_challenges(
entropy, len_evaluations >> FOLDS_PER_ROUND, NUM_CHALLENGES
)
# Re-run the descent at the pseudorandomly-chosen set of points, and
# verify consistency at each step
for i in range(len(roots)):
print("Descent round {}".format(i+1))
fold_factor = get_challenges(b''.join(roots[:i+1]), M31, 4)
fold_factor = ExtendedM31(
get_challenges(b''.join(roots[:i+1]), modulus, 4)
)
print("Fold factor: {}".format(fold_factor))
evaluation_size = len_evaluations >> (i * FOLDS_PER_ROUND)
positions = (
challenges.reshape((NUM_CHALLENGES, 1)) * FOLD_SIZE_RATIO
+ arange(FOLD_SIZE_RATIO)
+ cp.arange(FOLD_SIZE_RATIO)
).reshape((NUM_CHALLENGES * FOLD_SIZE_RATIO))
folded_values = fold_with_positions(
leaf_values[i].reshape((-1,4)),
leaf_values[i].reshape((-1,)),
evaluation_size,
positions,
fold_factor,
Expand All @@ -153,22 +153,22 @@ def verify_low_degree(proof, extra_entropy=b''):
if i < len(roots) - 1:
expected_values = (
leaf_values[i+1][
arange(NUM_CHALLENGES),
cp.arange(NUM_CHALLENGES),
challenges % FOLD_SIZE_RATIO
]
)
else:
expected_values = final_values[challenges]
assert np.array_equal(folded_values, expected_values)
assert folded_values == expected_values
# Also verify the Merkle branches
for j, c in enumerate(np.copy(challenges)):
for j, c in enumerate(cp.copy(challenges)):
assert verify_branch(
roots[i], c, tobytes(leaf_values[i][j]), branches[i][j]
roots[i], c, leaf_values[i][j].tobytes(), branches[i][j]
)
challenges >>= FOLDS_PER_ROUND
o = np.zeros_like(final_values)
o = zeros_like(final_values)
N = final_values.shape[0]
o[rbo_index_to_original(N, arange(N))] = final_values
coeffs = fft(o, is_top_level=False) % M31
assert iszero(coeffs[N//2:])
o[rbo_index_to_original(N, cp.arange(N))] = final_values
coeffs = fft(o, is_top_level=False)
assert coeffs[N//2:] == 0
return True
Loading

0 comments on commit ad7e7d0

Please sign in to comment.