From 188a7c56403bd4980e6ce3d9222f4a8307fc4883 Mon Sep 17 00:00:00 2001 From: Srinath Setty Date: Tue, 17 May 2022 18:38:42 +0530 Subject: [PATCH] Add a non-trivial step circuit (#66) --- src/lib.rs | 89 ++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 83 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 13e9d0f..385095f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -348,11 +348,11 @@ mod tests { use std::marker::PhantomData; #[derive(Clone, Debug)] - struct TestCircuit { + struct TrivialTestCircuit { _p: PhantomData, } - impl StepCircuit for TestCircuit + impl StepCircuit for TrivialTestCircuit where F: PrimeField, { @@ -369,19 +369,96 @@ mod tests { } } + #[derive(Clone, Debug)] + struct CubicCircuit { + _p: PhantomData, + } + + impl StepCircuit for CubicCircuit + where + F: PrimeField, + { + fn synthesize>( + &self, + cs: &mut CS, + z: AllocatedNum, + ) -> Result, SynthesisError> { + // Consider a cubic equation: `x^3 + x + 5 = y`, where `x` and `y` are respectively the input and output. + let x = z; + let x_sq = x.square(cs.namespace(|| "x_sq"))?; + let x_cu = x_sq.mul(cs.namespace(|| "x_cu"), &x)?; + let y = AllocatedNum::alloc(cs.namespace(|| "y"), || { + Ok(x_cu.get_value().unwrap() + x.get_value().unwrap() + F::from(5u64)) + })?; + + cs.enforce( + || "y = x^3 + x + 5", + |lc| { + lc + x_cu.get_variable() + + x.get_variable() + + CS::one() + + CS::one() + + CS::one() + + CS::one() + + CS::one() + }, + |lc| lc + CS::one(), + |lc| lc + y.get_variable(), + ); + + Ok(y) + } + + fn compute(&self, z: &F) -> F { + *z * *z * *z + z + F::from(5u64) + } + } + + #[test] + fn test_ivc_trivial() { + // produce public parameters + let pp = PublicParams::< + G1, + G2, + TrivialTestCircuit<::Base>, + TrivialTestCircuit<::Base>, + >::setup( + TrivialTestCircuit { + _p: Default::default(), + }, + TrivialTestCircuit { + _p: Default::default(), + }, + ); + + // produce a recursive SNARK + let res = RecursiveSNARK::prove( + &pp, + ::Base::zero(), + ::Base::zero(), + 3, + ); + assert!(res.is_ok()); + let recursive_snark = res.unwrap(); + + // verify the recursive SNARK + let res = recursive_snark.verify(&pp); + assert!(res.is_ok()); + } + #[test] fn test_ivc() { // produce public parameters let pp = PublicParams::< G1, G2, - TestCircuit<::Base>, - TestCircuit<::Base>, + TrivialTestCircuit<::Base>, + CubicCircuit<::Base>, >::setup( - TestCircuit { + TrivialTestCircuit { _p: Default::default(), }, - TestCircuit { + CubicCircuit { _p: Default::default(), }, );