diff --git a/src/provider/keccak.rs b/src/provider/keccak.rs index e2e8661..6adf352 100644 --- a/src/provider/keccak.rs +++ b/src/provider/keccak.rs @@ -18,19 +18,22 @@ const KECCAK256_PREFIX_CHALLENGE_HI: u8 = 1; pub struct Keccak256Transcript { round: u16, state: [u8; KECCAK256_STATE_SIZE], - transcript: Vec, + transcript: Keccak256, _p: PhantomData, } -fn compute_updated_state(input: &[u8]) -> [u8; KECCAK256_STATE_SIZE] { - let input_lo = [input, &[KECCAK256_PREFIX_CHALLENGE_LO]].concat(); - let input_hi = [input, &[KECCAK256_PREFIX_CHALLENGE_HI]].concat(); +fn compute_updated_state(keccak_instance: Keccak256, input: &[u8]) -> [u8; KECCAK256_STATE_SIZE] { + let mut updated_instance = keccak_instance; + updated_instance.input(input); - let mut hasher_lo = Keccak256::new(); - let mut hasher_hi = Keccak256::new(); + let input_lo = &[KECCAK256_PREFIX_CHALLENGE_LO]; + let input_hi = &[KECCAK256_PREFIX_CHALLENGE_HI]; - hasher_lo.input(&input_lo); - hasher_hi.input(&input_hi); + let mut hasher_lo = updated_instance.clone(); + let mut hasher_hi = updated_instance; + + hasher_lo.input(input_lo); + hasher_hi.input(input_hi); let output_lo = hasher_lo.result(); let output_hi = hasher_hi.result(); @@ -44,27 +47,28 @@ fn compute_updated_state(input: &[u8]) -> [u8; KECCAK256_STATE_SIZE] { impl TranscriptEngineTrait for Keccak256Transcript { fn new(label: &'static [u8]) -> Self { + let keccak_instance = Keccak256::new(); let input = [PERSONA_TAG, label].concat(); - let output = compute_updated_state(&input); + let output = compute_updated_state(keccak_instance.clone(), &input); Self { round: 0u16, state: output, - transcript: vec![], + transcript: keccak_instance, _p: Default::default(), } } fn squeeze(&mut self, label: &'static [u8]) -> Result { + // we gather the full input from the round, preceded by the current state of the transcript let input = [ DOM_SEP_TAG, self.round.to_le_bytes().as_ref(), self.state.as_ref(), - self.transcript.as_ref(), label, ] .concat(); - let output = compute_updated_state(&input); + let output = compute_updated_state(self.transcript.clone(), &input); // update state self.round = { @@ -75,20 +79,20 @@ impl TranscriptEngineTrait for Keccak256Transcript { } }; self.state.copy_from_slice(&output); - self.transcript = vec![]; + self.transcript = Keccak256::new(); // squeeze out a challenge Ok(G::Scalar::from_uniform(&output)) } fn absorb>(&mut self, label: &'static [u8], o: &T) { - self.transcript.extend_from_slice(label); - self.transcript.extend_from_slice(&o.to_transcript_bytes()); + self.transcript.input(label); + self.transcript.input(&o.to_transcript_bytes()); } fn dom_sep(&mut self, bytes: &'static [u8]) { - self.transcript.extend_from_slice(DOM_SEP_TAG); - self.transcript.extend_from_slice(bytes); + self.transcript.input(DOM_SEP_TAG); + self.transcript.input(bytes); } } @@ -97,9 +101,10 @@ mod tests { use crate::{ provider::bn256_grumpkin::bn256, provider::keccak::Keccak256Transcript, - traits::{Group, TranscriptEngineTrait}, + traits::{Group, PrimeFieldExt, TranscriptEngineTrait, TranscriptReprTrait}, }; use ff::PrimeField; + use rand::Rng; use sha3::{Digest, Keccak256}; fn test_keccak_transcript_with(expected_h1: &'static str, expected_h2: &'static str) { @@ -131,13 +136,13 @@ mod tests { #[test] fn test_keccak_transcript() { test_keccak_transcript_with::( - "432d5811c8be3d44d47f52108a8749ae18482efd1a37b830f966456b5d75340c", - "65f7908d53abcd18f3b1d767456ef9009b91c7566a635e9ca7be26e21d4d7a10", + "5ddffa8dc091862132788b8976af88b9a2c70594727e611c7217ba4c30c8c70a", + "4d4bf42c065870395749fa1c4fb641df1e0d53f05309b03d5b1db7f0be3aa13d", ); test_keccak_transcript_with::( - "93f9160d5501865b399ee4ff0ffe17b697a4023e33e931e2597d36e6cc4ac602", - "bca8bdb96608a8277a7cb34bd493dfbc5baf2a080d1d6c9d32d7ab4f238eb803", + "9fb71e3b74bfd0b60d97349849b895595779a240b92a6fae86bd2812692b6b0e", + "bfd4c50b7d6317e9267d5d65c985eb455a3561129c0b3beef79bfc8461a84f18", ); } @@ -151,4 +156,88 @@ mod tests { "29045a592007d0c246ef02c2223570da9522d0cf0f73282c79a1bc8f0bb2c238" ); } + + use super::{ + DOM_SEP_TAG, KECCAK256_PREFIX_CHALLENGE_HI, KECCAK256_PREFIX_CHALLENGE_LO, + KECCAK256_STATE_SIZE, PERSONA_TAG, + }; + + fn compute_updated_state_for_testing(input: &[u8]) -> [u8; KECCAK256_STATE_SIZE] { + let input_lo = [input, &[KECCAK256_PREFIX_CHALLENGE_LO]].concat(); + let input_hi = [input, &[KECCAK256_PREFIX_CHALLENGE_HI]].concat(); + + let mut hasher_lo = Keccak256::new(); + let mut hasher_hi = Keccak256::new(); + + hasher_lo.input(&input_lo); + hasher_hi.input(&input_hi); + + let output_lo = hasher_lo.result(); + let output_hi = hasher_hi.result(); + + [output_lo, output_hi] + .concat() + .as_slice() + .try_into() + .unwrap() + } + + fn squeeze_for_testing( + transcript: &[u8], + round: u16, + state: [u8; KECCAK256_STATE_SIZE], + label: &'static [u8], + ) -> [u8; 64] { + let input = [ + transcript, + DOM_SEP_TAG, + round.to_le_bytes().as_ref(), + state.as_ref(), + label, + ] + .concat(); + compute_updated_state_for_testing(&input) + } + + // This test is meant to ensure compatibility between the incremental way of computing the transcript above, and + // the former, which materialized the entirety of the input vector before calling Keccak256 on it. + fn test_keccak_transcript_incremental_vs_explicit_with() { + let test_label = b"test"; + let mut transcript: Keccak256Transcript = Keccak256Transcript::new(test_label); + let mut rng = rand::thread_rng(); + + // ten scalars + let scalars = std::iter::from_fn(|| Some(::Scalar::from(rng.gen::()))) + .take(10) + .collect::>(); + + // add the scalars to the transcripts, + let mut manual_transcript: Vec = vec![]; + let labels = vec![ + b"s1", b"s2", b"s3", b"s4", b"s5", b"s6", b"s7", b"s8", b"s9", b"s0", + ]; + + for i in 0..10 { + transcript.absorb(&labels[i][..], &scalars[i]); + manual_transcript.extend(labels[i]); + manual_transcript.extend(scalars[i].to_transcript_bytes()); + } + + // compute the initial state + let input = [PERSONA_TAG, test_label].concat(); + let initial_state = compute_updated_state_for_testing(&input); + + // make a challenge + let c1: ::Scalar = transcript.squeeze(b"c1").unwrap(); + + let c1_bytes = squeeze_for_testing(&manual_transcript[..], 0u16, initial_state, b"c1"); + let to_hex = |g: G::Scalar| hex::encode(g.to_repr().as_ref()); + assert_eq!(to_hex(c1), to_hex(G::Scalar::from_uniform(&c1_bytes))); + } + + #[test] + fn test_keccak_transcript_incremental_vs_explicit() { + test_keccak_transcript_incremental_vs_explicit_with::(); + test_keccak_transcript_incremental_vs_explicit_with::(); + } }