Skip to content

Commit

Permalink
WHIR wrapper (#742)
Browse files Browse the repository at this point in the history
Add the implementation of WHIR from  #642

---------

Co-authored-by: Ming <[email protected]>
  • Loading branch information
yczhangsjtu and hero78119 authored Feb 26, 2025
1 parent de94155 commit 59e74dc
Show file tree
Hide file tree
Showing 15 changed files with 2,800 additions and 201 deletions.
939 changes: 757 additions & 182 deletions Cargo.lock

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ version = "0.1.0"

[workspace.dependencies]
anyhow = { version = "1.0", default-features = false }
ark-std = "0.4"
ark-std = "0.5"
cfg-if = "1.0"
criterion = { version = "0.5", features = ["html_reports"] }
crossbeam-channel = "0.5"
itertools = "0.13"
num-derive = "0.4"
num-traits = "0.2"
p3-challenger = { git = "https://github.com/scroll-tech/plonky3", branch = "feat/whir_field_wrapper" }
p3-field = { git = "https://github.com/scroll-tech/plonky3", branch = "feat/whir_field_wrapper" }
p3-goldilocks = { git = "https://github.com/scroll-tech/plonky3", branch = "feat/whir_field_wrapper" }
p3-mds = { git = "https://github.com/scroll-tech/plonky3", branch = "feat/whir_field_wrapper" }
p3-poseidon = { git = "https://github.com/scroll-tech/plonky3", branch = "feat/whir_field_wrapper" }
p3-poseidon2 = { git = "https://github.com/scroll-tech/plonky3", branch = "feat/whir_field_wrapper" }
p3-symmetric = { git = "https://github.com/scroll-tech/plonky3", branch = "feat/whir_field_wrapper" }
p3-challenger = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" }
p3-field = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" }
p3-goldilocks = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" }
p3-mds = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" }
p3-poseidon = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" }
p3-poseidon2 = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" }
p3-symmetric = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" }
paste = "1"
plonky2 = "0.2"
poseidon = { path = "./poseidon" }
Expand Down
14 changes: 7 additions & 7 deletions ceno_zkvm/src/scheme/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use ceno_emul::{
};
use ff_ext::{ExtensionField, FieldInto, FromUniformBytes, GoldilocksExt2};
use itertools::Itertools;
use mpcs::{Basefold, BasefoldDefault, BasefoldRSParams, PolynomialCommitmentScheme};
use mpcs::{PolynomialCommitmentScheme, WhirDefault};
use multilinear_extensions::{
mle::IntoMLE, util::ceil_log2, virtual_poly::ArcMultilinearExtension,
};
Expand Down Expand Up @@ -88,11 +88,11 @@ impl<E: ExtensionField, const L: usize, const RW: usize> Instruction<E> for Test
fn test_rw_lk_expression_combination() {
fn test_rw_lk_expression_combination_inner<const L: usize, const RW: usize>() {
type E = GoldilocksExt2;
type Pcs = BasefoldDefault<E>;
type Pcs = WhirDefault<E>;

// pcs setup
let param = Pcs::setup(1 << 13).unwrap();
let (pp, vp) = Pcs::trim(param, 1 << 13).unwrap();
Pcs::setup(1 << 8).unwrap();
let (pp, vp) = Pcs::trim((), 1 << 8).unwrap();

// configure
let name = TestCircuit::<E, RW, L>::name();
Expand Down Expand Up @@ -200,7 +200,7 @@ const PROGRAM_CODE: [ceno_emul::Instruction; 4] = [
#[test]
fn test_single_add_instance_e2e() {
type E = GoldilocksExt2;
type Pcs = Basefold<GoldilocksExt2, BasefoldRSParams>;
type Pcs = WhirDefault<E>;

// set up program
let program = Program::new(
Expand All @@ -210,8 +210,8 @@ fn test_single_add_instance_e2e() {
Default::default(),
);

let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup");
let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim");
Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup");
let (pp, vp) = Pcs::trim((), 1 << MAX_NUM_VARIABLES).expect("Basefold trim");
let mut zkvm_cs = ZKVMConstraintSystem::default();
// opcode circuits
let add_config = zkvm_cs.register_opcode_circuit::<AddInstruction<E>>();
Expand Down
3 changes: 3 additions & 0 deletions clippy.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ allowed-duplicate-crates = [
"regex-syntax",
"syn",
"windows-sys",
"tracing-subscriber",
"wasi",
"getrandom",
]
22 changes: 21 additions & 1 deletion ff_ext/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ pub trait SmallField: Serialize + P3Field + FieldFrom<u64> + FieldInto<Self> {

pub trait ExtensionField: P3ExtensionField<Self::BaseField> + FromUniformBytes + Ord {
const DEGREE: usize;
const MULTIPLICATIVE_GENERATOR: Self;
const TWO_ADICITY: usize;
const BASE_TWO_ADIC_ROOT_OF_UNITY: Self::BaseField;
const TWO_ADIC_ROOT_OF_UNITY: Self;
const NONRESIDUE: Self::BaseField;

type BaseField: SmallField + Ord + PrimeField + FromUniformBytes + TwoAdicField + PoseidonField;

Expand All @@ -121,7 +126,10 @@ mod impl_goldilocks {
ExtensionField, FieldFrom, FieldInto, FromUniformBytes, GoldilocksExt2, SmallField,
poseidon::{PoseidonField, new_array},
};
use p3_field::{BasedVectorSpace, PrimeCharacteristicRing, PrimeField64};
use p3_field::{
BasedVectorSpace, Field, PrimeCharacteristicRing, PrimeField64, TwoAdicField,
extension::{BinomialExtensionField, BinomiallyExtendable},
};
use p3_goldilocks::{
Goldilocks, HL_GOLDILOCKS_8_EXTERNAL_ROUND_CONSTANTS,
HL_GOLDILOCKS_8_INTERNAL_ROUND_CONSTANTS, Poseidon2GoldilocksHL,
Expand Down Expand Up @@ -209,6 +217,18 @@ mod impl_goldilocks {

impl ExtensionField for GoldilocksExt2 {
const DEGREE: usize = 2;
const MULTIPLICATIVE_GENERATOR: Self = <GoldilocksExt2 as Field>::GENERATOR;
const TWO_ADICITY: usize = Goldilocks::TWO_ADICITY;
// Passing two-adacity itself to this function will get the root of unity
// with the largest order, i.e., order = 2^two-adacity.
const BASE_TWO_ADIC_ROOT_OF_UNITY: Self::BaseField =
Goldilocks::two_adic_generator_const(Goldilocks::TWO_ADICITY);
const TWO_ADIC_ROOT_OF_UNITY: Self = BinomialExtensionField::new_unchecked(
Goldilocks::ext_two_adic_generator_const(Goldilocks::TWO_ADICITY),
);
// non-residue is the value w such that the extension field is
// F[X]/(X^2 - w)
const NONRESIDUE: Self::BaseField = <Goldilocks as BinomiallyExtendable<2>>::W;

type BaseField = Goldilocks;

Expand Down
11 changes: 10 additions & 1 deletion mpcs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ bitvec = "1.0"
ctr = "0.9"
ff_ext = { path = "../ff_ext" }
# TODO: move to version 1, once our dependencies are updated
ark-ff = "0.5"
ark-serialize = { version = "0.5", features = ["derive"] }
bincode = "1.3.3"
generic-array = { version = "0.14", features = ["serde"] }
itertools.workspace = true
multilinear_extensions = { path = "../multilinear_extensions" }
Expand All @@ -32,6 +35,8 @@ rand_chacha.workspace = true
rayon = { workspace = true, optional = true }
serde.workspace = true
transcript = { path = "../transcript" }
whir = { git = "https://github.com/scroll-tech/whir", branch = "feat/ceno-binding-batch", features = ["ceno"] }
zeroize = "1.8"

[dev-dependencies]
criterion.workspace = true
Expand All @@ -40,7 +45,7 @@ criterion.workspace = true
benchmark = ["parallel"]
default = ["parallel"] # Add "sanity-check" to debug
parallel = ["dep:rayon"]
print-trace = ["ark-std/print-trace"]
print-trace = ["ark-std/print-trace", "whir/print-trace"]
sanity-check = []

[[bench]]
Expand All @@ -66,3 +71,7 @@ name = "fft"
[[bench]]
harness = false
name = "utils"

[[bench]]
harness = false
name = "whir"
192 changes: 192 additions & 0 deletions mpcs/benches/whir.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
use std::time::Duration;

use criterion::*;
use ff_ext::GoldilocksExt2;

use itertools::Itertools;
use mpcs::{
PolynomialCommitmentScheme, WhirDefault,
test_util::{gen_rand_poly_base, gen_rand_polys, get_point_from_challenge, setup_pcs},
};

use multilinear_extensions::{mle::MultilinearExtension, virtual_poly::ArcMultilinearExtension};
use transcript::{BasicTranscript, Transcript};

type T = BasicTranscript<GoldilocksExt2>;
type E = GoldilocksExt2;
type PcsGoldilocks = WhirDefault<E>;

const NUM_SAMPLES: usize = 10;
const NUM_VARS_START: usize = 20;
const NUM_VARS_END: usize = 20;
const BATCH_SIZE_LOG_START: usize = 6;
const BATCH_SIZE_LOG_END: usize = 6;

fn bench_commit_open_verify_goldilocks<Pcs: PolynomialCommitmentScheme<E>>(c: &mut Criterion) {
let mut group = c.benchmark_group("commit_open_verify_goldilocks".to_string());
group.sample_size(NUM_SAMPLES);
// Challenge is over extension field, poly over the base field
for num_vars in NUM_VARS_START..=NUM_VARS_END {
let (pp, vp) = {
let poly_size = 1 << num_vars;
let param = Pcs::setup(poly_size).unwrap();

group.bench_function(BenchmarkId::new("setup", format!("{}", num_vars)), |b| {
b.iter(|| {
Pcs::setup(poly_size).unwrap();
})
});
Pcs::trim(param, poly_size).unwrap()
};

let mut transcript = T::new(b"BaseFold");
let poly = gen_rand_poly_base(num_vars);
let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap();

group.bench_function(BenchmarkId::new("commit", format!("{}", num_vars)), |b| {
b.iter(|| {
Pcs::commit(&pp, &poly).unwrap();
})
});

let point = get_point_from_challenge(num_vars, &mut transcript);
let eval = poly.evaluate(point.as_slice());
transcript.append_field_element_ext(&eval);
let transcript_for_bench = transcript.clone();
let proof = Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap();

group.bench_function(BenchmarkId::new("open", format!("{}", num_vars)), |b| {
b.iter_batched(
|| transcript_for_bench.clone(),
|mut transcript| {
Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap();
},
BatchSize::SmallInput,
);
});
// Verify
let comm = Pcs::get_pure_commitment(&comm);
let mut transcript = T::new(b"BaseFold");
Pcs::write_commitment(&comm, &mut transcript).unwrap();
let point = get_point_from_challenge(num_vars, &mut transcript);
transcript.append_field_element_ext(&eval);
let transcript_for_bench = transcript.clone();
Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript).unwrap();
group.bench_function(BenchmarkId::new("verify", format!("{}", num_vars)), |b| {
b.iter_batched(
|| transcript_for_bench.clone(),
|mut transcript| {
Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript).unwrap();
},
BatchSize::SmallInput,
);
});
}
}

fn bench_simple_batch_commit_open_verify_goldilocks<Pcs: PolynomialCommitmentScheme<E>>(
c: &mut Criterion,
) {
let mut group = c.benchmark_group("simple_batch_commit_open_verify_goldilocks".to_string());
group.sample_size(NUM_SAMPLES);
// Challenge is over extension field, poly over the base field
for num_vars in NUM_VARS_START..=NUM_VARS_END {
for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END {
let batch_size = 1 << batch_size_log;
let (pp, vp) = setup_pcs::<E, Pcs>(num_vars);
let mut transcript = T::new(b"BaseFold");
let polys = gen_rand_polys(|_| num_vars, batch_size, gen_rand_poly_base);
let comm = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap();

group.bench_function(
BenchmarkId::new("batch_commit", format!("{}-{}", num_vars, batch_size)),
|b| {
b.iter(|| {
Pcs::batch_commit(&pp, &polys).unwrap();
})
},
);
let point = get_point_from_challenge(num_vars, &mut transcript);
let evals = polys.iter().map(|poly| poly.evaluate(&point)).collect_vec();
transcript.append_field_element_exts(&evals);
let transcript_for_bench = transcript.clone();
let polys = polys
.iter()
.map(|poly| ArcMultilinearExtension::from(poly.clone()))
.collect::<Vec<_>>();
let proof = Pcs::simple_batch_open(&pp, &polys, &comm, &point, &evals, &mut transcript)
.unwrap();

group.bench_function(
BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)),
|b| {
b.iter_batched(
|| transcript_for_bench.clone(),
|mut transcript| {
Pcs::simple_batch_open(
&pp,
&polys,
&comm,
&point,
&evals,
&mut transcript,
)
.unwrap();
},
BatchSize::SmallInput,
);
},
);
let comm = Pcs::get_pure_commitment(&comm);

// Batch verify
let mut transcript = BasicTranscript::new(b"BaseFold");
Pcs::write_commitment(&comm, &mut transcript).unwrap();

let point = get_point_from_challenge(num_vars, &mut transcript);
transcript.append_field_element_exts(&evals);
let backup_transcript = transcript.clone();

Pcs::simple_batch_verify(&vp, &comm, &point, &evals, &proof, &mut transcript).unwrap();

group.bench_function(
BenchmarkId::new("batch_verify", format!("{}-{}", num_vars, batch_size)),
|b| {
b.iter_batched(
|| backup_transcript.clone(),
|mut transcript| {
Pcs::simple_batch_verify(
&vp,
&comm,
&point,
&evals,
&proof,
&mut transcript,
)
.unwrap();
},
BatchSize::SmallInput,
);
},
);
}
}
}

fn bench_commit_open_verify_goldilocks_base(c: &mut Criterion) {
bench_commit_open_verify_goldilocks::<PcsGoldilocks>(c);
}

fn bench_simple_batch_commit_open_verify_goldilocks_base(c: &mut Criterion) {
bench_simple_batch_commit_open_verify_goldilocks::<PcsGoldilocks>(c);
}

criterion_group! {
name = bench_whir;
config = Criterion::default().warm_up_time(Duration::from_millis(3000));
targets =
bench_simple_batch_commit_open_verify_goldilocks_base,
bench_commit_open_verify_goldilocks_base,
}

criterion_main!(bench_whir);
15 changes: 15 additions & 0 deletions mpcs/src/basefold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,21 @@ mod test {
}
}

#[test]
#[ignore = "For benchmarking and profiling only"]
fn bench_basefold_simple_batch_commit_open_verify_goldilocks() {
{
let gen_rand_poly = gen_rand_poly_base;
run_commit_open_verify::<GoldilocksExt2, PcsGoldilocksRSCode>(gen_rand_poly, 20, 21);
run_simple_batch_commit_open_verify::<GoldilocksExt2, PcsGoldilocksRSCode>(
gen_rand_poly,
20,
21,
64,
);
}
}

#[test]
fn batch_commit_open_verify() {
for gen_rand_poly in [gen_rand_poly_base, gen_rand_poly_ext] {
Expand Down
Loading

0 comments on commit 59e74dc

Please sign in to comment.