diff --git a/Cargo.toml b/Cargo.toml index c5d7fae..05fcddd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,6 @@ rand_chacha = "0.3" itertools = "0.9.0" subtle = "2.4" pasta_curves = { version = "0.4.0", features = ["repr-c"] } -pasta-msm = "0.1.3" neptune = { version = "8.1.0", default-features = false } generic-array = "0.14.4" num-bigint = { version = "0.4", features = ["serde", "rand"] } @@ -47,4 +46,7 @@ name = "compressed-snark" harness = false [features] -default = [ "bellperson/default", "neptune/default" ] +default = ["bellperson/default", "neptune/default"] + +[target.'cfg(any(target_arch = "x86_64", target_arch = "aarch64"))'.dependencies] +pasta-msm = "0.1.3" diff --git a/src/pasta.rs b/src/pasta.rs index 3748225..a4610ae 100644 --- a/src/pasta.rs +++ b/src/pasta.rs @@ -16,9 +16,138 @@ use pasta_curves::{ }; use rand::SeedableRng; use rand_chacha::ChaCha20Rng; -use rayon::prelude::*; use sha3::Shake256; -use std::{io::Read, ops::Mul}; +use std::io::Read; + +//////////////////////////////////////Shared MSM code for Pasta curves/////////////////////////////////////////////// + +/// Native implementation of fast multiexp for platforms that do not support pasta_msm/semolina +/// Forked from zcash/halo2 +fn cpu_multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { + use ff::PrimeField; + let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); + + let c = if bases.len() < 4 { + 1 + } else if bases.len() < 32 { + 3 + } else { + (f64::from(bases.len() as u32)).ln().ceil() as usize + }; + + fn get_at(segment: usize, c: usize, bytes: &F::Repr) -> usize { + let skip_bits = segment * c; + let skip_bytes = skip_bits / 8; + + if skip_bytes >= 32 { + return 0; + } + + let mut v = [0; 8]; + for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) { + *v = *o; + } + + let mut tmp = u64::from_le_bytes(v); + tmp >>= skip_bits - (skip_bytes * 8); + tmp %= 1 << c; + + tmp as usize + } + + let segments = (256 / c) + 1; + + for current_segment in (0..segments).rev() { + for _ in 0..c { + *acc = acc.double(); + } + + #[derive(Clone, Copy)] + enum Bucket { + None, + Affine(C), + Projective(C::Curve), + } + + impl Bucket { + fn add_assign(&mut self, other: &C) { + *self = match *self { + Bucket::None => Bucket::Affine(*other), + Bucket::Affine(a) => Bucket::Projective(a + *other), + Bucket::Projective(mut a) => { + a += *other; + Bucket::Projective(a) + } + } + } + + fn add(self, mut other: C::Curve) -> C::Curve { + match self { + Bucket::None => other, + Bucket::Affine(a) => { + other += a; + other + } + Bucket::Projective(a) => other + a, + } + } + } + + let mut buckets: Vec> = vec![Bucket::None; (1 << c) - 1]; + + for (coeff, base) in coeffs.iter().zip(bases.iter()) { + let coeff = get_at::(current_segment, c, coeff); + if coeff != 0 { + buckets[coeff - 1].add_assign(base); + } + } + + // Summation by parts + // e.g. 3a + 2b + 1c = a + + // (a) + b + + // ((a) + b) + c + let mut running_sum = C::Curve::identity(); + for exp in buckets.into_iter().rev() { + running_sum = exp.add(running_sum); + *acc += &running_sum; + } + } +} + +/// Performs a multi-exponentiation operation without GPU acceleration. +/// +/// This function will panic if coeffs and bases have a different length. +/// +/// This will use multithreading if beneficial. +/// Forked from zcash/halo2 +fn cpu_best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { + assert_eq!(coeffs.len(), bases.len()); + + let num_threads = rayon::current_num_threads(); + if coeffs.len() > num_threads { + let chunk = coeffs.len() / num_threads; + let num_chunks = coeffs.chunks(chunk).len(); + let mut results = vec![C::Curve::identity(); num_chunks]; + rayon::scope(|scope| { + let chunk = coeffs.len() / num_threads; + + for ((coeffs, bases), acc) in coeffs + .chunks(chunk) + .zip(bases.chunks(chunk)) + .zip(results.iter_mut()) + { + scope.spawn(move |_| { + cpu_multiexp_serial(coeffs, bases, acc); + }); + } + }); + results.iter().fold(C::Curve::identity(), |a, b| a + b) + } else { + let mut acc = C::Curve::identity(); + cpu_multiexp_serial(coeffs, bases, &mut acc); + acc + } +} //////////////////////////////////////Pallas/////////////////////////////////////////////// @@ -43,6 +172,7 @@ impl Group for pallas::Point { type RO = PoseidonRO; type ROCircuit = PoseidonROCircuit; + #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] fn vartime_multiscalar_mul( scalars: &[Self::Scalar], bases: &[Self::PreprocessedGroupElement], @@ -50,14 +180,18 @@ impl Group for pallas::Point { if scalars.len() >= 128 { pasta_msm::pallas(bases, scalars) } else { - scalars - .par_iter() - .zip(bases) - .map(|(scalar, base)| base.mul(scalar)) - .reduce(Ep::group_zero, |x, y| x + y) + cpu_best_multiexp(scalars, bases) } } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + fn vartime_multiscalar_mul( + scalars: &[Self::Scalar], + bases: &[Self::PreprocessedGroupElement], + ) -> Self { + cpu_best_multiexp(scalars, bases) + } + fn preprocessed(&self) -> Self::PreprocessedGroupElement { self.to_affine() } @@ -153,6 +287,7 @@ impl Group for vesta::Point { type RO = PoseidonRO; type ROCircuit = PoseidonROCircuit; + #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] fn vartime_multiscalar_mul( scalars: &[Self::Scalar], bases: &[Self::PreprocessedGroupElement], @@ -160,14 +295,18 @@ impl Group for vesta::Point { if scalars.len() >= 128 { pasta_msm::vesta(bases, scalars) } else { - scalars - .par_iter() - .zip(bases) - .map(|(scalar, base)| base.mul(scalar)) - .reduce(Eq::group_zero, |x, y| x + y) + cpu_best_multiexp(scalars, bases) } } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + fn vartime_multiscalar_mul( + scalars: &[Self::Scalar], + bases: &[Self::PreprocessedGroupElement], + ) -> Self { + cpu_best_multiexp(scalars, bases) + } + fn compress(&self) -> Self::CompressedGroupElement { VestaCompressedElementWrapper::new(self.to_bytes()) }