diff --git a/src/circuit.rs b/src/circuit.rs index e585ee3..92f9bb5 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -335,6 +335,7 @@ mod tests { use crate::{ bellperson::r1cs::{NovaShape, NovaWitness}, commitments::CommitTrait, + traits::HashFuncConstantsTrait, }; use ff::PrimeField; use std::marker::PhantomData; diff --git a/src/pasta.rs b/src/pasta.rs index df073af..0d1d92c 100644 --- a/src/pasta.rs +++ b/src/pasta.rs @@ -1,5 +1,8 @@ //! This module implements the Nova traits for pallas::Point, pallas::Scalar, vesta::Point, vesta::Scalar. -use crate::traits::{ChallengeTrait, CompressedGroup, Group}; +use crate::{ + poseidon::PoseidonRO, + traits::{ChallengeTrait, CompressedGroup, Group}, +}; use core::ops::Mul; use ff::Field; use merlin::Transcript; @@ -33,6 +36,7 @@ impl Group for pallas::Point { type Scalar = pallas::Scalar; type CompressedGroupElement = PallasCompressedElementWrapper; type PreprocessedGroupElement = pallas::Affine; + type HashFunc = PoseidonRO; fn vartime_multiscalar_mul( scalars: &[Self::Scalar], @@ -120,6 +124,7 @@ impl Group for vesta::Point { type Scalar = vesta::Scalar; type CompressedGroupElement = VestaCompressedElementWrapper; type PreprocessedGroupElement = vesta::Affine; + type HashFunc = PoseidonRO; fn vartime_multiscalar_mul( scalars: &[Self::Scalar], diff --git a/src/poseidon.rs b/src/poseidon.rs index 7dfe278..d745a95 100644 --- a/src/poseidon.rs +++ b/src/poseidon.rs @@ -1,4 +1,5 @@ //! Poseidon Constants and Poseidon-based RO used in Nova +use crate::traits::{HashFuncConstantsTrait, HashFuncTrait}; use bellperson::{ gadgets::{ boolean::{AllocatedBit, Boolean}, @@ -6,36 +7,34 @@ use bellperson::{ }, ConstraintSystem, SynthesisError, }; +use core::marker::PhantomData; use ff::{PrimeField, PrimeFieldBits}; use generic_array::typenum::{U27, U8}; use neptune::{ circuit::poseidon_hash, poseidon::{Poseidon, PoseidonConstants}, + Strength, }; -#[cfg(test)] -use neptune::Strength; - /// All Poseidon Constants that are used in Nova #[derive(Clone)] -pub struct NovaPoseidonConstants +pub struct NovaPoseidonConstants where - F: PrimeField, + Scalar: PrimeField, { - constants8: PoseidonConstants, - constants27: PoseidonConstants, + constants8: PoseidonConstants, + constants27: PoseidonConstants, } -#[cfg(test)] -impl NovaPoseidonConstants +impl HashFuncConstantsTrait for NovaPoseidonConstants where - F: PrimeField, + Scalar: PrimeField + PrimeFieldBits, { /// Generate Poseidon constants for the arities that Nova uses #[allow(clippy::new_without_default)] - pub fn new() -> Self { - let constants8 = PoseidonConstants::::new_with_strength(Strength::Strengthened); - let constants27 = PoseidonConstants::::new_with_strength(Strength::Strengthened); + fn new() -> Self { + let constants8 = PoseidonConstants::::new_with_strength(Strength::Strengthened); + let constants27 = PoseidonConstants::::new_with_strength(Strength::Strengthened); Self { constants8, constants27, @@ -44,41 +43,28 @@ where } /// A Poseidon-based RO to use outside circuits -pub struct PoseidonRO +pub struct PoseidonRO where + Base: PrimeField + PrimeFieldBits, Scalar: PrimeField + PrimeFieldBits, { // Internal State - state: Vec, + state: Vec, // Constants for Poseidon - constants: NovaPoseidonConstants, + constants: NovaPoseidonConstants, + _p: PhantomData, } -impl PoseidonRO +impl PoseidonRO where + Base: PrimeField + PrimeFieldBits, Scalar: PrimeField + PrimeFieldBits, { - #[allow(dead_code)] - pub fn new(constants: NovaPoseidonConstants) -> Self { - Self { - state: Vec::new(), - constants, - } - } - - /// Absorb a new number into the state of the oracle - #[allow(dead_code)] - pub fn absorb(&mut self, e: Scalar) { - self.state.push(e); - } - - fn hash_inner(&mut self) -> Scalar { + fn hash_inner(&self) -> Base { match self.state.len() { - 8 => { - Poseidon::::new_with_preimage(&self.state, &self.constants.constants8).hash() - } + 8 => Poseidon::::new_with_preimage(&self.state, &self.constants.constants8).hash(), 27 => { - Poseidon::::new_with_preimage(&self.state, &self.constants.constants27).hash() + Poseidon::::new_with_preimage(&self.state, &self.constants.constants27).hash() } _ => { panic!( @@ -88,10 +74,33 @@ where } } } +} + +impl HashFuncTrait for PoseidonRO +where + Base: PrimeField + PrimeFieldBits, + Scalar: PrimeField + PrimeFieldBits, +{ + type Constants = NovaPoseidonConstants; + + #[allow(dead_code)] + fn new(constants: NovaPoseidonConstants) -> Self { + Self { + state: Vec::new(), + constants, + _p: PhantomData::default(), + } + } + + /// Absorb a new number into the state of the oracle + #[allow(dead_code)] + fn absorb(&mut self, e: Base) { + self.state.push(e); + } /// Compute a challenge by hashing the current state #[allow(dead_code)] - pub fn get_challenge(&mut self) -> Scalar { + fn get_challenge(&self) -> Scalar { let hash = self.hash_inner(); // Only keep 128 bits let bits = hash.to_le_bits(); @@ -107,7 +116,7 @@ where } #[allow(dead_code)] - pub fn get_hash(&mut self) -> Scalar { + fn get_hash(&self) -> Scalar { let hash = self.hash_inner(); // Only keep 250 bits let bits = hash.to_le_bits(); @@ -214,6 +223,7 @@ where mod tests { use super::*; type S = pasta_curves::pallas::Scalar; + type B = pasta_curves::vesta::Scalar; type G = pasta_curves::pallas::Point; use crate::{bellperson::solver::SatisfyingAssignment, gadgets::utils::le_bits_to_num}; use ff::Field; @@ -224,7 +234,7 @@ mod tests { // Check that the number computed inside the circuit is equal to the number computed outside the circuit let mut csprng: OsRng = OsRng; let constants = NovaPoseidonConstants::new(); - let mut ro: PoseidonRO = PoseidonRO::new(constants.clone()); + let mut ro: PoseidonRO = PoseidonRO::new(constants.clone()); let mut ro_gadget: PoseidonROGadget = PoseidonROGadget::new(constants); let mut cs: SatisfyingAssignment = SatisfyingAssignment::new(); for i in 0..27 { @@ -240,6 +250,6 @@ mod tests { let num = ro.get_challenge(); let num2_bits = ro_gadget.get_challenge(&mut cs).unwrap(); let num2 = le_bits_to_num(&mut cs, num2_bits).unwrap(); - assert_eq!(num, num2.get_value().unwrap()); + assert_eq!(num.to_repr(), num2.get_value().unwrap().to_repr()); } } diff --git a/src/traits.rs b/src/traits.rs index 7dc551a..b40e99a 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -32,6 +32,10 @@ pub trait Group: /// A type representing preprocessed group element type PreprocessedGroupElement; + /// A type that represents a hash function that consumes elements + /// from the base field and squeezes out elements of the scalar field + type HashFunc: HashFuncTrait; + /// A method to compute a multiexponentation fn vartime_multiscalar_mul( scalars: &[Self::Scalar], @@ -70,6 +74,30 @@ pub trait ChallengeTrait { fn challenge(label: &'static [u8], transcript: &mut Transcript) -> Self; } +/// A helper trait that defines the behavior of a hash function that we use as an RO +pub trait HashFuncTrait { + /// A type representing constants/parameters associated with the hash function + type Constants: HashFuncConstantsTrait; + + /// Initializes the hash function + fn new(constants: Self::Constants) -> Self; + + /// Adds a scalar to the internal state + fn absorb(&mut self, e: Base); + + /// Returns a random challenge by hashing the internal state + fn get_challenge(&self) -> Scalar; + + /// Returns a hash of the internal state + fn get_hash(&self) -> Scalar; +} + +/// A helper trait that defines the constants associated with a hash function +pub trait HashFuncConstantsTrait { + /// produces constants/parameters associated with the hash function + fn new() -> Self; +} + /// A helper trait for types with a group operation. pub trait GroupOps: Add + Sub + AddAssign + SubAssign