Browse Source

Refactor to idiomatic Result/Option patterns (#25)

This:
- introduces a small [thiserror](https://github.com/dtolnay/thiserror)-powered enum to improve ProofVerifyError's messages,
- refactors point decompression errors into a variant of that enum, thereby suppressing the panics which occur when decompresison fails.
- folds other panics into the Error cases of their enclosing `Result` return
master
François Garillot 4 years ago
committed by GitHub
parent
commit
9e4c166edb
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 84 additions and 90 deletions
  1. +1
    -0
      Cargo.toml
  2. +1
    -1
      src/commitments.rs
  3. +4
    -10
      src/dense_mlpoly.rs
  4. +11
    -10
      src/errors.rs
  5. +16
    -0
      src/group.rs
  6. +16
    -22
      src/lib.rs
  7. +4
    -4
      src/nizk/bullet.rs
  8. +16
    -22
      src/nizk/mod.rs
  9. +2
    -2
      src/r1csinstance.rs
  10. +8
    -12
      src/r1csproof.rs
  11. +3
    -3
      src/scalar/ristretto255.rs
  12. +1
    -3
      src/sparse_mlpoly.rs
  13. +1
    -1
      src/unipoly.rs

+ 1
- 0
Cargo.toml

@ -26,6 +26,7 @@ zeroize = { version = "1", default-features = false }
itertools = "0.9.0" itertools = "0.9.0"
colored = "1.9.3" colored = "1.9.3"
flate2 = "1.0.14" flate2 = "1.0.14"
thiserror = "1.0"
[dev-dependencies] [dev-dependencies]
criterion = "0.3.1" criterion = "0.3.1"

+ 1
- 1
src/commitments.rs

@ -27,7 +27,7 @@ impl MultiCommitGens {
MultiCommitGens { MultiCommitGens {
n, n,
G: gens[0..n].to_vec(),
G: gens[..n].to_vec(),
h: gens[n], h: gens[n],
} }
} }

+ 4
- 10
src/dense_mlpoly.rs

@ -90,7 +90,7 @@ impl EqPolynomial {
let ell = self.r.len(); let ell = self.r.len();
let (left_num_vars, _right_num_vars) = EqPolynomial::compute_factored_lens(ell); let (left_num_vars, _right_num_vars) = EqPolynomial::compute_factored_lens(ell);
let L = EqPolynomial::new(self.r[0..left_num_vars].to_vec()).evals();
let L = EqPolynomial::new(self.r[..left_num_vars].to_vec()).evals();
let R = EqPolynomial::new(self.r[left_num_vars..ell].to_vec()).evals(); let R = EqPolynomial::new(self.r[left_num_vars..ell].to_vec()).evals();
(L, R) (L, R)
@ -137,7 +137,7 @@ impl DensePolynomial {
pub fn split(&self, idx: usize) -> (DensePolynomial, DensePolynomial) { pub fn split(&self, idx: usize) -> (DensePolynomial, DensePolynomial) {
assert!(idx < self.len()); assert!(idx < self.len());
( (
DensePolynomial::new(self.Z[0..idx].to_vec()),
DensePolynomial::new(self.Z[..idx].to_vec()),
DensePolynomial::new(self.Z[idx..2 * idx].to_vec()), DensePolynomial::new(self.Z[idx..2 * idx].to_vec()),
) )
} }
@ -326,18 +326,12 @@ impl PolyEvalProof {
let default_blinds = PolyCommitmentBlinds { let default_blinds = PolyCommitmentBlinds {
blinds: vec![Scalar::zero(); L_size], blinds: vec![Scalar::zero(); L_size],
}; };
let blinds = match blinds_opt {
Some(p) => p,
None => &default_blinds,
};
let blinds = blinds_opt.map_or(&default_blinds, |p| p);
assert_eq!(blinds.blinds.len(), L_size); assert_eq!(blinds.blinds.len(), L_size);
let zero = Scalar::zero(); let zero = Scalar::zero();
let blind_Zr = match blind_Zr_opt {
Some(p) => p,
None => &zero,
};
let blind_Zr = blind_Zr_opt.map_or(&zero, |p| p);
// compute the L and R vectors // compute the L and R vectors
let eq = EqPolynomial::new(r.to_vec()); let eq = EqPolynomial::new(r.to_vec());

+ 11
- 10
src/errors.rs

@ -1,16 +1,17 @@
use core::fmt;
use core::fmt::Debug;
use thiserror::Error;
pub struct ProofVerifyError;
impl fmt::Display for ProofVerifyError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Proof verification failed")
}
#[derive(Error, Debug)]
pub enum ProofVerifyError {
#[error("Proof verification failed")]
InternalError,
#[error("Compressed group element failed to decompress: {0:?}")]
DecompressionError([u8; 32]),
} }
impl fmt::Debug for ProofVerifyError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{{ file: {}, line: {} }}", file!(), line!())
impl Default for ProofVerifyError {
fn default() -> Self {
ProofVerifyError::InternalError
} }
} }

+ 16
- 0
src/group.rs

@ -1,9 +1,25 @@
use super::errors::ProofVerifyError;
use super::scalar::{Scalar, ScalarBytes, ScalarBytesFromScalar}; use super::scalar::{Scalar, ScalarBytes, ScalarBytesFromScalar};
use core::borrow::Borrow; use core::borrow::Borrow;
use core::ops::{Mul, MulAssign}; use core::ops::{Mul, MulAssign};
pub type GroupElement = curve25519_dalek::ristretto::RistrettoPoint; pub type GroupElement = curve25519_dalek::ristretto::RistrettoPoint;
pub type CompressedGroup = curve25519_dalek::ristretto::CompressedRistretto; pub type CompressedGroup = curve25519_dalek::ristretto::CompressedRistretto;
pub trait CompressedGroupExt {
type Group;
fn unpack(&self) -> Result<Self::Group, ProofVerifyError>;
}
impl CompressedGroupExt for CompressedGroup {
type Group = curve25519_dalek::ristretto::RistrettoPoint;
fn unpack(&self) -> Result<Self::Group, ProofVerifyError> {
self
.decompress()
.ok_or_else(|| ProofVerifyError::DecompressionError(self.to_bytes()))
}
}
pub const GROUP_BASEPOINT_COMPRESSED: CompressedGroup = pub const GROUP_BASEPOINT_COMPRESSED: CompressedGroup =
curve25519_dalek::constants::RISTRETTO_BASEPOINT_COMPRESSED; curve25519_dalek::constants::RISTRETTO_BASEPOINT_COMPRESSED;

+ 16
- 22
src/lib.rs

@ -339,17 +339,14 @@ impl SNARK {
let timer_sat_proof = Timer::new("verify_sat_proof"); let timer_sat_proof = Timer::new("verify_sat_proof");
assert_eq!(input.assignment.len(), comm.comm.get_num_inputs()); assert_eq!(input.assignment.len(), comm.comm.get_num_inputs());
let (rx, ry) = self
.r1cs_sat_proof
.verify(
comm.comm.get_num_vars(),
comm.comm.get_num_cons(),
&input.assignment,
&self.inst_evals,
transcript,
&gens.gens_r1cs_sat,
)
.unwrap();
let (rx, ry) = self.r1cs_sat_proof.verify(
comm.comm.get_num_vars(),
comm.comm.get_num_cons(),
&input.assignment,
&self.inst_evals,
transcript,
&gens.gens_r1cs_sat,
)?;
timer_sat_proof.stop(); timer_sat_proof.stop();
let timer_eval_proof = Timer::new("verify_eval_proof"); let timer_eval_proof = Timer::new("verify_eval_proof");
@ -454,17 +451,14 @@ impl NIZK {
let timer_sat_proof = Timer::new("verify_sat_proof"); let timer_sat_proof = Timer::new("verify_sat_proof");
assert_eq!(input.assignment.len(), inst.inst.get_num_inputs()); assert_eq!(input.assignment.len(), inst.inst.get_num_inputs());
let (rx, ry) = self
.r1cs_sat_proof
.verify(
inst.inst.get_num_vars(),
inst.inst.get_num_cons(),
&input.assignment,
&inst_evals,
transcript,
&gens.gens_r1cs_sat,
)
.unwrap();
let (rx, ry) = self.r1cs_sat_proof.verify(
inst.inst.get_num_vars(),
inst.inst.get_num_cons(),
&input.assignment,
&inst_evals,
transcript,
&gens.gens_r1cs_sat,
)?;
// verify if claimed rx and ry are correct // verify if claimed rx and ry are correct
assert_eq!(rx, *claimed_rx); assert_eq!(rx, *claimed_rx);

+ 4
- 4
src/nizk/bullet.rs

@ -148,10 +148,10 @@ impl BulletReductionProof {
if lg_n >= 32 { if lg_n >= 32 {
// 4 billion multiplications should be enough for anyone // 4 billion multiplications should be enough for anyone
// and this check prevents overflow in 1<<lg_n below. // and this check prevents overflow in 1<<lg_n below.
return Err(ProofVerifyError);
return Err(ProofVerifyError::InternalError);
} }
if n != (1 << lg_n) { if n != (1 << lg_n) {
return Err(ProofVerifyError);
return Err(ProofVerifyError::InternalError);
} }
// 1. Recompute x_k,...,x_1 based on the proof transcript // 1. Recompute x_k,...,x_1 based on the proof transcript
@ -206,13 +206,13 @@ impl BulletReductionProof {
let Ls = self let Ls = self
.L_vec .L_vec
.iter() .iter()
.map(|p| p.decompress().ok_or(ProofVerifyError))
.map(|p| p.decompress().ok_or(ProofVerifyError::InternalError))
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let Rs = self let Rs = self
.R_vec .R_vec
.iter() .iter()
.map(|p| p.decompress().ok_or(ProofVerifyError))
.map(|p| p.decompress().ok_or(ProofVerifyError::InternalError))
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let G_hat = GroupElement::vartime_multiscalar_mul(s.iter(), G.iter()); let G_hat = GroupElement::vartime_multiscalar_mul(s.iter(), G.iter());

+ 16
- 22
src/nizk/mod.rs

@ -1,7 +1,7 @@
#![allow(clippy::too_many_arguments)] #![allow(clippy::too_many_arguments)]
use super::commitments::{Commitments, MultiCommitGens}; use super::commitments::{Commitments, MultiCommitGens};
use super::errors::ProofVerifyError; use super::errors::ProofVerifyError;
use super::group::CompressedGroup;
use super::group::{CompressedGroup, CompressedGroupExt};
use super::math::Math; use super::math::Math;
use super::random::RandomTape; use super::random::RandomTape;
use super::scalar::Scalar; use super::scalar::Scalar;
@ -64,17 +64,12 @@ impl KnowledgeProof {
let c = transcript.challenge_scalar(b"c"); let c = transcript.challenge_scalar(b"c");
let lhs = self.z1.commit(&self.z2, gens_n).compress(); let lhs = self.z1.commit(&self.z2, gens_n).compress();
let rhs = (c * C.decompress().expect("Could not decompress C")
+ self
.alpha
.decompress()
.expect("Could not decompress self.alpha"))
.compress();
let rhs = (c * C.unpack()? + self.alpha.unpack()?).compress();
if lhs == rhs { if lhs == rhs {
Ok(()) Ok(())
} else { } else {
Err(ProofVerifyError)
Err(ProofVerifyError::InternalError)
} }
} }
} }
@ -134,8 +129,8 @@ impl EqualityProof {
let c = transcript.challenge_scalar(b"c"); let c = transcript.challenge_scalar(b"c");
let rhs = { let rhs = {
let C = C1.decompress().unwrap() - C2.decompress().unwrap();
(c * C + self.alpha.decompress().unwrap()).compress()
let C = C1.unpack()? - C2.unpack()?;
(c * C + self.alpha.unpack()?).compress()
}; };
let lhs = (self.z * gens_n.h).compress(); let lhs = (self.z * gens_n.h).compress();
@ -143,7 +138,7 @@ impl EqualityProof {
if lhs == rhs { if lhs == rhs {
Ok(()) Ok(())
} else { } else {
Err(ProofVerifyError)
Err(ProofVerifyError::InternalError)
} }
} }
} }
@ -280,7 +275,7 @@ impl ProductProof {
&c, &c,
&MultiCommitGens { &MultiCommitGens {
n: 1, n: 1,
G: vec![X.decompress().unwrap()],
G: vec![X.unpack()?],
h: gens_n.h, h: gens_n.h,
}, },
&z3, &z3,
@ -289,7 +284,7 @@ impl ProductProof {
{ {
Ok(()) Ok(())
} else { } else {
Err(ProofVerifyError)
Err(ProofVerifyError::InternalError)
} }
} }
} }
@ -392,17 +387,16 @@ impl DotProductProof {
let c = transcript.challenge_scalar(b"c"); let c = transcript.challenge_scalar(b"c");
let mut result = c * Cx.decompress().unwrap() + self.delta.decompress().unwrap()
== self.z.commit(&self.z_delta, gens_n);
let mut result =
c * Cx.unpack()? + self.delta.unpack()? == self.z.commit(&self.z_delta, gens_n);
let dotproduct_z_a = DotProductProof::compute_dotproduct(&self.z, &a); let dotproduct_z_a = DotProductProof::compute_dotproduct(&self.z, &a);
result &= c * Cy.decompress().unwrap() + self.beta.decompress().unwrap()
== dotproduct_z_a.commit(&self.z_beta, gens_1);
result &= c * Cy.unpack()? + self.beta.unpack()? == dotproduct_z_a.commit(&self.z_beta, gens_1);
if result { if result {
Ok(()) Ok(())
} else { } else {
Err(ProofVerifyError)
Err(ProofVerifyError::InternalError)
} }
} }
} }
@ -534,7 +528,7 @@ impl DotProductProofLog {
Cx.append_to_transcript(b"Cx", transcript); Cx.append_to_transcript(b"Cx", transcript);
Cy.append_to_transcript(b"Cy", transcript); Cy.append_to_transcript(b"Cy", transcript);
let Gamma = Cx.decompress().unwrap() + Cy.decompress().unwrap();
let Gamma = Cx.unpack()? + Cy.unpack()?;
let (g_hat, Gamma_hat, a_hat) = self let (g_hat, Gamma_hat, a_hat) = self
.bullet_reduction_proof .bullet_reduction_proof
@ -547,9 +541,9 @@ impl DotProductProofLog {
let c = transcript.challenge_scalar(b"c"); let c = transcript.challenge_scalar(b"c");
let c_s = &c; let c_s = &c;
let beta_s = self.beta.decompress().unwrap();
let beta_s = self.beta.unpack()?;
let a_hat_s = &a_hat; let a_hat_s = &a_hat;
let delta_s = self.delta.decompress().unwrap();
let delta_s = self.delta.unpack()?;
let z1_s = &self.z1; let z1_s = &self.z1;
let z2_s = &self.z2; let z2_s = &self.z2;
@ -561,7 +555,7 @@ impl DotProductProofLog {
if lhs == rhs { if lhs == rhs {
Ok(()) Ok(())
} else { } else {
Err(ProofVerifyError)
Err(ProofVerifyError::InternalError)
} }
} }
} }

+ 2
- 2
src/r1csinstance.rs

@ -211,11 +211,11 @@ impl R1CSInstance {
}; };
assert_eq!( assert_eq!(
inst.is_sat(&Z[0..num_vars].to_vec(), &Z[num_vars + 1..].to_vec()),
inst.is_sat(&Z[..num_vars].to_vec(), &Z[num_vars + 1..].to_vec()),
true, true,
); );
(inst, Z[0..num_vars].to_vec(), Z[num_vars + 1..].to_vec())
(inst, Z[..num_vars].to_vec(), Z[num_vars + 1..].to_vec())
} }
pub fn is_sat(&self, vars: &[Scalar], input: &[Scalar]) -> bool { pub fn is_sat(&self, vars: &[Scalar], input: &[Scalar]) -> bool {

+ 8
- 12
src/r1csproof.rs

@ -370,18 +370,14 @@ impl R1CSProof {
let claim_phase1 = Scalar::zero() let claim_phase1 = Scalar::zero()
.commit(&Scalar::zero(), &gens.gens_sc.gens_1) .commit(&Scalar::zero(), &gens.gens_sc.gens_1)
.compress(); .compress();
let (comm_claim_post_phase1, rx) = self
.sc_proof_phase1
.verify(
&claim_phase1,
num_rounds_x,
3,
&gens.gens_sc.gens_1,
&gens.gens_sc.gens_4,
transcript,
)
.unwrap();
let (comm_claim_post_phase1, rx) = self.sc_proof_phase1.verify(
&claim_phase1,
num_rounds_x,
3,
&gens.gens_sc.gens_1,
&gens.gens_sc.gens_4,
transcript,
)?;
// perform the intermediate sum-check test with claimed Az, Bz, and Cz // perform the intermediate sum-check test with claimed Az, Bz, and Cz
let (comm_Az_claim, comm_Bz_claim, comm_Cz_claim, comm_prod_Az_Bz_claims) = &self.claims_phase2; let (comm_Az_claim, comm_Bz_claim, comm_Cz_claim, comm_prod_Az_Bz_claims) = &self.claims_phase2;
let (pok_Cz_claim, proof_prod) = &self.pok_claims_phase2; let (pok_Cz_claim, proof_prod) = &self.pok_claims_phase2;

+ 3
- 3
src/scalar/ristretto255.rs

@ -398,7 +398,7 @@ impl Scalar {
pub fn from_bytes(bytes: &[u8; 32]) -> CtOption<Scalar> { pub fn from_bytes(bytes: &[u8; 32]) -> CtOption<Scalar> {
let mut tmp = Scalar([0, 0, 0, 0]); let mut tmp = Scalar([0, 0, 0, 0]);
tmp.0[0] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap());
tmp.0[0] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[..8]).unwrap());
tmp.0[1] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap()); tmp.0[1] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap());
tmp.0[2] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap()); tmp.0[2] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap());
tmp.0[3] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap()); tmp.0[3] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap());
@ -429,7 +429,7 @@ impl Scalar {
let tmp = Scalar::montgomery_reduce(self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0); let tmp = Scalar::montgomery_reduce(self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0);
let mut res = [0; 32]; let mut res = [0; 32];
res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes());
res[..8].copy_from_slice(&tmp.0[0].to_le_bytes());
res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes()); res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes());
res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes()); res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes());
res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes()); res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes());
@ -441,7 +441,7 @@ impl Scalar {
/// a `Scalar` by reducing by the modulus. /// a `Scalar` by reducing by the modulus.
pub fn from_bytes_wide(bytes: &[u8; 64]) -> Scalar { pub fn from_bytes_wide(bytes: &[u8; 64]) -> Scalar {
Scalar::from_u512([ Scalar::from_u512([
u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap()),
u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[..8]).unwrap()),
u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap()), u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap()),
u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap()), u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap()),
u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap()), u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap()),

+ 1
- 3
src/sparse_mlpoly.rs

@ -1400,9 +1400,7 @@ impl PolyEvalNetworkProof {
let (claims_mem, rand_mem, mut claims_ops, claims_dotp, rand_ops) = self let (claims_mem, rand_mem, mut claims_ops, claims_dotp, rand_ops) = self
.proof_prod_layer .proof_prod_layer
.verify(num_ops, num_cells, evals, transcript)
.unwrap();
.verify(num_ops, num_cells, evals, transcript)?;
assert_eq!(claims_mem.len(), 4); assert_eq!(claims_mem.len(), 4);
assert_eq!(claims_ops.len(), 4 * num_instances); assert_eq!(claims_ops.len(), 4 * num_instances);
assert_eq!(claims_dotp.len(), 3 * num_instances); assert_eq!(claims_dotp.len(), 3 * num_instances);

+ 1
- 1
src/unipoly.rs

@ -80,7 +80,7 @@ impl UniPoly {
} }
pub fn compress(&self) -> CompressedUniPoly { pub fn compress(&self) -> CompressedUniPoly {
let coeffs_except_linear_term = [&self.coeffs[0..1], &self.coeffs[2..]].concat();
let coeffs_except_linear_term = [&self.coeffs[..1], &self.coeffs[2..]].concat();
assert_eq!(coeffs_except_linear_term.len() + 1, self.coeffs.len()); assert_eq!(coeffs_except_linear_term.len() + 1, self.coeffs.len());
CompressedUniPoly { CompressedUniPoly {
coeffs_except_linear_term, coeffs_except_linear_term,

Loading…
Cancel
Save