diff --git a/src/circuit.rs b/src/circuit.rs index 927dee7..95d45a9 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -38,7 +38,6 @@ pub struct NIFSVerifierCircuitParams { } impl NIFSVerifierCircuitParams { - #[allow(dead_code)] pub fn new(limb_width: usize, n_limbs: usize, is_primary_circuit: bool) -> Self { Self { limb_width, @@ -64,7 +63,7 @@ where G: Group, { /// Create new inputs/witness for the verification circuit - #[allow(dead_code, clippy::too_many_arguments)] + #[allow(clippy::too_many_arguments)] pub fn new( params: G::Scalar, i: G::Base, @@ -104,7 +103,6 @@ where SC: StepCircuit, { /// Create a new verification circuit for the input relaxed r1cs instances - #[allow(dead_code)] pub fn new( params: NIFSVerifierCircuitParams, inputs: Option>, @@ -143,12 +141,15 @@ where // Allocate i let i = AllocatedNum::alloc(cs.namespace(|| "i"), || Ok(self.inputs.get()?.i))?; + // Allocate z0 let z_0 = AllocatedNum::alloc(cs.namespace(|| "z0"), || Ok(self.inputs.get()?.z0))?; + // Allocate zi. If inputs.zi is not provided (base case) allocate default value 0 let z_i = AllocatedNum::alloc(cs.namespace(|| "zi"), || { Ok(self.inputs.get()?.zi.unwrap_or_else(G::Base::zero)) })?; + // Allocate the running instance let U: AllocatedRelaxedR1CSInstance = AllocatedRelaxedR1CSInstance::alloc( cs.namespace(|| "Allocate U"), @@ -158,6 +159,7 @@ where self.params.limb_width, self.params.n_limbs, )?; + // Allocate the instance to be folded in let u = AllocatedR1CSInstance::alloc( cs.namespace(|| "allocate instance u to fold"), @@ -176,6 +178,7 @@ where .map_or(None, |T| Some(T.comm.to_coordinates())) }), )?; + Ok((params, i, z_0, z_i, U, u, T)) } @@ -320,6 +323,7 @@ where &z_i, &Boolean::from(is_base_case), )?; + let z_next = self .step_circuit .synthesize(&mut cs.namespace(|| "F"), z_input)?; diff --git a/src/errors.rs b/src/errors.rs index 7038121..ca9bf0f 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -16,4 +16,8 @@ pub enum NovaError { UnSat, /// returned when the supplied compressed commitment cannot be decompressed DecompressionError, + /// returned if proof verification fails + ProofVerifyError, + /// returned if the provided number of steps is zero + InvalidNumSteps, } diff --git a/src/gadgets/utils.rs b/src/gadgets/utils.rs index 2034544..dce3cbf 100644 --- a/src/gadgets/utils.rs +++ b/src/gadgets/utils.rs @@ -13,7 +13,6 @@ use ff::{Field, PrimeField, PrimeFieldBits}; use num_bigint::BigInt; /// Gets as input the little indian representation of a number and spits out the number -#[allow(dead_code)] pub fn le_bits_to_num( mut cs: CS, bits: Vec, diff --git a/src/lib.rs b/src/lib.rs index 792ea21..00e4cb9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,12 +26,13 @@ use constants::{BN_LIMB_WIDTH, BN_N_LIMBS}; use core::marker::PhantomData; use errors::NovaError; use ff::Field; +use gadgets::utils::scalar_as_base; use nifs::NIFS; use poseidon::ROConstantsCircuit; // TODO: make this a trait so we can use it without the concrete implementation use r1cs::{ R1CSGens, R1CSInstance, R1CSShape, R1CSWitness, RelaxedR1CSInstance, RelaxedR1CSWitness, }; -use traits::{Group, HashFuncConstantsTrait, HashFuncTrait, StepCircuit}; +use traits::{AbsorbInROTrait, Group, HashFuncConstantsTrait, HashFuncTrait, StepCircuit}; type ROConstants = <::HashFunc as HashFuncTrait<::Base, ::Scalar>>::Constants; @@ -133,6 +134,8 @@ where r_U_secondary: RelaxedR1CSInstance, l_w_secondary: R1CSWitness, l_u_secondary: R1CSInstance, + zn_primary: G1::Scalar, + zn_secondary: G2::Scalar, _p_c1: PhantomData, _p_c2: PhantomData, } @@ -147,15 +150,19 @@ where /// Create a new `RecursiveSNARK` pub fn prove( pp: &PublicParams, + num_steps: usize, z0_primary: G1::Scalar, z0_secondary: G2::Scalar, - num_steps: usize, ) -> Result { + if num_steps == 0 { + return Err(NovaError::InvalidNumSteps); + } + // Execute the base case for the primary let mut cs_primary: SatisfyingAssignment = SatisfyingAssignment::new(); let inputs_primary: NIFSVerifierCircuitInputs = NIFSVerifierCircuitInputs::new( pp.r1cs_shape_secondary.get_digest(), - ::Scalar::zero(), + G1::Scalar::zero(), z0_primary, None, None, @@ -177,7 +184,7 @@ where let mut cs_secondary: SatisfyingAssignment = SatisfyingAssignment::new(); let inputs_secondary: NIFSVerifierCircuitInputs = NIFSVerifierCircuitInputs::new( pp.r1cs_shape_primary.get_digest(), - ::Scalar::zero(), + G2::Scalar::zero(), z0_secondary, None, None, @@ -214,6 +221,8 @@ where let mut z_next_primary = z0_primary; let mut z_next_secondary = z0_secondary; + z_next_primary = pp.c_primary.compute(&z_next_primary); + z_next_secondary = pp.c_secondary.compute(&z_next_secondary); for i in 1..num_steps { // fold the secondary circuit's instance @@ -227,13 +236,10 @@ where &l_w_secondary, )?; - z_next_primary = pp.c_primary.compute(&z_next_primary); - z_next_secondary = pp.c_secondary.compute(&z_next_secondary); - let mut cs_primary: SatisfyingAssignment = SatisfyingAssignment::new(); let inputs_primary: NIFSVerifierCircuitInputs = NIFSVerifierCircuitInputs::new( pp.r1cs_shape_secondary.get_digest(), - ::Scalar::from(i as u64), + G1::Scalar::from(i as u64), z0_primary, Some(z_next_primary), Some(r_U_secondary), @@ -267,7 +273,7 @@ where let mut cs_secondary: SatisfyingAssignment = SatisfyingAssignment::new(); let inputs_secondary: NIFSVerifierCircuitInputs = NIFSVerifierCircuitInputs::new( pp.r1cs_shape_primary.get_digest(), - ::Scalar::from(i as u64), + G2::Scalar::from(i as u64), z0_secondary, Some(z_next_secondary), Some(r_U_primary.clone()), @@ -292,6 +298,8 @@ where r_W_secondary = r_W_next_secondary; r_U_primary = r_U_next_primary; r_W_primary = r_W_next_primary; + z_next_primary = pp.c_primary.compute(&z_next_primary); + z_next_secondary = pp.c_secondary.compute(&z_next_secondary); } Ok(Self { @@ -303,15 +311,61 @@ where r_U_secondary, l_w_secondary, l_u_secondary, + zn_primary: z_next_primary, + zn_secondary: z_next_secondary, _p_c1: Default::default(), _p_c2: Default::default(), }) } /// Verify the correctness of the `RecursiveSNARK` - pub fn verify(&self, pp: &PublicParams) -> Result<(), NovaError> { - // TODO: perform additional checks on whether (shape_digest, z_0, z_i, i) are correct + pub fn verify( + &self, + pp: &PublicParams, + num_steps: usize, + z0_primary: G1::Scalar, + z0_secondary: G2::Scalar, + ) -> Result<(G1::Scalar, G2::Scalar), NovaError> { + // number of steps cannot be zero + if num_steps == 0 { + return Err(NovaError::ProofVerifyError); + } + + // check if the (relaxed) R1CS instances have two public outputs + if self.l_u_primary.X.len() != 2 + || self.l_u_secondary.X.len() != 2 + || self.r_U_primary.X.len() != 2 + || self.r_U_secondary.X.len() != 2 + { + return Err(NovaError::ProofVerifyError); + } + // check if the output hashes in R1CS instances point to the right running instances + let (hash_primary, hash_secondary) = { + let mut hasher = ::HashFunc::new(pp.ro_consts_secondary.clone()); + hasher.absorb(scalar_as_base::(pp.r1cs_shape_secondary.get_digest())); + hasher.absorb(G1::Scalar::from(num_steps as u64)); + hasher.absorb(z0_primary); + hasher.absorb(self.zn_primary); + self.r_U_secondary.absorb_in_ro(&mut hasher); + + let mut hasher2 = ::HashFunc::new(pp.ro_consts_primary.clone()); + hasher2.absorb(scalar_as_base::(pp.r1cs_shape_primary.get_digest())); + hasher2.absorb(G2::Scalar::from(num_steps as u64)); + hasher2.absorb(z0_secondary); + hasher2.absorb(self.zn_secondary); + self.r_U_primary.absorb_in_ro(&mut hasher2); + + (hasher.get_hash(), hasher2.get_hash()) + }; + + if hash_primary != scalar_as_base::(self.l_u_primary.X[1]) + || hash_secondary != scalar_as_base::(self.l_u_secondary.X[1]) + { + return Err(NovaError::ProofVerifyError); + } + + // check the satisfiability of the provided instances pp.r1cs_shape_primary.is_sat_relaxed( &pp.r1cs_gens_primary, &self.r_U_primary, @@ -333,7 +387,7 @@ where &self.l_w_secondary, )?; - Ok(()) + Ok((self.zn_primary, self.zn_secondary)) } } @@ -433,20 +487,78 @@ mod tests { // produce a recursive SNARK let res = RecursiveSNARK::prove( &pp, + 3, ::Scalar::zero(), ::Scalar::zero(), + ); + assert!(res.is_ok()); + let recursive_snark = res.unwrap(); + + // verify the recursive SNARK + let res = recursive_snark.verify( + &pp, 3, + ::Scalar::zero(), + ::Scalar::zero(), + ); + assert!(res.is_ok()); + } + + #[test] + fn test_ivc_nontrivial() { + // produce public parameters + let pp = PublicParams::< + G1, + G2, + TrivialTestCircuit<::Scalar>, + CubicCircuit<::Scalar>, + >::setup( + TrivialTestCircuit { + _p: Default::default(), + }, + CubicCircuit { + _p: Default::default(), + }, + ); + + let num_steps = 3; + + // produce a recursive SNARK + let res = RecursiveSNARK::prove( + &pp, + num_steps, + ::Scalar::one(), + ::Scalar::zero(), ); assert!(res.is_ok()); let recursive_snark = res.unwrap(); // verify the recursive SNARK - let res = recursive_snark.verify(&pp); + let res = recursive_snark.verify( + &pp, + num_steps, + ::Scalar::one(), + ::Scalar::zero(), + ); assert!(res.is_ok()); + + let (zn_primary, zn_secondary) = res.unwrap(); + + // sanity: check the claimed output with a direct computation of the same + assert_eq!(zn_primary, ::Scalar::one()); + let mut zn_secondary_direct = ::Scalar::zero(); + for _i in 0..num_steps { + zn_secondary_direct = CubicCircuit { + _p: Default::default(), + } + .compute(&zn_secondary_direct); + } + assert_eq!(zn_secondary, zn_secondary_direct); + assert_eq!(zn_secondary, ::Scalar::from(2460515u64)); } #[test] - fn test_ivc() { + fn test_ivc_base() { // produce public parameters let pp = PublicParams::< G1, @@ -462,18 +574,30 @@ mod tests { }, ); + let num_steps = 1; + // produce a recursive SNARK let res = RecursiveSNARK::prove( &pp, - ::Scalar::zero(), + num_steps, + ::Scalar::one(), ::Scalar::zero(), - 3, ); assert!(res.is_ok()); let recursive_snark = res.unwrap(); // verify the recursive SNARK - let res = recursive_snark.verify(&pp); + let res = recursive_snark.verify( + &pp, + num_steps, + ::Scalar::one(), + ::Scalar::zero(), + ); assert!(res.is_ok()); + + let (zn_primary, zn_secondary) = res.unwrap(); + + assert_eq!(zn_primary, ::Scalar::one()); + assert_eq!(zn_secondary, ::Scalar::from(5u64)); } } diff --git a/src/poseidon.rs b/src/poseidon.rs index abee7fb..cfcf550 100644 --- a/src/poseidon.rs +++ b/src/poseidon.rs @@ -88,7 +88,6 @@ where { type Constants = ROConstantsCircuit; - #[allow(dead_code)] fn new(constants: ROConstantsCircuit) -> Self { Self { state: Vec::new(), @@ -98,13 +97,11 @@ where } /// Absorb a new number into the state of the oracle - #[allow(dead_code)] fn absorb(&mut self, e: Base) { self.state.push(e); } /// Compute a challenge by hashing the current state - #[allow(dead_code)] fn get_challenge(&self) -> Scalar { let hash = self.hash_inner(); // Only keep NUM_CHALLENGE_BITS bits @@ -120,7 +117,6 @@ where res } - #[allow(dead_code)] fn get_hash(&self) -> Scalar { let hash = self.hash_inner(); // Only keep NUM_HASH_BITS bits @@ -152,7 +148,6 @@ where Scalar: PrimeField + PrimeFieldBits, { /// Initialize the internal state and set the poseidon constants - #[allow(dead_code)] pub fn new(constants: ROConstantsCircuit) -> Self { Self { state: Vec::new(), @@ -161,7 +156,6 @@ where } /// Absorb a new number into the state of the oracle - #[allow(dead_code)] pub fn absorb(&mut self, e: AllocatedNum) { self.state.push(e); } @@ -203,7 +197,6 @@ where } /// Compute a challenge by hashing the current state - #[allow(dead_code)] pub fn get_challenge(&mut self, mut cs: CS) -> Result, SynthesisError> where CS: ConstraintSystem, @@ -212,7 +205,6 @@ where Ok(bits[..NUM_CHALLENGE_BITS].into()) } - #[allow(dead_code)] pub fn get_hash(&mut self, mut cs: CS) -> Result, SynthesisError> where CS: ConstraintSystem,