Browse Source

additional error checking

master
Srinath Setty 3 years ago
parent
commit
1bb98a36b1
3 changed files with 79 additions and 23 deletions
  1. +1
    -1
      Cargo.toml
  2. +3
    -12
      src/errors.rs
  3. +75
    -10
      src/lib.rs

+ 1
- 1
Cargo.toml

@ -1,6 +1,6 @@
[package]
name = "spartan"
version = "0.2.0"
version = "0.2.1"
authors = ["Srinath Setty <srinath@microsoft.com>"]
edition = "2018"
description = "High-speed zkSNARKs without trusted setup"

+ 3
- 12
src/errors.rs

@ -14,6 +14,7 @@ impl fmt::Debug for ProofVerifyError {
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum R1CSError {
/// returned if the number of constraints is not a power of 2
NonPowerOfTwoCons,
@ -25,16 +26,6 @@ pub enum R1CSError {
InvalidNumberOfVars,
/// returned if a [u8;32] does not parse into a valid Scalar in the field of ristretto255
InvalidScalar,
}
impl fmt::Display for R1CSError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "R1CSError")
}
}
impl fmt::Debug for R1CSError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{{ file: {}, line: {} }}", file!(), line!())
}
/// returned if the supplied row or col in (row,col,val) tuple is out of range
InvalidIndex,
}

+ 75
- 10
src/lib.rs

@ -129,6 +129,17 @@ impl Instance {
let mut mat: Vec<(usize, usize, Scalar)> = Vec::new();
for i in 0..tups.len() {
let (row, col, val_bytes) = tups[i];
// row must be smaller than num_cons
if row >= num_cons {
return Err(R1CSError::InvalidIndex);
}
// col must be smaller than num_vars + 1 + num_inputs
if col >= num_vars + 1 + num_inputs {
return Err(R1CSError::InvalidIndex);
}
let val = Scalar::from_bytes(&val_bytes);
if val.is_some().unwrap_u8() == 1 {
mat.push((row, col, val.unwrap()));
@ -140,12 +151,18 @@ impl Instance {
};
let A_scalar = bytes_to_scalar(A);
if A_scalar.is_err() {
return Err(A_scalar.err().unwrap());
}
let B_scalar = bytes_to_scalar(B);
let C_scalar = bytes_to_scalar(C);
if B_scalar.is_err() {
return Err(B_scalar.err().unwrap());
}
// check for any parsing errors
if A_scalar.is_err() || B_scalar.is_err() || C_scalar.is_err() {
return Err(R1CSError::InvalidScalar);
let C_scalar = bytes_to_scalar(C);
if C_scalar.is_err() {
return Err(C_scalar.err().unwrap());
}
let inst = R1CSInstance::new(
@ -161,16 +178,19 @@ impl Instance {
}
/// Checks if a given R1CSInstance is satisfiable with a given variables and inputs assignments
pub fn is_sat(&self, vars: &VarsAssignment, inputs: &InputsAssignment) -> Result<bool, R1CSError> {
pub fn is_sat(
&self,
vars: &VarsAssignment,
inputs: &InputsAssignment,
) -> Result<bool, R1CSError> {
if vars.assignment.len() != self.inst.get_num_vars() {
return Err(R1CSError::InvalidNumberOfVars)
return Err(R1CSError::InvalidNumberOfVars);
}
if inputs.assignment.len() != self.inst.get_num_inputs() {
return Err(R1CSError::InvalidNumberOfInputs)
return Err(R1CSError::InvalidNumberOfInputs);
}
Ok(self.inst.is_sat(&vars.assignment, &inputs.assignment))
}
@ -485,4 +505,49 @@ mod tests {
.verify(&comm, &inputs, &mut verifier_transcript, &gens)
.is_ok());
}
#[test]
pub fn check_r1cs_invalid_index() {
let num_cons = 4;
let num_vars = 8;
let num_inputs = 1;
let zero: [u8; 32] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0,
];
let A = vec![(0, 0, zero)];
let B = vec![(100, 1, zero)];
let C = vec![(1, 1, zero)];
let inst = Instance::new(num_cons, num_vars, num_inputs, &A, &B, &C);
assert_eq!(inst.is_err(), true);
assert_eq!(inst.err(), Some(R1CSError::InvalidIndex));
}
#[test]
pub fn check_r1cs_invalid_scalar() {
let num_cons = 4;
let num_vars = 8;
let num_inputs = 1;
let zero: [u8; 32] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0,
];
let larger_than_mod = [
3, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8, 216,
57, 51, 72, 125, 157, 41, 83, 167, 237, 115,
];
let A = vec![(0, 0, zero)];
let B = vec![(1, 1, larger_than_mod)];
let C = vec![(1, 1, zero)];
let inst = Instance::new(num_cons, num_vars, num_inputs, &A, &B, &C);
assert_eq!(inst.is_err(), true);
assert_eq!(inst.err(), Some(R1CSError::InvalidScalar));
}
}

Loading…
Cancel
Save