Browse Source

fix misc. nits

al-gkr-basic-workflow
Al-Kindi-0 2 years ago
parent
commit
1d5d11efaf
3 changed files with 55 additions and 16 deletions
  1. +43
    -7
      crypto/src/hash/rpo/digest.rs
  2. +8
    -8
      crypto/src/hash/rpo/mds_freq.rs
  3. +4
    -1
      crypto/src/hash/rpo/mod.rs

+ 43
- 7
crypto/src/hash/rpo/digest.rs

@ -1,6 +1,7 @@
use super::DIGEST_SIZE;
use crate::{Digest, Felt, StarkField};
use core::slice;
use core::ops::Deref;
use winterfell::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
// DIGEST TRAIT IMPLEMENTATIONS
@ -15,13 +16,14 @@ impl RpoDigest256 {
}
pub fn as_elements(&self) -> &[Felt] {
&self.0
self.as_ref()
}
pub fn digests_as_elements(digests: &[Self]) -> &[Felt] {
let p = digests.as_ptr();
let len = digests.len() * DIGEST_SIZE;
unsafe { slice::from_raw_parts(p as *const Felt, len) }
pub fn digests_as_elements<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
where
I: Iterator<Item = &'a Self>,
{
digests.map(|d| d.0.iter()).flatten()
}
}
@ -52,7 +54,6 @@ impl Serializable for RpoDigest256 {
impl Deserializable for RpoDigest256 {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
// TODO: check if the field elements are valid?
let e1 = Felt::new(source.read_u64()?);
let e2 = Felt::new(source.read_u64()?);
let e3 = Felt::new(source.read_u64()?);
@ -80,6 +81,41 @@ impl From for [u8; 32] {
}
}
impl Deref for RpoDigest256 {
type Target = [Felt; DIGEST_SIZE];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl RpoDigest256 {
fn iter(&self) -> RpoDigest256Iter<'_> {
RpoDigest256Iter {
values: &self.0,
index: 0,
}
}
}
pub struct RpoDigest256Iter<'a> {
values: &'a [Felt; DIGEST_SIZE],
index: usize,
}
impl<'a> Iterator for RpoDigest256Iter<'a> {
type Item = &'a Felt;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.values.len() {
return None;
}
self.index += 1;
Some(&self.values[self.index - 1])
}
}
// TESTS
// ================================================================================================

+ 8
- 8
crypto/src/hash/rpo/mds_freq.rs

@ -26,7 +26,7 @@ const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-8, 1, 1];
// We use split 3 x 4 FFT transform in order to transform our vectors into the frequency domain.
#[inline(always)]
pub(crate) fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] {
pub(crate) const fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] {
let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = state;
let (u0, u1, u2) = fft4_real([s0, s3, s6, s9]);
@ -56,18 +56,18 @@ pub(crate) fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] {
// We use the real FFT to avoid redundant computations. See https://www.mdpi.com/2076-3417/12/9/4700
#[inline(always)]
fn fft2_real(x: [u64; 2]) -> [i64; 2] {
const fn fft2_real(x: [u64; 2]) -> [i64; 2] {
[(x[0] as i64 + x[1] as i64), (x[0] as i64 - x[1] as i64)]
}
#[inline(always)]
fn ifft2_real(y: [i64; 2]) -> [u64; 2] {
const fn ifft2_real(y: [i64; 2]) -> [u64; 2] {
// We avoid divisions by 2 by appropriately scaling the MDS matrix constants.
[(y[0] + y[1]) as u64, (y[0] - y[1]) as u64]
}
#[inline(always)]
fn fft4_real(x: [u64; 4]) -> (i64, (i64, i64), i64) {
const fn fft4_real(x: [u64; 4]) -> (i64, (i64, i64), i64) {
let [z0, z2] = fft2_real([x[0], x[2]]);
let [z1, z3] = fft2_real([x[1], x[3]]);
let y0 = z0 + z1;
@ -77,7 +77,7 @@ fn fft4_real(x: [u64; 4]) -> (i64, (i64, i64), i64) {
}
#[inline(always)]
fn ifft4_real(y: (i64, (i64, i64), i64)) -> [u64; 4] {
const fn ifft4_real(y: (i64, (i64, i64), i64)) -> [u64; 4] {
// In calculating 'z0' and 'z1', division by 2 is avoided by appropriately scaling
// the MDS matrix constants.
let z0 = y.0 + y.2;
@ -92,7 +92,7 @@ fn ifft4_real(y: (i64, (i64, i64), i64)) -> [u64; 4] {
}
#[inline(always)]
fn block1(x: [i64; 3], y: [i64; 3]) -> [i64; 3] {
const fn block1(x: [i64; 3], y: [i64; 3]) -> [i64; 3] {
let [x0, x1, x2] = x;
let [y0, y1, y2] = y;
let z0 = x0 * y0 + x1 * y2 + x2 * y1;
@ -103,7 +103,7 @@ fn block1(x: [i64; 3], y: [i64; 3]) -> [i64; 3] {
}
#[inline(always)]
fn block2(x: [(i64, i64); 3], y: [(i64, i64); 3]) -> [(i64, i64); 3] {
const fn block2(x: [(i64, i64); 3], y: [(i64, i64); 3]) -> [(i64, i64); 3] {
let [(x0r, x0i), (x1r, x1i), (x2r, x2i)] = x;
let [(y0r, y0i), (y1r, y1i), (y2r, y2i)] = y;
let x0s = x0r + x0i;
@ -141,7 +141,7 @@ fn block2(x: [(i64, i64); 3], y: [(i64, i64); 3]) -> [(i64, i64); 3] {
}
#[inline(always)]
fn block3(x: [i64; 3], y: [i64; 3]) -> [i64; 3] {
const fn block3(x: [i64; 3], y: [i64; 3]) -> [i64; 3] {
let [x0, x1, x2] = x;
let [y0, y1, y2] = y;
let z0 = x0 * y0 - x1 * y2 - x2 * y1;

+ 4
- 1
crypto/src/hash/rpo/mod.rs

@ -153,7 +153,10 @@ impl HashFn for Rpo256 {
// initialize the state by copying the digest elements into the rate portion of the state
// (8 total elements), and set the capacity elements to 0.
let mut state = [ZERO; STATE_WIDTH];
state[RATE_RANGE].copy_from_slice(Self::Digest::digests_as_elements(values));
let it = Self::Digest::digests_as_elements(values.into_iter());
for (i, v) in it.enumerate() {
state[RATE_RANGE.start + i] = *v;
}
// apply the RPO permutation and return the first four elements of the state
Self::apply_permutation(&mut state);

Loading…
Cancel
Save