Browse Source

feat: implement joint scalar mul

pull/1/head
Youssef El Housni 2 months ago
parent
commit
0ae123fbb1
2 changed files with 129 additions and 1 deletions
  1. +66
    -0
      src/groups/curves/short_weierstrass/mod.rs
  2. +63
    -1
      src/groups/mod.rs

+ 66
- 0
src/groups/curves/short_weierstrass/mod.rs

@ -632,6 +632,72 @@ where
infinity.select(&Self::zero(), &mul_result) infinity.select(&Self::zero(), &mul_result)
} }
/// Computes `bits1 * self + bits2 * p`, where `bits1` and `bits2` are big-endian
/// `Boolean` representation of the scalars.
///
/// `self` and `p` are non-zero and `self` ≠ `-p`.
#[tracing::instrument(target = "r1cs", skip(bits1, bits2))]
fn joint_scalar_mul_be<'a>(
&self,
p: &Self,
bits1: impl Iterator<Item = &'a Boolean<<P::BaseField as Field>::BasePrimeField>>,
bits2: impl Iterator<Item = &'a Boolean<<P::BaseField as Field>::BasePrimeField>>,
) -> Result<Self, SynthesisError> {
// prepare bits decomposition
let mut bits1 = bits1.collect::<Vec<_>>();
if bits1.len() == 0 {
return Ok(Self::zero());
}
// Remove unnecessary constant zeros in the most-significant positions.
bits1 = bits1
.into_iter()
// We iterate from the MSB down.
.rev()
// Skip leading zeros, if they are constants.
.skip_while(|b| b.is_constant() && (b.value().unwrap() == false))
.collect();
let mut bits2 = bits2.collect::<Vec<_>>();
if bits2.len() == 0 {
return Ok(Self::zero());
}
// Remove unnecessary constant zeros in the most-significant positions.
bits2 = bits2
.into_iter()
// We iterate from the MSB down.
.rev()
// Skip leading zeros, if they are constants.
.skip_while(|b| b.is_constant() && (b.value().unwrap() == false))
.collect();
// precompute points
let aff1 = self.to_affine()?;
let nz_aff1 = NonZeroAffineVar::new(aff1.x, aff1.y);
let aff2 = p.to_affine()?;
let nz_aff2 = NonZeroAffineVar::new(aff2.x, aff2.y);
let mut aff1_neg = NonZeroAffineVar::new(nz_aff1.x.clone(), nz_aff1.y.negate()?);
let mut aff2_neg = NonZeroAffineVar::new(nz_aff2.x.clone(), nz_aff2.y.negate()?);
let mut acc = nz_aff1.add_unchecked(&nz_aff2.clone())?;
// double-and-add loop
for (bit1, bit2) in (bits1.iter().rev().skip(1).rev()).zip(bits2.iter().rev().skip(1).rev()) {
let mut b = bit1.select(&nz_aff1, &aff1_neg)?;
acc = acc.double_and_add_unchecked(&b)?;
b = bit2.select(&nz_aff2, &aff2_neg)?;
acc = acc.add_unchecked(&b)?;
}
// last bit
aff1_neg = aff1_neg.add_unchecked(&acc)?;
acc = bits1[bits1.len()-1].select(&acc, &aff1_neg)?;
aff2_neg = aff2_neg.add_unchecked(&acc)?;
acc = bits2[bits1.len()-1].select(&acc, &aff2_neg)?;
Ok(acc.into_projective())
}
#[tracing::instrument(target = "r1cs", skip(scalar_bits_with_bases))] #[tracing::instrument(target = "r1cs", skip(scalar_bits_with_bases))]
fn precomputed_base_scalar_mul_le<'a, I, B>( fn precomputed_base_scalar_mul_le<'a, I, B>(
&mut self, &mut self,

+ 63
- 1
src/groups/mod.rs

@ -131,6 +131,20 @@ pub trait CurveVar:
Ok(res) Ok(res)
} }
/// Computes a `I1 * self + I2 * p` in place, where `I1` and `I2` are `Boolean` *big-endian*
/// representation of the scalars.
#[tracing::instrument(target = "r1cs", skip(bits1, bits2))]
fn joint_scalar_mul_be<'a>(
&self,
p: &Self,
bits1: impl Iterator<Item = &'a Boolean<ConstraintF>>,
bits2: impl Iterator<Item = &'a Boolean<ConstraintF>>,
) -> Result<Self, SynthesisError> {
let res1 = self.scalar_mul_le(bits1)?;
let res2 = p.scalar_mul_le(bits2)?;
Ok(res1+res2)
}
/// Computes a `I * self` in place, where `I` is a `Boolean` *little-endian* /// Computes a `I * self` in place, where `I` is a `Boolean` *little-endian*
/// representation of the scalar. /// representation of the scalar.
/// ///
@ -239,8 +253,56 @@ mod test_sw_arithmetic {
cs.is_satisfied() cs.is_satisfied()
} }
fn point_joint_scalar_mul_satisfied<G>() -> Result<bool>
where
G: CurveGroup,
G::BaseField: PrimeField,
G::Config: SWCurveConfig,
{
let mut rng = ark_std::test_rng();
let cs = ConstraintSystem::new_ref();
let point_in1 = Projective::<G::Config>::rand(&mut rng);
let point_in2 = Projective::<G::Config>::rand(&mut rng);
let scalar1 = G::ScalarField::rand(&mut rng);
let scalar2 = G::ScalarField::rand(&mut rng);
let point_out = point_in1 * scalar1 + point_in2 * scalar2;
let point_in1 =
ProjectiveVar::<G::Config, FpVar<G::BaseField>>::new_witness(cs.clone(), || {
Ok(point_in1)
})?;
let point_in2 =
ProjectiveVar::<G::Config, FpVar<G::BaseField>>::new_witness(cs.clone(), || {
Ok(point_in2)
})?;
let point_out =
ProjectiveVar::<G::Config, FpVar<G::BaseField>>::new_input(cs.clone(), || {
Ok(point_out)
})?;
let scalar1 = NonNativeFieldVar::new_input(cs.clone(), || Ok(scalar1))?;
let scalar2 = NonNativeFieldVar::new_input(cs.clone(), || Ok(scalar2))?;
let res = point_in1.joint_scalar_mul_be(&point_in2, scalar1.to_bits_le().unwrap().iter(), scalar2.to_bits_le().unwrap().iter())?;
point_out.enforce_equal(&res)?;
println!(
"#r1cs for joint_scalar_mul: {}",
cs.num_constraints()
);
cs.is_satisfied()
}
#[test] #[test]
fn test_point_scalar_mul_joye() {
fn test_point_scalar_mul() {
assert!(point_scalar_mul_joye_satisfied::<ark_bn254::G1Projective>().unwrap()); assert!(point_scalar_mul_joye_satisfied::<ark_bn254::G1Projective>().unwrap());
} }
#[test]
fn test_point_joint_scalar_mul() {
assert!(point_joint_scalar_mul_satisfied::<ark_bn254::G1Projective>().unwrap());
}
} }

Loading…
Cancel
Save