Port Espresso/hyperplonk's `virtualpolynomial`, `multilinearpolynomial` and `sum_check` utils from https://github.com/EspressoSystems/hyperplonk/tree/main Each file contains the reference to the original file. Porting it into a subdirectory `src/utils/espresso`, to have it self-contained. In future iterations we might replace part of it but we can keep focusing on the folding schemes part for now.main
@ -0,0 +1,3 @@ | 
															
														|||||
 | 
																pub mod multilinear_polynomial;
 | 
															
														||||
 | 
																pub mod sum_check;
 | 
															
														||||
 | 
																pub mod virtual_polynomial;
 | 
															
														||||
@ -0,0 +1,200 @@ | 
															
														|||||
 | 
																// code forked from
 | 
															
														||||
 | 
																// https://github.com/EspressoSystems/hyperplonk/blob/main/arithmetic/src/multilinear_polynomial.rs
 | 
															
														||||
 | 
																//
 | 
															
														||||
 | 
																// Copyright (c) 2023 Espresso Systems (espressosys.com)
 | 
															
														||||
 | 
																// This file is part of the HyperPlonk library.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																// You should have received a copy of the MIT License
 | 
															
														||||
 | 
																// along with the HyperPlonk library. If not, see <https://mit-license.org/>.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																use ark_ff::Field;
 | 
															
														||||
 | 
																#[cfg(feature = "parallel")]
 | 
															
														||||
 | 
																use rayon::prelude::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																pub use ark_poly::DenseMultilinearExtension;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																pub fn fix_variables<F: Field>(
 | 
															
														||||
 | 
																    poly: &DenseMultilinearExtension<F>,
 | 
															
														||||
 | 
																    partial_point: &[F],
 | 
															
														||||
 | 
																) -> DenseMultilinearExtension<F> {
 | 
															
														||||
 | 
																    assert!(
 | 
															
														||||
 | 
																        partial_point.len() <= poly.num_vars,
 | 
															
														||||
 | 
																        "invalid size of partial point"
 | 
															
														||||
 | 
																    );
 | 
															
														||||
 | 
																    let nv = poly.num_vars;
 | 
															
														||||
 | 
																    let mut poly = poly.evaluations.to_vec();
 | 
															
														||||
 | 
																    let dim = partial_point.len();
 | 
															
														||||
 | 
																    // evaluate single variable of partial point from left to right
 | 
															
														||||
 | 
																    for (i, point) in partial_point.iter().enumerate().take(dim) {
 | 
															
														||||
 | 
																        poly = fix_one_variable_helper(&poly, nv - i, point);
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    DenseMultilinearExtension::<F>::from_evaluations_slice(nv - dim, &poly[..(1 << (nv - dim))])
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																fn fix_one_variable_helper<F: Field>(data: &[F], nv: usize, point: &F) -> Vec<F> {
 | 
															
														||||
 | 
																    let mut res = vec![F::zero(); 1 << (nv - 1)];
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    // evaluate single variable of partial point from left to right
 | 
															
														||||
 | 
																    #[cfg(not(feature = "parallel"))]
 | 
															
														||||
 | 
																    for i in 0..(1 << (nv - 1)) {
 | 
															
														||||
 | 
																        res[i] = data[i << 1] + (data[(i << 1) + 1] - data[i << 1]) * point;
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    #[cfg(feature = "parallel")]
 | 
															
														||||
 | 
																    res.par_iter_mut().enumerate().for_each(|(i, x)| {
 | 
															
														||||
 | 
																        *x = data[i << 1] + (data[(i << 1) + 1] - data[i << 1]) * point;
 | 
															
														||||
 | 
																    });
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    res
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																pub fn evaluate_no_par<F: Field>(poly: &DenseMultilinearExtension<F>, point: &[F]) -> F {
 | 
															
														||||
 | 
																    assert_eq!(poly.num_vars, point.len());
 | 
															
														||||
 | 
																    fix_variables_no_par(poly, point).evaluations[0]
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																fn fix_variables_no_par<F: Field>(
 | 
															
														||||
 | 
																    poly: &DenseMultilinearExtension<F>,
 | 
															
														||||
 | 
																    partial_point: &[F],
 | 
															
														||||
 | 
																) -> DenseMultilinearExtension<F> {
 | 
															
														||||
 | 
																    assert!(
 | 
															
														||||
 | 
																        partial_point.len() <= poly.num_vars,
 | 
															
														||||
 | 
																        "invalid size of partial point"
 | 
															
														||||
 | 
																    );
 | 
															
														||||
 | 
																    let nv = poly.num_vars;
 | 
															
														||||
 | 
																    let mut poly = poly.evaluations.to_vec();
 | 
															
														||||
 | 
																    let dim = partial_point.len();
 | 
															
														||||
 | 
																    // evaluate single variable of partial point from left to right
 | 
															
														||||
 | 
																    for i in 1..dim + 1 {
 | 
															
														||||
 | 
																        let r = partial_point[i - 1];
 | 
															
														||||
 | 
																        for b in 0..(1 << (nv - i)) {
 | 
															
														||||
 | 
																            poly[b] = poly[b << 1] + (poly[(b << 1) + 1] - poly[b << 1]) * r;
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																    DenseMultilinearExtension::from_evaluations_slice(nv - dim, &poly[..(1 << (nv - dim))])
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																/// Given multilinear polynomial `p(x)` and s `s`, compute `s*p(x)`
 | 
															
														||||
 | 
																pub fn scalar_mul<F: Field>(
 | 
															
														||||
 | 
																    poly: &DenseMultilinearExtension<F>,
 | 
															
														||||
 | 
																    s: &F,
 | 
															
														||||
 | 
																) -> DenseMultilinearExtension<F> {
 | 
															
														||||
 | 
																    DenseMultilinearExtension {
 | 
															
														||||
 | 
																        evaluations: poly.evaluations.iter().map(|e| *e * s).collect(),
 | 
															
														||||
 | 
																        num_vars: poly.num_vars,
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																/// Test-only methods used in virtual_polynomial.rs
 | 
															
														||||
 | 
																#[cfg(test)]
 | 
															
														||||
 | 
																pub mod tests {
 | 
															
														||||
 | 
																    use super::*;
 | 
															
														||||
 | 
																    use ark_ff::PrimeField;
 | 
															
														||||
 | 
																    use ark_std::rand::RngCore;
 | 
															
														||||
 | 
																    use ark_std::{end_timer, start_timer};
 | 
															
														||||
 | 
																    use std::sync::Arc;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    pub fn fix_last_variables<F: PrimeField>(
 | 
															
														||||
 | 
																        poly: &DenseMultilinearExtension<F>,
 | 
															
														||||
 | 
																        partial_point: &[F],
 | 
															
														||||
 | 
																    ) -> DenseMultilinearExtension<F> {
 | 
															
														||||
 | 
																        assert!(
 | 
															
														||||
 | 
																            partial_point.len() <= poly.num_vars,
 | 
															
														||||
 | 
																            "invalid size of partial point"
 | 
															
														||||
 | 
																        );
 | 
															
														||||
 | 
																        let nv = poly.num_vars;
 | 
															
														||||
 | 
																        let mut poly = poly.evaluations.to_vec();
 | 
															
														||||
 | 
																        let dim = partial_point.len();
 | 
															
														||||
 | 
																        // evaluate single variable of partial point from left to right
 | 
															
														||||
 | 
																        for (i, point) in partial_point.iter().rev().enumerate().take(dim) {
 | 
															
														||||
 | 
																            poly = fix_last_variable_helper(&poly, nv - i, point);
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        DenseMultilinearExtension::<F>::from_evaluations_slice(nv - dim, &poly[..(1 << (nv - dim))])
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    fn fix_last_variable_helper<F: Field>(data: &[F], nv: usize, point: &F) -> Vec<F> {
 | 
															
														||||
 | 
																        let half_len = 1 << (nv - 1);
 | 
															
														||||
 | 
																        let mut res = vec![F::zero(); half_len];
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        // evaluate single variable of partial point from left to right
 | 
															
														||||
 | 
																        #[cfg(not(feature = "parallel"))]
 | 
															
														||||
 | 
																        for b in 0..half_len {
 | 
															
														||||
 | 
																            res[b] = data[b] + (data[b + half_len] - data[b]) * point;
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        #[cfg(feature = "parallel")]
 | 
															
														||||
 | 
																        res.par_iter_mut().enumerate().for_each(|(i, x)| {
 | 
															
														||||
 | 
																            *x = data[i] + (data[i + half_len] - data[i]) * point;
 | 
															
														||||
 | 
																        });
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        res
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Sample a random list of multilinear polynomials.
 | 
															
														||||
 | 
																    /// Returns
 | 
															
														||||
 | 
																    /// - the list of polynomials,
 | 
															
														||||
 | 
																    /// - its sum of polynomial evaluations over the boolean hypercube.
 | 
															
														||||
 | 
																    #[cfg(test)]
 | 
															
														||||
 | 
																    pub fn random_mle_list<F: PrimeField, R: RngCore>(
 | 
															
														||||
 | 
																        nv: usize,
 | 
															
														||||
 | 
																        degree: usize,
 | 
															
														||||
 | 
																        rng: &mut R,
 | 
															
														||||
 | 
																    ) -> (Vec<Arc<DenseMultilinearExtension<F>>>, F) {
 | 
															
														||||
 | 
																        let start = start_timer!(|| "sample random mle list");
 | 
															
														||||
 | 
																        let mut multiplicands = Vec::with_capacity(degree);
 | 
															
														||||
 | 
																        for _ in 0..degree {
 | 
															
														||||
 | 
																            multiplicands.push(Vec::with_capacity(1 << nv))
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																        let mut sum = F::zero();
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        for _ in 0..(1 << nv) {
 | 
															
														||||
 | 
																            let mut product = F::one();
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																            for e in multiplicands.iter_mut() {
 | 
															
														||||
 | 
																                let val = F::rand(rng);
 | 
															
														||||
 | 
																                e.push(val);
 | 
															
														||||
 | 
																                product *= val;
 | 
															
														||||
 | 
																            }
 | 
															
														||||
 | 
																            sum += product;
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        let list = multiplicands
 | 
															
														||||
 | 
																            .into_iter()
 | 
															
														||||
 | 
																            .map(|x| Arc::new(DenseMultilinearExtension::from_evaluations_vec(nv, x)))
 | 
															
														||||
 | 
																            .collect();
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        end_timer!(start);
 | 
															
														||||
 | 
																        (list, sum)
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    // Build a randomize list of mle-s whose sum is zero.
 | 
															
														||||
 | 
																    #[cfg(test)]
 | 
															
														||||
 | 
																    pub fn random_zero_mle_list<F: PrimeField, R: RngCore>(
 | 
															
														||||
 | 
																        nv: usize,
 | 
															
														||||
 | 
																        degree: usize,
 | 
															
														||||
 | 
																        rng: &mut R,
 | 
															
														||||
 | 
																    ) -> Vec<Arc<DenseMultilinearExtension<F>>> {
 | 
															
														||||
 | 
																        let start = start_timer!(|| "sample random zero mle list");
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        let mut multiplicands = Vec::with_capacity(degree);
 | 
															
														||||
 | 
																        for _ in 0..degree {
 | 
															
														||||
 | 
																            multiplicands.push(Vec::with_capacity(1 << nv))
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																        for _ in 0..(1 << nv) {
 | 
															
														||||
 | 
																            multiplicands[0].push(F::zero());
 | 
															
														||||
 | 
																            for e in multiplicands.iter_mut().skip(1) {
 | 
															
														||||
 | 
																                e.push(F::rand(rng));
 | 
															
														||||
 | 
																            }
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        let list = multiplicands
 | 
															
														||||
 | 
																            .into_iter()
 | 
															
														||||
 | 
																            .map(|x| Arc::new(DenseMultilinearExtension::from_evaluations_vec(nv, x)))
 | 
															
														||||
 | 
																            .collect();
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        end_timer!(start);
 | 
															
														||||
 | 
																        list
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																}
 | 
															
														||||
@ -0,0 +1,211 @@ | 
															
														|||||
 | 
																// code forked from:
 | 
															
														||||
 | 
																// https://github.com/EspressoSystems/hyperplonk/tree/main/subroutines/src/poly_iop/sum_check
 | 
															
														||||
 | 
																//
 | 
															
														||||
 | 
																// Copyright (c) 2023 Espresso Systems (espressosys.com)
 | 
															
														||||
 | 
																// This file is part of the HyperPlonk library.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																// You should have received a copy of the MIT License
 | 
															
														||||
 | 
																// along with the HyperPlonk library. If not, see <https://mit-license.org/>.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																//! This module implements the sum check protocol.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																use crate::utils::virtual_polynomial::{VPAuxInfo, VirtualPolynomial};
 | 
															
														||||
 | 
																use ark_ff::PrimeField;
 | 
															
														||||
 | 
																use ark_poly::DenseMultilinearExtension;
 | 
															
														||||
 | 
																use ark_std::{end_timer, start_timer};
 | 
															
														||||
 | 
																use std::{fmt::Debug, sync::Arc};
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																use espresso_subroutines::poly_iop::{prelude::PolyIOPErrors, PolyIOP};
 | 
															
														||||
 | 
																use espresso_transcript::IOPTranscript;
 | 
															
														||||
 | 
																use structs::{IOPProof, IOPProverState, IOPVerifierState};
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																mod prover;
 | 
															
														||||
 | 
																pub mod structs;
 | 
															
														||||
 | 
																pub mod verifier;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																/// Trait for doing sum check protocols.
 | 
															
														||||
 | 
																pub trait SumCheck<F: PrimeField> {
 | 
															
														||||
 | 
																    type VirtualPolynomial;
 | 
															
														||||
 | 
																    type VPAuxInfo;
 | 
															
														||||
 | 
																    type MultilinearExtension;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    type SumCheckProof: Clone + Debug + Default + PartialEq;
 | 
															
														||||
 | 
																    type Transcript;
 | 
															
														||||
 | 
																    type SumCheckSubClaim: Clone + Debug + Default + PartialEq;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Extract sum from the proof
 | 
															
														||||
 | 
																    fn extract_sum(proof: &Self::SumCheckProof) -> F;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Initialize the system with a transcript
 | 
															
														||||
 | 
																    ///
 | 
															
														||||
 | 
																    /// This function is optional -- in the case where a SumCheck is
 | 
															
														||||
 | 
																    /// an building block for a more complex protocol, the transcript
 | 
															
														||||
 | 
																    /// may be initialized by this complex protocol, and passed to the
 | 
															
														||||
 | 
																    /// SumCheck prover/verifier.
 | 
															
														||||
 | 
																    fn init_transcript() -> Self::Transcript;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Generate proof of the sum of polynomial over {0,1}^`num_vars`
 | 
															
														||||
 | 
																    ///
 | 
															
														||||
 | 
																    /// The polynomial is represented in the form of a VirtualPolynomial.
 | 
															
														||||
 | 
																    fn prove(
 | 
															
														||||
 | 
																        poly: &Self::VirtualPolynomial,
 | 
															
														||||
 | 
																        transcript: &mut Self::Transcript,
 | 
															
														||||
 | 
																    ) -> Result<Self::SumCheckProof, PolyIOPErrors>;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Verify the claimed sum using the proof
 | 
															
														||||
 | 
																    fn verify(
 | 
															
														||||
 | 
																        sum: F,
 | 
															
														||||
 | 
																        proof: &Self::SumCheckProof,
 | 
															
														||||
 | 
																        aux_info: &Self::VPAuxInfo,
 | 
															
														||||
 | 
																        transcript: &mut Self::Transcript,
 | 
															
														||||
 | 
																    ) -> Result<Self::SumCheckSubClaim, PolyIOPErrors>;
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																/// Trait for sum check protocol prover side APIs.
 | 
															
														||||
 | 
																pub trait SumCheckProver<F: PrimeField>
 | 
															
														||||
 | 
																where
 | 
															
														||||
 | 
																    Self: Sized,
 | 
															
														||||
 | 
																{
 | 
															
														||||
 | 
																    type VirtualPolynomial;
 | 
															
														||||
 | 
																    type ProverMessage;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Initialize the prover state to argue for the sum of the input polynomial
 | 
															
														||||
 | 
																    /// over {0,1}^`num_vars`.
 | 
															
														||||
 | 
																    fn prover_init(polynomial: &Self::VirtualPolynomial) -> Result<Self, PolyIOPErrors>;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Receive message from verifier, generate prover message, and proceed to
 | 
															
														||||
 | 
																    /// next round.
 | 
															
														||||
 | 
																    ///
 | 
															
														||||
 | 
																    /// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2).
 | 
															
														||||
 | 
																    fn prove_round_and_update_state(
 | 
															
														||||
 | 
																        &mut self,
 | 
															
														||||
 | 
																        challenge: &Option<F>,
 | 
															
														||||
 | 
																    ) -> Result<Self::ProverMessage, PolyIOPErrors>;
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																/// Trait for sum check protocol verifier side APIs.
 | 
															
														||||
 | 
																pub trait SumCheckVerifier<F: PrimeField> {
 | 
															
														||||
 | 
																    type VPAuxInfo;
 | 
															
														||||
 | 
																    type ProverMessage;
 | 
															
														||||
 | 
																    type Challenge;
 | 
															
														||||
 | 
																    type Transcript;
 | 
															
														||||
 | 
																    type SumCheckSubClaim;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Initialize the verifier's state.
 | 
															
														||||
 | 
																    fn verifier_init(index_info: &Self::VPAuxInfo) -> Self;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Run verifier for the current round, given a prover message.
 | 
															
														||||
 | 
																    ///
 | 
															
														||||
 | 
																    /// Note that `verify_round_and_update_state` only samples and stores
 | 
															
														||||
 | 
																    /// challenges; and update the verifier's state accordingly. The actual
 | 
															
														||||
 | 
																    /// verifications are deferred (in batch) to `check_and_generate_subclaim`
 | 
															
														||||
 | 
																    /// at the last step.
 | 
															
														||||
 | 
																    fn verify_round_and_update_state(
 | 
															
														||||
 | 
																        &mut self,
 | 
															
														||||
 | 
																        prover_msg: &Self::ProverMessage,
 | 
															
														||||
 | 
																        transcript: &mut Self::Transcript,
 | 
															
														||||
 | 
																    ) -> Result<Self::Challenge, PolyIOPErrors>;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// This function verifies the deferred checks in the interactive version of
 | 
															
														||||
 | 
																    /// the protocol; and generate the subclaim. Returns an error if the
 | 
															
														||||
 | 
																    /// proof failed to verify.
 | 
															
														||||
 | 
																    ///
 | 
															
														||||
 | 
																    /// If the asserted sum is correct, then the multilinear polynomial
 | 
															
														||||
 | 
																    /// evaluated at `subclaim.point` will be `subclaim.expected_evaluation`.
 | 
															
														||||
 | 
																    /// Otherwise, it is highly unlikely that those two will be equal.
 | 
															
														||||
 | 
																    /// Larger field size guarantees smaller soundness error.
 | 
															
														||||
 | 
																    fn check_and_generate_subclaim(
 | 
															
														||||
 | 
																        &self,
 | 
															
														||||
 | 
																        asserted_sum: &F,
 | 
															
														||||
 | 
																    ) -> Result<Self::SumCheckSubClaim, PolyIOPErrors>;
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																/// A SumCheckSubClaim is a claim generated by the verifier at the end of
 | 
															
														||||
 | 
																/// verification when it is convinced.
 | 
															
														||||
 | 
																#[derive(Clone, Debug, Default, PartialEq, Eq)]
 | 
															
														||||
 | 
																pub struct SumCheckSubClaim<F: PrimeField> {
 | 
															
														||||
 | 
																    /// the multi-dimensional point that this multilinear extension is evaluated
 | 
															
														||||
 | 
																    /// to
 | 
															
														||||
 | 
																    pub point: Vec<F>,
 | 
															
														||||
 | 
																    /// the expected evaluation
 | 
															
														||||
 | 
																    pub expected_evaluation: F,
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																impl<F: PrimeField> SumCheck<F> for PolyIOP<F> {
 | 
															
														||||
 | 
																    type SumCheckProof = IOPProof<F>;
 | 
															
														||||
 | 
																    type VirtualPolynomial = VirtualPolynomial<F>;
 | 
															
														||||
 | 
																    type VPAuxInfo = VPAuxInfo<F>;
 | 
															
														||||
 | 
																    type MultilinearExtension = Arc<DenseMultilinearExtension<F>>;
 | 
															
														||||
 | 
																    type SumCheckSubClaim = SumCheckSubClaim<F>;
 | 
															
														||||
 | 
																    type Transcript = IOPTranscript<F>;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    fn extract_sum(proof: &Self::SumCheckProof) -> F {
 | 
															
														||||
 | 
																        let start = start_timer!(|| "extract sum");
 | 
															
														||||
 | 
																        let res = proof.proofs[0].evaluations[0] + proof.proofs[0].evaluations[1];
 | 
															
														||||
 | 
																        end_timer!(start);
 | 
															
														||||
 | 
																        res
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    fn init_transcript() -> Self::Transcript {
 | 
															
														||||
 | 
																        let start = start_timer!(|| "init transcript");
 | 
															
														||||
 | 
																        let res = IOPTranscript::<F>::new(b"Initializing SumCheck transcript");
 | 
															
														||||
 | 
																        end_timer!(start);
 | 
															
														||||
 | 
																        res
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    fn prove(
 | 
															
														||||
 | 
																        poly: &Self::VirtualPolynomial,
 | 
															
														||||
 | 
																        transcript: &mut Self::Transcript,
 | 
															
														||||
 | 
																    ) -> Result<Self::SumCheckProof, PolyIOPErrors> {
 | 
															
														||||
 | 
																        let start = start_timer!(|| "sum check prove");
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        transcript.append_serializable_element(b"aux info", &poly.aux_info)?;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        let mut prover_state = IOPProverState::prover_init(poly)?;
 | 
															
														||||
 | 
																        let mut challenge = None;
 | 
															
														||||
 | 
																        let mut prover_msgs = Vec::with_capacity(poly.aux_info.num_variables);
 | 
															
														||||
 | 
																        for _ in 0..poly.aux_info.num_variables {
 | 
															
														||||
 | 
																            let prover_msg =
 | 
															
														||||
 | 
																                IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge)?;
 | 
															
														||||
 | 
																            transcript.append_serializable_element(b"prover msg", &prover_msg)?;
 | 
															
														||||
 | 
																            prover_msgs.push(prover_msg);
 | 
															
														||||
 | 
																            challenge = Some(transcript.get_and_append_challenge(b"Internal round")?);
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																        // pushing the last challenge point to the state
 | 
															
														||||
 | 
																        if let Some(p) = challenge {
 | 
															
														||||
 | 
																            prover_state.challenges.push(p)
 | 
															
														||||
 | 
																        };
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        end_timer!(start);
 | 
															
														||||
 | 
																        Ok(IOPProof {
 | 
															
														||||
 | 
																            point: prover_state.challenges,
 | 
															
														||||
 | 
																            proofs: prover_msgs,
 | 
															
														||||
 | 
																        })
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    fn verify(
 | 
															
														||||
 | 
																        claimed_sum: F,
 | 
															
														||||
 | 
																        proof: &Self::SumCheckProof,
 | 
															
														||||
 | 
																        aux_info: &Self::VPAuxInfo,
 | 
															
														||||
 | 
																        transcript: &mut Self::Transcript,
 | 
															
														||||
 | 
																    ) -> Result<Self::SumCheckSubClaim, PolyIOPErrors> {
 | 
															
														||||
 | 
																        let start = start_timer!(|| "sum check verify");
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        transcript.append_serializable_element(b"aux info", aux_info)?;
 | 
															
														||||
 | 
																        let mut verifier_state = IOPVerifierState::verifier_init(aux_info);
 | 
															
														||||
 | 
																        for i in 0..aux_info.num_variables {
 | 
															
														||||
 | 
																            let prover_msg = proof.proofs.get(i).expect("proof is incomplete");
 | 
															
														||||
 | 
																            transcript.append_serializable_element(b"prover msg", prover_msg)?;
 | 
															
														||||
 | 
																            IOPVerifierState::verify_round_and_update_state(
 | 
															
														||||
 | 
																                &mut verifier_state,
 | 
															
														||||
 | 
																                prover_msg,
 | 
															
														||||
 | 
																                transcript,
 | 
															
														||||
 | 
																            )?;
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        let res = IOPVerifierState::check_and_generate_subclaim(&verifier_state, &claimed_sum);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        end_timer!(start);
 | 
															
														||||
 | 
																        res
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																}
 | 
															
														||||
@ -0,0 +1,220 @@ | 
															
														|||||
 | 
																// code forked from:
 | 
															
														||||
 | 
																// https://github.com/EspressoSystems/hyperplonk/tree/main/subroutines/src/poly_iop/sum_check
 | 
															
														||||
 | 
																//
 | 
															
														||||
 | 
																// Copyright (c) 2023 Espresso Systems (espressosys.com)
 | 
															
														||||
 | 
																// This file is part of the HyperPlonk library.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																// You should have received a copy of the MIT License
 | 
															
														||||
 | 
																// along with the HyperPlonk library. If not, see <https://mit-license.org/>.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																//! Prover subroutines for a SumCheck protocol.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																use super::SumCheckProver;
 | 
															
														||||
 | 
																use crate::utils::multilinear_polynomial::fix_variables;
 | 
															
														||||
 | 
																use crate::utils::virtual_polynomial::VirtualPolynomial;
 | 
															
														||||
 | 
																use ark_ff::{batch_inversion, PrimeField};
 | 
															
														||||
 | 
																use ark_poly::DenseMultilinearExtension;
 | 
															
														||||
 | 
																use ark_std::{cfg_into_iter, end_timer, start_timer, vec::Vec};
 | 
															
														||||
 | 
																use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator};
 | 
															
														||||
 | 
																use std::sync::Arc;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																use super::structs::{IOPProverMessage, IOPProverState};
 | 
															
														||||
 | 
																use espresso_subroutines::poly_iop::prelude::PolyIOPErrors;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																// #[cfg(feature = "parallel")]
 | 
															
														||||
 | 
																use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
 | 
															
														||||
 | 
																    type VirtualPolynomial = VirtualPolynomial<F>;
 | 
															
														||||
 | 
																    type ProverMessage = IOPProverMessage<F>;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Initialize the prover state to argue for the sum of the input polynomial
 | 
															
														||||
 | 
																    /// over {0,1}^`num_vars`.
 | 
															
														||||
 | 
																    fn prover_init(polynomial: &Self::VirtualPolynomial) -> Result<Self, PolyIOPErrors> {
 | 
															
														||||
 | 
																        let start = start_timer!(|| "sum check prover init");
 | 
															
														||||
 | 
																        if polynomial.aux_info.num_variables == 0 {
 | 
															
														||||
 | 
																            return Err(PolyIOPErrors::InvalidParameters(
 | 
															
														||||
 | 
																                "Attempt to prove a constant.".to_string(),
 | 
															
														||||
 | 
																            ));
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																        end_timer!(start);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        Ok(Self {
 | 
															
														||||
 | 
																            challenges: Vec::with_capacity(polynomial.aux_info.num_variables),
 | 
															
														||||
 | 
																            round: 0,
 | 
															
														||||
 | 
																            poly: polynomial.clone(),
 | 
															
														||||
 | 
																            extrapolation_aux: (1..polynomial.aux_info.max_degree)
 | 
															
														||||
 | 
																                .map(|degree| {
 | 
															
														||||
 | 
																                    let points = (0..1 + degree as u64).map(F::from).collect::<Vec<_>>();
 | 
															
														||||
 | 
																                    let weights = barycentric_weights(&points);
 | 
															
														||||
 | 
																                    (points, weights)
 | 
															
														||||
 | 
																                })
 | 
															
														||||
 | 
																                .collect(),
 | 
															
														||||
 | 
																        })
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Receive message from verifier, generate prover message, and proceed to
 | 
															
														||||
 | 
																    /// next round.
 | 
															
														||||
 | 
																    ///
 | 
															
														||||
 | 
																    /// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2).
 | 
															
														||||
 | 
																    fn prove_round_and_update_state(
 | 
															
														||||
 | 
																        &mut self,
 | 
															
														||||
 | 
																        challenge: &Option<F>,
 | 
															
														||||
 | 
																    ) -> Result<Self::ProverMessage, PolyIOPErrors> {
 | 
															
														||||
 | 
																        // let start =
 | 
															
														||||
 | 
																        //     start_timer!(|| format!("sum check prove {}-th round and update state",
 | 
															
														||||
 | 
																        // self.round));
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        if self.round >= self.poly.aux_info.num_variables {
 | 
															
														||||
 | 
																            return Err(PolyIOPErrors::InvalidProver(
 | 
															
														||||
 | 
																                "Prover is not active".to_string(),
 | 
															
														||||
 | 
																            ));
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        // let fix_argument = start_timer!(|| "fix argument");
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        // Step 1:
 | 
															
														||||
 | 
																        // fix argument and evaluate f(x) over x_m = r; where r is the challenge
 | 
															
														||||
 | 
																        // for the current round, and m is the round number, indexed from 1
 | 
															
														||||
 | 
																        //
 | 
															
														||||
 | 
																        // i.e.:
 | 
															
														||||
 | 
																        // at round m <= n, for each mle g(x_1, ... x_n) within the flattened_mle
 | 
															
														||||
 | 
																        // which has already been evaluated to
 | 
															
														||||
 | 
																        //
 | 
															
														||||
 | 
																        //    g(r_1, ..., r_{m-1}, x_m ... x_n)
 | 
															
														||||
 | 
																        //
 | 
															
														||||
 | 
																        // eval g over r_m, and mutate g to g(r_1, ... r_m,, x_{m+1}... x_n)
 | 
															
														||||
 | 
																        let mut flattened_ml_extensions: Vec<DenseMultilinearExtension<F>> = self
 | 
															
														||||
 | 
																            .poly
 | 
															
														||||
 | 
																            .flattened_ml_extensions
 | 
															
														||||
 | 
																            .par_iter()
 | 
															
														||||
 | 
																            .map(|x| x.as_ref().clone())
 | 
															
														||||
 | 
																            .collect();
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        if let Some(chal) = challenge {
 | 
															
														||||
 | 
																            if self.round == 0 {
 | 
															
														||||
 | 
																                return Err(PolyIOPErrors::InvalidProver(
 | 
															
														||||
 | 
																                    "first round should be prover first.".to_string(),
 | 
															
														||||
 | 
																                ));
 | 
															
														||||
 | 
																            }
 | 
															
														||||
 | 
																            self.challenges.push(*chal);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																            let r = self.challenges[self.round - 1];
 | 
															
														||||
 | 
																            // #[cfg(feature = "parallel")]
 | 
															
														||||
 | 
																            flattened_ml_extensions
 | 
															
														||||
 | 
																                .par_iter_mut()
 | 
															
														||||
 | 
																                .for_each(|mle| *mle = fix_variables(mle, &[r]));
 | 
															
														||||
 | 
																            // #[cfg(not(feature = "parallel"))]
 | 
															
														||||
 | 
																            // flattened_ml_extensions
 | 
															
														||||
 | 
																            //     .iter_mut()
 | 
															
														||||
 | 
																            //     .for_each(|mle| *mle = fix_variables(mle, &[r]));
 | 
															
														||||
 | 
																        } else if self.round > 0 {
 | 
															
														||||
 | 
																            return Err(PolyIOPErrors::InvalidProver(
 | 
															
														||||
 | 
																                "verifier message is empty".to_string(),
 | 
															
														||||
 | 
																            ));
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																        // end_timer!(fix_argument);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        self.round += 1;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        let products_list = self.poly.products.clone();
 | 
															
														||||
 | 
																        let mut products_sum = vec![F::zero(); self.poly.aux_info.max_degree + 1];
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        // Step 2: generate sum for the partial evaluated polynomial:
 | 
															
														||||
 | 
																        // f(r_1, ... r_m,, x_{m+1}... x_n)
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        products_list.iter().for_each(|(coefficient, products)| {
 | 
															
														||||
 | 
																            let mut sum = cfg_into_iter!(0..1 << (self.poly.aux_info.num_variables - self.round))
 | 
															
														||||
 | 
																                .fold(
 | 
															
														||||
 | 
																                    || {
 | 
															
														||||
 | 
																                        (
 | 
															
														||||
 | 
																                            vec![(F::zero(), F::zero()); products.len()],
 | 
															
														||||
 | 
																                            vec![F::zero(); products.len() + 1],
 | 
															
														||||
 | 
																                        )
 | 
															
														||||
 | 
																                    },
 | 
															
														||||
 | 
																                    |(mut buf, mut acc), b| {
 | 
															
														||||
 | 
																                        buf.iter_mut()
 | 
															
														||||
 | 
																                            .zip(products.iter())
 | 
															
														||||
 | 
																                            .for_each(|((eval, step), f)| {
 | 
															
														||||
 | 
																                                let table = &flattened_ml_extensions[*f];
 | 
															
														||||
 | 
																                                *eval = table[b << 1];
 | 
															
														||||
 | 
																                                *step = table[(b << 1) + 1] - table[b << 1];
 | 
															
														||||
 | 
																                            });
 | 
															
														||||
 | 
																                        acc[0] += buf.iter().map(|(eval, _)| eval).product::<F>();
 | 
															
														||||
 | 
																                        acc[1..].iter_mut().for_each(|acc| {
 | 
															
														||||
 | 
																                            buf.iter_mut().for_each(|(eval, step)| *eval += step as &_);
 | 
															
														||||
 | 
																                            *acc += buf.iter().map(|(eval, _)| eval).product::<F>();
 | 
															
														||||
 | 
																                        });
 | 
															
														||||
 | 
																                        (buf, acc)
 | 
															
														||||
 | 
																                    },
 | 
															
														||||
 | 
																                )
 | 
															
														||||
 | 
																                .map(|(_, partial)| partial)
 | 
															
														||||
 | 
																                .reduce(
 | 
															
														||||
 | 
																                    || vec![F::zero(); products.len() + 1],
 | 
															
														||||
 | 
																                    |mut sum, partial| {
 | 
															
														||||
 | 
																                        sum.iter_mut()
 | 
															
														||||
 | 
																                            .zip(partial.iter())
 | 
															
														||||
 | 
																                            .for_each(|(sum, partial)| *sum += partial);
 | 
															
														||||
 | 
																                        sum
 | 
															
														||||
 | 
																                    },
 | 
															
														||||
 | 
																                );
 | 
															
														||||
 | 
																            sum.iter_mut().for_each(|sum| *sum *= coefficient);
 | 
															
														||||
 | 
																            let extraploation = cfg_into_iter!(0..self.poly.aux_info.max_degree - products.len())
 | 
															
														||||
 | 
																                .map(|i| {
 | 
															
														||||
 | 
																                    let (points, weights) = &self.extrapolation_aux[products.len() - 1];
 | 
															
														||||
 | 
																                    let at = F::from((products.len() + 1 + i) as u64);
 | 
															
														||||
 | 
																                    extrapolate(points, weights, &sum, &at)
 | 
															
														||||
 | 
																                })
 | 
															
														||||
 | 
																                .collect::<Vec<_>>();
 | 
															
														||||
 | 
																            products_sum
 | 
															
														||||
 | 
																                .iter_mut()
 | 
															
														||||
 | 
																                .zip(sum.iter().chain(extraploation.iter()))
 | 
															
														||||
 | 
																                .for_each(|(products_sum, sum)| *products_sum += sum);
 | 
															
														||||
 | 
																        });
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        // update prover's state to the partial evaluated polynomial
 | 
															
														||||
 | 
																        self.poly.flattened_ml_extensions = flattened_ml_extensions
 | 
															
														||||
 | 
																            .par_iter()
 | 
															
														||||
 | 
																            .map(|x| Arc::new(x.clone()))
 | 
															
														||||
 | 
																            .collect();
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        Ok(IOPProverMessage {
 | 
															
														||||
 | 
																            evaluations: products_sum,
 | 
															
														||||
 | 
																        })
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																fn barycentric_weights<F: PrimeField>(points: &[F]) -> Vec<F> {
 | 
															
														||||
 | 
																    let mut weights = points
 | 
															
														||||
 | 
																        .iter()
 | 
															
														||||
 | 
																        .enumerate()
 | 
															
														||||
 | 
																        .map(|(j, point_j)| {
 | 
															
														||||
 | 
																            points
 | 
															
														||||
 | 
																                .iter()
 | 
															
														||||
 | 
																                .enumerate()
 | 
															
														||||
 | 
																                .filter_map(|(i, point_i)| (i != j).then(|| *point_j - point_i))
 | 
															
														||||
 | 
																                .reduce(|acc, value| acc * value)
 | 
															
														||||
 | 
																                .unwrap_or_else(F::one)
 | 
															
														||||
 | 
																        })
 | 
															
														||||
 | 
																        .collect::<Vec<_>>();
 | 
															
														||||
 | 
																    batch_inversion(&mut weights);
 | 
															
														||||
 | 
																    weights
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																fn extrapolate<F: PrimeField>(points: &[F], weights: &[F], evals: &[F], at: &F) -> F {
 | 
															
														||||
 | 
																    let (coeffs, sum_inv) = {
 | 
															
														||||
 | 
																        let mut coeffs = points.iter().map(|point| *at - point).collect::<Vec<_>>();
 | 
															
														||||
 | 
																        batch_inversion(&mut coeffs);
 | 
															
														||||
 | 
																        coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| {
 | 
															
														||||
 | 
																            *coeff *= weight;
 | 
															
														||||
 | 
																        });
 | 
															
														||||
 | 
																        let sum_inv = coeffs.iter().sum::<F>().inverse().unwrap_or_default();
 | 
															
														||||
 | 
																        (coeffs, sum_inv)
 | 
															
														||||
 | 
																    };
 | 
															
														||||
 | 
																    coeffs
 | 
															
														||||
 | 
																        .iter()
 | 
															
														||||
 | 
																        .zip(evals)
 | 
															
														||||
 | 
																        .map(|(coeff, eval)| *coeff * eval)
 | 
															
														||||
 | 
																        .sum::<F>()
 | 
															
														||||
 | 
																        * sum_inv
 | 
															
														||||
 | 
																}
 | 
															
														||||
@ -0,0 +1,59 @@ | 
															
														|||||
 | 
																// code forked from:
 | 
															
														||||
 | 
																// https://github.com/EspressoSystems/hyperplonk/tree/main/subroutines/src/poly_iop/sum_check
 | 
															
														||||
 | 
																//
 | 
															
														||||
 | 
																// Copyright (c) 2023 Espresso Systems (espressosys.com)
 | 
															
														||||
 | 
																// This file is part of the HyperPlonk library.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																// You should have received a copy of the MIT License
 | 
															
														||||
 | 
																// along with the HyperPlonk library. If not, see <https://mit-license.org/>.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																//! This module defines structs that are shared by all sub protocols.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																use crate::utils::virtual_polynomial::VirtualPolynomial;
 | 
															
														||||
 | 
																use ark_ff::PrimeField;
 | 
															
														||||
 | 
																use ark_serialize::CanonicalSerialize;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																/// An IOP proof is a collections of
 | 
															
														||||
 | 
																/// - messages from prover to verifier at each round through the interactive
 | 
															
														||||
 | 
																///   protocol.
 | 
															
														||||
 | 
																/// - a point that is generated by the transcript for evaluation
 | 
															
														||||
 | 
																#[derive(Clone, Debug, Default, PartialEq, Eq)]
 | 
															
														||||
 | 
																pub struct IOPProof<F: PrimeField> {
 | 
															
														||||
 | 
																    pub point: Vec<F>,
 | 
															
														||||
 | 
																    pub proofs: Vec<IOPProverMessage<F>>,
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																/// A message from the prover to the verifier at a given round
 | 
															
														||||
 | 
																/// is a list of evaluations.
 | 
															
														||||
 | 
																#[derive(Clone, Debug, Default, PartialEq, Eq, CanonicalSerialize)]
 | 
															
														||||
 | 
																pub struct IOPProverMessage<F: PrimeField> {
 | 
															
														||||
 | 
																    pub(crate) evaluations: Vec<F>,
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																/// Prover State of a PolyIOP.
 | 
															
														||||
 | 
																#[derive(Debug)]
 | 
															
														||||
 | 
																pub struct IOPProverState<F: PrimeField> {
 | 
															
														||||
 | 
																    /// sampled randomness given by the verifier
 | 
															
														||||
 | 
																    pub challenges: Vec<F>,
 | 
															
														||||
 | 
																    /// the current round number
 | 
															
														||||
 | 
																    pub(crate) round: usize,
 | 
															
														||||
 | 
																    /// pointer to the virtual polynomial
 | 
															
														||||
 | 
																    pub(crate) poly: VirtualPolynomial<F>,
 | 
															
														||||
 | 
																    /// points with precomputed barycentric weights for extrapolating smaller
 | 
															
														||||
 | 
																    /// degree uni-polys to `max_degree + 1` evaluations.
 | 
															
														||||
 | 
																    pub(crate) extrapolation_aux: Vec<(Vec<F>, Vec<F>)>,
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																/// Prover State of a PolyIOP
 | 
															
														||||
 | 
																#[derive(Debug)]
 | 
															
														||||
 | 
																pub struct IOPVerifierState<F: PrimeField> {
 | 
															
														||||
 | 
																    pub(crate) round: usize,
 | 
															
														||||
 | 
																    pub(crate) num_vars: usize,
 | 
															
														||||
 | 
																    pub(crate) max_degree: usize,
 | 
															
														||||
 | 
																    pub(crate) finished: bool,
 | 
															
														||||
 | 
																    /// a list storing the univariate polynomial in evaluation form sent by the
 | 
															
														||||
 | 
																    /// prover at each round
 | 
															
														||||
 | 
																    pub(crate) polynomials_received: Vec<Vec<F>>,
 | 
															
														||||
 | 
																    /// a list storing the randomness sampled by the verifier at each round
 | 
															
														||||
 | 
																    pub(crate) challenges: Vec<F>,
 | 
															
														||||
 | 
																}
 | 
															
														||||
@ -0,0 +1,362 @@ | 
															
														|||||
 | 
																// code forked from:
 | 
															
														||||
 | 
																// https://github.com/EspressoSystems/hyperplonk/tree/main/subroutines/src/poly_iop/sum_check
 | 
															
														||||
 | 
																//
 | 
															
														||||
 | 
																// Copyright (c) 2023 Espresso Systems (espressosys.com)
 | 
															
														||||
 | 
																// This file is part of the HyperPlonk library.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																// You should have received a copy of the MIT License
 | 
															
														||||
 | 
																// along with the HyperPlonk library. If not, see <https://mit-license.org/>.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																//! Verifier subroutines for a SumCheck protocol.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																use super::{SumCheckSubClaim, SumCheckVerifier};
 | 
															
														||||
 | 
																use crate::utils::virtual_polynomial::VPAuxInfo;
 | 
															
														||||
 | 
																use ark_ff::PrimeField;
 | 
															
														||||
 | 
																use ark_std::{end_timer, start_timer};
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																use super::structs::{IOPProverMessage, IOPVerifierState};
 | 
															
														||||
 | 
																use espresso_subroutines::poly_iop::prelude::PolyIOPErrors;
 | 
															
														||||
 | 
																use espresso_transcript::IOPTranscript;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																#[cfg(feature = "parallel")]
 | 
															
														||||
 | 
																use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																impl<F: PrimeField> SumCheckVerifier<F> for IOPVerifierState<F> {
 | 
															
														||||
 | 
																    type VPAuxInfo = VPAuxInfo<F>;
 | 
															
														||||
 | 
																    type ProverMessage = IOPProverMessage<F>;
 | 
															
														||||
 | 
																    type Challenge = F;
 | 
															
														||||
 | 
																    type Transcript = IOPTranscript<F>;
 | 
															
														||||
 | 
																    type SumCheckSubClaim = SumCheckSubClaim<F>;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Initialize the verifier's state.
 | 
															
														||||
 | 
																    fn verifier_init(index_info: &Self::VPAuxInfo) -> Self {
 | 
															
														||||
 | 
																        let start = start_timer!(|| "sum check verifier init");
 | 
															
														||||
 | 
																        let res = Self {
 | 
															
														||||
 | 
																            round: 1,
 | 
															
														||||
 | 
																            num_vars: index_info.num_variables,
 | 
															
														||||
 | 
																            max_degree: index_info.max_degree,
 | 
															
														||||
 | 
																            finished: false,
 | 
															
														||||
 | 
																            polynomials_received: Vec::with_capacity(index_info.num_variables),
 | 
															
														||||
 | 
																            challenges: Vec::with_capacity(index_info.num_variables),
 | 
															
														||||
 | 
																        };
 | 
															
														||||
 | 
																        end_timer!(start);
 | 
															
														||||
 | 
																        res
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Run verifier for the current round, given a prover message.
 | 
															
														||||
 | 
																    ///
 | 
															
														||||
 | 
																    /// Note that `verify_round_and_update_state` only samples and stores
 | 
															
														||||
 | 
																    /// challenges; and update the verifier's state accordingly. The actual
 | 
															
														||||
 | 
																    /// verifications are deferred (in batch) to `check_and_generate_subclaim`
 | 
															
														||||
 | 
																    /// at the last step.
 | 
															
														||||
 | 
																    fn verify_round_and_update_state(
 | 
															
														||||
 | 
																        &mut self,
 | 
															
														||||
 | 
																        prover_msg: &Self::ProverMessage,
 | 
															
														||||
 | 
																        transcript: &mut Self::Transcript,
 | 
															
														||||
 | 
																    ) -> Result<Self::Challenge, PolyIOPErrors> {
 | 
															
														||||
 | 
																        let start =
 | 
															
														||||
 | 
																            start_timer!(|| format!("sum check verify {}-th round and update state", self.round));
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        if self.finished {
 | 
															
														||||
 | 
																            return Err(PolyIOPErrors::InvalidVerifier(
 | 
															
														||||
 | 
																                "Incorrect verifier state: Verifier is already finished.".to_string(),
 | 
															
														||||
 | 
																            ));
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        // In an interactive protocol, the verifier should
 | 
															
														||||
 | 
																        //
 | 
															
														||||
 | 
																        // 1. check if the received 'P(0) + P(1) = expected`.
 | 
															
														||||
 | 
																        // 2. set `expected` to P(r)`
 | 
															
														||||
 | 
																        //
 | 
															
														||||
 | 
																        // When we turn the protocol to a non-interactive one, it is sufficient to defer
 | 
															
														||||
 | 
																        // such checks to `check_and_generate_subclaim` after the last round.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        let challenge = transcript.get_and_append_challenge(b"Internal round")?;
 | 
															
														||||
 | 
																        self.challenges.push(challenge);
 | 
															
														||||
 | 
																        self.polynomials_received
 | 
															
														||||
 | 
																            .push(prover_msg.evaluations.to_vec());
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        if self.round == self.num_vars {
 | 
															
														||||
 | 
																            // accept and close
 | 
															
														||||
 | 
																            self.finished = true;
 | 
															
														||||
 | 
																        } else {
 | 
															
														||||
 | 
																            // proceed to the next round
 | 
															
														||||
 | 
																            self.round += 1;
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        end_timer!(start);
 | 
															
														||||
 | 
																        Ok(challenge)
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// This function verifies the deferred checks in the interactive version of
 | 
															
														||||
 | 
																    /// the protocol; and generate the subclaim. Returns an error if the
 | 
															
														||||
 | 
																    /// proof failed to verify.
 | 
															
														||||
 | 
																    ///
 | 
															
														||||
 | 
																    /// If the asserted sum is correct, then the multilinear polynomial
 | 
															
														||||
 | 
																    /// evaluated at `subclaim.point` will be `subclaim.expected_evaluation`.
 | 
															
														||||
 | 
																    /// Otherwise, it is highly unlikely that those two will be equal.
 | 
															
														||||
 | 
																    /// Larger field size guarantees smaller soundness error.
 | 
															
														||||
 | 
																    fn check_and_generate_subclaim(
 | 
															
														||||
 | 
																        &self,
 | 
															
														||||
 | 
																        asserted_sum: &F,
 | 
															
														||||
 | 
																    ) -> Result<Self::SumCheckSubClaim, PolyIOPErrors> {
 | 
															
														||||
 | 
																        let start = start_timer!(|| "sum check check and generate subclaim");
 | 
															
														||||
 | 
																        if !self.finished {
 | 
															
														||||
 | 
																            return Err(PolyIOPErrors::InvalidVerifier(
 | 
															
														||||
 | 
																                "Incorrect verifier state: Verifier has not finished.".to_string(),
 | 
															
														||||
 | 
																            ));
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        if self.polynomials_received.len() != self.num_vars {
 | 
															
														||||
 | 
																            return Err(PolyIOPErrors::InvalidVerifier(
 | 
															
														||||
 | 
																                "insufficient rounds".to_string(),
 | 
															
														||||
 | 
																            ));
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        // the deferred check during the interactive phase:
 | 
															
														||||
 | 
																        // 2. set `expected` to P(r)`
 | 
															
														||||
 | 
																        #[cfg(feature = "parallel")]
 | 
															
														||||
 | 
																        let mut expected_vec = self
 | 
															
														||||
 | 
																            .polynomials_received
 | 
															
														||||
 | 
																            .clone()
 | 
															
														||||
 | 
																            .into_par_iter()
 | 
															
														||||
 | 
																            .zip(self.challenges.clone().into_par_iter())
 | 
															
														||||
 | 
																            .map(|(evaluations, challenge)| {
 | 
															
														||||
 | 
																                if evaluations.len() != self.max_degree + 1 {
 | 
															
														||||
 | 
																                    return Err(PolyIOPErrors::InvalidVerifier(format!(
 | 
															
														||||
 | 
																                        "incorrect number of evaluations: {} vs {}",
 | 
															
														||||
 | 
																                        evaluations.len(),
 | 
															
														||||
 | 
																                        self.max_degree + 1
 | 
															
														||||
 | 
																                    )));
 | 
															
														||||
 | 
																                }
 | 
															
														||||
 | 
																                interpolate_uni_poly::<F>(&evaluations, challenge)
 | 
															
														||||
 | 
																            })
 | 
															
														||||
 | 
																            .collect::<Result<Vec<_>, PolyIOPErrors>>()?;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        #[cfg(not(feature = "parallel"))]
 | 
															
														||||
 | 
																        let mut expected_vec = self
 | 
															
														||||
 | 
																            .polynomials_received
 | 
															
														||||
 | 
																            .clone()
 | 
															
														||||
 | 
																            .into_iter()
 | 
															
														||||
 | 
																            .zip(self.challenges.clone().into_iter())
 | 
															
														||||
 | 
																            .map(|(evaluations, challenge)| {
 | 
															
														||||
 | 
																                if evaluations.len() != self.max_degree + 1 {
 | 
															
														||||
 | 
																                    return Err(PolyIOPErrors::InvalidVerifier(format!(
 | 
															
														||||
 | 
																                        "incorrect number of evaluations: {} vs {}",
 | 
															
														||||
 | 
																                        evaluations.len(),
 | 
															
														||||
 | 
																                        self.max_degree + 1
 | 
															
														||||
 | 
																                    )));
 | 
															
														||||
 | 
																                }
 | 
															
														||||
 | 
																                interpolate_uni_poly::<F>(&evaluations, challenge)
 | 
															
														||||
 | 
																            })
 | 
															
														||||
 | 
																            .collect::<Result<Vec<_>, PolyIOPErrors>>()?;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        // insert the asserted_sum to the first position of the expected vector
 | 
															
														||||
 | 
																        expected_vec.insert(0, *asserted_sum);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        for (evaluations, &expected) in self
 | 
															
														||||
 | 
																            .polynomials_received
 | 
															
														||||
 | 
																            .iter()
 | 
															
														||||
 | 
																            .zip(expected_vec.iter())
 | 
															
														||||
 | 
																            .take(self.num_vars)
 | 
															
														||||
 | 
																        {
 | 
															
														||||
 | 
																            // the deferred check during the interactive phase:
 | 
															
														||||
 | 
																            // 1. check if the received 'P(0) + P(1) = expected`.
 | 
															
														||||
 | 
																            if evaluations[0] + evaluations[1] != expected {
 | 
															
														||||
 | 
																                return Err(PolyIOPErrors::InvalidProof(
 | 
															
														||||
 | 
																                    "Prover message is not consistent with the claim.".to_string(),
 | 
															
														||||
 | 
																                ));
 | 
															
														||||
 | 
																            }
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																        end_timer!(start);
 | 
															
														||||
 | 
																        Ok(SumCheckSubClaim {
 | 
															
														||||
 | 
																            point: self.challenges.clone(),
 | 
															
														||||
 | 
																            // the last expected value (not checked within this function) will be included in the
 | 
															
														||||
 | 
																            // subclaim
 | 
															
														||||
 | 
																            expected_evaluation: expected_vec[self.num_vars],
 | 
															
														||||
 | 
																        })
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																/// Interpolate a uni-variate degree-`p_i.len()-1` polynomial and evaluate this
 | 
															
														||||
 | 
																/// polynomial at `eval_at`:
 | 
															
														||||
 | 
																///
 | 
															
														||||
 | 
																///   \sum_{i=0}^len p_i * (\prod_{j!=i} (eval_at - j)/(i-j) )
 | 
															
														||||
 | 
																///
 | 
															
														||||
 | 
																/// This implementation is linear in number of inputs in terms of field
 | 
															
														||||
 | 
																/// operations. It also has a quadratic term in primitive operations which is
 | 
															
														||||
 | 
																/// negligible compared to field operations.
 | 
															
														||||
 | 
																/// TODO: The quadratic term can be removed by precomputing the lagrange
 | 
															
														||||
 | 
																/// coefficients.
 | 
															
														||||
 | 
																pub fn interpolate_uni_poly<F: PrimeField>(p_i: &[F], eval_at: F) -> Result<F, PolyIOPErrors> {
 | 
															
														||||
 | 
																    let start = start_timer!(|| "sum check interpolate uni poly opt");
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    let len = p_i.len();
 | 
															
														||||
 | 
																    let mut evals = vec![];
 | 
															
														||||
 | 
																    let mut prod = eval_at;
 | 
															
														||||
 | 
																    evals.push(eval_at);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    // `prod = \prod_{j} (eval_at - j)`
 | 
															
														||||
 | 
																    for e in 1..len {
 | 
															
														||||
 | 
																        let tmp = eval_at - F::from(e as u64);
 | 
															
														||||
 | 
																        evals.push(tmp);
 | 
															
														||||
 | 
																        prod *= tmp;
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																    let mut res = F::zero();
 | 
															
														||||
 | 
																    // we want to compute \prod (j!=i) (i-j) for a given i
 | 
															
														||||
 | 
																    //
 | 
															
														||||
 | 
																    // we start from the last step, which is
 | 
															
														||||
 | 
																    //  denom[len-1] = (len-1) * (len-2) *... * 2 * 1
 | 
															
														||||
 | 
																    // the step before that is
 | 
															
														||||
 | 
																    //  denom[len-2] = (len-2) * (len-3) * ... * 2 * 1 * -1
 | 
															
														||||
 | 
																    // and the step before that is
 | 
															
														||||
 | 
																    //  denom[len-3] = (len-3) * (len-4) * ... * 2 * 1 * -1 * -2
 | 
															
														||||
 | 
																    //
 | 
															
														||||
 | 
																    // i.e., for any i, the one before this will be derived from
 | 
															
														||||
 | 
																    //  denom[i-1] = denom[i] * (len-i) / i
 | 
															
														||||
 | 
																    //
 | 
															
														||||
 | 
																    // that is, we only need to store
 | 
															
														||||
 | 
																    // - the last denom for i = len-1, and
 | 
															
														||||
 | 
																    // - the ratio between current step and fhe last step, which is the product of
 | 
															
														||||
 | 
																    //   (len-i) / i from all previous steps and we store this product as a fraction
 | 
															
														||||
 | 
																    //   number to reduce field divisions.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    // We know
 | 
															
														||||
 | 
																    //  - 2^61 < factorial(20) < 2^62
 | 
															
														||||
 | 
																    //  - 2^122 < factorial(33) < 2^123
 | 
															
														||||
 | 
																    // so we will be able to compute the ratio
 | 
															
														||||
 | 
																    //  - for len <= 20 with i64
 | 
															
														||||
 | 
																    //  - for len <= 33 with i128
 | 
															
														||||
 | 
																    //  - for len >  33 with BigInt
 | 
															
														||||
 | 
																    if p_i.len() <= 20 {
 | 
															
														||||
 | 
																        let last_denominator = F::from(u64_factorial(len - 1));
 | 
															
														||||
 | 
																        let mut ratio_numerator = 1i64;
 | 
															
														||||
 | 
																        let mut ratio_denominator = 1u64;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        for i in (0..len).rev() {
 | 
															
														||||
 | 
																            let ratio_numerator_f = if ratio_numerator < 0 {
 | 
															
														||||
 | 
																                -F::from((-ratio_numerator) as u64)
 | 
															
														||||
 | 
																            } else {
 | 
															
														||||
 | 
																                F::from(ratio_numerator as u64)
 | 
															
														||||
 | 
																            };
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																            res += p_i[i] * prod * F::from(ratio_denominator)
 | 
															
														||||
 | 
																                / (last_denominator * ratio_numerator_f * evals[i]);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																            // compute denom for the next step is current_denom * (len-i)/i
 | 
															
														||||
 | 
																            if i != 0 {
 | 
															
														||||
 | 
																                ratio_numerator *= -(len as i64 - i as i64);
 | 
															
														||||
 | 
																                ratio_denominator *= i as u64;
 | 
															
														||||
 | 
																            }
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																    } else if p_i.len() <= 33 {
 | 
															
														||||
 | 
																        let last_denominator = F::from(u128_factorial(len - 1));
 | 
															
														||||
 | 
																        let mut ratio_numerator = 1i128;
 | 
															
														||||
 | 
																        let mut ratio_denominator = 1u128;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        for i in (0..len).rev() {
 | 
															
														||||
 | 
																            let ratio_numerator_f = if ratio_numerator < 0 {
 | 
															
														||||
 | 
																                -F::from((-ratio_numerator) as u128)
 | 
															
														||||
 | 
																            } else {
 | 
															
														||||
 | 
																                F::from(ratio_numerator as u128)
 | 
															
														||||
 | 
																            };
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																            res += p_i[i] * prod * F::from(ratio_denominator)
 | 
															
														||||
 | 
																                / (last_denominator * ratio_numerator_f * evals[i]);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																            // compute denom for the next step is current_denom * (len-i)/i
 | 
															
														||||
 | 
																            if i != 0 {
 | 
															
														||||
 | 
																                ratio_numerator *= -(len as i128 - i as i128);
 | 
															
														||||
 | 
																                ratio_denominator *= i as u128;
 | 
															
														||||
 | 
																            }
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																    } else {
 | 
															
														||||
 | 
																        let mut denom_up = field_factorial::<F>(len - 1);
 | 
															
														||||
 | 
																        let mut denom_down = F::one();
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        for i in (0..len).rev() {
 | 
															
														||||
 | 
																            res += p_i[i] * prod * denom_down / (denom_up * evals[i]);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																            // compute denom for the next step is current_denom * (len-i)/i
 | 
															
														||||
 | 
																            if i != 0 {
 | 
															
														||||
 | 
																                denom_up *= -F::from((len - i) as u64);
 | 
															
														||||
 | 
																                denom_down *= F::from(i as u64);
 | 
															
														||||
 | 
																            }
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																    end_timer!(start);
 | 
															
														||||
 | 
																    Ok(res)
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																/// compute the factorial(a) = 1 * 2 * ... * a
 | 
															
														||||
 | 
																#[inline]
 | 
															
														||||
 | 
																fn field_factorial<F: PrimeField>(a: usize) -> F {
 | 
															
														||||
 | 
																    let mut res = F::one();
 | 
															
														||||
 | 
																    for i in 2..=a {
 | 
															
														||||
 | 
																        res *= F::from(i as u64);
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																    res
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																/// compute the factorial(a) = 1 * 2 * ... * a
 | 
															
														||||
 | 
																#[inline]
 | 
															
														||||
 | 
																fn u128_factorial(a: usize) -> u128 {
 | 
															
														||||
 | 
																    let mut res = 1u128;
 | 
															
														||||
 | 
																    for i in 2..=a {
 | 
															
														||||
 | 
																        res *= i as u128;
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																    res
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																/// compute the factorial(a) = 1 * 2 * ... * a
 | 
															
														||||
 | 
																#[inline]
 | 
															
														||||
 | 
																fn u64_factorial(a: usize) -> u64 {
 | 
															
														||||
 | 
																    let mut res = 1u64;
 | 
															
														||||
 | 
																    for i in 2..=a {
 | 
															
														||||
 | 
																        res *= i as u64;
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																    res
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																#[cfg(test)]
 | 
															
														||||
 | 
																mod test {
 | 
															
														||||
 | 
																    use super::interpolate_uni_poly;
 | 
															
														||||
 | 
																    use ark_bls12_377::Fr;
 | 
															
														||||
 | 
																    use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial, Polynomial};
 | 
															
														||||
 | 
																    use ark_std::{vec::Vec, UniformRand};
 | 
															
														||||
 | 
																    use espresso_subroutines::poly_iop::prelude::PolyIOPErrors;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    #[test]
 | 
															
														||||
 | 
																    fn test_interpolation() -> Result<(), PolyIOPErrors> {
 | 
															
														||||
 | 
																        let mut prng = ark_std::test_rng();
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        // test a polynomial with 20 known points, i.e., with degree 19
 | 
															
														||||
 | 
																        let poly = DensePolynomial::<Fr>::rand(20 - 1, &mut prng);
 | 
															
														||||
 | 
																        let evals = (0..20)
 | 
															
														||||
 | 
																            .map(|i| poly.evaluate(&Fr::from(i)))
 | 
															
														||||
 | 
																            .collect::<Vec<Fr>>();
 | 
															
														||||
 | 
																        let query = Fr::rand(&mut prng);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        assert_eq!(poly.evaluate(&query), interpolate_uni_poly(&evals, query)?);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        // test a polynomial with 33 known points, i.e., with degree 32
 | 
															
														||||
 | 
																        let poly = DensePolynomial::<Fr>::rand(33 - 1, &mut prng);
 | 
															
														||||
 | 
																        let evals = (0..33)
 | 
															
														||||
 | 
																            .map(|i| poly.evaluate(&Fr::from(i)))
 | 
															
														||||
 | 
																            .collect::<Vec<Fr>>();
 | 
															
														||||
 | 
																        let query = Fr::rand(&mut prng);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        assert_eq!(poly.evaluate(&query), interpolate_uni_poly(&evals, query)?);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        // test a polynomial with 64 known points, i.e., with degree 63
 | 
															
														||||
 | 
																        let poly = DensePolynomial::<Fr>::rand(64 - 1, &mut prng);
 | 
															
														||||
 | 
																        let evals = (0..64)
 | 
															
														||||
 | 
																            .map(|i| poly.evaluate(&Fr::from(i)))
 | 
															
														||||
 | 
																            .collect::<Vec<Fr>>();
 | 
															
														||||
 | 
																        let query = Fr::rand(&mut prng);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        assert_eq!(poly.evaluate(&query), interpolate_uni_poly(&evals, query)?);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        Ok(())
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																}
 | 
															
														||||
@ -0,0 +1,550 @@ | 
															
														|||||
 | 
																// code forked from
 | 
															
														||||
 | 
																// https://github.com/privacy-scaling-explorations/multifolding-poc/blob/main/src/espresso/virtual_polynomial.rs
 | 
															
														||||
 | 
																//
 | 
															
														||||
 | 
																// Copyright (c) 2023 Espresso Systems (espressosys.com)
 | 
															
														||||
 | 
																// This file is part of the HyperPlonk library.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																// You should have received a copy of the MIT License
 | 
															
														||||
 | 
																// along with the HyperPlonk library. If not, see <https://mit-license.org/>.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																//! This module defines our main mathematical object `VirtualPolynomial`; and
 | 
															
														||||
 | 
																//! various functions associated with it.
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																use ark_ff::PrimeField;
 | 
															
														||||
 | 
																use ark_poly::{DenseMultilinearExtension, MultilinearExtension};
 | 
															
														||||
 | 
																use ark_serialize::CanonicalSerialize;
 | 
															
														||||
 | 
																use ark_std::{end_timer, start_timer};
 | 
															
														||||
 | 
																use rayon::prelude::*;
 | 
															
														||||
 | 
																use std::{cmp::max, collections::HashMap, marker::PhantomData, ops::Add, sync::Arc};
 | 
															
														||||
 | 
																use thiserror::Error;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																use ark_std::string::String;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																//-- aritherrors
 | 
															
														||||
 | 
																/// A `enum` specifying the possible failure modes of the arithmetics.
 | 
															
														||||
 | 
																#[derive(Error, Debug)]
 | 
															
														||||
 | 
																pub enum ArithErrors {
 | 
															
														||||
 | 
																    #[error("Invalid parameters: {0}")]
 | 
															
														||||
 | 
																    InvalidParameters(String),
 | 
															
														||||
 | 
																    #[error("Should not arrive to this point")]
 | 
															
														||||
 | 
																    ShouldNotArrive,
 | 
															
														||||
 | 
																    #[error("An error during (de)serialization: {0}")]
 | 
															
														||||
 | 
																    SerializationErrors(ark_serialize::SerializationError),
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																impl From<ark_serialize::SerializationError> for ArithErrors {
 | 
															
														||||
 | 
																    fn from(e: ark_serialize::SerializationError) -> Self {
 | 
															
														||||
 | 
																        Self::SerializationErrors(e)
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																//-- aritherrors
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																#[rustfmt::skip]
 | 
															
														||||
 | 
																/// A virtual polynomial is a sum of products of multilinear polynomials;
 | 
															
														||||
 | 
																/// where the multilinear polynomials are stored via their multilinear
 | 
															
														||||
 | 
																/// extensions:  `(coefficient, DenseMultilinearExtension)`
 | 
															
														||||
 | 
																///
 | 
															
														||||
 | 
																/// * Number of products n = `polynomial.products.len()`,
 | 
															
														||||
 | 
																/// * Number of multiplicands of ith product m_i =
 | 
															
														||||
 | 
																///   `polynomial.products[i].1.len()`,
 | 
															
														||||
 | 
																/// * Coefficient of ith product c_i = `polynomial.products[i].0`
 | 
															
														||||
 | 
																///
 | 
															
														||||
 | 
																/// The resulting polynomial is
 | 
															
														||||
 | 
																///
 | 
															
														||||
 | 
																/// $$ \sum_{i=0}^{n} c_i \cdot \prod_{j=0}^{m_i} P_{ij} $$
 | 
															
														||||
 | 
																///
 | 
															
														||||
 | 
																/// Example:
 | 
															
														||||
 | 
																///  f = c0 * f0 * f1 * f2 + c1 * f3 * f4
 | 
															
														||||
 | 
																/// where f0 ... f4 are multilinear polynomials
 | 
															
														||||
 | 
																///
 | 
															
														||||
 | 
																/// - flattened_ml_extensions stores the multilinear extension representation of
 | 
															
														||||
 | 
																///   f0, f1, f2, f3 and f4
 | 
															
														||||
 | 
																/// - products is
 | 
															
														||||
 | 
																///     \[
 | 
															
														||||
 | 
																///         (c0, \[0, 1, 2\]),
 | 
															
														||||
 | 
																///         (c1, \[3, 4\])
 | 
															
														||||
 | 
																///     \]
 | 
															
														||||
 | 
																/// - raw_pointers_lookup_table maps fi to i
 | 
															
														||||
 | 
																///
 | 
															
														||||
 | 
																#[derive(Clone, Debug, Default, PartialEq)]
 | 
															
														||||
 | 
																pub struct VirtualPolynomial<F: PrimeField> {
 | 
															
														||||
 | 
																    /// Aux information about the multilinear polynomial
 | 
															
														||||
 | 
																    pub aux_info: VPAuxInfo<F>,
 | 
															
														||||
 | 
																    /// list of reference to products (as usize) of multilinear extension
 | 
															
														||||
 | 
																    pub products: Vec<(F, Vec<usize>)>,
 | 
															
														||||
 | 
																    /// Stores multilinear extensions in which product multiplicand can refer
 | 
															
														||||
 | 
																    /// to.
 | 
															
														||||
 | 
																    pub flattened_ml_extensions: Vec<Arc<DenseMultilinearExtension<F>>>,
 | 
															
														||||
 | 
																    /// Pointers to the above poly extensions
 | 
															
														||||
 | 
																    raw_pointers_lookup_table: HashMap<*const DenseMultilinearExtension<F>, usize>,
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																#[derive(Clone, Debug, Default, PartialEq, Eq, CanonicalSerialize)]
 | 
															
														||||
 | 
																/// Auxiliary information about the multilinear polynomial
 | 
															
														||||
 | 
																pub struct VPAuxInfo<F: PrimeField> {
 | 
															
														||||
 | 
																    /// max number of multiplicands in each product
 | 
															
														||||
 | 
																    pub max_degree: usize,
 | 
															
														||||
 | 
																    /// number of variables of the polynomial
 | 
															
														||||
 | 
																    pub num_variables: usize,
 | 
															
														||||
 | 
																    /// Associated field
 | 
															
														||||
 | 
																    #[doc(hidden)]
 | 
															
														||||
 | 
																    pub phantom: PhantomData<F>,
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																impl<F: PrimeField> Add for &VirtualPolynomial<F> {
 | 
															
														||||
 | 
																    type Output = VirtualPolynomial<F>;
 | 
															
														||||
 | 
																    fn add(self, other: &VirtualPolynomial<F>) -> Self::Output {
 | 
															
														||||
 | 
																        let start = start_timer!(|| "virtual poly add");
 | 
															
														||||
 | 
																        let mut res = self.clone();
 | 
															
														||||
 | 
																        for products in other.products.iter() {
 | 
															
														||||
 | 
																            let cur: Vec<Arc<DenseMultilinearExtension<F>>> = products
 | 
															
														||||
 | 
																                .1
 | 
															
														||||
 | 
																                .iter()
 | 
															
														||||
 | 
																                .map(|&x| other.flattened_ml_extensions[x].clone())
 | 
															
														||||
 | 
																                .collect();
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																            res.add_mle_list(cur, products.0)
 | 
															
														||||
 | 
																                .expect("add product failed");
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																        end_timer!(start);
 | 
															
														||||
 | 
																        res
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																// TODO: convert this into a trait
 | 
															
														||||
 | 
																impl<F: PrimeField> VirtualPolynomial<F> {
 | 
															
														||||
 | 
																    /// Creates an empty virtual polynomial with `num_variables`.
 | 
															
														||||
 | 
																    pub fn new(num_variables: usize) -> Self {
 | 
															
														||||
 | 
																        VirtualPolynomial {
 | 
															
														||||
 | 
																            aux_info: VPAuxInfo {
 | 
															
														||||
 | 
																                max_degree: 0,
 | 
															
														||||
 | 
																                num_variables,
 | 
															
														||||
 | 
																                phantom: PhantomData,
 | 
															
														||||
 | 
																            },
 | 
															
														||||
 | 
																            products: Vec::new(),
 | 
															
														||||
 | 
																            flattened_ml_extensions: Vec::new(),
 | 
															
														||||
 | 
																            raw_pointers_lookup_table: HashMap::new(),
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Creates an new virtual polynomial from a MLE and its coefficient.
 | 
															
														||||
 | 
																    pub fn new_from_mle(mle: &Arc<DenseMultilinearExtension<F>>, coefficient: F) -> Self {
 | 
															
														||||
 | 
																        let mle_ptr: *const DenseMultilinearExtension<F> = Arc::as_ptr(mle);
 | 
															
														||||
 | 
																        let mut hm = HashMap::new();
 | 
															
														||||
 | 
																        hm.insert(mle_ptr, 0);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        VirtualPolynomial {
 | 
															
														||||
 | 
																            aux_info: VPAuxInfo {
 | 
															
														||||
 | 
																                // The max degree is the max degree of any individual variable
 | 
															
														||||
 | 
																                max_degree: 1,
 | 
															
														||||
 | 
																                num_variables: mle.num_vars,
 | 
															
														||||
 | 
																                phantom: PhantomData,
 | 
															
														||||
 | 
																            },
 | 
															
														||||
 | 
																            // here `0` points to the first polynomial of `flattened_ml_extensions`
 | 
															
														||||
 | 
																            products: vec![(coefficient, vec![0])],
 | 
															
														||||
 | 
																            flattened_ml_extensions: vec![mle.clone()],
 | 
															
														||||
 | 
																            raw_pointers_lookup_table: hm,
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Add a product of list of multilinear extensions to self
 | 
															
														||||
 | 
																    /// Returns an error if the list is empty, or the MLE has a different
 | 
															
														||||
 | 
																    /// `num_vars` from self.
 | 
															
														||||
 | 
																    ///
 | 
															
														||||
 | 
																    /// The MLEs will be multiplied together, and then multiplied by the scalar
 | 
															
														||||
 | 
																    /// `coefficient`.
 | 
															
														||||
 | 
																    pub fn add_mle_list(
 | 
															
														||||
 | 
																        &mut self,
 | 
															
														||||
 | 
																        mle_list: impl IntoIterator<Item = Arc<DenseMultilinearExtension<F>>>,
 | 
															
														||||
 | 
																        coefficient: F,
 | 
															
														||||
 | 
																    ) -> Result<(), ArithErrors> {
 | 
															
														||||
 | 
																        let mle_list: Vec<Arc<DenseMultilinearExtension<F>>> = mle_list.into_iter().collect();
 | 
															
														||||
 | 
																        let mut indexed_product = Vec::with_capacity(mle_list.len());
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        if mle_list.is_empty() {
 | 
															
														||||
 | 
																            return Err(ArithErrors::InvalidParameters(
 | 
															
														||||
 | 
																                "input mle_list is empty".to_string(),
 | 
															
														||||
 | 
																            ));
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        self.aux_info.max_degree = max(self.aux_info.max_degree, mle_list.len());
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        for mle in mle_list {
 | 
															
														||||
 | 
																            if mle.num_vars != self.aux_info.num_variables {
 | 
															
														||||
 | 
																                return Err(ArithErrors::InvalidParameters(format!(
 | 
															
														||||
 | 
																                    "product has a multiplicand with wrong number of variables {} vs {}",
 | 
															
														||||
 | 
																                    mle.num_vars, self.aux_info.num_variables
 | 
															
														||||
 | 
																                )));
 | 
															
														||||
 | 
																            }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																            let mle_ptr: *const DenseMultilinearExtension<F> = Arc::as_ptr(&mle);
 | 
															
														||||
 | 
																            if let Some(index) = self.raw_pointers_lookup_table.get(&mle_ptr) {
 | 
															
														||||
 | 
																                indexed_product.push(*index)
 | 
															
														||||
 | 
																            } else {
 | 
															
														||||
 | 
																                let curr_index = self.flattened_ml_extensions.len();
 | 
															
														||||
 | 
																                self.flattened_ml_extensions.push(mle.clone());
 | 
															
														||||
 | 
																                self.raw_pointers_lookup_table.insert(mle_ptr, curr_index);
 | 
															
														||||
 | 
																                indexed_product.push(curr_index);
 | 
															
														||||
 | 
																            }
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																        self.products.push((coefficient, indexed_product));
 | 
															
														||||
 | 
																        Ok(())
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Multiple the current VirtualPolynomial by an MLE:
 | 
															
														||||
 | 
																    /// - add the MLE to the MLE list;
 | 
															
														||||
 | 
																    /// - multiple each product by MLE and its coefficient.
 | 
															
														||||
 | 
																    /// Returns an error if the MLE has a different `num_vars` from self.
 | 
															
														||||
 | 
																    pub fn mul_by_mle(
 | 
															
														||||
 | 
																        &mut self,
 | 
															
														||||
 | 
																        mle: Arc<DenseMultilinearExtension<F>>,
 | 
															
														||||
 | 
																        coefficient: F,
 | 
															
														||||
 | 
																    ) -> Result<(), ArithErrors> {
 | 
															
														||||
 | 
																        let start = start_timer!(|| "mul by mle");
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        if mle.num_vars != self.aux_info.num_variables {
 | 
															
														||||
 | 
																            return Err(ArithErrors::InvalidParameters(format!(
 | 
															
														||||
 | 
																                "product has a multiplicand with wrong number of variables {} vs {}",
 | 
															
														||||
 | 
																                mle.num_vars, self.aux_info.num_variables
 | 
															
														||||
 | 
																            )));
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        let mle_ptr: *const DenseMultilinearExtension<F> = Arc::as_ptr(&mle);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        // check if this mle already exists in the virtual polynomial
 | 
															
														||||
 | 
																        let mle_index = match self.raw_pointers_lookup_table.get(&mle_ptr) {
 | 
															
														||||
 | 
																            Some(&p) => p,
 | 
															
														||||
 | 
																            None => {
 | 
															
														||||
 | 
																                self.raw_pointers_lookup_table
 | 
															
														||||
 | 
																                    .insert(mle_ptr, self.flattened_ml_extensions.len());
 | 
															
														||||
 | 
																                self.flattened_ml_extensions.push(mle);
 | 
															
														||||
 | 
																                self.flattened_ml_extensions.len() - 1
 | 
															
														||||
 | 
																            }
 | 
															
														||||
 | 
																        };
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        for (prod_coef, indices) in self.products.iter_mut() {
 | 
															
														||||
 | 
																            // - add the MLE to the MLE list;
 | 
															
														||||
 | 
																            // - multiple each product by MLE and its coefficient.
 | 
															
														||||
 | 
																            indices.push(mle_index);
 | 
															
														||||
 | 
																            *prod_coef *= coefficient;
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        // increase the max degree by one as the MLE has degree 1.
 | 
															
														||||
 | 
																        self.aux_info.max_degree += 1;
 | 
															
														||||
 | 
																        end_timer!(start);
 | 
															
														||||
 | 
																        Ok(())
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Given virtual polynomial `p(x)` and scalar `s`, compute `s*p(x)`
 | 
															
														||||
 | 
																    pub fn scalar_mul(&mut self, s: &F) {
 | 
															
														||||
 | 
																        for (prod_coef, _) in self.products.iter_mut() {
 | 
															
														||||
 | 
																            *prod_coef *= s;
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Evaluate the virtual polynomial at point `point`.
 | 
															
														||||
 | 
																    /// Returns an error is point.len() does not match `num_variables`.
 | 
															
														||||
 | 
																    pub fn evaluate(&self, point: &[F]) -> Result<F, ArithErrors> {
 | 
															
														||||
 | 
																        let start = start_timer!(|| "evaluation");
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        if self.aux_info.num_variables != point.len() {
 | 
															
														||||
 | 
																            return Err(ArithErrors::InvalidParameters(format!(
 | 
															
														||||
 | 
																                "wrong number of variables {} vs {}",
 | 
															
														||||
 | 
																                self.aux_info.num_variables,
 | 
															
														||||
 | 
																                point.len()
 | 
															
														||||
 | 
																            )));
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        // Evaluate all the MLEs at `point`
 | 
															
														||||
 | 
																        let evals: Vec<F> = self
 | 
															
														||||
 | 
																            .flattened_ml_extensions
 | 
															
														||||
 | 
																            .iter()
 | 
															
														||||
 | 
																            .map(|x| {
 | 
															
														||||
 | 
																                x.evaluate(point).unwrap() // safe unwrap here since we have
 | 
															
														||||
 | 
																                                           // already checked that num_var
 | 
															
														||||
 | 
																                                           // matches
 | 
															
														||||
 | 
																            })
 | 
															
														||||
 | 
																            .collect();
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        let res = self
 | 
															
														||||
 | 
																            .products
 | 
															
														||||
 | 
																            .iter()
 | 
															
														||||
 | 
																            .map(|(c, p)| *c * p.iter().map(|&i| evals[i]).product::<F>())
 | 
															
														||||
 | 
																            .sum();
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        end_timer!(start);
 | 
															
														||||
 | 
																        Ok(res)
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    // Input poly f(x) and a random vector r, output
 | 
															
														||||
 | 
																    //      \hat f(x) = \sum_{x_i \in eval_x} f(x_i) eq(x, r)
 | 
															
														||||
 | 
																    // where
 | 
															
														||||
 | 
																    //      eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i))
 | 
															
														||||
 | 
																    //
 | 
															
														||||
 | 
																    // This function is used in ZeroCheck.
 | 
															
														||||
 | 
																    pub fn build_f_hat(&self, r: &[F]) -> Result<Self, ArithErrors> {
 | 
															
														||||
 | 
																        let start = start_timer!(|| "zero check build hat f");
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        if self.aux_info.num_variables != r.len() {
 | 
															
														||||
 | 
																            return Err(ArithErrors::InvalidParameters(format!(
 | 
															
														||||
 | 
																                "r.len() is different from number of variables: {} vs {}",
 | 
															
														||||
 | 
																                r.len(),
 | 
															
														||||
 | 
																                self.aux_info.num_variables
 | 
															
														||||
 | 
																            )));
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        let eq_x_r = build_eq_x_r(r)?;
 | 
															
														||||
 | 
																        let mut res = self.clone();
 | 
															
														||||
 | 
																        res.mul_by_mle(eq_x_r, F::one())?;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        end_timer!(start);
 | 
															
														||||
 | 
																        Ok(res)
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																/// Evaluate eq polynomial.
 | 
															
														||||
 | 
																pub fn eq_eval<F: PrimeField>(x: &[F], y: &[F]) -> Result<F, ArithErrors> {
 | 
															
														||||
 | 
																    if x.len() != y.len() {
 | 
															
														||||
 | 
																        return Err(ArithErrors::InvalidParameters(
 | 
															
														||||
 | 
																            "x and y have different length".to_string(),
 | 
															
														||||
 | 
																        ));
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																    let start = start_timer!(|| "eq_eval");
 | 
															
														||||
 | 
																    let mut res = F::one();
 | 
															
														||||
 | 
																    for (&xi, &yi) in x.iter().zip(y.iter()) {
 | 
															
														||||
 | 
																        let xi_yi = xi * yi;
 | 
															
														||||
 | 
																        res *= xi_yi + xi_yi - xi - yi + F::one();
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																    end_timer!(start);
 | 
															
														||||
 | 
																    Ok(res)
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																/// This function build the eq(x, r) polynomial for any given r.
 | 
															
														||||
 | 
																///
 | 
															
														||||
 | 
																/// Evaluate
 | 
															
														||||
 | 
																///      eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i))
 | 
															
														||||
 | 
																/// over r, which is
 | 
															
														||||
 | 
																///      eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i))
 | 
															
														||||
 | 
																fn build_eq_x_r<F: PrimeField>(r: &[F]) -> Result<Arc<DenseMultilinearExtension<F>>, ArithErrors> {
 | 
															
														||||
 | 
																    let evals = build_eq_x_r_vec(r)?;
 | 
															
														||||
 | 
																    let mle = DenseMultilinearExtension::from_evaluations_vec(r.len(), evals);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    Ok(Arc::new(mle))
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																/// This function build the eq(x, r) polynomial for any given r, and output the
 | 
															
														||||
 | 
																/// evaluation of eq(x, r) in its vector form.
 | 
															
														||||
 | 
																///
 | 
															
														||||
 | 
																/// Evaluate
 | 
															
														||||
 | 
																///      eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i))
 | 
															
														||||
 | 
																/// over r, which is
 | 
															
														||||
 | 
																///      eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i))
 | 
															
														||||
 | 
																fn build_eq_x_r_vec<F: PrimeField>(r: &[F]) -> Result<Vec<F>, ArithErrors> {
 | 
															
														||||
 | 
																    // we build eq(x,r) from its evaluations
 | 
															
														||||
 | 
																    // we want to evaluate eq(x,r) over x \in {0, 1}^num_vars
 | 
															
														||||
 | 
																    // for example, with num_vars = 4, x is a binary vector of 4, then
 | 
															
														||||
 | 
																    //  0 0 0 0 -> (1-r0)   * (1-r1)    * (1-r2)    * (1-r3)
 | 
															
														||||
 | 
																    //  1 0 0 0 -> r0       * (1-r1)    * (1-r2)    * (1-r3)
 | 
															
														||||
 | 
																    //  0 1 0 0 -> (1-r0)   * r1        * (1-r2)    * (1-r3)
 | 
															
														||||
 | 
																    //  1 1 0 0 -> r0       * r1        * (1-r2)    * (1-r3)
 | 
															
														||||
 | 
																    //  ....
 | 
															
														||||
 | 
																    //  1 1 1 1 -> r0       * r1        * r2        * r3
 | 
															
														||||
 | 
																    // we will need 2^num_var evaluations
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    let mut eval = Vec::new();
 | 
															
														||||
 | 
																    build_eq_x_r_helper(r, &mut eval)?;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    Ok(eval)
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																/// A helper function to build eq(x, r) recursively.
 | 
															
														||||
 | 
																/// This function takes `r.len()` steps, and for each step it requires a maximum
 | 
															
														||||
 | 
																/// `r.len()-1` multiplications.
 | 
															
														||||
 | 
																fn build_eq_x_r_helper<F: PrimeField>(r: &[F], buf: &mut Vec<F>) -> Result<(), ArithErrors> {
 | 
															
														||||
 | 
																    if r.is_empty() {
 | 
															
														||||
 | 
																        return Err(ArithErrors::InvalidParameters("r length is 0".to_string()));
 | 
															
														||||
 | 
																    } else if r.len() == 1 {
 | 
															
														||||
 | 
																        // initializing the buffer with [1-r_0, r_0]
 | 
															
														||||
 | 
																        buf.push(F::one() - r[0]);
 | 
															
														||||
 | 
																        buf.push(r[0]);
 | 
															
														||||
 | 
																    } else {
 | 
															
														||||
 | 
																        build_eq_x_r_helper(&r[1..], buf)?;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        // suppose at the previous step we received [b_1, ..., b_k]
 | 
															
														||||
 | 
																        // for the current step we will need
 | 
															
														||||
 | 
																        // if x_0 = 0:   (1-r0) * [b_1, ..., b_k]
 | 
															
														||||
 | 
																        // if x_0 = 1:   r0 * [b_1, ..., b_k]
 | 
															
														||||
 | 
																        // let mut res = vec![];
 | 
															
														||||
 | 
																        // for &b_i in buf.iter() {
 | 
															
														||||
 | 
																        //     let tmp = r[0] * b_i;
 | 
															
														||||
 | 
																        //     res.push(b_i - tmp);
 | 
															
														||||
 | 
																        //     res.push(tmp);
 | 
															
														||||
 | 
																        // }
 | 
															
														||||
 | 
																        // *buf = res;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        let mut res = vec![F::zero(); buf.len() << 1];
 | 
															
														||||
 | 
																        res.par_iter_mut().enumerate().for_each(|(i, val)| {
 | 
															
														||||
 | 
																            let bi = buf[i >> 1];
 | 
															
														||||
 | 
																            let tmp = r[0] * bi;
 | 
															
														||||
 | 
																            if i & 1 == 0 {
 | 
															
														||||
 | 
																                *val = bi - tmp;
 | 
															
														||||
 | 
																            } else {
 | 
															
														||||
 | 
																                *val = tmp;
 | 
															
														||||
 | 
																            }
 | 
															
														||||
 | 
																        });
 | 
															
														||||
 | 
																        *buf = res;
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    Ok(())
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																/// Decompose an integer into a binary vector in little endian.
 | 
															
														||||
 | 
																pub fn bit_decompose(input: u64, num_var: usize) -> Vec<bool> {
 | 
															
														||||
 | 
																    let mut res = Vec::with_capacity(num_var);
 | 
															
														||||
 | 
																    let mut i = input;
 | 
															
														||||
 | 
																    for _ in 0..num_var {
 | 
															
														||||
 | 
																        res.push(i & 1 == 1);
 | 
															
														||||
 | 
																        i >>= 1;
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																    res
 | 
															
														||||
 | 
																}
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																#[cfg(test)]
 | 
															
														||||
 | 
																mod test {
 | 
															
														||||
 | 
																    use super::*;
 | 
															
														||||
 | 
																    use crate::utils::multilinear_polynomial::tests::random_mle_list;
 | 
															
														||||
 | 
																    use ark_bls12_377::Fr;
 | 
															
														||||
 | 
																    use ark_ff::UniformRand;
 | 
															
														||||
 | 
																    use ark_std::{
 | 
															
														||||
 | 
																        rand::{Rng, RngCore},
 | 
															
														||||
 | 
																        test_rng,
 | 
															
														||||
 | 
																    };
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    impl<F: PrimeField> VirtualPolynomial<F> {
 | 
															
														||||
 | 
																        /// Sample a random virtual polynomial, return the polynomial and its sum.
 | 
															
														||||
 | 
																        fn rand<R: RngCore>(
 | 
															
														||||
 | 
																            nv: usize,
 | 
															
														||||
 | 
																            num_multiplicands_range: (usize, usize),
 | 
															
														||||
 | 
																            num_products: usize,
 | 
															
														||||
 | 
																            rng: &mut R,
 | 
															
														||||
 | 
																        ) -> Result<(Self, F), ArithErrors> {
 | 
															
														||||
 | 
																            let start = start_timer!(|| "sample random virtual polynomial");
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																            let mut sum = F::zero();
 | 
															
														||||
 | 
																            let mut poly = VirtualPolynomial::new(nv);
 | 
															
														||||
 | 
																            for _ in 0..num_products {
 | 
															
														||||
 | 
																                let num_multiplicands =
 | 
															
														||||
 | 
																                    rng.gen_range(num_multiplicands_range.0..num_multiplicands_range.1);
 | 
															
														||||
 | 
																                let (product, product_sum) = random_mle_list(nv, num_multiplicands, rng);
 | 
															
														||||
 | 
																                let coefficient = F::rand(rng);
 | 
															
														||||
 | 
																                poly.add_mle_list(product.into_iter(), coefficient)?;
 | 
															
														||||
 | 
																                sum += product_sum * coefficient;
 | 
															
														||||
 | 
																            }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																            end_timer!(start);
 | 
															
														||||
 | 
																            Ok((poly, sum))
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    #[test]
 | 
															
														||||
 | 
																    fn test_virtual_polynomial_additions() -> Result<(), ArithErrors> {
 | 
															
														||||
 | 
																        let mut rng = test_rng();
 | 
															
														||||
 | 
																        for nv in 2..5 {
 | 
															
														||||
 | 
																            for num_products in 2..5 {
 | 
															
														||||
 | 
																                let base: Vec<Fr> = (0..nv).map(|_| Fr::rand(&mut rng)).collect();
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																                let (a, _a_sum) =
 | 
															
														||||
 | 
																                    VirtualPolynomial::<Fr>::rand(nv, (2, 3), num_products, &mut rng)?;
 | 
															
														||||
 | 
																                let (b, _b_sum) =
 | 
															
														||||
 | 
																                    VirtualPolynomial::<Fr>::rand(nv, (2, 3), num_products, &mut rng)?;
 | 
															
														||||
 | 
																                let c = &a + &b;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																                assert_eq!(
 | 
															
														||||
 | 
																                    a.evaluate(base.as_ref())? + b.evaluate(base.as_ref())?,
 | 
															
														||||
 | 
																                    c.evaluate(base.as_ref())?
 | 
															
														||||
 | 
																                );
 | 
															
														||||
 | 
																            }
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        Ok(())
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    #[test]
 | 
															
														||||
 | 
																    fn test_virtual_polynomial_mul_by_mle() -> Result<(), ArithErrors> {
 | 
															
														||||
 | 
																        let mut rng = test_rng();
 | 
															
														||||
 | 
																        for nv in 2..5 {
 | 
															
														||||
 | 
																            for num_products in 2..5 {
 | 
															
														||||
 | 
																                let base: Vec<Fr> = (0..nv).map(|_| Fr::rand(&mut rng)).collect();
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																                let (a, _a_sum) =
 | 
															
														||||
 | 
																                    VirtualPolynomial::<Fr>::rand(nv, (2, 3), num_products, &mut rng)?;
 | 
															
														||||
 | 
																                let (b, _b_sum) = random_mle_list(nv, 1, &mut rng);
 | 
															
														||||
 | 
																                let b_mle = b[0].clone();
 | 
															
														||||
 | 
																                let coeff = Fr::rand(&mut rng);
 | 
															
														||||
 | 
																                let b_vp = VirtualPolynomial::new_from_mle(&b_mle, coeff);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																                let mut c = a.clone();
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																                c.mul_by_mle(b_mle, coeff)?;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																                assert_eq!(
 | 
															
														||||
 | 
																                    a.evaluate(base.as_ref())? * b_vp.evaluate(base.as_ref())?,
 | 
															
														||||
 | 
																                    c.evaluate(base.as_ref())?
 | 
															
														||||
 | 
																                );
 | 
															
														||||
 | 
																            }
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        Ok(())
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    #[test]
 | 
															
														||||
 | 
																    fn test_eq_xr() {
 | 
															
														||||
 | 
																        let mut rng = test_rng();
 | 
															
														||||
 | 
																        for nv in 4..10 {
 | 
															
														||||
 | 
																            let r: Vec<Fr> = (0..nv).map(|_| Fr::rand(&mut rng)).collect();
 | 
															
														||||
 | 
																            let eq_x_r = build_eq_x_r(r.as_ref()).unwrap();
 | 
															
														||||
 | 
																            let eq_x_r2 = build_eq_x_r_for_test(r.as_ref());
 | 
															
														||||
 | 
																            assert_eq!(eq_x_r, eq_x_r2);
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																    /// Naive method to build eq(x, r).
 | 
															
														||||
 | 
																    /// Only used for testing purpose.
 | 
															
														||||
 | 
																    // Evaluate
 | 
															
														||||
 | 
																    //      eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i))
 | 
															
														||||
 | 
																    // over r, which is
 | 
															
														||||
 | 
																    //      eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i))
 | 
															
														||||
 | 
																    fn build_eq_x_r_for_test<F: PrimeField>(r: &[F]) -> Arc<DenseMultilinearExtension<F>> {
 | 
															
														||||
 | 
																        // we build eq(x,r) from its evaluations
 | 
															
														||||
 | 
																        // we want to evaluate eq(x,r) over x \in {0, 1}^num_vars
 | 
															
														||||
 | 
																        // for example, with num_vars = 4, x is a binary vector of 4, then
 | 
															
														||||
 | 
																        //  0 0 0 0 -> (1-r0)   * (1-r1)    * (1-r2)    * (1-r3)
 | 
															
														||||
 | 
																        //  1 0 0 0 -> r0       * (1-r1)    * (1-r2)    * (1-r3)
 | 
															
														||||
 | 
																        //  0 1 0 0 -> (1-r0)   * r1        * (1-r2)    * (1-r3)
 | 
															
														||||
 | 
																        //  1 1 0 0 -> r0       * r1        * (1-r2)    * (1-r3)
 | 
															
														||||
 | 
																        //  ....
 | 
															
														||||
 | 
																        //  1 1 1 1 -> r0       * r1        * r2        * r3
 | 
															
														||||
 | 
																        // we will need 2^num_var evaluations
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        // First, we build array for {1 - r_i}
 | 
															
														||||
 | 
																        let one_minus_r: Vec<F> = r.iter().map(|ri| F::one() - ri).collect();
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        let num_var = r.len();
 | 
															
														||||
 | 
																        let mut eval = vec![];
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        for i in 0..1 << num_var {
 | 
															
														||||
 | 
																            let mut current_eval = F::one();
 | 
															
														||||
 | 
																            let bit_sequence = bit_decompose(i, num_var);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																            for (&bit, (ri, one_minus_ri)) in
 | 
															
														||||
 | 
																                bit_sequence.iter().zip(r.iter().zip(one_minus_r.iter()))
 | 
															
														||||
 | 
																            {
 | 
															
														||||
 | 
																                current_eval *= if bit { *ri } else { *one_minus_ri };
 | 
															
														||||
 | 
																            }
 | 
															
														||||
 | 
																            eval.push(current_eval);
 | 
															
														||||
 | 
																        }
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        let mle = DenseMultilinearExtension::from_evaluations_vec(num_var, eval);
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																        Arc::new(mle)
 | 
															
														||||
 | 
																    }
 | 
															
														||||
 | 
																}
 | 
															
														||||
@ -1 +1,7 @@ | 
															
														|||||
pub mod vec;
 | 
																pub mod vec;
 | 
															
														||||
 | 
																
 | 
															
														||||
 | 
																// expose espresso local modules
 | 
															
														||||
 | 
																pub mod espresso;
 | 
															
														||||
 | 
																pub use crate::utils::espresso::multilinear_polynomial;
 | 
															
														||||
 | 
																pub use crate::utils::espresso::sum_check;
 | 
															
														||||
 | 
																pub use crate::utils::espresso::virtual_polynomial;
 | 
															
														||||