Ppsnark refactorings (#208)

* refactor: Refactor row/col vector construction for efficiency

- Optimized the creation of `row` and `col` in `R1CSShapeSparkRepr::new` using map and unzip methods.
- Updated `R1CSShapeSparkRepr::evaluation_oracles` to create `E_row` and `E_col` using the same logic for consistency.

* refactor: Refactor and optimize `R1CSShapeSparkRepr` initialization

- Updated method for zero padding in `val_B` and `val_C` using `std::iter::repeat`, to need one vector allocation instead of two
- Functionality and outputs remain unchanged.

* refactor: Refactor polynomial struct in SumCheck to use generic Scalar type

- Updated `CompressedUniPoly` and `UniPoly` structs in `sumcheck.rs` to use the generic `Scalar` type.
- Adapted all methods within these structs to accommodate the `Scalar` type instead of `G: Group` type.
- Modified the type of `cubic_polys` in `ppsnark.rs` to `CompressedUniPoly<G::Scalar>`.

* refactor: Eliminate most instances of resize

resize in Rust may cause reallocation of the memory, which is an expensive operation. This is particularly true when the vector is resized to a larger size.
This commit is contained in:
François Garillot
2023-07-28 13:03:34 -04:00
committed by GitHub
parent cdab40357a
commit eeb3e470d5
2 changed files with 70 additions and 84 deletions

View File

@@ -119,51 +119,34 @@ impl<G: Group> R1CSShapeSparkRepr<G> {
max(total_nz, max(2 * S.num_vars, S.num_cons)).next_power_of_two() max(total_nz, max(2 * S.num_vars, S.num_cons)).next_power_of_two()
}; };
let row = { let (mut row, mut col) = (vec![0usize; N], vec![0usize; N]);
let mut r = S
.A
.iter()
.chain(S.B.iter())
.chain(S.C.iter())
.map(|(r, _, _)| *r)
.collect::<Vec<usize>>();
r.resize(N, 0usize);
r
};
let col = { for (i, (r, c, _)) in S.A.iter().chain(S.B.iter()).chain(S.C.iter()).enumerate() {
let mut c = S row[i] = *r;
.A col[i] = *c;
.iter() }
.chain(S.B.iter())
.chain(S.C.iter())
.map(|(_, c, _)| *c)
.collect::<Vec<usize>>();
c.resize(N, 0usize);
c
};
let val_A = { let val_A = {
let mut val = S.A.iter().map(|(_, _, v)| *v).collect::<Vec<G::Scalar>>(); let mut val = vec![G::Scalar::ZERO; N];
val.resize(N, G::Scalar::ZERO); for (i, (_, _, v)) in S.A.iter().enumerate() {
val[i] = *v;
}
val val
}; };
let val_B = { let val_B = {
// prepend zeros let mut val = vec![G::Scalar::ZERO; N];
let mut val = vec![G::Scalar::ZERO; S.A.len()]; for (i, (_, _, v)) in S.B.iter().enumerate() {
val.extend(S.B.iter().map(|(_, _, v)| *v).collect::<Vec<G::Scalar>>()); val[S.A.len() + i] = *v;
// append zeros }
val.resize(N, G::Scalar::ZERO);
val val
}; };
let val_C = { let val_C = {
// prepend zeros let mut val = vec![G::Scalar::ZERO; N];
let mut val = vec![G::Scalar::ZERO; S.A.len() + S.B.len()]; for (i, (_, _, v)) in S.C.iter().enumerate() {
val.extend(S.C.iter().map(|(_, _, v)| *v).collect::<Vec<G::Scalar>>()); val[S.A.len() + S.B.len() + i] = *v;
// append zeros }
val.resize(N, G::Scalar::ZERO);
val val
}; };
@@ -265,29 +248,30 @@ impl<G: Group> R1CSShapeSparkRepr<G> {
let mem_row = EqPolynomial::new(r_x_padded).evals(); let mem_row = EqPolynomial::new(r_x_padded).evals();
let mem_col = { let mem_col = {
let mut z = z.to_vec(); let mut val = vec![G::Scalar::ZERO; self.N];
z.resize(self.N, G::Scalar::ZERO); for (i, v) in z.iter().enumerate() {
z val[i] = *v;
}
val
}; };
let mut E_row = S let (E_row, E_col) = {
.A let mut E_row = vec![mem_row[0]; self.N]; // we place mem_row[0] since resized row is appended with 0s
.iter() let mut E_col = vec![mem_col[0]; self.N];
.chain(S.B.iter())
.chain(S.C.iter())
.map(|(r, _, _)| mem_row[*r])
.collect::<Vec<G::Scalar>>();
let mut E_col = S for (i, (val_r, val_c)) in S
.A .A
.iter() .iter()
.chain(S.B.iter()) .chain(S.B.iter())
.chain(S.C.iter()) .chain(S.C.iter())
.map(|(_, c, _)| mem_col[*c]) .map(|(r, c, _)| (mem_row[*r], mem_col[*c]))
.collect::<Vec<G::Scalar>>(); .enumerate()
{
E_row.resize(self.N, mem_row[0]); // we place mem_row[0] since resized row is appended with 0s E_row[i] = val_r;
E_col.resize(self.N, mem_col[0]); E_col[i] = val_c;
}
(E_row, E_col)
};
(mem_row, mem_col, E_row, E_col) (mem_row, mem_col, E_row, E_col)
} }
@@ -862,7 +846,7 @@ impl<G: Group, EE: EvaluationEngineTrait<G, CE = G::CE>> RelaxedR1CSSNARK<G, EE>
let mut e = claim; let mut e = claim;
let mut r: Vec<G::Scalar> = Vec::new(); let mut r: Vec<G::Scalar> = Vec::new();
let mut cubic_polys: Vec<CompressedUniPoly<G>> = Vec::new(); let mut cubic_polys: Vec<CompressedUniPoly<G::Scalar>> = Vec::new();
let num_rounds = mem.size().log_2(); let num_rounds = mem.size().log_2();
for _i in 0..num_rounds { for _i in 0..num_rounds {
let mut evals: Vec<Vec<G::Scalar>> = Vec::new(); let mut evals: Vec<Vec<G::Scalar>> = Vec::new();
@@ -999,8 +983,13 @@ impl<G: Group, EE: EvaluationEngineTrait<G, CE = G::CE>> RelaxedR1CSSNARKTrait<G
Bz.resize(pk.S_repr.N, G::Scalar::ZERO); Bz.resize(pk.S_repr.N, G::Scalar::ZERO);
Cz.resize(pk.S_repr.N, G::Scalar::ZERO); Cz.resize(pk.S_repr.N, G::Scalar::ZERO);
let mut E = W.E.clone(); let E = {
E.resize(pk.S_repr.N, G::Scalar::ZERO); let mut val = vec![G::Scalar::ZERO; pk.S_repr.N];
for (i, w_e) in W.E.iter().enumerate() {
val[i] = *w_e;
}
val
};
(Az, Bz, Cz, E) (Az, Bz, Cz, E)
}; };

View File

@@ -3,19 +3,18 @@
use super::polynomial::MultilinearPolynomial; use super::polynomial::MultilinearPolynomial;
use crate::errors::NovaError; use crate::errors::NovaError;
use crate::traits::{Group, TranscriptEngineTrait, TranscriptReprTrait}; use crate::traits::{Group, TranscriptEngineTrait, TranscriptReprTrait};
use core::marker::PhantomData; use ff::{Field, PrimeField};
use ff::Field;
use rayon::prelude::*; use rayon::prelude::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound = "")] #[serde(bound = "")]
pub(crate) struct SumcheckProof<G: Group> { pub(crate) struct SumcheckProof<G: Group> {
compressed_polys: Vec<CompressedUniPoly<G>>, compressed_polys: Vec<CompressedUniPoly<G::Scalar>>,
} }
impl<G: Group> SumcheckProof<G> { impl<G: Group> SumcheckProof<G> {
pub fn new(compressed_polys: Vec<CompressedUniPoly<G>>) -> Self { pub fn new(compressed_polys: Vec<CompressedUniPoly<G::Scalar>>) -> Self {
Self { compressed_polys } Self { compressed_polys }
} }
@@ -101,7 +100,7 @@ impl<G: Group> SumcheckProof<G> {
F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar + Sync, F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar + Sync,
{ {
let mut r: Vec<G::Scalar> = Vec::new(); let mut r: Vec<G::Scalar> = Vec::new();
let mut polys: Vec<CompressedUniPoly<G>> = Vec::new(); let mut polys: Vec<CompressedUniPoly<G::Scalar>> = Vec::new();
let mut claim_per_round = *claim; let mut claim_per_round = *claim;
for _ in 0..num_rounds { for _ in 0..num_rounds {
let poly = { let poly = {
@@ -150,7 +149,7 @@ impl<G: Group> SumcheckProof<G> {
{ {
let mut e = *claim; let mut e = *claim;
let mut r: Vec<G::Scalar> = Vec::new(); let mut r: Vec<G::Scalar> = Vec::new();
let mut quad_polys: Vec<CompressedUniPoly<G>> = Vec::new(); let mut quad_polys: Vec<CompressedUniPoly<G::Scalar>> = Vec::new();
for _j in 0..num_rounds { for _j in 0..num_rounds {
let mut evals: Vec<(G::Scalar, G::Scalar)> = Vec::new(); let mut evals: Vec<(G::Scalar, G::Scalar)> = Vec::new();
@@ -204,7 +203,7 @@ impl<G: Group> SumcheckProof<G> {
F: Fn(&G::Scalar, &G::Scalar, &G::Scalar, &G::Scalar) -> G::Scalar + Sync, F: Fn(&G::Scalar, &G::Scalar, &G::Scalar, &G::Scalar) -> G::Scalar + Sync,
{ {
let mut r: Vec<G::Scalar> = Vec::new(); let mut r: Vec<G::Scalar> = Vec::new();
let mut polys: Vec<CompressedUniPoly<G>> = Vec::new(); let mut polys: Vec<CompressedUniPoly<G::Scalar>> = Vec::new();
let mut claim_per_round = *claim; let mut claim_per_round = *claim;
for _ in 0..num_rounds { for _ in 0..num_rounds {
@@ -288,25 +287,24 @@ impl<G: Group> SumcheckProof<G> {
// ax^2 + bx + c stored as vec![a,b,c] // ax^2 + bx + c stored as vec![a,b,c]
// ax^3 + bx^2 + cx + d stored as vec![a,b,c,d] // ax^3 + bx^2 + cx + d stored as vec![a,b,c,d]
#[derive(Debug)] #[derive(Debug)]
pub struct UniPoly<G: Group> { pub struct UniPoly<Scalar: PrimeField> {
coeffs: Vec<G::Scalar>, coeffs: Vec<Scalar>,
} }
// ax^2 + bx + c stored as vec![a,c] // ax^2 + bx + c stored as vec![a,c]
// ax^3 + bx^2 + cx + d stored as vec![a,c,d] // ax^3 + bx^2 + cx + d stored as vec![a,c,d]
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CompressedUniPoly<G: Group> { pub struct CompressedUniPoly<Scalar: PrimeField> {
coeffs_except_linear_term: Vec<G::Scalar>, coeffs_except_linear_term: Vec<Scalar>,
_p: PhantomData<G>,
} }
impl<G: Group> UniPoly<G> { impl<Scalar: PrimeField> UniPoly<Scalar> {
pub fn from_evals(evals: &[G::Scalar]) -> Self { pub fn from_evals(evals: &[Scalar]) -> Self {
// we only support degree-2 or degree-3 univariate polynomials // we only support degree-2 or degree-3 univariate polynomials
assert!(evals.len() == 3 || evals.len() == 4); assert!(evals.len() == 3 || evals.len() == 4);
let coeffs = if evals.len() == 3 { let coeffs = if evals.len() == 3 {
// ax^2 + bx + c // ax^2 + bx + c
let two_inv = G::Scalar::from(2).invert().unwrap(); let two_inv = Scalar::from(2).invert().unwrap();
let c = evals[0]; let c = evals[0];
let a = two_inv * (evals[2] - evals[1] - evals[1] + c); let a = two_inv * (evals[2] - evals[1] - evals[1] + c);
@@ -314,8 +312,8 @@ impl<G: Group> UniPoly<G> {
vec![c, b, a] vec![c, b, a]
} else { } else {
// ax^3 + bx^2 + cx + d // ax^3 + bx^2 + cx + d
let two_inv = G::Scalar::from(2).invert().unwrap(); let two_inv = Scalar::from(2).invert().unwrap();
let six_inv = G::Scalar::from(6).invert().unwrap(); let six_inv = Scalar::from(6).invert().unwrap();
let d = evals[0]; let d = evals[0];
let a = six_inv let a = six_inv
@@ -338,18 +336,18 @@ impl<G: Group> UniPoly<G> {
self.coeffs.len() - 1 self.coeffs.len() - 1
} }
pub fn eval_at_zero(&self) -> G::Scalar { pub fn eval_at_zero(&self) -> Scalar {
self.coeffs[0] self.coeffs[0]
} }
pub fn eval_at_one(&self) -> G::Scalar { pub fn eval_at_one(&self) -> Scalar {
(0..self.coeffs.len()) (0..self.coeffs.len())
.into_par_iter() .into_par_iter()
.map(|i| self.coeffs[i]) .map(|i| self.coeffs[i])
.reduce(|| G::Scalar::ZERO, |a, b| a + b) .reduce(|| Scalar::ZERO, |a, b| a + b)
} }
pub fn evaluate(&self, r: &G::Scalar) -> G::Scalar { pub fn evaluate(&self, r: &Scalar) -> Scalar {
let mut eval = self.coeffs[0]; let mut eval = self.coeffs[0];
let mut power = *r; let mut power = *r;
for coeff in self.coeffs.iter().skip(1) { for coeff in self.coeffs.iter().skip(1) {
@@ -359,27 +357,26 @@ impl<G: Group> UniPoly<G> {
eval eval
} }
pub fn compress(&self) -> CompressedUniPoly<G> { pub fn compress(&self) -> CompressedUniPoly<Scalar> {
let coeffs_except_linear_term = [&self.coeffs[0..1], &self.coeffs[2..]].concat(); let coeffs_except_linear_term = [&self.coeffs[0..1], &self.coeffs[2..]].concat();
assert_eq!(coeffs_except_linear_term.len() + 1, self.coeffs.len()); assert_eq!(coeffs_except_linear_term.len() + 1, self.coeffs.len());
CompressedUniPoly { CompressedUniPoly {
coeffs_except_linear_term, coeffs_except_linear_term,
_p: Default::default(),
} }
} }
} }
impl<G: Group> CompressedUniPoly<G> { impl<Scalar: PrimeField> CompressedUniPoly<Scalar> {
// we require eval(0) + eval(1) = hint, so we can solve for the linear term as: // we require eval(0) + eval(1) = hint, so we can solve for the linear term as:
// linear_term = hint - 2 * constant_term - deg2 term - deg3 term // linear_term = hint - 2 * constant_term - deg2 term - deg3 term
pub fn decompress(&self, hint: &G::Scalar) -> UniPoly<G> { pub fn decompress(&self, hint: &Scalar) -> UniPoly<Scalar> {
let mut linear_term = let mut linear_term =
*hint - self.coeffs_except_linear_term[0] - self.coeffs_except_linear_term[0]; *hint - self.coeffs_except_linear_term[0] - self.coeffs_except_linear_term[0];
for i in 1..self.coeffs_except_linear_term.len() { for i in 1..self.coeffs_except_linear_term.len() {
linear_term -= self.coeffs_except_linear_term[i]; linear_term -= self.coeffs_except_linear_term[i];
} }
let mut coeffs: Vec<G::Scalar> = Vec::new(); let mut coeffs: Vec<Scalar> = Vec::new();
coeffs.push(self.coeffs_except_linear_term[0]); coeffs.push(self.coeffs_except_linear_term[0]);
coeffs.push(linear_term); coeffs.push(linear_term);
coeffs.extend(&self.coeffs_except_linear_term[1..]); coeffs.extend(&self.coeffs_except_linear_term[1..]);
@@ -388,7 +385,7 @@ impl<G: Group> CompressedUniPoly<G> {
} }
} }
impl<G: Group> TranscriptReprTrait<G> for UniPoly<G> { impl<G: Group> TranscriptReprTrait<G> for UniPoly<G::Scalar> {
fn to_transcript_bytes(&self) -> Vec<u8> { fn to_transcript_bytes(&self) -> Vec<u8> {
let coeffs = self.compress().coeffs_except_linear_term; let coeffs = self.compress().coeffs_except_linear_term;
coeffs.as_slice().to_transcript_bytes() coeffs.as_slice().to_transcript_bytes()