Browse Source

Ppsnark refactorings (#208)

* refactor: Refactor row/col vector construction for efficiency

- Optimized the creation of `row` and `col` in `R1CSShapeSparkRepr::new` using map and unzip methods.
- Updated `R1CSShapeSparkRepr::evaluation_oracles` to create `E_row` and `E_col` using the same logic for consistency.

* refactor: Refactor and optimize `R1CSShapeSparkRepr` initialization

- Updated method for zero padding in `val_B` and `val_C` using `std::iter::repeat`, to need one vector allocation instead of two
- Functionality and outputs remain unchanged.

* refactor: Refactor polynomial struct in SumCheck to use generic Scalar type

- Updated `CompressedUniPoly` and `UniPoly` structs in `sumcheck.rs` to use the generic `Scalar` type.
- Adapted all methods within these structs to accommodate the `Scalar` type instead of `G: Group` type.
- Modified the type of `cubic_polys` in `ppsnark.rs` to `CompressedUniPoly<G::Scalar>`.

* refactor: Eliminate most instances of resize

resize in Rust may cause reallocation of the memory, which is an expensive operation. This is particularly true when the vector is resized to a larger size.
poseidon-transcript
François Garillot 1 year ago
committed by GitHub
parent
commit
eeb3e470d5
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 70 additions and 84 deletions
  1. +46
    -57
      src/spartan/ppsnark.rs
  2. +24
    -27
      src/spartan/sumcheck.rs

+ 46
- 57
src/spartan/ppsnark.rs

@ -119,51 +119,34 @@ impl R1CSShapeSparkRepr {
max(total_nz, max(2 * S.num_vars, S.num_cons)).next_power_of_two() max(total_nz, max(2 * S.num_vars, S.num_cons)).next_power_of_two()
}; };
let row = {
let mut r = S
.A
.iter()
.chain(S.B.iter())
.chain(S.C.iter())
.map(|(r, _, _)| *r)
.collect::<Vec<usize>>();
r.resize(N, 0usize);
r
};
let (mut row, mut col) = (vec![0usize; N], vec![0usize; N]);
let col = {
let mut c = S
.A
.iter()
.chain(S.B.iter())
.chain(S.C.iter())
.map(|(_, c, _)| *c)
.collect::<Vec<usize>>();
c.resize(N, 0usize);
c
};
for (i, (r, c, _)) in S.A.iter().chain(S.B.iter()).chain(S.C.iter()).enumerate() {
row[i] = *r;
col[i] = *c;
}
let val_A = { let val_A = {
let mut val = S.A.iter().map(|(_, _, v)| *v).collect::<Vec<G::Scalar>>();
val.resize(N, G::Scalar::ZERO);
let mut val = vec![G::Scalar::ZERO; N];
for (i, (_, _, v)) in S.A.iter().enumerate() {
val[i] = *v;
}
val val
}; };
let val_B = { let val_B = {
// prepend zeros
let mut val = vec![G::Scalar::ZERO; S.A.len()];
val.extend(S.B.iter().map(|(_, _, v)| *v).collect::<Vec<G::Scalar>>());
// append zeros
val.resize(N, G::Scalar::ZERO);
let mut val = vec![G::Scalar::ZERO; N];
for (i, (_, _, v)) in S.B.iter().enumerate() {
val[S.A.len() + i] = *v;
}
val val
}; };
let val_C = { let val_C = {
// prepend zeros
let mut val = vec![G::Scalar::ZERO; S.A.len() + S.B.len()];
val.extend(S.C.iter().map(|(_, _, v)| *v).collect::<Vec<G::Scalar>>());
// append zeros
val.resize(N, G::Scalar::ZERO);
let mut val = vec![G::Scalar::ZERO; N];
for (i, (_, _, v)) in S.C.iter().enumerate() {
val[S.A.len() + S.B.len() + i] = *v;
}
val val
}; };
@ -265,29 +248,30 @@ impl R1CSShapeSparkRepr {
let mem_row = EqPolynomial::new(r_x_padded).evals(); let mem_row = EqPolynomial::new(r_x_padded).evals();
let mem_col = { let mem_col = {
let mut z = z.to_vec();
z.resize(self.N, G::Scalar::ZERO);
z
let mut val = vec![G::Scalar::ZERO; self.N];
for (i, v) in z.iter().enumerate() {
val[i] = *v;
}
val
}; };
let mut E_row = S
.A
.iter()
.chain(S.B.iter())
.chain(S.C.iter())
.map(|(r, _, _)| mem_row[*r])
.collect::<Vec<G::Scalar>>();
let mut E_col = S
.A
.iter()
.chain(S.B.iter())
.chain(S.C.iter())
.map(|(_, c, _)| mem_col[*c])
.collect::<Vec<G::Scalar>>();
let (E_row, E_col) = {
let mut E_row = vec![mem_row[0]; self.N]; // we place mem_row[0] since resized row is appended with 0s
let mut E_col = vec![mem_col[0]; self.N];
E_row.resize(self.N, mem_row[0]); // we place mem_row[0] since resized row is appended with 0s
E_col.resize(self.N, mem_col[0]);
for (i, (val_r, val_c)) in S
.A
.iter()
.chain(S.B.iter())
.chain(S.C.iter())
.map(|(r, c, _)| (mem_row[*r], mem_col[*c]))
.enumerate()
{
E_row[i] = val_r;
E_col[i] = val_c;
}
(E_row, E_col)
};
(mem_row, mem_col, E_row, E_col) (mem_row, mem_col, E_row, E_col)
} }
@ -862,7 +846,7 @@ impl> RelaxedR1CSSNARK
let mut e = claim; let mut e = claim;
let mut r: Vec<G::Scalar> = Vec::new(); let mut r: Vec<G::Scalar> = Vec::new();
let mut cubic_polys: Vec<CompressedUniPoly<G>> = Vec::new();
let mut cubic_polys: Vec<CompressedUniPoly<G::Scalar>> = Vec::new();
let num_rounds = mem.size().log_2(); let num_rounds = mem.size().log_2();
for _i in 0..num_rounds { for _i in 0..num_rounds {
let mut evals: Vec<Vec<G::Scalar>> = Vec::new(); let mut evals: Vec<Vec<G::Scalar>> = Vec::new();
@ -999,8 +983,13 @@ impl> RelaxedR1CSSNARKTrait
Bz.resize(pk.S_repr.N, G::Scalar::ZERO); Bz.resize(pk.S_repr.N, G::Scalar::ZERO);
Cz.resize(pk.S_repr.N, G::Scalar::ZERO); Cz.resize(pk.S_repr.N, G::Scalar::ZERO);
let mut E = W.E.clone();
E.resize(pk.S_repr.N, G::Scalar::ZERO);
let E = {
let mut val = vec![G::Scalar::ZERO; pk.S_repr.N];
for (i, w_e) in W.E.iter().enumerate() {
val[i] = *w_e;
}
val
};
(Az, Bz, Cz, E) (Az, Bz, Cz, E)
}; };

+ 24
- 27
src/spartan/sumcheck.rs

@ -3,19 +3,18 @@
use super::polynomial::MultilinearPolynomial; use super::polynomial::MultilinearPolynomial;
use crate::errors::NovaError; use crate::errors::NovaError;
use crate::traits::{Group, TranscriptEngineTrait, TranscriptReprTrait}; use crate::traits::{Group, TranscriptEngineTrait, TranscriptReprTrait};
use core::marker::PhantomData;
use ff::Field;
use ff::{Field, PrimeField};
use rayon::prelude::*; use rayon::prelude::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound = "")] #[serde(bound = "")]
pub(crate) struct SumcheckProof<G: Group> { pub(crate) struct SumcheckProof<G: Group> {
compressed_polys: Vec<CompressedUniPoly<G>>,
compressed_polys: Vec<CompressedUniPoly<G::Scalar>>,
} }
impl<G: Group> SumcheckProof<G> { impl<G: Group> SumcheckProof<G> {
pub fn new(compressed_polys: Vec<CompressedUniPoly<G>>) -> Self {
pub fn new(compressed_polys: Vec<CompressedUniPoly<G::Scalar>>) -> Self {
Self { compressed_polys } Self { compressed_polys }
} }
@ -101,7 +100,7 @@ impl SumcheckProof {
F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar + Sync, F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar + Sync,
{ {
let mut r: Vec<G::Scalar> = Vec::new(); let mut r: Vec<G::Scalar> = Vec::new();
let mut polys: Vec<CompressedUniPoly<G>> = Vec::new();
let mut polys: Vec<CompressedUniPoly<G::Scalar>> = Vec::new();
let mut claim_per_round = *claim; let mut claim_per_round = *claim;
for _ in 0..num_rounds { for _ in 0..num_rounds {
let poly = { let poly = {
@ -150,7 +149,7 @@ impl SumcheckProof {
{ {
let mut e = *claim; let mut e = *claim;
let mut r: Vec<G::Scalar> = Vec::new(); let mut r: Vec<G::Scalar> = Vec::new();
let mut quad_polys: Vec<CompressedUniPoly<G>> = Vec::new();
let mut quad_polys: Vec<CompressedUniPoly<G::Scalar>> = Vec::new();
for _j in 0..num_rounds { for _j in 0..num_rounds {
let mut evals: Vec<(G::Scalar, G::Scalar)> = Vec::new(); let mut evals: Vec<(G::Scalar, G::Scalar)> = Vec::new();
@ -204,7 +203,7 @@ impl SumcheckProof {
F: Fn(&G::Scalar, &G::Scalar, &G::Scalar, &G::Scalar) -> G::Scalar + Sync, F: Fn(&G::Scalar, &G::Scalar, &G::Scalar, &G::Scalar) -> G::Scalar + Sync,
{ {
let mut r: Vec<G::Scalar> = Vec::new(); let mut r: Vec<G::Scalar> = Vec::new();
let mut polys: Vec<CompressedUniPoly<G>> = Vec::new();
let mut polys: Vec<CompressedUniPoly<G::Scalar>> = Vec::new();
let mut claim_per_round = *claim; let mut claim_per_round = *claim;
for _ in 0..num_rounds { for _ in 0..num_rounds {
@ -288,25 +287,24 @@ impl SumcheckProof {
// ax^2 + bx + c stored as vec![a,b,c] // ax^2 + bx + c stored as vec![a,b,c]
// ax^3 + bx^2 + cx + d stored as vec![a,b,c,d] // ax^3 + bx^2 + cx + d stored as vec![a,b,c,d]
#[derive(Debug)] #[derive(Debug)]
pub struct UniPoly<G: Group> {
coeffs: Vec<G::Scalar>,
pub struct UniPoly<Scalar: PrimeField> {
coeffs: Vec<Scalar>,
} }
// ax^2 + bx + c stored as vec![a,c] // ax^2 + bx + c stored as vec![a,c]
// ax^3 + bx^2 + cx + d stored as vec![a,c,d] // ax^3 + bx^2 + cx + d stored as vec![a,c,d]
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CompressedUniPoly<G: Group> {
coeffs_except_linear_term: Vec<G::Scalar>,
_p: PhantomData<G>,
pub struct CompressedUniPoly<Scalar: PrimeField> {
coeffs_except_linear_term: Vec<Scalar>,
} }
impl<G: Group> UniPoly<G> {
pub fn from_evals(evals: &[G::Scalar]) -> Self {
impl<Scalar: PrimeField> UniPoly<Scalar> {
pub fn from_evals(evals: &[Scalar]) -> Self {
// we only support degree-2 or degree-3 univariate polynomials // we only support degree-2 or degree-3 univariate polynomials
assert!(evals.len() == 3 || evals.len() == 4); assert!(evals.len() == 3 || evals.len() == 4);
let coeffs = if evals.len() == 3 { let coeffs = if evals.len() == 3 {
// ax^2 + bx + c // ax^2 + bx + c
let two_inv = G::Scalar::from(2).invert().unwrap();
let two_inv = Scalar::from(2).invert().unwrap();
let c = evals[0]; let c = evals[0];
let a = two_inv * (evals[2] - evals[1] - evals[1] + c); let a = two_inv * (evals[2] - evals[1] - evals[1] + c);
@ -314,8 +312,8 @@ impl UniPoly {
vec![c, b, a] vec![c, b, a]
} else { } else {
// ax^3 + bx^2 + cx + d // ax^3 + bx^2 + cx + d
let two_inv = G::Scalar::from(2).invert().unwrap();
let six_inv = G::Scalar::from(6).invert().unwrap();
let two_inv = Scalar::from(2).invert().unwrap();
let six_inv = Scalar::from(6).invert().unwrap();
let d = evals[0]; let d = evals[0];
let a = six_inv let a = six_inv
@ -338,18 +336,18 @@ impl UniPoly {
self.coeffs.len() - 1 self.coeffs.len() - 1
} }
pub fn eval_at_zero(&self) -> G::Scalar {
pub fn eval_at_zero(&self) -> Scalar {
self.coeffs[0] self.coeffs[0]
} }
pub fn eval_at_one(&self) -> G::Scalar {
pub fn eval_at_one(&self) -> Scalar {
(0..self.coeffs.len()) (0..self.coeffs.len())
.into_par_iter() .into_par_iter()
.map(|i| self.coeffs[i]) .map(|i| self.coeffs[i])
.reduce(|| G::Scalar::ZERO, |a, b| a + b)
.reduce(|| Scalar::ZERO, |a, b| a + b)
} }
pub fn evaluate(&self, r: &G::Scalar) -> G::Scalar {
pub fn evaluate(&self, r: &Scalar) -> Scalar {
let mut eval = self.coeffs[0]; let mut eval = self.coeffs[0];
let mut power = *r; let mut power = *r;
for coeff in self.coeffs.iter().skip(1) { for coeff in self.coeffs.iter().skip(1) {
@ -359,27 +357,26 @@ impl UniPoly {
eval eval
} }
pub fn compress(&self) -> CompressedUniPoly<G> {
pub fn compress(&self) -> CompressedUniPoly<Scalar> {
let coeffs_except_linear_term = [&self.coeffs[0..1], &self.coeffs[2..]].concat(); let coeffs_except_linear_term = [&self.coeffs[0..1], &self.coeffs[2..]].concat();
assert_eq!(coeffs_except_linear_term.len() + 1, self.coeffs.len()); assert_eq!(coeffs_except_linear_term.len() + 1, self.coeffs.len());
CompressedUniPoly { CompressedUniPoly {
coeffs_except_linear_term, coeffs_except_linear_term,
_p: Default::default(),
} }
} }
} }
impl<G: Group> CompressedUniPoly<G> {
impl<Scalar: PrimeField> CompressedUniPoly<Scalar> {
// we require eval(0) + eval(1) = hint, so we can solve for the linear term as: // we require eval(0) + eval(1) = hint, so we can solve for the linear term as:
// linear_term = hint - 2 * constant_term - deg2 term - deg3 term // linear_term = hint - 2 * constant_term - deg2 term - deg3 term
pub fn decompress(&self, hint: &G::Scalar) -> UniPoly<G> {
pub fn decompress(&self, hint: &Scalar) -> UniPoly<Scalar> {
let mut linear_term = let mut linear_term =
*hint - self.coeffs_except_linear_term[0] - self.coeffs_except_linear_term[0]; *hint - self.coeffs_except_linear_term[0] - self.coeffs_except_linear_term[0];
for i in 1..self.coeffs_except_linear_term.len() { for i in 1..self.coeffs_except_linear_term.len() {
linear_term -= self.coeffs_except_linear_term[i]; linear_term -= self.coeffs_except_linear_term[i];
} }
let mut coeffs: Vec<G::Scalar> = Vec::new();
let mut coeffs: Vec<Scalar> = Vec::new();
coeffs.push(self.coeffs_except_linear_term[0]); coeffs.push(self.coeffs_except_linear_term[0]);
coeffs.push(linear_term); coeffs.push(linear_term);
coeffs.extend(&self.coeffs_except_linear_term[1..]); coeffs.extend(&self.coeffs_except_linear_term[1..]);
@ -388,7 +385,7 @@ impl CompressedUniPoly {
} }
} }
impl<G: Group> TranscriptReprTrait<G> for UniPoly<G> {
impl<G: Group> TranscriptReprTrait<G> for UniPoly<G::Scalar> {
fn to_transcript_bytes(&self) -> Vec<u8> { fn to_transcript_bytes(&self) -> Vec<u8> {
let coeffs = self.compress().coeffs_except_linear_term; let coeffs = self.compress().coeffs_except_linear_term;
coeffs.as_slice().to_transcript_bytes() coeffs.as_slice().to_transcript_bytes()

Loading…
Cancel
Save