From 35cb03f977fd9ebe4d90de9ab74ab04765443cfa Mon Sep 17 00:00:00 2001 From: Srinath Setty Date: Thu, 14 Jul 2022 16:15:45 -0700 Subject: [PATCH] reorganize traits into a module; cut boilerplate code (#91) use a default implementation for step circuit --- benches/compressed-snark.rs | 49 +++------------- benches/recursive-snark.rs | 57 +++---------------- examples/minroot.rs | 33 ++--------- src/circuit.rs | 52 ++++------------- src/lib.rs | 97 ++++++-------------------------- src/nifs.rs | 2 +- src/spartan_with_ipa_pc/mod.rs | 6 +- src/traits/circuit.rs | 41 ++++++++++++++ src/{traits.rs => traits/mod.rs} | 17 +----- src/{ => traits}/snark.rs | 4 +- 10 files changed, 100 insertions(+), 258 deletions(-) create mode 100644 src/traits/circuit.rs rename src/{traits.rs => traits/mod.rs} (93%) rename src/{ => traits}/snark.rs (92%) diff --git a/benches/compressed-snark.rs b/benches/compressed-snark.rs index 8133e76..090a2e9 100644 --- a/benches/compressed-snark.rs +++ b/benches/compressed-snark.rs @@ -1,46 +1,19 @@ #![allow(non_snake_case)] +use criterion::*; use nova_snark::{ - traits::{Group, StepCircuit}, + traits::{circuit::TrivialTestCircuit, Group}, CompressedSNARK, PublicParams, RecursiveSNARK, }; +use std::time::Duration; type G1 = pasta_curves::pallas::Point; type G2 = pasta_curves::vesta::Point; type S1 = nova_snark::spartan_with_ipa_pc::RelaxedR1CSSNARK; type S2 = nova_snark::spartan_with_ipa_pc::RelaxedR1CSSNARK; - -#[derive(Clone, Debug)] -struct TrivialTestCircuit { - _p: PhantomData, -} - -impl StepCircuit for TrivialTestCircuit -where - F: PrimeField, -{ - fn synthesize>( - &self, - _cs: &mut CS, - z: AllocatedNum, - ) -> Result, SynthesisError> { - Ok(z) - } - - fn compute(&self, z: &F) -> F { - *z - } -} - type C1 = TrivialTestCircuit<::Scalar>; type C2 = TrivialTestCircuit<::Scalar>; -use bellperson::{gadgets::num::AllocatedNum, ConstraintSystem, SynthesisError}; -use core::marker::PhantomData; -use criterion::*; -use ff::PrimeField; -use std::time::Duration; - fn compressed_snark_benchmark(c: &mut Criterion) { let num_samples = 10; bench_compressed_snark(c, num_samples); @@ -64,12 +37,8 @@ fn bench_compressed_snark(c: &mut Criterion, num_samples: usize) { // Produce public parameters let pp = PublicParams::::setup( - TrivialTestCircuit { - _p: Default::default(), - }, - TrivialTestCircuit { - _p: Default::default(), - }, + TrivialTestCircuit::default(), + TrivialTestCircuit::default(), ); // produce a recursive SNARK @@ -80,12 +49,8 @@ fn bench_compressed_snark(c: &mut Criterion, num_samples: usize) { let res = RecursiveSNARK::prove_step( &pp, recursive_snark, - TrivialTestCircuit { - _p: Default::default(), - }, - TrivialTestCircuit { - _p: Default::default(), - }, + TrivialTestCircuit::default(), + TrivialTestCircuit::default(), ::Scalar::one(), ::Scalar::zero(), ); diff --git a/benches/recursive-snark.rs b/benches/recursive-snark.rs index 761645e..31cb86b 100644 --- a/benches/recursive-snark.rs +++ b/benches/recursive-snark.rs @@ -1,44 +1,17 @@ #![allow(non_snake_case)] +use criterion::*; use nova_snark::{ - traits::{Group, StepCircuit}, + traits::{circuit::TrivialTestCircuit, Group}, PublicParams, RecursiveSNARK, }; +use std::time::Duration; type G1 = pasta_curves::pallas::Point; type G2 = pasta_curves::vesta::Point; - -#[derive(Clone, Debug)] -struct TrivialTestCircuit { - _p: PhantomData, -} - -impl StepCircuit for TrivialTestCircuit -where - F: PrimeField, -{ - fn synthesize>( - &self, - _cs: &mut CS, - z: AllocatedNum, - ) -> Result, SynthesisError> { - Ok(z) - } - - fn compute(&self, z: &F) -> F { - *z - } -} - type C1 = TrivialTestCircuit<::Scalar>; type C2 = TrivialTestCircuit<::Scalar>; -use bellperson::{gadgets::num::AllocatedNum, ConstraintSystem, SynthesisError}; -use core::marker::PhantomData; -use criterion::*; -use ff::PrimeField; -use std::time::Duration; - fn recursive_snark_benchmark(c: &mut Criterion) { let num_samples = 10; bench_recursive_snark(c, num_samples); @@ -62,12 +35,8 @@ fn bench_recursive_snark(c: &mut Criterion, num_samples: usize) { // Produce public parameters let pp = PublicParams::::setup( - TrivialTestCircuit { - _p: Default::default(), - }, - TrivialTestCircuit { - _p: Default::default(), - }, + TrivialTestCircuit::default(), + TrivialTestCircuit::default(), ); // Bench time to produce a recursive SNARK; @@ -81,12 +50,8 @@ fn bench_recursive_snark(c: &mut Criterion, num_samples: usize) { let res = RecursiveSNARK::prove_step( &pp, recursive_snark, - TrivialTestCircuit { - _p: Default::default(), - }, - TrivialTestCircuit { - _p: Default::default(), - }, + TrivialTestCircuit::default(), + TrivialTestCircuit::default(), ::Scalar::one(), ::Scalar::zero(), ); @@ -112,12 +77,8 @@ fn bench_recursive_snark(c: &mut Criterion, num_samples: usize) { assert!(RecursiveSNARK::prove_step( black_box(&pp), black_box(recursive_snark.clone()), - black_box(TrivialTestCircuit { - _p: Default::default(), - }), - black_box(TrivialTestCircuit { - _p: Default::default(), - }), + black_box(TrivialTestCircuit::default()), + black_box(TrivialTestCircuit::default()), black_box(::Scalar::zero()), black_box(::Scalar::zero()), ) diff --git a/examples/minroot.rs b/examples/minroot.rs index a75de23..4643791 100644 --- a/examples/minroot.rs +++ b/examples/minroot.rs @@ -12,11 +12,13 @@ use neptune::{ Strength, }; use nova_snark::{ - traits::{Group, StepCircuit}, + traits::{ + circuit::{StepCircuit, TrivialTestCircuit}, + Group, + }, CompressedSNARK, PublicParams, RecursiveSNARK, }; use num_bigint::BigUint; -use std::marker::PhantomData; use std::time::Instant; #[derive(Clone, Debug)] @@ -183,9 +185,7 @@ fn main() { pc: pc.clone(), }; - let circuit_secondary = TrivialTestCircuit { - _p: Default::default(), - }; + let circuit_secondary = TrivialTestCircuit::default(); println!("Nova-based VDF with MinRoot delay function"); println!("=========================================="); @@ -299,26 +299,3 @@ fn main() { ); assert!(res.is_ok()); } - -// A trivial test circuit that we use on the secondary curve -#[derive(Clone, Debug)] -struct TrivialTestCircuit { - _p: PhantomData, -} - -impl StepCircuit for TrivialTestCircuit -where - F: PrimeField, -{ - fn synthesize>( - &self, - _cs: &mut CS, - z: AllocatedNum, - ) -> Result, SynthesisError> { - Ok(z) - } - - fn compute(&self, z: &F) -> F { - *z - } -} diff --git a/src/circuit.rs b/src/circuit.rs index 87a4e92..ebe68f7 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -16,7 +16,7 @@ use super::{ }, }, r1cs::{R1CSInstance, RelaxedR1CSInstance}, - traits::{Group, HashFuncCircuitTrait, HashFuncConstantsCircuit, StepCircuit}, + traits::{circuit::StepCircuit, Group, HashFuncCircuitTrait, HashFuncConstantsCircuit}, }; use bellperson::{ gadgets::{ @@ -355,32 +355,8 @@ mod tests { use crate::{ bellperson::r1cs::{NovaShape, NovaWitness}, poseidon::PoseidonConstantsCircuit, - traits::HashFuncConstantsTrait, + traits::{circuit::TrivialTestCircuit, HashFuncConstantsTrait}, }; - use ff::PrimeField; - use std::marker::PhantomData; - - #[derive(Clone)] - struct TestCircuit { - _p: PhantomData, - } - - impl StepCircuit for TestCircuit - where - F: PrimeField, - { - fn synthesize>( - &self, - _cs: &mut CS, - z: AllocatedNum, - ) -> Result, SynthesisError> { - Ok(z) - } - - fn compute(&self, z: &F) -> F { - *z - } - } #[test] fn test_verification_circuit() { @@ -391,13 +367,11 @@ mod tests { let ro_consts2: HashFuncConstantsCircuit = PoseidonConstantsCircuit::new(); // Initialize the shape and gens for the primary - let circuit1: NIFSVerifierCircuit::Base>> = + let circuit1: NIFSVerifierCircuit::Base>> = NIFSVerifierCircuit::new( params1.clone(), None, - TestCircuit { - _p: Default::default(), - }, + TrivialTestCircuit::default(), ro_consts1.clone(), ); let mut cs: ShapeCS = ShapeCS::new(); @@ -406,13 +380,11 @@ mod tests { assert_eq!(cs.num_constraints(), 20584); // Initialize the shape and gens for the secondary - let circuit2: NIFSVerifierCircuit::Base>> = + let circuit2: NIFSVerifierCircuit::Base>> = NIFSVerifierCircuit::new( params2.clone(), None, - TestCircuit { - _p: Default::default(), - }, + TrivialTestCircuit::default(), ro_consts2.clone(), ); let mut cs: ShapeCS = ShapeCS::new(); @@ -425,13 +397,11 @@ mod tests { let mut cs1: SatisfyingAssignment = SatisfyingAssignment::new(); let inputs1: NIFSVerifierCircuitInputs = NIFSVerifierCircuitInputs::new(shape2.get_digest(), zero1, zero1, None, None, None, None); - let circuit1: NIFSVerifierCircuit::Base>> = + let circuit1: NIFSVerifierCircuit::Base>> = NIFSVerifierCircuit::new( params1, Some(inputs1), - TestCircuit { - _p: Default::default(), - }, + TrivialTestCircuit::default(), ro_consts1, ); let _ = circuit1.synthesize(&mut cs1); @@ -451,13 +421,11 @@ mod tests { Some(inst1), None, ); - let circuit: NIFSVerifierCircuit::Base>> = + let circuit: NIFSVerifierCircuit::Base>> = NIFSVerifierCircuit::new( params2, Some(inputs2), - TestCircuit { - _p: Default::default(), - }, + TrivialTestCircuit::default(), ro_consts2, ); let _ = circuit.synthesize(&mut cs2); diff --git a/src/lib.rs b/src/lib.rs index d7df862..b59d524 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,7 +16,6 @@ mod r1cs; pub mod errors; pub mod gadgets; pub mod pasta; -pub mod snark; pub mod spartan_with_ipa_pc; pub mod traits; @@ -36,10 +35,9 @@ use nifs::NIFS; use r1cs::{ R1CSGens, R1CSInstance, R1CSShape, R1CSWitness, RelaxedR1CSInstance, RelaxedR1CSWitness, }; -use snark::RelaxedR1CSSNARKTrait; use traits::{ - AbsorbInROTrait, Group, HashFuncConstants, HashFuncConstantsCircuit, HashFuncConstantsTrait, - HashFuncTrait, StepCircuit, + circuit::StepCircuit, snark::RelaxedR1CSSNARKTrait, AbsorbInROTrait, Group, HashFuncConstants, + HashFuncConstantsCircuit, HashFuncConstantsTrait, HashFuncTrait, }; /// A type that holds public parameters of Nova @@ -665,32 +663,11 @@ mod tests { type S1 = spartan_with_ipa_pc::RelaxedR1CSSNARK; type S2 = spartan_with_ipa_pc::RelaxedR1CSSNARK; use ::bellperson::{gadgets::num::AllocatedNum, ConstraintSystem, SynthesisError}; + use core::marker::PhantomData; use ff::PrimeField; - use std::marker::PhantomData; + use traits::circuit::TrivialTestCircuit; - #[derive(Clone, Debug)] - struct TrivialTestCircuit { - _p: PhantomData, - } - - impl StepCircuit for TrivialTestCircuit - where - F: PrimeField, - { - fn synthesize>( - &self, - _cs: &mut CS, - z: AllocatedNum, - ) -> Result, SynthesisError> { - Ok(z) - } - - fn compute(&self, z: &F) -> F { - *z - } - } - - #[derive(Clone, Debug)] + #[derive(Clone, Debug, Default)] struct CubicCircuit { _p: PhantomData, } @@ -743,14 +720,7 @@ mod tests { G2, TrivialTestCircuit<::Scalar>, TrivialTestCircuit<::Scalar>, - >::setup( - TrivialTestCircuit { - _p: Default::default(), - }, - TrivialTestCircuit { - _p: Default::default(), - }, - ); + >::setup(TrivialTestCircuit::default(), TrivialTestCircuit::default()); let num_steps = 1; @@ -758,12 +728,8 @@ mod tests { let res = RecursiveSNARK::prove_step( &pp, None, - TrivialTestCircuit { - _p: Default::default(), - }, - TrivialTestCircuit { - _p: Default::default(), - }, + TrivialTestCircuit::default(), + TrivialTestCircuit::default(), ::Scalar::zero(), ::Scalar::zero(), ); @@ -782,12 +748,8 @@ mod tests { #[test] fn test_ivc_nontrivial() { - let circuit_primary = TrivialTestCircuit { - _p: Default::default(), - }; - let circuit_secondary = CubicCircuit { - _p: Default::default(), - }; + let circuit_primary = TrivialTestCircuit::default(); + let circuit_secondary = CubicCircuit::default(); // produce public parameters let pp = PublicParams::< @@ -852,10 +814,7 @@ mod tests { assert_eq!(zn_primary, ::Scalar::one()); let mut zn_secondary_direct = ::Scalar::zero(); for _i in 0..num_steps { - zn_secondary_direct = CubicCircuit { - _p: Default::default(), - } - .compute(&zn_secondary_direct); + zn_secondary_direct = CubicCircuit::default().compute(&zn_secondary_direct); } assert_eq!(zn_secondary, zn_secondary_direct); assert_eq!(zn_secondary, ::Scalar::from(2460515u64)); @@ -863,12 +822,8 @@ mod tests { #[test] fn test_ivc_nontrivial_with_compression() { - let circuit_primary = TrivialTestCircuit { - _p: Default::default(), - }; - let circuit_secondary = CubicCircuit { - _p: Default::default(), - }; + let circuit_primary = TrivialTestCircuit::default(); + let circuit_secondary = CubicCircuit::default(); // produce public parameters let pp = PublicParams::< @@ -921,10 +876,7 @@ mod tests { assert_eq!(zn_primary, ::Scalar::one()); let mut zn_secondary_direct = ::Scalar::zero(); for _i in 0..num_steps { - zn_secondary_direct = CubicCircuit { - _p: Default::default(), - } - .compute(&zn_secondary_direct); + zn_secondary_direct = CubicCircuit::default().compute(&zn_secondary_direct); } assert_eq!(zn_secondary, zn_secondary_direct); assert_eq!(zn_secondary, ::Scalar::from(2460515u64)); @@ -1027,9 +979,7 @@ mod tests { y: ::Scalar::zero(), }; - let circuit_secondary = TrivialTestCircuit { - _p: Default::default(), - }; + let circuit_secondary = TrivialTestCircuit::default(); // produce public parameters let pp = PublicParams::< @@ -1093,14 +1043,7 @@ mod tests { G2, TrivialTestCircuit<::Scalar>, CubicCircuit<::Scalar>, - >::setup( - TrivialTestCircuit { - _p: Default::default(), - }, - CubicCircuit { - _p: Default::default(), - }, - ); + >::setup(TrivialTestCircuit::default(), CubicCircuit::default()); let num_steps = 1; @@ -1108,12 +1051,8 @@ mod tests { let res = RecursiveSNARK::prove_step( &pp, None, - TrivialTestCircuit { - _p: Default::default(), - }, - CubicCircuit { - _p: Default::default(), - }, + TrivialTestCircuit::default(), + CubicCircuit::default(), ::Scalar::one(), ::Scalar::zero(), ); diff --git a/src/nifs.rs b/src/nifs.rs index 57b5a98..55cdc04 100644 --- a/src/nifs.rs +++ b/src/nifs.rs @@ -8,7 +8,7 @@ use super::r1cs::{ R1CSGens, R1CSInstance, R1CSShape, R1CSWitness, RelaxedR1CSInstance, RelaxedR1CSWitness, }; use super::traits::{AbsorbInROTrait, Group, HashFuncTrait}; -use std::marker::PhantomData; +use core::marker::PhantomData; /// A SNARK that holds the proof of a step of an incremental computation #[allow(clippy::upper_case_acronyms)] diff --git a/src/spartan_with_ipa_pc/mod.rs b/src/spartan_with_ipa_pc/mod.rs index 5063399..e3e245b 100644 --- a/src/spartan_with_ipa_pc/mod.rs +++ b/src/spartan_with_ipa_pc/mod.rs @@ -8,8 +8,10 @@ use super::{ commitments::CommitGens, errors::NovaError, r1cs::{R1CSGens, R1CSShape, RelaxedR1CSInstance, RelaxedR1CSWitness}, - snark::{ProverKeyTrait, RelaxedR1CSSNARKTrait, VerifierKeyTrait}, - traits::{AppendToTranscriptTrait, ChallengeTrait, Group}, + traits::{ + snark::{ProverKeyTrait, RelaxedR1CSSNARKTrait, VerifierKeyTrait}, + AppendToTranscriptTrait, ChallengeTrait, Group, + }, }; use core::cmp::max; use ff::Field; diff --git a/src/traits/circuit.rs b/src/traits/circuit.rs new file mode 100644 index 0000000..3afaccd --- /dev/null +++ b/src/traits/circuit.rs @@ -0,0 +1,41 @@ +//! This module defines traits that a step function must implement +use bellperson::{gadgets::num::AllocatedNum, ConstraintSystem, SynthesisError}; +use core::marker::PhantomData; +use ff::PrimeField; + +/// A helper trait for a step of the incremental computation (i.e., circuit for F) +pub trait StepCircuit: Send + Sync + Clone { + /// Sythesize the circuit for a computation step and return variable + /// that corresponds to the output of the step z_{i+1} + fn synthesize>( + &self, + cs: &mut CS, + z: AllocatedNum, + ) -> Result, SynthesisError>; + + /// Execute the circuit for a computation step and return output + fn compute(&self, z: &F) -> F; +} + +/// A trivial step circuit that simply returns the input +#[derive(Clone, Debug, Default)] +pub struct TrivialTestCircuit { + _p: PhantomData, +} + +impl StepCircuit for TrivialTestCircuit +where + F: PrimeField, +{ + fn synthesize>( + &self, + _cs: &mut CS, + z: AllocatedNum, + ) -> Result, SynthesisError> { + Ok(z) + } + + fn compute(&self, z: &F) -> F { + *z + } +} diff --git a/src/traits.rs b/src/traits/mod.rs similarity index 93% rename from src/traits.rs rename to src/traits/mod.rs index 5b2e2c4..d8f79cf 100644 --- a/src/traits.rs +++ b/src/traits/mod.rs @@ -175,20 +175,6 @@ impl ScalarMul for T where T: Mul: for<'r> ScalarMul<&'r Rhs, Output> {} impl ScalarMulOwned for T where T: for<'r> ScalarMul<&'r Rhs, Output> {} -/// A helper trait for a step of the incremental computation (i.e., circuit for F) -pub trait StepCircuit: Send + Sync + Clone { - /// Sythesize the circuit for a computation step and return variable - /// that corresponds to the output of the step z_{i+1} - fn synthesize>( - &self, - cs: &mut CS, - z: AllocatedNum, - ) -> Result, SynthesisError>; - - /// Execute the circuit for a computation step and return output - fn compute(&self, z: &F) -> F; -} - impl AppendToTranscriptTrait for F { fn append_to_transcript(&self, label: &'static [u8], transcript: &mut Transcript) { transcript.append_message(label, self.to_repr().as_ref()); @@ -202,3 +188,6 @@ impl AppendToTranscriptTrait for [F] { } } } + +pub mod circuit; +pub mod snark; diff --git a/src/snark.rs b/src/traits/snark.rs similarity index 92% rename from src/snark.rs rename to src/traits/snark.rs index ff98280..283e622 100644 --- a/src/snark.rs +++ b/src/traits/snark.rs @@ -1,5 +1,5 @@ -//! A collection of traits that define the behavior of a zkSNARK for RelaxedR1CS -use super::{ +//! This module defines a collection of traits that define the behavior of a zkSNARK for RelaxedR1CS +use crate::{ errors::NovaError, r1cs::{R1CSGens, R1CSShape, RelaxedR1CSInstance, RelaxedR1CSWitness}, traits::Group,