Browse Source

Feature/sumcheck (#40)

* feat: init sumcheck.rs

* chore: rename

* feat: update lib and add trait for transcript with vec storing challenges

* bugfix: mut self ref of transcript

* feat: tentative sum-check using poseidon

* refactor: remove extension trait and use initial trait

* refactor: stop using extension trait, use initial Transcript trait

* feat: generic over CurveGroup sum-check verifier and algorithm

* feat: implement generic sum-check veriy

* bugfix: cargo clippy --fix

* chore: cargo fmt

* feat: (unstable) sum-check implementation

* feat: start benches

* chore: run clippy

* chore: run cargo fmt

* feat: add sum-check tests + benches

* chore: clippy + fmt

* chore: remove unstable sumcheck

* chore: delete duplicated sum-check code

* chore: remove deleted sum-check code from lib.rs imports

* feat: remove non generic traits, implement sum-check with generic trait and add test

* chore: remove non-generic struct

* chore: remove non generic verifier

* feat: make nifms generic over transcript and update to use poseidon transcript

* chore: cargo fmt

* chore: remove tmp benches

* chore: update cargo.toml

* refactor: remove Generic suffix

* feat: prover state generic over CurveGroup

* chore: disable clippy type complexity warning

* refactor: remove Transcript type and espresso transcript dependency

* refactor: SumCheckProver generic over CurveGroup

* chore: add line to eof for `Cargo.toml`

* bugfix: add error handling on sum-check prove and verify

* chore: clippy fix

* chore: add line at eof

* fix: use `map_err` and call `to_string()` on `PolyIOPErrors`
main
Pierre 1 year ago
committed by GitHub
parent
commit
876e23c159
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 291 additions and 211 deletions
  1. +1
    -1
      Cargo.toml
  2. +126
    -89
      src/folding/hypernova/nimfs.rs
  3. +4
    -0
      src/lib.rs
  4. +113
    -69
      src/utils/espresso/sum_check/mod.rs
  5. +17
    -13
      src/utils/espresso/sum_check/prover.rs
  6. +10
    -8
      src/utils/espresso/sum_check/structs.rs
  7. +20
    -31
      src/utils/espresso/sum_check/verifier.rs

+ 1
- 1
Cargo.toml

@ -20,7 +20,6 @@ color-eyre = "=0.6.2"
# tmp imports for espresso's sumcheck
ark-serialize = "^0.4.0"
espresso_subroutines = {git="https://github.com/EspressoSystems/hyperplonk", package="subroutines"}
espresso_transcript = {git="https://github.com/EspressoSystems/hyperplonk", package="transcript"}
[dev-dependencies]
ark-pallas = {version="0.4.0", features=["r1cs"]}
@ -47,3 +46,4 @@ parallel = [
# changes done between v0.4.0 and this fix which would break compatibility.
[patch.crates-io]
ark-r1cs-std = { git = "https://github.com/arnaucube/ark-r1cs-std-cherry-picked/" }

+ 126
- 89
src/folding/hypernova/nimfs.rs

@ -1,20 +1,21 @@
use ark_ec::CurveGroup;
use ark_crypto_primitives::sponge::Absorb;
use ark_ec::{CurveGroup, Group};
use ark_ff::{Field, PrimeField};
use ark_std::{One, Zero};
use espresso_subroutines::PolyIOP;
use espresso_transcript::IOPTranscript;
use super::cccs::{Witness, CCCS};
use super::lcccs::LCCCS;
use super::utils::{compute_c_from_sigmas_and_thetas, compute_g, compute_sigmas_and_thetas};
use crate::ccs::CCS;
use crate::transcript::Transcript;
use crate::utils::hypercube::BooleanHypercube;
use crate::utils::sum_check::structs::IOPProof as SumCheckProof;
use crate::utils::sum_check::{verifier::interpolate_uni_poly, SumCheck};
use crate::utils::sum_check::verifier::interpolate_uni_poly;
use crate::utils::sum_check::{IOPSumCheck, SumCheck};
use crate::utils::virtual_polynomial::VPAuxInfo;
use crate::Error;
use std::fmt::Debug;
use std::marker::PhantomData;
/// Proof defines a multifolding proof
@ -30,11 +31,15 @@ pub struct SigmasThetas(pub Vec>, pub Vec>);
#[derive(Debug)]
/// Implements the Non-Interactive Multi Folding Scheme described in section 5 of
/// [HyperNova](https://eprint.iacr.org/2023/573.pdf)
pub struct NIMFS<C: CurveGroup> {
pub struct NIMFS<C: CurveGroup, T: Transcript<C>> {
pub _c: PhantomData<C>,
pub _t: PhantomData<T>,
}
impl<C: CurveGroup> NIMFS<C> {
impl<C: CurveGroup, T: Transcript<C>> NIMFS<C, T>
where
<C as Group>::ScalarField: Absorb,
{
pub fn fold(
lcccs: &[LCCCS<C>],
cccs: &[CCCS<C>],
@ -144,7 +149,7 @@ impl NIMFS {
/// contains the sumcheck proof and the helper sumcheck claim sigmas and thetas.
#[allow(clippy::type_complexity)]
pub fn prove(
transcript: &mut IOPTranscript<C::ScalarField>,
transcript: &mut impl Transcript<C>,
ccs: &CCS<C>,
running_instances: &[LCCCS<C>],
new_instances: &[CCCS<C>],
@ -185,17 +190,19 @@ impl NIMFS {
}
// Step 1: Get some challenges
let gamma: C::ScalarField = transcript.get_and_append_challenge(b"gamma").unwrap();
let beta: Vec<C::ScalarField> = transcript
.get_and_append_challenge_vectors(b"beta", ccs.s)
.unwrap();
let gamma_scalar = C::ScalarField::from_le_bytes_mod_order(b"gamma");
let beta_scalar = C::ScalarField::from_le_bytes_mod_order(b"beta");
transcript.absorb(&gamma_scalar);
let gamma: C::ScalarField = transcript.get_challenge();
transcript.absorb(&beta_scalar);
let beta: Vec<C::ScalarField> = transcript.get_challenges(ccs.s);
// Compute g(x)
let g = compute_g(ccs, running_instances, &z_lcccs, &z_cccs, gamma, &beta);
// Step 3: Run the sumcheck prover
let sumcheck_proof =
<PolyIOP<C::ScalarField> as SumCheck<C::ScalarField>>::prove(&g, transcript).unwrap(); // XXX unwrap
let sumcheck_proof = IOPSumCheck::<C, T>::prove(&g, transcript)
.map_err(|err| Error::SumCheckProveError(err.to_string()))?;
// Note: The following two "sanity checks" are done for this prototype, in a final version
// they should be removed.
@ -209,8 +216,8 @@ impl NIMFS {
}
// note: this is the sum of g(x) over the whole boolean hypercube
let extracted_sum =
<PolyIOP<C::ScalarField> as SumCheck<C::ScalarField>>::extract_sum(&sumcheck_proof);
let extracted_sum = IOPSumCheck::<C, T>::extract_sum(&sumcheck_proof);
if extracted_sum != g_over_bhc {
return Err(Error::NotEqual);
}
@ -238,7 +245,9 @@ impl NIMFS {
let sigmas_thetas = compute_sigmas_and_thetas(ccs, &z_lcccs, &z_cccs, &r_x_prime);
// Step 6: Get the folding challenge
let rho: C::ScalarField = transcript.get_and_append_challenge(b"rho").unwrap();
let rho_scalar = C::ScalarField::from_le_bytes_mod_order(b"rho");
transcript.absorb(&rho_scalar);
let rho: C::ScalarField = transcript.get_challenge();
// Step 7: Create the folded instance
let folded_lcccs = Self::fold(
@ -266,7 +275,7 @@ impl NIMFS {
/// into a single LCCCS instance.
/// Returns the folded LCCCS instance.
pub fn verify(
transcript: &mut IOPTranscript<C::ScalarField>,
transcript: &mut impl Transcript<C>,
ccs: &CCS<C>,
running_instances: &[LCCCS<C>],
new_instances: &[CCCS<C>],
@ -282,10 +291,13 @@ impl NIMFS {
}
// Step 1: Get some challenges
let gamma: C::ScalarField = transcript.get_and_append_challenge(b"gamma").unwrap();
let beta: Vec<C::ScalarField> = transcript
.get_and_append_challenge_vectors(b"beta", ccs.s)
.unwrap();
let gamma_scalar = C::ScalarField::from_le_bytes_mod_order(b"gamma");
transcript.absorb(&gamma_scalar);
let gamma: C::ScalarField = transcript.get_challenge();
let beta_scalar = C::ScalarField::from_le_bytes_mod_order(b"beta");
transcript.absorb(&beta_scalar);
let beta: Vec<C::ScalarField> = transcript.get_challenges(ccs.s);
let vp_aux_info = VPAuxInfo::<C::ScalarField> {
max_degree: ccs.d + 1,
@ -304,13 +316,9 @@ impl NIMFS {
}
// Verify the interactive part of the sumcheck
let sumcheck_subclaim = <PolyIOP<C::ScalarField> as SumCheck<C::ScalarField>>::verify(
sum_v_j_gamma,
&proof.sc_proof,
&vp_aux_info,
transcript,
)
.unwrap();
let sumcheck_subclaim =
IOPSumCheck::<C, T>::verify(sum_v_j_gamma, &proof.sc_proof, &vp_aux_info, transcript)
.map_err(|err| Error::SumCheckVerifyError(err.to_string()))?;
// Step 2: Dig into the sumcheck claim and extract the randomness used
let r_x_prime = sumcheck_subclaim.point.clone();
@ -347,7 +355,9 @@ impl NIMFS {
}
// Step 6: Get the folding challenge
let rho: C::ScalarField = transcript.get_and_append_challenge(b"rho").unwrap();
let rho_scalar = C::ScalarField::from_le_bytes_mod_order(b"rho");
transcript.absorb(&rho_scalar);
let rho: C::ScalarField = transcript.get_challenge();
// Step 7: Compute the folded instance
Ok(Self::fold(
@ -364,6 +374,8 @@ impl NIMFS {
pub mod tests {
use super::*;
use crate::ccs::tests::{get_test_ccs, get_test_z};
use crate::transcript::poseidon::tests::poseidon_test_config;
use crate::transcript::poseidon::PoseidonTranscript;
use ark_std::test_rng;
use ark_std::UniformRand;
@ -395,9 +407,16 @@ pub mod tests {
let mut rng = test_rng();
let rho = Fr::rand(&mut rng);
let folded = NIMFS::<Projective>::fold(&[lcccs], &[cccs], &sigmas_thetas, r_x_prime, rho);
let folded = NIMFS::<Projective, PoseidonTranscript<Projective>>::fold(
&[lcccs],
&[cccs],
&sigmas_thetas,
r_x_prime,
rho,
);
let w_folded = NIMFS::<Projective>::fold_witness(&[w1], &[w2], rho);
let w_folded =
NIMFS::<Projective, PoseidonTranscript<Projective>>::fold_witness(&[w1], &[w2], rho);
// check lcccs relation
folded
@ -425,26 +444,30 @@ pub mod tests {
let (new_instance, w2) = ccs.to_cccs(&mut rng, &pedersen_params, &z_2).unwrap();
// Prover's transcript
let mut transcript_p = IOPTranscript::<Fr>::new(b"multifolding");
transcript_p.append_message(b"init", b"init").unwrap();
let poseidon_config = poseidon_test_config::<Fr>();
let mut transcript_p: PoseidonTranscript<Projective> =
PoseidonTranscript::<Projective>::new(&poseidon_config);
transcript_p.absorb(&Fr::from_le_bytes_mod_order(b"init init"));
// Run the prover side of the multifolding
let (proof, folded_lcccs, folded_witness) = NIMFS::<Projective>::prove(
&mut transcript_p,
&ccs,
&[running_instance.clone()],
&[new_instance.clone()],
&[w1],
&[w2],
)
.unwrap();
let (proof, folded_lcccs, folded_witness) =
NIMFS::<Projective, PoseidonTranscript<Projective>>::prove(
&mut transcript_p,
&ccs,
&[running_instance.clone()],
&[new_instance.clone()],
&[w1],
&[w2],
)
.unwrap();
// Verifier's transcript
let mut transcript_v = IOPTranscript::<Fr>::new(b"multifolding");
transcript_v.append_message(b"init", b"init").unwrap();
let mut transcript_v: PoseidonTranscript<Projective> =
PoseidonTranscript::<Projective>::new(&poseidon_config);
transcript_v.absorb(&Fr::from_le_bytes_mod_order(b"init init"));
// Run the verifier side of the multifolding
let folded_lcccs_v = NIMFS::<Projective>::verify(
let folded_lcccs_v = NIMFS::<Projective, PoseidonTranscript<Projective>>::verify(
&mut transcript_v,
&ccs,
&[running_instance.clone()],
@ -474,10 +497,15 @@ pub mod tests {
let (mut running_instance, mut w1) =
ccs.to_lcccs(&mut rng, &pedersen_params, &z_1).unwrap();
let mut transcript_p = IOPTranscript::<Fr>::new(b"multifolding");
let mut transcript_v = IOPTranscript::<Fr>::new(b"multifolding");
transcript_p.append_message(b"init", b"init").unwrap();
transcript_v.append_message(b"init", b"init").unwrap();
let poseidon_config = poseidon_test_config::<Fr>();
let mut transcript_p: PoseidonTranscript<Projective> =
PoseidonTranscript::<Projective>::new(&poseidon_config);
transcript_p.absorb(&Fr::from_le_bytes_mod_order(b"init init"));
let mut transcript_v: PoseidonTranscript<Projective> =
PoseidonTranscript::<Projective>::new(&poseidon_config);
transcript_v.absorb(&Fr::from_le_bytes_mod_order(b"init init"));
let n: usize = 10;
for i in 3..n {
@ -490,18 +518,19 @@ pub mod tests {
let (new_instance, w2) = ccs.to_cccs(&mut rng, &pedersen_params, &z_2).unwrap();
// run the prover side of the multifolding
let (proof, folded_lcccs, folded_witness) = NIMFS::<Projective>::prove(
&mut transcript_p,
&ccs,
&[running_instance.clone()],
&[new_instance.clone()],
&[w1],
&[w2],
)
.unwrap();
let (proof, folded_lcccs, folded_witness) =
NIMFS::<Projective, PoseidonTranscript<Projective>>::prove(
&mut transcript_p,
&ccs,
&[running_instance.clone()],
&[new_instance.clone()],
&[w1],
&[w2],
)
.unwrap();
// run the verifier side of the multifolding
let folded_lcccs_v = NIMFS::<Projective>::verify(
let folded_lcccs_v = NIMFS::<Projective, PoseidonTranscript<Projective>>::verify(
&mut transcript_v,
&ccs,
&[running_instance.clone()],
@ -509,7 +538,6 @@ pub mod tests {
proof,
)
.unwrap();
assert_eq!(folded_lcccs, folded_lcccs_v);
// check that the folded instance with the folded witness holds the LCCCS relation
@ -565,26 +593,30 @@ pub mod tests {
}
// Prover's transcript
let mut transcript_p = IOPTranscript::<Fr>::new(b"multifolding");
transcript_p.append_message(b"init", b"init").unwrap();
let poseidon_config = poseidon_test_config::<Fr>();
let mut transcript_p: PoseidonTranscript<Projective> =
PoseidonTranscript::<Projective>::new(&poseidon_config);
transcript_p.absorb(&Fr::from_le_bytes_mod_order(b"init init"));
// Run the prover side of the multifolding
let (proof, folded_lcccs, folded_witness) = NIMFS::<Projective>::prove(
&mut transcript_p,
&ccs,
&lcccs_instances,
&cccs_instances,
&w_lcccs,
&w_cccs,
)
.unwrap();
let (proof, folded_lcccs, folded_witness) =
NIMFS::<Projective, PoseidonTranscript<Projective>>::prove(
&mut transcript_p,
&ccs,
&lcccs_instances,
&cccs_instances,
&w_lcccs,
&w_cccs,
)
.unwrap();
// Verifier's transcript
let mut transcript_v = IOPTranscript::<Fr>::new(b"multifolding");
transcript_v.append_message(b"init", b"init").unwrap();
let mut transcript_v: PoseidonTranscript<Projective> =
PoseidonTranscript::<Projective>::new(&poseidon_config);
transcript_v.absorb(&Fr::from_le_bytes_mod_order(b"init init"));
// Run the verifier side of the multifolding
let folded_lcccs_v = NIMFS::<Projective>::verify(
let folded_lcccs_v = NIMFS::<Projective, PoseidonTranscript<Projective>>::verify(
&mut transcript_v,
&ccs,
&lcccs_instances,
@ -610,13 +642,16 @@ pub mod tests {
let ccs = get_test_ccs::<Projective>();
let pedersen_params = Pedersen::new_params(&mut rng, ccs.n - ccs.l - 1);
let poseidon_config = poseidon_test_config::<Fr>();
// Prover's transcript
let mut transcript_p = IOPTranscript::<Fr>::new(b"multifolding");
transcript_p.append_message(b"init", b"init").unwrap();
let mut transcript_p: PoseidonTranscript<Projective> =
PoseidonTranscript::<Projective>::new(&poseidon_config);
transcript_p.absorb(&Fr::from_le_bytes_mod_order(b"init init"));
// Verifier's transcript
let mut transcript_v = IOPTranscript::<Fr>::new(b"multifolding");
transcript_v.append_message(b"init", b"init").unwrap();
let mut transcript_v: PoseidonTranscript<Projective> =
PoseidonTranscript::<Projective>::new(&poseidon_config);
transcript_v.absorb(&Fr::from_le_bytes_mod_order(b"init init"));
let n_steps = 3;
@ -655,18 +690,19 @@ pub mod tests {
}
// Run the prover side of the multifolding
let (proof, folded_lcccs, folded_witness) = NIMFS::<Projective>::prove(
&mut transcript_p,
&ccs,
&lcccs_instances,
&cccs_instances,
&w_lcccs,
&w_cccs,
)
.unwrap();
let (proof, folded_lcccs, folded_witness) =
NIMFS::<Projective, PoseidonTranscript<Projective>>::prove(
&mut transcript_p,
&ccs,
&lcccs_instances,
&cccs_instances,
&w_lcccs,
&w_cccs,
)
.unwrap();
// Run the verifier side of the multifolding
let folded_lcccs_v = NIMFS::<Projective>::verify(
let folded_lcccs_v = NIMFS::<Projective, PoseidonTranscript<Projective>>::verify(
&mut transcript_v,
&ccs,
&lcccs_instances,
@ -674,6 +710,7 @@ pub mod tests {
proof,
)
.unwrap();
assert_eq!(folded_lcccs, folded_lcccs_v);
// Check that the folded LCCCS instance is a valid instance with respect to the folded witness

+ 4
- 0
src/lib.rs

@ -43,6 +43,10 @@ pub enum Error {
R1CSUnrelaxedFail,
#[error("Could not find the inner ConstraintSystem")]
NoInnerConstraintSystem,
#[error("Sum-check prove failed: {0}")]
SumCheckProveError(String),
#[error("Sum-check verify failed: {0}")]
SumCheckVerifyError(String),
}
/// FoldingScheme defines trait that is implemented by the diverse folding schemes. It is defined

+ 113
- 69
src/utils/espresso/sum_check/mod.rs

@ -9,60 +9,56 @@
//! This module implements the sum check protocol.
use crate::utils::virtual_polynomial::{VPAuxInfo, VirtualPolynomial};
use crate::{
transcript::Transcript,
utils::virtual_polynomial::{VPAuxInfo, VirtualPolynomial},
};
use ark_ec::{CurveGroup, Group};
use ark_ff::PrimeField;
use ark_poly::DenseMultilinearExtension;
use ark_std::{end_timer, start_timer};
use std::{fmt::Debug, sync::Arc};
use std::{fmt::Debug, marker::PhantomData, sync::Arc};
use espresso_subroutines::poly_iop::{prelude::PolyIOPErrors, PolyIOP};
use espresso_transcript::IOPTranscript;
use structs::{IOPProof, IOPProverState, IOPVerifierState};
use crate::utils::sum_check::structs::IOPProverMessage;
use crate::utils::sum_check::structs::IOPVerifierState;
use espresso_subroutines::poly_iop::prelude::PolyIOPErrors;
use structs::{IOPProof, IOPProverState};
mod prover;
pub mod structs;
pub mod verifier;
/// Trait for doing sum check protocols.
pub trait SumCheck<F: PrimeField> {
/// A generic sum-check trait over a curve group
pub trait SumCheck<C: CurveGroup> {
type VirtualPolynomial;
type VPAuxInfo;
type MultilinearExtension;
type SumCheckProof: Clone + Debug + Default + PartialEq;
type Transcript;
type SumCheckSubClaim: Clone + Debug + Default + PartialEq;
/// Extract sum from the proof
fn extract_sum(proof: &Self::SumCheckProof) -> F;
/// Initialize the system with a transcript
///
/// This function is optional -- in the case where a SumCheck is
/// an building block for a more complex protocol, the transcript
/// may be initialized by this complex protocol, and passed to the
/// SumCheck prover/verifier.
fn init_transcript() -> Self::Transcript;
fn extract_sum(proof: &Self::SumCheckProof) -> C::ScalarField;
/// Generate proof of the sum of polynomial over {0,1}^`num_vars`
///
/// The polynomial is represented in the form of a VirtualPolynomial.
fn prove(
poly: &Self::VirtualPolynomial,
transcript: &mut Self::Transcript,
transcript: &mut impl Transcript<C>,
) -> Result<Self::SumCheckProof, PolyIOPErrors>;
/// Verify the claimed sum using the proof
fn verify(
sum: F,
sum: C::ScalarField,
proof: &Self::SumCheckProof,
aux_info: &Self::VPAuxInfo,
transcript: &mut Self::Transcript,
transcript: &mut impl Transcript<C>,
) -> Result<Self::SumCheckSubClaim, PolyIOPErrors>;
}
/// Trait for sum check protocol prover side APIs.
pub trait SumCheckProver<F: PrimeField>
pub trait SumCheckProver<C: CurveGroup>
where
Self: Sized,
{
@ -79,16 +75,15 @@ where
/// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2).
fn prove_round_and_update_state(
&mut self,
challenge: &Option<F>,
challenge: &Option<C::ScalarField>,
) -> Result<Self::ProverMessage, PolyIOPErrors>;
}
/// Trait for sum check protocol verifier side APIs.
pub trait SumCheckVerifier<F: PrimeField> {
pub trait SumCheckVerifier<C: CurveGroup> {
type VPAuxInfo;
type ProverMessage;
type Challenge;
type Transcript;
type SumCheckSubClaim;
/// Initialize the verifier's state.
@ -103,7 +98,7 @@ pub trait SumCheckVerifier {
fn verify_round_and_update_state(
&mut self,
prover_msg: &Self::ProverMessage,
transcript: &mut Self::Transcript,
transcript: &mut impl Transcript<C>,
) -> Result<Self::Challenge, PolyIOPErrors>;
/// This function verifies the deferred checks in the interactive version of
@ -116,7 +111,7 @@ pub trait SumCheckVerifier {
/// Larger field size guarantees smaller soundness error.
fn check_and_generate_subclaim(
&self,
asserted_sum: &F,
asserted_sum: &C::ScalarField,
) -> Result<Self::SumCheckSubClaim, PolyIOPErrors>;
}
@ -131,52 +126,52 @@ pub struct SumCheckSubClaim {
pub expected_evaluation: F,
}
impl<F: PrimeField> SumCheck<F> for PolyIOP<F> {
type SumCheckProof = IOPProof<F>;
type VirtualPolynomial = VirtualPolynomial<F>;
type VPAuxInfo = VPAuxInfo<F>;
type MultilinearExtension = Arc<DenseMultilinearExtension<F>>;
type SumCheckSubClaim = SumCheckSubClaim<F>;
type Transcript = IOPTranscript<F>;
#[derive(Clone, Debug, Default, Copy, PartialEq, Eq)]
pub struct IOPSumCheck<C: CurveGroup, T: Transcript<C>> {
#[doc(hidden)]
phantom: PhantomData<C>,
#[doc(hidden)]
phantom2: PhantomData<T>,
}
fn extract_sum(proof: &Self::SumCheckProof) -> F {
impl<C: CurveGroup, T: Transcript<C>> SumCheck<C> for IOPSumCheck<C, T> {
type SumCheckProof = IOPProof<C::ScalarField>;
type VirtualPolynomial = VirtualPolynomial<C::ScalarField>;
type VPAuxInfo = VPAuxInfo<C::ScalarField>;
type MultilinearExtension = Arc<DenseMultilinearExtension<C::ScalarField>>;
type SumCheckSubClaim = SumCheckSubClaim<C::ScalarField>;
fn extract_sum(proof: &Self::SumCheckProof) -> C::ScalarField {
let start = start_timer!(|| "extract sum");
let res = proof.proofs[0].evaluations[0] + proof.proofs[0].evaluations[1];
end_timer!(start);
res
}
fn init_transcript() -> Self::Transcript {
let start = start_timer!(|| "init transcript");
let res = IOPTranscript::<F>::new(b"Initializing SumCheck transcript");
end_timer!(start);
res
}
fn prove(
poly: &Self::VirtualPolynomial,
transcript: &mut Self::Transcript,
) -> Result<Self::SumCheckProof, PolyIOPErrors> {
let start = start_timer!(|| "sum check prove");
transcript.append_serializable_element(b"aux info", &poly.aux_info)?;
let mut prover_state = IOPProverState::prover_init(poly)?;
let mut challenge = None;
let mut prover_msgs = Vec::with_capacity(poly.aux_info.num_variables);
poly: &VirtualPolynomial<C::ScalarField>,
transcript: &mut impl Transcript<C>,
) -> Result<IOPProof<C::ScalarField>, PolyIOPErrors> {
transcript.absorb(&<C as Group>::ScalarField::from(
poly.aux_info.num_variables as u64,
));
transcript.absorb(&<C as Group>::ScalarField::from(
poly.aux_info.max_degree as u64,
));
let mut prover_state: IOPProverState<C> = IOPProverState::prover_init(poly)?;
let mut challenge: Option<C::ScalarField> = None;
let mut prover_msgs: Vec<IOPProverMessage<C::ScalarField>> =
Vec::with_capacity(poly.aux_info.num_variables);
for _ in 0..poly.aux_info.num_variables {
let prover_msg =
let prover_msg: IOPProverMessage<C::ScalarField> =
IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge)?;
transcript.append_serializable_element(b"prover msg", &prover_msg)?;
transcript.absorb_vec(&prover_msg.evaluations);
prover_msgs.push(prover_msg);
challenge = Some(transcript.get_and_append_challenge(b"Internal round")?);
challenge = Some(transcript.get_challenge());
}
// pushing the last challenge point to the state
if let Some(p) = challenge {
prover_state.challenges.push(p)
};
end_timer!(start);
Ok(IOPProof {
point: prover_state.challenges,
proofs: prover_msgs,
@ -184,18 +179,19 @@ impl SumCheck for PolyIOP {
}
fn verify(
claimed_sum: F,
proof: &Self::SumCheckProof,
aux_info: &Self::VPAuxInfo,
transcript: &mut Self::Transcript,
) -> Result<Self::SumCheckSubClaim, PolyIOPErrors> {
let start = start_timer!(|| "sum check verify");
transcript.append_serializable_element(b"aux info", aux_info)?;
claimed_sum: C::ScalarField,
proof: &IOPProof<C::ScalarField>,
aux_info: &VPAuxInfo<C::ScalarField>,
transcript: &mut impl Transcript<C>,
) -> Result<SumCheckSubClaim<C::ScalarField>, PolyIOPErrors> {
transcript.absorb(&<C as Group>::ScalarField::from(
aux_info.num_variables as u64,
));
transcript.absorb(&<C as Group>::ScalarField::from(aux_info.max_degree as u64));
let mut verifier_state = IOPVerifierState::verifier_init(aux_info);
for i in 0..aux_info.num_variables {
let prover_msg = proof.proofs.get(i).expect("proof is incomplete");
transcript.append_serializable_element(b"prover msg", prover_msg)?;
transcript.absorb_vec(&prover_msg.evaluations);
IOPVerifierState::verify_round_and_update_state(
&mut verifier_state,
prover_msg,
@ -203,9 +199,57 @@ impl SumCheck for PolyIOP {
)?;
}
let res = IOPVerifierState::check_and_generate_subclaim(&verifier_state, &claimed_sum);
IOPVerifierState::check_and_generate_subclaim(&verifier_state, &claimed_sum)
}
}
end_timer!(start);
res
#[cfg(test)]
pub mod tests {
use std::sync::Arc;
use ark_ff::Field;
use ark_pallas::Fr;
use ark_pallas::Projective;
use ark_poly::DenseMultilinearExtension;
use ark_poly::MultilinearExtension;
use ark_std::test_rng;
use crate::transcript::poseidon::tests::poseidon_test_config;
use crate::transcript::poseidon::PoseidonTranscript;
use crate::transcript::Transcript;
use crate::utils::sum_check::SumCheck;
use crate::utils::virtual_polynomial::VirtualPolynomial;
use super::IOPSumCheck;
#[test]
pub fn sumcheck_poseidon() {
let mut rng = test_rng();
let poly_mle = DenseMultilinearExtension::rand(5, &mut rng);
let virtual_poly = VirtualPolynomial::new_from_mle(&Arc::new(poly_mle), Fr::ONE);
let poseidon_config = poseidon_test_config::<Fr>();
// sum-check prove
let mut poseidon_transcript_prove: PoseidonTranscript<Projective> =
PoseidonTranscript::<Projective>::new(&poseidon_config);
let sum_check = IOPSumCheck::<Projective, PoseidonTranscript<Projective>>::prove(
&virtual_poly,
&mut poseidon_transcript_prove,
)
.unwrap();
// sum-check verify
let claimed_sum =
IOPSumCheck::<Projective, PoseidonTranscript<Projective>>::extract_sum(&sum_check);
let mut poseidon_transcript_verify: PoseidonTranscript<Projective> =
PoseidonTranscript::<Projective>::new(&poseidon_config);
let res_verify = IOPSumCheck::<Projective, PoseidonTranscript<Projective>>::verify(
claimed_sum,
&sum_check,
&virtual_poly.aux_info,
&mut poseidon_transcript_verify,
);
assert!(res_verify.is_ok());
}
}

+ 17
- 13
src/utils/espresso/sum_check/prover.rs

@ -12,6 +12,8 @@
use super::SumCheckProver;
use crate::utils::multilinear_polynomial::fix_variables;
use crate::utils::virtual_polynomial::VirtualPolynomial;
use ark_ec::CurveGroup;
use ark_ff::Field;
use ark_ff::{batch_inversion, PrimeField};
use ark_poly::DenseMultilinearExtension;
use ark_std::{cfg_into_iter, end_timer, start_timer, vec::Vec};
@ -24,9 +26,9 @@ use espresso_subroutines::poly_iop::prelude::PolyIOPErrors;
// #[cfg(feature = "parallel")]
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
type VirtualPolynomial = VirtualPolynomial<F>;
type ProverMessage = IOPProverMessage<F>;
impl<C: CurveGroup> SumCheckProver<C> for IOPProverState<C> {
type VirtualPolynomial = VirtualPolynomial<C::ScalarField>;
type ProverMessage = IOPProverMessage<C::ScalarField>;
/// Initialize the prover state to argue for the sum of the input polynomial
/// over {0,1}^`num_vars`.
@ -45,7 +47,9 @@ impl SumCheckProver for IOPProverState {
poly: polynomial.clone(),
extrapolation_aux: (1..polynomial.aux_info.max_degree)
.map(|degree| {
let points = (0..1 + degree as u64).map(F::from).collect::<Vec<_>>();
let points = (0..1 + degree as u64)
.map(C::ScalarField::from)
.collect::<Vec<_>>();
let weights = barycentric_weights(&points);
(points, weights)
})
@ -59,7 +63,7 @@ impl SumCheckProver for IOPProverState {
/// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2).
fn prove_round_and_update_state(
&mut self,
challenge: &Option<F>,
challenge: &Option<C::ScalarField>,
) -> Result<Self::ProverMessage, PolyIOPErrors> {
// let start =
// start_timer!(|| format!("sum check prove {}-th round and update state",
@ -84,7 +88,7 @@ impl SumCheckProver for IOPProverState {
// g(r_1, ..., r_{m-1}, x_m ... x_n)
//
// eval g over r_m, and mutate g to g(r_1, ... r_m,, x_{m+1}... x_n)
let mut flattened_ml_extensions: Vec<DenseMultilinearExtension<F>> = self
let mut flattened_ml_extensions: Vec<DenseMultilinearExtension<C::ScalarField>> = self
.poly
.flattened_ml_extensions
.par_iter()
@ -118,7 +122,7 @@ impl SumCheckProver for IOPProverState {
self.round += 1;
let products_list = self.poly.products.clone();
let mut products_sum = vec![F::zero(); self.poly.aux_info.max_degree + 1];
let mut products_sum = vec![C::ScalarField::ZERO; self.poly.aux_info.max_degree + 1];
// Step 2: generate sum for the partial evaluated polynomial:
// f(r_1, ... r_m,, x_{m+1}... x_n)
@ -128,8 +132,8 @@ impl SumCheckProver for IOPProverState {
.fold(
|| {
(
vec![(F::zero(), F::zero()); products.len()],
vec![F::zero(); products.len() + 1],
vec![(C::ScalarField::ZERO, C::ScalarField::ZERO); products.len()],
vec![C::ScalarField::ZERO; products.len() + 1],
)
},
|(mut buf, mut acc), b| {
@ -140,17 +144,17 @@ impl SumCheckProver for IOPProverState {
*eval = table[b << 1];
*step = table[(b << 1) + 1] - table[b << 1];
});
acc[0] += buf.iter().map(|(eval, _)| eval).product::<F>();
acc[0] += buf.iter().map(|(eval, _)| eval).product::<C::ScalarField>();
acc[1..].iter_mut().for_each(|acc| {
buf.iter_mut().for_each(|(eval, step)| *eval += step as &_);
*acc += buf.iter().map(|(eval, _)| eval).product::<F>();
*acc += buf.iter().map(|(eval, _)| eval).product::<C::ScalarField>();
});
(buf, acc)
},
)
.map(|(_, partial)| partial)
.reduce(
|| vec![F::zero(); products.len() + 1],
|| vec![C::ScalarField::ZERO; products.len() + 1],
|mut sum, partial| {
sum.iter_mut()
.zip(partial.iter())
@ -162,7 +166,7 @@ impl SumCheckProver for IOPProverState {
let extraploation = cfg_into_iter!(0..self.poly.aux_info.max_degree - products.len())
.map(|i| {
let (points, weights) = &self.extrapolation_aux[products.len() - 1];
let at = F::from((products.len() + 1 + i) as u64);
let at = C::ScalarField::from((products.len() + 1 + i) as u64);
extrapolate(points, weights, &sum, &at)
})
.collect::<Vec<_>>();

+ 10
- 8
src/utils/espresso/sum_check/structs.rs

@ -10,6 +10,7 @@
//! This module defines structs that are shared by all sub protocols.
use crate::utils::virtual_polynomial::VirtualPolynomial;
use ark_ec::CurveGroup;
use ark_ff::PrimeField;
use ark_serialize::CanonicalSerialize;
@ -32,28 +33,29 @@ pub struct IOPProverMessage {
/// Prover State of a PolyIOP.
#[derive(Debug)]
pub struct IOPProverState<F: PrimeField> {
pub struct IOPProverState<C: CurveGroup> {
/// sampled randomness given by the verifier
pub challenges: Vec<F>,
pub challenges: Vec<C::ScalarField>,
/// the current round number
pub(crate) round: usize,
/// pointer to the virtual polynomial
pub(crate) poly: VirtualPolynomial<F>,
pub(crate) poly: VirtualPolynomial<C::ScalarField>,
/// points with precomputed barycentric weights for extrapolating smaller
/// degree uni-polys to `max_degree + 1` evaluations.
pub(crate) extrapolation_aux: Vec<(Vec<F>, Vec<F>)>,
#[allow(clippy::type_complexity)]
pub(crate) extrapolation_aux: Vec<(Vec<C::ScalarField>, Vec<C::ScalarField>)>,
}
/// Prover State of a PolyIOP
/// Verifier State of a PolyIOP, generic over a curve group
#[derive(Debug)]
pub struct IOPVerifierState<F: PrimeField> {
pub struct IOPVerifierState<C: CurveGroup> {
pub(crate) round: usize,
pub(crate) num_vars: usize,
pub(crate) max_degree: usize,
pub(crate) finished: bool,
/// a list storing the univariate polynomial in evaluation form sent by the
/// prover at each round
pub(crate) polynomials_received: Vec<Vec<F>>,
pub(crate) polynomials_received: Vec<Vec<C::ScalarField>>,
/// a list storing the randomness sampled by the verifier at each round
pub(crate) challenges: Vec<F>,
pub(crate) challenges: Vec<C::ScalarField>,
}

+ 20
- 31
src/utils/espresso/sum_check/verifier.rs

@ -9,24 +9,25 @@
//! Verifier subroutines for a SumCheck protocol.
use super::{SumCheckSubClaim, SumCheckVerifier};
use crate::utils::virtual_polynomial::VPAuxInfo;
use super::{
structs::{IOPProverMessage, IOPVerifierState},
SumCheckSubClaim, SumCheckVerifier,
};
use crate::{transcript::Transcript, utils::virtual_polynomial::VPAuxInfo};
use ark_ec::CurveGroup;
use ark_ff::PrimeField;
use ark_std::{end_timer, start_timer};
use super::structs::{IOPProverMessage, IOPVerifierState};
use espresso_subroutines::poly_iop::prelude::PolyIOPErrors;
use espresso_transcript::IOPTranscript;
#[cfg(feature = "parallel")]
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
impl<F: PrimeField> SumCheckVerifier<F> for IOPVerifierState<F> {
type VPAuxInfo = VPAuxInfo<F>;
type ProverMessage = IOPProverMessage<F>;
type Challenge = F;
type Transcript = IOPTranscript<F>;
type SumCheckSubClaim = SumCheckSubClaim<F>;
impl<C: CurveGroup> SumCheckVerifier<C> for IOPVerifierState<C> {
type VPAuxInfo = VPAuxInfo<C::ScalarField>;
type ProverMessage = IOPProverMessage<C::ScalarField>;
type Challenge = C::ScalarField;
type SumCheckSubClaim = SumCheckSubClaim<C::ScalarField>;
/// Initialize the verifier's state.
fn verifier_init(index_info: &Self::VPAuxInfo) -> Self {
@ -43,17 +44,11 @@ impl SumCheckVerifier for IOPVerifierState {
res
}
/// Run verifier for the current round, given a prover message.
///
/// Note that `verify_round_and_update_state` only samples and stores
/// challenges; and update the verifier's state accordingly. The actual
/// verifications are deferred (in batch) to `check_and_generate_subclaim`
/// at the last step.
fn verify_round_and_update_state(
&mut self,
prover_msg: &Self::ProverMessage,
transcript: &mut Self::Transcript,
) -> Result<Self::Challenge, PolyIOPErrors> {
prover_msg: &<IOPVerifierState<C> as SumCheckVerifier<C>>::ProverMessage,
transcript: &mut impl Transcript<C>,
) -> Result<<IOPVerifierState<C> as SumCheckVerifier<C>>::Challenge, PolyIOPErrors> {
let start =
start_timer!(|| format!("sum check verify {}-th round and update state", self.round));
@ -70,8 +65,7 @@ impl SumCheckVerifier for IOPVerifierState {
//
// When we turn the protocol to a non-interactive one, it is sufficient to defer
// such checks to `check_and_generate_subclaim` after the last round.
let challenge = transcript.get_and_append_challenge(b"Internal round")?;
let challenge = transcript.get_challenge();
self.challenges.push(challenge);
self.polynomials_received
.push(prover_msg.evaluations.to_vec());
@ -88,17 +82,9 @@ impl SumCheckVerifier for IOPVerifierState {
Ok(challenge)
}
/// This function verifies the deferred checks in the interactive version of
/// the protocol; and generate the subclaim. Returns an error if the
/// proof failed to verify.
///
/// If the asserted sum is correct, then the multilinear polynomial
/// evaluated at `subclaim.point` will be `subclaim.expected_evaluation`.
/// Otherwise, it is highly unlikely that those two will be equal.
/// Larger field size guarantees smaller soundness error.
fn check_and_generate_subclaim(
&self,
asserted_sum: &F,
asserted_sum: &C::ScalarField,
) -> Result<Self::SumCheckSubClaim, PolyIOPErrors> {
let start = start_timer!(|| "sum check check and generate subclaim");
if !self.finished {
@ -129,7 +115,7 @@ impl SumCheckVerifier for IOPVerifierState {
self.max_degree + 1
)));
}
interpolate_uni_poly::<F>(&evaluations, challenge)
interpolate_uni_poly::<C::ScalarField>(&evaluations, challenge)
})
.collect::<Result<Vec<_>, PolyIOPErrors>>()?;
@ -160,6 +146,9 @@ impl SumCheckVerifier for IOPVerifierState {
.zip(expected_vec.iter())
.take(self.num_vars)
{
let eval_: C::ScalarField = evaluations[0] + evaluations[1];
println!("evaluations: {:?}, expected: {:?}", eval_, expected);
// the deferred check during the interactive phase:
// 1. check if the received 'P(0) + P(1) = expected`.
if evaluations[0] + evaluations[1] != expected {

Loading…
Cancel
Save