From f9672faf2320b967d1dbc1cf0d7251a0d74db5f9 Mon Sep 17 00:00:00 2001 From: Srinath Setty Date: Thu, 22 Sep 2022 13:31:55 -0700 Subject: [PATCH] Make Nova's ecc gadgets read curve parameters from the group trait (#115) * make ecc gadgets defined over Group rather than PrimeField * use curve parameters from Group trait --- examples/signature.rs | 20 +- src/circuit.rs | 8 +- src/gadgets/ecc.rs | 607 ++++++++++++++++++++---------------------- src/gadgets/r1cs.rs | 10 +- src/pasta.rs | 20 +- src/traits/mod.rs | 7 +- 6 files changed, 324 insertions(+), 348 deletions(-) diff --git a/examples/signature.rs b/examples/signature.rs index 6516712..097dcca 100644 --- a/examples/signature.rs +++ b/examples/signature.rs @@ -7,7 +7,7 @@ use ff::{ derive::byteorder::{ByteOrder, LittleEndian}, Field, PrimeField, PrimeFieldBits, }; -use nova_snark::gadgets::ecc::AllocatedPoint; +use nova_snark::{gadgets::ecc::AllocatedPoint, traits::Group as NovaGroup}; use num_bigint::BigUint; use pasta_curves::{ arithmetic::CurveAffine, @@ -192,21 +192,21 @@ pub fn synthesize_bits>( .collect::, SynthesisError>>() } -pub fn verify_signature>( +pub fn verify_signature>( cs: &mut CS, - pk: AllocatedPoint, - r: AllocatedPoint, + pk: AllocatedPoint, + r: AllocatedPoint, s_bits: Vec, c_bits: Vec, ) -> Result<(), SynthesisError> { - let g = AllocatedPoint::alloc( + let g = AllocatedPoint::::alloc( cs.namespace(|| "g"), Some(( - F::from_str_vartime( + G::Base::from_str_vartime( "28948022309329048855892746252171976963363056481941647379679742748393362948096", ) .unwrap(), - F::from_str_vartime("2").unwrap(), + G::Base::from_str_vartime("2").unwrap(), false, )), ) @@ -218,7 +218,7 @@ pub fn verify_signature> |lc| lc + CS::one(), |lc| { lc + ( - F::from_str_vartime( + G::Base::from_str_vartime( "28948022309329048855892746252171976963363056481941647379679742748393362948096", ) .unwrap(), @@ -231,7 +231,7 @@ pub fn verify_signature> || "gy is vesta curve", |lc| lc + g.get_coordinates().1.get_variable(), |lc| lc + CS::one(), - |lc| lc + (F::from_str_vartime("2").unwrap(), CS::one()), + |lc| lc + (G::Base::from_str_vartime("2").unwrap(), CS::one()), ); let sg = g.scalar_mul(cs.namespace(|| "[s]G"), s_bits)?; @@ -281,7 +281,7 @@ fn main() { let pk = { let pkxy = pk.0.to_affine().coordinates().unwrap(); - AllocatedPoint::alloc( + AllocatedPoint::::alloc( cs.namespace(|| "pub key"), Some((*pkxy.x(), *pkxy.y(), false)), ) diff --git a/src/circuit.rs b/src/circuit.rs index e7abc14..eecf8a3 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -130,7 +130,7 @@ where Vec>, AllocatedRelaxedR1CSInstance, AllocatedR1CSInstance, - AllocatedPoint, + AllocatedPoint, ), SynthesisError, > { @@ -231,7 +231,7 @@ where z_i: Vec>, U: AllocatedRelaxedR1CSInstance, u: AllocatedR1CSInstance, - T: AllocatedPoint, + T: AllocatedPoint, arity: usize, ) -> Result<(AllocatedRelaxedR1CSInstance, AllocatedBit), SynthesisError> { // Check that u.x[0] = Hash(params, U, i, z0, zi) @@ -412,7 +412,7 @@ mod tests { let mut cs: ShapeCS = ShapeCS::new(); let _ = circuit1.synthesize(&mut cs); let (shape1, gens1) = (cs.r1cs_shape(), cs.r1cs_gens()); - assert_eq!(cs.num_constraints(), 9819); + assert_eq!(cs.num_constraints(), 9815); // Initialize the shape and gens for the secondary let circuit2: NovaAugmentedCircuit::Base>> = @@ -425,7 +425,7 @@ mod tests { let mut cs: ShapeCS = ShapeCS::new(); let _ = circuit2.synthesize(&mut cs); let (shape2, gens2) = (cs.r1cs_shape(), cs.r1cs_gens()); - assert_eq!(cs.num_constraints(), 10351); + assert_eq!(cs.num_constraints(), 10347); // Execute the base case for the primary let zero1 = <::Base as Field>::zero(); diff --git a/src/gadgets/ecc.rs b/src/gadgets/ecc.rs index 5a8dd79..e6aaddf 100644 --- a/src/gadgets/ecc.rs +++ b/src/gadgets/ecc.rs @@ -1,9 +1,12 @@ //! This module implements various elliptic curve gadgets #![allow(non_snake_case)] -use crate::gadgets::utils::{ - alloc_num_equals, alloc_one, alloc_zero, conditionally_select, conditionally_select2, - select_num_or_one, select_num_or_zero, select_num_or_zero2, select_one_or_diff2, - select_one_or_num2, select_zero_or_num2, +use crate::{ + gadgets::utils::{ + alloc_num_equals, alloc_one, alloc_zero, conditionally_select, conditionally_select2, + select_num_or_one, select_num_or_zero, select_num_or_zero2, select_one_or_diff2, + select_one_or_num2, select_zero_or_num2, + }, + traits::Group, }; use bellperson::{ gadgets::{ @@ -13,40 +16,43 @@ use bellperson::{ }, ConstraintSystem, SynthesisError, }; -use ff::PrimeField; +use ff::{Field, PrimeField}; /// AllocatedPoint provides an elliptic curve abstraction inside a circuit. #[derive(Clone)] -pub struct AllocatedPoint +pub struct AllocatedPoint where - Fp: PrimeField, + G: Group, { - pub(crate) x: AllocatedNum, - pub(crate) y: AllocatedNum, - pub(crate) is_infinity: AllocatedNum, + pub(crate) x: AllocatedNum, + pub(crate) y: AllocatedNum, + pub(crate) is_infinity: AllocatedNum, } -impl AllocatedPoint +impl AllocatedPoint where - Fp: PrimeField, + G: Group, { /// Allocates a new point on the curve using coordinates provided by `coords`. /// If coords = None, it allocates the default infinity point - pub fn alloc(mut cs: CS, coords: Option<(Fp, Fp, bool)>) -> Result + pub fn alloc( + mut cs: CS, + coords: Option<(G::Base, G::Base, bool)>, + ) -> Result where - CS: ConstraintSystem, + CS: ConstraintSystem, { let x = AllocatedNum::alloc(cs.namespace(|| "x"), || { - Ok(coords.map_or(Fp::zero(), |c| c.0)) + Ok(coords.map_or(G::Base::zero(), |c| c.0)) })?; let y = AllocatedNum::alloc(cs.namespace(|| "y"), || { - Ok(coords.map_or(Fp::zero(), |c| c.1)) + Ok(coords.map_or(G::Base::zero(), |c| c.1)) })?; let is_infinity = AllocatedNum::alloc(cs.namespace(|| "is_infinity"), || { Ok(if coords.map_or(true, |c| c.2) { - Fp::one() + G::Base::one() } else { - Fp::zero() + G::Base::zero() }) })?; cs.enforce( @@ -62,7 +68,7 @@ where /// Allocates a default point on the curve. pub fn default(mut cs: CS) -> Result where - CS: ConstraintSystem, + CS: ConstraintSystem, { let zero = alloc_zero(cs.namespace(|| "zero"))?; let one = alloc_one(cs.namespace(|| "one"))?; @@ -75,42 +81,18 @@ where } /// Returns coordinates associated with the point. - pub fn get_coordinates(&self) -> (&AllocatedNum, &AllocatedNum, &AllocatedNum) { + pub fn get_coordinates( + &self, + ) -> ( + &AllocatedNum, + &AllocatedNum, + &AllocatedNum, + ) { (&self.x, &self.y, &self.is_infinity) } - // Allocate a random point. Only used for testing - #[cfg(test)] - pub fn random_vartime>(mut cs: CS) -> Result { - loop { - let x = Fp::random(&mut OsRng); - let y = (x * x * x + Fp::one() + Fp::one() + Fp::one() + Fp::one() + Fp::one()).sqrt(); - if y.is_some().unwrap_u8() == 1 { - let x_alloc = AllocatedNum::alloc(cs.namespace(|| "x"), || Ok(x))?; - let y_alloc = AllocatedNum::alloc(cs.namespace(|| "y"), || Ok(y.unwrap()))?; - let is_infinity = alloc_zero(cs.namespace(|| "Is Infinity"))?; - return Ok(Self { - x: x_alloc, - y: y_alloc, - is_infinity, - }); - } - } - } - - /// Make the point io - #[cfg(test)] - pub fn inputize>(&self, mut cs: CS) -> Result<(), SynthesisError> { - let _ = self.x.inputize(cs.namespace(|| "Input point.x")); - let _ = self.y.inputize(cs.namespace(|| "Input point.y")); - let _ = self - .is_infinity - .inputize(cs.namespace(|| "Input point.is_infinity")); - Ok(()) - } - /// Negates the provided point - pub fn negate>(&self, mut cs: CS) -> Result { + pub fn negate>(&self, mut cs: CS) -> Result { let y = AllocatedNum::alloc(cs.namespace(|| "y"), || Ok(-*self.y.get_value().get()?))?; cs.enforce( @@ -128,10 +110,10 @@ where } /// Add two points (may be equal) - pub fn add>( + pub fn add>( &self, mut cs: CS, - other: &AllocatedPoint, + other: &AllocatedPoint, ) -> Result { // Compute boolean equal indicating if self = other @@ -177,10 +159,10 @@ where /// Adds other point to this point and returns the result. Assumes that the two points are /// different and that both other.is_infinity and this.is_infinty are bits - pub fn add_internal>( + pub fn add_internal>( &self, mut cs: CS, - other: &AllocatedPoint, + other: &AllocatedPoint, equal_x: &AllocatedBit, ) -> Result { //************************************************************************/ @@ -195,9 +177,9 @@ where // NOT(NOT(self.is_ifninity) AND NOT(other.is_infinity)) let at_least_one_inf = AllocatedNum::alloc(cs.namespace(|| "at least one inf"), || { Ok( - Fp::one() - - (Fp::one() - *self.is_infinity.get_value().get()?) - * (Fp::one() - *other.is_infinity.get_value().get()?), + G::Base::one() + - (G::Base::one() - *self.is_infinity.get_value().get()?) + * (G::Base::one() - *other.is_infinity.get_value().get()?), ) })?; cs.enforce( @@ -211,7 +193,7 @@ where let x_diff_is_actual = AllocatedNum::alloc(cs.namespace(|| "allocate x_diff_is_actual"), || { Ok(if *equal_x.get_value().get()? { - Fp::one() + G::Base::one() } else { *at_least_one_inf.get_value().get()? }) @@ -233,9 +215,9 @@ where )?; let lambda = AllocatedNum::alloc(cs.namespace(|| "lambda"), || { - let x_diff_inv = if *x_diff_is_actual.get_value().get()? == Fp::one() { + let x_diff_inv = if *x_diff_is_actual.get_value().get()? == G::Base::one() { // Set to default - Fp::one() + G::Base::one() } else { // Set to the actual inverse (*other.x.get_value().get()? - *self.x.get_value().get()?) @@ -340,15 +322,13 @@ where } /// Doubles the supplied point. - pub fn double>(&self, mut cs: CS) -> Result { + pub fn double>(&self, mut cs: CS) -> Result { //*************************************************************/ - // lambda = (Fp::one() + Fp::one() + Fp::one()) - // * self.x - // * self.x - // * ((Fp::one() + Fp::one()) * self.y).invert().unwrap(); + // lambda = (G::Base::from(3) * self.x * self.x + G::A()) + // * (G::Base::from(2)) * self.y).invert().unwrap(); /*************************************************************/ - // Compute tmp = (Fp::one() + Fp::one())* self.y ? self != inf : 1 + // Compute tmp = (G::Base::one() + G::Base::one())* self.y ? self != inf : 1 let tmp_actual = AllocatedNum::alloc(cs.namespace(|| "tmp_actual"), || { Ok(*self.y.get_value().get()? + *self.y.get_value().get()?) })?; @@ -361,44 +341,35 @@ where let tmp = select_one_or_num2(cs.namespace(|| "tmp"), &tmp_actual, &self.is_infinity)?; - // Now compute lambda as (Fp::one() + Fp::one + Fp::one()) * self.x * self.x * tmp_inv - let prod_1 = AllocatedNum::alloc(cs.namespace(|| "alloc prod 1"), || { - let tmp_inv = if *self.is_infinity.get_value().get()? == Fp::one() { - // Return default value 1 - Fp::one() - } else { - // Return the actual inverse - (*tmp.get_value().get()?).invert().unwrap() - }; + // Now compute lambda as (G::Base::from(3) * self.x * self.x + G::A()) * tmp_inv - Ok(tmp_inv * self.x.get_value().get()?) + let prod_1 = AllocatedNum::alloc(cs.namespace(|| "alloc prod 1"), || { + Ok(G::Base::from(3) * self.x.get_value().get()? * self.x.get_value().get()?) })?; - cs.enforce( || "Check prod 1", - |lc| lc + tmp.get_variable(), - |lc| lc + prod_1.get_variable(), - |lc| lc + self.x.get_variable(), - ); - - let prod_2 = AllocatedNum::alloc(cs.namespace(|| "alloc prod 2"), || { - Ok(*prod_1.get_value().get()? * self.x.get_value().get()?) - })?; - cs.enforce( - || "Check prod 2", + |lc| lc + (G::Base::from(3), self.x.get_variable()), |lc| lc + self.x.get_variable(), |lc| lc + prod_1.get_variable(), - |lc| lc + prod_2.get_variable(), ); - let lambda = AllocatedNum::alloc(cs.namespace(|| "lambda"), || { - Ok(*prod_2.get_value().get()? * (Fp::one() + Fp::one() + Fp::one())) + let lambda = AllocatedNum::alloc(cs.namespace(|| "alloc lambda"), || { + let tmp_inv = if *self.is_infinity.get_value().get()? == G::Base::one() { + // Return default value 1 + G::Base::one() + } else { + // Return the actual inverse + (*tmp.get_value().get()?).invert().unwrap() + }; + + Ok(tmp_inv * (*prod_1.get_value().get()? + G::get_curve_params().0)) })?; + cs.enforce( || "Check lambda", - |lc| lc + CS::one() + CS::one() + CS::one(), - |lc| lc + prod_2.get_variable(), + |lc| lc + tmp.get_variable(), |lc| lc + lambda.get_variable(), + |lc| lc + prod_1.get_variable() + (G::get_curve_params().0, CS::one()), ); /*************************************************************/ @@ -455,12 +426,12 @@ where /// A gadget for scalar multiplication, optimized to use incomplete addition law. /// The optimization here is analogous to https://github.com/arkworks-rs/r1cs-std/blob/6d64f379a27011b3629cf4c9cb38b7b7b695d5a0/src/groups/curves/short_weierstrass/mod.rs#L295, /// except we use complete addition law over affine coordinates instead of projective coordinates for the tail bits - pub fn scalar_mul>( + pub fn scalar_mul>( &self, mut cs: CS, scalar_bits: Vec, ) -> Result { - let split_len = core::cmp::min(scalar_bits.len(), (Fp::NUM_BITS - 2) as usize); + let split_len = core::cmp::min(scalar_bits.len(), (G::Base::NUM_BITS - 2) as usize); let (incomplete_bits, complete_bits) = scalar_bits.split_at(split_len); // we convert AllocatedPoint into AllocatedPointNonInfinity; we deal with the case where self.is_infinity = 1 below @@ -544,7 +515,7 @@ where } /// If condition outputs a otherwise outputs b - pub fn conditionally_select>( + pub fn conditionally_select>( mut cs: CS, a: &Self, b: &Self, @@ -565,7 +536,7 @@ where } /// If condition outputs a otherwise infinity - pub fn select_point_or_infinity>( + pub fn select_point_or_infinity>( mut cs: CS, a: &Self, condition: &Boolean, @@ -584,163 +555,29 @@ where } } -#[cfg(test)] -use ff::PrimeFieldBits; -#[cfg(test)] -use rand::rngs::OsRng; -#[cfg(test)] -use std::marker::PhantomData; - -#[cfg(test)] -#[derive(Debug, Clone)] -pub struct Point -where - Fp: PrimeField, - Fq: PrimeField + PrimeFieldBits, -{ - x: Fp, - y: Fp, - is_infinity: bool, - _p: PhantomData, -} - -#[cfg(test)] -impl Point -where - Fp: PrimeField, - Fq: PrimeField + PrimeFieldBits, -{ - pub fn new(x: Fp, y: Fp, is_infinity: bool) -> Self { - Self { - x, - y, - is_infinity, - _p: Default::default(), - } - } - - pub fn random_vartime() -> Self { - loop { - let x = Fp::random(&mut OsRng); - let y = (x * x * x + Fp::one() + Fp::one() + Fp::one() + Fp::one() + Fp::one()).sqrt(); - if y.is_some().unwrap_u8() == 1 { - return Self { - x, - y: y.unwrap(), - is_infinity: false, - _p: Default::default(), - }; - } - } - } - - /// Add any two points - pub fn add(&self, other: &Point) -> Self { - if self.x == other.x { - // If self == other then call double - if self.y == other.y { - self.double() - } else { - // if self.x == other.x and self.y != other.y then return infinity - Self { - x: Fp::zero(), - y: Fp::zero(), - is_infinity: true, - _p: Default::default(), - } - } - } else { - self.add_internal(other) - } - } - - /// Add two different points - pub fn add_internal(&self, other: &Point) -> Self { - if self.is_infinity { - return other.clone(); - } - - if other.is_infinity { - return self.clone(); - } - - let lambda = (other.y - self.y) * (other.x - self.x).invert().unwrap(); - let x = lambda * lambda - self.x - other.x; - let y = lambda * (self.x - x) - self.y; - Self { - x, - y, - is_infinity: false, - _p: Default::default(), - } - } - - pub fn double(&self) -> Self { - if self.is_infinity { - return Self { - x: Fp::zero(), - y: Fp::zero(), - is_infinity: true, - _p: Default::default(), - }; - } - - let lambda = (Fp::one() + Fp::one() + Fp::one()) - * self.x - * self.x - * ((Fp::one() + Fp::one()) * self.y).invert().unwrap(); - let x = lambda * lambda - self.x - self.x; - let y = lambda * (self.x - x) - self.y; - Self { - x, - y, - is_infinity: false, - _p: Default::default(), - } - } - - pub fn scalar_mul(&self, scalar: &Fq) -> Self { - let mut res = Self { - x: Fp::zero(), - y: Fp::zero(), - is_infinity: true, - _p: Default::default(), - }; - - let bits = scalar.to_le_bits(); - for i in (0..bits.len()).rev() { - res = res.double(); - if bits[i] { - res = self.add(&res); - } - } - res - } -} - #[derive(Clone)] /// AllocatedPoint but one that is guaranteed to be not infinity -pub struct AllocatedPointNonInfinity +pub struct AllocatedPointNonInfinity where - Fp: PrimeField, + G: Group, { - x: AllocatedNum, - y: AllocatedNum, + x: AllocatedNum, + y: AllocatedNum, } -impl AllocatedPointNonInfinity +impl AllocatedPointNonInfinity where - Fp: PrimeField, + G: Group, { /// Creates a new AllocatedPointNonInfinity from the specified coordinates - pub fn new(x: AllocatedNum, y: AllocatedNum) -> Self { + pub fn new(x: AllocatedNum, y: AllocatedNum) -> Self { Self { x, y } } /// Allocates a new point on the curve using coordinates provided by `coords`. - pub fn alloc(mut cs: CS, coords: Option<(Fp, Fp)>) -> Result + pub fn alloc(mut cs: CS, coords: Option<(G::Base, G::Base)>) -> Result where - CS: ConstraintSystem, + CS: ConstraintSystem, { let x = AllocatedNum::alloc(cs.namespace(|| "x"), || { coords.map_or(Err(SynthesisError::AssignmentMissing), |c| Ok(c.0)) @@ -753,7 +590,7 @@ where } /// Turns an AllocatedPoint into an AllocatedPointNonInfinity (assumes it is not infinity) - pub fn from_allocated_point(p: &AllocatedPoint) -> Self { + pub fn from_allocated_point(p: &AllocatedPoint) -> Self { Self { x: p.x.clone(), y: p.y.clone(), @@ -763,8 +600,8 @@ where /// Returns an AllocatedPoint from an AllocatedPointNonInfinity pub fn to_allocated_point( &self, - is_infinity: &AllocatedNum, - ) -> Result, SynthesisError> { + is_infinity: &AllocatedNum, + ) -> Result, SynthesisError> { Ok(AllocatedPoint { x: self.x.clone(), y: self.y.clone(), @@ -773,19 +610,19 @@ where } /// Returns coordinates associated with the point. - pub fn get_coordinates(&self) -> (&AllocatedNum, &AllocatedNum) { + pub fn get_coordinates(&self) -> (&AllocatedNum, &AllocatedNum) { (&self.x, &self.y) } /// Add two points assuming self != +/- other pub fn add_incomplete(&self, mut cs: CS, other: &Self) -> Result where - CS: ConstraintSystem, + CS: ConstraintSystem, { // allocate a free variable that an honest prover sets to lambda = (y2-y1)/(x2-x1) let lambda = AllocatedNum::alloc(cs.namespace(|| "lambda"), || { if *other.x.get_value().get()? == *self.x.get_value().get()? { - Ok(Fp::one()) + Ok(G::Base::one()) } else { Ok( (*other.y.get_value().get()? - *self.y.get_value().get()?) @@ -842,17 +679,17 @@ where /// doubles the point; since this is called with a point not at infinity, it is guaranteed to be not infinity pub fn double_incomplete(&self, mut cs: CS) -> Result where - CS: ConstraintSystem, + CS: ConstraintSystem, { - // lambda = (3 x^2 + a) / 2 * y. For pasta curves, a = 0 + // lambda = (3 x^2 + a) / 2 * y let x_sq = self.x.square(cs.namespace(|| "x_sq"))?; let lambda = AllocatedNum::alloc(cs.namespace(|| "lambda"), || { - let n = Fp::from(3) * x_sq.get_value().get()?; - let d = Fp::from(2) * *self.y.get_value().get()?; - if d == Fp::zero() { - Ok(Fp::one()) + let n = G::Base::from(3) * x_sq.get_value().get()? + G::get_curve_params().0; + let d = G::Base::from(2) * *self.y.get_value().get()?; + if d == G::Base::zero() { + Ok(G::Base::one()) } else { Ok(n * d.invert().unwrap()) } @@ -860,8 +697,8 @@ where cs.enforce( || "Check that lambda is computed correctly", |lc| lc + lambda.get_variable(), - |lc| lc + (Fp::from(2), self.y.get_variable()), - |lc| lc + (Fp::from(3), x_sq.get_variable()), + |lc| lc + (G::Base::from(2), self.y.get_variable()), + |lc| lc + (G::Base::from(3), x_sq.get_variable()) + (G::get_curve_params().0, CS::one()), ); let x = AllocatedNum::alloc(cs.namespace(|| "x"), || { @@ -876,7 +713,7 @@ where || "check that x is correct", |lc| lc + lambda.get_variable(), |lc| lc + lambda.get_variable(), - |lc| lc + x.get_variable() + (Fp::from(2), self.x.get_variable()), + |lc| lc + x.get_variable() + (G::Base::from(2), self.x.get_variable()), ); let y = AllocatedNum::alloc(cs.namespace(|| "y"), || { @@ -897,14 +734,13 @@ where } /// If condition outputs a otherwise outputs b - pub fn conditionally_select>( + pub fn conditionally_select>( mut cs: CS, a: &Self, b: &Self, condition: &Boolean, ) -> Result { let x = conditionally_select(cs.namespace(|| "select x"), &a.x, &b.x, condition)?; - let y = conditionally_select(cs.namespace(|| "select y"), &a.y, &b.y, condition)?; Ok(Self { x, y }) @@ -914,18 +750,161 @@ where #[cfg(test)] mod tests { use super::*; + use crate::bellperson::{ + r1cs::{NovaShape, NovaWitness}, + {shape_cs::ShapeCS, solver::SatisfyingAssignment}, + }; + use ff::{Field, PrimeFieldBits}; + use pasta_curves::{arithmetic::CurveAffine, group::Curve, EpAffine}; + use rand::rngs::OsRng; + use std::ops::Mul; + type G1 = pasta_curves::pallas::Point; + type G2 = pasta_curves::vesta::Point; + + #[derive(Debug, Clone)] + pub struct Point + where + G: Group, + { + x: G::Base, + y: G::Base, + is_infinity: bool, + } + + #[cfg(test)] + impl Point + where + G: Group, + { + pub fn new(x: G::Base, y: G::Base, is_infinity: bool) -> Self { + Self { x, y, is_infinity } + } + + pub fn random_vartime() -> Self { + loop { + let x = G::Base::random(&mut OsRng); + let y = (x * x * x + G::Base::from(5)).sqrt(); + if y.is_some().unwrap_u8() == 1 { + return Self { + x, + y: y.unwrap(), + is_infinity: false, + }; + } + } + } + + /// Add any two points + pub fn add(&self, other: &Point) -> Self { + if self.x == other.x { + // If self == other then call double + if self.y == other.y { + self.double() + } else { + // if self.x == other.x and self.y != other.y then return infinity + Self { + x: G::Base::zero(), + y: G::Base::zero(), + is_infinity: true, + } + } + } else { + self.add_internal(other) + } + } + + /// Add two different points + pub fn add_internal(&self, other: &Point) -> Self { + if self.is_infinity { + return other.clone(); + } + + if other.is_infinity { + return self.clone(); + } + + let lambda = (other.y - self.y) * (other.x - self.x).invert().unwrap(); + let x = lambda * lambda - self.x - other.x; + let y = lambda * (self.x - x) - self.y; + Self { + x, + y, + is_infinity: false, + } + } + + pub fn double(&self) -> Self { + if self.is_infinity { + return Self { + x: G::Base::zero(), + y: G::Base::zero(), + is_infinity: true, + }; + } + + let lambda = G::Base::from(3) + * self.x + * self.x + * ((G::Base::one() + G::Base::one()) * self.y) + .invert() + .unwrap(); + let x = lambda * lambda - self.x - self.x; + let y = lambda * (self.x - x) - self.y; + Self { + x, + y, + is_infinity: false, + } + } + + pub fn scalar_mul(&self, scalar: &G::Scalar) -> Self { + let mut res = Self { + x: G::Base::zero(), + y: G::Base::zero(), + is_infinity: true, + }; + + let bits = scalar.to_le_bits(); + for i in (0..bits.len()).rev() { + res = res.double(); + if bits[i] { + res = self.add(&res); + } + } + res + } + } + + // Allocate a random point. Only used for testing + pub fn alloc_random_point>( + mut cs: CS, + ) -> Result, SynthesisError> { + // get a random point + let p = Point::::random_vartime(); + AllocatedPoint::alloc(cs.namespace(|| "alloc p"), Some((p.x, p.y, p.is_infinity))) + } + + /// Make the point io + pub fn inputize_allocted_point>( + p: &AllocatedPoint, + mut cs: CS, + ) -> Result<(), SynthesisError> { + let _ = p.x.inputize(cs.namespace(|| "Input point.x")); + let _ = p.y.inputize(cs.namespace(|| "Input point.y")); + let _ = p + .is_infinity + .inputize(cs.namespace(|| "Input point.is_infinity")); + Ok(()) + } #[test] fn test_ecc_ops() { - type Fp = pasta_curves::pallas::Base; - type Fq = pasta_curves::pallas::Scalar; - // perform some curve arithmetic - let a = Point::::random_vartime(); - let b = Point::::random_vartime(); + let a = Point::::random_vartime(); + let b = Point::::random_vartime(); let c = a.add(&b); let d = a.double(); - let s = Fq::random(&mut OsRng); + let s = ::Scalar::random(&mut OsRng); let e = a.scalar_mul(&s); // perform the same computation by translating to pasta_curve types @@ -968,27 +947,15 @@ mod tests { assert_eq!(e_pasta, e_pasta_2); } - use crate::bellperson::{ - r1cs::{NovaShape, NovaWitness}, - {shape_cs::ShapeCS, solver::SatisfyingAssignment}, - }; - use ff::{Field, PrimeFieldBits}; - use pasta_curves::{arithmetic::CurveAffine, group::Curve, EpAffine}; - use std::ops::Mul; - type G = pasta_curves::pallas::Point; - type Fp = pasta_curves::pallas::Scalar; - type Fq = pasta_curves::vesta::Scalar; - - fn synthesize_smul(mut cs: CS) -> (AllocatedPoint, AllocatedPoint, Fq) + fn synthesize_smul(mut cs: CS) -> (AllocatedPoint, AllocatedPoint, G::Scalar) where - Fp: PrimeField, - Fq: PrimeField + PrimeFieldBits, - CS: ConstraintSystem, + G: Group, + CS: ConstraintSystem, { - let a = AllocatedPoint::::random_vartime(cs.namespace(|| "a")).unwrap(); - a.inputize(cs.namespace(|| "inputize a")).unwrap(); + let a = alloc_random_point(cs.namespace(|| "a")).unwrap(); + inputize_allocted_point(&a, cs.namespace(|| "inputize a")).unwrap(); - let s = Fq::random(&mut OsRng); + let s = G::Scalar::random(&mut OsRng); // Allocate bits for s let bits: Vec = s .to_le_bits() @@ -998,33 +965,33 @@ mod tests { .collect::, SynthesisError>>() .unwrap(); let e = a.scalar_mul(cs.namespace(|| "Scalar Mul"), bits).unwrap(); - e.inputize(cs.namespace(|| "inputize e")).unwrap(); + inputize_allocted_point(&e, cs.namespace(|| "inputize e")).unwrap(); (a, e, s) } #[test] fn test_ecc_circuit_ops() { // First create the shape - let mut cs: ShapeCS = ShapeCS::new(); - let _ = synthesize_smul::(cs.namespace(|| "synthesize")); + let mut cs: ShapeCS = ShapeCS::new(); + let _ = synthesize_smul::(cs.namespace(|| "synthesize")); println!("Number of constraints: {}", cs.num_constraints()); let shape = cs.r1cs_shape(); let gens = cs.r1cs_gens(); // Then the satisfying assignment - let mut cs: SatisfyingAssignment = SatisfyingAssignment::new(); - let (a, e, s) = synthesize_smul::(cs.namespace(|| "synthesize")); + let mut cs: SatisfyingAssignment = SatisfyingAssignment::new(); + let (a, e, s) = synthesize_smul::(cs.namespace(|| "synthesize")); let (inst, witness) = cs.r1cs_instance_and_witness(&shape, &gens).unwrap(); - let a_p: Point = Point::new( + let a_p: Point = Point::new( a.x.get_value().unwrap(), a.y.get_value().unwrap(), - a.is_infinity.get_value().unwrap() == Fp::one(), + a.is_infinity.get_value().unwrap() == ::Base::one(), ); - let e_p: Point = Point::new( + let e_p: Point = Point::new( e.x.get_value().unwrap(), e.y.get_value().unwrap(), - e.is_infinity.get_value().unwrap() == Fp::one(), + e.is_infinity.get_value().unwrap() == ::Base::one(), ); let e_new = a_p.scalar_mul(&s); assert!(e_p.x == e_new.x && e_p.y == e_new.y); @@ -1032,41 +999,40 @@ mod tests { assert!(shape.is_sat(&gens, &inst, &witness).is_ok()); } - fn synthesize_add_equal(mut cs: CS) -> (AllocatedPoint, AllocatedPoint) + fn synthesize_add_equal(mut cs: CS) -> (AllocatedPoint, AllocatedPoint) where - Fp: PrimeField, - Fq: PrimeField + PrimeFieldBits, - CS: ConstraintSystem, + G: Group, + CS: ConstraintSystem, { - let a = AllocatedPoint::::random_vartime(cs.namespace(|| "a")).unwrap(); - a.inputize(cs.namespace(|| "inputize a")).unwrap(); + let a = alloc_random_point(cs.namespace(|| "a")).unwrap(); + inputize_allocted_point(&a, cs.namespace(|| "inputize a")).unwrap(); let e = a.add(cs.namespace(|| "add a to a"), &a).unwrap(); - e.inputize(cs.namespace(|| "inputize e")).unwrap(); + inputize_allocted_point(&e, cs.namespace(|| "inputize e")).unwrap(); (a, e) } #[test] fn test_ecc_circuit_add_equal() { // First create the shape - let mut cs: ShapeCS = ShapeCS::new(); - let _ = synthesize_add_equal::(cs.namespace(|| "synthesize add equal")); + let mut cs: ShapeCS = ShapeCS::new(); + let _ = synthesize_add_equal::(cs.namespace(|| "synthesize add equal")); println!("Number of constraints: {}", cs.num_constraints()); let shape = cs.r1cs_shape(); let gens = cs.r1cs_gens(); // Then the satisfying assignment - let mut cs: SatisfyingAssignment = SatisfyingAssignment::new(); - let (a, e) = synthesize_add_equal::(cs.namespace(|| "synthesize add equal")); + let mut cs: SatisfyingAssignment = SatisfyingAssignment::new(); + let (a, e) = synthesize_add_equal::(cs.namespace(|| "synthesize add equal")); let (inst, witness) = cs.r1cs_instance_and_witness(&shape, &gens).unwrap(); - let a_p: Point = Point::new( + let a_p: Point = Point::new( a.x.get_value().unwrap(), a.y.get_value().unwrap(), - a.is_infinity.get_value().unwrap() == Fp::one(), + a.is_infinity.get_value().unwrap() == ::Base::one(), ); - let e_p: Point = Point::new( + let e_p: Point = Point::new( e.x.get_value().unwrap(), e.y.get_value().unwrap(), - e.is_infinity.get_value().unwrap() == Fp::one(), + e.is_infinity.get_value().unwrap() == ::Base::one(), ); let e_new = a_p.add(&a_p); assert!(e_p.x == e_new.x && e_p.y == e_new.y); @@ -1074,18 +1040,19 @@ mod tests { assert!(shape.is_sat(&gens, &inst, &witness).is_ok()); } - fn synthesize_add_negation(mut cs: CS) -> AllocatedPoint + fn synthesize_add_negation(mut cs: CS) -> AllocatedPoint where - Fp: PrimeField, - Fq: PrimeField + PrimeFieldBits, - CS: ConstraintSystem, + G: Group, + CS: ConstraintSystem, { - let a = AllocatedPoint::::random_vartime(cs.namespace(|| "a")).unwrap(); - a.inputize(cs.namespace(|| "inputize a")).unwrap(); + let a = alloc_random_point(cs.namespace(|| "a")).unwrap(); + inputize_allocted_point(&a, cs.namespace(|| "inputize a")).unwrap(); let mut b = a.clone(); - b.y = - AllocatedNum::alloc(cs.namespace(|| "allocate negation of a"), || Ok(Fp::zero())).unwrap(); - b.inputize(cs.namespace(|| "inputize b")).unwrap(); + b.y = AllocatedNum::alloc(cs.namespace(|| "allocate negation of a"), || { + Ok(G::Base::zero()) + }) + .unwrap(); + inputize_allocted_point(&b, cs.namespace(|| "inputize b")).unwrap(); let e = a.add(cs.namespace(|| "add a to b"), &b).unwrap(); e } @@ -1093,20 +1060,20 @@ mod tests { #[test] fn test_ecc_circuit_add_negation() { // First create the shape - let mut cs: ShapeCS = ShapeCS::new(); - let _ = synthesize_add_negation::(cs.namespace(|| "synthesize add equal")); + let mut cs: ShapeCS = ShapeCS::new(); + let _ = synthesize_add_negation::(cs.namespace(|| "synthesize add equal")); println!("Number of constraints: {}", cs.num_constraints()); let shape = cs.r1cs_shape(); let gens = cs.r1cs_gens(); // Then the satisfying assignment - let mut cs: SatisfyingAssignment = SatisfyingAssignment::new(); - let e = synthesize_add_negation::(cs.namespace(|| "synthesize add negation")); + let mut cs: SatisfyingAssignment = SatisfyingAssignment::new(); + let e = synthesize_add_negation::(cs.namespace(|| "synthesize add negation")); let (inst, witness) = cs.r1cs_instance_and_witness(&shape, &gens).unwrap(); - let e_p: Point = Point::new( + let e_p: Point = Point::new( e.x.get_value().unwrap(), e.y.get_value().unwrap(), - e.is_infinity.get_value().unwrap() == Fp::one(), + e.is_infinity.get_value().unwrap() == ::Base::one(), ); assert!(e_p.is_infinity); // Make sure that it is satisfiable diff --git a/src/gadgets/r1cs.rs b/src/gadgets/r1cs.rs index 229f073..459a60a 100644 --- a/src/gadgets/r1cs.rs +++ b/src/gadgets/r1cs.rs @@ -27,7 +27,7 @@ pub struct AllocatedR1CSInstance where G: Group, { - pub(crate) W: AllocatedPoint, + pub(crate) W: AllocatedPoint, pub(crate) X0: AllocatedNum, pub(crate) X1: AllocatedNum, } @@ -75,8 +75,8 @@ pub struct AllocatedRelaxedR1CSInstance where G: Group, { - pub(crate) W: AllocatedPoint, - pub(crate) E: AllocatedPoint, + pub(crate) W: AllocatedPoint, + pub(crate) E: AllocatedPoint, pub(crate) u: AllocatedNum, pub(crate) X0: BigNat, pub(crate) X1: BigNat, @@ -262,7 +262,7 @@ where mut cs: CS, params: AllocatedNum, // hash of R1CSShape of F' u: AllocatedR1CSInstance, - T: AllocatedPoint, + T: AllocatedPoint, ro_consts: ROConstantsCircuit, limb_width: usize, n_limbs: usize, @@ -309,7 +309,7 @@ where // Allocate the order of the non-native field as a constant let m_bn = alloc_bignat_constant( cs.namespace(|| "alloc m"), - &G::get_order(), + &G::get_curve_params().2, limb_width, n_limbs, )?; diff --git a/src/pasta.rs b/src/pasta.rs index 62289f1..3748225 100644 --- a/src/pasta.rs +++ b/src/pasta.rs @@ -89,12 +89,16 @@ impl Group for pallas::Point { } } - fn get_order() -> BigInt { - BigInt::from_str_radix( + fn get_curve_params() -> (Self::Base, Self::Base, BigInt) { + let A = Self::Base::zero(); + let B = Self::Base::from(5); + let order = BigInt::from_str_radix( "40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001", 16, ) - .unwrap() + .unwrap(); + + (A, B, order) } fn zero() -> Self { @@ -195,12 +199,16 @@ impl Group for vesta::Point { } } - fn get_order() -> BigInt { - BigInt::from_str_radix( + fn get_curve_params() -> (Self::Base, Self::Base, BigInt) { + let A = Self::Base::zero(); + let B = Self::Base::from(5); + let order = BigInt::from_str_radix( "40000000000000000000000000000000224698fc094cf91b992d30ed00000001", 16, ) - .unwrap() + .unwrap(); + + (A, B, order) } fn zero() -> Self { diff --git a/src/traits/mod.rs b/src/traits/mod.rs index 4f1868f..1c81563 100644 --- a/src/traits/mod.rs +++ b/src/traits/mod.rs @@ -12,6 +12,7 @@ use merlin::Transcript; use num_bigint::BigInt; /// Represents an element of a group +/// This is currently tailored for an elliptic curve group pub trait Group: Clone + Copy @@ -62,14 +63,14 @@ pub trait Group: /// Returns the affine coordinates (x, y, infinty) for the point fn to_coordinates(&self) -> (Self::Base, Self::Base, bool); - /// Returns the order of the group as a big integer - fn get_order() -> BigInt; - /// Returns an element that is the additive identity of the group fn zero() -> Self; /// Returns the generator of the group fn get_generator() -> Self; + + /// Returns A, B, and the order of the group as a big integer + fn get_curve_params() -> (Self::Base, Self::Base, BigInt); } /// Represents a compressed version of a group element