diff --git a/src/spartan/ppsnark.rs b/src/spartan/ppsnark.rs index aebd5c1..24caece 100644 --- a/src/spartan/ppsnark.rs +++ b/src/spartan/ppsnark.rs @@ -119,51 +119,34 @@ impl R1CSShapeSparkRepr { max(total_nz, max(2 * S.num_vars, S.num_cons)).next_power_of_two() }; - let row = { - let mut r = S - .A - .iter() - .chain(S.B.iter()) - .chain(S.C.iter()) - .map(|(r, _, _)| *r) - .collect::>(); - r.resize(N, 0usize); - r - }; + let (mut row, mut col) = (vec![0usize; N], vec![0usize; N]); - let col = { - let mut c = S - .A - .iter() - .chain(S.B.iter()) - .chain(S.C.iter()) - .map(|(_, c, _)| *c) - .collect::>(); - c.resize(N, 0usize); - c - }; + for (i, (r, c, _)) in S.A.iter().chain(S.B.iter()).chain(S.C.iter()).enumerate() { + row[i] = *r; + col[i] = *c; + } let val_A = { - let mut val = S.A.iter().map(|(_, _, v)| *v).collect::>(); - val.resize(N, G::Scalar::ZERO); + let mut val = vec![G::Scalar::ZERO; N]; + for (i, (_, _, v)) in S.A.iter().enumerate() { + val[i] = *v; + } val }; let val_B = { - // prepend zeros - let mut val = vec![G::Scalar::ZERO; S.A.len()]; - val.extend(S.B.iter().map(|(_, _, v)| *v).collect::>()); - // append zeros - val.resize(N, G::Scalar::ZERO); + let mut val = vec![G::Scalar::ZERO; N]; + for (i, (_, _, v)) in S.B.iter().enumerate() { + val[S.A.len() + i] = *v; + } val }; let val_C = { - // prepend zeros - let mut val = vec![G::Scalar::ZERO; S.A.len() + S.B.len()]; - val.extend(S.C.iter().map(|(_, _, v)| *v).collect::>()); - // append zeros - val.resize(N, G::Scalar::ZERO); + let mut val = vec![G::Scalar::ZERO; N]; + for (i, (_, _, v)) in S.C.iter().enumerate() { + val[S.A.len() + S.B.len() + i] = *v; + } val }; @@ -265,29 +248,30 @@ impl R1CSShapeSparkRepr { let mem_row = EqPolynomial::new(r_x_padded).evals(); let mem_col = { - let mut z = z.to_vec(); - z.resize(self.N, G::Scalar::ZERO); - z + let mut val = vec![G::Scalar::ZERO; self.N]; + for (i, v) in z.iter().enumerate() { + val[i] = *v; + } + val }; - let mut E_row = S - .A - .iter() - .chain(S.B.iter()) - .chain(S.C.iter()) - .map(|(r, _, _)| mem_row[*r]) - .collect::>(); - - let mut E_col = S - .A - .iter() - .chain(S.B.iter()) - .chain(S.C.iter()) - .map(|(_, c, _)| mem_col[*c]) - .collect::>(); + let (E_row, E_col) = { + let mut E_row = vec![mem_row[0]; self.N]; // we place mem_row[0] since resized row is appended with 0s + let mut E_col = vec![mem_col[0]; self.N]; - E_row.resize(self.N, mem_row[0]); // we place mem_row[0] since resized row is appended with 0s - E_col.resize(self.N, mem_col[0]); + for (i, (val_r, val_c)) in S + .A + .iter() + .chain(S.B.iter()) + .chain(S.C.iter()) + .map(|(r, c, _)| (mem_row[*r], mem_col[*c])) + .enumerate() + { + E_row[i] = val_r; + E_col[i] = val_c; + } + (E_row, E_col) + }; (mem_row, mem_col, E_row, E_col) } @@ -862,7 +846,7 @@ impl> RelaxedR1CSSNARK let mut e = claim; let mut r: Vec = Vec::new(); - let mut cubic_polys: Vec> = Vec::new(); + let mut cubic_polys: Vec> = Vec::new(); let num_rounds = mem.size().log_2(); for _i in 0..num_rounds { let mut evals: Vec> = Vec::new(); @@ -999,8 +983,13 @@ impl> RelaxedR1CSSNARKTrait { - compressed_polys: Vec>, + compressed_polys: Vec>, } impl SumcheckProof { - pub fn new(compressed_polys: Vec>) -> Self { + pub fn new(compressed_polys: Vec>) -> Self { Self { compressed_polys } } @@ -101,7 +100,7 @@ impl SumcheckProof { F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar + Sync, { let mut r: Vec = Vec::new(); - let mut polys: Vec> = Vec::new(); + let mut polys: Vec> = Vec::new(); let mut claim_per_round = *claim; for _ in 0..num_rounds { let poly = { @@ -150,7 +149,7 @@ impl SumcheckProof { { let mut e = *claim; let mut r: Vec = Vec::new(); - let mut quad_polys: Vec> = Vec::new(); + let mut quad_polys: Vec> = Vec::new(); for _j in 0..num_rounds { let mut evals: Vec<(G::Scalar, G::Scalar)> = Vec::new(); @@ -204,7 +203,7 @@ impl SumcheckProof { F: Fn(&G::Scalar, &G::Scalar, &G::Scalar, &G::Scalar) -> G::Scalar + Sync, { let mut r: Vec = Vec::new(); - let mut polys: Vec> = Vec::new(); + let mut polys: Vec> = Vec::new(); let mut claim_per_round = *claim; for _ in 0..num_rounds { @@ -288,25 +287,24 @@ impl SumcheckProof { // ax^2 + bx + c stored as vec![a,b,c] // ax^3 + bx^2 + cx + d stored as vec![a,b,c,d] #[derive(Debug)] -pub struct UniPoly { - coeffs: Vec, +pub struct UniPoly { + coeffs: Vec, } // ax^2 + bx + c stored as vec![a,c] // ax^3 + bx^2 + cx + d stored as vec![a,c,d] #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct CompressedUniPoly { - coeffs_except_linear_term: Vec, - _p: PhantomData, +pub struct CompressedUniPoly { + coeffs_except_linear_term: Vec, } -impl UniPoly { - pub fn from_evals(evals: &[G::Scalar]) -> Self { +impl UniPoly { + pub fn from_evals(evals: &[Scalar]) -> Self { // we only support degree-2 or degree-3 univariate polynomials assert!(evals.len() == 3 || evals.len() == 4); let coeffs = if evals.len() == 3 { // ax^2 + bx + c - let two_inv = G::Scalar::from(2).invert().unwrap(); + let two_inv = Scalar::from(2).invert().unwrap(); let c = evals[0]; let a = two_inv * (evals[2] - evals[1] - evals[1] + c); @@ -314,8 +312,8 @@ impl UniPoly { vec![c, b, a] } else { // ax^3 + bx^2 + cx + d - let two_inv = G::Scalar::from(2).invert().unwrap(); - let six_inv = G::Scalar::from(6).invert().unwrap(); + let two_inv = Scalar::from(2).invert().unwrap(); + let six_inv = Scalar::from(6).invert().unwrap(); let d = evals[0]; let a = six_inv @@ -338,18 +336,18 @@ impl UniPoly { self.coeffs.len() - 1 } - pub fn eval_at_zero(&self) -> G::Scalar { + pub fn eval_at_zero(&self) -> Scalar { self.coeffs[0] } - pub fn eval_at_one(&self) -> G::Scalar { + pub fn eval_at_one(&self) -> Scalar { (0..self.coeffs.len()) .into_par_iter() .map(|i| self.coeffs[i]) - .reduce(|| G::Scalar::ZERO, |a, b| a + b) + .reduce(|| Scalar::ZERO, |a, b| a + b) } - pub fn evaluate(&self, r: &G::Scalar) -> G::Scalar { + pub fn evaluate(&self, r: &Scalar) -> Scalar { let mut eval = self.coeffs[0]; let mut power = *r; for coeff in self.coeffs.iter().skip(1) { @@ -359,27 +357,26 @@ impl UniPoly { eval } - pub fn compress(&self) -> CompressedUniPoly { + pub fn compress(&self) -> CompressedUniPoly { let coeffs_except_linear_term = [&self.coeffs[0..1], &self.coeffs[2..]].concat(); assert_eq!(coeffs_except_linear_term.len() + 1, self.coeffs.len()); CompressedUniPoly { coeffs_except_linear_term, - _p: Default::default(), } } } -impl CompressedUniPoly { +impl CompressedUniPoly { // we require eval(0) + eval(1) = hint, so we can solve for the linear term as: // linear_term = hint - 2 * constant_term - deg2 term - deg3 term - pub fn decompress(&self, hint: &G::Scalar) -> UniPoly { + pub fn decompress(&self, hint: &Scalar) -> UniPoly { let mut linear_term = *hint - self.coeffs_except_linear_term[0] - self.coeffs_except_linear_term[0]; for i in 1..self.coeffs_except_linear_term.len() { linear_term -= self.coeffs_except_linear_term[i]; } - let mut coeffs: Vec = Vec::new(); + let mut coeffs: Vec = Vec::new(); coeffs.push(self.coeffs_except_linear_term[0]); coeffs.push(linear_term); coeffs.extend(&self.coeffs_except_linear_term[1..]); @@ -388,7 +385,7 @@ impl CompressedUniPoly { } } -impl TranscriptReprTrait for UniPoly { +impl TranscriptReprTrait for UniPoly { fn to_transcript_bytes(&self) -> Vec { let coeffs = self.compress().coeffs_except_linear_term; coeffs.as_slice().to_transcript_bytes()