From 7ec2f57b84938e3f4c9e569c77d684eeb8ea6909 Mon Sep 17 00:00:00 2001 From: Srinath Setty Date: Tue, 23 Aug 2022 15:05:04 -0700 Subject: [PATCH] optimize ECC ops (#110) * optimize ECC ops * update version --- Cargo.toml | 2 +- src/circuit.rs | 4 +- src/gadgets/ecc.rs | 325 +++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 303 insertions(+), 28 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0af2e5d..128b6f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "nova-snark" -version = "0.8.0" +version = "0.8.1" authors = ["Srinath Setty "] edition = "2021" description = "Recursive zkSNARKs without trusted setup" diff --git a/src/circuit.rs b/src/circuit.rs index 6443f27..e7abc14 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -412,7 +412,7 @@ mod tests { let mut cs: ShapeCS = ShapeCS::new(); let _ = circuit1.synthesize(&mut cs); let (shape1, gens1) = (cs.r1cs_shape(), cs.r1cs_gens()); - assert_eq!(cs.num_constraints(), 18967); + assert_eq!(cs.num_constraints(), 9819); // Initialize the shape and gens for the secondary let circuit2: NovaAugmentedCircuit::Base>> = @@ -425,7 +425,7 @@ mod tests { let mut cs: ShapeCS = ShapeCS::new(); let _ = circuit2.synthesize(&mut cs); let (shape2, gens2) = (cs.r1cs_shape(), cs.r1cs_gens()); - assert_eq!(cs.num_constraints(), 19499); + assert_eq!(cs.num_constraints(), 10351); // Execute the base case for the primary let zero1 = <::Base as Field>::zero(); diff --git a/src/gadgets/ecc.rs b/src/gadgets/ecc.rs index 86a1443..5a8dd79 100644 --- a/src/gadgets/ecc.rs +++ b/src/gadgets/ecc.rs @@ -109,6 +109,24 @@ where Ok(()) } + /// Negates the provided point + pub fn negate>(&self, mut cs: CS) -> Result { + let y = AllocatedNum::alloc(cs.namespace(|| "y"), || Ok(-*self.y.get_value().get()?))?; + + cs.enforce( + || "check y = - self.y", + |lc| lc + self.y.get_variable(), + |lc| lc + CS::one(), + |lc| lc - y.get_variable(), + ); + + Ok(Self { + x: self.x.clone(), + y, + is_infinity: self.is_infinity.clone(), + }) + } + /// Add two points (may be equal) pub fn add>( &self, @@ -434,34 +452,95 @@ where Ok(Self { x, y, is_infinity }) } - /// A gadget for scalar multiplication. + /// A gadget for scalar multiplication, optimized to use incomplete addition law. + /// The optimization here is analogous to https://github.com/arkworks-rs/r1cs-std/blob/6d64f379a27011b3629cf4c9cb38b7b7b695d5a0/src/groups/curves/short_weierstrass/mod.rs#L295, + /// except we use complete addition law over affine coordinates instead of projective coordinates for the tail bits pub fn scalar_mul>( &self, mut cs: CS, - scalar: Vec, + scalar_bits: Vec, ) -> Result { - let mut res = Self::default(cs.namespace(|| "res"))?; - for i in (0..scalar.len()).rev() { - /*************************************************************/ - // res = res.double(); - /*************************************************************/ - - res = res.double(cs.namespace(|| format!("{}: double", i)))?; - - /*************************************************************/ - // if scalar[i] { - // res = self.add(&res); - // } - /*************************************************************/ - let self_and_res = self.add(cs.namespace(|| format!("{}: add", i)), &res)?; - res = Self::conditionally_select( - cs.namespace(|| format!("{}: Update res", i)), - &self_and_res, - &res, - &Boolean::from(scalar[i].clone()), + let split_len = core::cmp::min(scalar_bits.len(), (Fp::NUM_BITS - 2) as usize); + let (incomplete_bits, complete_bits) = scalar_bits.split_at(split_len); + + // we convert AllocatedPoint into AllocatedPointNonInfinity; we deal with the case where self.is_infinity = 1 below + let mut p = AllocatedPointNonInfinity::from_allocated_point(self); + + // we assume the first bit to be 1, so we must initialize acc to self and double it + // we remove this assumption below + let mut acc = p.clone(); + p = p.double_incomplete(cs.namespace(|| "double"))?; + + // perform the double-and-add loop to compute the scalar mul using incomplete addition law + for (i, bit) in incomplete_bits.iter().enumerate().skip(1) { + let temp = acc.add_incomplete(cs.namespace(|| format!("add {}", i)), &p)?; + acc = AllocatedPointNonInfinity::conditionally_select( + cs.namespace(|| format!("acc_iteration_{}", i)), + &temp, + &acc, + &Boolean::from(bit.clone()), + )?; + + p = p.double_incomplete(cs.namespace(|| format!("double {}", i)))?; + } + + // convert back to AllocatedPoint + let res = { + // we set acc.is_infinity = self.is_infinity + let acc = acc.to_allocated_point(&self.is_infinity)?; + + // we remove the initial slack if bits[0] is as not as assumed (i.e., it is not 1) + let acc_minus_initial = { + let neg = self.negate(cs.namespace(|| "negate"))?; + acc.add(cs.namespace(|| "res minus self"), &neg) + }?; + + AllocatedPoint::conditionally_select( + cs.namespace(|| "remove slack if necessary"), + &acc, + &acc_minus_initial, + &Boolean::from(scalar_bits[0].clone()), + )? + }; + + // when self.is_infinity = 1, return the default point, else return res + // we already set res.is_infinity to be self.is_infinity, so we do not need to set it here + let default = Self::default(cs.namespace(|| "default"))?; + let x = conditionally_select2( + cs.namespace(|| "check if self.is_infinity is zero (x)"), + &default.x, + &res.x, + &self.is_infinity, + )?; + + let y = conditionally_select2( + cs.namespace(|| "check if self.is_infinity is zero (y)"), + &default.y, + &res.y, + &self.is_infinity, + )?; + + // we now perform the remaining scalar mul using complete addition law + let mut acc = AllocatedPoint { + x, + y, + is_infinity: res.is_infinity, + }; + let mut p_complete = p.to_allocated_point(&self.is_infinity)?; + + for (i, bit) in complete_bits.iter().enumerate() { + let temp = acc.add(cs.namespace(|| format!("add_complete {}", i)), &p_complete)?; + acc = AllocatedPoint::conditionally_select( + cs.namespace(|| format!("acc_complete_iteration_{}", i)), + &temp, + &acc, + &Boolean::from(bit.clone()), )?; + + p_complete = p_complete.double(cs.namespace(|| format!("double_complete {}", i)))?; } - Ok(res) + + Ok(acc) } /// If condition outputs a otherwise outputs b @@ -639,6 +718,199 @@ where } } +#[derive(Clone)] +/// AllocatedPoint but one that is guaranteed to be not infinity +pub struct AllocatedPointNonInfinity +where + Fp: PrimeField, +{ + x: AllocatedNum, + y: AllocatedNum, +} + +impl AllocatedPointNonInfinity +where + Fp: PrimeField, +{ + /// Creates a new AllocatedPointNonInfinity from the specified coordinates + pub fn new(x: AllocatedNum, y: AllocatedNum) -> Self { + Self { x, y } + } + + /// Allocates a new point on the curve using coordinates provided by `coords`. + pub fn alloc(mut cs: CS, coords: Option<(Fp, Fp)>) -> Result + where + CS: ConstraintSystem, + { + let x = AllocatedNum::alloc(cs.namespace(|| "x"), || { + coords.map_or(Err(SynthesisError::AssignmentMissing), |c| Ok(c.0)) + })?; + let y = AllocatedNum::alloc(cs.namespace(|| "y"), || { + coords.map_or(Err(SynthesisError::AssignmentMissing), |c| Ok(c.1)) + })?; + + Ok(Self { x, y }) + } + + /// Turns an AllocatedPoint into an AllocatedPointNonInfinity (assumes it is not infinity) + pub fn from_allocated_point(p: &AllocatedPoint) -> Self { + Self { + x: p.x.clone(), + y: p.y.clone(), + } + } + + /// Returns an AllocatedPoint from an AllocatedPointNonInfinity + pub fn to_allocated_point( + &self, + is_infinity: &AllocatedNum, + ) -> Result, SynthesisError> { + Ok(AllocatedPoint { + x: self.x.clone(), + y: self.y.clone(), + is_infinity: is_infinity.clone(), + }) + } + + /// Returns coordinates associated with the point. + pub fn get_coordinates(&self) -> (&AllocatedNum, &AllocatedNum) { + (&self.x, &self.y) + } + + /// Add two points assuming self != +/- other + pub fn add_incomplete(&self, mut cs: CS, other: &Self) -> Result + where + CS: ConstraintSystem, + { + // allocate a free variable that an honest prover sets to lambda = (y2-y1)/(x2-x1) + let lambda = AllocatedNum::alloc(cs.namespace(|| "lambda"), || { + if *other.x.get_value().get()? == *self.x.get_value().get()? { + Ok(Fp::one()) + } else { + Ok( + (*other.y.get_value().get()? - *self.y.get_value().get()?) + * (*other.x.get_value().get()? - *self.x.get_value().get()?) + .invert() + .unwrap(), + ) + } + })?; + cs.enforce( + || "Check that lambda is computed correctly", + |lc| lc + lambda.get_variable(), + |lc| lc + other.x.get_variable() - self.x.get_variable(), + |lc| lc + other.y.get_variable() - self.y.get_variable(), + ); + + //************************************************************************/ + // x = lambda * lambda - self.x - other.x; + //************************************************************************/ + let x = AllocatedNum::alloc(cs.namespace(|| "x"), || { + Ok( + *lambda.get_value().get()? * lambda.get_value().get()? + - *self.x.get_value().get()? + - *other.x.get_value().get()?, + ) + })?; + cs.enforce( + || "check that x is correct", + |lc| lc + lambda.get_variable(), + |lc| lc + lambda.get_variable(), + |lc| lc + x.get_variable() + self.x.get_variable() + other.x.get_variable(), + ); + + //************************************************************************/ + // y = lambda * (self.x - x) - self.y; + //************************************************************************/ + let y = AllocatedNum::alloc(cs.namespace(|| "y"), || { + Ok( + *lambda.get_value().get()? * (*self.x.get_value().get()? - *x.get_value().get()?) + - *self.y.get_value().get()?, + ) + })?; + + cs.enforce( + || "Check that y is correct", + |lc| lc + lambda.get_variable(), + |lc| lc + self.x.get_variable() - x.get_variable(), + |lc| lc + y.get_variable() + self.y.get_variable(), + ); + + Ok(Self { x, y }) + } + + /// doubles the point; since this is called with a point not at infinity, it is guaranteed to be not infinity + pub fn double_incomplete(&self, mut cs: CS) -> Result + where + CS: ConstraintSystem, + { + // lambda = (3 x^2 + a) / 2 * y. For pasta curves, a = 0 + + let x_sq = self.x.square(cs.namespace(|| "x_sq"))?; + + let lambda = AllocatedNum::alloc(cs.namespace(|| "lambda"), || { + let n = Fp::from(3) * x_sq.get_value().get()?; + let d = Fp::from(2) * *self.y.get_value().get()?; + if d == Fp::zero() { + Ok(Fp::one()) + } else { + Ok(n * d.invert().unwrap()) + } + })?; + cs.enforce( + || "Check that lambda is computed correctly", + |lc| lc + lambda.get_variable(), + |lc| lc + (Fp::from(2), self.y.get_variable()), + |lc| lc + (Fp::from(3), x_sq.get_variable()), + ); + + let x = AllocatedNum::alloc(cs.namespace(|| "x"), || { + Ok( + *lambda.get_value().get()? * *lambda.get_value().get()? + - *self.x.get_value().get()? + - *self.x.get_value().get()?, + ) + })?; + + cs.enforce( + || "check that x is correct", + |lc| lc + lambda.get_variable(), + |lc| lc + lambda.get_variable(), + |lc| lc + x.get_variable() + (Fp::from(2), self.x.get_variable()), + ); + + let y = AllocatedNum::alloc(cs.namespace(|| "y"), || { + Ok( + *lambda.get_value().get()? * (*self.x.get_value().get()? - *x.get_value().get()?) + - *self.y.get_value().get()?, + ) + })?; + + cs.enforce( + || "Check that y is correct", + |lc| lc + lambda.get_variable(), + |lc| lc + self.x.get_variable() - x.get_variable(), + |lc| lc + y.get_variable() + self.y.get_variable(), + ); + + Ok(Self { x, y }) + } + + /// If condition outputs a otherwise outputs b + pub fn conditionally_select>( + mut cs: CS, + a: &Self, + b: &Self, + condition: &Boolean, + ) -> Result { + let x = conditionally_select(cs.namespace(|| "select x"), &a.x, &b.x, condition)?; + + let y = conditionally_select(cs.namespace(|| "select y"), &a.y, &b.y, condition)?; + + Ok(Self { x, y }) + } +} + #[cfg(test)] mod tests { use super::*; @@ -696,14 +968,16 @@ mod tests { assert_eq!(e_pasta, e_pasta_2); } - use crate::bellperson::{shape_cs::ShapeCS, solver::SatisfyingAssignment}; + use crate::bellperson::{ + r1cs::{NovaShape, NovaWitness}, + {shape_cs::ShapeCS, solver::SatisfyingAssignment}, + }; use ff::{Field, PrimeFieldBits}; use pasta_curves::{arithmetic::CurveAffine, group::Curve, EpAffine}; use std::ops::Mul; type G = pasta_curves::pallas::Point; type Fp = pasta_curves::pallas::Scalar; type Fq = pasta_curves::vesta::Scalar; - use crate::bellperson::r1cs::{NovaShape, NovaWitness}; fn synthesize_smul(mut cs: CS) -> (AllocatedPoint, AllocatedPoint, Fq) where @@ -713,8 +987,9 @@ mod tests { { let a = AllocatedPoint::::random_vartime(cs.namespace(|| "a")).unwrap(); a.inputize(cs.namespace(|| "inputize a")).unwrap(); + let s = Fq::random(&mut OsRng); - // Allocate random bits and only keep 128 bits + // Allocate bits for s let bits: Vec = s .to_le_bits() .into_iter()