Skip to content

Simply BaseFold verifier #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: cyte/fix-query-phase
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 127 additions & 98 deletions Cargo.lock

Large diffs are not rendered by default.

37 changes: 18 additions & 19 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@ openvm-native-circuit = { git = "https://github.com/scroll-tech/openvm.git", bra
openvm-native-compiler = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/native_multi_observe", default-features = false }
openvm-native-compiler-derive = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/native_multi_observe", default-features = false }
openvm-native-recursion = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/native_multi_observe", default-features = false }

openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.0.0", default-features = false }
openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.0.0", default-features = false }

rand = { version = "0.8.5", default-features = false }
itertools = { version = "0.13.0", default-features = false }
bincode = "1"
bincode = "1.3.3"
tracing = "0.1.40"

# Plonky3
Expand All @@ -39,26 +38,26 @@ ark-poly = "0.5"
ark-serialize = "0.5"

# Ceno
ceno_mle = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext", package = "multilinear_extensions" }
ceno_sumcheck = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext", package = "sumcheck" }
ceno_transcript = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext", package = "transcript" }
ceno_witness = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext", package = "witness" }
ceno_zkvm = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext" }
ceno_emul = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext" }
mpcs = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext" }
ff_ext = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/export_ff_ext" }
ceno_mle = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "multilinear_extensions" }
ceno_sumcheck = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "sumcheck" }
ceno_transcript = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "transcript" }
ceno_witness = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "witness" }
ceno_zkvm = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" }
ceno_emul = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" }
mpcs = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" }
ff_ext = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

[features]
bench-metrics = ["openvm-circuit/bench-metrics"]

# [patch."https://github.com/scroll-tech/ceno.git"]
# ceno_mle = { path = "../ceno/multilinear_extensions", package = "multilinear_extensions" }
# ceno_sumcheck = { path = "../ceno/sumcheck", package = "sumcheck" }
# ceno_transcript = { path = "../ceno/transcript", package = "transcript" }
# ceno_witness = { path = "../ceno/witness", package = "witness" }
# ceno_zkvm = { path = "../ceno/ceno_zkvm" }
# ceno_emul = { path = "../ceno/ceno_emul" }
# mpcs = { path = "../ceno/mpcs" }
# ff_ext = { path = "../ceno/ff_ext" }
[patch."https://github.com/scroll-tech/ceno.git"]
ceno_mle = { path = "../ceno/multilinear_extensions", package = "multilinear_extensions" }
ceno_sumcheck = { path = "../ceno/sumcheck", package = "sumcheck" }
ceno_transcript = { path = "../ceno/transcript", package = "transcript" }
ceno_witness = { path = "../ceno/witness", package = "witness" }
ceno_zkvm = { path = "../ceno/ceno_zkvm" }
ceno_emul = { path = "../ceno/ceno_emul" }
mpcs = { path = "../ceno/mpcs" }
ff_ext = { path = "../ceno/ff_ext" }
2 changes: 1 addition & 1 deletion rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[toolchain]
channel = "nightly-2025-01-06"
channel = "nightly-2025-03-25"
targets = ["riscv32im-unknown-none-elf"]
components = ["clippy", "rustfmt", "rust-src"]
187 changes: 137 additions & 50 deletions src/arithmetics/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::tower_verifier::binding::PointAndEvalVariable;
use crate::zkvm_verifier::binding::ZKVMOpcodeProofInputVariable;
use ceno_zkvm::expression::{Expression, Fixed, Instance};
use ceno_mle::expression::{Expression, Fixed, Instance};
use ceno_zkvm::structs::{ChallengeId, WitnessId};
use ff_ext::ExtensionField;
use ff_ext::{BabyBearExt4, SmallField};
use itertools::Either;
use openvm_native_compiler::prelude::*;
use openvm_native_compiler_derive::iter_zip;
use openvm_native_recursion::challenger::ChallengerVariable;
Expand All @@ -13,8 +14,9 @@ use openvm_native_recursion::challenger::{
use p3_field::{FieldAlgebra, FieldExtensionAlgebra};
type E = BabyBearExt4;
const HASH_RATE: usize = 8;
const MAX_NUM_VARS: usize = 25;

pub fn _print_ext_arr<C: Config>(builder: &mut Builder<C>, arr: &Array<C, Ext<C::F, C::EF>>) {
pub fn print_ext_arr<C: Config>(builder: &mut Builder<C>, arr: &Array<C, Ext<C::F, C::EF>>) {
iter_zip!(builder, arr).for_each(|ptr_vec, builder| {
let e = builder.iter_ptr_get(arr, ptr_vec[0]);
builder.print_e(e);
Expand All @@ -28,7 +30,7 @@ pub fn print_felt_arr<C: Config>(builder: &mut Builder<C>, arr: &Array<C, Felt<C
});
}

pub fn _print_usize_arr<C: Config>(builder: &mut Builder<C>, arr: &Array<C, Usize<C::N>>) {
pub fn print_usize_arr<C: Config>(builder: &mut Builder<C>, arr: &Array<C, Usize<C::N>>) {
iter_zip!(builder, arr).for_each(|ptr_vec, builder| {
let n = builder.iter_ptr_get(arr, ptr_vec[0]);
builder.print_v(n.get_var());
Expand Down Expand Up @@ -83,13 +85,11 @@ pub fn is_smaller_than<C: Config>(
RVar::from(v)
}

pub fn evaluate_at_point<C: Config>(
pub fn evaluate_at_point_degree_1<C: Config>(
builder: &mut Builder<C>,
evals: &Array<C, Ext<C::F, C::EF>>,
point: &Array<C, Ext<C::F, C::EF>>,
) -> Ext<C::F, C::EF> {
// TODO: Dynamic length
// TODO: Sanity checks
let left = builder.get(&evals, 0);
let right = builder.get(&evals, 1);
let r = builder.get(point, 0);
Expand All @@ -114,6 +114,80 @@ pub fn fixed_dot_product<C: Config>(
acc
}

pub struct PolyEvaluator<C: Config> {
powers_of_2: Array<C, Usize<C::N>>,
}

impl<C: Config> PolyEvaluator<C> {
pub fn new(builder: &mut Builder<C>) -> Self {
let powers_of_2: Array<C, Usize<C::N>> = builder.dyn_array(MAX_NUM_VARS);
builder.set(&powers_of_2, 0, Usize::from(16777216));
builder.set(&powers_of_2, 1, Usize::from(8388608));
builder.set(&powers_of_2, 2, Usize::from(4194304));
builder.set(&powers_of_2, 3, Usize::from(1048576));
builder.set(&powers_of_2, 4, Usize::from(2097152));
builder.set(&powers_of_2, 5, Usize::from(524288));
builder.set(&powers_of_2, 6, Usize::from(262144));
builder.set(&powers_of_2, 7, Usize::from(131072));
builder.set(&powers_of_2, 8, Usize::from(65536));
builder.set(&powers_of_2, 9, Usize::from(32768));
builder.set(&powers_of_2, 10, Usize::from(16384));
builder.set(&powers_of_2, 11, Usize::from(8192));
builder.set(&powers_of_2, 12, Usize::from(4096));
builder.set(&powers_of_2, 13, Usize::from(2048));
builder.set(&powers_of_2, 14, Usize::from(1024));
builder.set(&powers_of_2, 15, Usize::from(512));
builder.set(&powers_of_2, 16, Usize::from(256));
builder.set(&powers_of_2, 17, Usize::from(128));
builder.set(&powers_of_2, 18, Usize::from(64));
builder.set(&powers_of_2, 19, Usize::from(32));
builder.set(&powers_of_2, 20, Usize::from(16));
builder.set(&powers_of_2, 21, Usize::from(8));
builder.set(&powers_of_2, 22, Usize::from(4));
builder.set(&powers_of_2, 23, Usize::from(2));
builder.set(&powers_of_2, 24, Usize::from(1));

Self { powers_of_2 }
}

pub fn evaluate_base_poly_at_point(
&self,
builder: &mut Builder<C>,
evals: &Array<C, Felt<C::F>>,
point: &Array<C, Ext<C::F, C::EF>>,
) -> Ext<C::F, C::EF> {
let num_var = point.len();

let evals_ext: Array<C, Ext<C::F, C::EF>> = builder.dyn_array(evals.len());
iter_zip!(builder, evals, evals_ext).for_each(|ptr_vec, builder| {
let f = builder.iter_ptr_get(&evals, ptr_vec[0]);
let e = builder.ext_from_base_slice(&[f]);
builder.iter_ptr_set(&evals_ext, ptr_vec[1], e);
});

let pwr_slice_idx: Usize<C::N> = builder.eval(Usize::from(25) - num_var);
let pwrs = self.powers_of_2.slice(builder, pwr_slice_idx, MAX_NUM_VARS);

iter_zip!(builder, point, pwrs).for_each(|ptr_vec, builder| {
let pt = builder.iter_ptr_get(&point, ptr_vec[0]);
let idx_bound = builder.iter_ptr_get(&pwrs, ptr_vec[1]);

builder.range(0, idx_bound).for_each(|idx_vec, builder| {
let left_idx: Usize<C::N> = builder.eval(idx_vec[0] * Usize::from(2));
let right_idx: Usize<C::N> =
builder.eval(idx_vec[0] * Usize::from(2) + Usize::from(1));
let left = builder.get(&evals_ext, left_idx);
let right = builder.get(&evals_ext, right_idx);

let e: Ext<C::F, C::EF> = builder.eval(pt * (right - left) + left);
builder.set(&evals_ext, idx_vec[0], e);
});
});

builder.get(&evals_ext, 0)
}
}

pub fn dot_product<C: Config>(
builder: &mut Builder<C>,
a: &Array<C, Ext<C::F, C::EF>>,
Expand Down Expand Up @@ -261,6 +335,24 @@ pub fn product<C: Config>(
acc
}

// Multiply all elements in a nested Array
pub fn nested_product<C: Config>(
builder: &mut Builder<C>,
arr: &Array<C, Array<C, Ext<C::F, C::EF>>>,
) -> Ext<C::F, C::EF> {
let acc = builder.constant(C::EF::ONE);
iter_zip!(builder, arr).for_each(|ptr_vec, builder| {
let inner_arr = builder.iter_ptr_get(arr, ptr_vec[0]);

iter_zip!(builder, inner_arr).for_each(|ptr_vec, builder| {
let el = builder.iter_ptr_get(&inner_arr, ptr_vec[0]);
builder.assign(&acc, acc * el);
});
});

acc
}

// Add all elements in the Array
pub fn sum<C: Config>(
builder: &mut Builder<C>,
Expand Down Expand Up @@ -334,12 +426,13 @@ pub fn eq_eval_less_or_equal_than<C: Config>(
a: &Array<C, Ext<C::F, C::EF>>,
b: &Array<C, Ext<C::F, C::EF>>,
) -> Ext<C::F, C::EF> {
builder.cycle_tracker_start("Compute eq_eval_less_or_equal_than");
let eq_bit_decomp: Array<C, Felt<C::F>> = opcode_proof
.num_instances_minus_one_bit_decomposition
.slice(builder, 0, b.len());

let one_ext: Ext<C::F, C::EF> = builder.constant(C::EF::ONE);
let rp_len = builder.eval_expr(RVar::from(b.len()) + RVar::from(1));
let rp_len = builder.eval_expr(b.len() + C::N::ONE);
let running_product: Array<C, Ext<C::F, C::EF>> = builder.dyn_array(rp_len);
builder.set(&running_product, 0, one_ext);

Expand All @@ -353,49 +446,33 @@ pub fn eq_eval_less_or_equal_than<C: Config>(
builder.set(&running_product, next_idx, next_v);
});

let running_product2: Array<C, Ext<C::F, C::EF>> = builder.dyn_array(rp_len);
builder.set(&running_product2, b.len(), one_ext);

let eq_bit_decomp_rev = reverse(builder, &eq_bit_decomp);
let idx_arr = gen_idx_arr(builder, b.len());
let idx_arr_rev = reverse(builder, &idx_arr);
builder.assert_usize_eq(eq_bit_decomp_rev.len(), idx_arr_rev.len());

iter_zip!(builder, idx_arr_rev, eq_bit_decomp_rev).for_each(|ptr_vec, builder| {
let i = builder.iter_ptr_get(&idx_arr_rev, ptr_vec[0]);
let bit = builder.iter_ptr_get(&eq_bit_decomp_rev, ptr_vec[1]);
let bit_ext = builder.ext_from_base_slice(&[bit]);
let last_idx = builder.eval_expr(i.clone() + RVar::from(1));

let v = builder.get(&running_product2, last_idx);
let a_i = builder.get(a, i.clone());
let b_i = builder.get(b, i.clone());

let next_v: Ext<C::F, C::EF> = builder.eval(
v * (a_i * b_i * bit_ext + (one_ext - a_i) * (one_ext - b_i) * (one_ext - bit_ext)),
);
builder.set(&running_product2, i, next_v);
});

// Here is an example of how this works:
// Suppose max_idx = (110101)_2
// Then ans = eq(a, b)
// - eq(11011, a[1..6], b[1..6])eq(a[0..1], b[0..1])
// - eq(111, a[3..6], b[3..6])eq(a[0..3], b[0..3])
let ans = builder.get(&running_product, b.len());
builder.range(0, b.len()).for_each(|idx_vec, builder| {
let bit = builder.get(&eq_bit_decomp, idx_vec[0]);
let running_product2: Ext<C::F, C::EF> = builder.constant(C::EF::ONE);
let idx: Var<C::N> = builder.uninit();
builder.assign(&idx, b.len() - C::N::ONE);
builder.range(0, b.len()).for_each(|_, builder| {
let bit = builder.get(&eq_bit_decomp, idx);
let bit_rvar = RVar::from(builder.cast_felt_to_var(bit));
let bit_ext: Ext<C::F, C::EF> = builder.eval(bit * SymbolicExt::from_f(C::EF::ONE));

builder.if_ne(bit_rvar, RVar::from(1)).then(|builder| {
let next_idx = builder.eval_expr(idx_vec[0] + RVar::from(1));
let v1 = builder.get(&running_product, idx_vec[0]);
let v2 = builder.get(&running_product2, next_idx);
let a_i = builder.get(a, idx_vec[0]);
let b_i = builder.get(b, idx_vec[0]);
let a_i = builder.get(a, idx);
let b_i = builder.get(b, idx);

builder.assign(&ans, ans - v1 * v2 * a_i * b_i);
// Suppose max_idx = (110101)_2
// Then ans = eq(a, b)
// - eq(11011, a[1..6], b[1..6])eq(a[0..1], b[0..1])
// - eq(111, a[3..6], b[3..6])eq(a[0..3], b[0..3])
builder.if_ne(bit_rvar, RVar::from(1)).then(|builder| {
let v1 = builder.get(&running_product, idx);
builder.assign(&ans, ans - v1 * running_product2 * a_i * b_i);
});

builder.assign(
&running_product2,
running_product2
* (a_i * b_i * bit_ext + (one_ext - a_i) * (one_ext - b_i) * (one_ext - bit_ext)),
);
builder.assign(&idx, idx - C::N::ONE);
});

let a_remainder_arr: Array<C, Ext<C::F, C::EF>> = a.slice(builder, b.len(), a.len());
Expand All @@ -404,6 +481,8 @@ pub fn eq_eval_less_or_equal_than<C: Config>(
builder.assign(&ans, ans * (one_ext - a));
});

builder.cycle_tracker_end("Compute eq_eval_less_or_equal_than");

ans
}

Expand Down Expand Up @@ -537,9 +616,14 @@ pub fn eval_ceno_expr_with_instance<C: Config>(
res
},
&|builder, scalar| {
let res: Ext<C::F, C::EF> =
builder.constant(C::EF::from_canonical_u32(scalar.to_canonical_u64() as u32));
res
let scalar_base_slice = scalar
.as_bases()
.iter()
.map(|b| C::F::from_canonical_u64(b.to_canonical_u64()))
.collect::<Vec<C::F>>();
let scalar_ext: Ext<C::F, C::EF> =
builder.constant(C::EF::from_base_slice(&scalar_base_slice));
scalar_ext
},
&|builder, challenge_id, pow, scalar, offset| {
let challenge = builder.get(&challenges, challenge_id as usize);
Expand Down Expand Up @@ -587,7 +671,7 @@ pub fn evaluate_ceno_expr<C: Config, T>(
wit_in: &impl Fn(&mut Builder<C>, WitnessId) -> T, // witin id
structural_wit_in: &impl Fn(&mut Builder<C>, WitnessId, usize, u32, usize) -> T,
instance: &impl Fn(&mut Builder<C>, Instance) -> T,
constant: &impl Fn(&mut Builder<C>, <E as ExtensionField>::BaseField) -> T,
constant: &impl Fn(&mut Builder<C>, E) -> T,
challenge: &impl Fn(&mut Builder<C>, ChallengeId, usize, E, E) -> T,
sum: &impl Fn(&mut Builder<C>, T, T) -> T,
product: &impl Fn(&mut Builder<C>, T, T) -> T,
Expand All @@ -600,7 +684,10 @@ pub fn evaluate_ceno_expr<C: Config, T>(
structural_wit_in(builder, *witness_id, *max_len, *offset, *multi_factor)
}
Expression::Instance(i) => instance(builder, *i),
Expression::Constant(scalar) => constant(builder, *scalar),
Expression::Constant(scalar) => match scalar {
Either::Left(s) => constant(builder, E::from_base(*s)),
Either::Right(s) => constant(builder, *s),
},
Expression::Sum(a, b) => {
let a = evaluate_ceno_expr(
builder,
Expand Down
Loading