From 1d5d11efaf8bd4fe6b747d1047297c96d13a94f7 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Mon, 24 Oct 2022 09:32:39 +0200 Subject: [PATCH] fix misc. nits --- crypto/src/hash/rpo/digest.rs | 50 ++++++++++++++++++++++++++++----- crypto/src/hash/rpo/mds_freq.rs | 16 +++++------ crypto/src/hash/rpo/mod.rs | 5 +++- 3 files changed, 55 insertions(+), 16 deletions(-) diff --git a/crypto/src/hash/rpo/digest.rs b/crypto/src/hash/rpo/digest.rs index 7c21298..a0c20a4 100644 --- a/crypto/src/hash/rpo/digest.rs +++ b/crypto/src/hash/rpo/digest.rs @@ -1,6 +1,7 @@ use super::DIGEST_SIZE; use crate::{Digest, Felt, StarkField}; -use core::slice; +use core::ops::Deref; + use winterfell::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; // DIGEST TRAIT IMPLEMENTATIONS @@ -15,13 +16,14 @@ impl RpoDigest256 { } pub fn as_elements(&self) -> &[Felt] { - &self.0 + self.as_ref() } - pub fn digests_as_elements(digests: &[Self]) -> &[Felt] { - let p = digests.as_ptr(); - let len = digests.len() * DIGEST_SIZE; - unsafe { slice::from_raw_parts(p as *const Felt, len) } + pub fn digests_as_elements<'a, I>(digests: I) -> impl Iterator + where + I: Iterator, + { + digests.map(|d| d.0.iter()).flatten() } } @@ -52,7 +54,6 @@ impl Serializable for RpoDigest256 { impl Deserializable for RpoDigest256 { fn read_from(source: &mut R) -> Result { - // TODO: check if the field elements are valid? let e1 = Felt::new(source.read_u64()?); let e2 = Felt::new(source.read_u64()?); let e3 = Felt::new(source.read_u64()?); @@ -80,6 +81,41 @@ impl From for [u8; 32] { } } +impl Deref for RpoDigest256 { + type Target = [Felt; DIGEST_SIZE]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl RpoDigest256 { + fn iter(&self) -> RpoDigest256Iter<'_> { + RpoDigest256Iter { + values: &self.0, + index: 0, + } + } +} + +pub struct RpoDigest256Iter<'a> { + values: &'a [Felt; DIGEST_SIZE], + index: usize, +} + +impl<'a> Iterator for RpoDigest256Iter<'a> { + type Item = &'a Felt; + + fn next(&mut self) -> Option { + if self.index >= self.values.len() { + return None; + } + + self.index += 1; + Some(&self.values[self.index - 1]) + } +} + // TESTS // ================================================================================================ diff --git a/crypto/src/hash/rpo/mds_freq.rs b/crypto/src/hash/rpo/mds_freq.rs index a510730..b3a6c90 100644 --- a/crypto/src/hash/rpo/mds_freq.rs +++ b/crypto/src/hash/rpo/mds_freq.rs @@ -26,7 +26,7 @@ const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-8, 1, 1]; // We use split 3 x 4 FFT transform in order to transform our vectors into the frequency domain. #[inline(always)] -pub(crate) fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] { +pub(crate) const fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] { let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = state; let (u0, u1, u2) = fft4_real([s0, s3, s6, s9]); @@ -56,18 +56,18 @@ pub(crate) fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] { // We use the real FFT to avoid redundant computations. See https://www.mdpi.com/2076-3417/12/9/4700 #[inline(always)] -fn fft2_real(x: [u64; 2]) -> [i64; 2] { +const fn fft2_real(x: [u64; 2]) -> [i64; 2] { [(x[0] as i64 + x[1] as i64), (x[0] as i64 - x[1] as i64)] } #[inline(always)] -fn ifft2_real(y: [i64; 2]) -> [u64; 2] { +const fn ifft2_real(y: [i64; 2]) -> [u64; 2] { // We avoid divisions by 2 by appropriately scaling the MDS matrix constants. [(y[0] + y[1]) as u64, (y[0] - y[1]) as u64] } #[inline(always)] -fn fft4_real(x: [u64; 4]) -> (i64, (i64, i64), i64) { +const fn fft4_real(x: [u64; 4]) -> (i64, (i64, i64), i64) { let [z0, z2] = fft2_real([x[0], x[2]]); let [z1, z3] = fft2_real([x[1], x[3]]); let y0 = z0 + z1; @@ -77,7 +77,7 @@ fn fft4_real(x: [u64; 4]) -> (i64, (i64, i64), i64) { } #[inline(always)] -fn ifft4_real(y: (i64, (i64, i64), i64)) -> [u64; 4] { +const fn ifft4_real(y: (i64, (i64, i64), i64)) -> [u64; 4] { // In calculating 'z0' and 'z1', division by 2 is avoided by appropriately scaling // the MDS matrix constants. let z0 = y.0 + y.2; @@ -92,7 +92,7 @@ fn ifft4_real(y: (i64, (i64, i64), i64)) -> [u64; 4] { } #[inline(always)] -fn block1(x: [i64; 3], y: [i64; 3]) -> [i64; 3] { +const fn block1(x: [i64; 3], y: [i64; 3]) -> [i64; 3] { let [x0, x1, x2] = x; let [y0, y1, y2] = y; let z0 = x0 * y0 + x1 * y2 + x2 * y1; @@ -103,7 +103,7 @@ fn block1(x: [i64; 3], y: [i64; 3]) -> [i64; 3] { } #[inline(always)] -fn block2(x: [(i64, i64); 3], y: [(i64, i64); 3]) -> [(i64, i64); 3] { +const fn block2(x: [(i64, i64); 3], y: [(i64, i64); 3]) -> [(i64, i64); 3] { let [(x0r, x0i), (x1r, x1i), (x2r, x2i)] = x; let [(y0r, y0i), (y1r, y1i), (y2r, y2i)] = y; let x0s = x0r + x0i; @@ -141,7 +141,7 @@ fn block2(x: [(i64, i64); 3], y: [(i64, i64); 3]) -> [(i64, i64); 3] { } #[inline(always)] -fn block3(x: [i64; 3], y: [i64; 3]) -> [i64; 3] { +const fn block3(x: [i64; 3], y: [i64; 3]) -> [i64; 3] { let [x0, x1, x2] = x; let [y0, y1, y2] = y; let z0 = x0 * y0 - x1 * y2 - x2 * y1; diff --git a/crypto/src/hash/rpo/mod.rs b/crypto/src/hash/rpo/mod.rs index c711416..b9370cb 100644 --- a/crypto/src/hash/rpo/mod.rs +++ b/crypto/src/hash/rpo/mod.rs @@ -153,7 +153,10 @@ impl HashFn for Rpo256 { // initialize the state by copying the digest elements into the rate portion of the state // (8 total elements), and set the capacity elements to 0. let mut state = [ZERO; STATE_WIDTH]; - state[RATE_RANGE].copy_from_slice(Self::Digest::digests_as_elements(values)); + let it = Self::Digest::digests_as_elements(values.into_iter()); + for (i, v) in it.enumerate() { + state[RATE_RANGE.start + i] = *v; + } // apply the RPO permutation and return the first four elements of the state Self::apply_permutation(&mut state);