Browse Source

Small code organization improvements (#206)

* refactor: Deleted a redundant `ScalarMul` helper trait

* refactor: Refactor `to_transcript_bytes`

* refactor: refactor R1CS Shape checking in Spartan checks

- Introduced a new function `check_regular_shape` in `r1cs.rs` to enforce regularity conditions necessary for Spartan-class SNARKs.

* refactor: Refactor sumcheck.rs prove_quad_* for readability

- Extracted the calculation of evaluation points to its new function `compute_eval_points`, enhancing code reusability within `prove_quad` and `prove_quad_batch` functions.
main
François Garillot 1 year ago
committed by GitHub
parent
commit
a62bccf206
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 49 additions and 54 deletions
  1. +10
    -0
      src/r1cs.rs
  2. +1
    -4
      src/spartan/ppsnark.rs
  3. +1
    -4
      src/spartan/snark.rs
  4. +31
    -34
      src/spartan/sumcheck.rs
  5. +3
    -7
      src/traits/commitment.rs
  6. +3
    -5
      src/traits/mod.rs

+ 10
- 0
src/r1cs.rs

@ -133,6 +133,16 @@ impl R1CSShape {
}) })
} }
// Checks regularity conditions on the R1CSShape, required in Spartan-class SNARKs
// Panics if num_cons, num_vars, or num_io are not powers of two, or if num_io > num_vars
#[inline]
pub(crate) fn check_regular_shape(&self) {
assert_eq!(self.num_cons.next_power_of_two(), self.num_cons);
assert_eq!(self.num_vars.next_power_of_two(), self.num_vars);
assert_eq!(self.num_io.next_power_of_two(), self.num_io);
assert!(self.num_io < self.num_vars);
}
pub fn multiply_vec( pub fn multiply_vec(
&self, &self,
z: &[G::Scalar], z: &[G::Scalar],

+ 1
- 4
src/spartan/ppsnark.rs

@ -967,10 +967,7 @@ impl> RelaxedR1CSSNARKTrait
let mut w_u_vec = Vec::new(); let mut w_u_vec = Vec::new();
// sanity check that R1CSShape has certain size characteristics // sanity check that R1CSShape has certain size characteristics
assert_eq!(pk.S.num_cons.next_power_of_two(), pk.S.num_cons);
assert_eq!(pk.S.num_vars.next_power_of_two(), pk.S.num_vars);
assert_eq!(pk.S.num_io.next_power_of_two(), pk.S.num_io);
assert!(pk.S.num_io < pk.S.num_vars);
pk.S.check_regular_shape();
// append the verifier key (which includes commitment to R1CS matrices) and the RelaxedR1CSInstance to the transcript // append the verifier key (which includes commitment to R1CS matrices) and the RelaxedR1CSInstance to the transcript
transcript.absorb(b"vk", &pk.vk_digest); transcript.absorb(b"vk", &pk.vk_digest);

+ 1
- 4
src/spartan/snark.rs

@ -102,10 +102,7 @@ impl> RelaxedR1CSSNARKTrait
let mut transcript = G::TE::new(b"RelaxedR1CSSNARK"); let mut transcript = G::TE::new(b"RelaxedR1CSSNARK");
// sanity check that R1CSShape has certain size characteristics // sanity check that R1CSShape has certain size characteristics
assert_eq!(pk.S.num_cons.next_power_of_two(), pk.S.num_cons);
assert_eq!(pk.S.num_vars.next_power_of_two(), pk.S.num_vars);
assert_eq!(pk.S.num_io.next_power_of_two(), pk.S.num_io);
assert!(pk.S.num_io < pk.S.num_vars);
pk.S.check_regular_shape();
// append the digest of vk (which includes R1CS matrices) and the RelaxedR1CSInstance to the transcript // append the digest of vk (which includes R1CS matrices) and the RelaxedR1CSInstance to the transcript
transcript.absorb(b"vk", &pk.vk_digest); transcript.absorb(b"vk", &pk.vk_digest);

+ 31
- 34
src/spartan/sumcheck.rs

@ -61,6 +61,34 @@ impl SumcheckProof {
Ok((e, r)) Ok((e, r))
} }
#[inline]
fn compute_eval_points<F>(
poly_A: &MultilinearPolynomial<G::Scalar>,
poly_B: &MultilinearPolynomial<G::Scalar>,
comb_func: &F,
) -> (G::Scalar, G::Scalar)
where
F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar + Sync,
{
let len = poly_A.len() / 2;
(0..len)
.into_par_iter()
.map(|i| {
// eval 0: bound_func is A(low)
let eval_point_0 = comb_func(&poly_A[i], &poly_B[i]);
// eval 2: bound_func is -A(low) + 2*A(high)
let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i];
let eval_point_2 = comb_func(&poly_A_bound_point, &poly_B_bound_point);
(eval_point_0, eval_point_2)
})
.reduce(
|| (G::Scalar::ZERO, G::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1),
)
}
pub fn prove_quad<F>( pub fn prove_quad<F>(
claim: &G::Scalar, claim: &G::Scalar,
num_rounds: usize, num_rounds: usize,
@ -77,25 +105,7 @@ impl SumcheckProof {
let mut claim_per_round = *claim; let mut claim_per_round = *claim;
for _ in 0..num_rounds { for _ in 0..num_rounds {
let poly = { let poly = {
let len = poly_A.len() / 2;
// Make an iterator returning the contributions to the evaluations
let (eval_point_0, eval_point_2) = (0..len)
.into_par_iter()
.map(|i| {
// eval 0: bound_func is A(low)
let eval_point_0 = comb_func(&poly_A[i], &poly_B[i]);
// eval 2: bound_func is -A(low) + 2*A(high)
let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i];
let eval_point_2 = comb_func(&poly_A_bound_point, &poly_B_bound_point);
(eval_point_0, eval_point_2)
})
.reduce(
|| (G::Scalar::ZERO, G::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1),
);
let (eval_point_0, eval_point_2) = Self::compute_eval_points(poly_A, poly_B, &comb_func);
let evals = vec![eval_point_0, claim_per_round - eval_point_0, eval_point_2]; let evals = vec![eval_point_0, claim_per_round - eval_point_0, eval_point_2];
UniPoly::from_evals(&evals) UniPoly::from_evals(&evals)
@ -136,7 +146,7 @@ impl SumcheckProof {
transcript: &mut G::TE, transcript: &mut G::TE,
) -> Result<(Self, Vec<G::Scalar>, (Vec<G::Scalar>, Vec<G::Scalar>)), NovaError> ) -> Result<(Self, Vec<G::Scalar>, (Vec<G::Scalar>, Vec<G::Scalar>)), NovaError>
where where
F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar,
F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar + Sync,
{ {
let mut e = *claim; let mut e = *claim;
let mut r: Vec<G::Scalar> = Vec::new(); let mut r: Vec<G::Scalar> = Vec::new();
@ -146,20 +156,7 @@ impl SumcheckProof {
let mut evals: Vec<(G::Scalar, G::Scalar)> = Vec::new(); let mut evals: Vec<(G::Scalar, G::Scalar)> = Vec::new();
for (poly_A, poly_B) in poly_A_vec.iter().zip(poly_B_vec.iter()) { for (poly_A, poly_B) in poly_A_vec.iter().zip(poly_B_vec.iter()) {
let mut eval_point_0 = G::Scalar::ZERO;
let mut eval_point_2 = G::Scalar::ZERO;
let len = poly_A.len() / 2;
for i in 0..len {
// eval 0: bound_func is A(low)
eval_point_0 += comb_func(&poly_A[i], &poly_B[i]);
// eval 2: bound_func is -A(low) + 2*A(high)
let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i];
eval_point_2 += comb_func(&poly_A_bound_point, &poly_B_bound_point);
}
let (eval_point_0, eval_point_2) = Self::compute_eval_points(poly_A, poly_B, &comb_func);
evals.push((eval_point_0, eval_point_2)); evals.push((eval_point_0, eval_point_2));
} }

+ 3
- 7
src/traits/commitment.rs

@ -6,10 +6,12 @@ use crate::{
}; };
use core::{ use core::{
fmt::Debug, fmt::Debug,
ops::{Add, AddAssign, Mul, MulAssign},
ops::{Add, AddAssign},
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::ScalarMul;
/// Defines basic operations on commitments /// Defines basic operations on commitments
pub trait CommitmentOps<Rhs = Self, Output = Self>: pub trait CommitmentOps<Rhs = Self, Output = Self>:
Add<Rhs, Output = Output> + AddAssign<Rhs> Add<Rhs, Output = Output> + AddAssign<Rhs>
@ -31,12 +33,6 @@ impl CommitmentOpsOwned for T where
{ {
} }
/// A helper trait for types implementing a multiplication of a commitment with a scalar
pub trait ScalarMul<Rhs, Output = Self>: Mul<Rhs, Output = Output> + MulAssign<Rhs> {}
impl<T, Rhs, Output> ScalarMul<Rhs, Output> for T where T: Mul<Rhs, Output = Output> + MulAssign<Rhs>
{}
/// This trait defines the behavior of the commitment /// This trait defines the behavior of the commitment
pub trait CommitmentTrait<G: Group>: pub trait CommitmentTrait<G: Group>:
Clone Clone

+ 3
- 5
src/traits/mod.rs

@ -236,11 +236,9 @@ pub trait PrimeFieldExt: PrimeField {
impl<G: Group, T: TranscriptReprTrait<G>> TranscriptReprTrait<G> for &[T] { impl<G: Group, T: TranscriptReprTrait<G>> TranscriptReprTrait<G> for &[T] {
fn to_transcript_bytes(&self) -> Vec<u8> { fn to_transcript_bytes(&self) -> Vec<u8> {
(0..self.len())
.map(|i| self[i].to_transcript_bytes())
.collect::<Vec<_>>()
.into_iter()
.flatten()
self
.iter()
.flat_map(|t| t.to_transcript_bytes())
.collect::<Vec<u8>>() .collect::<Vec<u8>>()
} }
} }

Loading…
Cancel
Save