//! This module implements the Nova traits for pallas::Point, pallas::Scalar, vesta::Point, vesta::Scalar. use crate::{ errors::NovaError, provider::{ keccak::Keccak256Transcript, pedersen::CommitmentEngine, poseidon::{PoseidonRO, PoseidonROCircuit}, }, traits::{ChallengeTrait, CompressedGroup, Group, PrimeFieldExt, TranscriptEngineTrait}, }; use digest::{ExtendableOutput, Input}; use ff::PrimeField; use num_bigint::BigInt; use num_traits::Num; use pasta_curves::{ self, arithmetic::{CurveAffine, CurveExt, FieldExt, Group as OtherGroup}, group::{cofactor::CofactorCurveAffine, Curve, Group as AnotherGroup, GroupEncoding}, pallas, vesta, Ep, EpAffine, Eq, EqAffine, }; use rayon::prelude::*; use serde::{Deserialize, Serialize}; use sha3::Shake256; use std::io::Read; /// A wrapper for compressed group elements of pallas #[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)] pub struct PallasCompressedElementWrapper { repr: [u8; 32], } impl PallasCompressedElementWrapper { /// Wraps repr into the wrapper pub fn new(repr: [u8; 32]) -> Self { Self { repr } } } /// A wrapper for compressed group elements of vesta #[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)] pub struct VestaCompressedElementWrapper { repr: [u8; 32], } impl VestaCompressedElementWrapper { /// Wraps repr into the wrapper pub fn new(repr: [u8; 32]) -> Self { Self { repr } } } macro_rules! impl_traits { ( $name:ident, $name_compressed:ident, $name_curve:ident, $name_curve_affine:ident, $order_str:literal ) => { impl Group for $name::Point { type Base = $name::Base; type Scalar = $name::Scalar; type CompressedGroupElement = $name_compressed; type PreprocessedGroupElement = $name::Affine; type RO = PoseidonRO; type ROCircuit = PoseidonROCircuit; type TE = Keccak256Transcript; type CE = CommitmentEngine; #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] fn vartime_multiscalar_mul( scalars: &[Self::Scalar], bases: &[Self::PreprocessedGroupElement], ) -> Self { if scalars.len() >= 128 { pasta_msm::$name(bases, scalars) } else { cpu_best_multiexp(scalars, bases) } } #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] fn vartime_multiscalar_mul( scalars: &[Self::Scalar], bases: &[Self::PreprocessedGroupElement], ) -> Self { cpu_best_multiexp(scalars, bases) } fn preprocessed(&self) -> Self::PreprocessedGroupElement { self.to_affine() } fn compress(&self) -> Self::CompressedGroupElement { $name_compressed::new(self.to_bytes()) } fn from_label(label: &'static [u8], n: usize) -> Vec { let mut shake = Shake256::default(); shake.input(label); let mut reader = shake.xof_result(); let mut uniform_bytes_vec = Vec::new(); for _ in 0..n { let mut uniform_bytes = [0u8; 32]; reader.read_exact(&mut uniform_bytes).unwrap(); uniform_bytes_vec.push(uniform_bytes); } let ck_proj: Vec<$name_curve> = (0..n) .collect::>() .into_par_iter() .map(|i| { let hash = $name_curve::hash_to_curve("from_uniform_bytes"); hash(&uniform_bytes_vec[i]) }) .collect(); let num_threads = rayon::current_num_threads(); if ck_proj.len() > num_threads { let chunk = (ck_proj.len() as f64 / num_threads as f64).ceil() as usize; (0..num_threads) .collect::>() .into_par_iter() .map(|i| { let start = i * chunk; let end = if i == num_threads - 1 { ck_proj.len() } else { core::cmp::min((i + 1) * chunk, ck_proj.len()) }; if end > start { let mut ck = vec![$name_curve_affine::identity(); end - start]; ::batch_normalize(&ck_proj[start..end], &mut ck); ck } else { vec![] } }) .collect::>>() .into_par_iter() .flatten() .collect() } else { let mut ck = vec![$name_curve_affine::identity(); n]; ::batch_normalize(&ck_proj, &mut ck); ck } } fn to_coordinates(&self) -> (Self::Base, Self::Base, bool) { let coordinates = self.to_affine().coordinates(); if coordinates.is_some().unwrap_u8() == 1 { (*coordinates.unwrap().x(), *coordinates.unwrap().y(), false) } else { (Self::Base::zero(), Self::Base::zero(), true) } } fn get_curve_params() -> (Self::Base, Self::Base, BigInt) { let A = Self::Base::zero(); let B = Self::Base::from(5); let order = BigInt::from_str_radix($order_str, 16).unwrap(); (A, B, order) } fn zero() -> Self { $name::Point::group_zero() } fn get_generator() -> Self { $name::Point::generator() } } impl PrimeFieldExt for $name::Scalar { fn from_uniform(bytes: &[u8]) -> Self { let bytes_arr: [u8; 64] = bytes.try_into().unwrap(); $name::Scalar::from_bytes_wide(&bytes_arr) } fn to_bytes(&self) -> Vec { self.to_repr().as_ref().to_vec() } } impl CompressedGroup for $name_compressed { type GroupElement = $name::Point; fn decompress(&self) -> Option<$name::Point> { Some($name_curve::from_bytes(&self.repr).unwrap()) } fn as_bytes(&self) -> Vec { self.repr.to_vec() } } }; } impl, F: PrimeField> ChallengeTrait for F { fn challenge(label: &'static [u8], transcript: &mut G::TE) -> Result { transcript.squeeze_scalar(label) } } impl_traits!( pallas, PallasCompressedElementWrapper, Ep, EpAffine, "40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001" ); impl_traits!( vesta, VestaCompressedElementWrapper, Eq, EqAffine, "40000000000000000000000000000000224698fc094cf91b992d30ed00000001" ); /// Native implementation of fast multiexp for platforms that do not support pasta_msm/semolina /// Adapted from zcash/halo2 fn cpu_multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); let c = if bases.len() < 4 { 1 } else if bases.len() < 32 { 3 } else { (f64::from(bases.len() as u32)).ln().ceil() as usize }; fn get_at(segment: usize, c: usize, bytes: &F::Repr) -> usize { let skip_bits = segment * c; let skip_bytes = skip_bits / 8; if skip_bytes >= 32 { return 0; } let mut v = [0; 8]; for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) { *v = *o; } let mut tmp = u64::from_le_bytes(v); tmp >>= skip_bits - (skip_bytes * 8); tmp %= 1 << c; tmp as usize } let segments = (256 / c) + 1; for current_segment in (0..segments).rev() { for _ in 0..c { *acc = acc.double(); } #[derive(Clone, Copy)] enum Bucket { None, Affine(C), Projective(C::Curve), } impl Bucket { fn add_assign(&mut self, other: &C) { *self = match *self { Bucket::None => Bucket::Affine(*other), Bucket::Affine(a) => Bucket::Projective(a + *other), Bucket::Projective(mut a) => { a += *other; Bucket::Projective(a) } } } fn add(self, mut other: C::Curve) -> C::Curve { match self { Bucket::None => other, Bucket::Affine(a) => { other += a; other } Bucket::Projective(a) => other + a, } } } let mut buckets: Vec> = vec![Bucket::None; (1 << c) - 1]; for (coeff, base) in coeffs.iter().zip(bases.iter()) { let coeff = get_at::(current_segment, c, coeff); if coeff != 0 { buckets[coeff - 1].add_assign(base); } } // Summation by parts // e.g. 3a + 2b + 1c = a + // (a) + b + // ((a) + b) + c let mut running_sum = C::Curve::identity(); for exp in buckets.into_iter().rev() { running_sum = exp.add(running_sum); *acc += &running_sum; } } } /// Performs a multi-exponentiation operation without GPU acceleration. /// /// This function will panic if coeffs and bases have a different length. /// /// This will use multithreading if beneficial. /// Adapted from zcash/halo2 fn cpu_best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { assert_eq!(coeffs.len(), bases.len()); let num_threads = rayon::current_num_threads(); if coeffs.len() > num_threads { let chunk = coeffs.len() / num_threads; let num_chunks = coeffs.chunks(chunk).len(); let mut results = vec![C::Curve::identity(); num_chunks]; rayon::scope(|scope| { for ((coeffs, bases), acc) in coeffs .chunks(chunk) .zip(bases.chunks(chunk)) .zip(results.iter_mut()) { scope.spawn(move |_| { cpu_multiexp_serial(coeffs, bases, acc); }); } }); results.iter().fold(C::Curve::identity(), |a, b| a + b) } else { let mut acc = C::Curve::identity(); cpu_multiexp_serial(coeffs, bases, &mut acc); acc } } #[cfg(test)] mod tests { use super::*; type G = pasta_curves::pallas::Point; fn from_label_serial(label: &'static [u8], n: usize) -> Vec { let mut shake = Shake256::default(); shake.input(label); let mut reader = shake.xof_result(); let mut ck = Vec::new(); for _ in 0..n { let mut uniform_bytes = [0u8; 32]; reader.read_exact(&mut uniform_bytes).unwrap(); let hash = Ep::hash_to_curve("from_uniform_bytes"); ck.push(hash(&uniform_bytes).to_affine()); } ck } #[test] fn test_from_label() { let label = b"test_from_label"; for n in [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 1021, ] { let ck_par = ::from_label(label, n); let ck_ser = from_label_serial(label, n); assert_eq!(ck_par.len(), n); assert_eq!(ck_ser.len(), n); assert_eq!(ck_par, ck_ser); } } }