From 1e6bf942e215979b0e02d5b2b83a6c4450d740e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= <4142+huitseeker@users.noreply.github.com> Date: Mon, 19 Jun 2023 19:11:42 -0400 Subject: [PATCH] [refactorings] Leftovers (pot-pourri?) (#184) * test: compute_path * refactor: path computation - Improve path concatenation by utilizing built-in `join` method * refactor: replace `PartialEq` with derived instance - Derive `PartialEq` for `SatisfyingAssignment` struct - Remove redundant manual implementation of `PartialEq` Cargo-expand generates: ``` #[automatically_derived] impl ::core::cmp::PartialEq for SatisfyingAssignment where G::Scalar: PrimeField, G::Scalar: ::core::cmp::PartialEq, G::Scalar: ::core::cmp::PartialEq, G::Scalar: ::core::cmp::PartialEq, G::Scalar: ::core::cmp::PartialEq, G::Scalar: ::core::cmp::PartialEq, { #[inline] fn eq(&self, other: &SatisfyingAssignment) -> bool { self.a_aux_density == other.a_aux_density && self.b_input_density == other.b_input_density && self.b_aux_density == other.b_aux_density && self.a == other.a && self.b == other.b && self.c == other.c && self.input_assignment == other.input_assignment && self.aux_assignment == other.aux_assignment } } ``` * refactor: avoid default for PhantomData Unit type * refactor: replace fold with sum where applicable - Simplify code by replacing `fold` with `sum` in various instances * refactor: decompression method in sumcheck.rs * refactor: test functions to use slice instead of vector conversion * refactor: use more references in functions - Update parameter types to use references instead of owned values in various functions that do not need them - Replace cloning instances with references --- benches/compressed-snark.rs | 4 ++-- benches/recursive-snark.rs | 4 ++-- examples/signature.rs | 4 ++-- src/bellperson/shape_cs.rs | 37 ++++++++++++++++++++++-------- src/bellperson/solver.rs | 17 +------------- src/circuit.rs | 26 ++++++++++----------- src/gadgets/ecc.rs | 4 ++-- src/gadgets/nonnative/bignat.rs | 2 +- src/gadgets/r1cs.rs | 38 +++++++++++++------------------ src/gadgets/utils.rs | 2 +- src/lib.rs | 24 ++++++++++---------- src/provider/poseidon.rs | 4 ++-- src/spartan/mod.rs | 12 +++++----- src/spartan/pp.rs | 40 ++++++++++++++------------------- src/spartan/sumcheck.rs | 14 +++++------- 15 files changed, 109 insertions(+), 123 deletions(-) diff --git a/benches/compressed-snark.rs b/benches/compressed-snark.rs index db64f9a..4ae9741 100644 --- a/benches/compressed-snark.rs +++ b/benches/compressed-snark.rs @@ -90,8 +90,8 @@ fn bench_compressed_snark(c: &mut Criterion) { let res = recursive_snark.verify( &pp, i + 1, - &vec![::Scalar::from(2u64)][..], - &vec![::Scalar::from(2u64)][..], + &[::Scalar::from(2u64)], + &[::Scalar::from(2u64)], ); assert!(res.is_ok()); } diff --git a/benches/recursive-snark.rs b/benches/recursive-snark.rs index fe20513..eed8d48 100644 --- a/benches/recursive-snark.rs +++ b/benches/recursive-snark.rs @@ -113,8 +113,8 @@ fn bench_recursive_snark(c: &mut Criterion) { .verify( black_box(&pp), black_box(num_warmup_steps), - black_box(&vec![::Scalar::from(2u64)][..]), - black_box(&vec![::Scalar::from(2u64)][..]), + black_box(&[::Scalar::from(2u64)]), + black_box(&[::Scalar::from(2u64)]), ) .is_ok()); }); diff --git a/examples/signature.rs b/examples/signature.rs index 2410160..90438df 100644 --- a/examples/signature.rs +++ b/examples/signature.rs @@ -233,8 +233,8 @@ pub fn verify_signature>( |lc| lc + (G::Base::from_str_vartime("2").unwrap(), CS::one()), ); - let sg = g.scalar_mul(cs.namespace(|| "[s]G"), s_bits)?; - let cpk = pk.scalar_mul(&mut cs.namespace(|| "[c]PK"), c_bits)?; + let sg = g.scalar_mul(cs.namespace(|| "[s]G"), &s_bits)?; + let cpk = pk.scalar_mul(&mut cs.namespace(|| "[c]PK"), &c_bits)?; let rcpk = cpk.add(&mut cs.namespace(|| "R + [c]PK"), &r)?; let (rcpk_x, rcpk_y, _) = rcpk.get_coordinates(); diff --git a/src/bellperson/shape_cs.rs b/src/bellperson/shape_cs.rs index 1b259c8..bb96463 100644 --- a/src/bellperson/shape_cs.rs +++ b/src/bellperson/shape_cs.rs @@ -308,17 +308,36 @@ fn compute_path(ns: &[String], this: &str) -> String { "'/' is not allowed in names" ); - let mut name = String::new(); + let mut name = ns.join("/"); + if !name.is_empty() { + name.push('/'); + } - let mut needs_separation = false; - for ns in ns.iter().chain(Some(this.to_string()).iter()) { - if needs_separation { - name += "/"; - } + name.push_str(this); + + name +} - name += ns; - needs_separation = true; +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_compute_path() { + let ns = vec!["path".to_string(), "to".to_string(), "dir".to_string()]; + let this = "file"; + assert_eq!(compute_path(&ns, this), "path/to/dir/file"); + + let ns = vec!["".to_string(), "".to_string(), "".to_string()]; + let this = "file"; + assert_eq!(compute_path(&ns, this), "///file"); } - name + #[test] + #[should_panic(expected = "'/' is not allowed in names")] + fn test_compute_path_invalid() { + let ns = vec!["path".to_string(), "to".to_string(), "dir".to_string()]; + let this = "fi/le"; + compute_path(&ns, this); + } } diff --git a/src/bellperson/solver.rs b/src/bellperson/solver.rs index ac86d3b..0eaf088 100644 --- a/src/bellperson/solver.rs +++ b/src/bellperson/solver.rs @@ -8,6 +8,7 @@ use bellperson::{ }; /// A `ConstraintSystem` which calculates witness values for a concrete instance of an R1CS circuit. +#[derive(PartialEq)] pub struct SatisfyingAssignment where G::Scalar: PrimeField, @@ -68,22 +69,6 @@ where } } -impl PartialEq for SatisfyingAssignment -where - G::Scalar: PrimeField, -{ - fn eq(&self, other: &SatisfyingAssignment) -> bool { - self.a_aux_density == other.a_aux_density - && self.b_input_density == other.b_input_density - && self.b_aux_density == other.b_aux_density - && self.a == other.a - && self.b == other.b - && self.c == other.c - && self.input_assignment == other.input_assignment - && self.aux_assignment == other.aux_assignment - } -} - impl ConstraintSystem for SatisfyingAssignment where G::Scalar: PrimeField, diff --git a/src/circuit.rs b/src/circuit.rs index d9ef590..d94fee1 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -158,8 +158,8 @@ impl> NovaAugmentedCircuit { // Allocate the running instance let U: AllocatedRelaxedR1CSInstance = AllocatedRelaxedR1CSInstance::alloc( cs.namespace(|| "Allocate U"), - self.inputs.get().map_or(None, |inputs| { - inputs.U.get().map_or(None, |U| Some(U.clone())) + self.inputs.get().as_ref().map_or(None, |inputs| { + inputs.U.get().as_ref().map_or(None, |U| Some(U)) }), self.params.limb_width, self.params.n_limbs, @@ -168,8 +168,8 @@ impl> NovaAugmentedCircuit { // Allocate the instance to be folded in let u = AllocatedR1CSInstance::alloc( cs.namespace(|| "allocate instance u to fold"), - self.inputs.get().map_or(None, |inputs| { - inputs.u.get().map_or(None, |u| Some(u.clone())) + self.inputs.get().as_ref().map_or(None, |inputs| { + inputs.u.get().as_ref().map_or(None, |u| Some(u)) }), )?; @@ -219,9 +219,9 @@ impl> NovaAugmentedCircuit { i: AllocatedNum, z_0: Vec>, z_i: Vec>, - U: AllocatedRelaxedR1CSInstance, - u: AllocatedR1CSInstance, - T: AllocatedPoint, + U: &AllocatedRelaxedR1CSInstance, + u: &AllocatedR1CSInstance, + T: &AllocatedPoint, arity: usize, ) -> Result<(AllocatedRelaxedR1CSInstance, AllocatedBit), SynthesisError> { // Check that u.x[0] = Hash(params, U, i, z0, zi) @@ -240,7 +240,7 @@ impl> NovaAugmentedCircuit { U.absorb_in_ro(cs.namespace(|| "absorb U"), &mut ro)?; let hash_bits = ro.squeeze(cs.namespace(|| "Input hash"), NUM_HASH_BITS)?; - let hash = le_bits_to_num(cs.namespace(|| "bits to hash"), hash_bits)?; + let hash = le_bits_to_num(cs.namespace(|| "bits to hash"), &hash_bits)?; let check_pass = alloc_num_equals( cs.namespace(|| "check consistency of u.X[0] with H(params, U, i, z0, zi)"), &u.X0, @@ -290,9 +290,9 @@ impl> Circuit<::Base> i.clone(), z_0.clone(), z_i.clone(), - U, - u.clone(), - T, + &U, + &u, + &T, arity, )?; @@ -312,7 +312,7 @@ impl> Circuit<::Base> // Compute the U_new let Unew = Unew_base.conditionally_select( cs.namespace(|| "compute U_new"), - Unew_non_base, + &Unew_non_base, &Boolean::from(is_base_case.clone()), )?; @@ -357,7 +357,7 @@ impl> Circuit<::Base> } Unew.absorb_in_ro(cs.namespace(|| "absorb U_new"), &mut ro)?; let hash_bits = ro.squeeze(cs.namespace(|| "output hash bits"), NUM_HASH_BITS)?; - let hash = le_bits_to_num(cs.namespace(|| "convert hash to num"), hash_bits)?; + let hash = le_bits_to_num(cs.namespace(|| "convert hash to num"), &hash_bits)?; // Outputs the computed hash and u.X[1] that corresponds to the hash of the other circuit u.X1 diff --git a/src/gadgets/ecc.rs b/src/gadgets/ecc.rs index 997180d..86dc61e 100644 --- a/src/gadgets/ecc.rs +++ b/src/gadgets/ecc.rs @@ -429,7 +429,7 @@ where pub fn scalar_mul>( &self, mut cs: CS, - scalar_bits: Vec, + scalar_bits: &[AllocatedBit], ) -> Result { let split_len = core::cmp::min(scalar_bits.len(), (G::Base::NUM_BITS - 2) as usize); let (incomplete_bits, complete_bits) = scalar_bits.split_at(split_len); @@ -968,7 +968,7 @@ mod tests { .map(|(i, bit)| AllocatedBit::alloc(cs.namespace(|| format!("bit {i}")), Some(bit))) .collect::, SynthesisError>>() .unwrap(); - let e = a.scalar_mul(cs.namespace(|| "Scalar Mul"), bits).unwrap(); + let e = a.scalar_mul(cs.namespace(|| "Scalar Mul"), &bits).unwrap(); inputize_allocted_point(&e, cs.namespace(|| "inputize e")).unwrap(); (a, e, s) } diff --git a/src/gadgets/nonnative/bignat.rs b/src/gadgets/nonnative/bignat.rs index 8fcf29a..9db4848 100644 --- a/src/gadgets/nonnative/bignat.rs +++ b/src/gadgets/nonnative/bignat.rs @@ -222,7 +222,7 @@ impl BigNat { /// The value is provided by an allocated number pub fn from_num>( mut cs: CS, - n: Num, + n: &Num, limb_width: usize, n_limbs: usize, ) -> Result { diff --git a/src/gadgets/r1cs.rs b/src/gadgets/r1cs.rs index 2c0d33b..6841d71 100644 --- a/src/gadgets/r1cs.rs +++ b/src/gadgets/r1cs.rs @@ -33,7 +33,7 @@ impl AllocatedR1CSInstance { /// Takes the r1cs instance and creates a new allocated r1cs instance pub fn alloc::Base>>( mut cs: CS, - u: Option>, + u: Option<&R1CSInstance>, ) -> Result { // Check that the incoming instance has exactly 2 io let W = AllocatedPoint::alloc( @@ -76,7 +76,7 @@ impl AllocatedRelaxedR1CSInstance { /// Allocates the given RelaxedR1CSInstance as a witness of the circuit pub fn alloc::Base>>( mut cs: CS, - inst: Option>, + inst: Option<&RelaxedR1CSInstance>, limb_width: usize, n_limbs: usize, ) -> Result { @@ -104,22 +104,14 @@ impl AllocatedRelaxedR1CSInstance { // Allocate X0 and X1. If the input instance is None, then allocate default values 0. let X0 = BigNat::alloc_from_nat( cs.namespace(|| "allocate X[0]"), - || { - Ok(f_to_nat( - &inst.clone().map_or(G::Scalar::ZERO, |inst| inst.X[0]), - )) - }, + || Ok(f_to_nat(&inst.map_or(G::Scalar::ZERO, |inst| inst.X[0]))), limb_width, n_limbs, )?; let X1 = BigNat::alloc_from_nat( cs.namespace(|| "allocate X[1]"), - || { - Ok(f_to_nat( - &inst.clone().map_or(G::Scalar::ZERO, |inst| inst.X[1]), - )) - }, + || Ok(f_to_nat(&inst.map_or(G::Scalar::ZERO, |inst| inst.X[1]))), limb_width, n_limbs, )?; @@ -170,14 +162,14 @@ impl AllocatedRelaxedR1CSInstance { let X0 = BigNat::from_num( cs.namespace(|| "allocate X0 from relaxed r1cs"), - Num::from(inst.X0.clone()), + &Num::from(inst.X0), limb_width, n_limbs, )?; let X1 = BigNat::from_num( cs.namespace(|| "allocate X1 from relaxed r1cs"), - Num::from(inst.X1.clone()), + &Num::from(inst.X1), limb_width, n_limbs, )?; @@ -246,8 +238,8 @@ impl AllocatedRelaxedR1CSInstance { &self, mut cs: CS, params: AllocatedNum, // hash of R1CSShape of F' - u: AllocatedR1CSInstance, - T: AllocatedPoint, + u: &AllocatedR1CSInstance, + T: &AllocatedPoint, ro_consts: ROConstantsCircuit, limb_width: usize, n_limbs: usize, @@ -261,14 +253,14 @@ impl AllocatedRelaxedR1CSInstance { ro.absorb(T.y.clone()); ro.absorb(T.is_infinity.clone()); let r_bits = ro.squeeze(cs.namespace(|| "r bits"), NUM_CHALLENGE_BITS)?; - let r = le_bits_to_num(cs.namespace(|| "r"), r_bits.clone())?; + let r = le_bits_to_num(cs.namespace(|| "r"), &r_bits)?; // W_fold = self.W + r * u.W - let rW = u.W.scalar_mul(cs.namespace(|| "r * u.W"), r_bits.clone())?; + let rW = u.W.scalar_mul(cs.namespace(|| "r * u.W"), &r_bits)?; let W_fold = self.W.add(cs.namespace(|| "self.W + r * u.W"), &rW)?; // E_fold = self.E + r * T - let rT = T.scalar_mul(cs.namespace(|| "r * T"), r_bits)?; + let rT = T.scalar_mul(cs.namespace(|| "r * T"), &r_bits)?; let E_fold = self.E.add(cs.namespace(|| "self.E + r * T"), &rT)?; // u_fold = u_r + r @@ -286,7 +278,7 @@ impl AllocatedRelaxedR1CSInstance { // Analyze r into limbs let r_bn = BigNat::from_num( cs.namespace(|| "allocate r_bn"), - Num::from(r.clone()), + &Num::from(r), limb_width, n_limbs, )?; @@ -302,7 +294,7 @@ impl AllocatedRelaxedR1CSInstance { // Analyze X0 to bignat let X0_bn = BigNat::from_num( cs.namespace(|| "allocate X0_bn"), - Num::from(u.X0.clone()), + &Num::from(u.X0.clone()), limb_width, n_limbs, )?; @@ -317,7 +309,7 @@ impl AllocatedRelaxedR1CSInstance { // Analyze X1 to bignat let X1_bn = BigNat::from_num( cs.namespace(|| "allocate X1_bn"), - Num::from(u.X1.clone()), + &Num::from(u.X1.clone()), limb_width, n_limbs, )?; @@ -342,7 +334,7 @@ impl AllocatedRelaxedR1CSInstance { pub fn conditionally_select::Base>>( &self, mut cs: CS, - other: AllocatedRelaxedR1CSInstance, + other: &AllocatedRelaxedR1CSInstance, condition: &Boolean, ) -> Result, SynthesisError> { let W = AllocatedPoint::conditionally_select( diff --git a/src/gadgets/utils.rs b/src/gadgets/utils.rs index 1262f23..72562fd 100644 --- a/src/gadgets/utils.rs +++ b/src/gadgets/utils.rs @@ -15,7 +15,7 @@ use num_bigint::BigInt; /// Gets as input the little indian representation of a number and spits out the number pub fn le_bits_to_num( mut cs: CS, - bits: Vec, + bits: &[AllocatedBit], ) -> Result, SynthesisError> where Scalar: PrimeField + PrimeFieldBits, diff --git a/src/lib.rs b/src/lib.rs index 57b0c18..405f0da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -938,8 +938,8 @@ mod tests { let res = recursive_snark.verify( &pp, num_steps, - &vec![::Scalar::ZERO][..], - &vec![::Scalar::ZERO][..], + &[::Scalar::ZERO], + &[::Scalar::ZERO], ); assert!(res.is_ok()); } @@ -997,8 +997,8 @@ mod tests { let res = recursive_snark.verify( &pp, i + 1, - &vec![::Scalar::ONE][..], - &vec![::Scalar::ZERO][..], + &[::Scalar::ONE], + &[::Scalar::ZERO], ); assert!(res.is_ok()); } @@ -1007,8 +1007,8 @@ mod tests { let res = recursive_snark.verify( &pp, num_steps, - &vec![::Scalar::ONE][..], - &vec![::Scalar::ZERO][..], + &[::Scalar::ONE], + &[::Scalar::ZERO], ); assert!(res.is_ok()); @@ -1084,8 +1084,8 @@ mod tests { let res = recursive_snark.verify( &pp, num_steps, - &vec![::Scalar::ONE][..], - &vec![::Scalar::ZERO][..], + &[::Scalar::ONE], + &[::Scalar::ZERO], ); assert!(res.is_ok()); @@ -1178,8 +1178,8 @@ mod tests { let res = recursive_snark.verify( &pp, num_steps, - &vec![::Scalar::ONE][..], - &vec![::Scalar::ZERO][..], + &[::Scalar::ONE], + &[::Scalar::ZERO], ); assert!(res.is_ok()); @@ -1426,8 +1426,8 @@ mod tests { let res = recursive_snark.verify( &pp, num_steps, - &vec![::Scalar::ONE][..], - &vec![::Scalar::ZERO][..], + &[::Scalar::ONE], + &[::Scalar::ZERO], ); assert!(res.is_ok()); diff --git a/src/provider/poseidon.rs b/src/provider/poseidon.rs index 2445984..a37ad9c 100644 --- a/src/provider/poseidon.rs +++ b/src/provider/poseidon.rs @@ -65,7 +65,7 @@ where constants, num_absorbs, squeezed: false, - _p: PhantomData::default(), + _p: PhantomData, } } @@ -236,7 +236,7 @@ mod tests { } let num = ro.squeeze(NUM_CHALLENGE_BITS); let num2_bits = ro_gadget.squeeze(&mut cs, NUM_CHALLENGE_BITS).unwrap(); - let num2 = le_bits_to_num(&mut cs, num2_bits).unwrap(); + let num2 = le_bits_to_num(&mut cs, &num2_bits).unwrap(); assert_eq!(num.to_repr(), num2.get_value().unwrap().to_repr()); } diff --git a/src/spartan/mod.rs b/src/spartan/mod.rs index 6e9b18d..b73cf48 100644 --- a/src/spartan/mod.rs +++ b/src/spartan/mod.rs @@ -109,7 +109,7 @@ impl PolyEvalInstance { .iter() .zip(powers_of_s.iter()) .map(|(e, p)| *e * p) - .fold(G::Scalar::ZERO, |acc, item| acc + item); + .sum(); let c = c_vec .iter() .zip(powers_of_s.iter()) @@ -378,7 +378,7 @@ impl> RelaxedR1CSSNARKTrait> = w_vec_padded .iter() @@ -420,7 +420,7 @@ impl> RelaxedR1CSSNARKTrait> RelaxedR1CSSNARKTrait> RelaxedR1CSSNARKTrait> RelaxedR1CSSNARKTrait> RelaxedR1CSSNARK .iter() .zip(coeffs.iter()) .map(|(c_1, c_2)| *c_1 * c_2) - .fold(G::Scalar::ZERO, |acc, item| acc + item); + .sum(); let mut e = claim; let mut r: Vec = Vec::new(); @@ -876,15 +876,9 @@ impl> RelaxedR1CSSNARK evals.extend(inner.evaluation_points()); assert_eq!(evals.len(), num_claims); - let evals_combined_0 = (0..evals.len()) - .map(|i| evals[i][0] * coeffs[i]) - .fold(G::Scalar::ZERO, |acc, item| acc + item); - let evals_combined_2 = (0..evals.len()) - .map(|i| evals[i][1] * coeffs[i]) - .fold(G::Scalar::ZERO, |acc, item| acc + item); - let evals_combined_3 = (0..evals.len()) - .map(|i| evals[i][2] * coeffs[i]) - .fold(G::Scalar::ZERO, |acc, item| acc + item); + let evals_combined_0 = (0..evals.len()).map(|i| evals[i][0] * coeffs[i]).sum(); + let evals_combined_2 = (0..evals.len()).map(|i| evals[i][1] * coeffs[i]).sum(); + let evals_combined_3 = (0..evals.len()).map(|i| evals[i][2] * coeffs[i]).sum(); let evals = vec![ evals_combined_0, @@ -1242,13 +1236,13 @@ impl> RelaxedR1CSSNARKTrait> RelaxedR1CSSNARKTrait> RelaxedR1CSSNARKTrait> = w_vec_padded .iter() @@ -1508,7 +1502,7 @@ impl> RelaxedR1CSSNARKTrait> RelaxedR1CSSNARKTrait> RelaxedR1CSSNARKTrait> RelaxedR1CSSNARKTrait> RelaxedR1CSSNARKTrait> RelaxedR1CSSNARKTrait> RelaxedR1CSSNARKTrait SumcheckProof { evals.push((eval_point_0, eval_point_2)); } - let evals_combined_0 = (0..evals.len()) - .map(|i| evals[i].0 * coeffs[i]) - .fold(G::Scalar::ZERO, |acc, item| acc + item); - let evals_combined_2 = (0..evals.len()) - .map(|i| evals[i].1 * coeffs[i]) - .fold(G::Scalar::ZERO, |acc, item| acc + item); + let evals_combined_0 = (0..evals.len()).map(|i| evals[i].0 * coeffs[i]).sum(); + let evals_combined_2 = (0..evals.len()).map(|i| evals[i].1 * coeffs[i]).sum(); let evals = vec![evals_combined_0, e - evals_combined_0, evals_combined_2]; let poly = UniPoly::from_evals(&evals); @@ -387,9 +383,9 @@ impl CompressedUniPoly { } let mut coeffs: Vec = Vec::new(); - coeffs.extend(vec![&self.coeffs_except_linear_term[0]]); - coeffs.extend(vec![&linear_term]); - coeffs.extend(self.coeffs_except_linear_term[1..].to_vec()); + coeffs.push(self.coeffs_except_linear_term[0]); + coeffs.push(linear_term); + coeffs.extend(&self.coeffs_except_linear_term[1..]); assert_eq!(self.coeffs_except_linear_term.len() + 1, coeffs.len()); UniPoly { coeffs } }