Browse Source

[refactorings] Leftovers (pot-pourri?) (#184)

* test: compute_path

* refactor: path computation

- Improve path concatenation by utilizing built-in `join` method

* refactor: replace `PartialEq` with derived instance

- Derive `PartialEq` for `SatisfyingAssignment` struct
- Remove redundant manual implementation of `PartialEq`

Cargo-expand generates:
```
        #[automatically_derived]
        impl<G: ::core::cmp::PartialEq + Group> ::core::cmp::PartialEq
        for SatisfyingAssignment<G>
        where
            G::Scalar: PrimeField,
            G::Scalar: ::core::cmp::PartialEq,
            G::Scalar: ::core::cmp::PartialEq,
            G::Scalar: ::core::cmp::PartialEq,
            G::Scalar: ::core::cmp::PartialEq,
            G::Scalar: ::core::cmp::PartialEq,
        {
            #[inline]
            fn eq(&self, other: &SatisfyingAssignment<G>) -> bool {
                self.a_aux_density == other.a_aux_density
                    && self.b_input_density == other.b_input_density
                    && self.b_aux_density == other.b_aux_density && self.a == other.a
                    && self.b == other.b && self.c == other.c
                    && self.input_assignment == other.input_assignment
                    && self.aux_assignment == other.aux_assignment
            }
        }
```

* refactor: avoid default for PhantomData Unit type

* refactor: replace fold with sum where applicable

- Simplify code by replacing `fold` with `sum` in various instances

* refactor: decompression method in sumcheck.rs

* refactor: test functions to use slice instead of vector conversion

* refactor: use more references in functions

- Update parameter types to use references instead of owned values in various functions that do not need them
- Replace cloning instances with references
main
François Garillot 1 year ago
committed by GitHub
parent
commit
1e6bf942e2
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 109 additions and 123 deletions
  1. +2
    -2
      benches/compressed-snark.rs
  2. +2
    -2
      benches/recursive-snark.rs
  3. +2
    -2
      examples/signature.rs
  4. +28
    -9
      src/bellperson/shape_cs.rs
  5. +1
    -16
      src/bellperson/solver.rs
  6. +13
    -13
      src/circuit.rs
  7. +2
    -2
      src/gadgets/ecc.rs
  8. +1
    -1
      src/gadgets/nonnative/bignat.rs
  9. +15
    -23
      src/gadgets/r1cs.rs
  10. +1
    -1
      src/gadgets/utils.rs
  11. +12
    -12
      src/lib.rs
  12. +2
    -2
      src/provider/poseidon.rs
  13. +6
    -6
      src/spartan/mod.rs
  14. +17
    -23
      src/spartan/pp.rs
  15. +5
    -9
      src/spartan/sumcheck.rs

+ 2
- 2
benches/compressed-snark.rs

@ -90,8 +90,8 @@ fn bench_compressed_snark(c: &mut Criterion) {
let res = recursive_snark.verify( let res = recursive_snark.verify(
&pp, &pp,
i + 1, i + 1,
&vec![<G1 as Group>::Scalar::from(2u64)][..],
&vec![<G2 as Group>::Scalar::from(2u64)][..],
&[<G1 as Group>::Scalar::from(2u64)],
&[<G2 as Group>::Scalar::from(2u64)],
); );
assert!(res.is_ok()); assert!(res.is_ok());
} }

+ 2
- 2
benches/recursive-snark.rs

@ -113,8 +113,8 @@ fn bench_recursive_snark(c: &mut Criterion) {
.verify( .verify(
black_box(&pp), black_box(&pp),
black_box(num_warmup_steps), black_box(num_warmup_steps),
black_box(&vec![<G1 as Group>::Scalar::from(2u64)][..]),
black_box(&vec![<G2 as Group>::Scalar::from(2u64)][..]),
black_box(&[<G1 as Group>::Scalar::from(2u64)]),
black_box(&[<G2 as Group>::Scalar::from(2u64)]),
) )
.is_ok()); .is_ok());
}); });

+ 2
- 2
examples/signature.rs

@ -233,8 +233,8 @@ pub fn verify_signature>(
|lc| lc + (G::Base::from_str_vartime("2").unwrap(), CS::one()), |lc| lc + (G::Base::from_str_vartime("2").unwrap(), CS::one()),
); );
let sg = g.scalar_mul(cs.namespace(|| "[s]G"), s_bits)?;
let cpk = pk.scalar_mul(&mut cs.namespace(|| "[c]PK"), c_bits)?;
let sg = g.scalar_mul(cs.namespace(|| "[s]G"), &s_bits)?;
let cpk = pk.scalar_mul(&mut cs.namespace(|| "[c]PK"), &c_bits)?;
let rcpk = cpk.add(&mut cs.namespace(|| "R + [c]PK"), &r)?; let rcpk = cpk.add(&mut cs.namespace(|| "R + [c]PK"), &r)?;
let (rcpk_x, rcpk_y, _) = rcpk.get_coordinates(); let (rcpk_x, rcpk_y, _) = rcpk.get_coordinates();

+ 28
- 9
src/bellperson/shape_cs.rs

@ -308,17 +308,36 @@ fn compute_path(ns: &[String], this: &str) -> String {
"'/' is not allowed in names" "'/' is not allowed in names"
); );
let mut name = String::new();
let mut name = ns.join("/");
if !name.is_empty() {
name.push('/');
}
let mut needs_separation = false;
for ns in ns.iter().chain(Some(this.to_string()).iter()) {
if needs_separation {
name += "/";
}
name.push_str(this);
name
}
name += ns;
needs_separation = true;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_path() {
let ns = vec!["path".to_string(), "to".to_string(), "dir".to_string()];
let this = "file";
assert_eq!(compute_path(&ns, this), "path/to/dir/file");
let ns = vec!["".to_string(), "".to_string(), "".to_string()];
let this = "file";
assert_eq!(compute_path(&ns, this), "///file");
} }
name
#[test]
#[should_panic(expected = "'/' is not allowed in names")]
fn test_compute_path_invalid() {
let ns = vec!["path".to_string(), "to".to_string(), "dir".to_string()];
let this = "fi/le";
compute_path(&ns, this);
}
} }

+ 1
- 16
src/bellperson/solver.rs

@ -8,6 +8,7 @@ use bellperson::{
}; };
/// A `ConstraintSystem` which calculates witness values for a concrete instance of an R1CS circuit. /// A `ConstraintSystem` which calculates witness values for a concrete instance of an R1CS circuit.
#[derive(PartialEq)]
pub struct SatisfyingAssignment<G: Group> pub struct SatisfyingAssignment<G: Group>
where where
G::Scalar: PrimeField, G::Scalar: PrimeField,
@ -68,22 +69,6 @@ where
} }
} }
impl<G: Group> PartialEq for SatisfyingAssignment<G>
where
G::Scalar: PrimeField,
{
fn eq(&self, other: &SatisfyingAssignment<G>) -> bool {
self.a_aux_density == other.a_aux_density
&& self.b_input_density == other.b_input_density
&& self.b_aux_density == other.b_aux_density
&& self.a == other.a
&& self.b == other.b
&& self.c == other.c
&& self.input_assignment == other.input_assignment
&& self.aux_assignment == other.aux_assignment
}
}
impl<G: Group> ConstraintSystem<G::Scalar> for SatisfyingAssignment<G> impl<G: Group> ConstraintSystem<G::Scalar> for SatisfyingAssignment<G>
where where
G::Scalar: PrimeField, G::Scalar: PrimeField,

+ 13
- 13
src/circuit.rs

@ -158,8 +158,8 @@ impl> NovaAugmentedCircuit {
// Allocate the running instance // Allocate the running instance
let U: AllocatedRelaxedR1CSInstance<G> = AllocatedRelaxedR1CSInstance::alloc( let U: AllocatedRelaxedR1CSInstance<G> = AllocatedRelaxedR1CSInstance::alloc(
cs.namespace(|| "Allocate U"), cs.namespace(|| "Allocate U"),
self.inputs.get().map_or(None, |inputs| {
inputs.U.get().map_or(None, |U| Some(U.clone()))
self.inputs.get().as_ref().map_or(None, |inputs| {
inputs.U.get().as_ref().map_or(None, |U| Some(U))
}), }),
self.params.limb_width, self.params.limb_width,
self.params.n_limbs, self.params.n_limbs,
@ -168,8 +168,8 @@ impl> NovaAugmentedCircuit {
// Allocate the instance to be folded in // Allocate the instance to be folded in
let u = AllocatedR1CSInstance::alloc( let u = AllocatedR1CSInstance::alloc(
cs.namespace(|| "allocate instance u to fold"), cs.namespace(|| "allocate instance u to fold"),
self.inputs.get().map_or(None, |inputs| {
inputs.u.get().map_or(None, |u| Some(u.clone()))
self.inputs.get().as_ref().map_or(None, |inputs| {
inputs.u.get().as_ref().map_or(None, |u| Some(u))
}), }),
)?; )?;
@ -219,9 +219,9 @@ impl> NovaAugmentedCircuit {
i: AllocatedNum<G::Base>, i: AllocatedNum<G::Base>,
z_0: Vec<AllocatedNum<G::Base>>, z_0: Vec<AllocatedNum<G::Base>>,
z_i: Vec<AllocatedNum<G::Base>>, z_i: Vec<AllocatedNum<G::Base>>,
U: AllocatedRelaxedR1CSInstance<G>,
u: AllocatedR1CSInstance<G>,
T: AllocatedPoint<G>,
U: &AllocatedRelaxedR1CSInstance<G>,
u: &AllocatedR1CSInstance<G>,
T: &AllocatedPoint<G>,
arity: usize, arity: usize,
) -> Result<(AllocatedRelaxedR1CSInstance<G>, AllocatedBit), SynthesisError> { ) -> Result<(AllocatedRelaxedR1CSInstance<G>, AllocatedBit), SynthesisError> {
// Check that u.x[0] = Hash(params, U, i, z0, zi) // Check that u.x[0] = Hash(params, U, i, z0, zi)
@ -240,7 +240,7 @@ impl> NovaAugmentedCircuit {
U.absorb_in_ro(cs.namespace(|| "absorb U"), &mut ro)?; U.absorb_in_ro(cs.namespace(|| "absorb U"), &mut ro)?;
let hash_bits = ro.squeeze(cs.namespace(|| "Input hash"), NUM_HASH_BITS)?; let hash_bits = ro.squeeze(cs.namespace(|| "Input hash"), NUM_HASH_BITS)?;
let hash = le_bits_to_num(cs.namespace(|| "bits to hash"), hash_bits)?;
let hash = le_bits_to_num(cs.namespace(|| "bits to hash"), &hash_bits)?;
let check_pass = alloc_num_equals( let check_pass = alloc_num_equals(
cs.namespace(|| "check consistency of u.X[0] with H(params, U, i, z0, zi)"), cs.namespace(|| "check consistency of u.X[0] with H(params, U, i, z0, zi)"),
&u.X0, &u.X0,
@ -290,9 +290,9 @@ impl> Circuit<::Base>
i.clone(), i.clone(),
z_0.clone(), z_0.clone(),
z_i.clone(), z_i.clone(),
U,
u.clone(),
T,
&U,
&u,
&T,
arity, arity,
)?; )?;
@ -312,7 +312,7 @@ impl> Circuit<::Base>
// Compute the U_new // Compute the U_new
let Unew = Unew_base.conditionally_select( let Unew = Unew_base.conditionally_select(
cs.namespace(|| "compute U_new"), cs.namespace(|| "compute U_new"),
Unew_non_base,
&Unew_non_base,
&Boolean::from(is_base_case.clone()), &Boolean::from(is_base_case.clone()),
)?; )?;
@ -357,7 +357,7 @@ impl> Circuit<::Base>
} }
Unew.absorb_in_ro(cs.namespace(|| "absorb U_new"), &mut ro)?; Unew.absorb_in_ro(cs.namespace(|| "absorb U_new"), &mut ro)?;
let hash_bits = ro.squeeze(cs.namespace(|| "output hash bits"), NUM_HASH_BITS)?; let hash_bits = ro.squeeze(cs.namespace(|| "output hash bits"), NUM_HASH_BITS)?;
let hash = le_bits_to_num(cs.namespace(|| "convert hash to num"), hash_bits)?;
let hash = le_bits_to_num(cs.namespace(|| "convert hash to num"), &hash_bits)?;
// Outputs the computed hash and u.X[1] that corresponds to the hash of the other circuit // Outputs the computed hash and u.X[1] that corresponds to the hash of the other circuit
u.X1 u.X1

+ 2
- 2
src/gadgets/ecc.rs

@ -429,7 +429,7 @@ where
pub fn scalar_mul<CS: ConstraintSystem<G::Base>>( pub fn scalar_mul<CS: ConstraintSystem<G::Base>>(
&self, &self,
mut cs: CS, mut cs: CS,
scalar_bits: Vec<AllocatedBit>,
scalar_bits: &[AllocatedBit],
) -> Result<Self, SynthesisError> { ) -> Result<Self, SynthesisError> {
let split_len = core::cmp::min(scalar_bits.len(), (G::Base::NUM_BITS - 2) as usize); let split_len = core::cmp::min(scalar_bits.len(), (G::Base::NUM_BITS - 2) as usize);
let (incomplete_bits, complete_bits) = scalar_bits.split_at(split_len); let (incomplete_bits, complete_bits) = scalar_bits.split_at(split_len);
@ -968,7 +968,7 @@ mod tests {
.map(|(i, bit)| AllocatedBit::alloc(cs.namespace(|| format!("bit {i}")), Some(bit))) .map(|(i, bit)| AllocatedBit::alloc(cs.namespace(|| format!("bit {i}")), Some(bit)))
.collect::<Result<Vec<AllocatedBit>, SynthesisError>>() .collect::<Result<Vec<AllocatedBit>, SynthesisError>>()
.unwrap(); .unwrap();
let e = a.scalar_mul(cs.namespace(|| "Scalar Mul"), bits).unwrap();
let e = a.scalar_mul(cs.namespace(|| "Scalar Mul"), &bits).unwrap();
inputize_allocted_point(&e, cs.namespace(|| "inputize e")).unwrap(); inputize_allocted_point(&e, cs.namespace(|| "inputize e")).unwrap();
(a, e, s) (a, e, s)
} }

+ 1
- 1
src/gadgets/nonnative/bignat.rs

@ -222,7 +222,7 @@ impl BigNat {
/// The value is provided by an allocated number /// The value is provided by an allocated number
pub fn from_num<CS: ConstraintSystem<Scalar>>( pub fn from_num<CS: ConstraintSystem<Scalar>>(
mut cs: CS, mut cs: CS,
n: Num<Scalar>,
n: &Num<Scalar>,
limb_width: usize, limb_width: usize,
n_limbs: usize, n_limbs: usize,
) -> Result<Self, SynthesisError> { ) -> Result<Self, SynthesisError> {

+ 15
- 23
src/gadgets/r1cs.rs

@ -33,7 +33,7 @@ impl AllocatedR1CSInstance {
/// Takes the r1cs instance and creates a new allocated r1cs instance /// Takes the r1cs instance and creates a new allocated r1cs instance
pub fn alloc<CS: ConstraintSystem<<G as Group>::Base>>( pub fn alloc<CS: ConstraintSystem<<G as Group>::Base>>(
mut cs: CS, mut cs: CS,
u: Option<R1CSInstance<G>>,
u: Option<&R1CSInstance<G>>,
) -> Result<Self, SynthesisError> { ) -> Result<Self, SynthesisError> {
// Check that the incoming instance has exactly 2 io // Check that the incoming instance has exactly 2 io
let W = AllocatedPoint::alloc( let W = AllocatedPoint::alloc(
@ -76,7 +76,7 @@ impl AllocatedRelaxedR1CSInstance {
/// Allocates the given RelaxedR1CSInstance as a witness of the circuit /// Allocates the given RelaxedR1CSInstance as a witness of the circuit
pub fn alloc<CS: ConstraintSystem<<G as Group>::Base>>( pub fn alloc<CS: ConstraintSystem<<G as Group>::Base>>(
mut cs: CS, mut cs: CS,
inst: Option<RelaxedR1CSInstance<G>>,
inst: Option<&RelaxedR1CSInstance<G>>,
limb_width: usize, limb_width: usize,
n_limbs: usize, n_limbs: usize,
) -> Result<Self, SynthesisError> { ) -> Result<Self, SynthesisError> {
@ -104,22 +104,14 @@ impl AllocatedRelaxedR1CSInstance {
// Allocate X0 and X1. If the input instance is None, then allocate default values 0. // Allocate X0 and X1. If the input instance is None, then allocate default values 0.
let X0 = BigNat::alloc_from_nat( let X0 = BigNat::alloc_from_nat(
cs.namespace(|| "allocate X[0]"), cs.namespace(|| "allocate X[0]"),
|| {
Ok(f_to_nat(
&inst.clone().map_or(G::Scalar::ZERO, |inst| inst.X[0]),
))
},
|| Ok(f_to_nat(&inst.map_or(G::Scalar::ZERO, |inst| inst.X[0]))),
limb_width, limb_width,
n_limbs, n_limbs,
)?; )?;
let X1 = BigNat::alloc_from_nat( let X1 = BigNat::alloc_from_nat(
cs.namespace(|| "allocate X[1]"), cs.namespace(|| "allocate X[1]"),
|| {
Ok(f_to_nat(
&inst.clone().map_or(G::Scalar::ZERO, |inst| inst.X[1]),
))
},
|| Ok(f_to_nat(&inst.map_or(G::Scalar::ZERO, |inst| inst.X[1]))),
limb_width, limb_width,
n_limbs, n_limbs,
)?; )?;
@ -170,14 +162,14 @@ impl AllocatedRelaxedR1CSInstance {
let X0 = BigNat::from_num( let X0 = BigNat::from_num(
cs.namespace(|| "allocate X0 from relaxed r1cs"), cs.namespace(|| "allocate X0 from relaxed r1cs"),
Num::from(inst.X0.clone()),
&Num::from(inst.X0),
limb_width, limb_width,
n_limbs, n_limbs,
)?; )?;
let X1 = BigNat::from_num( let X1 = BigNat::from_num(
cs.namespace(|| "allocate X1 from relaxed r1cs"), cs.namespace(|| "allocate X1 from relaxed r1cs"),
Num::from(inst.X1.clone()),
&Num::from(inst.X1),
limb_width, limb_width,
n_limbs, n_limbs,
)?; )?;
@ -246,8 +238,8 @@ impl AllocatedRelaxedR1CSInstance {
&self, &self,
mut cs: CS, mut cs: CS,
params: AllocatedNum<G::Base>, // hash of R1CSShape of F' params: AllocatedNum<G::Base>, // hash of R1CSShape of F'
u: AllocatedR1CSInstance<G>,
T: AllocatedPoint<G>,
u: &AllocatedR1CSInstance<G>,
T: &AllocatedPoint<G>,
ro_consts: ROConstantsCircuit<G>, ro_consts: ROConstantsCircuit<G>,
limb_width: usize, limb_width: usize,
n_limbs: usize, n_limbs: usize,
@ -261,14 +253,14 @@ impl AllocatedRelaxedR1CSInstance {
ro.absorb(T.y.clone()); ro.absorb(T.y.clone());
ro.absorb(T.is_infinity.clone()); ro.absorb(T.is_infinity.clone());
let r_bits = ro.squeeze(cs.namespace(|| "r bits"), NUM_CHALLENGE_BITS)?; let r_bits = ro.squeeze(cs.namespace(|| "r bits"), NUM_CHALLENGE_BITS)?;
let r = le_bits_to_num(cs.namespace(|| "r"), r_bits.clone())?;
let r = le_bits_to_num(cs.namespace(|| "r"), &r_bits)?;
// W_fold = self.W + r * u.W // W_fold = self.W + r * u.W
let rW = u.W.scalar_mul(cs.namespace(|| "r * u.W"), r_bits.clone())?;
let rW = u.W.scalar_mul(cs.namespace(|| "r * u.W"), &r_bits)?;
let W_fold = self.W.add(cs.namespace(|| "self.W + r * u.W"), &rW)?; let W_fold = self.W.add(cs.namespace(|| "self.W + r * u.W"), &rW)?;
// E_fold = self.E + r * T // E_fold = self.E + r * T
let rT = T.scalar_mul(cs.namespace(|| "r * T"), r_bits)?;
let rT = T.scalar_mul(cs.namespace(|| "r * T"), &r_bits)?;
let E_fold = self.E.add(cs.namespace(|| "self.E + r * T"), &rT)?; let E_fold = self.E.add(cs.namespace(|| "self.E + r * T"), &rT)?;
// u_fold = u_r + r // u_fold = u_r + r
@ -286,7 +278,7 @@ impl AllocatedRelaxedR1CSInstance {
// Analyze r into limbs // Analyze r into limbs
let r_bn = BigNat::from_num( let r_bn = BigNat::from_num(
cs.namespace(|| "allocate r_bn"), cs.namespace(|| "allocate r_bn"),
Num::from(r.clone()),
&Num::from(r),
limb_width, limb_width,
n_limbs, n_limbs,
)?; )?;
@ -302,7 +294,7 @@ impl AllocatedRelaxedR1CSInstance {
// Analyze X0 to bignat // Analyze X0 to bignat
let X0_bn = BigNat::from_num( let X0_bn = BigNat::from_num(
cs.namespace(|| "allocate X0_bn"), cs.namespace(|| "allocate X0_bn"),
Num::from(u.X0.clone()),
&Num::from(u.X0.clone()),
limb_width, limb_width,
n_limbs, n_limbs,
)?; )?;
@ -317,7 +309,7 @@ impl AllocatedRelaxedR1CSInstance {
// Analyze X1 to bignat // Analyze X1 to bignat
let X1_bn = BigNat::from_num( let X1_bn = BigNat::from_num(
cs.namespace(|| "allocate X1_bn"), cs.namespace(|| "allocate X1_bn"),
Num::from(u.X1.clone()),
&Num::from(u.X1.clone()),
limb_width, limb_width,
n_limbs, n_limbs,
)?; )?;
@ -342,7 +334,7 @@ impl AllocatedRelaxedR1CSInstance {
pub fn conditionally_select<CS: ConstraintSystem<<G as Group>::Base>>( pub fn conditionally_select<CS: ConstraintSystem<<G as Group>::Base>>(
&self, &self,
mut cs: CS, mut cs: CS,
other: AllocatedRelaxedR1CSInstance<G>,
other: &AllocatedRelaxedR1CSInstance<G>,
condition: &Boolean, condition: &Boolean,
) -> Result<AllocatedRelaxedR1CSInstance<G>, SynthesisError> { ) -> Result<AllocatedRelaxedR1CSInstance<G>, SynthesisError> {
let W = AllocatedPoint::conditionally_select( let W = AllocatedPoint::conditionally_select(

+ 1
- 1
src/gadgets/utils.rs

@ -15,7 +15,7 @@ use num_bigint::BigInt;
/// Gets as input the little indian representation of a number and spits out the number /// Gets as input the little indian representation of a number and spits out the number
pub fn le_bits_to_num<Scalar, CS>( pub fn le_bits_to_num<Scalar, CS>(
mut cs: CS, mut cs: CS,
bits: Vec<AllocatedBit>,
bits: &[AllocatedBit],
) -> Result<AllocatedNum<Scalar>, SynthesisError> ) -> Result<AllocatedNum<Scalar>, SynthesisError>
where where
Scalar: PrimeField + PrimeFieldBits, Scalar: PrimeField + PrimeFieldBits,

+ 12
- 12
src/lib.rs

@ -938,8 +938,8 @@ mod tests {
let res = recursive_snark.verify( let res = recursive_snark.verify(
&pp, &pp,
num_steps, num_steps,
&vec![<G1 as Group>::Scalar::ZERO][..],
&vec![<G2 as Group>::Scalar::ZERO][..],
&[<G1 as Group>::Scalar::ZERO],
&[<G2 as Group>::Scalar::ZERO],
); );
assert!(res.is_ok()); assert!(res.is_ok());
} }
@ -997,8 +997,8 @@ mod tests {
let res = recursive_snark.verify( let res = recursive_snark.verify(
&pp, &pp,
i + 1, i + 1,
&vec![<G1 as Group>::Scalar::ONE][..],
&vec![<G2 as Group>::Scalar::ZERO][..],
&[<G1 as Group>::Scalar::ONE],
&[<G2 as Group>::Scalar::ZERO],
); );
assert!(res.is_ok()); assert!(res.is_ok());
} }
@ -1007,8 +1007,8 @@ mod tests {
let res = recursive_snark.verify( let res = recursive_snark.verify(
&pp, &pp,
num_steps, num_steps,
&vec![<G1 as Group>::Scalar::ONE][..],
&vec![<G2 as Group>::Scalar::ZERO][..],
&[<G1 as Group>::Scalar::ONE],
&[<G2 as Group>::Scalar::ZERO],
); );
assert!(res.is_ok()); assert!(res.is_ok());
@ -1084,8 +1084,8 @@ mod tests {
let res = recursive_snark.verify( let res = recursive_snark.verify(
&pp, &pp,
num_steps, num_steps,
&vec![<G1 as Group>::Scalar::ONE][..],
&vec![<G2 as Group>::Scalar::ZERO][..],
&[<G1 as Group>::Scalar::ONE],
&[<G2 as Group>::Scalar::ZERO],
); );
assert!(res.is_ok()); assert!(res.is_ok());
@ -1178,8 +1178,8 @@ mod tests {
let res = recursive_snark.verify( let res = recursive_snark.verify(
&pp, &pp,
num_steps, num_steps,
&vec![<G1 as Group>::Scalar::ONE][..],
&vec![<G2 as Group>::Scalar::ZERO][..],
&[<G1 as Group>::Scalar::ONE],
&[<G2 as Group>::Scalar::ZERO],
); );
assert!(res.is_ok()); assert!(res.is_ok());
@ -1426,8 +1426,8 @@ mod tests {
let res = recursive_snark.verify( let res = recursive_snark.verify(
&pp, &pp,
num_steps, num_steps,
&vec![<G1 as Group>::Scalar::ONE][..],
&vec![<G2 as Group>::Scalar::ZERO][..],
&[<G1 as Group>::Scalar::ONE],
&[<G2 as Group>::Scalar::ZERO],
); );
assert!(res.is_ok()); assert!(res.is_ok());

+ 2
- 2
src/provider/poseidon.rs

@ -65,7 +65,7 @@ where
constants, constants,
num_absorbs, num_absorbs,
squeezed: false, squeezed: false,
_p: PhantomData::default(),
_p: PhantomData,
} }
} }
@ -236,7 +236,7 @@ mod tests {
} }
let num = ro.squeeze(NUM_CHALLENGE_BITS); let num = ro.squeeze(NUM_CHALLENGE_BITS);
let num2_bits = ro_gadget.squeeze(&mut cs, NUM_CHALLENGE_BITS).unwrap(); let num2_bits = ro_gadget.squeeze(&mut cs, NUM_CHALLENGE_BITS).unwrap();
let num2 = le_bits_to_num(&mut cs, num2_bits).unwrap();
let num2 = le_bits_to_num(&mut cs, &num2_bits).unwrap();
assert_eq!(num.to_repr(), num2.get_value().unwrap().to_repr()); assert_eq!(num.to_repr(), num2.get_value().unwrap().to_repr());
} }

+ 6
- 6
src/spartan/mod.rs

@ -109,7 +109,7 @@ impl PolyEvalInstance {
.iter() .iter()
.zip(powers_of_s.iter()) .zip(powers_of_s.iter())
.map(|(e, p)| *e * p) .map(|(e, p)| *e * p)
.fold(G::Scalar::ZERO, |acc, item| acc + item);
.sum();
let c = c_vec let c = c_vec
.iter() .iter()
.zip(powers_of_s.iter()) .zip(powers_of_s.iter())
@ -378,7 +378,7 @@ impl> RelaxedR1CSSNARKTrait
.iter() .iter()
.zip(powers_of_rho.iter()) .zip(powers_of_rho.iter())
.map(|(u, p)| u.e * p) .map(|(u, p)| u.e * p)
.fold(G::Scalar::ZERO, |acc, item| acc + item);
.sum();
let mut polys_left: Vec<MultilinearPolynomial<G::Scalar>> = w_vec_padded let mut polys_left: Vec<MultilinearPolynomial<G::Scalar>> = w_vec_padded
.iter() .iter()
@ -420,7 +420,7 @@ impl> RelaxedR1CSSNARKTrait
.iter() .iter()
.zip(powers_of_gamma.iter()) .zip(powers_of_gamma.iter())
.map(|(e, g_i)| *e * *g_i) .map(|(e, g_i)| *e * *g_i)
.fold(G::Scalar::ZERO, |acc, item| acc + item);
.sum();
let eval_arg = EE::prove( let eval_arg = EE::prove(
ck, ck,
@ -573,7 +573,7 @@ impl> RelaxedR1CSSNARKTrait
.iter() .iter()
.zip(powers_of_rho.iter()) .zip(powers_of_rho.iter())
.map(|(u, p)| u.e * p) .map(|(u, p)| u.e * p)
.fold(G::Scalar::ZERO, |acc, item| acc + item);
.sum();
let num_rounds_z = u_vec_padded[0].x.len(); let num_rounds_z = u_vec_padded[0].x.len();
let (claim_batch_final, r_z) = let (claim_batch_final, r_z) =
@ -593,7 +593,7 @@ impl> RelaxedR1CSSNARKTrait
.zip(self.evals_batch.iter()) .zip(self.evals_batch.iter())
.zip(powers_of_rho.iter()) .zip(powers_of_rho.iter())
.map(|((e_i, p_i), rho_i)| *e_i * *p_i * rho_i) .map(|((e_i, p_i), rho_i)| *e_i * *p_i * rho_i)
.fold(G::Scalar::ZERO, |acc, item| acc + item)
.sum()
}; };
if claim_batch_final != claim_batch_final_expected { if claim_batch_final != claim_batch_final_expected {
@ -615,7 +615,7 @@ impl> RelaxedR1CSSNARKTrait
.iter() .iter()
.zip(powers_of_gamma.iter()) .zip(powers_of_gamma.iter())
.map(|(e, g_i)| *e * *g_i) .map(|(e, g_i)| *e * *g_i)
.fold(G::Scalar::ZERO, |acc, item| acc + item);
.sum();
// verify // verify
EE::verify( EE::verify(

+ 17
- 23
src/spartan/pp.rs

@ -863,7 +863,7 @@ impl> RelaxedR1CSSNARK
.iter() .iter()
.zip(coeffs.iter()) .zip(coeffs.iter())
.map(|(c_1, c_2)| *c_1 * c_2) .map(|(c_1, c_2)| *c_1 * c_2)
.fold(G::Scalar::ZERO, |acc, item| acc + item);
.sum();
let mut e = claim; let mut e = claim;
let mut r: Vec<G::Scalar> = Vec::new(); let mut r: Vec<G::Scalar> = Vec::new();
@ -876,15 +876,9 @@ impl> RelaxedR1CSSNARK
evals.extend(inner.evaluation_points()); evals.extend(inner.evaluation_points());
assert_eq!(evals.len(), num_claims); assert_eq!(evals.len(), num_claims);
let evals_combined_0 = (0..evals.len())
.map(|i| evals[i][0] * coeffs[i])
.fold(G::Scalar::ZERO, |acc, item| acc + item);
let evals_combined_2 = (0..evals.len())
.map(|i| evals[i][1] * coeffs[i])
.fold(G::Scalar::ZERO, |acc, item| acc + item);
let evals_combined_3 = (0..evals.len())
.map(|i| evals[i][2] * coeffs[i])
.fold(G::Scalar::ZERO, |acc, item| acc + item);
let evals_combined_0 = (0..evals.len()).map(|i| evals[i][0] * coeffs[i]).sum();
let evals_combined_2 = (0..evals.len()).map(|i| evals[i][1] * coeffs[i]).sum();
let evals_combined_3 = (0..evals.len()).map(|i| evals[i][2] * coeffs[i]).sum();
let evals = vec![ let evals = vec![
evals_combined_0, evals_combined_0,
@ -1242,13 +1236,13 @@ impl> RelaxedR1CSSNARKTrait
.iter() .iter()
.zip(powers_of_rho.iter()) .zip(powers_of_rho.iter())
.map(|(e, p)| *e * p) .map(|(e, p)| *e * p)
.fold(G::Scalar::ZERO, |acc, item| acc + item);
.sum();
let eval_output = eval_output_vec let eval_output = eval_output_vec
.iter() .iter()
.zip(powers_of_rho.iter()) .zip(powers_of_rho.iter())
.map(|(e, p)| *e * p) .map(|(e, p)| *e * p)
.fold(G::Scalar::ZERO, |acc, item| acc + item);
.sum();
let comm_output = mem_sc_inst let comm_output = mem_sc_inst
.comm_output_vec .comm_output_vec
@ -1274,7 +1268,7 @@ impl> RelaxedR1CSSNARKTrait
.iter() .iter()
.zip(powers_of_rho.iter()) .zip(powers_of_rho.iter())
.map(|(e, p)| *e * p) .map(|(e, p)| *e * p)
.fold(G::Scalar::ZERO, |acc, item| acc + item);
.sum();
// eval_output = output(r_sat) // eval_output = output(r_sat)
w_u_vec.push(( w_u_vec.push((
@ -1466,7 +1460,7 @@ impl> RelaxedR1CSSNARKTrait
.iter() .iter()
.zip(powers_of_rho.iter()) .zip(powers_of_rho.iter())
.map(|(u, p)| u.e * p) .map(|(u, p)| u.e * p)
.fold(G::Scalar::ZERO, |acc, item| acc + item);
.sum();
let mut polys_left: Vec<MultilinearPolynomial<G::Scalar>> = w_vec_padded let mut polys_left: Vec<MultilinearPolynomial<G::Scalar>> = w_vec_padded
.iter() .iter()
@ -1508,7 +1502,7 @@ impl> RelaxedR1CSSNARKTrait
.iter() .iter()
.zip(powers_of_gamma.iter()) .zip(powers_of_gamma.iter())
.map(|(e, g_i)| *e * *g_i) .map(|(e, g_i)| *e * *g_i)
.fold(G::Scalar::ZERO, |acc, item| acc + item);
.sum();
let eval_arg = EE::prove( let eval_arg = EE::prove(
ck, ck,
@ -1677,13 +1671,13 @@ impl> RelaxedR1CSSNARKTrait
// verify claim_sat_final // verify claim_sat_final
let taus_bound_r_sat = EqPolynomial::new(tau.clone()).evaluate(&r_sat); let taus_bound_r_sat = EqPolynomial::new(tau.clone()).evaluate(&r_sat);
let rand_eq_bound_r_sat = EqPolynomial::new(rand_eq).evaluate(&r_sat); let rand_eq_bound_r_sat = EqPolynomial::new(rand_eq).evaluate(&r_sat);
let claim_mem_final_expected = (0..8)
let claim_mem_final_expected: G::Scalar = (0..8)
.map(|i| { .map(|i| {
coeffs[i] coeffs[i]
* rand_eq_bound_r_sat * rand_eq_bound_r_sat
* (self.eval_left_arr[i] * self.eval_right_arr[i] - self.eval_output_arr[i]) * (self.eval_left_arr[i] * self.eval_right_arr[i] - self.eval_output_arr[i])
}) })
.fold(G::Scalar::ZERO, |acc, item| acc + item);
.sum();
let claim_outer_final_expected = coeffs[8] let claim_outer_final_expected = coeffs[8]
* taus_bound_r_sat * taus_bound_r_sat
* (self.eval_Az * self.eval_Bz - U.u * self.eval_Cz - self.eval_E); * (self.eval_Az * self.eval_Bz - U.u * self.eval_Cz - self.eval_E);
@ -1753,14 +1747,14 @@ impl> RelaxedR1CSSNARKTrait
.iter() .iter()
.zip(powers_of_rho.iter()) .zip(powers_of_rho.iter())
.map(|(e, p)| *e * p) .map(|(e, p)| *e * p)
.fold(G::Scalar::ZERO, |acc, item| acc + item);
.sum();
let eval_output = self let eval_output = self
.eval_output_arr .eval_output_arr
.iter() .iter()
.zip(powers_of_rho.iter()) .zip(powers_of_rho.iter())
.map(|(e, p)| *e * p) .map(|(e, p)| *e * p)
.fold(G::Scalar::ZERO, |acc, item| acc + item);
.sum();
let comm_output = comm_output_vec let comm_output = comm_output_vec
.iter() .iter()
@ -1773,7 +1767,7 @@ impl> RelaxedR1CSSNARKTrait
.iter() .iter()
.zip(powers_of_rho.iter()) .zip(powers_of_rho.iter())
.map(|(e, p)| *e * p) .map(|(e, p)| *e * p)
.fold(G::Scalar::ZERO, |acc, item| acc + item);
.sum();
// eval_output = output(r_sat) // eval_output = output(r_sat)
u_vec.push(PolyEvalInstance { u_vec.push(PolyEvalInstance {
@ -1994,7 +1988,7 @@ impl> RelaxedR1CSSNARKTrait
.iter() .iter()
.zip(powers_of_rho.iter()) .zip(powers_of_rho.iter())
.map(|(u, p)| u.e * p) .map(|(u, p)| u.e * p)
.fold(G::Scalar::ZERO, |acc, item| acc + item);
.sum();
let num_rounds_z = u_vec_padded[0].x.len(); let num_rounds_z = u_vec_padded[0].x.len();
let (claim_batch_final, r_z) = let (claim_batch_final, r_z) =
@ -2014,7 +2008,7 @@ impl> RelaxedR1CSSNARKTrait
.zip(self.evals_batch_arr.iter()) .zip(self.evals_batch_arr.iter())
.zip(powers_of_rho.iter()) .zip(powers_of_rho.iter())
.map(|((e_i, p_i), rho_i)| *e_i * *p_i * rho_i) .map(|((e_i, p_i), rho_i)| *e_i * *p_i * rho_i)
.fold(G::Scalar::ZERO, |acc, item| acc + item)
.sum()
}; };
if claim_batch_final != claim_batch_final_expected { if claim_batch_final != claim_batch_final_expected {
@ -2036,7 +2030,7 @@ impl> RelaxedR1CSSNARKTrait
.iter() .iter()
.zip(powers_of_gamma.iter()) .zip(powers_of_gamma.iter())
.map(|(e, g_i)| *e * *g_i) .map(|(e, g_i)| *e * *g_i)
.fold(G::Scalar::ZERO, |acc, item| acc + item);
.sum();
// verify // verify
EE::verify( EE::verify(

+ 5
- 9
src/spartan/sumcheck.rs

@ -163,12 +163,8 @@ impl SumcheckProof {
evals.push((eval_point_0, eval_point_2)); evals.push((eval_point_0, eval_point_2));
} }
let evals_combined_0 = (0..evals.len())
.map(|i| evals[i].0 * coeffs[i])
.fold(G::Scalar::ZERO, |acc, item| acc + item);
let evals_combined_2 = (0..evals.len())
.map(|i| evals[i].1 * coeffs[i])
.fold(G::Scalar::ZERO, |acc, item| acc + item);
let evals_combined_0 = (0..evals.len()).map(|i| evals[i].0 * coeffs[i]).sum();
let evals_combined_2 = (0..evals.len()).map(|i| evals[i].1 * coeffs[i]).sum();
let evals = vec![evals_combined_0, e - evals_combined_0, evals_combined_2]; let evals = vec![evals_combined_0, e - evals_combined_0, evals_combined_2];
let poly = UniPoly::from_evals(&evals); let poly = UniPoly::from_evals(&evals);
@ -387,9 +383,9 @@ impl CompressedUniPoly {
} }
let mut coeffs: Vec<G::Scalar> = Vec::new(); let mut coeffs: Vec<G::Scalar> = Vec::new();
coeffs.extend(vec![&self.coeffs_except_linear_term[0]]);
coeffs.extend(vec![&linear_term]);
coeffs.extend(self.coeffs_except_linear_term[1..].to_vec());
coeffs.push(self.coeffs_except_linear_term[0]);
coeffs.push(linear_term);
coeffs.extend(&self.coeffs_except_linear_term[1..]);
assert_eq!(self.coeffs_except_linear_term.len() + 1, coeffs.len()); assert_eq!(self.coeffs_except_linear_term.len() + 1, coeffs.len());
UniPoly { coeffs } UniPoly { coeffs }
} }

Loading…
Cancel
Save