From 876e23c159250bcc32c6782cf428dabbc68077bb Mon Sep 17 00:00:00 2001 From: Pierre Date: Fri, 15 Dec 2023 14:21:07 +0100 Subject: [PATCH] Feature/sumcheck (#40) * feat: init sumcheck.rs * chore: rename * feat: update lib and add trait for transcript with vec storing challenges * bugfix: mut self ref of transcript * feat: tentative sum-check using poseidon * refactor: remove extension trait and use initial trait * refactor: stop using extension trait, use initial Transcript trait * feat: generic over CurveGroup sum-check verifier and algorithm * feat: implement generic sum-check veriy * bugfix: cargo clippy --fix * chore: cargo fmt * feat: (unstable) sum-check implementation * feat: start benches * chore: run clippy * chore: run cargo fmt * feat: add sum-check tests + benches * chore: clippy + fmt * chore: remove unstable sumcheck * chore: delete duplicated sum-check code * chore: remove deleted sum-check code from lib.rs imports * feat: remove non generic traits, implement sum-check with generic trait and add test * chore: remove non-generic struct * chore: remove non generic verifier * feat: make nifms generic over transcript and update to use poseidon transcript * chore: cargo fmt * chore: remove tmp benches * chore: update cargo.toml * refactor: remove Generic suffix * feat: prover state generic over CurveGroup * chore: disable clippy type complexity warning * refactor: remove Transcript type and espresso transcript dependency * refactor: SumCheckProver generic over CurveGroup * chore: add line to eof for `Cargo.toml` * bugfix: add error handling on sum-check prove and verify * chore: clippy fix * chore: add line at eof * fix: use `map_err` and call `to_string()` on `PolyIOPErrors` --- Cargo.toml | 2 +- src/folding/hypernova/nimfs.rs | 215 +++++++++++++---------- src/lib.rs | 4 + src/utils/espresso/sum_check/mod.rs | 182 +++++++++++-------- src/utils/espresso/sum_check/prover.rs | 30 ++-- src/utils/espresso/sum_check/structs.rs | 18 +- src/utils/espresso/sum_check/verifier.rs | 51 +++--- 7 files changed, 291 insertions(+), 211 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index dde7c9a..1392665 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,6 @@ color-eyre = "=0.6.2" # tmp imports for espresso's sumcheck ark-serialize = "^0.4.0" espresso_subroutines = {git="https://github.com/EspressoSystems/hyperplonk", package="subroutines"} -espresso_transcript = {git="https://github.com/EspressoSystems/hyperplonk", package="transcript"} [dev-dependencies] ark-pallas = {version="0.4.0", features=["r1cs"]} @@ -47,3 +46,4 @@ parallel = [ # changes done between v0.4.0 and this fix which would break compatibility. [patch.crates-io] ark-r1cs-std = { git = "https://github.com/arnaucube/ark-r1cs-std-cherry-picked/" } + diff --git a/src/folding/hypernova/nimfs.rs b/src/folding/hypernova/nimfs.rs index 3f9e882..5d39700 100644 --- a/src/folding/hypernova/nimfs.rs +++ b/src/folding/hypernova/nimfs.rs @@ -1,20 +1,21 @@ -use ark_ec::CurveGroup; +use ark_crypto_primitives::sponge::Absorb; +use ark_ec::{CurveGroup, Group}; use ark_ff::{Field, PrimeField}; use ark_std::{One, Zero}; -use espresso_subroutines::PolyIOP; -use espresso_transcript::IOPTranscript; - use super::cccs::{Witness, CCCS}; use super::lcccs::LCCCS; use super::utils::{compute_c_from_sigmas_and_thetas, compute_g, compute_sigmas_and_thetas}; use crate::ccs::CCS; +use crate::transcript::Transcript; use crate::utils::hypercube::BooleanHypercube; use crate::utils::sum_check::structs::IOPProof as SumCheckProof; -use crate::utils::sum_check::{verifier::interpolate_uni_poly, SumCheck}; +use crate::utils::sum_check::verifier::interpolate_uni_poly; +use crate::utils::sum_check::{IOPSumCheck, SumCheck}; use crate::utils::virtual_polynomial::VPAuxInfo; use crate::Error; +use std::fmt::Debug; use std::marker::PhantomData; /// Proof defines a multifolding proof @@ -30,11 +31,15 @@ pub struct SigmasThetas(pub Vec>, pub Vec>); #[derive(Debug)] /// Implements the Non-Interactive Multi Folding Scheme described in section 5 of /// [HyperNova](https://eprint.iacr.org/2023/573.pdf) -pub struct NIMFS { +pub struct NIMFS> { pub _c: PhantomData, + pub _t: PhantomData, } -impl NIMFS { +impl> NIMFS +where + ::ScalarField: Absorb, +{ pub fn fold( lcccs: &[LCCCS], cccs: &[CCCS], @@ -144,7 +149,7 @@ impl NIMFS { /// contains the sumcheck proof and the helper sumcheck claim sigmas and thetas. #[allow(clippy::type_complexity)] pub fn prove( - transcript: &mut IOPTranscript, + transcript: &mut impl Transcript, ccs: &CCS, running_instances: &[LCCCS], new_instances: &[CCCS], @@ -185,17 +190,19 @@ impl NIMFS { } // Step 1: Get some challenges - let gamma: C::ScalarField = transcript.get_and_append_challenge(b"gamma").unwrap(); - let beta: Vec = transcript - .get_and_append_challenge_vectors(b"beta", ccs.s) - .unwrap(); + let gamma_scalar = C::ScalarField::from_le_bytes_mod_order(b"gamma"); + let beta_scalar = C::ScalarField::from_le_bytes_mod_order(b"beta"); + transcript.absorb(&gamma_scalar); + let gamma: C::ScalarField = transcript.get_challenge(); + transcript.absorb(&beta_scalar); + let beta: Vec = transcript.get_challenges(ccs.s); // Compute g(x) let g = compute_g(ccs, running_instances, &z_lcccs, &z_cccs, gamma, &beta); // Step 3: Run the sumcheck prover - let sumcheck_proof = - as SumCheck>::prove(&g, transcript).unwrap(); // XXX unwrap + let sumcheck_proof = IOPSumCheck::::prove(&g, transcript) + .map_err(|err| Error::SumCheckProveError(err.to_string()))?; // Note: The following two "sanity checks" are done for this prototype, in a final version // they should be removed. @@ -209,8 +216,8 @@ impl NIMFS { } // note: this is the sum of g(x) over the whole boolean hypercube - let extracted_sum = - as SumCheck>::extract_sum(&sumcheck_proof); + let extracted_sum = IOPSumCheck::::extract_sum(&sumcheck_proof); + if extracted_sum != g_over_bhc { return Err(Error::NotEqual); } @@ -238,7 +245,9 @@ impl NIMFS { let sigmas_thetas = compute_sigmas_and_thetas(ccs, &z_lcccs, &z_cccs, &r_x_prime); // Step 6: Get the folding challenge - let rho: C::ScalarField = transcript.get_and_append_challenge(b"rho").unwrap(); + let rho_scalar = C::ScalarField::from_le_bytes_mod_order(b"rho"); + transcript.absorb(&rho_scalar); + let rho: C::ScalarField = transcript.get_challenge(); // Step 7: Create the folded instance let folded_lcccs = Self::fold( @@ -266,7 +275,7 @@ impl NIMFS { /// into a single LCCCS instance. /// Returns the folded LCCCS instance. pub fn verify( - transcript: &mut IOPTranscript, + transcript: &mut impl Transcript, ccs: &CCS, running_instances: &[LCCCS], new_instances: &[CCCS], @@ -282,10 +291,13 @@ impl NIMFS { } // Step 1: Get some challenges - let gamma: C::ScalarField = transcript.get_and_append_challenge(b"gamma").unwrap(); - let beta: Vec = transcript - .get_and_append_challenge_vectors(b"beta", ccs.s) - .unwrap(); + let gamma_scalar = C::ScalarField::from_le_bytes_mod_order(b"gamma"); + transcript.absorb(&gamma_scalar); + let gamma: C::ScalarField = transcript.get_challenge(); + + let beta_scalar = C::ScalarField::from_le_bytes_mod_order(b"beta"); + transcript.absorb(&beta_scalar); + let beta: Vec = transcript.get_challenges(ccs.s); let vp_aux_info = VPAuxInfo:: { max_degree: ccs.d + 1, @@ -304,13 +316,9 @@ impl NIMFS { } // Verify the interactive part of the sumcheck - let sumcheck_subclaim = as SumCheck>::verify( - sum_v_j_gamma, - &proof.sc_proof, - &vp_aux_info, - transcript, - ) - .unwrap(); + let sumcheck_subclaim = + IOPSumCheck::::verify(sum_v_j_gamma, &proof.sc_proof, &vp_aux_info, transcript) + .map_err(|err| Error::SumCheckVerifyError(err.to_string()))?; // Step 2: Dig into the sumcheck claim and extract the randomness used let r_x_prime = sumcheck_subclaim.point.clone(); @@ -347,7 +355,9 @@ impl NIMFS { } // Step 6: Get the folding challenge - let rho: C::ScalarField = transcript.get_and_append_challenge(b"rho").unwrap(); + let rho_scalar = C::ScalarField::from_le_bytes_mod_order(b"rho"); + transcript.absorb(&rho_scalar); + let rho: C::ScalarField = transcript.get_challenge(); // Step 7: Compute the folded instance Ok(Self::fold( @@ -364,6 +374,8 @@ impl NIMFS { pub mod tests { use super::*; use crate::ccs::tests::{get_test_ccs, get_test_z}; + use crate::transcript::poseidon::tests::poseidon_test_config; + use crate::transcript::poseidon::PoseidonTranscript; use ark_std::test_rng; use ark_std::UniformRand; @@ -395,9 +407,16 @@ pub mod tests { let mut rng = test_rng(); let rho = Fr::rand(&mut rng); - let folded = NIMFS::::fold(&[lcccs], &[cccs], &sigmas_thetas, r_x_prime, rho); + let folded = NIMFS::>::fold( + &[lcccs], + &[cccs], + &sigmas_thetas, + r_x_prime, + rho, + ); - let w_folded = NIMFS::::fold_witness(&[w1], &[w2], rho); + let w_folded = + NIMFS::>::fold_witness(&[w1], &[w2], rho); // check lcccs relation folded @@ -425,26 +444,30 @@ pub mod tests { let (new_instance, w2) = ccs.to_cccs(&mut rng, &pedersen_params, &z_2).unwrap(); // Prover's transcript - let mut transcript_p = IOPTranscript::::new(b"multifolding"); - transcript_p.append_message(b"init", b"init").unwrap(); + let poseidon_config = poseidon_test_config::(); + let mut transcript_p: PoseidonTranscript = + PoseidonTranscript::::new(&poseidon_config); + transcript_p.absorb(&Fr::from_le_bytes_mod_order(b"init init")); // Run the prover side of the multifolding - let (proof, folded_lcccs, folded_witness) = NIMFS::::prove( - &mut transcript_p, - &ccs, - &[running_instance.clone()], - &[new_instance.clone()], - &[w1], - &[w2], - ) - .unwrap(); + let (proof, folded_lcccs, folded_witness) = + NIMFS::>::prove( + &mut transcript_p, + &ccs, + &[running_instance.clone()], + &[new_instance.clone()], + &[w1], + &[w2], + ) + .unwrap(); // Verifier's transcript - let mut transcript_v = IOPTranscript::::new(b"multifolding"); - transcript_v.append_message(b"init", b"init").unwrap(); + let mut transcript_v: PoseidonTranscript = + PoseidonTranscript::::new(&poseidon_config); + transcript_v.absorb(&Fr::from_le_bytes_mod_order(b"init init")); // Run the verifier side of the multifolding - let folded_lcccs_v = NIMFS::::verify( + let folded_lcccs_v = NIMFS::>::verify( &mut transcript_v, &ccs, &[running_instance.clone()], @@ -474,10 +497,15 @@ pub mod tests { let (mut running_instance, mut w1) = ccs.to_lcccs(&mut rng, &pedersen_params, &z_1).unwrap(); - let mut transcript_p = IOPTranscript::::new(b"multifolding"); - let mut transcript_v = IOPTranscript::::new(b"multifolding"); - transcript_p.append_message(b"init", b"init").unwrap(); - transcript_v.append_message(b"init", b"init").unwrap(); + let poseidon_config = poseidon_test_config::(); + + let mut transcript_p: PoseidonTranscript = + PoseidonTranscript::::new(&poseidon_config); + transcript_p.absorb(&Fr::from_le_bytes_mod_order(b"init init")); + + let mut transcript_v: PoseidonTranscript = + PoseidonTranscript::::new(&poseidon_config); + transcript_v.absorb(&Fr::from_le_bytes_mod_order(b"init init")); let n: usize = 10; for i in 3..n { @@ -490,18 +518,19 @@ pub mod tests { let (new_instance, w2) = ccs.to_cccs(&mut rng, &pedersen_params, &z_2).unwrap(); // run the prover side of the multifolding - let (proof, folded_lcccs, folded_witness) = NIMFS::::prove( - &mut transcript_p, - &ccs, - &[running_instance.clone()], - &[new_instance.clone()], - &[w1], - &[w2], - ) - .unwrap(); + let (proof, folded_lcccs, folded_witness) = + NIMFS::>::prove( + &mut transcript_p, + &ccs, + &[running_instance.clone()], + &[new_instance.clone()], + &[w1], + &[w2], + ) + .unwrap(); // run the verifier side of the multifolding - let folded_lcccs_v = NIMFS::::verify( + let folded_lcccs_v = NIMFS::>::verify( &mut transcript_v, &ccs, &[running_instance.clone()], @@ -509,7 +538,6 @@ pub mod tests { proof, ) .unwrap(); - assert_eq!(folded_lcccs, folded_lcccs_v); // check that the folded instance with the folded witness holds the LCCCS relation @@ -565,26 +593,30 @@ pub mod tests { } // Prover's transcript - let mut transcript_p = IOPTranscript::::new(b"multifolding"); - transcript_p.append_message(b"init", b"init").unwrap(); + let poseidon_config = poseidon_test_config::(); + let mut transcript_p: PoseidonTranscript = + PoseidonTranscript::::new(&poseidon_config); + transcript_p.absorb(&Fr::from_le_bytes_mod_order(b"init init")); // Run the prover side of the multifolding - let (proof, folded_lcccs, folded_witness) = NIMFS::::prove( - &mut transcript_p, - &ccs, - &lcccs_instances, - &cccs_instances, - &w_lcccs, - &w_cccs, - ) - .unwrap(); + let (proof, folded_lcccs, folded_witness) = + NIMFS::>::prove( + &mut transcript_p, + &ccs, + &lcccs_instances, + &cccs_instances, + &w_lcccs, + &w_cccs, + ) + .unwrap(); // Verifier's transcript - let mut transcript_v = IOPTranscript::::new(b"multifolding"); - transcript_v.append_message(b"init", b"init").unwrap(); + let mut transcript_v: PoseidonTranscript = + PoseidonTranscript::::new(&poseidon_config); + transcript_v.absorb(&Fr::from_le_bytes_mod_order(b"init init")); // Run the verifier side of the multifolding - let folded_lcccs_v = NIMFS::::verify( + let folded_lcccs_v = NIMFS::>::verify( &mut transcript_v, &ccs, &lcccs_instances, @@ -610,13 +642,16 @@ pub mod tests { let ccs = get_test_ccs::(); let pedersen_params = Pedersen::new_params(&mut rng, ccs.n - ccs.l - 1); + let poseidon_config = poseidon_test_config::(); // Prover's transcript - let mut transcript_p = IOPTranscript::::new(b"multifolding"); - transcript_p.append_message(b"init", b"init").unwrap(); + let mut transcript_p: PoseidonTranscript = + PoseidonTranscript::::new(&poseidon_config); + transcript_p.absorb(&Fr::from_le_bytes_mod_order(b"init init")); // Verifier's transcript - let mut transcript_v = IOPTranscript::::new(b"multifolding"); - transcript_v.append_message(b"init", b"init").unwrap(); + let mut transcript_v: PoseidonTranscript = + PoseidonTranscript::::new(&poseidon_config); + transcript_v.absorb(&Fr::from_le_bytes_mod_order(b"init init")); let n_steps = 3; @@ -655,18 +690,19 @@ pub mod tests { } // Run the prover side of the multifolding - let (proof, folded_lcccs, folded_witness) = NIMFS::::prove( - &mut transcript_p, - &ccs, - &lcccs_instances, - &cccs_instances, - &w_lcccs, - &w_cccs, - ) - .unwrap(); + let (proof, folded_lcccs, folded_witness) = + NIMFS::>::prove( + &mut transcript_p, + &ccs, + &lcccs_instances, + &cccs_instances, + &w_lcccs, + &w_cccs, + ) + .unwrap(); // Run the verifier side of the multifolding - let folded_lcccs_v = NIMFS::::verify( + let folded_lcccs_v = NIMFS::>::verify( &mut transcript_v, &ccs, &lcccs_instances, @@ -674,6 +710,7 @@ pub mod tests { proof, ) .unwrap(); + assert_eq!(folded_lcccs, folded_lcccs_v); // Check that the folded LCCCS instance is a valid instance with respect to the folded witness diff --git a/src/lib.rs b/src/lib.rs index b6ffd50..5170728 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,10 @@ pub enum Error { R1CSUnrelaxedFail, #[error("Could not find the inner ConstraintSystem")] NoInnerConstraintSystem, + #[error("Sum-check prove failed: {0}")] + SumCheckProveError(String), + #[error("Sum-check verify failed: {0}")] + SumCheckVerifyError(String), } /// FoldingScheme defines trait that is implemented by the diverse folding schemes. It is defined diff --git a/src/utils/espresso/sum_check/mod.rs b/src/utils/espresso/sum_check/mod.rs index 9126cda..c30c39b 100644 --- a/src/utils/espresso/sum_check/mod.rs +++ b/src/utils/espresso/sum_check/mod.rs @@ -9,60 +9,56 @@ //! This module implements the sum check protocol. -use crate::utils::virtual_polynomial::{VPAuxInfo, VirtualPolynomial}; +use crate::{ + transcript::Transcript, + utils::virtual_polynomial::{VPAuxInfo, VirtualPolynomial}, +}; +use ark_ec::{CurveGroup, Group}; use ark_ff::PrimeField; use ark_poly::DenseMultilinearExtension; use ark_std::{end_timer, start_timer}; -use std::{fmt::Debug, sync::Arc}; +use std::{fmt::Debug, marker::PhantomData, sync::Arc}; -use espresso_subroutines::poly_iop::{prelude::PolyIOPErrors, PolyIOP}; -use espresso_transcript::IOPTranscript; -use structs::{IOPProof, IOPProverState, IOPVerifierState}; +use crate::utils::sum_check::structs::IOPProverMessage; +use crate::utils::sum_check::structs::IOPVerifierState; +use espresso_subroutines::poly_iop::prelude::PolyIOPErrors; +use structs::{IOPProof, IOPProverState}; mod prover; pub mod structs; pub mod verifier; -/// Trait for doing sum check protocols. -pub trait SumCheck { +/// A generic sum-check trait over a curve group +pub trait SumCheck { type VirtualPolynomial; type VPAuxInfo; type MultilinearExtension; type SumCheckProof: Clone + Debug + Default + PartialEq; - type Transcript; type SumCheckSubClaim: Clone + Debug + Default + PartialEq; /// Extract sum from the proof - fn extract_sum(proof: &Self::SumCheckProof) -> F; - - /// Initialize the system with a transcript - /// - /// This function is optional -- in the case where a SumCheck is - /// an building block for a more complex protocol, the transcript - /// may be initialized by this complex protocol, and passed to the - /// SumCheck prover/verifier. - fn init_transcript() -> Self::Transcript; + fn extract_sum(proof: &Self::SumCheckProof) -> C::ScalarField; /// Generate proof of the sum of polynomial over {0,1}^`num_vars` /// /// The polynomial is represented in the form of a VirtualPolynomial. fn prove( poly: &Self::VirtualPolynomial, - transcript: &mut Self::Transcript, + transcript: &mut impl Transcript, ) -> Result; /// Verify the claimed sum using the proof fn verify( - sum: F, + sum: C::ScalarField, proof: &Self::SumCheckProof, aux_info: &Self::VPAuxInfo, - transcript: &mut Self::Transcript, + transcript: &mut impl Transcript, ) -> Result; } /// Trait for sum check protocol prover side APIs. -pub trait SumCheckProver +pub trait SumCheckProver where Self: Sized, { @@ -79,16 +75,15 @@ where /// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2). fn prove_round_and_update_state( &mut self, - challenge: &Option, + challenge: &Option, ) -> Result; } /// Trait for sum check protocol verifier side APIs. -pub trait SumCheckVerifier { +pub trait SumCheckVerifier { type VPAuxInfo; type ProverMessage; type Challenge; - type Transcript; type SumCheckSubClaim; /// Initialize the verifier's state. @@ -103,7 +98,7 @@ pub trait SumCheckVerifier { fn verify_round_and_update_state( &mut self, prover_msg: &Self::ProverMessage, - transcript: &mut Self::Transcript, + transcript: &mut impl Transcript, ) -> Result; /// This function verifies the deferred checks in the interactive version of @@ -116,7 +111,7 @@ pub trait SumCheckVerifier { /// Larger field size guarantees smaller soundness error. fn check_and_generate_subclaim( &self, - asserted_sum: &F, + asserted_sum: &C::ScalarField, ) -> Result; } @@ -131,52 +126,52 @@ pub struct SumCheckSubClaim { pub expected_evaluation: F, } -impl SumCheck for PolyIOP { - type SumCheckProof = IOPProof; - type VirtualPolynomial = VirtualPolynomial; - type VPAuxInfo = VPAuxInfo; - type MultilinearExtension = Arc>; - type SumCheckSubClaim = SumCheckSubClaim; - type Transcript = IOPTranscript; +#[derive(Clone, Debug, Default, Copy, PartialEq, Eq)] +pub struct IOPSumCheck> { + #[doc(hidden)] + phantom: PhantomData, + #[doc(hidden)] + phantom2: PhantomData, +} - fn extract_sum(proof: &Self::SumCheckProof) -> F { +impl> SumCheck for IOPSumCheck { + type SumCheckProof = IOPProof; + type VirtualPolynomial = VirtualPolynomial; + type VPAuxInfo = VPAuxInfo; + type MultilinearExtension = Arc>; + type SumCheckSubClaim = SumCheckSubClaim; + + fn extract_sum(proof: &Self::SumCheckProof) -> C::ScalarField { let start = start_timer!(|| "extract sum"); let res = proof.proofs[0].evaluations[0] + proof.proofs[0].evaluations[1]; end_timer!(start); res } - fn init_transcript() -> Self::Transcript { - let start = start_timer!(|| "init transcript"); - let res = IOPTranscript::::new(b"Initializing SumCheck transcript"); - end_timer!(start); - res - } - fn prove( - poly: &Self::VirtualPolynomial, - transcript: &mut Self::Transcript, - ) -> Result { - let start = start_timer!(|| "sum check prove"); - - transcript.append_serializable_element(b"aux info", &poly.aux_info)?; - - let mut prover_state = IOPProverState::prover_init(poly)?; - let mut challenge = None; - let mut prover_msgs = Vec::with_capacity(poly.aux_info.num_variables); + poly: &VirtualPolynomial, + transcript: &mut impl Transcript, + ) -> Result, PolyIOPErrors> { + transcript.absorb(&::ScalarField::from( + poly.aux_info.num_variables as u64, + )); + transcript.absorb(&::ScalarField::from( + poly.aux_info.max_degree as u64, + )); + let mut prover_state: IOPProverState = IOPProverState::prover_init(poly)?; + let mut challenge: Option = None; + let mut prover_msgs: Vec> = + Vec::with_capacity(poly.aux_info.num_variables); for _ in 0..poly.aux_info.num_variables { - let prover_msg = + let prover_msg: IOPProverMessage = IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge)?; - transcript.append_serializable_element(b"prover msg", &prover_msg)?; + transcript.absorb_vec(&prover_msg.evaluations); prover_msgs.push(prover_msg); - challenge = Some(transcript.get_and_append_challenge(b"Internal round")?); + challenge = Some(transcript.get_challenge()); } - // pushing the last challenge point to the state if let Some(p) = challenge { prover_state.challenges.push(p) }; - - end_timer!(start); Ok(IOPProof { point: prover_state.challenges, proofs: prover_msgs, @@ -184,18 +179,19 @@ impl SumCheck for PolyIOP { } fn verify( - claimed_sum: F, - proof: &Self::SumCheckProof, - aux_info: &Self::VPAuxInfo, - transcript: &mut Self::Transcript, - ) -> Result { - let start = start_timer!(|| "sum check verify"); - - transcript.append_serializable_element(b"aux info", aux_info)?; + claimed_sum: C::ScalarField, + proof: &IOPProof, + aux_info: &VPAuxInfo, + transcript: &mut impl Transcript, + ) -> Result, PolyIOPErrors> { + transcript.absorb(&::ScalarField::from( + aux_info.num_variables as u64, + )); + transcript.absorb(&::ScalarField::from(aux_info.max_degree as u64)); let mut verifier_state = IOPVerifierState::verifier_init(aux_info); for i in 0..aux_info.num_variables { let prover_msg = proof.proofs.get(i).expect("proof is incomplete"); - transcript.append_serializable_element(b"prover msg", prover_msg)?; + transcript.absorb_vec(&prover_msg.evaluations); IOPVerifierState::verify_round_and_update_state( &mut verifier_state, prover_msg, @@ -203,9 +199,57 @@ impl SumCheck for PolyIOP { )?; } - let res = IOPVerifierState::check_and_generate_subclaim(&verifier_state, &claimed_sum); + IOPVerifierState::check_and_generate_subclaim(&verifier_state, &claimed_sum) + } +} - end_timer!(start); - res +#[cfg(test)] +pub mod tests { + use std::sync::Arc; + + use ark_ff::Field; + use ark_pallas::Fr; + use ark_pallas::Projective; + use ark_poly::DenseMultilinearExtension; + use ark_poly::MultilinearExtension; + use ark_std::test_rng; + + use crate::transcript::poseidon::tests::poseidon_test_config; + use crate::transcript::poseidon::PoseidonTranscript; + use crate::transcript::Transcript; + use crate::utils::sum_check::SumCheck; + use crate::utils::virtual_polynomial::VirtualPolynomial; + + use super::IOPSumCheck; + + #[test] + pub fn sumcheck_poseidon() { + let mut rng = test_rng(); + let poly_mle = DenseMultilinearExtension::rand(5, &mut rng); + let virtual_poly = VirtualPolynomial::new_from_mle(&Arc::new(poly_mle), Fr::ONE); + let poseidon_config = poseidon_test_config::(); + + // sum-check prove + let mut poseidon_transcript_prove: PoseidonTranscript = + PoseidonTranscript::::new(&poseidon_config); + let sum_check = IOPSumCheck::>::prove( + &virtual_poly, + &mut poseidon_transcript_prove, + ) + .unwrap(); + + // sum-check verify + let claimed_sum = + IOPSumCheck::>::extract_sum(&sum_check); + let mut poseidon_transcript_verify: PoseidonTranscript = + PoseidonTranscript::::new(&poseidon_config); + let res_verify = IOPSumCheck::>::verify( + claimed_sum, + &sum_check, + &virtual_poly.aux_info, + &mut poseidon_transcript_verify, + ); + + assert!(res_verify.is_ok()); } } diff --git a/src/utils/espresso/sum_check/prover.rs b/src/utils/espresso/sum_check/prover.rs index d43cc65..56fd829 100644 --- a/src/utils/espresso/sum_check/prover.rs +++ b/src/utils/espresso/sum_check/prover.rs @@ -12,6 +12,8 @@ use super::SumCheckProver; use crate::utils::multilinear_polynomial::fix_variables; use crate::utils::virtual_polynomial::VirtualPolynomial; +use ark_ec::CurveGroup; +use ark_ff::Field; use ark_ff::{batch_inversion, PrimeField}; use ark_poly::DenseMultilinearExtension; use ark_std::{cfg_into_iter, end_timer, start_timer, vec::Vec}; @@ -24,9 +26,9 @@ use espresso_subroutines::poly_iop::prelude::PolyIOPErrors; // #[cfg(feature = "parallel")] use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; -impl SumCheckProver for IOPProverState { - type VirtualPolynomial = VirtualPolynomial; - type ProverMessage = IOPProverMessage; +impl SumCheckProver for IOPProverState { + type VirtualPolynomial = VirtualPolynomial; + type ProverMessage = IOPProverMessage; /// Initialize the prover state to argue for the sum of the input polynomial /// over {0,1}^`num_vars`. @@ -45,7 +47,9 @@ impl SumCheckProver for IOPProverState { poly: polynomial.clone(), extrapolation_aux: (1..polynomial.aux_info.max_degree) .map(|degree| { - let points = (0..1 + degree as u64).map(F::from).collect::>(); + let points = (0..1 + degree as u64) + .map(C::ScalarField::from) + .collect::>(); let weights = barycentric_weights(&points); (points, weights) }) @@ -59,7 +63,7 @@ impl SumCheckProver for IOPProverState { /// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2). fn prove_round_and_update_state( &mut self, - challenge: &Option, + challenge: &Option, ) -> Result { // let start = // start_timer!(|| format!("sum check prove {}-th round and update state", @@ -84,7 +88,7 @@ impl SumCheckProver for IOPProverState { // g(r_1, ..., r_{m-1}, x_m ... x_n) // // eval g over r_m, and mutate g to g(r_1, ... r_m,, x_{m+1}... x_n) - let mut flattened_ml_extensions: Vec> = self + let mut flattened_ml_extensions: Vec> = self .poly .flattened_ml_extensions .par_iter() @@ -118,7 +122,7 @@ impl SumCheckProver for IOPProverState { self.round += 1; let products_list = self.poly.products.clone(); - let mut products_sum = vec![F::zero(); self.poly.aux_info.max_degree + 1]; + let mut products_sum = vec![C::ScalarField::ZERO; self.poly.aux_info.max_degree + 1]; // Step 2: generate sum for the partial evaluated polynomial: // f(r_1, ... r_m,, x_{m+1}... x_n) @@ -128,8 +132,8 @@ impl SumCheckProver for IOPProverState { .fold( || { ( - vec![(F::zero(), F::zero()); products.len()], - vec![F::zero(); products.len() + 1], + vec![(C::ScalarField::ZERO, C::ScalarField::ZERO); products.len()], + vec![C::ScalarField::ZERO; products.len() + 1], ) }, |(mut buf, mut acc), b| { @@ -140,17 +144,17 @@ impl SumCheckProver for IOPProverState { *eval = table[b << 1]; *step = table[(b << 1) + 1] - table[b << 1]; }); - acc[0] += buf.iter().map(|(eval, _)| eval).product::(); + acc[0] += buf.iter().map(|(eval, _)| eval).product::(); acc[1..].iter_mut().for_each(|acc| { buf.iter_mut().for_each(|(eval, step)| *eval += step as &_); - *acc += buf.iter().map(|(eval, _)| eval).product::(); + *acc += buf.iter().map(|(eval, _)| eval).product::(); }); (buf, acc) }, ) .map(|(_, partial)| partial) .reduce( - || vec![F::zero(); products.len() + 1], + || vec![C::ScalarField::ZERO; products.len() + 1], |mut sum, partial| { sum.iter_mut() .zip(partial.iter()) @@ -162,7 +166,7 @@ impl SumCheckProver for IOPProverState { let extraploation = cfg_into_iter!(0..self.poly.aux_info.max_degree - products.len()) .map(|i| { let (points, weights) = &self.extrapolation_aux[products.len() - 1]; - let at = F::from((products.len() + 1 + i) as u64); + let at = C::ScalarField::from((products.len() + 1 + i) as u64); extrapolate(points, weights, &sum, &at) }) .collect::>(); diff --git a/src/utils/espresso/sum_check/structs.rs b/src/utils/espresso/sum_check/structs.rs index 88b855a..40dbeff 100644 --- a/src/utils/espresso/sum_check/structs.rs +++ b/src/utils/espresso/sum_check/structs.rs @@ -10,6 +10,7 @@ //! This module defines structs that are shared by all sub protocols. use crate::utils::virtual_polynomial::VirtualPolynomial; +use ark_ec::CurveGroup; use ark_ff::PrimeField; use ark_serialize::CanonicalSerialize; @@ -32,28 +33,29 @@ pub struct IOPProverMessage { /// Prover State of a PolyIOP. #[derive(Debug)] -pub struct IOPProverState { +pub struct IOPProverState { /// sampled randomness given by the verifier - pub challenges: Vec, + pub challenges: Vec, /// the current round number pub(crate) round: usize, /// pointer to the virtual polynomial - pub(crate) poly: VirtualPolynomial, + pub(crate) poly: VirtualPolynomial, /// points with precomputed barycentric weights for extrapolating smaller /// degree uni-polys to `max_degree + 1` evaluations. - pub(crate) extrapolation_aux: Vec<(Vec, Vec)>, + #[allow(clippy::type_complexity)] + pub(crate) extrapolation_aux: Vec<(Vec, Vec)>, } -/// Prover State of a PolyIOP +/// Verifier State of a PolyIOP, generic over a curve group #[derive(Debug)] -pub struct IOPVerifierState { +pub struct IOPVerifierState { pub(crate) round: usize, pub(crate) num_vars: usize, pub(crate) max_degree: usize, pub(crate) finished: bool, /// a list storing the univariate polynomial in evaluation form sent by the /// prover at each round - pub(crate) polynomials_received: Vec>, + pub(crate) polynomials_received: Vec>, /// a list storing the randomness sampled by the verifier at each round - pub(crate) challenges: Vec, + pub(crate) challenges: Vec, } diff --git a/src/utils/espresso/sum_check/verifier.rs b/src/utils/espresso/sum_check/verifier.rs index e9b4470..21dca45 100644 --- a/src/utils/espresso/sum_check/verifier.rs +++ b/src/utils/espresso/sum_check/verifier.rs @@ -9,24 +9,25 @@ //! Verifier subroutines for a SumCheck protocol. -use super::{SumCheckSubClaim, SumCheckVerifier}; -use crate::utils::virtual_polynomial::VPAuxInfo; +use super::{ + structs::{IOPProverMessage, IOPVerifierState}, + SumCheckSubClaim, SumCheckVerifier, +}; +use crate::{transcript::Transcript, utils::virtual_polynomial::VPAuxInfo}; +use ark_ec::CurveGroup; use ark_ff::PrimeField; use ark_std::{end_timer, start_timer}; -use super::structs::{IOPProverMessage, IOPVerifierState}; use espresso_subroutines::poly_iop::prelude::PolyIOPErrors; -use espresso_transcript::IOPTranscript; #[cfg(feature = "parallel")] use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; -impl SumCheckVerifier for IOPVerifierState { - type VPAuxInfo = VPAuxInfo; - type ProverMessage = IOPProverMessage; - type Challenge = F; - type Transcript = IOPTranscript; - type SumCheckSubClaim = SumCheckSubClaim; +impl SumCheckVerifier for IOPVerifierState { + type VPAuxInfo = VPAuxInfo; + type ProverMessage = IOPProverMessage; + type Challenge = C::ScalarField; + type SumCheckSubClaim = SumCheckSubClaim; /// Initialize the verifier's state. fn verifier_init(index_info: &Self::VPAuxInfo) -> Self { @@ -43,17 +44,11 @@ impl SumCheckVerifier for IOPVerifierState { res } - /// Run verifier for the current round, given a prover message. - /// - /// Note that `verify_round_and_update_state` only samples and stores - /// challenges; and update the verifier's state accordingly. The actual - /// verifications are deferred (in batch) to `check_and_generate_subclaim` - /// at the last step. fn verify_round_and_update_state( &mut self, - prover_msg: &Self::ProverMessage, - transcript: &mut Self::Transcript, - ) -> Result { + prover_msg: & as SumCheckVerifier>::ProverMessage, + transcript: &mut impl Transcript, + ) -> Result< as SumCheckVerifier>::Challenge, PolyIOPErrors> { let start = start_timer!(|| format!("sum check verify {}-th round and update state", self.round)); @@ -70,8 +65,7 @@ impl SumCheckVerifier for IOPVerifierState { // // When we turn the protocol to a non-interactive one, it is sufficient to defer // such checks to `check_and_generate_subclaim` after the last round. - - let challenge = transcript.get_and_append_challenge(b"Internal round")?; + let challenge = transcript.get_challenge(); self.challenges.push(challenge); self.polynomials_received .push(prover_msg.evaluations.to_vec()); @@ -88,17 +82,9 @@ impl SumCheckVerifier for IOPVerifierState { Ok(challenge) } - /// This function verifies the deferred checks in the interactive version of - /// the protocol; and generate the subclaim. Returns an error if the - /// proof failed to verify. - /// - /// If the asserted sum is correct, then the multilinear polynomial - /// evaluated at `subclaim.point` will be `subclaim.expected_evaluation`. - /// Otherwise, it is highly unlikely that those two will be equal. - /// Larger field size guarantees smaller soundness error. fn check_and_generate_subclaim( &self, - asserted_sum: &F, + asserted_sum: &C::ScalarField, ) -> Result { let start = start_timer!(|| "sum check check and generate subclaim"); if !self.finished { @@ -129,7 +115,7 @@ impl SumCheckVerifier for IOPVerifierState { self.max_degree + 1 ))); } - interpolate_uni_poly::(&evaluations, challenge) + interpolate_uni_poly::(&evaluations, challenge) }) .collect::, PolyIOPErrors>>()?; @@ -160,6 +146,9 @@ impl SumCheckVerifier for IOPVerifierState { .zip(expected_vec.iter()) .take(self.num_vars) { + let eval_: C::ScalarField = evaluations[0] + evaluations[1]; + + println!("evaluations: {:?}, expected: {:?}", eval_, expected); // the deferred check during the interactive phase: // 1. check if the received 'P(0) + P(1) = expected`. if evaluations[0] + evaluations[1] != expected {