From bcbdec2982ca6ff0e839fc3c03be8590154e42ff Mon Sep 17 00:00:00 2001 From: Srinath Setty Date: Mon, 25 Apr 2022 21:09:47 -0700 Subject: [PATCH] Accelerated MSM prep (#41) * remove send + sync * introduce a new associative type to capture any form of preprocessing on group elements * update pasta_curves version * simplify trait requirements * fix clippy --- Cargo.toml | 2 +- src/commitments.rs | 11 +++++++--- src/pasta.rs | 54 ++++++++++++++++++---------------------------- src/r1cs.rs | 1 - src/traits.rs | 17 +++++++-------- 5 files changed, 38 insertions(+), 47 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 007637f..eef36b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ rand_core = { version = "0.5", default-features = false } rand_chacha = "0.3" itertools = "0.9.0" subtle = "2.4" -pasta_curves = "0.3.0" +pasta_curves = "0.3.1" neptune = "6.1" generic-array = "0.14.4" bellperson-nonnative = { version = "0.2.1", default-features = false, features = ["wasm"] } diff --git a/src/commitments.rs b/src/commitments.rs index 1e1c2e5..e2cec07 100644 --- a/src/commitments.rs +++ b/src/commitments.rs @@ -4,6 +4,7 @@ use super::{ }; use core::{ fmt::Debug, + marker::PhantomData, ops::{Add, AddAssign, Mul, MulAssign}, }; use digest::{ExtendableOutput, Input}; @@ -13,7 +14,8 @@ use std::io::Read; #[derive(Debug)] pub struct CommitGens { - gens: Vec, + gens: Vec, + _p: PhantomData, } #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -31,14 +33,17 @@ impl CommitGens { let mut shake = Shake256::default(); shake.input(label); let mut reader = shake.xof_result(); - let mut gens: Vec = Vec::new(); + let mut gens: Vec = Vec::new(); let mut uniform_bytes = [0u8; 64]; for _ in 0..n { reader.read_exact(&mut uniform_bytes).unwrap(); gens.push(G::from_uniform_bytes(&uniform_bytes).unwrap()); } - CommitGens { gens } + CommitGens { + gens, + _p: PhantomData::default(), + } } } diff --git a/src/pasta.rs b/src/pasta.rs index b6dd70f..df073af 100644 --- a/src/pasta.rs +++ b/src/pasta.rs @@ -1,5 +1,6 @@ //! This module implements the Nova traits for pallas::Point, pallas::Scalar, vesta::Point, vesta::Scalar. use crate::traits::{ChallengeTrait, CompressedGroup, Group}; +use core::ops::Mul; use ff::Field; use merlin::Transcript; use pasta_curves::{ @@ -11,7 +12,6 @@ use pasta_curves::{ use rand::SeedableRng; use rand_chacha::ChaCha20Rng; use rug::Integer; -use std::{borrow::Borrow, ops::Mul}; //////////////////////////////////////Pallas/////////////////////////////////////////////// @@ -28,27 +28,21 @@ impl PallasCompressedElementWrapper { } } -unsafe impl Send for PallasCompressedElementWrapper {} -unsafe impl Sync for PallasCompressedElementWrapper {} - impl Group for pallas::Point { type Base = pallas::Base; type Scalar = pallas::Scalar; type CompressedGroupElement = PallasCompressedElementWrapper; + type PreprocessedGroupElement = pallas::Affine; - fn vartime_multiscalar_mul(scalars: I, points: J) -> Self - where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator, - J::Item: Borrow, - Self: Clone, - { + fn vartime_multiscalar_mul( + scalars: &[Self::Scalar], + bases: &[Self::PreprocessedGroupElement], + ) -> Self { // Unoptimized. scalars - .into_iter() - .zip(points) - .map(|(scalar, point)| (*point.borrow()).mul(*scalar.borrow())) + .iter() + .zip(bases) + .map(|(scalar, base)| base.mul(scalar)) .fold(Ep::group_zero(), |acc, x| acc + x) } @@ -56,7 +50,7 @@ impl Group for pallas::Point { PallasCompressedElementWrapper::new(self.to_bytes()) } - fn from_uniform_bytes(bytes: &[u8]) -> Option { + fn from_uniform_bytes(bytes: &[u8]) -> Option { if bytes.len() != 64 { None } else { @@ -64,7 +58,7 @@ impl Group for pallas::Point { arr.copy_from_slice(&bytes[0..32]); let hash = Ep::hash_to_curve("from_uniform_bytes"); - Some(hash(&arr)) + Some(hash(&arr).to_affine()) } } @@ -121,27 +115,21 @@ impl VestaCompressedElementWrapper { } } -unsafe impl Send for VestaCompressedElementWrapper {} -unsafe impl Sync for VestaCompressedElementWrapper {} - impl Group for vesta::Point { type Base = vesta::Base; type Scalar = vesta::Scalar; type CompressedGroupElement = VestaCompressedElementWrapper; + type PreprocessedGroupElement = vesta::Affine; - fn vartime_multiscalar_mul(scalars: I, points: J) -> Self - where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator, - J::Item: Borrow, - Self: Clone, - { + fn vartime_multiscalar_mul( + scalars: &[Self::Scalar], + bases: &[Self::PreprocessedGroupElement], + ) -> Self { // Unoptimized. scalars - .into_iter() - .zip(points) - .map(|(scalar, point)| (*point.borrow()).mul(*scalar.borrow())) + .iter() + .zip(bases) + .map(|(scalar, base)| base.mul(scalar)) .fold(Eq::group_zero(), |acc, x| acc + x) } @@ -149,7 +137,7 @@ impl Group for vesta::Point { VestaCompressedElementWrapper::new(self.to_bytes()) } - fn from_uniform_bytes(bytes: &[u8]) -> Option { + fn from_uniform_bytes(bytes: &[u8]) -> Option { if bytes.len() != 64 { None } else { @@ -157,7 +145,7 @@ impl Group for vesta::Point { arr.copy_from_slice(&bytes[0..32]); let hash = Eq::hash_to_curve("from_uniform_bytes"); - Some(hash(&arr)) + Some(hash(&arr).to_affine()) } } diff --git a/src/r1cs.rs b/src/r1cs.rs index 3ac1469..6ec70f3 100644 --- a/src/r1cs.rs +++ b/src/r1cs.rs @@ -10,7 +10,6 @@ use itertools::concat; use rayon::prelude::*; /// Public parameters for a given R1CS -#[derive(Debug)] pub struct R1CSGens { pub(crate) gens_W: CommitGens, // TODO: avoid pub(crate) pub(crate) gens_E: CommitGens, diff --git a/src/traits.rs b/src/traits.rs index adf9a6e..7dc551a 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,7 +1,6 @@ //! This module defines various traits required by the users of the library to implement. use bellperson::{gadgets::num::AllocatedNum, ConstraintSystem, SynthesisError}; use core::{ - borrow::Borrow, fmt::Debug, ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}, }; @@ -30,21 +29,21 @@ pub trait Group: /// A type representing the compressed version of the group element type CompressedGroupElement: CompressedGroup; + /// A type representing preprocessed group element + type PreprocessedGroupElement; + /// A method to compute a multiexponentation - fn vartime_multiscalar_mul(scalars: I, points: J) -> Self - where - I: IntoIterator, - I::Item: Borrow, - J: IntoIterator, - J::Item: Borrow, - Self: Clone; + fn vartime_multiscalar_mul( + scalars: &[Self::Scalar], + bases: &[Self::PreprocessedGroupElement], + ) -> Self; /// Compresses the group element fn compress(&self) -> Self::CompressedGroupElement; /// Attempts to create a group element from a sequence of bytes, /// failing with a `None` if the supplied bytes do not encode the group element - fn from_uniform_bytes(bytes: &[u8]) -> Option; + fn from_uniform_bytes(bytes: &[u8]) -> Option; /// Returns the affine coordinates (x, y, infinty) for the point fn to_coordinates(&self) -> (Self::Base, Self::Base, bool);