From 4163b8765afdc80aafa21a3a2044576c65b290a8 Mon Sep 17 00:00:00 2001 From: Sean Bowe Date: Sat, 2 Apr 2022 12:28:46 -0600 Subject: [PATCH 1/4] Reduce depth of AST by special casing the application of Horner's rule. The existing code will fold together a very deep AST that applies Horner's rule to each gate in a proof -- which could include multiple circuits and so for some applications will quickly grow such that when we recursively descend later during evaluation the stack will easily overflow. This change special cases the application of Horner's rule to a "DistributePowers" AST node to keep the tree depth from exploding in size. --- halo2_proofs/CHANGELOG.md | 4 +++ halo2_proofs/src/plonk/vanishing/prover.rs | 4 +-- halo2_proofs/src/poly/evaluator.rs | 40 ++++++++++++++++++++++ 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/halo2_proofs/CHANGELOG.md b/halo2_proofs/CHANGELOG.md index a3d0d3f467..01e6000f78 100644 --- a/halo2_proofs/CHANGELOG.md +++ b/halo2_proofs/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to Rust's notion of ## [Unreleased] +### Changed +- PLONK prover was improved to avoid stack overflows when large numbers of gates + are involved in a proof. + ## [0.1.0-beta.3] - 2022-03-22 ### Added - `halo2_proofs::circuit`: diff --git a/halo2_proofs/src/plonk/vanishing/prover.rs b/halo2_proofs/src/plonk/vanishing/prover.rs index e5bddb4886..c794070be3 100644 --- a/halo2_proofs/src/plonk/vanishing/prover.rs +++ b/halo2_proofs/src/plonk/vanishing/prover.rs @@ -77,9 +77,7 @@ impl Committed { transcript: &mut T, ) -> Result, Error> { // Evaluate the h(X) polynomial's constraint system expressions for the constraints provided - let h_poly = expressions - .reduce(|h_poly, v| &(&h_poly * *y) + &v) // Fold the gates together with the y challenge - .unwrap_or_else(|| poly::Ast::ConstantTerm(C::Scalar::zero())); + let h_poly = poly::Ast::distribute_powers(expressions, *y); // Fold the gates together with the y challenge let h_poly = evaluator.evaluate(&h_poly, domain); // Evaluate the h(X) polynomial // Divide by t(X) = X^{params.n} - 1. diff --git a/halo2_proofs/src/poly/evaluator.rs b/halo2_proofs/src/poly/evaluator.rs index b9abfc9174..871f7f3962 100644 --- a/halo2_proofs/src/poly/evaluator.rs +++ b/halo2_proofs/src/poly/evaluator.rs @@ -150,6 +150,11 @@ impl Evaluator { lhs.union(&rhs).cloned().collect() } Ast::Scale(a, _) => collect_rotations(a), + Ast::DistributePowers(terms, _) => terms + .iter() + .map(|term| collect_rotations(term)) + .reduce(|a, b| a.union(&b).cloned().collect()) + .unwrap_or(HashSet::new()), Ast::LinearTerm(_) | Ast::ConstantTerm(_) => HashSet::default(), } } @@ -196,6 +201,14 @@ impl Evaluator { leaves: &'a HashMap, &'a [F]>, } + impl<'a, E, F: FieldExt, B: Basis> AstContext<'a, E, F, B> { + /// Returns the actual size of the chunk we're operating on in this + /// context, which may be smaller than `chunk_size`. + fn local_chunk_size(&self) -> usize { + cmp::min(self.chunk_size, self.poly_len - self.chunk_size * self.chunk_index) + } + } + fn recurse( ast: &Ast, ctx: &AstContext<'_, E, F, B>, @@ -225,6 +238,19 @@ impl Evaluator { } lhs } + Ast::DistributePowers(terms, base) => { + let mut acc = vec![F::zero(); ctx.local_chunk_size()]; + + for term in terms.iter() { + let term = recurse(term, ctx); + for (acc, term) in acc.iter_mut().zip(term) { + *acc *= base; + *acc += term; + } + } + + acc + } Ast::LinearTerm(scalar) => B::linear_term( ctx.domain, ctx.poly_len, @@ -285,6 +311,9 @@ pub(crate) enum Ast { Add(Arc>, Arc>), Mul(AstMul), Scale(Arc>, F), + /// Represents a linear combination of a vector of nodes and the powers of a + /// field element. + DistributePowers(Arc>>, F), /// The degree-1 term of a polynomial. /// /// The field element is the coefficient of the term in the standard basis, not the @@ -296,6 +325,12 @@ pub(crate) enum Ast { ConstantTerm(F), } +impl Ast { + pub fn distribute_powers>(i: I, base: F) -> Self { + Ast::DistributePowers(Arc::new(i.into_iter().collect()), base) + } +} + impl fmt::Debug for Ast { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -303,6 +338,11 @@ impl fmt::Debug for Ast { Self::Add(lhs, rhs) => f.debug_tuple("Add").field(lhs).field(rhs).finish(), Self::Mul(x) => f.debug_tuple("Mul").field(x).finish(), Self::Scale(base, scalar) => f.debug_tuple("Scale").field(base).field(scalar).finish(), + Self::DistributePowers(terms, base) => f + .debug_tuple("DistributePowers") + .field(terms) + .field(base) + .finish(), Self::LinearTerm(x) => f.debug_tuple("LinearTerm").field(x).finish(), Self::ConstantTerm(x) => f.debug_tuple("ConstantTerm").field(x).finish(), } From fd7e9ddbb078bb701f5ebf3d5e7728604a81f7c8 Mon Sep 17 00:00:00 2001 From: Sean Bowe Date: Sat, 2 Apr 2022 15:38:46 -0600 Subject: [PATCH 2/4] rustfmt --- halo2_proofs/src/poly/evaluator.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/halo2_proofs/src/poly/evaluator.rs b/halo2_proofs/src/poly/evaluator.rs index 871f7f3962..aa91cf4f82 100644 --- a/halo2_proofs/src/poly/evaluator.rs +++ b/halo2_proofs/src/poly/evaluator.rs @@ -205,7 +205,10 @@ impl Evaluator { /// Returns the actual size of the chunk we're operating on in this /// context, which may be smaller than `chunk_size`. fn local_chunk_size(&self) -> usize { - cmp::min(self.chunk_size, self.poly_len - self.chunk_size * self.chunk_index) + cmp::min( + self.chunk_size, + self.poly_len - self.chunk_size * self.chunk_index, + ) } } From fa069a745541ca188c0e1d8adb2082f6f0d6670d Mon Sep 17 00:00:00 2001 From: Sean Bowe Date: Sun, 3 Apr 2022 10:06:19 -0600 Subject: [PATCH 3/4] Use unwrap_or_default() instead of unwrap_or(HashMap::new()) --- halo2_proofs/src/poly/evaluator.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/halo2_proofs/src/poly/evaluator.rs b/halo2_proofs/src/poly/evaluator.rs index aa91cf4f82..04b9bfe0a9 100644 --- a/halo2_proofs/src/poly/evaluator.rs +++ b/halo2_proofs/src/poly/evaluator.rs @@ -154,7 +154,7 @@ impl Evaluator { .iter() .map(|term| collect_rotations(term)) .reduce(|a, b| a.union(&b).cloned().collect()) - .unwrap_or(HashSet::new()), + .unwrap_or_default(), Ast::LinearTerm(_) | Ast::ConstantTerm(_) => HashSet::default(), } } From 6a31a0e6a1ab56f43108ccde95c60052d1e660ea Mon Sep 17 00:00:00 2001 From: Sean Bowe Date: Mon, 4 Apr 2022 14:07:31 -0600 Subject: [PATCH 4/4] Apply @str4d's review suggestions. --- halo2_proofs/src/poly/evaluator.rs | 33 +++++++++--------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/halo2_proofs/src/poly/evaluator.rs b/halo2_proofs/src/poly/evaluator.rs index 04b9bfe0a9..cda6aac963 100644 --- a/halo2_proofs/src/poly/evaluator.rs +++ b/halo2_proofs/src/poly/evaluator.rs @@ -152,9 +152,8 @@ impl Evaluator { Ast::Scale(a, _) => collect_rotations(a), Ast::DistributePowers(terms, _) => terms .iter() - .map(|term| collect_rotations(term)) - .reduce(|a, b| a.union(&b).cloned().collect()) - .unwrap_or_default(), + .flat_map(|term| collect_rotations(term).into_iter()) + .collect(), Ast::LinearTerm(_) | Ast::ConstantTerm(_) => HashSet::default(), } } @@ -201,17 +200,6 @@ impl Evaluator { leaves: &'a HashMap, &'a [F]>, } - impl<'a, E, F: FieldExt, B: Basis> AstContext<'a, E, F, B> { - /// Returns the actual size of the chunk we're operating on in this - /// context, which may be smaller than `chunk_size`. - fn local_chunk_size(&self) -> usize { - cmp::min( - self.chunk_size, - self.poly_len - self.chunk_size * self.chunk_index, - ) - } - } - fn recurse( ast: &Ast, ctx: &AstContext<'_, E, F, B>, @@ -241,19 +229,17 @@ impl Evaluator { } lhs } - Ast::DistributePowers(terms, base) => { - let mut acc = vec![F::zero(); ctx.local_chunk_size()]; - - for term in terms.iter() { + Ast::DistributePowers(terms, base) => terms.iter().fold( + B::constant_term(ctx.poly_len, ctx.chunk_size, ctx.chunk_index, F::zero()), + |mut acc, term| { let term = recurse(term, ctx); for (acc, term) in acc.iter_mut().zip(term) { *acc *= base; *acc += term; } - } - - acc - } + acc + }, + ), Ast::LinearTerm(scalar) => B::linear_term( ctx.domain, ctx.poly_len, @@ -315,7 +301,8 @@ pub(crate) enum Ast { Mul(AstMul), Scale(Arc>, F), /// Represents a linear combination of a vector of nodes and the powers of a - /// field element. + /// field element, where the nodes are ordered from highest to lowest degree + /// terms. DistributePowers(Arc>>, F), /// The degree-1 term of a polynomial. ///