diff --git a/src/groups/curves/short_weierstrass/mod.rs b/src/groups/curves/short_weierstrass/mod.rs index b605cb9..8251bdb 100644 --- a/src/groups/curves/short_weierstrass/mod.rs +++ b/src/groups/curves/short_weierstrass/mod.rs @@ -632,6 +632,72 @@ where infinity.select(&Self::zero(), &mul_result) } + /// Computes `bits1 * self + bits2 * p`, where `bits1` and `bits2` are big-endian + /// `Boolean` representation of the scalars. + /// + /// `self` and `p` are non-zero and `self` ≠ `-p`. + #[tracing::instrument(target = "r1cs", skip(bits1, bits2))] + fn joint_scalar_mul_be<'a>( + &self, + p: &Self, + bits1: impl Iterator::BasePrimeField>>, + bits2: impl Iterator::BasePrimeField>>, + ) -> Result { + // prepare bits decomposition + let mut bits1 = bits1.collect::>(); + if bits1.len() == 0 { + return Ok(Self::zero()); + } + // Remove unnecessary constant zeros in the most-significant positions. + bits1 = bits1 + .into_iter() + // We iterate from the MSB down. + .rev() + // Skip leading zeros, if they are constants. + .skip_while(|b| b.is_constant() && (b.value().unwrap() == false)) + .collect(); + + let mut bits2 = bits2.collect::>(); + if bits2.len() == 0 { + return Ok(Self::zero()); + } + // Remove unnecessary constant zeros in the most-significant positions. + bits2 = bits2 + .into_iter() + // We iterate from the MSB down. + .rev() + // Skip leading zeros, if they are constants. + .skip_while(|b| b.is_constant() && (b.value().unwrap() == false)) + .collect(); + + // precompute points + let aff1 = self.to_affine()?; + let nz_aff1 = NonZeroAffineVar::new(aff1.x, aff1.y); + + let aff2 = p.to_affine()?; + let nz_aff2 = NonZeroAffineVar::new(aff2.x, aff2.y); + + let mut aff1_neg = NonZeroAffineVar::new(nz_aff1.x.clone(), nz_aff1.y.negate()?); + let mut aff2_neg = NonZeroAffineVar::new(nz_aff2.x.clone(), nz_aff2.y.negate()?); + let mut acc = nz_aff1.add_unchecked(&nz_aff2.clone())?; + + // double-and-add loop + for (bit1, bit2) in (bits1.iter().rev().skip(1).rev()).zip(bits2.iter().rev().skip(1).rev()) { + let mut b = bit1.select(&nz_aff1, &aff1_neg)?; + acc = acc.double_and_add_unchecked(&b)?; + b = bit2.select(&nz_aff2, &aff2_neg)?; + acc = acc.add_unchecked(&b)?; + } + + // last bit + aff1_neg = aff1_neg.add_unchecked(&acc)?; + acc = bits1[bits1.len()-1].select(&acc, &aff1_neg)?; + aff2_neg = aff2_neg.add_unchecked(&acc)?; + acc = bits2[bits1.len()-1].select(&acc, &aff2_neg)?; + + Ok(acc.into_projective()) + } + #[tracing::instrument(target = "r1cs", skip(scalar_bits_with_bases))] fn precomputed_base_scalar_mul_le<'a, I, B>( &mut self, diff --git a/src/groups/mod.rs b/src/groups/mod.rs index 7596808..d9fc650 100644 --- a/src/groups/mod.rs +++ b/src/groups/mod.rs @@ -131,6 +131,20 @@ pub trait CurveVar: Ok(res) } + /// Computes a `I1 * self + I2 * p` in place, where `I1` and `I2` are `Boolean` *big-endian* + /// representation of the scalars. + #[tracing::instrument(target = "r1cs", skip(bits1, bits2))] + fn joint_scalar_mul_be<'a>( + &self, + p: &Self, + bits1: impl Iterator>, + bits2: impl Iterator>, + ) -> Result { + let res1 = self.scalar_mul_le(bits1)?; + let res2 = p.scalar_mul_le(bits2)?; + Ok(res1+res2) + } + /// Computes a `I * self` in place, where `I` is a `Boolean` *little-endian* /// representation of the scalar. /// @@ -239,8 +253,56 @@ mod test_sw_arithmetic { cs.is_satisfied() } + fn point_joint_scalar_mul_satisfied() -> Result + where + G: CurveGroup, + G::BaseField: PrimeField, + G::Config: SWCurveConfig, + { + let mut rng = ark_std::test_rng(); + + let cs = ConstraintSystem::new_ref(); + let point_in1 = Projective::::rand(&mut rng); + let point_in2 = Projective::::rand(&mut rng); + let scalar1 = G::ScalarField::rand(&mut rng); + let scalar2 = G::ScalarField::rand(&mut rng); + let point_out = point_in1 * scalar1 + point_in2 * scalar2; + + let point_in1 = + ProjectiveVar::>::new_witness(cs.clone(), || { + Ok(point_in1) + })?; + let point_in2 = + ProjectiveVar::>::new_witness(cs.clone(), || { + Ok(point_in2) + })?; + let point_out = + ProjectiveVar::>::new_input(cs.clone(), || { + Ok(point_out) + })?; + let scalar1 = NonNativeFieldVar::new_input(cs.clone(), || Ok(scalar1))?; + let scalar2 = NonNativeFieldVar::new_input(cs.clone(), || Ok(scalar2))?; + + let res = point_in1.joint_scalar_mul_be(&point_in2, scalar1.to_bits_le().unwrap().iter(), scalar2.to_bits_le().unwrap().iter())?; + + point_out.enforce_equal(&res)?; + + println!( + "#r1cs for joint_scalar_mul: {}", + cs.num_constraints() + ); + + + cs.is_satisfied() + } + #[test] - fn test_point_scalar_mul_joye() { + fn test_point_scalar_mul() { assert!(point_scalar_mul_joye_satisfied::().unwrap()); } + #[test] + fn test_point_joint_scalar_mul() { + assert!(point_joint_scalar_mul_satisfied::().unwrap()); + } + }