diff --git a/src/circuit.rs b/src/circuit.rs index 8ca56c6..f44c01a 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -15,9 +15,8 @@ use super::{ alloc_num_equals, alloc_scalar_as_base, alloc_zero, conditionally_select, le_bits_to_num, }, }, - poseidon::{PoseidonROGadget, ROConstantsCircuit}, r1cs::{R1CSInstance, RelaxedR1CSInstance}, - traits::{Group, StepCircuit}, + traits::{Group, HashFuncCircuitTrait, HashFuncConstantsCircuit, StepCircuit}, }; use bellperson::{ gadgets::{ @@ -91,7 +90,7 @@ where SC: StepCircuit, { params: NIFSVerifierCircuitParams, - ro_consts: ROConstantsCircuit, + ro_consts: HashFuncConstantsCircuit, inputs: Option>, step_circuit: SC, // The function that is applied for each step } @@ -106,7 +105,7 @@ where params: NIFSVerifierCircuitParams, inputs: Option>, step_circuit: SC, - ro_consts: ROConstantsCircuit, + ro_consts: HashFuncConstantsCircuit, ) -> Self { Self { params, @@ -221,7 +220,7 @@ where T: AllocatedPoint, ) -> Result<(AllocatedRelaxedR1CSInstance, AllocatedBit), SynthesisError> { // Check that u.x[0] = Hash(params, U, i, z0, zi) - let mut ro: PoseidonROGadget = PoseidonROGadget::new(self.ro_consts.clone()); + let mut ro = G::HashFuncCircuit::new(self.ro_consts.clone()); ro.absorb(params.clone()); ro.absorb(i); ro.absorb(z_0); @@ -328,7 +327,7 @@ where .synthesize(&mut cs.namespace(|| "F"), z_input)?; // Compute the new hash H(params, Unew, i+1, z0, z_{i+1}) - let mut ro: PoseidonROGadget = PoseidonROGadget::new(self.ro_consts); + let mut ro = G::HashFuncCircuit::new(self.ro_consts); ro.absorb(params); ro.absorb(i_new.clone()); ro.absorb(z_0); @@ -356,6 +355,7 @@ mod tests { use crate::constants::{BN_LIMB_WIDTH, BN_N_LIMBS}; use crate::{ bellperson::r1cs::{NovaShape, NovaWitness}, + poseidon::PoseidonConstantsCircuit, traits::HashFuncConstantsTrait, }; use ff::PrimeField; @@ -388,8 +388,8 @@ mod tests { // In the following we use 1 to refer to the primary, and 2 to refer to the secondary circuit let params1 = NIFSVerifierCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, true); let params2 = NIFSVerifierCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, false); - let ro_consts1: ROConstantsCircuit<::Base> = ROConstantsCircuit::new(); - let ro_consts2: ROConstantsCircuit<::Base> = ROConstantsCircuit::new(); + let ro_consts1: HashFuncConstantsCircuit = PoseidonConstantsCircuit::new(); + let ro_consts2: HashFuncConstantsCircuit = PoseidonConstantsCircuit::new(); // Initialize the shape and gens for the primary let circuit1: NIFSVerifierCircuit::Base>> = diff --git a/src/gadgets/r1cs.rs b/src/gadgets/r1cs.rs index b54b77d..33cc7d7 100644 --- a/src/gadgets/r1cs.rs +++ b/src/gadgets/r1cs.rs @@ -7,9 +7,8 @@ use crate::{ conditionally_select_bignat, le_bits_to_num, }, }, - poseidon::{PoseidonROGadget, ROConstantsCircuit}, r1cs::{R1CSInstance, RelaxedR1CSInstance}, - traits::Group, + traits::{Group, HashFuncCircuitTrait, HashFuncConstantsCircuit}, }; use bellperson::{ gadgets::{boolean::Boolean, num::AllocatedNum, Assignment}, @@ -61,7 +60,7 @@ where } /// Absorb the provided instance in the RO - pub fn absorb_in_ro(&self, ro: &mut PoseidonROGadget) { + pub fn absorb_in_ro(&self, ro: &mut G::HashFuncCircuit) { ro.absorb(self.W.x.clone()); ro.absorb(self.W.y.clone()); ro.absorb(self.W.is_infinity.clone()); @@ -208,7 +207,7 @@ where pub fn absorb_in_ro::Base>>( &self, mut cs: CS, - ro: &mut PoseidonROGadget, + ro: &mut G::HashFuncCircuit, ) -> Result<(), SynthesisError> { ro.absorb(self.W.x.clone()); ro.absorb(self.W.y.clone()); @@ -263,12 +262,12 @@ where params: AllocatedNum, // hash of R1CSShape of F' u: AllocatedR1CSInstance, T: AllocatedPoint, - ro_consts: ROConstantsCircuit, + ro_consts: HashFuncConstantsCircuit, limb_width: usize, n_limbs: usize, ) -> Result, SynthesisError> { // Compute r: - let mut ro: PoseidonROGadget = PoseidonROGadget::new(ro_consts); + let mut ro = G::HashFuncCircuit::new(ro_consts); ro.absorb(params); self.absorb_in_ro(cs.namespace(|| "absorb running instance"), &mut ro)?; u.absorb_in_ro(&mut ro); diff --git a/src/lib.rs b/src/lib.rs index c011d8d..acf39ab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,31 +33,30 @@ use errors::NovaError; use ff::Field; use gadgets::utils::scalar_as_base; use nifs::NIFS; -use poseidon::ROConstantsCircuit; // TODO: make this a trait so we can use it without the concrete implementation use r1cs::{ R1CSGens, R1CSInstance, R1CSShape, R1CSWitness, RelaxedR1CSInstance, RelaxedR1CSWitness, }; use snark::RelaxedR1CSSNARKTrait; -use traits::{AbsorbInROTrait, Group, HashFuncConstantsTrait, HashFuncTrait, StepCircuit}; - -type ROConstants = - <::HashFunc as HashFuncTrait<::Base, ::Scalar>>::Constants; +use traits::{ + AbsorbInROTrait, Group, HashFuncConstants, HashFuncConstantsCircuit, HashFuncConstantsTrait, + HashFuncTrait, StepCircuit, +}; /// A type that holds public parameters of Nova pub struct PublicParams where G1: Group::Scalar>, G2: Group::Scalar>, - C1: StepCircuit + Clone, - C2: StepCircuit + Clone, + C1: StepCircuit, + C2: StepCircuit, { - ro_consts_primary: ROConstants, - ro_consts_circuit_primary: ROConstantsCircuit<::Base>, + ro_consts_primary: HashFuncConstants, + ro_consts_circuit_primary: HashFuncConstantsCircuit, r1cs_gens_primary: R1CSGens, r1cs_shape_primary: R1CSShape, r1cs_shape_padded_primary: R1CSShape, - ro_consts_secondary: ROConstants, - ro_consts_circuit_secondary: ROConstantsCircuit<::Base>, + ro_consts_secondary: HashFuncConstants, + ro_consts_circuit_secondary: HashFuncConstantsCircuit, r1cs_gens_secondary: R1CSGens, r1cs_shape_secondary: R1CSShape, r1cs_shape_padded_secondary: R1CSShape, @@ -71,21 +70,22 @@ impl PublicParams where G1: Group::Scalar>, G2: Group::Scalar>, - C1: StepCircuit + Clone, - C2: StepCircuit + Clone, + C1: StepCircuit, + C2: StepCircuit, { /// Create a new `PublicParams` pub fn setup(c_primary: C1, c_secondary: C2) -> Self { let params_primary = NIFSVerifierCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, true); let params_secondary = NIFSVerifierCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, false); - let ro_consts_primary: ROConstants = ROConstants::::new(); - let ro_consts_secondary: ROConstants = ROConstants::::new(); + let ro_consts_primary: HashFuncConstants = HashFuncConstants::::new(); + let ro_consts_secondary: HashFuncConstants = HashFuncConstants::::new(); - let ro_consts_circuit_primary: ROConstantsCircuit<::Base> = - ROConstantsCircuit::new(); - let ro_consts_circuit_secondary: ROConstantsCircuit<::Base> = - ROConstantsCircuit::new(); + // ro_consts_circuit_primart are parameterized by G2 because the type alias uses G2::Base = G1::Scalar + let ro_consts_circuit_primary: HashFuncConstantsCircuit = + HashFuncConstantsCircuit::::new(); + let ro_consts_circuit_secondary: HashFuncConstantsCircuit = + HashFuncConstantsCircuit::::new(); // Initialize gens for the primary let circuit_primary: NIFSVerifierCircuit = NIFSVerifierCircuit::new( @@ -135,8 +135,8 @@ pub struct RecursiveSNARK where G1: Group::Scalar>, G2: Group::Scalar>, - C1: StepCircuit + Clone, - C2: StepCircuit + Clone, + C1: StepCircuit, + C2: StepCircuit, { r_W_primary: RelaxedR1CSWitness, r_U_primary: RelaxedR1CSInstance, @@ -156,8 +156,8 @@ impl RecursiveSNARK where G1: Group::Scalar>, G2: Group::Scalar>, - C1: StepCircuit + Clone, - C2: StepCircuit + Clone, + C1: StepCircuit, + C2: StepCircuit, { /// Create a new `RecursiveSNARK` pub fn prove( diff --git a/src/nifs.rs b/src/nifs.rs index a9ed77c..05153b5 100644 --- a/src/nifs.rs +++ b/src/nifs.rs @@ -17,7 +17,7 @@ pub struct NIFS { _p: PhantomData, } -type ROConstants = +type HashFuncConstants = <::HashFunc as HashFuncTrait<::Base, ::Scalar>>::Constants; impl NIFS { @@ -29,7 +29,7 @@ impl NIFS { /// if and only if `W1` satisfies `U1` and `W2` satisfies `U2`. pub fn prove( gens: &R1CSGens, - ro_consts: &ROConstants, + ro_consts: &HashFuncConstants, S: &R1CSShape, U1: &RelaxedR1CSInstance, W1: &RelaxedR1CSWitness, @@ -78,7 +78,7 @@ impl NIFS { /// if and only if `U1` and `U2` are satisfiable. pub fn verify( &self, - ro_consts: &ROConstants, + ro_consts: &HashFuncConstants, S: &R1CSShape, U1: &RelaxedR1CSInstance, U2: &R1CSInstance, diff --git a/src/pasta.rs b/src/pasta.rs index 0deb040..9992354 100644 --- a/src/pasta.rs +++ b/src/pasta.rs @@ -1,6 +1,6 @@ //! This module implements the Nova traits for pallas::Point, pallas::Scalar, vesta::Point, vesta::Scalar. use crate::{ - poseidon::PoseidonRO, + poseidon::{PoseidonHashFunc, PoseidonHashFuncCircuit}, traits::{ChallengeTrait, CompressedGroup, Group}, }; use digest::{ExtendableOutput, Input}; @@ -40,7 +40,8 @@ impl Group for pallas::Point { type Scalar = pallas::Scalar; type CompressedGroupElement = PallasCompressedElementWrapper; type PreprocessedGroupElement = pallas::Affine; - type HashFunc = PoseidonRO; + type HashFunc = PoseidonHashFunc; + type HashFuncCircuit = PoseidonHashFuncCircuit; fn vartime_multiscalar_mul( scalars: &[Self::Scalar], @@ -137,7 +138,8 @@ impl Group for vesta::Point { type Scalar = vesta::Scalar; type CompressedGroupElement = VestaCompressedElementWrapper; type PreprocessedGroupElement = vesta::Affine; - type HashFunc = PoseidonRO; + type HashFunc = PoseidonHashFunc; + type HashFuncCircuit = PoseidonHashFuncCircuit; fn vartime_multiscalar_mul( scalars: &[Self::Scalar], diff --git a/src/poseidon.rs b/src/poseidon.rs index f1988bb..961788a 100644 --- a/src/poseidon.rs +++ b/src/poseidon.rs @@ -1,7 +1,7 @@ //! Poseidon Constants and Poseidon-based RO used in Nova use super::{ constants::{NUM_CHALLENGE_BITS, NUM_HASH_BITS}, - traits::{HashFuncConstantsTrait, HashFuncTrait}, + traits::{HashFuncCircuitTrait, HashFuncConstantsTrait, HashFuncTrait}, }; use bellperson::{ gadgets::{ @@ -21,7 +21,7 @@ use neptune::{ /// All Poseidon Constants that are used in Nova #[derive(Clone)] -pub struct ROConstantsCircuit +pub struct PoseidonConstantsCircuit where Scalar: PrimeField, { @@ -29,7 +29,7 @@ where constants32: PoseidonConstants, } -impl HashFuncConstantsTrait for ROConstantsCircuit +impl HashFuncConstantsTrait for PoseidonConstantsCircuit where Scalar: PrimeField + PrimeFieldBits, { @@ -46,7 +46,7 @@ where } /// A Poseidon-based RO to use outside circuits -pub struct PoseidonRO +pub struct PoseidonHashFunc where Base: PrimeField + PrimeFieldBits, Scalar: PrimeField + PrimeFieldBits, @@ -54,11 +54,11 @@ where // Internal State state: Vec, // Constants for Poseidon - constants: ROConstantsCircuit, + constants: PoseidonConstantsCircuit, _p: PhantomData, } -impl PoseidonRO +impl PoseidonHashFunc where Base: PrimeField + PrimeFieldBits, Scalar: PrimeField + PrimeFieldBits, @@ -81,14 +81,14 @@ where } } -impl HashFuncTrait for PoseidonRO +impl HashFuncTrait for PoseidonHashFunc where Base: PrimeField + PrimeFieldBits, Scalar: PrimeField + PrimeFieldBits, { - type Constants = ROConstantsCircuit; + type Constants = PoseidonConstantsCircuit; - fn new(constants: ROConstantsCircuit) -> Self { + fn new(constants: PoseidonConstantsCircuit) -> Self { Self { state: Vec::new(), constants, @@ -134,32 +134,19 @@ where } /// A Poseidon-based RO gadget to use inside the verifier circuit. -pub struct PoseidonROGadget +pub struct PoseidonHashFuncCircuit where Scalar: PrimeField + PrimeFieldBits, { // Internal state state: Vec>, - constants: ROConstantsCircuit, + constants: PoseidonConstantsCircuit, } -impl PoseidonROGadget +impl PoseidonHashFuncCircuit where Scalar: PrimeField + PrimeFieldBits, { - /// Initialize the internal state and set the poseidon constants - pub fn new(constants: ROConstantsCircuit) -> Self { - Self { - state: Vec::new(), - constants, - } - } - - /// Absorb a new number into the state of the oracle - pub fn absorb(&mut self, e: AllocatedNum) { - self.state.push(e); - } - fn hash_inner(&mut self, mut cs: CS) -> Result, SynthesisError> where CS: ConstraintSystem, @@ -195,9 +182,29 @@ where .collect(), ) } +} + +impl HashFuncCircuitTrait for PoseidonHashFuncCircuit +where + Scalar: PrimeField + PrimeFieldBits, +{ + type Constants = PoseidonConstantsCircuit; + + /// Initialize the internal state and set the poseidon constants + fn new(constants: PoseidonConstantsCircuit) -> Self { + Self { + state: Vec::new(), + constants, + } + } + + /// Absorb a new number into the state of the oracle + fn absorb(&mut self, e: AllocatedNum) { + self.state.push(e); + } /// Compute a challenge by hashing the current state - pub fn get_challenge(&mut self, mut cs: CS) -> Result, SynthesisError> + fn get_challenge(&mut self, mut cs: CS) -> Result, SynthesisError> where CS: ConstraintSystem, { @@ -205,7 +212,7 @@ where Ok(bits[..NUM_CHALLENGE_BITS].into()) } - pub fn get_hash(&mut self, mut cs: CS) -> Result, SynthesisError> + fn get_hash(&mut self, mut cs: CS) -> Result, SynthesisError> where CS: ConstraintSystem, { @@ -228,9 +235,9 @@ mod tests { fn test_poseidon_ro() { // Check that the number computed inside the circuit is equal to the number computed outside the circuit let mut csprng: OsRng = OsRng; - let constants = ROConstantsCircuit::new(); - let mut ro: PoseidonRO = PoseidonRO::new(constants.clone()); - let mut ro_gadget: PoseidonROGadget = PoseidonROGadget::new(constants); + let constants = PoseidonConstantsCircuit::new(); + let mut ro: PoseidonHashFunc = PoseidonHashFunc::new(constants.clone()); + let mut ro_gadget: PoseidonHashFuncCircuit = PoseidonHashFuncCircuit::new(constants); let mut cs: SatisfyingAssignment = SatisfyingAssignment::new(); for i in 0..27 { let num = S::random(&mut csprng); diff --git a/src/traits.rs b/src/traits.rs index 56d5718..5b2e2c4 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,5 +1,8 @@ //! This module defines various traits required by the users of the library to implement. -use bellperson::{gadgets::num::AllocatedNum, ConstraintSystem, SynthesisError}; +use bellperson::{ + gadgets::{boolean::AllocatedBit, num::AllocatedNum}, + ConstraintSystem, SynthesisError, +}; use core::{ fmt::Debug, ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}, @@ -38,6 +41,9 @@ pub trait Group: /// from the base field and squeezes out elements of the scalar field type HashFunc: HashFuncTrait; + /// An alternate implementation of Self::HashFunc in the circuit model + type HashFuncCircuit: HashFuncCircuitTrait; + /// A method to compute a multiexponentation fn vartime_multiscalar_mul( scalars: &[Self::Scalar], @@ -108,12 +114,42 @@ pub trait HashFuncTrait { fn get_hash(&self) -> Scalar; } +/// A helper trait that defines the behavior of a hash function that we use as an RO in the circuit model +pub trait HashFuncCircuitTrait { + /// A type representing constants/parameters associated with the hash function + type Constants: HashFuncConstantsTrait + Clone + Send + Sync; + + /// Initializes the hash function + fn new(constants: Self::Constants) -> Self; + + /// Adds a scalar to the internal state + fn absorb(&mut self, e: AllocatedNum); + + /// Returns a random challenge by hashing the internal state + fn get_challenge(&mut self, cs: CS) -> Result, SynthesisError> + where + CS: ConstraintSystem; + + /// Returns a hash of the internal state + fn get_hash(&mut self, cs: CS) -> Result, SynthesisError> + where + CS: ConstraintSystem; +} + /// A helper trait that defines the constants associated with a hash function pub trait HashFuncConstantsTrait { /// produces constants/parameters associated with the hash function fn new() -> Self; } +/// An alias for constants associated with G::HashFunc +pub type HashFuncConstants = + <::HashFunc as HashFuncTrait<::Base, ::Scalar>>::Constants; + +/// An alias for constants associated with G::HashFuncCircuit +pub type HashFuncConstantsCircuit = + <::HashFuncCircuit as HashFuncCircuitTrait<::Base>>::Constants; + /// A helper trait for types with a group operation. pub trait GroupOps: Add + Sub + AddAssign + SubAssign