From 07b3c4289bf70f1b34bcf11b2b08a8187a6bba87 Mon Sep 17 00:00:00 2001 From: Srinath Setty Date: Sun, 15 May 2022 12:05:18 +0530 Subject: [PATCH] Recursion APIs (#62) * recursion APIs (WIP) * PublicParams struct and associated new * fix build * draft of APIs * start with tests * add a test case for the base case of recursion --- src/circuit.rs | 30 +++-- src/gadgets/r1cs.rs | 6 +- src/lib.rs | 291 +++++++++++++++++++++++++++++++++++++++++++- src/poseidon.rs | 16 +-- 4 files changed, 314 insertions(+), 29 deletions(-) diff --git a/src/circuit.rs b/src/circuit.rs index 7b20ad6..243212e 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -16,7 +16,7 @@ use super::{ alloc_num_equals, alloc_scalar_as_base, alloc_zero, conditionally_select, le_bits_to_num, }, }, - poseidon::{NovaPoseidonConstants, PoseidonROGadget}, + poseidon::{PoseidonROGadget, ROConstantsCircuit}, r1cs::{R1CSInstance, RelaxedR1CSInstance}, traits::{Group, StepCircuit}, }; @@ -87,7 +87,7 @@ where } /// Circuit that encodes only the folding verifier -pub struct NIFSVerifierCircuit +pub struct NIFSVerifierCircuit where G: Group, SC: StepCircuit, @@ -95,7 +95,7 @@ where params: NIFSVerifierCircuitParams, inputs: Option>, step_circuit: SC, // The function that is applied for each step - poseidon_constants: NovaPoseidonConstants, + ro_consts: ROConstantsCircuit, } impl NIFSVerifierCircuit @@ -109,13 +109,13 @@ where params: NIFSVerifierCircuitParams, inputs: Option>, step_circuit: SC, - poseidon_constants: NovaPoseidonConstants, + ro_consts: ROConstantsCircuit, ) -> Self { Self { params, inputs, step_circuit, - poseidon_constants, + ro_consts, } } @@ -219,7 +219,7 @@ where T: AllocatedPoint, ) -> Result<(AllocatedRelaxedR1CSInstance, AllocatedBit), SynthesisError> { // Check that u.x[0] = Hash(params, U, i, z0, zi) - let mut ro: PoseidonROGadget = PoseidonROGadget::new(self.poseidon_constants.clone()); + let mut ro: PoseidonROGadget = PoseidonROGadget::new(self.ro_consts.clone()); ro.absorb(params.clone()); ro.absorb(i); ro.absorb(z_0); @@ -240,7 +240,7 @@ where params, u, T, - self.poseidon_constants.clone(), + self.ro_consts.clone(), self.params.limb_width, self.params.n_limbs, )?; @@ -325,7 +325,7 @@ where .synthesize(&mut cs.namespace(|| "F"), z_input)?; // Compute the new hash H(params, Unew, i+1, z0, z_{i+1}) - let mut ro: PoseidonROGadget = PoseidonROGadget::new(self.poseidon_constants); + let mut ro: PoseidonROGadget = PoseidonROGadget::new(self.ro_consts); ro.absorb(params); ro.absorb(i_new.clone()); ro.absorb(z_0); @@ -380,10 +380,8 @@ mod tests { // In the following we use 1 to refer to the primary, and 2 to refer to the secondary circuit let params1 = NIFSVerifierCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, true); let params2 = NIFSVerifierCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, false); - let poseidon_constants1: NovaPoseidonConstants<::Base> = - NovaPoseidonConstants::new(); - let poseidon_constants2: NovaPoseidonConstants<::Base> = - NovaPoseidonConstants::new(); + let ro_consts1: ROConstantsCircuit<::Base> = ROConstantsCircuit::new(); + let ro_consts2: ROConstantsCircuit<::Base> = ROConstantsCircuit::new(); // Initialize the shape and gens for the primary let circuit1: NIFSVerifierCircuit::Base>> = @@ -393,7 +391,7 @@ mod tests { TestCircuit { _p: Default::default(), }, - poseidon_constants1.clone(), + ro_consts1.clone(), ); let mut cs: ShapeCS = ShapeCS::new(); let _ = circuit1.synthesize(&mut cs); @@ -411,7 +409,7 @@ mod tests { TestCircuit { _p: Default::default(), }, - poseidon_constants2.clone(), + ro_consts2.clone(), ); let mut cs: ShapeCS = ShapeCS::new(); let _ = circuit2.synthesize(&mut cs); @@ -440,7 +438,7 @@ mod tests { TestCircuit { _p: Default::default(), }, - poseidon_constants1, + ro_consts1, ); let _ = circuit1.synthesize(&mut cs1); let (inst1, witness1) = cs1.r1cs_instance_and_witness(&shape1, &gens1).unwrap(); @@ -466,7 +464,7 @@ mod tests { TestCircuit { _p: Default::default(), }, - poseidon_constants2, + ro_consts2, ); let _ = circuit.synthesize(&mut cs2); let (inst2, witness2) = cs2.r1cs_instance_and_witness(&shape2, &gens2).unwrap(); diff --git a/src/gadgets/r1cs.rs b/src/gadgets/r1cs.rs index 990121a..b54b77d 100644 --- a/src/gadgets/r1cs.rs +++ b/src/gadgets/r1cs.rs @@ -7,7 +7,7 @@ use crate::{ conditionally_select_bignat, le_bits_to_num, }, }, - poseidon::{NovaPoseidonConstants, PoseidonROGadget}, + poseidon::{PoseidonROGadget, ROConstantsCircuit}, r1cs::{R1CSInstance, RelaxedR1CSInstance}, traits::Group, }; @@ -263,12 +263,12 @@ where params: AllocatedNum, // hash of R1CSShape of F' u: AllocatedR1CSInstance, T: AllocatedPoint, - poseidon_constants: NovaPoseidonConstants, + ro_consts: ROConstantsCircuit, limb_width: usize, n_limbs: usize, ) -> Result, SynthesisError> { // Compute r: - let mut ro: PoseidonROGadget = PoseidonROGadget::new(poseidon_constants); + let mut ro: PoseidonROGadget = PoseidonROGadget::new(ro_consts); ro.absorb(params); self.absorb_in_ro(cs.namespace(|| "absorb running instance"), &mut ro)?; u.absorb_in_ro(&mut ro); diff --git a/src/lib.rs b/src/lib.rs index 9ce2c94..27ae7ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,9 +15,237 @@ mod poseidon; pub mod r1cs; pub mod traits; +use crate::bellperson::{ + r1cs::{NovaShape, NovaWitness}, + shape_cs::ShapeCS, + solver::SatisfyingAssignment, +}; +use crate::poseidon::ROConstantsCircuit; // TODO: make this a trait so we can use it without the concrete implementation +use ::bellperson::{Circuit, ConstraintSystem}; +use circuit::{NIFSVerifierCircuit, NIFSVerifierCircuitInputs, NIFSVerifierCircuitParams}; +use constants::{BN_LIMB_WIDTH, BN_N_LIMBS}; +use core::marker::PhantomData; use errors::NovaError; -use r1cs::{R1CSGens, R1CSShape, RelaxedR1CSInstance, RelaxedR1CSWitness}; -use traits::Group; +use ff::Field; +use r1cs::{ + R1CSGens, R1CSInstance, R1CSShape, R1CSWitness, RelaxedR1CSInstance, RelaxedR1CSWitness, +}; +use traits::{Group, HashFuncConstantsTrait, HashFuncTrait, StepCircuit}; + +type ROConstants = + <::HashFunc as HashFuncTrait<::Base, ::Scalar>>::Constants; + +/// A type that holds public parameters of Nova +pub struct PublicParams +where + G1: Group::Scalar>, + G2: Group::Scalar>, + C1: StepCircuit + Clone, + C2: StepCircuit + Clone, +{ + _ro_consts_primary: ROConstants, + ro_consts_circuit_primary: ROConstantsCircuit<::Base>, + r1cs_gens_primary: R1CSGens, + r1cs_shape_primary: R1CSShape, + _ro_consts_secondary: ROConstants, + ro_consts_circuit_secondary: ROConstantsCircuit<::Base>, + r1cs_gens_secondary: R1CSGens, + r1cs_shape_secondary: R1CSShape, + c_primary: C1, + c_secondary: C2, + params_primary: NIFSVerifierCircuitParams, + params_secondary: NIFSVerifierCircuitParams, +} + +impl PublicParams +where + G1: Group::Scalar>, + G2: Group::Scalar>, + C1: StepCircuit + Clone, + C2: StepCircuit + Clone, +{ + /// Create a new `PublicParams` + pub fn setup(c_primary: C1, c_secondary: C2) -> Self { + let params_primary = NIFSVerifierCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, true); + let params_secondary = NIFSVerifierCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, false); + + let _ro_consts_primary: ROConstants = ROConstants::::new(); + let _ro_consts_secondary: ROConstants = ROConstants::::new(); + + let ro_consts_circuit_primary: ROConstantsCircuit<::Base> = + ROConstantsCircuit::new(); + let ro_consts_circuit_secondary: ROConstantsCircuit<::Base> = + ROConstantsCircuit::new(); + + // Initialize gens for the primary + let circuit_primary: NIFSVerifierCircuit = NIFSVerifierCircuit::new( + params_primary.clone(), + None, + c_primary.clone(), + ro_consts_circuit_primary.clone(), + ); + let mut cs: ShapeCS = ShapeCS::new(); + let _ = circuit_primary.synthesize(&mut cs); + let (r1cs_shape_primary, r1cs_gens_primary) = (cs.r1cs_shape(), cs.r1cs_gens()); + + // Initialize gens for the secondary + let circuit_secondary: NIFSVerifierCircuit = NIFSVerifierCircuit::new( + params_secondary.clone(), + None, + c_secondary.clone(), + ro_consts_circuit_secondary.clone(), + ); + let mut cs: ShapeCS = ShapeCS::new(); + let _ = circuit_secondary.synthesize(&mut cs); + let (r1cs_shape_secondary, r1cs_gens_secondary) = (cs.r1cs_shape(), cs.r1cs_gens()); + + Self { + _ro_consts_primary, + ro_consts_circuit_primary, + r1cs_gens_primary, + r1cs_shape_primary, + _ro_consts_secondary, + ro_consts_circuit_secondary, + r1cs_gens_secondary, + r1cs_shape_secondary, + c_primary, + c_secondary, + params_primary, + params_secondary, + } + } +} + +/// A SNARK that proves the correct execution of an incremental computation +pub struct RecursiveSNARK +where + G1: Group::Scalar>, + G2: Group::Scalar>, + C1: StepCircuit + Clone, + C2: StepCircuit + Clone, +{ + r_W_primary: RelaxedR1CSWitness, + r_U_primary: RelaxedR1CSInstance, + l_w_primary: R1CSWitness, + l_u_primary: R1CSInstance, + r_W_secondary: RelaxedR1CSWitness, + r_U_secondary: RelaxedR1CSInstance, + l_w_secondary: R1CSWitness, + l_u_secondary: R1CSInstance, + _p_c1: PhantomData, + _p_c2: PhantomData, +} + +impl RecursiveSNARK +where + G1: Group::Scalar>, + G2: Group::Scalar>, + C1: StepCircuit + Clone, + C2: StepCircuit + Clone, +{ + /// Create a new `RecursiveSNARK` + pub fn prove( + pp: &PublicParams, + z0_primary: G1::Scalar, + z0_secondary: G2::Scalar, + ) -> Result { + // Execute the base case for the primary + let mut cs_primary: SatisfyingAssignment = SatisfyingAssignment::new(); + let inputs_primary: NIFSVerifierCircuitInputs = NIFSVerifierCircuitInputs::new( + pp.r1cs_shape_secondary.get_digest(), + ::Base::zero(), + z0_primary, + None, + None, + None, + None, + ); + let circuit_primary: NIFSVerifierCircuit = NIFSVerifierCircuit::new( + pp.params_primary.clone(), + Some(inputs_primary), + pp.c_primary.clone(), + pp.ro_consts_circuit_primary.clone(), + ); + let _ = circuit_primary.synthesize(&mut cs_primary); + let (u_primary, w_primary) = cs_primary + .r1cs_instance_and_witness(&pp.r1cs_shape_primary, &pp.r1cs_gens_primary) + .map_err(|_e| NovaError::UnSat)?; + + // check if the base case is satisfied + pp.r1cs_shape_primary + .is_sat(&pp.r1cs_gens_primary, &u_primary, &w_primary) + .map_err(|_e| NovaError::UnSat)?; + + // Execute the base case for the secondary + let mut cs_secondary: SatisfyingAssignment = SatisfyingAssignment::new(); + let inputs_secondary: NIFSVerifierCircuitInputs = NIFSVerifierCircuitInputs::new( + pp.r1cs_shape_primary.get_digest(), + ::Base::zero(), + z0_secondary, + None, + None, + Some(u_primary.clone()), + None, + ); + let circuit_secondary: NIFSVerifierCircuit = NIFSVerifierCircuit::new( + pp.params_secondary.clone(), + Some(inputs_secondary), + pp.c_secondary.clone(), + pp.ro_consts_circuit_secondary.clone(), + ); + let _ = circuit_secondary.synthesize(&mut cs_secondary); + let (u_secondary, w_secondary) = cs_secondary + .r1cs_instance_and_witness(&pp.r1cs_shape_secondary, &pp.r1cs_gens_secondary) + .map_err(|_e| NovaError::UnSat)?; + + // check if the base case is satisfied + pp.r1cs_shape_secondary + .is_sat(&pp.r1cs_gens_secondary, &u_secondary, &w_secondary) + .map_err(|_e| NovaError::UnSat)?; + + Ok(Self { + r_W_primary: RelaxedR1CSWitness::::default(&pp.r1cs_shape_primary), + r_U_primary: RelaxedR1CSInstance::::default( + &pp.r1cs_gens_primary, + &pp.r1cs_shape_primary, + ), + l_w_primary: w_primary, + l_u_primary: u_primary, + r_W_secondary: RelaxedR1CSWitness::::default(&pp.r1cs_shape_secondary), + r_U_secondary: RelaxedR1CSInstance::::default( + &pp.r1cs_gens_secondary, + &pp.r1cs_shape_secondary, + ), + l_w_secondary: w_secondary, + l_u_secondary: u_secondary, + _p_c1: Default::default(), + _p_c2: Default::default(), + }) + } + + /// Verify the correctness of the `RecursiveSNARK` + pub fn verify(&self, pp: &PublicParams) -> Result<(), NovaError> { + pp.r1cs_shape_primary.is_sat_relaxed( + &pp.r1cs_gens_primary, + &self.r_U_primary, + &self.r_W_primary, + )?; + pp.r1cs_shape_primary + .is_sat(&pp.r1cs_gens_primary, &self.l_u_primary, &self.l_w_primary)?; + pp.r1cs_shape_secondary.is_sat_relaxed( + &pp.r1cs_gens_secondary, + &self.r_U_secondary, + &self.r_W_secondary, + )?; + pp.r1cs_shape_secondary.is_sat( + &pp.r1cs_gens_secondary, + &self.l_u_secondary, + &self.l_w_secondary, + )?; + + Ok(()) + } +} /// A SNARK that proves the knowledge of a valid `RecursiveSNARK` pub struct CompressedSNARKTrivial { @@ -41,3 +269,62 @@ impl CompressedSNARKTrivial { S.is_sat_relaxed(gens, U, &self.W) } } + +#[cfg(test)] +mod tests { + use super::*; + type G1 = pasta_curves::pallas::Point; + type G2 = pasta_curves::vesta::Point; + use ::bellperson::{gadgets::num::AllocatedNum, ConstraintSystem, SynthesisError}; + use ff::PrimeField; + use std::marker::PhantomData; + + #[derive(Clone, Debug)] + struct TestCircuit { + _p: PhantomData, + } + + impl StepCircuit for TestCircuit + where + F: PrimeField, + { + fn synthesize>( + &self, + _cs: &mut CS, + z: AllocatedNum, + ) -> Result, SynthesisError> { + Ok(z) + } + } + + #[test] + fn test_base_case() { + // produce public parameters + let pp = PublicParams::< + G1, + G2, + TestCircuit<::Base>, + TestCircuit<::Base>, + >::setup( + TestCircuit { + _p: Default::default(), + }, + TestCircuit { + _p: Default::default(), + }, + ); + + // produce a recursive SNARK + let res = RecursiveSNARK::prove( + &pp, + ::Base::zero(), + ::Base::zero(), + ); + assert!(res.is_ok()); + let recursive_snark = res.unwrap(); + + // verify the recursive SNARK + let res = recursive_snark.verify(&pp); + assert!(res.is_ok()); + } +} diff --git a/src/poseidon.rs b/src/poseidon.rs index 92aa8f4..abee7fb 100644 --- a/src/poseidon.rs +++ b/src/poseidon.rs @@ -21,7 +21,7 @@ use neptune::{ /// All Poseidon Constants that are used in Nova #[derive(Clone)] -pub struct NovaPoseidonConstants +pub struct ROConstantsCircuit where Scalar: PrimeField, { @@ -29,7 +29,7 @@ where constants32: PoseidonConstants, } -impl HashFuncConstantsTrait for NovaPoseidonConstants +impl HashFuncConstantsTrait for ROConstantsCircuit where Scalar: PrimeField + PrimeFieldBits, { @@ -54,7 +54,7 @@ where // Internal State state: Vec, // Constants for Poseidon - constants: NovaPoseidonConstants, + constants: ROConstantsCircuit, _p: PhantomData, } @@ -86,10 +86,10 @@ where Base: PrimeField + PrimeFieldBits, Scalar: PrimeField + PrimeFieldBits, { - type Constants = NovaPoseidonConstants; + type Constants = ROConstantsCircuit; #[allow(dead_code)] - fn new(constants: NovaPoseidonConstants) -> Self { + fn new(constants: ROConstantsCircuit) -> Self { Self { state: Vec::new(), constants, @@ -144,7 +144,7 @@ where { // Internal state state: Vec>, - constants: NovaPoseidonConstants, + constants: ROConstantsCircuit, } impl PoseidonROGadget @@ -153,7 +153,7 @@ where { /// Initialize the internal state and set the poseidon constants #[allow(dead_code)] - pub fn new(constants: NovaPoseidonConstants) -> Self { + pub fn new(constants: ROConstantsCircuit) -> Self { Self { state: Vec::new(), constants, @@ -236,7 +236,7 @@ mod tests { fn test_poseidon_ro() { // Check that the number computed inside the circuit is equal to the number computed outside the circuit let mut csprng: OsRng = OsRng; - let constants = NovaPoseidonConstants::new(); + let constants = ROConstantsCircuit::new(); let mut ro: PoseidonRO = PoseidonRO::new(constants.clone()); let mut ro_gadget: PoseidonROGadget = PoseidonROGadget::new(constants); let mut cs: SatisfyingAssignment = SatisfyingAssignment::new();