//! This module implements the Nova traits for bn256::Point, bn256::Scalar, grumpkin::Point, grumpkin::Scalar. use crate::{ provider::{ keccak::Keccak256Transcript, pedersen::CommitmentEngine, poseidon::{PoseidonRO, PoseidonROCircuit}, }, traits::{CompressedGroup, Group, PrimeFieldExt, TranscriptReprTrait}, }; use digest::{ExtendableOutput, Input}; use ff::{FromUniformBytes, PrimeField}; use num_bigint::BigInt; use num_traits::Num; use pasta_curves::{ self, arithmetic::{CurveAffine, CurveExt}, group::{cofactor::CofactorCurveAffine, Curve, Group as AnotherGroup, GroupEncoding}, }; use rayon::prelude::*; use sha3::Shake256; use std::io::Read; use halo2curves::bn256::{ G1Affine as Bn256Affine, G1Compressed as Bn256Compressed, G1 as Bn256Point, }; use halo2curves::grumpkin::{ G1Affine as GrumpkinAffine, G1Compressed as GrumpkinCompressed, G1 as GrumpkinPoint, }; /// Re-exports that give access to the standard aliases used in the code base, for bn256 pub mod bn256 { pub use halo2curves::bn256::{ Fq as Base, Fr as Scalar, G1Affine as Affine, G1Compressed as Compressed, G1 as Point, }; } /// Re-exports that give access to the standard aliases used in the code base, for grumpkin pub mod grumpkin { pub use halo2curves::grumpkin::{ Fq as Base, Fr as Scalar, G1Affine as Affine, G1Compressed as Compressed, G1 as Point, }; } // This implementation behaves in ways specific to the bn256/grumpkin curves in: // - to_coordinates, // - vartime_multiscalar_mul, where it does not call into accelerated implementations. macro_rules! impl_traits { ( $name:ident, $name_compressed:ident, $name_curve:ident, $name_curve_affine:ident, $order_str:literal ) => { impl Group for $name::Point { type Base = $name::Base; type Scalar = $name::Scalar; type CompressedGroupElement = $name_compressed; type PreprocessedGroupElement = $name::Affine; type RO = PoseidonRO; type ROCircuit = PoseidonROCircuit; type TE = Keccak256Transcript; type CE = CommitmentEngine; fn vartime_multiscalar_mul( scalars: &[Self::Scalar], bases: &[Self::PreprocessedGroupElement], ) -> Self { cpu_best_multiexp(scalars, bases) } fn preprocessed(&self) -> Self::PreprocessedGroupElement { self.to_affine() } fn compress(&self) -> Self::CompressedGroupElement { self.to_bytes() } fn from_label(label: &'static [u8], n: usize) -> Vec { let mut shake = Shake256::default(); shake.input(label); let mut reader = shake.xof_result(); let mut uniform_bytes_vec = Vec::new(); for _ in 0..n { let mut uniform_bytes = [0u8; 32]; reader.read_exact(&mut uniform_bytes).unwrap(); uniform_bytes_vec.push(uniform_bytes); } let gens_proj: Vec<$name_curve> = (0..n) .collect::>() .into_par_iter() .map(|i| { let hash = $name_curve::hash_to_curve("from_uniform_bytes"); hash(&uniform_bytes_vec[i]) }) .collect(); let num_threads = rayon::current_num_threads(); if gens_proj.len() > num_threads { let chunk = (gens_proj.len() as f64 / num_threads as f64).ceil() as usize; (0..num_threads) .collect::>() .into_par_iter() .map(|i| { let start = i * chunk; let end = if i == num_threads - 1 { gens_proj.len() } else { core::cmp::min((i + 1) * chunk, gens_proj.len()) }; if end > start { let mut gens = vec![$name_curve_affine::identity(); end - start]; ::batch_normalize(&gens_proj[start..end], &mut gens); gens } else { vec![] } }) .collect::>>() .into_par_iter() .flatten() .collect() } else { let mut gens = vec![$name_curve_affine::identity(); n]; ::batch_normalize(&gens_proj, &mut gens); gens } } fn to_coordinates(&self) -> (Self::Base, Self::Base, bool) { let coordinates = self.to_affine().coordinates(); if coordinates.is_some().unwrap_u8() == 1 // The bn256/grumpkin convention is to define and return the identity point's affine encoding (not None) && (Self::PreprocessedGroupElement::identity() != self.to_affine()) { (*coordinates.unwrap().x(), *coordinates.unwrap().y(), false) } else { (Self::Base::zero(), Self::Base::zero(), true) } } fn get_curve_params() -> (Self::Base, Self::Base, BigInt) { let A = $name::Point::a(); let B = $name::Point::b(); let order = BigInt::from_str_radix($order_str, 16).unwrap(); (A, B, order) } fn zero() -> Self { $name::Point::identity() } fn get_generator() -> Self { $name::Point::generator() } } impl PrimeFieldExt for $name::Scalar { fn from_uniform(bytes: &[u8]) -> Self { let bytes_arr: [u8; 64] = bytes.try_into().unwrap(); $name::Scalar::from_uniform_bytes(&bytes_arr) } } impl TranscriptReprTrait for $name_compressed { fn to_transcript_bytes(&self) -> Vec { self.as_ref().to_vec() } } impl CompressedGroup for $name_compressed { type GroupElement = $name::Point; fn decompress(&self) -> Option<$name::Point> { Some($name_curve::from_bytes(&self).unwrap()) } } }; } impl TranscriptReprTrait for grumpkin::Base { fn to_transcript_bytes(&self) -> Vec { self.to_repr().to_vec() } } impl TranscriptReprTrait for grumpkin::Scalar { fn to_transcript_bytes(&self) -> Vec { self.to_repr().to_vec() } } impl_traits!( bn256, Bn256Compressed, Bn256Point, Bn256Affine, "30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001" ); impl_traits!( grumpkin, GrumpkinCompressed, GrumpkinPoint, GrumpkinAffine, "30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47" ); /// Performs a multi-exponentiation operation without GPU acceleration. /// /// This function will panic if coeffs and bases have a different length. /// /// This will use multithreading if beneficial. /// Adapted from zcash/halo2 // TODO: update once https://github.com/privacy-scaling-explorations/halo2curves/pull/29 // (or a successor thereof) is merged fn cpu_best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { crate::provider::pasta::cpu_best_multiexp(coeffs, bases) } #[cfg(test)] mod tests { use super::*; type G = bn256::Point; fn from_label_serial(label: &'static [u8], n: usize) -> Vec { let mut shake = Shake256::default(); shake.input(label); let mut reader = shake.xof_result(); let mut ck = Vec::new(); for _ in 0..n { let mut uniform_bytes = [0u8; 32]; reader.read_exact(&mut uniform_bytes).unwrap(); let hash = bn256::Point::hash_to_curve("from_uniform_bytes"); ck.push(hash(&uniform_bytes).to_affine()); } ck } #[test] fn test_from_label() { let label = b"test_from_label"; for n in [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 1021, ] { let ck_par = ::from_label(label, n); let ck_ser = from_label_serial(label, n); assert_eq!(ck_par.len(), n); assert_eq!(ck_ser.len(), n); assert_eq!(ck_par, ck_ser); } } }