mirror of
https://github.com/arnaucube/Nova.git
synced 2026-01-11 08:31:29 +01:00
Speed up MSMs for non-GPU accelerated MSMs and architectures that don't support GPU/semolina (#126)
* WASM target support * fast multiexp for WASM * add parallelisation for MSM https://github.com/zcash/halo2/blob/main/halo2_proofs/src/arithmetic.rs
This commit is contained in:
@@ -23,7 +23,6 @@ rand_chacha = "0.3"
|
|||||||
itertools = "0.9.0"
|
itertools = "0.9.0"
|
||||||
subtle = "2.4"
|
subtle = "2.4"
|
||||||
pasta_curves = { version = "0.4.0", features = ["repr-c"] }
|
pasta_curves = { version = "0.4.0", features = ["repr-c"] }
|
||||||
pasta-msm = "0.1.3"
|
|
||||||
neptune = { version = "8.1.0", default-features = false }
|
neptune = { version = "8.1.0", default-features = false }
|
||||||
generic-array = "0.14.4"
|
generic-array = "0.14.4"
|
||||||
num-bigint = { version = "0.4", features = ["serde", "rand"] }
|
num-bigint = { version = "0.4", features = ["serde", "rand"] }
|
||||||
@@ -47,4 +46,7 @@ name = "compressed-snark"
|
|||||||
harness = false
|
harness = false
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = [ "bellperson/default", "neptune/default" ]
|
default = ["bellperson/default", "neptune/default"]
|
||||||
|
|
||||||
|
[target.'cfg(any(target_arch = "x86_64", target_arch = "aarch64"))'.dependencies]
|
||||||
|
pasta-msm = "0.1.3"
|
||||||
|
|||||||
163
src/pasta.rs
163
src/pasta.rs
@@ -16,9 +16,138 @@ use pasta_curves::{
|
|||||||
};
|
};
|
||||||
use rand::SeedableRng;
|
use rand::SeedableRng;
|
||||||
use rand_chacha::ChaCha20Rng;
|
use rand_chacha::ChaCha20Rng;
|
||||||
use rayon::prelude::*;
|
|
||||||
use sha3::Shake256;
|
use sha3::Shake256;
|
||||||
use std::{io::Read, ops::Mul};
|
use std::io::Read;
|
||||||
|
|
||||||
|
//////////////////////////////////////Shared MSM code for Pasta curves///////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Native implementation of fast multiexp for platforms that do not support pasta_msm/semolina
|
||||||
|
/// Forked from zcash/halo2
|
||||||
|
fn cpu_multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
|
||||||
|
use ff::PrimeField;
|
||||||
|
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
|
||||||
|
|
||||||
|
let c = if bases.len() < 4 {
|
||||||
|
1
|
||||||
|
} else if bases.len() < 32 {
|
||||||
|
3
|
||||||
|
} else {
|
||||||
|
(f64::from(bases.len() as u32)).ln().ceil() as usize
|
||||||
|
};
|
||||||
|
|
||||||
|
fn get_at<F: PrimeField>(segment: usize, c: usize, bytes: &F::Repr) -> usize {
|
||||||
|
let skip_bits = segment * c;
|
||||||
|
let skip_bytes = skip_bits / 8;
|
||||||
|
|
||||||
|
if skip_bytes >= 32 {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut v = [0; 8];
|
||||||
|
for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
|
||||||
|
*v = *o;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut tmp = u64::from_le_bytes(v);
|
||||||
|
tmp >>= skip_bits - (skip_bytes * 8);
|
||||||
|
tmp %= 1 << c;
|
||||||
|
|
||||||
|
tmp as usize
|
||||||
|
}
|
||||||
|
|
||||||
|
let segments = (256 / c) + 1;
|
||||||
|
|
||||||
|
for current_segment in (0..segments).rev() {
|
||||||
|
for _ in 0..c {
|
||||||
|
*acc = acc.double();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
enum Bucket<C: CurveAffine> {
|
||||||
|
None,
|
||||||
|
Affine(C),
|
||||||
|
Projective(C::Curve),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C: CurveAffine> Bucket<C> {
|
||||||
|
fn add_assign(&mut self, other: &C) {
|
||||||
|
*self = match *self {
|
||||||
|
Bucket::None => Bucket::Affine(*other),
|
||||||
|
Bucket::Affine(a) => Bucket::Projective(a + *other),
|
||||||
|
Bucket::Projective(mut a) => {
|
||||||
|
a += *other;
|
||||||
|
Bucket::Projective(a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add(self, mut other: C::Curve) -> C::Curve {
|
||||||
|
match self {
|
||||||
|
Bucket::None => other,
|
||||||
|
Bucket::Affine(a) => {
|
||||||
|
other += a;
|
||||||
|
other
|
||||||
|
}
|
||||||
|
Bucket::Projective(a) => other + a,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; (1 << c) - 1];
|
||||||
|
|
||||||
|
for (coeff, base) in coeffs.iter().zip(bases.iter()) {
|
||||||
|
let coeff = get_at::<C::Scalar>(current_segment, c, coeff);
|
||||||
|
if coeff != 0 {
|
||||||
|
buckets[coeff - 1].add_assign(base);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Summation by parts
|
||||||
|
// e.g. 3a + 2b + 1c = a +
|
||||||
|
// (a) + b +
|
||||||
|
// ((a) + b) + c
|
||||||
|
let mut running_sum = C::Curve::identity();
|
||||||
|
for exp in buckets.into_iter().rev() {
|
||||||
|
running_sum = exp.add(running_sum);
|
||||||
|
*acc += &running_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Performs a multi-exponentiation operation without GPU acceleration.
|
||||||
|
///
|
||||||
|
/// This function will panic if coeffs and bases have a different length.
|
||||||
|
///
|
||||||
|
/// This will use multithreading if beneficial.
|
||||||
|
/// Forked from zcash/halo2
|
||||||
|
fn cpu_best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
|
||||||
|
assert_eq!(coeffs.len(), bases.len());
|
||||||
|
|
||||||
|
let num_threads = rayon::current_num_threads();
|
||||||
|
if coeffs.len() > num_threads {
|
||||||
|
let chunk = coeffs.len() / num_threads;
|
||||||
|
let num_chunks = coeffs.chunks(chunk).len();
|
||||||
|
let mut results = vec![C::Curve::identity(); num_chunks];
|
||||||
|
rayon::scope(|scope| {
|
||||||
|
let chunk = coeffs.len() / num_threads;
|
||||||
|
|
||||||
|
for ((coeffs, bases), acc) in coeffs
|
||||||
|
.chunks(chunk)
|
||||||
|
.zip(bases.chunks(chunk))
|
||||||
|
.zip(results.iter_mut())
|
||||||
|
{
|
||||||
|
scope.spawn(move |_| {
|
||||||
|
cpu_multiexp_serial(coeffs, bases, acc);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
results.iter().fold(C::Curve::identity(), |a, b| a + b)
|
||||||
|
} else {
|
||||||
|
let mut acc = C::Curve::identity();
|
||||||
|
cpu_multiexp_serial(coeffs, bases, &mut acc);
|
||||||
|
acc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////Pallas///////////////////////////////////////////////
|
//////////////////////////////////////Pallas///////////////////////////////////////////////
|
||||||
|
|
||||||
@@ -43,6 +172,7 @@ impl Group for pallas::Point {
|
|||||||
type RO = PoseidonRO<Self::Base, Self::Scalar>;
|
type RO = PoseidonRO<Self::Base, Self::Scalar>;
|
||||||
type ROCircuit = PoseidonROCircuit<Self::Base>;
|
type ROCircuit = PoseidonROCircuit<Self::Base>;
|
||||||
|
|
||||||
|
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
|
||||||
fn vartime_multiscalar_mul(
|
fn vartime_multiscalar_mul(
|
||||||
scalars: &[Self::Scalar],
|
scalars: &[Self::Scalar],
|
||||||
bases: &[Self::PreprocessedGroupElement],
|
bases: &[Self::PreprocessedGroupElement],
|
||||||
@@ -50,14 +180,18 @@ impl Group for pallas::Point {
|
|||||||
if scalars.len() >= 128 {
|
if scalars.len() >= 128 {
|
||||||
pasta_msm::pallas(bases, scalars)
|
pasta_msm::pallas(bases, scalars)
|
||||||
} else {
|
} else {
|
||||||
scalars
|
cpu_best_multiexp(scalars, bases)
|
||||||
.par_iter()
|
|
||||||
.zip(bases)
|
|
||||||
.map(|(scalar, base)| base.mul(scalar))
|
|
||||||
.reduce(Ep::group_zero, |x, y| x + y)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
|
||||||
|
fn vartime_multiscalar_mul(
|
||||||
|
scalars: &[Self::Scalar],
|
||||||
|
bases: &[Self::PreprocessedGroupElement],
|
||||||
|
) -> Self {
|
||||||
|
cpu_best_multiexp(scalars, bases)
|
||||||
|
}
|
||||||
|
|
||||||
fn preprocessed(&self) -> Self::PreprocessedGroupElement {
|
fn preprocessed(&self) -> Self::PreprocessedGroupElement {
|
||||||
self.to_affine()
|
self.to_affine()
|
||||||
}
|
}
|
||||||
@@ -153,6 +287,7 @@ impl Group for vesta::Point {
|
|||||||
type RO = PoseidonRO<Self::Base, Self::Scalar>;
|
type RO = PoseidonRO<Self::Base, Self::Scalar>;
|
||||||
type ROCircuit = PoseidonROCircuit<Self::Base>;
|
type ROCircuit = PoseidonROCircuit<Self::Base>;
|
||||||
|
|
||||||
|
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
|
||||||
fn vartime_multiscalar_mul(
|
fn vartime_multiscalar_mul(
|
||||||
scalars: &[Self::Scalar],
|
scalars: &[Self::Scalar],
|
||||||
bases: &[Self::PreprocessedGroupElement],
|
bases: &[Self::PreprocessedGroupElement],
|
||||||
@@ -160,14 +295,18 @@ impl Group for vesta::Point {
|
|||||||
if scalars.len() >= 128 {
|
if scalars.len() >= 128 {
|
||||||
pasta_msm::vesta(bases, scalars)
|
pasta_msm::vesta(bases, scalars)
|
||||||
} else {
|
} else {
|
||||||
scalars
|
cpu_best_multiexp(scalars, bases)
|
||||||
.par_iter()
|
|
||||||
.zip(bases)
|
|
||||||
.map(|(scalar, base)| base.mul(scalar))
|
|
||||||
.reduce(Eq::group_zero, |x, y| x + y)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
|
||||||
|
fn vartime_multiscalar_mul(
|
||||||
|
scalars: &[Self::Scalar],
|
||||||
|
bases: &[Self::PreprocessedGroupElement],
|
||||||
|
) -> Self {
|
||||||
|
cpu_best_multiexp(scalars, bases)
|
||||||
|
}
|
||||||
|
|
||||||
fn compress(&self) -> Self::CompressedGroupElement {
|
fn compress(&self) -> Self::CompressedGroupElement {
|
||||||
VestaCompressedElementWrapper::new(self.to_bytes())
|
VestaCompressedElementWrapper::new(self.to_bytes())
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user