Browse Source

optimize ECC ops (#110)

* optimize ECC ops

* update version
main
Srinath Setty 2 years ago
committed by GitHub
parent
commit
7ec2f57b84
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 303 additions and 28 deletions
  1. +1
    -1
      Cargo.toml
  2. +2
    -2
      src/circuit.rs
  3. +300
    -25
      src/gadgets/ecc.rs

+ 1
- 1
Cargo.toml

@ -1,6 +1,6 @@
[package] [package]
name = "nova-snark" name = "nova-snark"
version = "0.8.0"
version = "0.8.1"
authors = ["Srinath Setty <srinath@microsoft.com>"] authors = ["Srinath Setty <srinath@microsoft.com>"]
edition = "2021" edition = "2021"
description = "Recursive zkSNARKs without trusted setup" description = "Recursive zkSNARKs without trusted setup"

+ 2
- 2
src/circuit.rs

@ -412,7 +412,7 @@ mod tests {
let mut cs: ShapeCS<G1> = ShapeCS::new(); let mut cs: ShapeCS<G1> = ShapeCS::new();
let _ = circuit1.synthesize(&mut cs); let _ = circuit1.synthesize(&mut cs);
let (shape1, gens1) = (cs.r1cs_shape(), cs.r1cs_gens()); let (shape1, gens1) = (cs.r1cs_shape(), cs.r1cs_gens());
assert_eq!(cs.num_constraints(), 18967);
assert_eq!(cs.num_constraints(), 9819);
// Initialize the shape and gens for the secondary // Initialize the shape and gens for the secondary
let circuit2: NovaAugmentedCircuit<G1, TrivialTestCircuit<<G1 as Group>::Base>> = let circuit2: NovaAugmentedCircuit<G1, TrivialTestCircuit<<G1 as Group>::Base>> =
@ -425,7 +425,7 @@ mod tests {
let mut cs: ShapeCS<G2> = ShapeCS::new(); let mut cs: ShapeCS<G2> = ShapeCS::new();
let _ = circuit2.synthesize(&mut cs); let _ = circuit2.synthesize(&mut cs);
let (shape2, gens2) = (cs.r1cs_shape(), cs.r1cs_gens()); let (shape2, gens2) = (cs.r1cs_shape(), cs.r1cs_gens());
assert_eq!(cs.num_constraints(), 19499);
assert_eq!(cs.num_constraints(), 10351);
// Execute the base case for the primary // Execute the base case for the primary
let zero1 = <<G2 as Group>::Base as Field>::zero(); let zero1 = <<G2 as Group>::Base as Field>::zero();

+ 300
- 25
src/gadgets/ecc.rs

@ -109,6 +109,24 @@ where
Ok(()) Ok(())
} }
/// Negates the provided point
pub fn negate<CS: ConstraintSystem<Fp>>(&self, mut cs: CS) -> Result<Self, SynthesisError> {
let y = AllocatedNum::alloc(cs.namespace(|| "y"), || Ok(-*self.y.get_value().get()?))?;
cs.enforce(
|| "check y = - self.y",
|lc| lc + self.y.get_variable(),
|lc| lc + CS::one(),
|lc| lc - y.get_variable(),
);
Ok(Self {
x: self.x.clone(),
y,
is_infinity: self.is_infinity.clone(),
})
}
/// Add two points (may be equal) /// Add two points (may be equal)
pub fn add<CS: ConstraintSystem<Fp>>( pub fn add<CS: ConstraintSystem<Fp>>(
&self, &self,
@ -434,34 +452,95 @@ where
Ok(Self { x, y, is_infinity }) Ok(Self { x, y, is_infinity })
} }
/// A gadget for scalar multiplication.
/// A gadget for scalar multiplication, optimized to use incomplete addition law.
/// The optimization here is analogous to https://github.com/arkworks-rs/r1cs-std/blob/6d64f379a27011b3629cf4c9cb38b7b7b695d5a0/src/groups/curves/short_weierstrass/mod.rs#L295,
/// except we use complete addition law over affine coordinates instead of projective coordinates for the tail bits
pub fn scalar_mul<CS: ConstraintSystem<Fp>>( pub fn scalar_mul<CS: ConstraintSystem<Fp>>(
&self, &self,
mut cs: CS, mut cs: CS,
scalar: Vec<AllocatedBit>,
scalar_bits: Vec<AllocatedBit>,
) -> Result<Self, SynthesisError> { ) -> Result<Self, SynthesisError> {
let mut res = Self::default(cs.namespace(|| "res"))?;
for i in (0..scalar.len()).rev() {
/*************************************************************/
// res = res.double();
/*************************************************************/
res = res.double(cs.namespace(|| format!("{}: double", i)))?;
/*************************************************************/
// if scalar[i] {
// res = self.add(&res);
// }
/*************************************************************/
let self_and_res = self.add(cs.namespace(|| format!("{}: add", i)), &res)?;
res = Self::conditionally_select(
cs.namespace(|| format!("{}: Update res", i)),
&self_and_res,
&res,
&Boolean::from(scalar[i].clone()),
let split_len = core::cmp::min(scalar_bits.len(), (Fp::NUM_BITS - 2) as usize);
let (incomplete_bits, complete_bits) = scalar_bits.split_at(split_len);
// we convert AllocatedPoint into AllocatedPointNonInfinity; we deal with the case where self.is_infinity = 1 below
let mut p = AllocatedPointNonInfinity::from_allocated_point(self);
// we assume the first bit to be 1, so we must initialize acc to self and double it
// we remove this assumption below
let mut acc = p.clone();
p = p.double_incomplete(cs.namespace(|| "double"))?;
// perform the double-and-add loop to compute the scalar mul using incomplete addition law
for (i, bit) in incomplete_bits.iter().enumerate().skip(1) {
let temp = acc.add_incomplete(cs.namespace(|| format!("add {}", i)), &p)?;
acc = AllocatedPointNonInfinity::conditionally_select(
cs.namespace(|| format!("acc_iteration_{}", i)),
&temp,
&acc,
&Boolean::from(bit.clone()),
)?;
p = p.double_incomplete(cs.namespace(|| format!("double {}", i)))?;
}
// convert back to AllocatedPoint
let res = {
// we set acc.is_infinity = self.is_infinity
let acc = acc.to_allocated_point(&self.is_infinity)?;
// we remove the initial slack if bits[0] is as not as assumed (i.e., it is not 1)
let acc_minus_initial = {
let neg = self.negate(cs.namespace(|| "negate"))?;
acc.add(cs.namespace(|| "res minus self"), &neg)
}?;
AllocatedPoint::conditionally_select(
cs.namespace(|| "remove slack if necessary"),
&acc,
&acc_minus_initial,
&Boolean::from(scalar_bits[0].clone()),
)?
};
// when self.is_infinity = 1, return the default point, else return res
// we already set res.is_infinity to be self.is_infinity, so we do not need to set it here
let default = Self::default(cs.namespace(|| "default"))?;
let x = conditionally_select2(
cs.namespace(|| "check if self.is_infinity is zero (x)"),
&default.x,
&res.x,
&self.is_infinity,
)?;
let y = conditionally_select2(
cs.namespace(|| "check if self.is_infinity is zero (y)"),
&default.y,
&res.y,
&self.is_infinity,
)?;
// we now perform the remaining scalar mul using complete addition law
let mut acc = AllocatedPoint {
x,
y,
is_infinity: res.is_infinity,
};
let mut p_complete = p.to_allocated_point(&self.is_infinity)?;
for (i, bit) in complete_bits.iter().enumerate() {
let temp = acc.add(cs.namespace(|| format!("add_complete {}", i)), &p_complete)?;
acc = AllocatedPoint::conditionally_select(
cs.namespace(|| format!("acc_complete_iteration_{}", i)),
&temp,
&acc,
&Boolean::from(bit.clone()),
)?; )?;
p_complete = p_complete.double(cs.namespace(|| format!("double_complete {}", i)))?;
} }
Ok(res)
Ok(acc)
} }
/// If condition outputs a otherwise outputs b /// If condition outputs a otherwise outputs b
@ -639,6 +718,199 @@ where
} }
} }
#[derive(Clone)]
/// AllocatedPoint but one that is guaranteed to be not infinity
pub struct AllocatedPointNonInfinity<Fp>
where
Fp: PrimeField,
{
x: AllocatedNum<Fp>,
y: AllocatedNum<Fp>,
}
impl<Fp> AllocatedPointNonInfinity<Fp>
where
Fp: PrimeField,
{
/// Creates a new AllocatedPointNonInfinity from the specified coordinates
pub fn new(x: AllocatedNum<Fp>, y: AllocatedNum<Fp>) -> Self {
Self { x, y }
}
/// Allocates a new point on the curve using coordinates provided by `coords`.
pub fn alloc<CS>(mut cs: CS, coords: Option<(Fp, Fp)>) -> Result<Self, SynthesisError>
where
CS: ConstraintSystem<Fp>,
{
let x = AllocatedNum::alloc(cs.namespace(|| "x"), || {
coords.map_or(Err(SynthesisError::AssignmentMissing), |c| Ok(c.0))
})?;
let y = AllocatedNum::alloc(cs.namespace(|| "y"), || {
coords.map_or(Err(SynthesisError::AssignmentMissing), |c| Ok(c.1))
})?;
Ok(Self { x, y })
}
/// Turns an AllocatedPoint into an AllocatedPointNonInfinity (assumes it is not infinity)
pub fn from_allocated_point(p: &AllocatedPoint<Fp>) -> Self {
Self {
x: p.x.clone(),
y: p.y.clone(),
}
}
/// Returns an AllocatedPoint from an AllocatedPointNonInfinity
pub fn to_allocated_point(
&self,
is_infinity: &AllocatedNum<Fp>,
) -> Result<AllocatedPoint<Fp>, SynthesisError> {
Ok(AllocatedPoint {
x: self.x.clone(),
y: self.y.clone(),
is_infinity: is_infinity.clone(),
})
}
/// Returns coordinates associated with the point.
pub fn get_coordinates(&self) -> (&AllocatedNum<Fp>, &AllocatedNum<Fp>) {
(&self.x, &self.y)
}
/// Add two points assuming self != +/- other
pub fn add_incomplete<CS>(&self, mut cs: CS, other: &Self) -> Result<Self, SynthesisError>
where
CS: ConstraintSystem<Fp>,
{
// allocate a free variable that an honest prover sets to lambda = (y2-y1)/(x2-x1)
let lambda = AllocatedNum::alloc(cs.namespace(|| "lambda"), || {
if *other.x.get_value().get()? == *self.x.get_value().get()? {
Ok(Fp::one())
} else {
Ok(
(*other.y.get_value().get()? - *self.y.get_value().get()?)
* (*other.x.get_value().get()? - *self.x.get_value().get()?)
.invert()
.unwrap(),
)
}
})?;
cs.enforce(
|| "Check that lambda is computed correctly",
|lc| lc + lambda.get_variable(),
|lc| lc + other.x.get_variable() - self.x.get_variable(),
|lc| lc + other.y.get_variable() - self.y.get_variable(),
);
//************************************************************************/
// x = lambda * lambda - self.x - other.x;
//************************************************************************/
let x = AllocatedNum::alloc(cs.namespace(|| "x"), || {
Ok(
*lambda.get_value().get()? * lambda.get_value().get()?
- *self.x.get_value().get()?
- *other.x.get_value().get()?,
)
})?;
cs.enforce(
|| "check that x is correct",
|lc| lc + lambda.get_variable(),
|lc| lc + lambda.get_variable(),
|lc| lc + x.get_variable() + self.x.get_variable() + other.x.get_variable(),
);
//************************************************************************/
// y = lambda * (self.x - x) - self.y;
//************************************************************************/
let y = AllocatedNum::alloc(cs.namespace(|| "y"), || {
Ok(
*lambda.get_value().get()? * (*self.x.get_value().get()? - *x.get_value().get()?)
- *self.y.get_value().get()?,
)
})?;
cs.enforce(
|| "Check that y is correct",
|lc| lc + lambda.get_variable(),
|lc| lc + self.x.get_variable() - x.get_variable(),
|lc| lc + y.get_variable() + self.y.get_variable(),
);
Ok(Self { x, y })
}
/// doubles the point; since this is called with a point not at infinity, it is guaranteed to be not infinity
pub fn double_incomplete<CS>(&self, mut cs: CS) -> Result<Self, SynthesisError>
where
CS: ConstraintSystem<Fp>,
{
// lambda = (3 x^2 + a) / 2 * y. For pasta curves, a = 0
let x_sq = self.x.square(cs.namespace(|| "x_sq"))?;
let lambda = AllocatedNum::alloc(cs.namespace(|| "lambda"), || {
let n = Fp::from(3) * x_sq.get_value().get()?;
let d = Fp::from(2) * *self.y.get_value().get()?;
if d == Fp::zero() {
Ok(Fp::one())
} else {
Ok(n * d.invert().unwrap())
}
})?;
cs.enforce(
|| "Check that lambda is computed correctly",
|lc| lc + lambda.get_variable(),
|lc| lc + (Fp::from(2), self.y.get_variable()),
|lc| lc + (Fp::from(3), x_sq.get_variable()),
);
let x = AllocatedNum::alloc(cs.namespace(|| "x"), || {
Ok(
*lambda.get_value().get()? * *lambda.get_value().get()?
- *self.x.get_value().get()?
- *self.x.get_value().get()?,
)
})?;
cs.enforce(
|| "check that x is correct",
|lc| lc + lambda.get_variable(),
|lc| lc + lambda.get_variable(),
|lc| lc + x.get_variable() + (Fp::from(2), self.x.get_variable()),
);
let y = AllocatedNum::alloc(cs.namespace(|| "y"), || {
Ok(
*lambda.get_value().get()? * (*self.x.get_value().get()? - *x.get_value().get()?)
- *self.y.get_value().get()?,
)
})?;
cs.enforce(
|| "Check that y is correct",
|lc| lc + lambda.get_variable(),
|lc| lc + self.x.get_variable() - x.get_variable(),
|lc| lc + y.get_variable() + self.y.get_variable(),
);
Ok(Self { x, y })
}
/// If condition outputs a otherwise outputs b
pub fn conditionally_select<CS: ConstraintSystem<Fp>>(
mut cs: CS,
a: &Self,
b: &Self,
condition: &Boolean,
) -> Result<Self, SynthesisError> {
let x = conditionally_select(cs.namespace(|| "select x"), &a.x, &b.x, condition)?;
let y = conditionally_select(cs.namespace(|| "select y"), &a.y, &b.y, condition)?;
Ok(Self { x, y })
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -696,14 +968,16 @@ mod tests {
assert_eq!(e_pasta, e_pasta_2); assert_eq!(e_pasta, e_pasta_2);
} }
use crate::bellperson::{shape_cs::ShapeCS, solver::SatisfyingAssignment};
use crate::bellperson::{
r1cs::{NovaShape, NovaWitness},
{shape_cs::ShapeCS, solver::SatisfyingAssignment},
};
use ff::{Field, PrimeFieldBits}; use ff::{Field, PrimeFieldBits};
use pasta_curves::{arithmetic::CurveAffine, group::Curve, EpAffine}; use pasta_curves::{arithmetic::CurveAffine, group::Curve, EpAffine};
use std::ops::Mul; use std::ops::Mul;
type G = pasta_curves::pallas::Point; type G = pasta_curves::pallas::Point;
type Fp = pasta_curves::pallas::Scalar; type Fp = pasta_curves::pallas::Scalar;
type Fq = pasta_curves::vesta::Scalar; type Fq = pasta_curves::vesta::Scalar;
use crate::bellperson::r1cs::{NovaShape, NovaWitness};
fn synthesize_smul<Fp, Fq, CS>(mut cs: CS) -> (AllocatedPoint<Fp>, AllocatedPoint<Fp>, Fq) fn synthesize_smul<Fp, Fq, CS>(mut cs: CS) -> (AllocatedPoint<Fp>, AllocatedPoint<Fp>, Fq)
where where
@ -713,8 +987,9 @@ mod tests {
{ {
let a = AllocatedPoint::<Fp>::random_vartime(cs.namespace(|| "a")).unwrap(); let a = AllocatedPoint::<Fp>::random_vartime(cs.namespace(|| "a")).unwrap();
a.inputize(cs.namespace(|| "inputize a")).unwrap(); a.inputize(cs.namespace(|| "inputize a")).unwrap();
let s = Fq::random(&mut OsRng); let s = Fq::random(&mut OsRng);
// Allocate random bits and only keep 128 bits
// Allocate bits for s
let bits: Vec<AllocatedBit> = s let bits: Vec<AllocatedBit> = s
.to_le_bits() .to_le_bits()
.into_iter() .into_iter()

Loading…
Cancel
Save