Browse Source

Improve handling of constant bits in scalar mul for SW curves (#43)

* We add a double_and_add method that computes 2 * self + other more
  efficiently than just doubling + addition; this is not used anywhere 
  yet, but I am planning on fiddling with it to see if we can leverage
  it somehow. (See zcash/zcash#3924 for details)

* We handle constant scalars better:
  * We skip the most-significant constant zeroes to avoid unnecessary
    doubling
  * When intermediate bits of the scalar are constants, instead of
    conditionally adding, we directly use the value of the bit to
    decide whether to add or not.


Co-authored-by: Dev Ojha <ValarDragon@users.noreply.github.com>
Co-authored-by: weikeng <w.k@berkeley.edu>
master
Pratyush Mishra 3 years ago
committed by GitHub
parent
commit
d9e0200433
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 67 additions and 12 deletions
  1. +2
    -2
      CHANGELOG.md
  2. +33
    -9
      src/groups/curves/short_weierstrass/mod.rs
  3. +32
    -1
      src/groups/curves/short_weierstrass/non_zero_affine.rs

+ 2
- 2
CHANGELOG.md

@ -1,7 +1,7 @@
## Pending ## Pending
### Breaking changes ### Breaking changes
- #12 Make the `ToBitsGadget` impl for `FpVar` output fixed-size
- #12 Make the output of the `ToBitsGadget` impl for `FpVar` fixed-size
### Features ### Features
@ -17,7 +17,7 @@
- #33 Speedup scalar multiplication by a constant - #33 Speedup scalar multiplication by a constant
- #35 Construct a `FpVar` from bits - #35 Construct a `FpVar` from bits
- #36 Implement `ToConstraintFieldGadget` for `Vec<Uint8>` - #36 Implement `ToConstraintFieldGadget` for `Vec<Uint8>`
- #40 Faster scalar multiplication for Short Weierstrass curves by relying on affine formulae
- #40, #43 Faster scalar multiplication for Short Weierstrass curves by relying on affine formulae
### Bug fixes ### Bug fixes
- #8 Fix bug in `three_bit_cond_neg_lookup` when using a constant lookup bit - #8 Fix bug in `three_bit_cond_neg_lookup` when using a constant lookup bit

+ 33
- 9
src/groups/curves/short_weierstrass/mod.rs

@ -2,7 +2,7 @@ use ark_ec::{
short_weierstrass_jacobian::{GroupAffine as SWAffine, GroupProjective as SWProjective}, short_weierstrass_jacobian::{GroupAffine as SWAffine, GroupProjective as SWProjective},
AffineCurve, ProjectiveCurve, SWModelParameters, AffineCurve, ProjectiveCurve, SWModelParameters,
}; };
use ark_ff::{BigInteger, BitIteratorBE, Field, FpParameters, One, PrimeField, Zero};
use ark_ff::{BigInteger, BitIteratorBE, Field, One, PrimeField, Zero};
use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError};
use core::{borrow::Borrow, marker::PhantomData}; use core::{borrow::Borrow, marker::PhantomData};
use non_zero_affine::NonZeroAffineVar; use non_zero_affine::NonZeroAffineVar;
@ -276,7 +276,7 @@ where
multiple_of_power_of_two: &mut NonZeroAffineVar<P, F>, multiple_of_power_of_two: &mut NonZeroAffineVar<P, F>,
bits: &[&Boolean<<P::BaseField as Field>::BasePrimeField>], bits: &[&Boolean<<P::BaseField as Field>::BasePrimeField>],
) -> Result<(), SynthesisError> { ) -> Result<(), SynthesisError> {
let scalar_modulus_bits = <<P::ScalarField as PrimeField>::Params>::MODULUS_BITS as usize;
let scalar_modulus_bits = <P::ScalarField as PrimeField>::size_in_bits();
assert!(scalar_modulus_bits >= bits.len()); assert!(scalar_modulus_bits >= bits.len());
let split_len = ark_std::cmp::min(scalar_modulus_bits - 2, bits.len()); let split_len = ark_std::cmp::min(scalar_modulus_bits - 2, bits.len());
@ -317,8 +317,14 @@ where
// As mentioned, we will skip the LSB, and will later handle it via a conditional subtraction. // As mentioned, we will skip the LSB, and will later handle it via a conditional subtraction.
for bit in affine_bits.iter().skip(1) { for bit in affine_bits.iter().skip(1) {
let temp = accumulator.add_unchecked(&multiple_of_power_of_two)?;
accumulator = bit.select(&temp, &accumulator)?;
if bit.is_constant() {
if *bit == &Boolean::TRUE {
accumulator = accumulator.add_unchecked(&multiple_of_power_of_two)?;
}
} else {
let temp = accumulator.add_unchecked(&multiple_of_power_of_two)?;
accumulator = bit.select(&temp, &accumulator)?;
}
multiple_of_power_of_two.double_in_place()?; multiple_of_power_of_two.double_in_place()?;
} }
// Perform conditional subtraction: // Perform conditional subtraction:
@ -332,8 +338,14 @@ where
// Now, let's finish off the rest of the bits using our complete formulae // Now, let's finish off the rest of the bits using our complete formulae
for bit in proj_bits { for bit in proj_bits {
let temp = &*mul_result + &multiple_of_power_of_two.into_projective();
*mul_result = bit.select(&temp, &mul_result)?;
if bit.is_constant() {
if *bit == &Boolean::TRUE {
*mul_result += &multiple_of_power_of_two.into_projective();
}
} else {
let temp = &*mul_result + &multiple_of_power_of_two.into_projective();
*mul_result = bit.select(&temp, &mul_result)?;
}
multiple_of_power_of_two.double_in_place()?; multiple_of_power_of_two.double_in_place()?;
} }
Ok(()) Ok(())
@ -485,11 +497,23 @@ where
// will conditionally select zero if `self` was zero. // will conditionally select zero if `self` was zero.
let non_zero_self = NonZeroAffineVar::new(x, y); let non_zero_self = NonZeroAffineVar::new(x, y);
let bits = bits.collect::<Vec<_>>();
let mut bits = bits.collect::<Vec<_>>();
if bits.len() == 0 { if bits.len() == 0 {
return Ok(Self::zero()); return Ok(Self::zero());
} }
let scalar_modulus_bits = <<P::ScalarField as PrimeField>::Params>::MODULUS_BITS as usize;
// Remove unnecessary constant zeros in the most-significant positions.
bits = bits
.into_iter()
// We iterate from the MSB down.
.rev()
// Skip leading zeros, if they are constants.
.skip_while(|b| b.is_constant() && (b.value().unwrap() == false))
.collect();
// After collecting we are in big-endian form; we have to reverse to get back to
// little-endian.
bits.reverse();
let scalar_modulus_bits = <P::ScalarField as PrimeField>::size_in_bits();
let mut mul_result = Self::zero(); let mut mul_result = Self::zero();
let mut power_of_two_times_self = non_zero_self; let mut power_of_two_times_self = non_zero_self;
// We chunk up `bits` into `p`-sized chunks. // We chunk up `bits` into `p`-sized chunks.
@ -497,7 +521,7 @@ where
self.fixed_scalar_mul_le(&mut mul_result, &mut power_of_two_times_self, bits)?; self.fixed_scalar_mul_le(&mut mul_result, &mut power_of_two_times_self, bits)?;
} }
// The foregoing algorithms rely on mixed/incomplete addition, and so do not
// The foregoing algorithm relies on incomplete addition, and so does not
// work when the input (`self`) is zero. We hence have to perform // work when the input (`self`) is zero. We hence have to perform
// a check to ensure that if the input is zero, then so is the output. // a check to ensure that if the input is zero, then so is the output.
// The cost of this check should be less than the benefit of using // The cost of this check should be less than the benefit of using

+ 32
- 1
src/groups/curves/short_weierstrass/non_zero_affine.rs

@ -74,7 +74,7 @@ where
let (x1, y1) = (&self.x, &self.y); let (x1, y1) = (&self.x, &self.y);
let x1_sqr = x1.square()?; let x1_sqr = x1.square()?;
// Then, // Then,
// tangent lambda := (3 * x1^2 + a) / y1·;
// tangent lambda := (3 * x1^2 + a) / (2 * y1);
// x3 = lambda^2 - 2x1 // x3 = lambda^2 - 2x1
// y3 = lambda * (x1 - x3) - y1 // y3 = lambda * (x1 - x3) - y1
let numerator = x1_sqr.double()? + &x1_sqr + P::COEFF_A; let numerator = x1_sqr.double()? + &x1_sqr + P::COEFF_A;
@ -86,6 +86,37 @@ where
} }
} }
/// Computes `(self + other) + self`. This method requires only 5 constraints,
/// less than the 7 required when computing via `self.double() + other`.
///
/// This follows the formulae from [\[ELM03\]](https://arxiv.org/abs/math/0208038).
#[tracing::instrument(target = "r1cs", skip(self))]
pub(crate) fn double_and_add(&self, other: &Self) -> Result<Self, SynthesisError> {
if [self].is_constant() || other.is_constant() {
self.double()?.add_unchecked(other)
} else {
let (x1, y1) = (&self.x, &self.y);
let (x2, y2) = (&other.x, &other.y);
// Calculate self + other:
// slope lambda := (y2 - y1)/(x2 - x1);
// x3 = lambda^2 - x1 - x2;
// y3 = lambda * (x1 - x3) - y1
let numerator = y2 - y1;
let denominator = x2 - x1;
let lambda_1 = numerator.mul_by_inverse(&denominator)?;
let x3 = lambda_1.square()? - x1 - x2;
// Calculate final addition slope:
let lambda_2 = (lambda_1 + y1.double()?.mul_by_inverse(&(&x3 - x1))?).negate()?;
let x4 = lambda_2.square()? - x1 - x3;
let y4 = lambda_2 * &(x1 - &x4) - y1;
Ok(Self::new(x4, y4))
}
}
/// Doubles `self` in place. /// Doubles `self` in place.
#[tracing::instrument(target = "r1cs", skip(self))] #[tracing::instrument(target = "r1cs", skip(self))]
pub(crate) fn double_in_place(&mut self) -> Result<(), SynthesisError> { pub(crate) fn double_in_place(&mut self) -> Result<(), SynthesisError> {

Loading…
Cancel
Save