From 9e4c166edb73f62b813db58eb8ada3ff354b27e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= <4142+huitseeker@users.noreply.github.com> Date: Tue, 29 Sep 2020 18:18:43 -0400 Subject: [PATCH] Refactor to idiomatic Result/Option patterns (#25) This: - introduces a small [thiserror](https://github.com/dtolnay/thiserror)-powered enum to improve ProofVerifyError's messages, - refactors point decompression errors into a variant of that enum, thereby suppressing the panics which occur when decompresison fails. - folds other panics into the Error cases of their enclosing `Result` return --- Cargo.toml | 1 + src/commitments.rs | 2 +- src/dense_mlpoly.rs | 14 ++++---------- src/errors.rs | 21 +++++++++++---------- src/group.rs | 16 ++++++++++++++++ src/lib.rs | 38 ++++++++++++++++---------------------- src/nizk/bullet.rs | 8 ++++---- src/nizk/mod.rs | 38 ++++++++++++++++---------------------- src/r1csinstance.rs | 4 ++-- src/r1csproof.rs | 20 ++++++++------------ src/scalar/ristretto255.rs | 6 +++--- src/sparse_mlpoly.rs | 4 +--- src/unipoly.rs | 2 +- 13 files changed, 84 insertions(+), 90 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ac18dd8..0c83867 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ zeroize = { version = "1", default-features = false } itertools = "0.9.0" colored = "1.9.3" flate2 = "1.0.14" +thiserror = "1.0" [dev-dependencies] criterion = "0.3.1" diff --git a/src/commitments.rs b/src/commitments.rs index 8cc9112..d3caf7f 100644 --- a/src/commitments.rs +++ b/src/commitments.rs @@ -27,7 +27,7 @@ impl MultiCommitGens { MultiCommitGens { n, - G: gens[0..n].to_vec(), + G: gens[..n].to_vec(), h: gens[n], } } diff --git a/src/dense_mlpoly.rs b/src/dense_mlpoly.rs index 3c836c1..f802f49 100644 --- a/src/dense_mlpoly.rs +++ b/src/dense_mlpoly.rs @@ -90,7 +90,7 @@ impl EqPolynomial { let ell = self.r.len(); let (left_num_vars, _right_num_vars) = EqPolynomial::compute_factored_lens(ell); - let L = EqPolynomial::new(self.r[0..left_num_vars].to_vec()).evals(); + let L = EqPolynomial::new(self.r[..left_num_vars].to_vec()).evals(); let R = EqPolynomial::new(self.r[left_num_vars..ell].to_vec()).evals(); (L, R) @@ -137,7 +137,7 @@ impl DensePolynomial { pub fn split(&self, idx: usize) -> (DensePolynomial, DensePolynomial) { assert!(idx < self.len()); ( - DensePolynomial::new(self.Z[0..idx].to_vec()), + DensePolynomial::new(self.Z[..idx].to_vec()), DensePolynomial::new(self.Z[idx..2 * idx].to_vec()), ) } @@ -326,18 +326,12 @@ impl PolyEvalProof { let default_blinds = PolyCommitmentBlinds { blinds: vec![Scalar::zero(); L_size], }; - let blinds = match blinds_opt { - Some(p) => p, - None => &default_blinds, - }; + let blinds = blinds_opt.map_or(&default_blinds, |p| p); assert_eq!(blinds.blinds.len(), L_size); let zero = Scalar::zero(); - let blind_Zr = match blind_Zr_opt { - Some(p) => p, - None => &zero, - }; + let blind_Zr = blind_Zr_opt.map_or(&zero, |p| p); // compute the L and R vectors let eq = EqPolynomial::new(r.to_vec()); diff --git a/src/errors.rs b/src/errors.rs index 4917979..97eeb44 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,16 +1,17 @@ -use core::fmt; +use core::fmt::Debug; +use thiserror::Error; -pub struct ProofVerifyError; - -impl fmt::Display for ProofVerifyError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Proof verification failed") - } +#[derive(Error, Debug)] +pub enum ProofVerifyError { + #[error("Proof verification failed")] + InternalError, + #[error("Compressed group element failed to decompress: {0:?}")] + DecompressionError([u8; 32]), } -impl fmt::Debug for ProofVerifyError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{{ file: {}, line: {} }}", file!(), line!()) +impl Default for ProofVerifyError { + fn default() -> Self { + ProofVerifyError::InternalError } } diff --git a/src/group.rs b/src/group.rs index 7a60966..ee8b770 100644 --- a/src/group.rs +++ b/src/group.rs @@ -1,9 +1,25 @@ +use super::errors::ProofVerifyError; use super::scalar::{Scalar, ScalarBytes, ScalarBytesFromScalar}; use core::borrow::Borrow; use core::ops::{Mul, MulAssign}; pub type GroupElement = curve25519_dalek::ristretto::RistrettoPoint; pub type CompressedGroup = curve25519_dalek::ristretto::CompressedRistretto; + +pub trait CompressedGroupExt { + type Group; + fn unpack(&self) -> Result; +} + +impl CompressedGroupExt for CompressedGroup { + type Group = curve25519_dalek::ristretto::RistrettoPoint; + fn unpack(&self) -> Result { + self + .decompress() + .ok_or_else(|| ProofVerifyError::DecompressionError(self.to_bytes())) + } +} + pub const GROUP_BASEPOINT_COMPRESSED: CompressedGroup = curve25519_dalek::constants::RISTRETTO_BASEPOINT_COMPRESSED; diff --git a/src/lib.rs b/src/lib.rs index be781d3..55cfdf6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -339,17 +339,14 @@ impl SNARK { let timer_sat_proof = Timer::new("verify_sat_proof"); assert_eq!(input.assignment.len(), comm.comm.get_num_inputs()); - let (rx, ry) = self - .r1cs_sat_proof - .verify( - comm.comm.get_num_vars(), - comm.comm.get_num_cons(), - &input.assignment, - &self.inst_evals, - transcript, - &gens.gens_r1cs_sat, - ) - .unwrap(); + let (rx, ry) = self.r1cs_sat_proof.verify( + comm.comm.get_num_vars(), + comm.comm.get_num_cons(), + &input.assignment, + &self.inst_evals, + transcript, + &gens.gens_r1cs_sat, + )?; timer_sat_proof.stop(); let timer_eval_proof = Timer::new("verify_eval_proof"); @@ -454,17 +451,14 @@ impl NIZK { let timer_sat_proof = Timer::new("verify_sat_proof"); assert_eq!(input.assignment.len(), inst.inst.get_num_inputs()); - let (rx, ry) = self - .r1cs_sat_proof - .verify( - inst.inst.get_num_vars(), - inst.inst.get_num_cons(), - &input.assignment, - &inst_evals, - transcript, - &gens.gens_r1cs_sat, - ) - .unwrap(); + let (rx, ry) = self.r1cs_sat_proof.verify( + inst.inst.get_num_vars(), + inst.inst.get_num_cons(), + &input.assignment, + &inst_evals, + transcript, + &gens.gens_r1cs_sat, + )?; // verify if claimed rx and ry are correct assert_eq!(rx, *claimed_rx); diff --git a/src/nizk/bullet.rs b/src/nizk/bullet.rs index d33f878..0405059 100644 --- a/src/nizk/bullet.rs +++ b/src/nizk/bullet.rs @@ -148,10 +148,10 @@ impl BulletReductionProof { if lg_n >= 32 { // 4 billion multiplications should be enough for anyone // and this check prevents overflow in 1<, _>>()?; let Rs = self .R_vec .iter() - .map(|p| p.decompress().ok_or(ProofVerifyError)) + .map(|p| p.decompress().ok_or(ProofVerifyError::InternalError)) .collect::, _>>()?; let G_hat = GroupElement::vartime_multiscalar_mul(s.iter(), G.iter()); diff --git a/src/nizk/mod.rs b/src/nizk/mod.rs index 0cfb592..ddd2408 100644 --- a/src/nizk/mod.rs +++ b/src/nizk/mod.rs @@ -1,7 +1,7 @@ #![allow(clippy::too_many_arguments)] use super::commitments::{Commitments, MultiCommitGens}; use super::errors::ProofVerifyError; -use super::group::CompressedGroup; +use super::group::{CompressedGroup, CompressedGroupExt}; use super::math::Math; use super::random::RandomTape; use super::scalar::Scalar; @@ -64,17 +64,12 @@ impl KnowledgeProof { let c = transcript.challenge_scalar(b"c"); let lhs = self.z1.commit(&self.z2, gens_n).compress(); - let rhs = (c * C.decompress().expect("Could not decompress C") - + self - .alpha - .decompress() - .expect("Could not decompress self.alpha")) - .compress(); + let rhs = (c * C.unpack()? + self.alpha.unpack()?).compress(); if lhs == rhs { Ok(()) } else { - Err(ProofVerifyError) + Err(ProofVerifyError::InternalError) } } } @@ -134,8 +129,8 @@ impl EqualityProof { let c = transcript.challenge_scalar(b"c"); let rhs = { - let C = C1.decompress().unwrap() - C2.decompress().unwrap(); - (c * C + self.alpha.decompress().unwrap()).compress() + let C = C1.unpack()? - C2.unpack()?; + (c * C + self.alpha.unpack()?).compress() }; let lhs = (self.z * gens_n.h).compress(); @@ -143,7 +138,7 @@ impl EqualityProof { if lhs == rhs { Ok(()) } else { - Err(ProofVerifyError) + Err(ProofVerifyError::InternalError) } } } @@ -280,7 +275,7 @@ impl ProductProof { &c, &MultiCommitGens { n: 1, - G: vec![X.decompress().unwrap()], + G: vec![X.unpack()?], h: gens_n.h, }, &z3, @@ -289,7 +284,7 @@ impl ProductProof { { Ok(()) } else { - Err(ProofVerifyError) + Err(ProofVerifyError::InternalError) } } } @@ -392,17 +387,16 @@ impl DotProductProof { let c = transcript.challenge_scalar(b"c"); - let mut result = c * Cx.decompress().unwrap() + self.delta.decompress().unwrap() - == self.z.commit(&self.z_delta, gens_n); + let mut result = + c * Cx.unpack()? + self.delta.unpack()? == self.z.commit(&self.z_delta, gens_n); let dotproduct_z_a = DotProductProof::compute_dotproduct(&self.z, &a); - result &= c * Cy.decompress().unwrap() + self.beta.decompress().unwrap() - == dotproduct_z_a.commit(&self.z_beta, gens_1); + result &= c * Cy.unpack()? + self.beta.unpack()? == dotproduct_z_a.commit(&self.z_beta, gens_1); if result { Ok(()) } else { - Err(ProofVerifyError) + Err(ProofVerifyError::InternalError) } } } @@ -534,7 +528,7 @@ impl DotProductProofLog { Cx.append_to_transcript(b"Cx", transcript); Cy.append_to_transcript(b"Cy", transcript); - let Gamma = Cx.decompress().unwrap() + Cy.decompress().unwrap(); + let Gamma = Cx.unpack()? + Cy.unpack()?; let (g_hat, Gamma_hat, a_hat) = self .bullet_reduction_proof @@ -547,9 +541,9 @@ impl DotProductProofLog { let c = transcript.challenge_scalar(b"c"); let c_s = &c; - let beta_s = self.beta.decompress().unwrap(); + let beta_s = self.beta.unpack()?; let a_hat_s = &a_hat; - let delta_s = self.delta.decompress().unwrap(); + let delta_s = self.delta.unpack()?; let z1_s = &self.z1; let z2_s = &self.z2; @@ -561,7 +555,7 @@ impl DotProductProofLog { if lhs == rhs { Ok(()) } else { - Err(ProofVerifyError) + Err(ProofVerifyError::InternalError) } } } diff --git a/src/r1csinstance.rs b/src/r1csinstance.rs index c0bd21c..4c3d0e4 100644 --- a/src/r1csinstance.rs +++ b/src/r1csinstance.rs @@ -211,11 +211,11 @@ impl R1CSInstance { }; assert_eq!( - inst.is_sat(&Z[0..num_vars].to_vec(), &Z[num_vars + 1..].to_vec()), + inst.is_sat(&Z[..num_vars].to_vec(), &Z[num_vars + 1..].to_vec()), true, ); - (inst, Z[0..num_vars].to_vec(), Z[num_vars + 1..].to_vec()) + (inst, Z[..num_vars].to_vec(), Z[num_vars + 1..].to_vec()) } pub fn is_sat(&self, vars: &[Scalar], input: &[Scalar]) -> bool { diff --git a/src/r1csproof.rs b/src/r1csproof.rs index 1fbecef..e66cdd9 100644 --- a/src/r1csproof.rs +++ b/src/r1csproof.rs @@ -370,18 +370,14 @@ impl R1CSProof { let claim_phase1 = Scalar::zero() .commit(&Scalar::zero(), &gens.gens_sc.gens_1) .compress(); - let (comm_claim_post_phase1, rx) = self - .sc_proof_phase1 - .verify( - &claim_phase1, - num_rounds_x, - 3, - &gens.gens_sc.gens_1, - &gens.gens_sc.gens_4, - transcript, - ) - .unwrap(); - + let (comm_claim_post_phase1, rx) = self.sc_proof_phase1.verify( + &claim_phase1, + num_rounds_x, + 3, + &gens.gens_sc.gens_1, + &gens.gens_sc.gens_4, + transcript, + )?; // perform the intermediate sum-check test with claimed Az, Bz, and Cz let (comm_Az_claim, comm_Bz_claim, comm_Cz_claim, comm_prod_Az_Bz_claims) = &self.claims_phase2; let (pok_Cz_claim, proof_prod) = &self.pok_claims_phase2; diff --git a/src/scalar/ristretto255.rs b/src/scalar/ristretto255.rs index 5696917..e8e33c8 100755 --- a/src/scalar/ristretto255.rs +++ b/src/scalar/ristretto255.rs @@ -398,7 +398,7 @@ impl Scalar { pub fn from_bytes(bytes: &[u8; 32]) -> CtOption { let mut tmp = Scalar([0, 0, 0, 0]); - tmp.0[0] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap()); + tmp.0[0] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[..8]).unwrap()); tmp.0[1] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap()); tmp.0[2] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap()); tmp.0[3] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap()); @@ -429,7 +429,7 @@ impl Scalar { let tmp = Scalar::montgomery_reduce(self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0); let mut res = [0; 32]; - res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes()); + res[..8].copy_from_slice(&tmp.0[0].to_le_bytes()); res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes()); res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes()); res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes()); @@ -441,7 +441,7 @@ impl Scalar { /// a `Scalar` by reducing by the modulus. pub fn from_bytes_wide(bytes: &[u8; 64]) -> Scalar { Scalar::from_u512([ - u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap()), + u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[..8]).unwrap()), u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap()), u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap()), u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap()), diff --git a/src/sparse_mlpoly.rs b/src/sparse_mlpoly.rs index d242757..389dab0 100644 --- a/src/sparse_mlpoly.rs +++ b/src/sparse_mlpoly.rs @@ -1400,9 +1400,7 @@ impl PolyEvalNetworkProof { let (claims_mem, rand_mem, mut claims_ops, claims_dotp, rand_ops) = self .proof_prod_layer - .verify(num_ops, num_cells, evals, transcript) - .unwrap(); - + .verify(num_ops, num_cells, evals, transcript)?; assert_eq!(claims_mem.len(), 4); assert_eq!(claims_ops.len(), 4 * num_instances); assert_eq!(claims_dotp.len(), 3 * num_instances); diff --git a/src/unipoly.rs b/src/unipoly.rs index 895b24d..0a0549f 100644 --- a/src/unipoly.rs +++ b/src/unipoly.rs @@ -80,7 +80,7 @@ impl UniPoly { } pub fn compress(&self) -> CompressedUniPoly { - let coeffs_except_linear_term = [&self.coeffs[0..1], &self.coeffs[2..]].concat(); + let coeffs_except_linear_term = [&self.coeffs[..1], &self.coeffs[2..]].concat(); assert_eq!(coeffs_except_linear_term.len() + 1, self.coeffs.len()); CompressedUniPoly { coeffs_except_linear_term,