Skip to content

rewrite #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: feat/simplify-basefold-verifier
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/arithmetics/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::tower_verifier::binding::PointAndEvalVariable;
use crate::zkvm_verifier::binding::ZKVMOpcodeProofInputVariable;
use ceno_mle::expression::{Expression, Fixed, Instance};
use ceno_mle::{Expression, Fixed, Instance};
use ceno_zkvm::structs::{ChallengeId, WitnessId};
use ff_ext::ExtensionField;
use ff_ext::{BabyBearExt4, SmallField};
Expand Down
8 changes: 1 addition & 7 deletions src/basefold_verifier/basefold.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use mpcs::BasefoldProof as InnerBasefoldProof;
use mpcs::basefold::BasefoldProof as InnerBasefoldProof;
use openvm_native_compiler::{asm::AsmConfig, prelude::*};
use openvm_native_recursion::hints::{Hintable, VecAutoHintable};
use openvm_stark_sdk::p3_baby_bear::BabyBear;
Expand Down Expand Up @@ -26,7 +26,6 @@ pub type HashDigest = MmcsCommitment;
pub struct BasefoldCommitment {
pub commit: HashDigest,
pub log2_max_codeword_size: usize,
pub trivial_commits: Vec<(usize, HashDigest)>,
}

use mpcs::BasefoldCommitment as InnerBasefoldCommitment;
Expand All @@ -38,11 +37,6 @@ impl From<InnerBasefoldCommitment<E>> for BasefoldCommitment {
value: value.commit().into(),
},
log2_max_codeword_size: value.log2_max_codeword_size,
trivial_commits: value
.trivial_commits
.into_iter()
.map(|(i, c)| (i, c.into()))
.collect(),
}
}
}
Expand Down
194 changes: 28 additions & 166 deletions src/basefold_verifier/query_phase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use ff_ext::{BabyBearExt4, ExtensionField, PoseidonField};
use mpcs::QueryOpeningProof as InnerQueryOpeningProof;
use openvm_native_compiler::{asm::AsmConfig, prelude::*};
use openvm_native_compiler_derive::iter_zip;
use openvm_native_recursion::{
hints::{Hintable, VecAutoHintable},
vars::HintSlice,
Expand Down Expand Up @@ -315,18 +316,19 @@ pub(crate) fn batch_verifier_query_phase<C: Config + Debug>(
}

// encode_small
let final_rmm_values_len = builder.get(&input.final_message, 0).len();
let final_message = &input.proof.final_message;
let final_rmm_values_len = builder.get(final_message, 0).len();
let final_rmm_values = builder.dyn_array(final_rmm_values_len.clone());

builder
.range(0, final_rmm_values_len.clone())
.for_each(|i_vec, builder| {
let i = i_vec[0];
let row_len = input.final_message.len();
let row_len = final_message.len();
let sum = builder.constant(C::EF::ZERO);
builder.range(0, row_len).for_each(|j_vec, builder| {
let j = j_vec[0];
let row = builder.get(&input.final_message, j);
let row = builder.get(final_message, j);
let row_j = builder.get(&row, i);
builder.assign(&sum, sum + row_j);
});
Expand All @@ -342,52 +344,35 @@ pub(crate) fn batch_verifier_query_phase<C: Config + Debug>(
let log2_witin_max_codeword_size: Var<C::N> =
builder.eval(input.max_num_var.clone() + get_rate_log::<C>());

// Nondeterministically supply the index folding_sorted_order
// Check that:
// 1. It has the same length as input.circuit_meta (checked by requesting folding_len hints)
// 2. It does not contain the same index twice (checked via a correspondence array)
// 3. Indexed witin_num_vars are sorted in decreasing order
// Infer witin_num_vars through index
let folding_len = input.circuit_meta.len();
let zero: Ext<C::F, C::EF> = builder.constant(C::EF::ZERO);
let folding_sort_surjective: Array<C, Ext<C::F, C::EF>> =
builder.dyn_array(folding_len.clone());
builder
.range(0, folding_len.clone())
.for_each(|i_vec, builder| {
let i = i_vec[0];
builder.set(&folding_sort_surjective, i, zero.clone());
});

// an vector with same length as circuit_meta, which is sorted by num_var in descending order and keep its index
// for reverse lookup when retrieving next base codeword to involve into batching
// Sort input.dimensions by height, returns
// 1. height_order: after sorting by decreasing height, the original index of each entry
// 2. num_unique_height: number of different heights
// 3. count_per_unique_height: for each unique height, number of dimensions of that height
// builder.assert_nonzero(&Usize::from(0));
let (
folding_sorted_order_index,
num_unique_num_vars,
count_per_unique_num_var,
sorted_unique_num_vars,
) = sort_with_count(
builder,
&input.circuit_meta,
|m: CircuitIndexMetaVariable<C>| m.witin_num_vars,
);

builder
.range(0, input.indices.len())
.for_each(|i_vec, builder| {
let i = i_vec[0];
// i >>= 1;
let idx = builder.get(&input.indices, i);
let query = builder.get(&input.queries, i);
let witin_opened_values = query.witin_base_proof.opened_values;
let query = builder.get(&input.proof.query_opening_proof, i);

iter_zip!(query.input_proofs, input.rounds,).for_each(|ptr_vec| {
let batch_opening = builder.iter_ptr_get(&query.input_proofs, ptr_vec[0]);
let round = builder.iter_ptr_get(&input.rounds, ptr_vec[1]);
let opened_values = batch_opening.opened_values;
let opening_proof = batch_opening.opening_proof;
// get dimension

// i >>= 1

// verify input mmcs
let mmcs_verifier_input = MmcsVerifierInputVariable {
commit: round.commit.commit.clone(),
dimensions: dimensions,
index_bits: idx_bits.clone().slice(builder, 1, idx_len), // Remove the first bit because two entries are grouped into one leaf in the Merkle tree
opened_values: opened_values,
proof: opening_proof,
};

mmcs_verify_batch(builder, mmcs_verifier_input);
});

let witin_opening_proof = query.witin_base_proof.opening_proof;
let fixed_is_some = query.fixed_is_some;
let fixed_commit = query.fixed_base_proof;
let opening_ext = query.commit_phase_openings;

// verify base oracle query proof
Expand All @@ -412,75 +397,6 @@ pub(crate) fn batch_verifier_query_phase<C: Config + Debug>(
builder.assign(&idx_len, idx_len_minus_one);
builder.assign(&idx, idx_half);

let (witin_dimensions, fixed_dimensions) =
get_base_codeword_dimensions(builder, input.circuit_meta.clone());
// verify witness
let mmcs_verifier_input = MmcsVerifierInputVariable {
commit: input.witin_comm.commit.clone(),
dimensions: witin_dimensions,
index_bits: idx_bits.clone().slice(builder, 1, idx_len), // Remove the first bit because two entries are grouped into one leaf in the Merkle tree
opened_values: witin_opened_values.clone(),
proof: witin_opening_proof,
};
mmcs_verify_batch(builder, mmcs_verifier_input);

// verify fixed
let fixed_commit_leafs = builder.dyn_array(0);
builder
.if_eq(fixed_is_some.clone(), Usize::from(1))
.then(|builder| {
let fixed_opened_values = fixed_commit.opened_values.clone();

let fixed_opening_proof = fixed_commit.opening_proof.clone();
// new_idx used by mmcs proof
let new_idx: Var<C::N> = builder.eval(idx);
// Nondeterministically supply a hint:
// 0: input.fixed_comm.log2_max_codeword_size < log2_witin_max_codeword_size
// 1: >=
let branch_le = builder.hint_var();
builder.if_eq(branch_le, Usize::from(0)).then(|builder| {
// input.fixed_comm.log2_max_codeword_size < log2_witin_max_codeword_size
builder.assert_less_than_slow_small_rhs(
input.fixed_comm.log2_max_codeword_size.clone(),
log2_witin_max_codeword_size,
);
// idx >> idx_shift
let idx_shift_remain: Var<C::N> = builder.eval(
idx_len
- (log2_witin_max_codeword_size
- input.fixed_comm.log2_max_codeword_size.clone()),
);
let tmp_idx = bin_to_dec(builder, &idx_bits, idx_shift_remain);
builder.assign(&new_idx, tmp_idx);
});
builder.if_ne(branch_le, Usize::from(0)).then(|builder| {
// input.fixed_comm.log2_max_codeword_size >= log2_witin_max_codeword_size
let input_codeword_size_plus_one: Var<C::N> = builder
.eval(input.fixed_comm.log2_max_codeword_size.clone() + Usize::from(1));
builder.assert_less_than_slow_small_rhs(
log2_witin_max_codeword_size,
input_codeword_size_plus_one,
);
// idx << -idx_shift
let idx_shift = builder.eval(
input.fixed_comm.log2_max_codeword_size.clone()
- log2_witin_max_codeword_size,
);
let idx_factor = pow_2(builder, idx_shift);
builder.assign(&new_idx, new_idx * idx_factor);
});
// verify witness
let mmcs_verifier_input = MmcsVerifierInputVariable {
commit: input.fixed_comm.commit.clone(),
dimensions: fixed_dimensions.clone(),
index_bits: idx_bits.clone().slice(builder, 1, idx_len),
opened_values: fixed_opened_values.clone(),
proof: fixed_opening_proof,
};
mmcs_verify_batch(builder, mmcs_verifier_input);
builder.assign(&fixed_commit_leafs, fixed_opened_values);
});

// base_codeword_lo_hi
let base_codeword_lo = builder.dyn_array(folding_len.clone());
let base_codeword_hi = builder.dyn_array(folding_len.clone());
Expand All @@ -489,9 +405,6 @@ pub(crate) fn batch_verifier_query_phase<C: Config + Debug>(
.for_each(|j_vec, builder| {
let j = j_vec[0];
let circuit_meta = builder.get(&input.circuit_meta, j);
let witin_num_polys = circuit_meta.witin_num_polys;
let fixed_num_vars = circuit_meta.fixed_num_vars;
let fixed_num_polys = circuit_meta.fixed_num_polys;
let witin_leafs = builder.get(&witin_opened_values, j);
// lo_wit, hi_wit
let leafs_len_div_2 = builder.hint_var();
Expand Down Expand Up @@ -600,57 +513,6 @@ pub(crate) fn batch_verifier_query_phase<C: Config + Debug>(
// next folding challenges
let is_interpolate_to_right_index = builder.get(&idx_bits, j_plus_one);
let new_involved_codewords: Ext<C::F, C::EF> = builder.constant(C::EF::ZERO);
builder
.if_ne(next_unique_num_vars_index, num_unique_num_vars)
.then(|builder| {
let next_unique_num_vars: Var<C::N> =
builder.get(&sorted_unique_num_vars, next_unique_num_vars_index);
builder
.if_eq(next_unique_num_vars, cur_num_var)
.then(|builder| {
let next_unique_num_vars_count: Var<C::N> = builder
.get(&count_per_unique_num_var, next_unique_num_vars_index);
builder.range(0, next_unique_num_vars_count).for_each(
|k_vec, builder| {
let k =
builder.eval_expr(k_vec[0] + cumul_num_vars_count);
let index = builder.get(&folding_sorted_order_index, k);
let lo = builder.get(&base_codeword_lo, index.clone());
let hi = builder.get(&base_codeword_hi, index.clone());
builder
.if_eq(
is_interpolate_to_right_index,
Usize::from(1),
)
.then(|builder| {
builder.assign(
&new_involved_codewords,
new_involved_codewords + hi,
);
});
builder
.if_ne(
is_interpolate_to_right_index,
Usize::from(1),
)
.then(|builder| {
builder.assign(
&new_involved_codewords,
new_involved_codewords + lo,
);
});
},
);
builder.assign(
&cumul_num_vars_count,
cumul_num_vars_count + next_unique_num_vars_count,
);
builder.assign(
&next_unique_num_vars_index,
next_unique_num_vars_index + Usize::from(1),
);
});
});

// leafs
let leafs = builder.dyn_array(2);
Expand Down
Loading