From f16fa1e40103e9bc0d63766cee11e4c00e2633d2 Mon Sep 17 00:00:00 2001 From: Leo Date: Wed, 10 May 2023 22:52:05 +0200 Subject: [PATCH] make ecc tests generic (#165) --- src/gadgets/ecc.rs | 104 +++++++++++++++++++++++++++++------------- src/provider/pasta.rs | 4 +- 2 files changed, 74 insertions(+), 34 deletions(-) diff --git a/src/gadgets/ecc.rs b/src/gadgets/ecc.rs index 47b10d3..0947826 100644 --- a/src/gadgets/ecc.rs +++ b/src/gadgets/ecc.rs @@ -755,11 +755,8 @@ mod tests { {shape_cs::ShapeCS, solver::SatisfyingAssignment}, }; use ff::{Field, PrimeFieldBits}; - use pasta_curves::{arithmetic::CurveAffine, group::Curve, EpAffine}; + use pasta_curves::{arithmetic::CurveAffine, group::Curve, pallas, vesta}; use rand::rngs::OsRng; - use std::ops::Mul; - type G1 = pasta_curves::pallas::Point; - type G2 = pasta_curves::vesta::Point; #[derive(Debug, Clone)] pub struct Point @@ -783,7 +780,7 @@ mod tests { pub fn random_vartime() -> Self { loop { let x = G::Base::random(&mut OsRng); - let y = (x * x * x + G::Base::from(5)).sqrt(); + let y = (x.square() * x + G::get_curve_params().1).sqrt(); if y.is_some().unwrap_u8() == 1 { return Self { x, @@ -897,52 +894,61 @@ mod tests { #[test] fn test_ecc_ops() { + test_ecc_ops_with::(); + test_ecc_ops_with::(); + } + + fn test_ecc_ops_with() + where + C: CurveAffine, + G: Group, + { // perform some curve arithmetic - let a = Point::::random_vartime(); - let b = Point::::random_vartime(); + let a = Point::::random_vartime(); + let b = Point::::random_vartime(); let c = a.add(&b); let d = a.double(); - let s = ::Scalar::random(&mut OsRng); + let s = ::Scalar::random(&mut OsRng); let e = a.scalar_mul(&s); - // perform the same computation by translating to pasta_curve types - let a_pasta = EpAffine::from_xy( - pasta_curves::Fp::from_repr(a.x.to_repr()).unwrap(), - pasta_curves::Fp::from_repr(a.y.to_repr()).unwrap(), + // perform the same computation by translating to curve types + let a_curve = C::from_xy( + C::Base::from_repr(a.x.to_repr()).unwrap(), + C::Base::from_repr(a.y.to_repr()).unwrap(), ) .unwrap(); - let b_pasta = EpAffine::from_xy( - pasta_curves::Fp::from_repr(b.x.to_repr()).unwrap(), - pasta_curves::Fp::from_repr(b.y.to_repr()).unwrap(), + let b_curve = C::from_xy( + C::Base::from_repr(b.x.to_repr()).unwrap(), + C::Base::from_repr(b.y.to_repr()).unwrap(), ) .unwrap(); - let c_pasta = (a_pasta + b_pasta).to_affine(); - let d_pasta = (a_pasta + a_pasta).to_affine(); - let e_pasta = a_pasta - .mul(pasta_curves::Fq::from_repr(s.to_repr()).unwrap()) + let c_curve = (a_curve + b_curve).to_affine(); + let d_curve = (a_curve + a_curve).to_affine(); + let e_curve = a_curve + .mul(C::Scalar::from_repr(s.to_repr()).unwrap()) .to_affine(); - // transform c, d, and e into pasta_curve types - let c_pasta_2 = EpAffine::from_xy( - pasta_curves::Fp::from_repr(c.x.to_repr()).unwrap(), - pasta_curves::Fp::from_repr(c.y.to_repr()).unwrap(), + // transform c, d, and e into curve types + let c_curve_2 = C::from_xy( + C::Base::from_repr(c.x.to_repr()).unwrap(), + C::Base::from_repr(c.y.to_repr()).unwrap(), ) .unwrap(); - let d_pasta_2 = EpAffine::from_xy( - pasta_curves::Fp::from_repr(d.x.to_repr()).unwrap(), - pasta_curves::Fp::from_repr(d.y.to_repr()).unwrap(), + let d_curve_2 = C::from_xy( + C::Base::from_repr(d.x.to_repr()).unwrap(), + C::Base::from_repr(d.y.to_repr()).unwrap(), ) .unwrap(); - let e_pasta_2 = EpAffine::from_xy( - pasta_curves::Fp::from_repr(e.x.to_repr()).unwrap(), - pasta_curves::Fp::from_repr(e.y.to_repr()).unwrap(), + let e_curve_2 = C::from_xy( + C::Base::from_repr(e.x.to_repr()).unwrap(), + C::Base::from_repr(e.y.to_repr()).unwrap(), ) .unwrap(); // check that we have the same outputs - assert_eq!(c_pasta, c_pasta_2); - assert_eq!(d_pasta, d_pasta_2); - assert_eq!(e_pasta, e_pasta_2); + assert_eq!(c_curve, c_curve_2); + assert_eq!(d_curve, d_curve_2); + assert_eq!(e_curve, e_curve_2); } fn synthesize_smul(mut cs: CS) -> (AllocatedPoint, AllocatedPoint, G::Scalar) @@ -969,6 +975,17 @@ mod tests { #[test] fn test_ecc_circuit_ops() { + test_ecc_circuit_ops_with::(); + test_ecc_circuit_ops_with::(); + } + + fn test_ecc_circuit_ops_with() + where + B: PrimeField, + S: PrimeField, + G1: Group, + G2: Group, + { // First create the shape let mut cs: ShapeCS = ShapeCS::new(); let _ = synthesize_smul::(cs.namespace(|| "synthesize")); @@ -1010,6 +1027,17 @@ mod tests { #[test] fn test_ecc_circuit_add_equal() { + test_ecc_circuit_add_equal_with::(); + test_ecc_circuit_add_equal_with::(); + } + + fn test_ecc_circuit_add_equal_with() + where + B: PrimeField, + S: PrimeField, + G1: Group, + G2: Group, + { // First create the shape let mut cs: ShapeCS = ShapeCS::new(); let _ = synthesize_add_equal::(cs.namespace(|| "synthesize add equal")); @@ -1055,6 +1083,18 @@ mod tests { #[test] fn test_ecc_circuit_add_negation() { + test_ecc_circuit_add_negation_with::( + ); + test_ecc_circuit_add_negation_with::(); + } + + fn test_ecc_circuit_add_negation_with() + where + B: PrimeField, + S: PrimeField, + G1: Group, + G2: Group, + { // First create the shape let mut cs: ShapeCS = ShapeCS::new(); let _ = synthesize_add_negation::(cs.namespace(|| "synthesize add equal")); diff --git a/src/provider/pasta.rs b/src/provider/pasta.rs index f67041f..0fa2525 100644 --- a/src/provider/pasta.rs +++ b/src/provider/pasta.rs @@ -155,8 +155,8 @@ macro_rules! impl_traits { } fn get_curve_params() -> (Self::Base, Self::Base, BigInt) { - let A = Self::Base::zero(); - let B = Self::Base::from(5); + let A = $name::Point::a(); + let B = $name::Point::b(); let order = BigInt::from_str_radix($order_str, 16).unwrap(); (A, B, order)