From e67ed82eb5c3d93e59e29c6e829e11351114885f Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Thu, 30 Jan 2025 13:56:16 -0500 Subject: [PATCH] perf: save 1 dbl in scalar_mul_le + fmt --- src/groups/curves/short_weierstrass/mod.rs | 55 +++++++++++------ .../short_weierstrass/non_zero_affine.rs | 4 +- src/groups/mod.rs | 61 ++++++++++++++----- 3 files changed, 83 insertions(+), 37 deletions(-) diff --git a/src/groups/curves/short_weierstrass/mod.rs b/src/groups/curves/short_weierstrass/mod.rs index 8251bdb..b9d09bf 100644 --- a/src/groups/curves/short_weierstrass/mod.rs +++ b/src/groups/curves/short_weierstrass/mod.rs @@ -356,7 +356,7 @@ where *mul_result += result - subtrahend; // Now, let's finish off the rest of the bits using our complete formulae - for bit in proj_bits { + for bit in proj_bits.iter().rev().skip(1).rev() { if bit.is_constant() { if *bit == &Boolean::TRUE { *mul_result += &multiple_of_power_of_two.into_projective(); @@ -367,6 +367,21 @@ where } multiple_of_power_of_two.double_in_place()?; } + + // last bit + // we don't need the last doubling of multiple_of_power_of_two + let n = proj_bits.len(); + if n >= 1 { + if proj_bits[n - 1].is_constant() { + if proj_bits[n - 1] == &Boolean::TRUE { + *mul_result += &multiple_of_power_of_two.into_projective(); + } + } else { + let temp = &*mul_result + &multiple_of_power_of_two.into_projective(); + *mul_result = proj_bits[n - 1].select(&temp, &mul_result)?; + } + } + Ok(()) } } @@ -518,12 +533,13 @@ where // zero if `self` was zero. However, we also want to make sure that generated // constraints are satisfiable in both cases. // - // In particular, using non-sensible values for `x` and `y` in zero-case may cause - // `unchecked` operations to generate constraints that can never be satisfied, depending - // on the curve equation coefficients. + // In particular, using non-sensible values for `x` and `y` in zero-case may + // cause `unchecked` operations to generate constraints that can never + // be satisfied, depending on the curve equation coefficients. // - // The safest approach is to use coordinates of some point from the curve, thus not - // violating assumptions of `NonZeroAffine`. For instance, generator point. + // The safest approach is to use coordinates of some point from the curve, thus + // not violating assumptions of `NonZeroAffine`. For instance, generator + // point. let x = infinity.select(&F::constant(P::GENERATOR.x), &x)?; let y = infinity.select(&F::constant(P::GENERATOR.y), &y)?; let non_zero_self = NonZeroAffineVar::new(x, y); @@ -563,10 +579,7 @@ where // first bit let temp = NonZeroAffineVar::new(non_zero_self.x, non_zero_self.y.negate()?); acc1 = acc0.add_unchecked(&temp)?; - acc0 = bits[0].select( - &acc0, - &acc1, - )?; + acc0 = bits[0].select(&acc0, &acc1)?; let mul_result = acc0.into_projective(); infinity.select(&Self::zero(), &mul_result) @@ -590,12 +603,13 @@ where // zero if `self` was zero. However, we also want to make sure that generated // constraints are satisfiable in both cases. // - // In particular, using non-sensible values for `x` and `y` in zero-case may cause - // `unchecked` operations to generate constraints that can never be satisfied, depending - // on the curve equation coefficients. + // In particular, using non-sensible values for `x` and `y` in zero-case may + // cause `unchecked` operations to generate constraints that can never + // be satisfied, depending on the curve equation coefficients. // - // The safest approach is to use coordinates of some point from the curve, thus not - // violating assumptions of `NonZeroAffine`. For instance, generator point. + // The safest approach is to use coordinates of some point from the curve, thus + // not violating assumptions of `NonZeroAffine`. For instance, generator + // point. let x = infinity.select(&F::constant(P::GENERATOR.x), &x)?; let y = infinity.select(&F::constant(P::GENERATOR.y), &y)?; let non_zero_self = NonZeroAffineVar::new(x, y); @@ -632,8 +646,8 @@ where infinity.select(&Self::zero(), &mul_result) } - /// Computes `bits1 * self + bits2 * p`, where `bits1` and `bits2` are big-endian - /// `Boolean` representation of the scalars. + /// Computes `bits1 * self + bits2 * p`, where `bits1` and `bits2` are + /// big-endian `Boolean` representation of the scalars. /// /// `self` and `p` are non-zero and `self` ≠ `-p`. #[tracing::instrument(target = "r1cs", skip(bits1, bits2))] @@ -682,7 +696,8 @@ where let mut acc = nz_aff1.add_unchecked(&nz_aff2.clone())?; // double-and-add loop - for (bit1, bit2) in (bits1.iter().rev().skip(1).rev()).zip(bits2.iter().rev().skip(1).rev()) { + for (bit1, bit2) in (bits1.iter().rev().skip(1).rev()).zip(bits2.iter().rev().skip(1).rev()) + { let mut b = bit1.select(&nz_aff1, &aff1_neg)?; acc = acc.double_and_add_unchecked(&b)?; b = bit2.select(&nz_aff2, &aff2_neg)?; @@ -691,9 +706,9 @@ where // last bit aff1_neg = aff1_neg.add_unchecked(&acc)?; - acc = bits1[bits1.len()-1].select(&acc, &aff1_neg)?; + acc = bits1[bits1.len() - 1].select(&acc, &aff1_neg)?; aff2_neg = aff2_neg.add_unchecked(&acc)?; - acc = bits2[bits1.len()-1].select(&acc, &aff2_neg)?; + acc = bits2[bits1.len() - 1].select(&acc, &aff2_neg)?; Ok(acc.into_projective()) } diff --git a/src/groups/curves/short_weierstrass/non_zero_affine.rs b/src/groups/curves/short_weierstrass/non_zero_affine.rs index 8aa40b1..e1268af 100644 --- a/src/groups/curves/short_weierstrass/non_zero_affine.rs +++ b/src/groups/curves/short_weierstrass/non_zero_affine.rs @@ -130,8 +130,8 @@ where } } - /// Conditionally computes `(self + other) + self` or `(self + other) + other` - /// depending on the value of `cond`. + /// Conditionally computes `(self + other) + self` or `(self + other) + + /// other` depending on the value of `cond`. /// /// This follows the formulae from [\[ELM03\]](https://arxiv.org/abs/math/0208038). #[tracing::instrument(target = "r1cs", skip(self))] diff --git a/src/groups/mod.rs b/src/groups/mod.rs index d9fc650..bb26e0d 100644 --- a/src/groups/mod.rs +++ b/src/groups/mod.rs @@ -131,8 +131,8 @@ pub trait CurveVar: Ok(res) } - /// Computes a `I1 * self + I2 * p` in place, where `I1` and `I2` are `Boolean` *big-endian* - /// representation of the scalars. + /// Computes a `I1 * self + I2 * p` in place, where `I1` and `I2` are + /// `Boolean` *big-endian* representation of the scalars. #[tracing::instrument(target = "r1cs", skip(bits1, bits2))] fn joint_scalar_mul_be<'a>( &self, @@ -142,7 +142,7 @@ pub trait CurveVar: ) -> Result { let res1 = self.scalar_mul_le(bits1)?; let res2 = p.scalar_mul_le(bits2)?; - Ok(res1+res2) + Ok(res1 + res2) } /// Computes a `I * self` in place, where `I` is a `Boolean` *little-endian* @@ -217,6 +217,38 @@ mod test_sw_arithmetic { use ark_relations::r1cs::{ConstraintSystem, Result}; use ark_std::UniformRand; + fn point_scalar_mul_satisfied() -> Result + where + G: CurveGroup, + G::BaseField: PrimeField, + G::Config: SWCurveConfig, + { + let mut rng = ark_std::test_rng(); + + let cs = ConstraintSystem::new_ref(); + let point_in = Projective::::rand(&mut rng); + let scalar = G::ScalarField::rand(&mut rng); + let point_out = point_in * scalar; + + let point_in = + ProjectiveVar::>::new_witness(cs.clone(), || { + Ok(point_in) + })?; + let point_out = + ProjectiveVar::>::new_input(cs.clone(), || { + Ok(point_out) + })?; + let scalar = NonNativeFieldVar::new_input(cs.clone(), || Ok(scalar))?; + + let mul = point_in.scalar_mul_le(scalar.to_bits_le().unwrap().iter())?; + + point_out.enforce_equal(&mul)?; + + println!("#r1cs for scalar_mul_le: {}", cs.num_constraints()); + + cs.is_satisfied() + } + fn point_scalar_mul_joye_satisfied() -> Result where G: CurveGroup, @@ -244,11 +276,7 @@ mod test_sw_arithmetic { point_out.enforce_equal(&mul)?; - println!( - "#r1cs for scalar_mul_joye_le: {}", - cs.num_constraints() - ); - + println!("#r1cs for scalar_mul_joye_le: {}", cs.num_constraints()); cs.is_satisfied() } @@ -283,26 +311,29 @@ mod test_sw_arithmetic { let scalar1 = NonNativeFieldVar::new_input(cs.clone(), || Ok(scalar1))?; let scalar2 = NonNativeFieldVar::new_input(cs.clone(), || Ok(scalar2))?; - let res = point_in1.joint_scalar_mul_be(&point_in2, scalar1.to_bits_le().unwrap().iter(), scalar2.to_bits_le().unwrap().iter())?; + let res = point_in1.joint_scalar_mul_be( + &point_in2, + scalar1.to_bits_le().unwrap().iter(), + scalar2.to_bits_le().unwrap().iter(), + )?; point_out.enforce_equal(&res)?; - println!( - "#r1cs for joint_scalar_mul: {}", - cs.num_constraints() - ); - + println!("#r1cs for joint_scalar_mul: {}", cs.num_constraints()); cs.is_satisfied() } #[test] fn test_point_scalar_mul() { + assert!(point_scalar_mul_satisfied::().unwrap()); + } + #[test] + fn test_point_scalar_mul_joye() { assert!(point_scalar_mul_joye_satisfied::().unwrap()); } #[test] fn test_point_joint_scalar_mul() { assert!(point_joint_scalar_mul_satisfied::().unwrap()); } - }