From ccc6ccd4c7773e1b180401b419f8d68505304b9e Mon Sep 17 00:00:00 2001 From: Srinath Setty Date: Tue, 16 Aug 2022 11:35:17 -0700 Subject: [PATCH] Support for arbitrary arity for step circuit's IO (#107) * support for arbitrary arity for F * revive MinRoot example * revive tests * revive ecdsa * remove unused code * use None instead of Some(1u32) * revive benches * fix clippy warning --- benches/compressed-snark.rs | 30 +++-- benches/recursive-snark.rs | 34 +++--- examples/ecdsa/circuit.rs | 118 +++--------------- examples/ecdsa/main.rs | 28 +---- examples/minroot.rs | 85 +++++-------- src/circuit.rs | 92 ++++++++++---- src/constants.rs | 2 +- src/errors.rs | 4 + src/gadgets/utils.rs | 16 +++ src/lib.rs | 213 ++++++++++++++++++++------------- src/poseidon.rs | 4 +- src/spartan_with_ipa_pc/mod.rs | 3 +- src/traits/circuit.rs | 26 ++-- 13 files changed, 323 insertions(+), 332 deletions(-) diff --git a/benches/compressed-snark.rs b/benches/compressed-snark.rs index 364786c..dde6003 100644 --- a/benches/compressed-snark.rs +++ b/benches/compressed-snark.rs @@ -57,8 +57,8 @@ fn bench_compressed_snark(c: &mut Criterion) { recursive_snark, NonTrivialTestCircuit::new(num_cons), TrivialTestCircuit::default(), - ::Scalar::from(2u64), - ::Scalar::from(2u64), + vec![::Scalar::from(2u64)], + vec![::Scalar::from(2u64)], ); assert!(res.is_ok()); let recursive_snark_unwrapped = res.unwrap(); @@ -67,8 +67,8 @@ fn bench_compressed_snark(c: &mut Criterion) { let res = recursive_snark_unwrapped.verify( &pp, i + 1, - ::Scalar::from(2u64), - ::Scalar::from(2u64), + vec![::Scalar::from(2u64)], + vec![::Scalar::from(2u64)], ); assert!(res.is_ok()); @@ -98,8 +98,8 @@ fn bench_compressed_snark(c: &mut Criterion) { .verify( black_box(&pp), black_box(num_steps), - black_box(::Scalar::from(2u64)), - black_box(::Scalar::from(2u64)), + black_box(vec![::Scalar::from(2u64)]), + black_box(vec![::Scalar::from(2u64)]), ) .is_ok()); }) @@ -130,28 +130,32 @@ impl StepCircuit for NonTrivialTestCircuit where F: PrimeField, { + fn arity(&self) -> usize { + 1 + } + fn synthesize>( &self, cs: &mut CS, - z: AllocatedNum, - ) -> Result, SynthesisError> { + z: &[AllocatedNum], + ) -> Result>, SynthesisError> { // Consider a an equation: `x^2 = y`, where `x` and `y` are respectively the input and output. - let mut x = z; + let mut x = z[0].clone(); let mut y = x.clone(); for i in 0..self.num_cons { y = x.square(cs.namespace(|| format!("x_sq_{}", i)))?; x = y.clone(); } - Ok(y) + Ok(vec![y]) } - fn output(&self, z: &F) -> F { - let mut x = *z; + fn output(&self, z: &[F]) -> Vec { + let mut x = z[0]; let mut y = x; for _i in 0..self.num_cons { y = x * x; x = y; } - y + vec![y] } } diff --git a/benches/recursive-snark.rs b/benches/recursive-snark.rs index 25511e8..8023f47 100644 --- a/benches/recursive-snark.rs +++ b/benches/recursive-snark.rs @@ -57,8 +57,8 @@ fn bench_recursive_snark(c: &mut Criterion) { recursive_snark, NonTrivialTestCircuit::new(num_cons), TrivialTestCircuit::default(), - ::Scalar::from(2u64), - ::Scalar::from(2u64), + vec![::Scalar::from(2u64)], + vec![::Scalar::from(2u64)], ); assert!(res.is_ok()); let recursive_snark_unwrapped = res.unwrap(); @@ -67,8 +67,8 @@ fn bench_recursive_snark(c: &mut Criterion) { let res = recursive_snark_unwrapped.verify( &pp, i + 1, - ::Scalar::from(2u64), - ::Scalar::from(2u64), + vec![::Scalar::from(2u64)], + vec![::Scalar::from(2u64)], ); assert!(res.is_ok()); @@ -84,8 +84,8 @@ fn bench_recursive_snark(c: &mut Criterion) { black_box(recursive_snark.clone()), black_box(NonTrivialTestCircuit::new(num_cons)), black_box(TrivialTestCircuit::default()), - black_box(::Scalar::from(2u64)), - black_box(::Scalar::from(2u64)), + black_box(vec![::Scalar::from(2u64)]), + black_box(vec![::Scalar::from(2u64)]), ) .is_ok()); }) @@ -100,8 +100,8 @@ fn bench_recursive_snark(c: &mut Criterion) { .verify( black_box(&pp), black_box(num_warmup_steps), - black_box(::Scalar::from(2u64)), - black_box(::Scalar::from(2u64)), + black_box(vec![::Scalar::from(2u64)]), + black_box(vec![::Scalar::from(2u64)]), ) .is_ok()); }); @@ -131,28 +131,32 @@ impl StepCircuit for NonTrivialTestCircuit where F: PrimeField, { + fn arity(&self) -> usize { + 1 + } + fn synthesize>( &self, cs: &mut CS, - z: AllocatedNum, - ) -> Result, SynthesisError> { + z: &[AllocatedNum], + ) -> Result>, SynthesisError> { // Consider a an equation: `x^2 = y`, where `x` and `y` are respectively the input and output. - let mut x = z; + let mut x = z[0].clone(); let mut y = x.clone(); for i in 0..self.num_cons { y = x.square(cs.namespace(|| format!("x_sq_{}", i)))?; x = y.clone(); } - Ok(y) + Ok(vec![y]) } - fn output(&self, z: &F) -> F { - let mut x = *z; + fn output(&self, z: &[F]) -> Vec { + let mut x = z[0]; let mut y = x; for _i in 0..self.num_cons { y = x * x; x = y; } - y + vec![y] } } diff --git a/examples/ecdsa/circuit.rs b/examples/ecdsa/circuit.rs index 237ba40..211007e 100644 --- a/examples/ecdsa/circuit.rs +++ b/examples/ecdsa/circuit.rs @@ -3,11 +3,6 @@ use bellperson::{ ConstraintSystem, SynthesisError, }; use ff::{PrimeField, PrimeFieldBits}; -use generic_array::typenum::U8; -use neptune::{ - circuit::poseidon_hash, - poseidon::{Poseidon, PoseidonConstants}, -}; use nova_snark::{gadgets::ecc::AllocatedPoint, traits::circuit::StepCircuit}; use subtle::Choice; @@ -66,11 +61,6 @@ pub struct EcdsaCircuit where F: PrimeField, { - pub z_r: Coordinate, - pub z_g: Coordinate, - pub z_pk: Coordinate, - pub z_c: F, - pub z_s: F, pub r: Coordinate, pub g: Coordinate, pub pk: Coordinate, @@ -78,7 +68,6 @@ where pub s: F, pub c_bits: Vec, pub s_bits: Vec, - pub pc: PoseidonConstants, } impl EcdsaCircuit @@ -88,42 +77,14 @@ where // Creates a new [`EcdsaCircuit`]. The base and scalar field elements from the curve // field used by the signature are converted to scalar field elements from the cyclic curve // field used by the circuit. - pub fn new( - num_steps: usize, - signatures: &[EcdsaSignature], - pc: &PoseidonConstants, - ) -> (F, Vec) + pub fn new(num_steps: usize, signatures: &[EcdsaSignature]) -> (Vec, Vec) where Fb: PrimeField, Fs: PrimeField + PrimeFieldBits, { - let mut z0 = F::zero(); + let mut z0 = Vec::new(); let mut circuits = Vec::new(); - for i in 0..num_steps { - let mut j = i; - if i > 0 { - j = i - 1 - }; - let z_signature = &signatures[j]; - let z_r = Coordinate::new( - F::from_repr(z_signature.r.x.to_repr()).unwrap(), - F::from_repr(z_signature.r.y.to_repr()).unwrap(), - ); - - let z_g = Coordinate::new( - F::from_repr(z_signature.g.x.to_repr()).unwrap(), - F::from_repr(z_signature.g.y.to_repr()).unwrap(), - ); - - let z_pk = Coordinate::new( - F::from_repr(z_signature.pk.x.to_repr()).unwrap(), - F::from_repr(z_signature.pk.y.to_repr()).unwrap(), - ); - - let z_c = F::from_repr(z_signature.c.to_repr()).unwrap(); - let z_s = F::from_repr(z_signature.s.to_repr()).unwrap(); - - let signature = &signatures[i]; + for (i, signature) in signatures.iter().enumerate().take(num_steps) { let r = Coordinate::new( F::from_repr(signature.r.x.to_repr()).unwrap(), F::from_repr(signature.r.y.to_repr()).unwrap(), @@ -145,11 +106,6 @@ where let s = F::from_repr(signature.s.to_repr()).unwrap(); let circuit = EcdsaCircuit { - z_r, - z_g, - z_pk, - z_c, - z_s, r, g, pk, @@ -157,13 +113,11 @@ where s, c_bits, s_bits, - pc: pc.clone(), }; circuits.push(circuit); if i == 0 { - z0 = - Poseidon::::new_with_preimage(&[r.x, r.y, g.x, g.y, pk.x, pk.y, c, s], pc).hash(); + z0 = vec![r.x, r.y, g.x, g.y, pk.x, pk.y, c, s]; } } @@ -208,36 +162,18 @@ impl StepCircuit for EcdsaCircuit where F: PrimeField + PrimeFieldBits, { + fn arity(&self) -> usize { + 8 + } + // Prove knowledge of the sk used to generate the Ecdsa signature (R,s) // with public key PK and message commitment c. // [s]G == R + [c]PK fn synthesize>( &self, cs: &mut CS, - z: AllocatedNum, - ) -> Result, SynthesisError> { - let z_rx = AllocatedNum::alloc(cs.namespace(|| "z_rx"), || Ok(self.z_r.x))?; - let z_ry = AllocatedNum::alloc(cs.namespace(|| "z_ry"), || Ok(self.z_r.y))?; - let z_gx = AllocatedNum::alloc(cs.namespace(|| "z_gx"), || Ok(self.z_g.x))?; - let z_gy = AllocatedNum::alloc(cs.namespace(|| "z_gy"), || Ok(self.z_g.y))?; - let z_pkx = AllocatedNum::alloc(cs.namespace(|| "z_pkx"), || Ok(self.z_pk.x))?; - let z_pky = AllocatedNum::alloc(cs.namespace(|| "z_pky"), || Ok(self.z_pk.y))?; - let z_c = AllocatedNum::alloc(cs.namespace(|| "z_c"), || Ok(self.z_c))?; - let z_s = AllocatedNum::alloc(cs.namespace(|| "z_s"), || Ok(self.z_s))?; - - let z_hash = poseidon_hash( - cs.namespace(|| "input hash"), - vec![z_rx, z_ry, z_gx, z_gy, z_pkx, z_pky, z_c, z_s], - &self.pc, - )?; - - cs.enforce( - || "z == z1", - |lc| lc + z.get_variable(), - |lc| lc + CS::one(), - |lc| lc + z_hash.get_variable(), - ); - + _z: &[AllocatedNum], + ) -> Result>, SynthesisError> { let g = AllocatedPoint::alloc( cs.namespace(|| "G"), Some((self.g.x, self.g.y, self.g.is_infinity)), @@ -282,36 +218,12 @@ where let c = AllocatedNum::alloc(cs.namespace(|| "c"), || Ok(self.c))?; let s = AllocatedNum::alloc(cs.namespace(|| "s"), || Ok(self.s))?; - poseidon_hash( - cs.namespace(|| "output hash"), - vec![rx, ry, gx, gy, pkx, pky, c, s], - &self.pc, - ) + Ok(vec![rx, ry, gx, gy, pkx, pky, c, s]) } - fn output(&self, z: &F) -> F { - let z_hash = Poseidon::::new_with_preimage( - &[ - self.z_r.x, - self.z_r.y, - self.z_g.x, - self.z_g.y, - self.z_pk.x, - self.z_pk.y, - self.z_c, - self.z_s, - ], - &self.pc, - ) - .hash(); - debug_assert_eq!(z, &z_hash); - - Poseidon::::new_with_preimage( - &[ - self.r.x, self.r.y, self.g.x, self.g.y, self.pk.x, self.pk.y, self.c, self.s, - ], - &self.pc, - ) - .hash() + fn output(&self, _z: &[F]) -> Vec { + vec![ + self.r.x, self.r.y, self.g.x, self.g.y, self.pk.x, self.pk.y, self.c, self.s, + ] } } diff --git a/examples/ecdsa/main.rs b/examples/ecdsa/main.rs index 4fc854b..b2c47e8 100644 --- a/examples/ecdsa/main.rs +++ b/examples/ecdsa/main.rs @@ -9,8 +9,6 @@ use ff::{ derive::byteorder::{ByteOrder, LittleEndian}, Field, PrimeField, PrimeFieldBits, }; -use generic_array::typenum::U8; -use neptune::{poseidon::PoseidonConstants, Strength}; use nova_snark::{ traits::{circuit::TrivialTestCircuit, Group as Nova_Group}, CompressedSNARK, PublicParams, RecursiveSNARK, @@ -165,22 +163,7 @@ fn main() { // produce public parameters println!("Generating public parameters..."); - let pc = PoseidonConstants::<::Scalar, U8>::new_with_strength(Strength::Standard); let circuit_primary = EcdsaCircuit::<::Scalar> { - z_r: Coordinate::new( - ::Scalar::zero(), - ::Scalar::zero(), - ), - z_g: Coordinate::new( - ::Scalar::zero(), - ::Scalar::zero(), - ), - z_pk: Coordinate::new( - ::Scalar::zero(), - ::Scalar::zero(), - ), - z_c: ::Scalar::zero(), - z_s: ::Scalar::zero(), r: Coordinate::new( ::Scalar::zero(), ::Scalar::zero(), @@ -197,7 +180,6 @@ fn main() { s: ::Scalar::zero(), c_bits: vec![Choice::from(0u8); 256], s_bits: vec![Choice::from(0u8); 256], - pc: pc.clone(), }; let circuit_secondary = TrivialTestCircuit::default(); @@ -258,10 +240,10 @@ fn main() { let (z0_primary, circuits_primary) = EcdsaCircuit::<::Scalar>::new::< ::Base, ::Scalar, - >(num_steps, &signatures(), &pc); + >(num_steps, &signatures()); // Secondary circuit - let z0_secondary = ::Scalar::zero(); + let z0_secondary = vec![::Scalar::zero()]; // produce a recursive SNARK println!("Generating a RecursiveSNARK..."); @@ -277,8 +259,8 @@ fn main() { recursive_snark, circuit_primary.clone(), circuit_secondary.clone(), - z0_primary, - z0_secondary, + z0_primary.clone(), + z0_secondary.clone(), ); assert!(result.is_ok()); println!("RecursiveSNARK::prove_step {}: {:?}", i, result.is_ok()); @@ -290,7 +272,7 @@ fn main() { // verify the recursive SNARK println!("Verifying the RecursiveSNARK..."); - let res = recursive_snark.verify(&pp, num_steps, z0_primary, z0_secondary); + let res = recursive_snark.verify(&pp, num_steps, z0_primary.clone(), z0_secondary.clone()); println!("RecursiveSNARK::verify: {:?}", res.is_ok()); assert!(res.is_ok()); diff --git a/examples/minroot.rs b/examples/minroot.rs index d173608..761d561 100644 --- a/examples/minroot.rs +++ b/examples/minroot.rs @@ -5,12 +5,6 @@ type G1 = pasta_curves::pallas::Point; type G2 = pasta_curves::vesta::Point; use ::bellperson::{gadgets::num::AllocatedNum, ConstraintSystem, SynthesisError}; use ff::PrimeField; -use generic_array::typenum::U2; -use neptune::{ - circuit::poseidon_hash, - poseidon::{Poseidon, PoseidonConstants}, - Strength, -}; use nova_snark::{ traits::{ circuit::{StepCircuit, TrivialTestCircuit}, @@ -31,7 +25,7 @@ struct MinRootIteration { impl MinRootIteration { // produces a sample non-deterministic advice, executing one invocation of MinRoot per step - fn new(num_iters: usize, x_0: &F, y_0: &F, pc: &PoseidonConstants) -> (F, Vec) { + fn new(num_iters: usize, x_0: &F, y_0: &F) -> (Vec, Vec) { // although this code is written generically, it is tailored to Pallas' scalar field // (p - 3 / 5) let exp = BigUint::parse_bytes( @@ -65,7 +59,7 @@ impl MinRootIteration { y_i = y_i_plus_1; } - let z0 = Poseidon::::new_with_preimage(&[*x_0, *y_0], pc).hash(); + let z0 = vec![*x_0, *y_0]; (z0, res) } @@ -74,23 +68,27 @@ impl MinRootIteration { #[derive(Clone, Debug)] struct MinRootCircuit { seq: Vec>, - pc: PoseidonConstants, } impl StepCircuit for MinRootCircuit where F: PrimeField, { + fn arity(&self) -> usize { + 2 + } + fn synthesize>( &self, cs: &mut CS, - z: AllocatedNum, - ) -> Result, SynthesisError> { - let mut z_out: Result, SynthesisError> = Err(SynthesisError::AssignmentMissing); + z: &[AllocatedNum], + ) -> Result>, SynthesisError> { + let mut z_out: Result>, SynthesisError> = + Err(SynthesisError::AssignmentMissing); - // allocate variables to hold x_0 and y_0 - let x_0 = AllocatedNum::alloc(cs.namespace(|| "x_0"), || Ok(self.seq[0].x_i))?; - let y_0 = AllocatedNum::alloc(cs.namespace(|| "y_0"), || Ok(self.seq[0].y_i))?; + // use the provided inputs + let x_0 = z[0].clone(); + let y_0 = z[1].clone(); // variables to hold running x_i and y_i let mut x_i = x_0; @@ -102,21 +100,6 @@ where Ok(self.seq[i].x_i_plus_1) })?; - // check that z = hash(x_i, y_i), where z is an output from the prior step - if i == 0 { - let z_hash = poseidon_hash( - cs.namespace(|| "input hash"), - vec![x_i.clone(), y_i.clone()], - &self.pc, - )?; - cs.enforce( - || "z =? z_hash", - |lc| lc + z_hash.get_variable(), - |lc| lc + CS::one(), - |lc| lc + z.get_variable(), - ); - } - // check the following conditions hold: // (i) x_i_plus_1 = (x_i + y_i)^{1/5}, which can be more easily checked with x_i_plus_1^5 = x_i + y_i // (ii) y_i_plus_1 = x_i @@ -135,11 +118,7 @@ where // return hash(x_i_plus_1, y_i_plus_1) since Nova circuits expect a single output if i == self.seq.len() - 1 { - z_out = poseidon_hash( - cs.namespace(|| "output hash"), - vec![x_i_plus_1.clone(), x_i.clone()], - &self.pc, - ); + z_out = Ok(vec![x_i_plus_1.clone(), x_i.clone()]); } // update x_i and y_i for the next iteration @@ -150,22 +129,16 @@ where z_out } - fn output(&self, z: &F) -> F { + fn output(&self, z: &[F]) -> Vec { // sanity check - let z_hash = - Poseidon::::new_with_preimage(&[self.seq[0].x_i, self.seq[0].y_i], &self.pc).hash(); - debug_assert_eq!(z, &z_hash); - - // compute output hash using advice - let iters = self.seq.len(); - Poseidon::::new_with_preimage( - &[ - self.seq[iters - 1].x_i_plus_1, - self.seq[iters - 1].y_i_plus_1, - ], - &self.pc, - ) - .hash() + debug_assert_eq!(z[0], self.seq[0].x_i); + debug_assert_eq!(z[1], self.seq[0].y_i); + + // compute output using advice + vec![ + self.seq[self.seq.len() - 1].x_i_plus_1, + self.seq[self.seq.len() - 1].y_i_plus_1, + ] } } @@ -176,7 +149,6 @@ fn main() { let num_steps = 10; for num_iters_per_step in [1024, 2048, 4096, 8192, 16384, 32768, 65535] { // number of iterations of MinRoot per Nova's recursive step - let pc = PoseidonConstants::<::Scalar, U2>::new_with_strength(Strength::Standard); let circuit_primary = MinRootCircuit { seq: vec![ MinRootIteration { @@ -187,7 +159,6 @@ fn main() { }; num_iters_per_step ], - pc: pc.clone(), }; let circuit_secondary = TrivialTestCircuit::default(); @@ -228,7 +199,6 @@ fn main() { num_iters_per_step * num_steps, &::Scalar::zero(), &::Scalar::one(), - &pc, ); let minroot_circuits = (0..num_steps) .map(|i| MinRootCircuit { @@ -240,11 +210,10 @@ fn main() { y_i_plus_1: minroot_iterations[i * num_iters_per_step + j].y_i_plus_1, }) .collect::>(), - pc: pc.clone(), }) .collect::>(); - let z0_secondary = ::Scalar::zero(); + let z0_secondary = vec![::Scalar::zero()]; type C1 = MinRootCircuit<::Scalar>; type C2 = TrivialTestCircuit<::Scalar>; @@ -259,8 +228,8 @@ fn main() { recursive_snark, circuit_primary.clone(), circuit_secondary.clone(), - z0_primary, - z0_secondary, + z0_primary.clone(), + z0_secondary.clone(), ); assert!(res.is_ok()); println!( @@ -278,7 +247,7 @@ fn main() { // verify the recursive SNARK println!("Verifying a RecursiveSNARK..."); let start = Instant::now(); - let res = recursive_snark.verify(&pp, num_steps, z0_primary, z0_secondary); + let res = recursive_snark.verify(&pp, num_steps, z0_primary.clone(), z0_secondary.clone()); println!( "RecursiveSNARK::verify: {:?}, took {:?}", res.is_ok(), diff --git a/src/circuit.rs b/src/circuit.rs index 6213453..92e67eb 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -8,12 +8,12 @@ use super::{ commitments::Commitment, - constants::{NUM_FE_FOR_HASH, NUM_HASH_BITS}, + constants::{NUM_FE_WITHOUT_IO_FOR_CRHF, NUM_HASH_BITS}, gadgets::{ ecc::AllocatedPoint, r1cs::{AllocatedR1CSInstance, AllocatedRelaxedR1CSInstance}, utils::{ - alloc_num_equals, alloc_scalar_as_base, alloc_zero, conditionally_select, le_bits_to_num, + alloc_num_equals, alloc_scalar_as_base, alloc_zero, conditionally_select_vec, le_bits_to_num, }, }, r1cs::{R1CSInstance, RelaxedR1CSInstance}, @@ -50,8 +50,8 @@ impl NovaAugmentedCircuitParams { pub struct NovaAugmentedCircuitInputs { params: G::Scalar, // Hash(Shape of u2, Gens for u2). Needed for computing the challenge. i: G::Base, - z0: G::Base, - zi: Option, + z0: Vec, + zi: Option>, U: Option>, u: Option>, T: Option>, @@ -66,8 +66,8 @@ where pub fn new( params: G::Scalar, i: G::Base, - z0: G::Base, - zi: Option, + z0: Vec, + zi: Option>, U: Option>, u: Option>, T: Option>, @@ -121,12 +121,13 @@ where fn alloc_witness::Base>>( &self, mut cs: CS, + arity: usize, ) -> Result< ( AllocatedNum, AllocatedNum, - AllocatedNum, - AllocatedNum, + Vec>, + Vec>, AllocatedRelaxedR1CSInstance, AllocatedR1CSInstance, AllocatedPoint, @@ -143,12 +144,23 @@ where let i = AllocatedNum::alloc(cs.namespace(|| "i"), || Ok(self.inputs.get()?.i))?; // Allocate z0 - let z_0 = AllocatedNum::alloc(cs.namespace(|| "z0"), || Ok(self.inputs.get()?.z0))?; + let z_0 = (0..arity) + .map(|i| { + AllocatedNum::alloc(cs.namespace(|| format!("z0_{}", i)), || { + Ok(self.inputs.get()?.z0[i]) + }) + }) + .collect::>, _>>()?; // Allocate zi. If inputs.zi is not provided (base case) allocate default value 0 - let z_i = AllocatedNum::alloc(cs.namespace(|| "zi"), || { - Ok(self.inputs.get()?.zi.unwrap_or_else(G::Base::zero)) - })?; + let zero = vec![G::Base::zero(); arity]; + let z_i = (0..arity) + .map(|i| { + AllocatedNum::alloc(cs.namespace(|| format!("zi_{}", i)), || { + Ok(self.inputs.get()?.zi.as_ref().unwrap_or(&zero)[i]) + }) + }) + .collect::>, _>>()?; // Allocate the running instance let U: AllocatedRelaxedR1CSInstance = AllocatedRelaxedR1CSInstance::alloc( @@ -215,18 +227,26 @@ where mut cs: CS, params: AllocatedNum, i: AllocatedNum, - z_0: AllocatedNum, - z_i: AllocatedNum, + z_0: Vec>, + z_i: Vec>, U: AllocatedRelaxedR1CSInstance, u: AllocatedR1CSInstance, T: AllocatedPoint, + arity: usize, ) -> Result<(AllocatedRelaxedR1CSInstance, AllocatedBit), SynthesisError> { // Check that u.x[0] = Hash(params, U, i, z0, zi) - let mut ro = G::ROCircuit::new(self.ro_consts.clone(), NUM_FE_FOR_HASH); + let mut ro = G::ROCircuit::new( + self.ro_consts.clone(), + NUM_FE_WITHOUT_IO_FOR_CRHF + 2 * arity, + ); ro.absorb(params.clone()); ro.absorb(i); - ro.absorb(z_0); - ro.absorb(z_i); + for e in z_0 { + ro.absorb(e); + } + for e in z_i { + ro.absorb(e); + } U.absorb_in_ro(cs.namespace(|| "absorb U"), &mut ro)?; let hash_bits = ro.squeeze(cs.namespace(|| "Input hash"), NUM_HASH_BITS)?; @@ -261,9 +281,11 @@ where self, cs: &mut CS, ) -> Result<(), SynthesisError> { + let arity = self.step_circuit.arity(); + // Allocate all witnesses let (params, i, z_0, z_i, U, u, T) = - self.alloc_witness(cs.namespace(|| "allocate the circuit witness"))?; + self.alloc_witness(cs.namespace(|| "allocate the circuit witness"), arity)?; // Compute variable indicating if this is the base case let zero = alloc_zero(cs.namespace(|| "zero"))?; @@ -283,6 +305,7 @@ where U, u.clone(), T, + arity, )?; // Either check_non_base_pass=true or we are in the base case @@ -317,7 +340,7 @@ where ); // Compute z_{i+1} - let z_input = conditionally_select( + let z_input = conditionally_select_vec( cs.namespace(|| "select input to F"), &z_0, &z_i, @@ -326,14 +349,24 @@ where let z_next = self .step_circuit - .synthesize(&mut cs.namespace(|| "F"), z_input)?; + .synthesize(&mut cs.namespace(|| "F"), &z_input)?; + + if z_next.len() != arity { + return Err(SynthesisError::IncompatibleLengthVector( + "z_next".to_string(), + )); + } // Compute the new hash H(params, Unew, i+1, z0, z_{i+1}) - let mut ro = G::ROCircuit::new(self.ro_consts, NUM_FE_FOR_HASH); + let mut ro = G::ROCircuit::new(self.ro_consts, NUM_FE_WITHOUT_IO_FOR_CRHF + 2 * arity); ro.absorb(params); ro.absorb(i_new.clone()); - ro.absorb(z_0); - ro.absorb(z_next); + for e in z_0 { + ro.absorb(e); + } + for e in z_next { + ro.absorb(e); + } Unew.absorb_in_ro(cs.namespace(|| "absorb U_new"), &mut ro)?; let hash_bits = ro.squeeze(cs.namespace(|| "output hash bits"), NUM_HASH_BITS)?; let hash = le_bits_to_num(cs.namespace(|| "convert hash to num"), hash_bits)?; @@ -397,8 +430,15 @@ mod tests { // Execute the base case for the primary let zero1 = <::Base as Field>::zero(); let mut cs1: SatisfyingAssignment = SatisfyingAssignment::new(); - let inputs1: NovaAugmentedCircuitInputs = - NovaAugmentedCircuitInputs::new(shape2.get_digest(), zero1, zero1, None, None, None, None); + let inputs1: NovaAugmentedCircuitInputs = NovaAugmentedCircuitInputs::new( + shape2.get_digest(), + zero1, + vec![zero1], + None, + None, + None, + None, + ); let circuit1: NovaAugmentedCircuit::Base>> = NovaAugmentedCircuit::new( params1, @@ -417,7 +457,7 @@ mod tests { let inputs2: NovaAugmentedCircuitInputs = NovaAugmentedCircuitInputs::new( shape1.get_digest(), zero2, - zero2, + vec![zero2], None, None, Some(inst1), diff --git a/src/constants.rs b/src/constants.rs index aeb53e6..f0f7fbd 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -2,5 +2,5 @@ pub(crate) const NUM_CHALLENGE_BITS: usize = 128; pub(crate) const NUM_HASH_BITS: usize = 250; pub(crate) const BN_LIMB_WIDTH: usize = 64; pub(crate) const BN_N_LIMBS: usize = 4; -pub(crate) const NUM_FE_FOR_HASH: usize = 19; +pub(crate) const NUM_FE_WITHOUT_IO_FOR_CRHF: usize = 17; pub(crate) const NUM_FE_FOR_RO: usize = 24; diff --git a/src/errors.rs b/src/errors.rs index efde3e5..cc2081b 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -24,4 +24,8 @@ pub enum NovaError { InvalidIPA, /// returned when an invalid sum-check proof is provided InvalidSumcheckProof, + /// returned when the initial input to an incremental computation differs from a previously declared arity + InvalidInitialInputLength, + /// returned when the step execution produces an output whose length differs from a previously declared arity + InvalidStepOutputLength, } diff --git a/src/gadgets/utils.rs b/src/gadgets/utils.rs index ce31154..926b8a0 100644 --- a/src/gadgets/utils.rs +++ b/src/gadgets/utils.rs @@ -211,6 +211,22 @@ pub fn conditionally_select>( Ok(c) } +/// If condition return a otherwise b +pub fn conditionally_select_vec>( + mut cs: CS, + a: &[AllocatedNum], + b: &[AllocatedNum], + condition: &Boolean, +) -> Result>, SynthesisError> { + a.iter() + .zip(b.iter()) + .enumerate() + .map(|(i, (a, b))| { + conditionally_select(cs.namespace(|| format!("select_{}", i)), a, b, condition) + }) + .collect::>, SynthesisError>>() +} + /// If condition return a otherwise b where a and b are BigNats pub fn conditionally_select_bignat>( mut cs: CS, diff --git a/src/lib.rs b/src/lib.rs index b40d0d9..f08aec5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,7 +26,7 @@ use crate::bellperson::{ }; use ::bellperson::{Circuit, ConstraintSystem}; use circuit::{NovaAugmentedCircuit, NovaAugmentedCircuitInputs, NovaAugmentedCircuitParams}; -use constants::{BN_LIMB_WIDTH, BN_N_LIMBS, NUM_FE_FOR_HASH, NUM_HASH_BITS}; +use constants::{BN_LIMB_WIDTH, BN_N_LIMBS, NUM_FE_WITHOUT_IO_FOR_CRHF, NUM_HASH_BITS}; use core::marker::PhantomData; use errors::NovaError; use ff::Field; @@ -48,6 +48,8 @@ where C1: StepCircuit, C2: StepCircuit, { + F_arity_primary: usize, + F_arity_secondary: usize, ro_consts_primary: ROConstants, ro_consts_circuit_primary: ROConstantsCircuit, r1cs_gens_primary: R1CSGens, @@ -81,6 +83,9 @@ where let ro_consts_primary: ROConstants = ROConstants::::new(); let ro_consts_secondary: ROConstants = ROConstants::::new(); + let F_arity_primary = c_primary.arity(); + let F_arity_secondary = c_secondary.arity(); + // ro_consts_circuit_primart are parameterized by G2 because the type alias uses G2::Base = G1::Scalar let ro_consts_circuit_primary: ROConstantsCircuit = ROConstantsCircuit::::new(); let ro_consts_circuit_secondary: ROConstantsCircuit = ROConstantsCircuit::::new(); @@ -110,6 +115,8 @@ where let r1cs_shape_padded_secondary = r1cs_shape_secondary.pad(); Self { + F_arity_primary, + F_arity_secondary, ro_consts_primary, ro_consts_circuit_primary, r1cs_gens_primary, @@ -162,8 +169,8 @@ where l_w_secondary: R1CSWitness, l_u_secondary: R1CSInstance, i: usize, - zi_primary: G1::Scalar, - zi_secondary: G2::Scalar, + zi_primary: Vec, + zi_secondary: Vec, _p_c1: PhantomData, _p_c2: PhantomData, } @@ -182,9 +189,13 @@ where recursive_snark: Option, c_primary: C1, c_secondary: C2, - z0_primary: G1::Scalar, - z0_secondary: G2::Scalar, + z0_primary: Vec, + z0_secondary: Vec, ) -> Result { + if z0_primary.len() != pp.F_arity_primary || z0_secondary.len() != pp.F_arity_secondary { + return Err(NovaError::InvalidInitialInputLength); + } + match recursive_snark { None => { // base case for the primary @@ -192,7 +203,7 @@ where let inputs_primary: NovaAugmentedCircuitInputs = NovaAugmentedCircuitInputs::new( pp.r1cs_shape_secondary.get_digest(), G1::Scalar::zero(), - z0_primary, + z0_primary.clone(), None, None, None, @@ -215,7 +226,7 @@ where let inputs_secondary: NovaAugmentedCircuitInputs = NovaAugmentedCircuitInputs::new( pp.r1cs_shape_primary.get_digest(), G2::Scalar::zero(), - z0_secondary, + z0_secondary.clone(), None, None, Some(u_primary.clone()), @@ -254,6 +265,10 @@ where let zi_primary = c_primary.output(&z0_primary); let zi_secondary = c_secondary.output(&z0_secondary); + if z0_primary.len() != pp.F_arity_primary || z0_secondary.len() != pp.F_arity_secondary { + return Err(NovaError::InvalidStepOutputLength); + } + Ok(Self { r_W_primary, r_U_primary, @@ -287,7 +302,7 @@ where pp.r1cs_shape_secondary.get_digest(), G1::Scalar::from(r_snark.i as u64), z0_primary, - Some(r_snark.zi_primary), + Some(r_snark.zi_primary.clone()), Some(r_snark.r_U_secondary.clone()), Some(r_snark.l_u_secondary.clone()), Some(nifs_secondary.comm_T.decompress()?), @@ -321,7 +336,7 @@ where pp.r1cs_shape_primary.get_digest(), G2::Scalar::from(r_snark.i as u64), z0_secondary, - Some(r_snark.zi_secondary), + Some(r_snark.zi_secondary.clone()), Some(r_snark.r_U_primary.clone()), Some(l_u_primary.clone()), Some(nifs_primary.comm_T.decompress()?), @@ -367,9 +382,9 @@ where &self, pp: &PublicParams, num_steps: usize, - z0_primary: G1::Scalar, - z0_secondary: G2::Scalar, - ) -> Result<(G1::Scalar, G2::Scalar), NovaError> { + z0_primary: Vec, + z0_secondary: Vec, + ) -> Result<(Vec, Vec), NovaError> { // number of steps cannot be zero if num_steps == 0 { return Err(NovaError::ProofVerifyError); @@ -391,18 +406,32 @@ where // check if the output hashes in R1CS instances point to the right running instances let (hash_primary, hash_secondary) = { - let mut hasher = ::RO::new(pp.ro_consts_secondary.clone(), NUM_FE_FOR_HASH); + let mut hasher = ::RO::new( + pp.ro_consts_secondary.clone(), + NUM_FE_WITHOUT_IO_FOR_CRHF + 2 * pp.F_arity_primary, + ); hasher.absorb(scalar_as_base::(pp.r1cs_shape_secondary.get_digest())); hasher.absorb(G1::Scalar::from(num_steps as u64)); - hasher.absorb(z0_primary); - hasher.absorb(self.zi_primary); + for e in &z0_primary { + hasher.absorb(*e); + } + for e in &self.zi_primary { + hasher.absorb(*e); + } self.r_U_secondary.absorb_in_ro(&mut hasher); - let mut hasher2 = ::RO::new(pp.ro_consts_primary.clone(), NUM_FE_FOR_HASH); + let mut hasher2 = ::RO::new( + pp.ro_consts_primary.clone(), + NUM_FE_WITHOUT_IO_FOR_CRHF + 2 * pp.F_arity_secondary, + ); hasher2.absorb(scalar_as_base::(pp.r1cs_shape_primary.get_digest())); hasher2.absorb(G2::Scalar::from(num_steps as u64)); - hasher2.absorb(z0_secondary); - hasher2.absorb(self.zi_secondary); + for e in &z0_secondary { + hasher2.absorb(*e); + } + for e in &self.zi_secondary { + hasher2.absorb(*e); + } self.r_U_primary.absorb_in_ro(&mut hasher2); ( @@ -463,7 +492,7 @@ where res_r_secondary?; res_l_secondary?; - Ok((self.zi_primary, self.zi_secondary)) + Ok((self.zi_primary.clone(), self.zi_secondary.clone())) } } @@ -488,8 +517,8 @@ where nifs_secondary: NIFS, f_W_snark_secondary: S2, - zn_primary: G1::Scalar, - zn_secondary: G2::Scalar, + zn_primary: Vec, + zn_secondary: Vec, _p_c1: PhantomData, _p_c2: PhantomData, @@ -574,8 +603,8 @@ where nifs_secondary, f_W_snark_secondary: f_W_snark_secondary?, - zn_primary: recursive_snark.zi_primary, - zn_secondary: recursive_snark.zi_secondary, + zn_primary: recursive_snark.zi_primary.clone(), + zn_secondary: recursive_snark.zi_secondary.clone(), _p_c1: Default::default(), _p_c2: Default::default(), @@ -587,9 +616,9 @@ where &self, pp: &PublicParams, num_steps: usize, - z0_primary: G1::Scalar, - z0_secondary: G2::Scalar, - ) -> Result<(G1::Scalar, G2::Scalar), NovaError> { + z0_primary: Vec, + z0_secondary: Vec, + ) -> Result<(Vec, Vec), NovaError> { // number of steps cannot be zero if num_steps == 0 { return Err(NovaError::ProofVerifyError); @@ -606,18 +635,32 @@ where // check if the output hashes in R1CS instances point to the right running instances let (hash_primary, hash_secondary) = { - let mut hasher = ::RO::new(pp.ro_consts_secondary.clone(), NUM_FE_FOR_HASH); + let mut hasher = ::RO::new( + pp.ro_consts_secondary.clone(), + NUM_FE_WITHOUT_IO_FOR_CRHF + 2 * pp.F_arity_primary, + ); hasher.absorb(scalar_as_base::(pp.r1cs_shape_secondary.get_digest())); hasher.absorb(G1::Scalar::from(num_steps as u64)); - hasher.absorb(z0_primary); - hasher.absorb(self.zn_primary); + for e in z0_primary { + hasher.absorb(e); + } + for e in &self.zn_primary { + hasher.absorb(*e); + } self.r_U_secondary.absorb_in_ro(&mut hasher); - let mut hasher2 = ::RO::new(pp.ro_consts_primary.clone(), NUM_FE_FOR_HASH); + let mut hasher2 = ::RO::new( + pp.ro_consts_primary.clone(), + NUM_FE_WITHOUT_IO_FOR_CRHF + 2 * pp.F_arity_secondary, + ); hasher2.absorb(scalar_as_base::(pp.r1cs_shape_primary.get_digest())); hasher2.absorb(G2::Scalar::from(num_steps as u64)); - hasher2.absorb(z0_secondary); - hasher2.absorb(self.zn_secondary); + for e in z0_secondary { + hasher2.absorb(e); + } + for e in &self.zn_secondary { + hasher2.absorb(*e); + } self.r_U_primary.absorb_in_ro(&mut hasher2); ( @@ -665,7 +708,7 @@ where res_primary?; res_secondary?; - Ok((self.zn_primary, self.zn_secondary)) + Ok((self.zn_primary.clone(), self.zn_secondary.clone())) } } @@ -690,15 +733,19 @@ mod tests { where F: PrimeField, { + fn arity(&self) -> usize { + 1 + } + fn synthesize>( &self, cs: &mut CS, - z: AllocatedNum, - ) -> Result, SynthesisError> { + z: &[AllocatedNum], + ) -> Result>, SynthesisError> { // Consider a cubic equation: `x^3 + x + 5 = y`, where `x` and `y` are respectively the input and output. - let x = z; + let x = &z[0]; let x_sq = x.square(cs.namespace(|| "x_sq"))?; - let x_cu = x_sq.mul(cs.namespace(|| "x_cu"), &x)?; + let x_cu = x_sq.mul(cs.namespace(|| "x_cu"), x)?; let y = AllocatedNum::alloc(cs.namespace(|| "y"), || { Ok(x_cu.get_value().unwrap() + x.get_value().unwrap() + F::from(5u64)) })?; @@ -718,11 +765,11 @@ mod tests { |lc| lc + y.get_variable(), ); - Ok(y) + Ok(vec![y]) } - fn output(&self, z: &F) -> F { - *z * *z * *z + z + F::from(5u64) + fn output(&self, z: &[F]) -> Vec { + vec![z[0] * z[0] * z[0] + z[0] + F::from(5u64)] } } @@ -744,8 +791,8 @@ mod tests { None, TrivialTestCircuit::default(), TrivialTestCircuit::default(), - ::Scalar::zero(), - ::Scalar::zero(), + vec![::Scalar::zero()], + vec![::Scalar::zero()], ); assert!(res.is_ok()); let recursive_snark = res.unwrap(); @@ -754,8 +801,8 @@ mod tests { let res = recursive_snark.verify( &pp, num_steps, - ::Scalar::zero(), - ::Scalar::zero(), + vec![::Scalar::zero()], + vec![::Scalar::zero()], ); assert!(res.is_ok()); } @@ -791,8 +838,8 @@ mod tests { recursive_snark, circuit_primary.clone(), circuit_secondary.clone(), - ::Scalar::one(), - ::Scalar::zero(), + vec![::Scalar::one()], + vec![::Scalar::zero()], ); assert!(res.is_ok()); let recursive_snark_unwrapped = res.unwrap(); @@ -801,8 +848,8 @@ mod tests { let res = recursive_snark_unwrapped.verify( &pp, i + 1, - ::Scalar::one(), - ::Scalar::zero(), + vec![::Scalar::one()], + vec![::Scalar::zero()], ); assert!(res.is_ok()); @@ -817,21 +864,21 @@ mod tests { let res = recursive_snark.verify( &pp, num_steps, - ::Scalar::one(), - ::Scalar::zero(), + vec![::Scalar::one()], + vec![::Scalar::zero()], ); assert!(res.is_ok()); let (zn_primary, zn_secondary) = res.unwrap(); // sanity: check the claimed output with a direct computation of the same - assert_eq!(zn_primary, ::Scalar::one()); - let mut zn_secondary_direct = ::Scalar::zero(); + assert_eq!(zn_primary, vec![::Scalar::one()]); + let mut zn_secondary_direct = vec![::Scalar::zero()]; for _i in 0..num_steps { zn_secondary_direct = CubicCircuit::default().output(&zn_secondary_direct); } assert_eq!(zn_secondary, zn_secondary_direct); - assert_eq!(zn_secondary, ::Scalar::from(2460515u64)); + assert_eq!(zn_secondary, vec![::Scalar::from(2460515u64)]); } #[test] @@ -865,8 +912,8 @@ mod tests { recursive_snark, circuit_primary.clone(), circuit_secondary.clone(), - ::Scalar::one(), - ::Scalar::zero(), + vec![::Scalar::one()], + vec![::Scalar::zero()], ); assert!(res.is_ok()); recursive_snark = Some(res.unwrap()); @@ -879,21 +926,21 @@ mod tests { let res = recursive_snark.verify( &pp, num_steps, - ::Scalar::one(), - ::Scalar::zero(), + vec![::Scalar::one()], + vec![::Scalar::zero()], ); assert!(res.is_ok()); let (zn_primary, zn_secondary) = res.unwrap(); // sanity: check the claimed output with a direct computation of the same - assert_eq!(zn_primary, ::Scalar::one()); - let mut zn_secondary_direct = ::Scalar::zero(); + assert_eq!(zn_primary, vec![::Scalar::one()]); + let mut zn_secondary_direct = vec![::Scalar::zero()]; for _i in 0..num_steps { zn_secondary_direct = CubicCircuit::default().output(&zn_secondary_direct); } assert_eq!(zn_secondary, zn_secondary_direct); - assert_eq!(zn_secondary, ::Scalar::from(2460515u64)); + assert_eq!(zn_secondary, vec![::Scalar::from(2460515u64)]); // produce a compressed SNARK let res = CompressedSNARK::<_, _, _, _, S1, S2>::prove(&pp, &recursive_snark); @@ -904,8 +951,8 @@ mod tests { let res = compressed_snark.verify( &pp, num_steps, - ::Scalar::one(), - ::Scalar::zero(), + vec![::Scalar::one()], + vec![::Scalar::zero()], ); assert!(res.is_ok()); } @@ -922,7 +969,7 @@ mod tests { where F: PrimeField, { - fn new(num_steps: usize) -> (F, Vec) { + fn new(num_steps: usize) -> (Vec, Vec) { let mut powers = Vec::new(); let rng = &mut rand::rngs::OsRng; let mut seed = F::random(rng); @@ -939,7 +986,7 @@ mod tests { // reverse the powers to get roots let roots = powers.into_iter().rev().collect::>(); - (roots[0].y, roots[1..].to_vec()) + (vec![roots[0].y], roots[1..].to_vec()) } } @@ -947,12 +994,16 @@ mod tests { where F: PrimeField, { + fn arity(&self) -> usize { + 1 + } + fn synthesize>( &self, cs: &mut CS, - z: AllocatedNum, - ) -> Result, SynthesisError> { - let x = z; + z: &[AllocatedNum], + ) -> Result>, SynthesisError> { + let x = &z[0]; // we allocate a variable and set it to the provided non-derministic advice. let y = AllocatedNum::alloc(cs.namespace(|| "y"), || Ok(self.y))?; @@ -969,12 +1020,12 @@ mod tests { |lc| lc + x.get_variable(), ); - Ok(y) + Ok(vec![y]) } - fn output(&self, z: &F) -> F { + fn output(&self, z: &[F]) -> Vec { // sanity check - let x = *z; + let x = z[0]; let y_pow_5 = { let y = self.y; let y_sq = y.square(); @@ -985,7 +1036,7 @@ mod tests { // return non-deterministic advice // as the output of the step - self.y + vec![self.y] } } @@ -1007,7 +1058,7 @@ mod tests { // produce non-deterministic advice let (z0_primary, roots) = FifthRootCheckingCircuit::new(num_steps); - let z0_secondary = ::Scalar::zero(); + let z0_secondary = vec![::Scalar::zero()]; // produce a recursive SNARK let mut recursive_snark: Option< @@ -1025,8 +1076,8 @@ mod tests { recursive_snark, circuit_primary.clone(), circuit_secondary.clone(), - z0_primary, - z0_secondary, + z0_primary.clone(), + z0_secondary.clone(), ); assert!(res.is_ok()); recursive_snark = Some(res.unwrap()); @@ -1036,7 +1087,7 @@ mod tests { let recursive_snark = recursive_snark.unwrap(); // verify the recursive SNARK - let res = recursive_snark.verify(&pp, num_steps, z0_primary, z0_secondary); + let res = recursive_snark.verify(&pp, num_steps, z0_primary.clone(), z0_secondary.clone()); assert!(res.is_ok()); // produce a compressed SNARK @@ -1067,8 +1118,8 @@ mod tests { None, TrivialTestCircuit::default(), CubicCircuit::default(), - ::Scalar::one(), - ::Scalar::zero(), + vec![::Scalar::one()], + vec![::Scalar::zero()], ); assert!(res.is_ok()); let recursive_snark = res.unwrap(); @@ -1077,14 +1128,14 @@ mod tests { let res = recursive_snark.verify( &pp, num_steps, - ::Scalar::one(), - ::Scalar::zero(), + vec![::Scalar::one()], + vec![::Scalar::zero()], ); assert!(res.is_ok()); let (zn_primary, zn_secondary) = res.unwrap(); - assert_eq!(zn_primary, ::Scalar::one()); - assert_eq!(zn_secondary, ::Scalar::from(5u64)); + assert_eq!(zn_primary, vec![::Scalar::one()]); + assert_eq!(zn_secondary, vec![::Scalar::from(5u64)]); } } diff --git a/src/poseidon.rs b/src/poseidon.rs index ca1740d..167910c 100644 --- a/src/poseidon.rs +++ b/src/poseidon.rs @@ -86,7 +86,7 @@ where SpongeOp::Squeeze(1u32), ]); - sponge.start(parameter, Some(1u32), acc); + sponge.start(parameter, None, acc); assert_eq!(self.num_absorbs, self.state.len()); SpongeAPI::absorb(&mut sponge, self.num_absorbs as u32, &self.state, acc); let hash = SpongeAPI::squeeze(&mut sponge, 1, acc); @@ -163,7 +163,7 @@ where let acc = &mut ns; assert_eq!(self.num_absorbs, self.state.len()); - sponge.start(parameter, Some(1u32), acc); + sponge.start(parameter, None, acc); neptune::sponge::api::SpongeAPI::absorb( &mut sponge, self.num_absorbs as u32, diff --git a/src/spartan_with_ipa_pc/mod.rs b/src/spartan_with_ipa_pc/mod.rs index e3e245b..754e61b 100644 --- a/src/spartan_with_ipa_pc/mod.rs +++ b/src/spartan_with_ipa_pc/mod.rs @@ -323,8 +323,7 @@ impl RelaxedR1CSSNARKTrait for RelaxedR1CSSNARK { .map(|i| (i + 1, U.X[i])) .collect::>(), ); - SparsePolynomial::new((vk.S.num_vars as f64).log2() as usize, poly_X) - .evaluate(&r_y[1..].to_vec()) + SparsePolynomial::new((vk.S.num_vars as f64).log2() as usize, poly_X).evaluate(&r_y[1..]) }; (G::Scalar::one() - r_y[0]) * self.eval_W + r_y[0] * eval_X }; diff --git a/src/traits/circuit.rs b/src/traits/circuit.rs index b5dde10..f800a7e 100644 --- a/src/traits/circuit.rs +++ b/src/traits/circuit.rs @@ -5,16 +5,22 @@ use ff::PrimeField; /// A helper trait for a step of the incremental computation (i.e., circuit for F) pub trait StepCircuit: Send + Sync + Clone { + /// Return the the number of inputs or outputs of each step + /// (this method is called only at circuit synthesis time) + /// `synthesize` and `output` methods are expected to take as + /// input a vector of size equal to arity and output a vector of size equal to arity + fn arity(&self) -> usize; + /// Sythesize the circuit for a computation step and return variable /// that corresponds to the output of the step z_{i+1} fn synthesize>( &self, cs: &mut CS, - z: AllocatedNum, - ) -> Result, SynthesisError>; + z: &[AllocatedNum], + ) -> Result>, SynthesisError>; /// return the output of the step when provided with with the step's input - fn output(&self, z: &F) -> F; + fn output(&self, z: &[F]) -> Vec; } /// A trivial step circuit that simply returns the input @@ -27,15 +33,19 @@ impl StepCircuit for TrivialTestCircuit where F: PrimeField, { + fn arity(&self) -> usize { + 1 + } + fn synthesize>( &self, _cs: &mut CS, - z: AllocatedNum, - ) -> Result, SynthesisError> { - Ok(z) + z: &[AllocatedNum], + ) -> Result>, SynthesisError> { + Ok(z.to_vec()) } - fn output(&self, z: &F) -> F { - *z + fn output(&self, z: &[F]) -> Vec { + z.to_vec() } }