Improve performance of recursive (#163)

* Improve performance of recursive

* Fix the test after rebase

* Fix CI/CD warnings

* Update benchmark to work with new interface of RecursiveSNARK

* Fix example to make sure step 1 is correct

* refactor: Removes unneeded pass-by value in verification

- Update function arguments to use borrowing instead of passing ownership

* Resolve the conflict with upstream branch

* refactor: Avoid extra input cloning in RecursiveSNARK::new

* Update criterion to 0.5.1 to prevent the panic with its plot

* Fix benchmark issue with new recursive_snark instance

* Fix CI/CD warning with

* refactor: Make mutation easier to observe

- Utilize mutable references to Points for better memory management

* chore: Downgrade clippy dependency for compatibility

---------

Co-authored-by: François Garillot <francois@garillot.net>
This commit is contained in:
Chiro Hiro
2023-06-20 02:52:57 +07:00
committed by GitHub
parent 031738de51
commit af886d6ce7
7 changed files with 372 additions and 351 deletions

View File

@@ -37,7 +37,7 @@ thiserror = "1.0"
pasta-msm = { version = "0.1.4" } pasta-msm = { version = "0.1.4" }
[dev-dependencies] [dev-dependencies]
criterion = "0.3.1" criterion = { version = "0.4", features = ["html_reports"] }
rand = "0.8.4" rand = "0.8.4"
hex = "0.4.3" hex = "0.4.3"

View File

@@ -43,46 +43,46 @@ fn bench_compressed_snark(c: &mut Criterion) {
let mut group = c.benchmark_group(format!("CompressedSNARK-StepCircuitSize-{num_cons}")); let mut group = c.benchmark_group(format!("CompressedSNARK-StepCircuitSize-{num_cons}"));
group.sample_size(num_samples); group.sample_size(num_samples);
let c_primary = NonTrivialTestCircuit::new(num_cons);
let c_secondary = TrivialTestCircuit::default();
// Produce public parameters // Produce public parameters
let pp = PublicParams::<G1, G2, C1, C2>::setup( let pp = PublicParams::<G1, G2, C1, C2>::setup(c_primary.clone(), c_secondary.clone());
NonTrivialTestCircuit::new(num_cons),
TrivialTestCircuit::default(),
);
// Produce prover and verifier keys for CompressedSNARK // Produce prover and verifier keys for CompressedSNARK
let (pk, vk) = CompressedSNARK::<_, _, _, _, S1, S2>::setup(&pp).unwrap(); let (pk, vk) = CompressedSNARK::<_, _, _, _, S1, S2>::setup(&pp).unwrap();
// produce a recursive SNARK // produce a recursive SNARK
let num_steps = 3; let num_steps = 3;
let mut recursive_snark: Option<RecursiveSNARK<G1, G2, C1, C2>> = None; let mut recursive_snark: RecursiveSNARK<G1, G2, C1, C2> = RecursiveSNARK::new(
&pp,
&c_primary,
&c_secondary,
vec![<G1 as Group>::Scalar::from(2u64)],
vec![<G2 as Group>::Scalar::from(2u64)],
);
for i in 0..num_steps { for i in 0..num_steps {
let res = RecursiveSNARK::prove_step( let res = recursive_snark.prove_step(
&pp, &pp,
recursive_snark, &c_primary,
NonTrivialTestCircuit::new(num_cons), &c_secondary,
TrivialTestCircuit::default(),
vec![<G1 as Group>::Scalar::from(2u64)], vec![<G1 as Group>::Scalar::from(2u64)],
vec![<G2 as Group>::Scalar::from(2u64)], vec![<G2 as Group>::Scalar::from(2u64)],
); );
assert!(res.is_ok()); assert!(res.is_ok());
let recursive_snark_unwrapped = res.unwrap();
// verify the recursive snark at each step of recursion // verify the recursive snark at each step of recursion
let res = recursive_snark_unwrapped.verify( let res = recursive_snark.verify(
&pp, &pp,
i + 1, i + 1,
vec![<G1 as Group>::Scalar::from(2u64)], &vec![<G1 as Group>::Scalar::from(2u64)][..],
vec![<G2 as Group>::Scalar::from(2u64)], &vec![<G2 as Group>::Scalar::from(2u64)][..],
); );
assert!(res.is_ok()); assert!(res.is_ok());
// set the running variable for the next iteration
recursive_snark = Some(recursive_snark_unwrapped);
} }
// Bench time to produce a compressed SNARK // Bench time to produce a compressed SNARK
let recursive_snark = recursive_snark.unwrap();
group.bench_function("Prove", |b| { group.bench_function("Prove", |b| {
b.iter(|| { b.iter(|| {
assert!(CompressedSNARK::<_, _, _, _, S1, S2>::prove( assert!(CompressedSNARK::<_, _, _, _, S1, S2>::prove(

View File

@@ -38,52 +38,53 @@ fn bench_recursive_snark(c: &mut Criterion) {
let mut group = c.benchmark_group(format!("RecursiveSNARK-StepCircuitSize-{num_cons}")); let mut group = c.benchmark_group(format!("RecursiveSNARK-StepCircuitSize-{num_cons}"));
group.sample_size(10); group.sample_size(10);
let c_primary = NonTrivialTestCircuit::new(num_cons);
let c_secondary = TrivialTestCircuit::default();
// Produce public parameters // Produce public parameters
let pp = PublicParams::<G1, G2, C1, C2>::setup( let pp = PublicParams::<G1, G2, C1, C2>::setup(c_primary.clone(), c_secondary.clone());
NonTrivialTestCircuit::new(num_cons),
TrivialTestCircuit::default(),
);
// Bench time to produce a recursive SNARK; // Bench time to produce a recursive SNARK;
// we execute a certain number of warm-up steps since executing // we execute a certain number of warm-up steps since executing
// the first step is cheaper than other steps owing to the presence of // the first step is cheaper than other steps owing to the presence of
// a lot of zeros in the satisfying assignment // a lot of zeros in the satisfying assignment
let num_warmup_steps = 10; let num_warmup_steps = 10;
let mut recursive_snark: Option<RecursiveSNARK<G1, G2, C1, C2>> = None; let mut recursive_snark: RecursiveSNARK<G1, G2, C1, C2> = RecursiveSNARK::new(
&pp,
&c_primary,
&c_secondary,
vec![<G1 as Group>::Scalar::from(2u64)],
vec![<G2 as Group>::Scalar::from(2u64)],
);
for i in 0..num_warmup_steps { for i in 0..num_warmup_steps {
let res = RecursiveSNARK::prove_step( let res = recursive_snark.prove_step(
&pp, &pp,
recursive_snark, &c_primary,
NonTrivialTestCircuit::new(num_cons), &c_secondary,
TrivialTestCircuit::default(),
vec![<G1 as Group>::Scalar::from(2u64)], vec![<G1 as Group>::Scalar::from(2u64)],
vec![<G2 as Group>::Scalar::from(2u64)], vec![<G2 as Group>::Scalar::from(2u64)],
); );
assert!(res.is_ok()); assert!(res.is_ok());
let recursive_snark_unwrapped = res.unwrap();
// verify the recursive snark at each step of recursion // verify the recursive snark at each step of recursion
let res = recursive_snark_unwrapped.verify( let res = recursive_snark.verify(
&pp, &pp,
i + 1, i + 1,
vec![<G1 as Group>::Scalar::from(2u64)], &[<G1 as Group>::Scalar::from(2u64)],
vec![<G2 as Group>::Scalar::from(2u64)], &[<G2 as Group>::Scalar::from(2u64)],
); );
assert!(res.is_ok()); assert!(res.is_ok());
// set the running variable for the next iteration
recursive_snark = Some(recursive_snark_unwrapped);
} }
group.bench_function("Prove", |b| { group.bench_function("Prove", |b| {
b.iter(|| { b.iter(|| {
// produce a recursive SNARK for a step of the recursion // produce a recursive SNARK for a step of the recursion
assert!(RecursiveSNARK::prove_step( assert!(black_box(&mut recursive_snark.clone())
.prove_step(
black_box(&pp), black_box(&pp),
black_box(recursive_snark.clone()), black_box(&c_primary),
black_box(NonTrivialTestCircuit::new(num_cons)), black_box(&c_secondary),
black_box(TrivialTestCircuit::default()),
black_box(vec![<G1 as Group>::Scalar::from(2u64)]), black_box(vec![<G1 as Group>::Scalar::from(2u64)]),
black_box(vec![<G2 as Group>::Scalar::from(2u64)]), black_box(vec![<G2 as Group>::Scalar::from(2u64)]),
) )
@@ -91,8 +92,6 @@ fn bench_recursive_snark(c: &mut Criterion) {
}) })
}); });
let recursive_snark = recursive_snark.unwrap();
// Benchmark the verification time // Benchmark the verification time
group.bench_function("Verify", |b| { group.bench_function("Verify", |b| {
b.iter(|| { b.iter(|| {
@@ -100,8 +99,8 @@ fn bench_recursive_snark(c: &mut Criterion) {
.verify( .verify(
black_box(&pp), black_box(&pp),
black_box(num_warmup_steps), black_box(num_warmup_steps),
black_box(vec![<G1 as Group>::Scalar::from(2u64)]), black_box(&vec![<G1 as Group>::Scalar::from(2u64)][..]),
black_box(vec![<G2 as Group>::Scalar::from(2u64)]), black_box(&vec![<G2 as Group>::Scalar::from(2u64)][..]),
) )
.is_ok()); .is_ok());
}); });

View File

@@ -172,7 +172,7 @@ fn main() {
G2, G2,
MinRootCircuit<<G1 as Group>::Scalar>, MinRootCircuit<<G1 as Group>::Scalar>,
TrivialTestCircuit<<G2 as Group>::Scalar>, TrivialTestCircuit<<G2 as Group>::Scalar>,
>::setup(circuit_primary, circuit_secondary.clone()); >::setup(circuit_primary.clone(), circuit_secondary.clone());
println!("PublicParams::setup, took {:?} ", start.elapsed()); println!("PublicParams::setup, took {:?} ", start.elapsed());
println!( println!(
@@ -218,15 +218,20 @@ fn main() {
type C2 = TrivialTestCircuit<<G2 as Group>::Scalar>; type C2 = TrivialTestCircuit<<G2 as Group>::Scalar>;
// produce a recursive SNARK // produce a recursive SNARK
println!("Generating a RecursiveSNARK..."); println!("Generating a RecursiveSNARK...");
let mut recursive_snark: Option<RecursiveSNARK<G1, G2, C1, C2>> = None; let mut recursive_snark: RecursiveSNARK<G1, G2, C1, C2> = RecursiveSNARK::<G1, G2, C1, C2>::new(
&pp,
&minroot_circuits[0],
&circuit_secondary,
z0_primary.clone(),
z0_secondary.clone(),
);
for (i, circuit_primary) in minroot_circuits.iter().take(num_steps).enumerate() { for (i, circuit_primary) in minroot_circuits.iter().take(num_steps).enumerate() {
let start = Instant::now(); let start = Instant::now();
let res = RecursiveSNARK::prove_step( let res = recursive_snark.prove_step(
&pp, &pp,
recursive_snark, circuit_primary,
circuit_primary.clone(), &circuit_secondary,
circuit_secondary.clone(),
z0_primary.clone(), z0_primary.clone(),
z0_secondary.clone(), z0_secondary.clone(),
); );
@@ -237,16 +242,12 @@ fn main() {
res.is_ok(), res.is_ok(),
start.elapsed() start.elapsed()
); );
recursive_snark = Some(res.unwrap());
} }
assert!(recursive_snark.is_some());
let recursive_snark = recursive_snark.unwrap();
// verify the recursive SNARK // verify the recursive SNARK
println!("Verifying a RecursiveSNARK..."); println!("Verifying a RecursiveSNARK...");
let start = Instant::now(); let start = Instant::now();
let res = recursive_snark.verify(&pp, num_steps, z0_primary.clone(), z0_secondary.clone()); let res = recursive_snark.verify(&pp, num_steps, &z0_primary, &z0_secondary);
println!( println!(
"RecursiveSNARK::verify: {:?}, took {:?}", "RecursiveSNARK::verify: {:?}, took {:?}",
res.is_ok(), res.is_ok(),

View File

@@ -1067,13 +1067,13 @@ mod tests {
{ {
let a = alloc_random_point(cs.namespace(|| "a")).unwrap(); let a = alloc_random_point(cs.namespace(|| "a")).unwrap();
inputize_allocted_point(&a, cs.namespace(|| "inputize a")).unwrap(); inputize_allocted_point(&a, cs.namespace(|| "inputize a")).unwrap();
let mut b = a.clone(); let mut b = &mut a.clone();
b.y = AllocatedNum::alloc(cs.namespace(|| "allocate negation of a"), || { b.y = AllocatedNum::alloc(cs.namespace(|| "allocate negation of a"), || {
Ok(G::Base::ZERO) Ok(G::Base::ZERO)
}) })
.unwrap(); .unwrap();
inputize_allocted_point(&b, cs.namespace(|| "inputize b")).unwrap(); inputize_allocted_point(b, cs.namespace(|| "inputize b")).unwrap();
let e = a.add(cs.namespace(|| "add a to b"), &b).unwrap(); let e = a.add(cs.namespace(|| "add a to b"), b).unwrap();
e e
} }

View File

@@ -192,28 +192,24 @@ where
C1: StepCircuit<G1::Scalar>, C1: StepCircuit<G1::Scalar>,
C2: StepCircuit<G2::Scalar>, C2: StepCircuit<G2::Scalar>,
{ {
/// Create a new `RecursiveSNARK` (or updates the provided `RecursiveSNARK`) /// Create new instance of recursive SNARK
/// by executing a step of the incremental computation pub fn new(
pub fn prove_step(
pp: &PublicParams<G1, G2, C1, C2>, pp: &PublicParams<G1, G2, C1, C2>,
recursive_snark: Option<Self>, c_primary: &C1,
c_primary: C1, c_secondary: &C2,
c_secondary: C2,
z0_primary: Vec<G1::Scalar>, z0_primary: Vec<G1::Scalar>,
z0_secondary: Vec<G2::Scalar>, z0_secondary: Vec<G2::Scalar>,
) -> Result<Self, NovaError> { ) -> Self {
if z0_primary.len() != pp.F_arity_primary || z0_secondary.len() != pp.F_arity_secondary { // Expected outputs of the two circuits
return Err(NovaError::InvalidInitialInputLength); let zi_primary = c_primary.output(&z0_primary);
} let zi_secondary = c_secondary.output(&z0_secondary);
match recursive_snark {
None => {
// base case for the primary // base case for the primary
let mut cs_primary: SatisfyingAssignment<G1> = SatisfyingAssignment::new(); let mut cs_primary: SatisfyingAssignment<G1> = SatisfyingAssignment::new();
let inputs_primary: NovaAugmentedCircuitInputs<G2> = NovaAugmentedCircuitInputs::new( let inputs_primary: NovaAugmentedCircuitInputs<G2> = NovaAugmentedCircuitInputs::new(
scalar_as_base::<G1>(pp.digest), scalar_as_base::<G1>(pp.digest),
G1::Scalar::ZERO, G1::Scalar::ZERO,
z0_primary.clone(), z0_primary,
None, None,
None, None,
None, None,
@@ -229,14 +225,15 @@ where
let _ = circuit_primary.synthesize(&mut cs_primary); let _ = circuit_primary.synthesize(&mut cs_primary);
let (u_primary, w_primary) = cs_primary let (u_primary, w_primary) = cs_primary
.r1cs_instance_and_witness(&pp.r1cs_shape_primary, &pp.ck_primary) .r1cs_instance_and_witness(&pp.r1cs_shape_primary, &pp.ck_primary)
.map_err(|_e| NovaError::UnSat)?; .map_err(|_e| NovaError::UnSat)
.expect("Nova error unsat");
// base case for the secondary // base case for the secondary
let mut cs_secondary: SatisfyingAssignment<G2> = SatisfyingAssignment::new(); let mut cs_secondary: SatisfyingAssignment<G2> = SatisfyingAssignment::new();
let inputs_secondary: NovaAugmentedCircuitInputs<G1> = NovaAugmentedCircuitInputs::new( let inputs_secondary: NovaAugmentedCircuitInputs<G1> = NovaAugmentedCircuitInputs::new(
pp.digest, pp.digest,
G2::Scalar::ZERO, G2::Scalar::ZERO,
z0_secondary.clone(), z0_secondary,
None, None,
None, None,
Some(u_primary.clone()), Some(u_primary.clone()),
@@ -251,18 +248,15 @@ where
let _ = circuit_secondary.synthesize(&mut cs_secondary); let _ = circuit_secondary.synthesize(&mut cs_secondary);
let (u_secondary, w_secondary) = cs_secondary let (u_secondary, w_secondary) = cs_secondary
.r1cs_instance_and_witness(&pp.r1cs_shape_secondary, &pp.ck_secondary) .r1cs_instance_and_witness(&pp.r1cs_shape_secondary, &pp.ck_secondary)
.map_err(|_e| NovaError::UnSat)?; .map_err(|_e| NovaError::UnSat)
.expect("Nova error unsat");
// IVC proof for the primary circuit // IVC proof for the primary circuit
let l_w_primary = w_primary; let l_w_primary = w_primary;
let l_u_primary = u_primary; let l_u_primary = u_primary;
let r_W_primary = let r_W_primary = RelaxedR1CSWitness::from_r1cs_witness(&pp.r1cs_shape_primary, &l_w_primary);
RelaxedR1CSWitness::from_r1cs_witness(&pp.r1cs_shape_primary, &l_w_primary); let r_U_primary =
let r_U_primary = RelaxedR1CSInstance::from_r1cs_instance( RelaxedR1CSInstance::from_r1cs_instance(&pp.ck_primary, &pp.r1cs_shape_primary, &l_u_primary);
&pp.ck_primary,
&pp.r1cs_shape_primary,
&l_u_primary,
);
// IVC proof of the secondary circuit // IVC proof of the secondary circuit
let l_w_secondary = w_secondary; let l_w_secondary = w_secondary;
@@ -271,49 +265,66 @@ where
let r_U_secondary = let r_U_secondary =
RelaxedR1CSInstance::<G2>::default(&pp.ck_secondary, &pp.r1cs_shape_secondary); RelaxedR1CSInstance::<G2>::default(&pp.ck_secondary, &pp.r1cs_shape_secondary);
// Outputs of the two circuits thus far
let zi_primary = c_primary.output(&z0_primary);
let zi_secondary = c_secondary.output(&z0_secondary);
if zi_primary.len() != pp.F_arity_primary || zi_secondary.len() != pp.F_arity_secondary { if zi_primary.len() != pp.F_arity_primary || zi_secondary.len() != pp.F_arity_secondary {
return Err(NovaError::InvalidStepOutputLength); panic!("Invalid step length");
} }
Ok(Self { Self {
r_W_primary, r_W_primary,
r_U_primary, r_U_primary,
r_W_secondary, r_W_secondary,
r_U_secondary, r_U_secondary,
l_w_secondary, l_w_secondary,
l_u_secondary, l_u_secondary,
i: 1_usize, i: 0,
zi_primary, zi_primary,
zi_secondary, zi_secondary,
_p_c1: Default::default(), _p_c1: Default::default(),
_p_c2: Default::default(), _p_c2: Default::default(),
})
} }
Some(r_snark) => { }
/// Create a new `RecursiveSNARK` (or updates the provided `RecursiveSNARK`)
/// by executing a step of the incremental computation
pub fn prove_step(
&mut self,
pp: &PublicParams<G1, G2, C1, C2>,
c_primary: &C1,
c_secondary: &C2,
z0_primary: Vec<G1::Scalar>,
z0_secondary: Vec<G2::Scalar>,
) -> Result<(), NovaError> {
if z0_primary.len() != pp.F_arity_primary || z0_secondary.len() != pp.F_arity_secondary {
return Err(NovaError::InvalidInitialInputLength);
}
// Frist step was already done in the constructor
if self.i == 0 {
self.i = 1;
return Ok(());
}
// fold the secondary circuit's instance // fold the secondary circuit's instance
let (nifs_secondary, (r_U_secondary, r_W_secondary)) = NIFS::prove( let (nifs_secondary, (r_U_secondary, r_W_secondary)) = NIFS::prove(
&pp.ck_secondary, &pp.ck_secondary,
&pp.ro_consts_secondary, &pp.ro_consts_secondary,
&scalar_as_base::<G1>(pp.digest), &scalar_as_base::<G1>(pp.digest),
&pp.r1cs_shape_secondary, &pp.r1cs_shape_secondary,
&r_snark.r_U_secondary, &self.r_U_secondary,
&r_snark.r_W_secondary, &self.r_W_secondary,
&r_snark.l_u_secondary, &self.l_u_secondary,
&r_snark.l_w_secondary, &self.l_w_secondary,
)?; )
.expect("Unable to fold secondary");
let mut cs_primary: SatisfyingAssignment<G1> = SatisfyingAssignment::new(); let mut cs_primary: SatisfyingAssignment<G1> = SatisfyingAssignment::new();
let inputs_primary: NovaAugmentedCircuitInputs<G2> = NovaAugmentedCircuitInputs::new( let inputs_primary: NovaAugmentedCircuitInputs<G2> = NovaAugmentedCircuitInputs::new(
scalar_as_base::<G1>(pp.digest), scalar_as_base::<G1>(pp.digest),
G1::Scalar::from(r_snark.i as u64), G1::Scalar::from(self.i as u64),
z0_primary, z0_primary,
Some(r_snark.zi_primary.clone()), Some(self.zi_primary.clone()),
Some(r_snark.r_U_secondary.clone()), Some(self.r_U_secondary.clone()),
Some(r_snark.l_u_secondary.clone()), Some(self.l_u_secondary.clone()),
Some(Commitment::<G2>::decompress(&nifs_secondary.comm_T)?), Some(Commitment::<G2>::decompress(&nifs_secondary.comm_T)?),
); );
@@ -327,7 +338,8 @@ where
let (l_u_primary, l_w_primary) = cs_primary let (l_u_primary, l_w_primary) = cs_primary
.r1cs_instance_and_witness(&pp.r1cs_shape_primary, &pp.ck_primary) .r1cs_instance_and_witness(&pp.r1cs_shape_primary, &pp.ck_primary)
.map_err(|_e| NovaError::UnSat)?; .map_err(|_e| NovaError::UnSat)
.expect("Nova error unsat");
// fold the primary circuit's instance // fold the primary circuit's instance
let (nifs_primary, (r_U_primary, r_W_primary)) = NIFS::prove( let (nifs_primary, (r_U_primary, r_W_primary)) = NIFS::prove(
@@ -335,19 +347,20 @@ where
&pp.ro_consts_primary, &pp.ro_consts_primary,
&pp.digest, &pp.digest,
&pp.r1cs_shape_primary, &pp.r1cs_shape_primary,
&r_snark.r_U_primary, &self.r_U_primary,
&r_snark.r_W_primary, &self.r_W_primary,
&l_u_primary, &l_u_primary,
&l_w_primary, &l_w_primary,
)?; )
.expect("Unable to fold primary");
let mut cs_secondary: SatisfyingAssignment<G2> = SatisfyingAssignment::new(); let mut cs_secondary: SatisfyingAssignment<G2> = SatisfyingAssignment::new();
let inputs_secondary: NovaAugmentedCircuitInputs<G1> = NovaAugmentedCircuitInputs::new( let inputs_secondary: NovaAugmentedCircuitInputs<G1> = NovaAugmentedCircuitInputs::new(
pp.digest, pp.digest,
G2::Scalar::from(r_snark.i as u64), G2::Scalar::from(self.i as u64),
z0_secondary, z0_secondary,
Some(r_snark.zi_secondary.clone()), Some(self.zi_secondary.clone()),
Some(r_snark.r_U_primary.clone()), Some(self.r_U_primary.clone()),
Some(l_u_primary), Some(l_u_primary),
Some(Commitment::<G1>::decompress(&nifs_primary.comm_T)?), Some(Commitment::<G1>::decompress(&nifs_primary.comm_T)?),
); );
@@ -365,24 +378,21 @@ where
.map_err(|_e| NovaError::UnSat)?; .map_err(|_e| NovaError::UnSat)?;
// update the running instances and witnesses // update the running instances and witnesses
let zi_primary = c_primary.output(&r_snark.zi_primary); self.zi_primary = c_primary.output(&self.zi_primary);
let zi_secondary = c_secondary.output(&r_snark.zi_secondary); self.zi_secondary = c_secondary.output(&self.zi_secondary);
Ok(Self { self.l_u_secondary = l_u_secondary;
r_W_primary, self.l_w_secondary = l_w_secondary;
r_U_primary,
r_W_secondary, self.r_U_primary = r_U_primary;
r_U_secondary, self.r_W_primary = r_W_primary;
l_w_secondary,
l_u_secondary, self.i += 1;
i: r_snark.i + 1,
zi_primary, self.r_U_secondary = r_U_secondary;
zi_secondary, self.r_W_secondary = r_W_secondary;
_p_c1: Default::default(),
_p_c2: Default::default(), Ok(())
})
}
}
} }
/// Verify the correctness of the `RecursiveSNARK` /// Verify the correctness of the `RecursiveSNARK`
@@ -390,8 +400,8 @@ where
&self, &self,
pp: &PublicParams<G1, G2, C1, C2>, pp: &PublicParams<G1, G2, C1, C2>,
num_steps: usize, num_steps: usize,
z0_primary: Vec<G1::Scalar>, z0_primary: &[G1::Scalar],
z0_secondary: Vec<G2::Scalar>, z0_secondary: &[G2::Scalar],
) -> Result<(Vec<G1::Scalar>, Vec<G2::Scalar>), NovaError> { ) -> Result<(Vec<G1::Scalar>, Vec<G2::Scalar>), NovaError> {
// number of steps cannot be zero // number of steps cannot be zero
if num_steps == 0 { if num_steps == 0 {
@@ -419,7 +429,7 @@ where
); );
hasher.absorb(pp.digest); hasher.absorb(pp.digest);
hasher.absorb(G1::Scalar::from(num_steps as u64)); hasher.absorb(G1::Scalar::from(num_steps as u64));
for e in &z0_primary { for e in z0_primary {
hasher.absorb(*e); hasher.absorb(*e);
} }
for e in &self.zi_primary { for e in &self.zi_primary {
@@ -433,7 +443,7 @@ where
); );
hasher2.absorb(scalar_as_base::<G1>(pp.digest)); hasher2.absorb(scalar_as_base::<G1>(pp.digest));
hasher2.absorb(G2::Scalar::from(num_steps as u64)); hasher2.absorb(G2::Scalar::from(num_steps as u64));
for e in &z0_secondary { for e in z0_secondary {
hasher2.absorb(*e); hasher2.absorb(*e);
} }
for e in &self.zi_secondary { for e in &self.zi_secondary {
@@ -906,23 +916,30 @@ mod tests {
let num_steps = 1; let num_steps = 1;
// produce a recursive SNARK // produce a recursive SNARK
let res = RecursiveSNARK::prove_step( let mut recursive_snark = RecursiveSNARK::new(
&pp, &pp,
None, &test_circuit1,
test_circuit1, &test_circuit2,
test_circuit2,
vec![<G1 as Group>::Scalar::ZERO], vec![<G1 as Group>::Scalar::ZERO],
vec![<G2 as Group>::Scalar::ZERO], vec![<G2 as Group>::Scalar::ZERO],
); );
let res = recursive_snark.prove_step(
&pp,
&test_circuit1,
&test_circuit2,
vec![<G1 as Group>::Scalar::ZERO],
vec![<G2 as Group>::Scalar::ZERO],
);
assert!(res.is_ok()); assert!(res.is_ok());
let recursive_snark = res.unwrap();
// verify the recursive SNARK // verify the recursive SNARK
let res = recursive_snark.verify( let res = recursive_snark.verify(
&pp, &pp,
num_steps, num_steps,
vec![<G1 as Group>::Scalar::ZERO], &vec![<G1 as Group>::Scalar::ZERO][..],
vec![<G2 as Group>::Scalar::ZERO], &vec![<G2 as Group>::Scalar::ZERO][..],
); );
assert!(res.is_ok()); assert!(res.is_ok());
} }
@@ -953,49 +970,45 @@ mod tests {
let num_steps = 3; let num_steps = 3;
// produce a recursive SNARK // produce a recursive SNARK
let mut recursive_snark: Option< let mut recursive_snark = RecursiveSNARK::<
RecursiveSNARK<
G1, G1,
G2, G2,
TrivialTestCircuit<<G1 as Group>::Scalar>, TrivialTestCircuit<<G1 as Group>::Scalar>,
CubicCircuit<<G2 as Group>::Scalar>, CubicCircuit<<G2 as Group>::Scalar>,
>, >::new(
> = None; &pp,
&circuit_primary,
&circuit_secondary,
vec![<G1 as Group>::Scalar::ONE],
vec![<G2 as Group>::Scalar::ZERO],
);
for i in 0..num_steps { for i in 0..num_steps {
let res = RecursiveSNARK::prove_step( let res = recursive_snark.prove_step(
&pp, &pp,
recursive_snark, &circuit_primary,
circuit_primary.clone(), &circuit_secondary,
circuit_secondary.clone(),
vec![<G1 as Group>::Scalar::ONE], vec![<G1 as Group>::Scalar::ONE],
vec![<G2 as Group>::Scalar::ZERO], vec![<G2 as Group>::Scalar::ZERO],
); );
assert!(res.is_ok()); assert!(res.is_ok());
let recursive_snark_unwrapped = res.unwrap();
// verify the recursive snark at each step of recursion // verify the recursive snark at each step of recursion
let res = recursive_snark_unwrapped.verify( let res = recursive_snark.verify(
&pp, &pp,
i + 1, i + 1,
vec![<G1 as Group>::Scalar::ONE], &vec![<G1 as Group>::Scalar::ONE][..],
vec![<G2 as Group>::Scalar::ZERO], &vec![<G2 as Group>::Scalar::ZERO][..],
); );
assert!(res.is_ok()); assert!(res.is_ok());
// set the running variable for the next iteration
recursive_snark = Some(recursive_snark_unwrapped);
} }
assert!(recursive_snark.is_some());
let recursive_snark = recursive_snark.unwrap();
// verify the recursive SNARK // verify the recursive SNARK
let res = recursive_snark.verify( let res = recursive_snark.verify(
&pp, &pp,
num_steps, num_steps,
vec![<G1 as Group>::Scalar::ONE], &vec![<G1 as Group>::Scalar::ONE][..],
vec![<G2 as Group>::Scalar::ZERO], &vec![<G2 as Group>::Scalar::ZERO][..],
); );
assert!(res.is_ok()); assert!(res.is_ok());
@@ -1043,37 +1056,36 @@ mod tests {
let num_steps = 3; let num_steps = 3;
// produce a recursive SNARK // produce a recursive SNARK
let mut recursive_snark: Option< let mut recursive_snark = RecursiveSNARK::<
RecursiveSNARK<
G1, G1,
G2, G2,
TrivialTestCircuit<<G1 as Group>::Scalar>, TrivialTestCircuit<<G1 as Group>::Scalar>,
CubicCircuit<<G2 as Group>::Scalar>, CubicCircuit<<G2 as Group>::Scalar>,
>, >::new(
> = None; &pp,
&circuit_primary,
&circuit_secondary,
vec![<G1 as Group>::Scalar::ONE],
vec![<G2 as Group>::Scalar::ZERO],
);
for _i in 0..num_steps { for _i in 0..num_steps {
let res = RecursiveSNARK::prove_step( let res = recursive_snark.prove_step(
&pp, &pp,
recursive_snark, &circuit_primary,
circuit_primary.clone(), &circuit_secondary,
circuit_secondary.clone(),
vec![<G1 as Group>::Scalar::ONE], vec![<G1 as Group>::Scalar::ONE],
vec![<G2 as Group>::Scalar::ZERO], vec![<G2 as Group>::Scalar::ZERO],
); );
assert!(res.is_ok()); assert!(res.is_ok());
recursive_snark = Some(res.unwrap());
} }
assert!(recursive_snark.is_some());
let recursive_snark = recursive_snark.unwrap();
// verify the recursive SNARK // verify the recursive SNARK
let res = recursive_snark.verify( let res = recursive_snark.verify(
&pp, &pp,
num_steps, num_steps,
vec![<G1 as Group>::Scalar::ONE], &vec![<G1 as Group>::Scalar::ONE][..],
vec![<G2 as Group>::Scalar::ZERO], &vec![<G2 as Group>::Scalar::ZERO][..],
); );
assert!(res.is_ok()); assert!(res.is_ok());
@@ -1138,37 +1150,36 @@ mod tests {
let num_steps = 3; let num_steps = 3;
// produce a recursive SNARK // produce a recursive SNARK
let mut recursive_snark: Option< let mut recursive_snark = RecursiveSNARK::<
RecursiveSNARK<
G1, G1,
G2, G2,
TrivialTestCircuit<<G1 as Group>::Scalar>, TrivialTestCircuit<<G1 as Group>::Scalar>,
CubicCircuit<<G2 as Group>::Scalar>, CubicCircuit<<G2 as Group>::Scalar>,
>, >::new(
> = None; &pp,
&circuit_primary,
&circuit_secondary,
vec![<G1 as Group>::Scalar::ONE],
vec![<G2 as Group>::Scalar::ZERO],
);
for _i in 0..num_steps { for _i in 0..num_steps {
let res = RecursiveSNARK::prove_step( let res = recursive_snark.prove_step(
&pp, &pp,
recursive_snark, &circuit_primary,
circuit_primary.clone(), &circuit_secondary,
circuit_secondary.clone(),
vec![<G1 as Group>::Scalar::ONE], vec![<G1 as Group>::Scalar::ONE],
vec![<G2 as Group>::Scalar::ZERO], vec![<G2 as Group>::Scalar::ZERO],
); );
assert!(res.is_ok()); assert!(res.is_ok());
recursive_snark = Some(res.unwrap());
} }
assert!(recursive_snark.is_some());
let recursive_snark = recursive_snark.unwrap();
// verify the recursive SNARK // verify the recursive SNARK
let res = recursive_snark.verify( let res = recursive_snark.verify(
&pp, &pp,
num_steps, num_steps,
vec![<G1 as Group>::Scalar::ONE], &vec![<G1 as Group>::Scalar::ONE][..],
vec![<G2 as Group>::Scalar::ZERO], &vec![<G2 as Group>::Scalar::ZERO][..],
); );
assert!(res.is_ok()); assert!(res.is_ok());
@@ -1237,14 +1248,9 @@ mod tests {
let rng = &mut rand::rngs::OsRng; let rng = &mut rand::rngs::OsRng;
let mut seed = F::random(rng); let mut seed = F::random(rng);
for _i in 0..num_steps + 1 { for _i in 0..num_steps + 1 {
let mut power = seed; seed *= seed.clone().square().square();
power = power.square();
power = power.square();
power *= seed;
powers.push(Self { y: power }); powers.push(Self { y: seed });
seed = power;
} }
// reverse the powers to get roots // reverse the powers to get roots
@@ -1289,12 +1295,7 @@ mod tests {
fn output(&self, z: &[F]) -> Vec<F> { fn output(&self, z: &[F]) -> Vec<F> {
// sanity check // sanity check
let x = z[0]; let x = z[0];
let y_pow_5 = { let y_pow_5 = self.y * self.y.clone().square().square();
let y = self.y;
let y_sq = y.square();
let y_quad = y_sq.square();
y_quad * self.y
};
assert_eq!(x, y_pow_5); assert_eq!(x, y_pow_5);
// return non-deterministic advice // return non-deterministic advice
@@ -1324,33 +1325,37 @@ mod tests {
let z0_secondary = vec![<G2 as Group>::Scalar::ZERO]; let z0_secondary = vec![<G2 as Group>::Scalar::ZERO];
// produce a recursive SNARK // produce a recursive SNARK
let mut recursive_snark: Option< let mut recursive_snark: RecursiveSNARK<
RecursiveSNARK<
G1, G1,
G2, G2,
FifthRootCheckingCircuit<<G1 as Group>::Scalar>, FifthRootCheckingCircuit<<G1 as Group>::Scalar>,
TrivialTestCircuit<<G2 as Group>::Scalar>, TrivialTestCircuit<<G2 as Group>::Scalar>,
>, > = RecursiveSNARK::<
> = None; G1,
G2,
FifthRootCheckingCircuit<<G1 as Group>::Scalar>,
TrivialTestCircuit<<G2 as Group>::Scalar>,
>::new(
&pp,
&roots[0],
&circuit_secondary,
z0_primary.clone(),
z0_secondary.clone(),
);
for circuit_primary in roots.iter().take(num_steps) { for circuit_primary in roots.iter().take(num_steps) {
let res = RecursiveSNARK::prove_step( let res = recursive_snark.prove_step(
&pp, &pp,
recursive_snark, circuit_primary,
circuit_primary.clone(), &circuit_secondary.clone(),
circuit_secondary.clone(),
z0_primary.clone(), z0_primary.clone(),
z0_secondary.clone(), z0_secondary.clone(),
); );
assert!(res.is_ok()); assert!(res.is_ok());
recursive_snark = Some(res.unwrap());
} }
assert!(recursive_snark.is_some());
let recursive_snark = recursive_snark.unwrap();
// verify the recursive SNARK // verify the recursive SNARK
let res = recursive_snark.verify(&pp, num_steps, z0_primary.clone(), z0_secondary.clone()); let res = recursive_snark.verify(&pp, num_steps, &z0_primary, &z0_secondary);
assert!(res.is_ok()); assert!(res.is_ok());
// produce the prover and verifier keys for compressed snark // produce the prover and verifier keys for compressed snark
@@ -1379,34 +1384,50 @@ mod tests {
G1: Group<Base = <G2 as Group>::Scalar>, G1: Group<Base = <G2 as Group>::Scalar>,
G2: Group<Base = <G1 as Group>::Scalar>, G2: Group<Base = <G1 as Group>::Scalar>,
{ {
let test_circuit1 = TrivialTestCircuit::<<G1 as Group>::Scalar>::default();
let test_circuit2 = CubicCircuit::<<G2 as Group>::Scalar>::default();
// produce public parameters // produce public parameters
let pp = PublicParams::< let pp = PublicParams::<
G1, G1,
G2, G2,
TrivialTestCircuit<<G1 as Group>::Scalar>, TrivialTestCircuit<<G1 as Group>::Scalar>,
CubicCircuit<<G2 as Group>::Scalar>, CubicCircuit<<G2 as Group>::Scalar>,
>::setup(TrivialTestCircuit::default(), CubicCircuit::default()); >::setup(test_circuit1.clone(), test_circuit2.clone());
let num_steps = 1; let num_steps = 1;
// produce a recursive SNARK // produce a recursive SNARK
let res = RecursiveSNARK::prove_step( let mut recursive_snark = RecursiveSNARK::<
G1,
G2,
TrivialTestCircuit<<G1 as Group>::Scalar>,
CubicCircuit<<G2 as Group>::Scalar>,
>::new(
&pp, &pp,
None, &test_circuit1,
TrivialTestCircuit::default(), &test_circuit2,
CubicCircuit::default(),
vec![<G1 as Group>::Scalar::ONE], vec![<G1 as Group>::Scalar::ONE],
vec![<G2 as Group>::Scalar::ZERO], vec![<G2 as Group>::Scalar::ZERO],
); );
// produce a recursive SNARK
let res = recursive_snark.prove_step(
&pp,
&test_circuit1,
&test_circuit2,
vec![<G1 as Group>::Scalar::ONE],
vec![<G2 as Group>::Scalar::ZERO],
);
assert!(res.is_ok()); assert!(res.is_ok());
let recursive_snark = res.unwrap();
// verify the recursive SNARK // verify the recursive SNARK
let res = recursive_snark.verify( let res = recursive_snark.verify(
&pp, &pp,
num_steps, num_steps,
vec![<G1 as Group>::Scalar::ONE], &vec![<G1 as Group>::Scalar::ONE][..],
vec![<G2 as Group>::Scalar::ZERO], &vec![<G2 as Group>::Scalar::ZERO][..],
); );
assert!(res.is_ok()); assert!(res.is_ok());

View File

@@ -69,7 +69,7 @@ impl<G: Group> Default for Commitment<G> {
impl<G: Group> TranscriptReprTrait<G> for Commitment<G> { impl<G: Group> TranscriptReprTrait<G> for Commitment<G> {
fn to_transcript_bytes(&self) -> Vec<u8> { fn to_transcript_bytes(&self) -> Vec<u8> {
let (x, y, is_infinity) = self.comm.to_coordinates(); let (x, y, is_infinity) = self.comm.to_coordinates();
let is_infinity_byte = if is_infinity { 0u8 } else { 1u8 }; let is_infinity_byte = (!is_infinity).into();
[ [
x.to_transcript_bytes(), x.to_transcript_bytes(),
y.to_transcript_bytes(), y.to_transcript_bytes(),