From a62bccf2063da5f8843352f5d1acb242252ac775 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= <4142+huitseeker@users.noreply.github.com> Date: Fri, 21 Jul 2023 14:24:47 -0400 Subject: [PATCH] Small code organization improvements (#206) * refactor: Deleted a redundant `ScalarMul` helper trait * refactor: Refactor `to_transcript_bytes` * refactor: refactor R1CS Shape checking in Spartan checks - Introduced a new function `check_regular_shape` in `r1cs.rs` to enforce regularity conditions necessary for Spartan-class SNARKs. * refactor: Refactor sumcheck.rs prove_quad_* for readability - Extracted the calculation of evaluation points to its new function `compute_eval_points`, enhancing code reusability within `prove_quad` and `prove_quad_batch` functions. --- src/r1cs.rs | 10 +++++++ src/spartan/ppsnark.rs | 5 +--- src/spartan/snark.rs | 5 +--- src/spartan/sumcheck.rs | 65 +++++++++++++++++++--------------------- src/traits/commitment.rs | 10 ++----- src/traits/mod.rs | 8 ++--- 6 files changed, 49 insertions(+), 54 deletions(-) diff --git a/src/r1cs.rs b/src/r1cs.rs index 4b204e9..4894e70 100644 --- a/src/r1cs.rs +++ b/src/r1cs.rs @@ -133,6 +133,16 @@ impl R1CSShape { }) } + // Checks regularity conditions on the R1CSShape, required in Spartan-class SNARKs + // Panics if num_cons, num_vars, or num_io are not powers of two, or if num_io > num_vars + #[inline] + pub(crate) fn check_regular_shape(&self) { + assert_eq!(self.num_cons.next_power_of_two(), self.num_cons); + assert_eq!(self.num_vars.next_power_of_two(), self.num_vars); + assert_eq!(self.num_io.next_power_of_two(), self.num_io); + assert!(self.num_io < self.num_vars); + } + pub fn multiply_vec( &self, z: &[G::Scalar], diff --git a/src/spartan/ppsnark.rs b/src/spartan/ppsnark.rs index 0011f46..aebd5c1 100644 --- a/src/spartan/ppsnark.rs +++ b/src/spartan/ppsnark.rs @@ -967,10 +967,7 @@ impl> RelaxedR1CSSNARKTrait> RelaxedR1CSSNARKTrait SumcheckProof { Ok((e, r)) } + #[inline] + fn compute_eval_points( + poly_A: &MultilinearPolynomial, + poly_B: &MultilinearPolynomial, + comb_func: &F, + ) -> (G::Scalar, G::Scalar) + where + F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar + Sync, + { + let len = poly_A.len() / 2; + (0..len) + .into_par_iter() + .map(|i| { + // eval 0: bound_func is A(low) + let eval_point_0 = comb_func(&poly_A[i], &poly_B[i]); + + // eval 2: bound_func is -A(low) + 2*A(high) + let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; + let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; + let eval_point_2 = comb_func(&poly_A_bound_point, &poly_B_bound_point); + (eval_point_0, eval_point_2) + }) + .reduce( + || (G::Scalar::ZERO, G::Scalar::ZERO), + |a, b| (a.0 + b.0, a.1 + b.1), + ) + } + pub fn prove_quad( claim: &G::Scalar, num_rounds: usize, @@ -77,25 +105,7 @@ impl SumcheckProof { let mut claim_per_round = *claim; for _ in 0..num_rounds { let poly = { - let len = poly_A.len() / 2; - - // Make an iterator returning the contributions to the evaluations - let (eval_point_0, eval_point_2) = (0..len) - .into_par_iter() - .map(|i| { - // eval 0: bound_func is A(low) - let eval_point_0 = comb_func(&poly_A[i], &poly_B[i]); - - // eval 2: bound_func is -A(low) + 2*A(high) - let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; - let eval_point_2 = comb_func(&poly_A_bound_point, &poly_B_bound_point); - (eval_point_0, eval_point_2) - }) - .reduce( - || (G::Scalar::ZERO, G::Scalar::ZERO), - |a, b| (a.0 + b.0, a.1 + b.1), - ); + let (eval_point_0, eval_point_2) = Self::compute_eval_points(poly_A, poly_B, &comb_func); let evals = vec![eval_point_0, claim_per_round - eval_point_0, eval_point_2]; UniPoly::from_evals(&evals) @@ -136,7 +146,7 @@ impl SumcheckProof { transcript: &mut G::TE, ) -> Result<(Self, Vec, (Vec, Vec)), NovaError> where - F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar, + F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar + Sync, { let mut e = *claim; let mut r: Vec = Vec::new(); @@ -146,20 +156,7 @@ impl SumcheckProof { let mut evals: Vec<(G::Scalar, G::Scalar)> = Vec::new(); for (poly_A, poly_B) in poly_A_vec.iter().zip(poly_B_vec.iter()) { - let mut eval_point_0 = G::Scalar::ZERO; - let mut eval_point_2 = G::Scalar::ZERO; - - let len = poly_A.len() / 2; - for i in 0..len { - // eval 0: bound_func is A(low) - eval_point_0 += comb_func(&poly_A[i], &poly_B[i]); - - // eval 2: bound_func is -A(low) + 2*A(high) - let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; - eval_point_2 += comb_func(&poly_A_bound_point, &poly_B_bound_point); - } - + let (eval_point_0, eval_point_2) = Self::compute_eval_points(poly_A, poly_B, &comb_func); evals.push((eval_point_0, eval_point_2)); } diff --git a/src/traits/commitment.rs b/src/traits/commitment.rs index 9b4725f..4ac8349 100644 --- a/src/traits/commitment.rs +++ b/src/traits/commitment.rs @@ -6,10 +6,12 @@ use crate::{ }; use core::{ fmt::Debug, - ops::{Add, AddAssign, Mul, MulAssign}, + ops::{Add, AddAssign}, }; use serde::{Deserialize, Serialize}; +use super::ScalarMul; + /// Defines basic operations on commitments pub trait CommitmentOps: Add + AddAssign @@ -31,12 +33,6 @@ impl CommitmentOpsOwned for T where { } -/// A helper trait for types implementing a multiplication of a commitment with a scalar -pub trait ScalarMul: Mul + MulAssign {} - -impl ScalarMul for T where T: Mul + MulAssign -{} - /// This trait defines the behavior of the commitment pub trait CommitmentTrait: Clone diff --git a/src/traits/mod.rs b/src/traits/mod.rs index 5138cea..ffad871 100644 --- a/src/traits/mod.rs +++ b/src/traits/mod.rs @@ -236,11 +236,9 @@ pub trait PrimeFieldExt: PrimeField { impl> TranscriptReprTrait for &[T] { fn to_transcript_bytes(&self) -> Vec { - (0..self.len()) - .map(|i| self[i].to_transcript_bytes()) - .collect::>() - .into_iter() - .flatten() + self + .iter() + .flat_map(|t| t.to_transcript_bytes()) .collect::>() } }