diff --git a/examples/minroot.rs b/examples/minroot.rs index 122abda..801cf55 100644 --- a/examples/minroot.rs +++ b/examples/minroot.rs @@ -169,6 +169,7 @@ fn main() { ); // produce public parameters + let start = Instant::now(); println!("Producing public parameters..."); let pp = PublicParams::< G1, @@ -176,6 +177,8 @@ fn main() { MinRootCircuit<::Scalar>, TrivialTestCircuit<::Scalar>, >::setup(circuit_primary, circuit_secondary.clone()); + println!("PublicParams::setup, took {:?} ", start.elapsed()); + println!( "Number of constraints per step (primary circuit): {}", pp.num_constraints().0 diff --git a/src/pasta.rs b/src/pasta.rs index 97f0319..f882229 100644 --- a/src/pasta.rs +++ b/src/pasta.rs @@ -11,18 +11,203 @@ use num_traits::Num; use pasta_curves::{ self, arithmetic::{CurveAffine, CurveExt, Group as OtherGroup}, - group::{Curve, Group as AnotherGroup, GroupEncoding}, - pallas, vesta, Ep, Eq, + group::{cofactor::CofactorCurveAffine, Curve, Group as AnotherGroup, GroupEncoding}, + pallas, vesta, Ep, EpAffine, Eq, EqAffine, }; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; +use rayon::prelude::*; use serde::{Deserialize, Serialize}; use sha3::Shake256; use std::io::Read; -//////////////////////////////////////Shared MSM code for Pasta curves/////////////////////////////////////////////// +/// A wrapper for compressed group elements of pallas +#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)] +pub struct PallasCompressedElementWrapper { + repr: [u8; 32], +} + +impl PallasCompressedElementWrapper { + /// Wraps repr into the wrapper + pub fn new(repr: [u8; 32]) -> Self { + Self { repr } + } +} + +/// A wrapper for compressed group elements of vesta +#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)] +pub struct VestaCompressedElementWrapper { + repr: [u8; 32], +} + +impl VestaCompressedElementWrapper { + /// Wraps repr into the wrapper + pub fn new(repr: [u8; 32]) -> Self { + Self { repr } + } +} + +macro_rules! impl_traits { + ( + $name:ident, + $name_compressed:ident, + $name_curve:ident, + $name_curve_affine:ident, + $order_str:literal + ) => { + impl Group for $name::Point { + type Base = $name::Base; + type Scalar = $name::Scalar; + type CompressedGroupElement = $name_compressed; + type PreprocessedGroupElement = $name::Affine; + type RO = PoseidonRO; + type ROCircuit = PoseidonROCircuit; + + #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] + fn vartime_multiscalar_mul( + scalars: &[Self::Scalar], + bases: &[Self::PreprocessedGroupElement], + ) -> Self { + if scalars.len() >= 128 { + pasta_msm::$name(bases, scalars) + } else { + cpu_best_multiexp(scalars, bases) + } + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + fn vartime_multiscalar_mul( + scalars: &[Self::Scalar], + bases: &[Self::PreprocessedGroupElement], + ) -> Self { + cpu_best_multiexp(scalars, bases) + } + + fn preprocessed(&self) -> Self::PreprocessedGroupElement { + self.to_affine() + } + + fn compress(&self) -> Self::CompressedGroupElement { + $name_compressed::new(self.to_bytes()) + } + + fn from_label(label: &'static [u8], n: usize) -> Vec { + let mut shake = Shake256::default(); + shake.input(label); + let mut reader = shake.xof_result(); + let mut uniform_bytes_vec = Vec::new(); + for _ in 0..n { + let mut uniform_bytes = [0u8; 32]; + reader.read_exact(&mut uniform_bytes).unwrap(); + uniform_bytes_vec.push(uniform_bytes); + } + let gens_proj: Vec<$name_curve> = (0..n) + .collect::>() + .into_par_iter() + .map(|i| { + let hash = $name_curve::hash_to_curve("from_uniform_bytes"); + hash(&uniform_bytes_vec[i]) + }) + .collect(); + + let num_threads = rayon::current_num_threads(); + if gens_proj.len() > num_threads { + let chunk = (gens_proj.len() as f64 / num_threads as f64).ceil() as usize; + (0..num_threads) + .collect::>() + .into_par_iter() + .map(|i| { + let start = i * chunk; + let end = if i == num_threads - 1 { + gens_proj.len() + } else { + core::cmp::min((i + 1) * chunk, gens_proj.len()) + }; + if end > start { + let mut gens = vec![$name_curve_affine::identity(); end - start]; + ::batch_normalize(&gens_proj[start..end], &mut gens); + gens + } else { + vec![] + } + }) + .collect::>>() + .into_par_iter() + .flatten() + .collect() + } else { + let mut gens = vec![$name_curve_affine::identity(); n]; + ::batch_normalize(&gens_proj, &mut gens); + gens + } + } + + fn to_coordinates(&self) -> (Self::Base, Self::Base, bool) { + let coordinates = self.to_affine().coordinates(); + if coordinates.is_some().unwrap_u8() == 1 { + (*coordinates.unwrap().x(), *coordinates.unwrap().y(), false) + } else { + (Self::Base::zero(), Self::Base::zero(), true) + } + } + + fn get_curve_params() -> (Self::Base, Self::Base, BigInt) { + let A = Self::Base::zero(); + let B = Self::Base::from(5); + let order = BigInt::from_str_radix($order_str, 16).unwrap(); + + (A, B, order) + } + + fn zero() -> Self { + $name::Point::group_zero() + } + + fn get_generator() -> Self { + $name::Point::generator() + } + } + + impl ChallengeTrait for $name::Scalar { + fn challenge(label: &'static [u8], transcript: &mut Transcript) -> Self { + let mut key: ::Seed = Default::default(); + transcript.challenge_bytes(label, &mut key); + let mut rng = ChaCha20Rng::from_seed(key); + $name::Scalar::random(&mut rng) + } + } + + impl CompressedGroup for $name_compressed { + type GroupElement = $name::Point; + + fn decompress(&self) -> Option<$name::Point> { + Some($name_curve::from_bytes(&self.repr).unwrap()) + } + + fn as_bytes(&self) -> &[u8] { + &self.repr + } + } + }; +} + +impl_traits!( + pallas, + PallasCompressedElementWrapper, + Ep, + EpAffine, + "40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001" +); + +impl_traits!( + vesta, + VestaCompressedElementWrapper, + Eq, + EqAffine, + "40000000000000000000000000000000224698fc094cf91b992d30ed00000001" +); /// Native implementation of fast multiexp for platforms that do not support pasta_msm/semolina -/// Forked from zcash/halo2 +/// Adapted from zcash/halo2 fn cpu_multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { use ff::PrimeField; let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); @@ -119,7 +304,7 @@ fn cpu_multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: & /// This function will panic if coeffs and bases have a different length. /// /// This will use multithreading if beneficial. -/// Forked from zcash/halo2 +/// Adapted from zcash/halo2 fn cpu_best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { assert_eq!(coeffs.len(), bases.len()); @@ -129,8 +314,6 @@ fn cpu_best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu let num_chunks = coeffs.chunks(chunk).len(); let mut results = vec![C::Curve::identity(); num_chunks]; rayon::scope(|scope| { - let chunk = coeffs.len() / num_threads; - for ((coeffs, bases), acc) in coeffs .chunks(chunk) .zip(bases.chunks(chunk)) @@ -149,64 +332,18 @@ fn cpu_best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu } } -//////////////////////////////////////Pallas/////////////////////////////////////////////// - -/// A wrapper for compressed group elements that come from the pallas curve -#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)] -pub struct PallasCompressedElementWrapper { - repr: [u8; 32], -} - -impl PallasCompressedElementWrapper { - /// Wraps repr into the wrapper - pub fn new(repr: [u8; 32]) -> Self { - Self { repr } - } -} +#[cfg(test)] +mod tests { + use super::*; + type G = pasta_curves::pallas::Point; -impl Group for pallas::Point { - type Base = pallas::Base; - type Scalar = pallas::Scalar; - type CompressedGroupElement = PallasCompressedElementWrapper; - type PreprocessedGroupElement = pallas::Affine; - type RO = PoseidonRO; - type ROCircuit = PoseidonROCircuit; - - #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] - fn vartime_multiscalar_mul( - scalars: &[Self::Scalar], - bases: &[Self::PreprocessedGroupElement], - ) -> Self { - if scalars.len() >= 128 { - pasta_msm::pallas(bases, scalars) - } else { - cpu_best_multiexp(scalars, bases) - } - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - fn vartime_multiscalar_mul( - scalars: &[Self::Scalar], - bases: &[Self::PreprocessedGroupElement], - ) -> Self { - cpu_best_multiexp(scalars, bases) - } - - fn preprocessed(&self) -> Self::PreprocessedGroupElement { - self.to_affine() - } - - fn compress(&self) -> Self::CompressedGroupElement { - PallasCompressedElementWrapper::new(self.to_bytes()) - } - - fn from_label(label: &'static [u8], n: usize) -> Vec { + fn from_label_serial(label: &'static [u8], n: usize) -> Vec { let mut shake = Shake256::default(); shake.input(label); let mut reader = shake.xof_result(); - let mut gens: Vec = Vec::new(); - let mut uniform_bytes = [0u8; 32]; + let mut gens = Vec::new(); for _ in 0..n { + let mut uniform_bytes = [0u8; 32]; reader.read_exact(&mut uniform_bytes).unwrap(); let hash = Ep::hash_to_curve("from_uniform_bytes"); gens.push(hash(&uniform_bytes).to_affine()); @@ -214,167 +351,17 @@ impl Group for pallas::Point { gens } - fn to_coordinates(&self) -> (Self::Base, Self::Base, bool) { - let coordinates = self.to_affine().coordinates(); - if coordinates.is_some().unwrap_u8() == 1 { - (*coordinates.unwrap().x(), *coordinates.unwrap().y(), false) - } else { - (Self::Base::zero(), Self::Base::zero(), true) + #[test] + fn test_from_label() { + let label = b"test_from_label"; + for n in [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 1021, + ] { + let gens_par = ::from_label(label, n); + let gens_ser = from_label_serial(label, n); + assert_eq!(gens_par.len(), n); + assert_eq!(gens_ser.len(), n); + assert_eq!(gens_par, gens_ser); } } - - fn get_curve_params() -> (Self::Base, Self::Base, BigInt) { - let A = Self::Base::zero(); - let B = Self::Base::from(5); - let order = BigInt::from_str_radix( - "40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001", - 16, - ) - .unwrap(); - - (A, B, order) - } - - fn zero() -> Self { - pallas::Point::group_zero() - } - - fn get_generator() -> Self { - pallas::Point::generator() - } -} - -impl ChallengeTrait for pallas::Scalar { - fn challenge(label: &'static [u8], transcript: &mut Transcript) -> Self { - let mut key: ::Seed = Default::default(); - transcript.challenge_bytes(label, &mut key); - let mut rng = ChaCha20Rng::from_seed(key); - pallas::Scalar::random(&mut rng) - } -} - -impl CompressedGroup for PallasCompressedElementWrapper { - type GroupElement = pallas::Point; - - fn decompress(&self) -> Option { - Some(Ep::from_bytes(&self.repr).unwrap()) - } - fn as_bytes(&self) -> &[u8] { - &self.repr - } -} - -//////////////////////////////////////Vesta//////////////////////////////////////////////// - -/// A wrapper for compressed group elements that come from the vesta curve -#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)] -pub struct VestaCompressedElementWrapper { - repr: [u8; 32], -} - -impl VestaCompressedElementWrapper { - /// Wraps repr into the wrapper - pub fn new(repr: [u8; 32]) -> Self { - Self { repr } - } -} - -impl Group for vesta::Point { - type Base = vesta::Base; - type Scalar = vesta::Scalar; - type CompressedGroupElement = VestaCompressedElementWrapper; - type PreprocessedGroupElement = vesta::Affine; - type RO = PoseidonRO; - type ROCircuit = PoseidonROCircuit; - - #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] - fn vartime_multiscalar_mul( - scalars: &[Self::Scalar], - bases: &[Self::PreprocessedGroupElement], - ) -> Self { - if scalars.len() >= 128 { - pasta_msm::vesta(bases, scalars) - } else { - cpu_best_multiexp(scalars, bases) - } - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - fn vartime_multiscalar_mul( - scalars: &[Self::Scalar], - bases: &[Self::PreprocessedGroupElement], - ) -> Self { - cpu_best_multiexp(scalars, bases) - } - - fn compress(&self) -> Self::CompressedGroupElement { - VestaCompressedElementWrapper::new(self.to_bytes()) - } - - fn preprocessed(&self) -> Self::PreprocessedGroupElement { - self.to_affine() - } - - fn from_label(label: &'static [u8], n: usize) -> Vec { - let mut shake = Shake256::default(); - shake.input(label); - let mut reader = shake.xof_result(); - let mut gens: Vec = Vec::new(); - let mut uniform_bytes = [0u8; 32]; - for _ in 0..n { - reader.read_exact(&mut uniform_bytes).unwrap(); - let hash = Eq::hash_to_curve("from_uniform_bytes"); - gens.push(hash(&uniform_bytes).to_affine()); - } - gens - } - - fn to_coordinates(&self) -> (Self::Base, Self::Base, bool) { - let coordinates = self.to_affine().coordinates(); - if coordinates.is_some().unwrap_u8() == 1 { - (*coordinates.unwrap().x(), *coordinates.unwrap().y(), false) - } else { - (Self::Base::zero(), Self::Base::zero(), true) - } - } - - fn get_curve_params() -> (Self::Base, Self::Base, BigInt) { - let A = Self::Base::zero(); - let B = Self::Base::from(5); - let order = BigInt::from_str_radix( - "40000000000000000000000000000000224698fc094cf91b992d30ed00000001", - 16, - ) - .unwrap(); - - (A, B, order) - } - - fn zero() -> Self { - vesta::Point::group_zero() - } - - fn get_generator() -> Self { - vesta::Point::generator() - } -} - -impl ChallengeTrait for vesta::Scalar { - fn challenge(label: &'static [u8], transcript: &mut Transcript) -> Self { - let mut key: ::Seed = Default::default(); - transcript.challenge_bytes(label, &mut key); - let mut rng = ChaCha20Rng::from_seed(key); - vesta::Scalar::random(&mut rng) - } -} - -impl CompressedGroup for VestaCompressedElementWrapper { - type GroupElement = vesta::Point; - - fn decompress(&self) -> Option { - Some(Eq::from_bytes(&self.repr).unwrap()) - } - fn as_bytes(&self) -> &[u8] { - &self.repr - } }