From b4a0b50618ba9d9ab62c02b462fa35f7ac258918 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Tue, 29 Aug 2023 08:49:43 +0200 Subject: [PATCH] Port Espresso's VirtualPoly, MLE and SumCheck (#8) Port Espresso/hyperplonk's `virtualpolynomial`, `multilinearpolynomial` and `sum_check` utils from https://github.com/EspressoSystems/hyperplonk/tree/main Each file contains the reference to the original file. Porting it into a subdirectory `src/utils/espresso`, to have it self-contained. In future iterations we might replace part of it but we can keep focusing on the folding schemes part for now. --- Cargo.toml | 21 +- src/utils/espresso/mod.rs | 3 + src/utils/espresso/multilinear_polynomial.rs | 200 +++++++ src/utils/espresso/sum_check/mod.rs | 211 +++++++ src/utils/espresso/sum_check/prover.rs | 220 ++++++++ src/utils/espresso/sum_check/structs.rs | 59 ++ src/utils/espresso/sum_check/verifier.rs | 362 ++++++++++++ src/utils/espresso/virtual_polynomial.rs | 550 +++++++++++++++++++ src/utils/mod.rs | 6 + src/utils/vec.rs | 1 + 10 files changed, 1629 insertions(+), 4 deletions(-) create mode 100644 src/utils/espresso/mod.rs create mode 100644 src/utils/espresso/multilinear_polynomial.rs create mode 100644 src/utils/espresso/sum_check/mod.rs create mode 100644 src/utils/espresso/sum_check/prover.rs create mode 100644 src/utils/espresso/sum_check/structs.rs create mode 100644 src/utils/espresso/sum_check/verifier.rs create mode 100644 src/utils/espresso/virtual_polynomial.rs diff --git a/Cargo.toml b/Cargo.toml index ca5a9f2..816de33 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,17 +5,30 @@ edition = "2021" [dependencies] ark-ec = "0.4.2" -ark-ff = "0.4.2" -ark-std = "0.4.0" -ark-poly = "0.4.0" +ark-ff = "^0.4.0" +ark-poly = "^0.4.0" +ark-std = "^0.4.0" ark-crypto-primitives = { version = "^0.4.0", default-features = false, features = ["r1cs", "sponge"] } ark-relations = { version = "^0.4.0", default-features = false } ark-r1cs-std = { version = "^0.4.0", default-features = false } thiserror = "1.0" +rayon = "1.7.0" + +# tmp imports for espresso's sumcheck +ark-serialize = "0.4.2" +espresso_subroutines = {git="https://github.com/EspressoSystems/hyperplonk", package="subroutines"} +espresso_transcript = {git="https://github.com/EspressoSystems/hyperplonk", package="transcript"} + [dev-dependencies] ark-bls12-377 = "0.4.0" ark-bw6-761 = "0.4.0" [features] -default = [] +default = ["parallel"] + +parallel = [ + "ark-std/parallel", + "ark-ff/parallel", + "ark-poly/parallel", + ] diff --git a/src/utils/espresso/mod.rs b/src/utils/espresso/mod.rs new file mode 100644 index 0000000..8c11fd0 --- /dev/null +++ b/src/utils/espresso/mod.rs @@ -0,0 +1,3 @@ +pub mod multilinear_polynomial; +pub mod sum_check; +pub mod virtual_polynomial; diff --git a/src/utils/espresso/multilinear_polynomial.rs b/src/utils/espresso/multilinear_polynomial.rs new file mode 100644 index 0000000..da5d39a --- /dev/null +++ b/src/utils/espresso/multilinear_polynomial.rs @@ -0,0 +1,200 @@ +// code forked from +// https://github.com/EspressoSystems/hyperplonk/blob/main/arithmetic/src/multilinear_polynomial.rs +// +// Copyright (c) 2023 Espresso Systems (espressosys.com) +// This file is part of the HyperPlonk library. + +// You should have received a copy of the MIT License +// along with the HyperPlonk library. If not, see . + +use ark_ff::Field; +#[cfg(feature = "parallel")] +use rayon::prelude::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; + +pub use ark_poly::DenseMultilinearExtension; + +pub fn fix_variables( + poly: &DenseMultilinearExtension, + partial_point: &[F], +) -> DenseMultilinearExtension { + assert!( + partial_point.len() <= poly.num_vars, + "invalid size of partial point" + ); + let nv = poly.num_vars; + let mut poly = poly.evaluations.to_vec(); + let dim = partial_point.len(); + // evaluate single variable of partial point from left to right + for (i, point) in partial_point.iter().enumerate().take(dim) { + poly = fix_one_variable_helper(&poly, nv - i, point); + } + + DenseMultilinearExtension::::from_evaluations_slice(nv - dim, &poly[..(1 << (nv - dim))]) +} + +fn fix_one_variable_helper(data: &[F], nv: usize, point: &F) -> Vec { + let mut res = vec![F::zero(); 1 << (nv - 1)]; + + // evaluate single variable of partial point from left to right + #[cfg(not(feature = "parallel"))] + for i in 0..(1 << (nv - 1)) { + res[i] = data[i << 1] + (data[(i << 1) + 1] - data[i << 1]) * point; + } + + #[cfg(feature = "parallel")] + res.par_iter_mut().enumerate().for_each(|(i, x)| { + *x = data[i << 1] + (data[(i << 1) + 1] - data[i << 1]) * point; + }); + + res +} + +pub fn evaluate_no_par(poly: &DenseMultilinearExtension, point: &[F]) -> F { + assert_eq!(poly.num_vars, point.len()); + fix_variables_no_par(poly, point).evaluations[0] +} + +fn fix_variables_no_par( + poly: &DenseMultilinearExtension, + partial_point: &[F], +) -> DenseMultilinearExtension { + assert!( + partial_point.len() <= poly.num_vars, + "invalid size of partial point" + ); + let nv = poly.num_vars; + let mut poly = poly.evaluations.to_vec(); + let dim = partial_point.len(); + // evaluate single variable of partial point from left to right + for i in 1..dim + 1 { + let r = partial_point[i - 1]; + for b in 0..(1 << (nv - i)) { + poly[b] = poly[b << 1] + (poly[(b << 1) + 1] - poly[b << 1]) * r; + } + } + DenseMultilinearExtension::from_evaluations_slice(nv - dim, &poly[..(1 << (nv - dim))]) +} + +/// Given multilinear polynomial `p(x)` and s `s`, compute `s*p(x)` +pub fn scalar_mul( + poly: &DenseMultilinearExtension, + s: &F, +) -> DenseMultilinearExtension { + DenseMultilinearExtension { + evaluations: poly.evaluations.iter().map(|e| *e * s).collect(), + num_vars: poly.num_vars, + } +} + +/// Test-only methods used in virtual_polynomial.rs +#[cfg(test)] +pub mod tests { + use super::*; + use ark_ff::PrimeField; + use ark_std::rand::RngCore; + use ark_std::{end_timer, start_timer}; + use std::sync::Arc; + + pub fn fix_last_variables( + poly: &DenseMultilinearExtension, + partial_point: &[F], + ) -> DenseMultilinearExtension { + assert!( + partial_point.len() <= poly.num_vars, + "invalid size of partial point" + ); + let nv = poly.num_vars; + let mut poly = poly.evaluations.to_vec(); + let dim = partial_point.len(); + // evaluate single variable of partial point from left to right + for (i, point) in partial_point.iter().rev().enumerate().take(dim) { + poly = fix_last_variable_helper(&poly, nv - i, point); + } + + DenseMultilinearExtension::::from_evaluations_slice(nv - dim, &poly[..(1 << (nv - dim))]) + } + + fn fix_last_variable_helper(data: &[F], nv: usize, point: &F) -> Vec { + let half_len = 1 << (nv - 1); + let mut res = vec![F::zero(); half_len]; + + // evaluate single variable of partial point from left to right + #[cfg(not(feature = "parallel"))] + for b in 0..half_len { + res[b] = data[b] + (data[b + half_len] - data[b]) * point; + } + + #[cfg(feature = "parallel")] + res.par_iter_mut().enumerate().for_each(|(i, x)| { + *x = data[i] + (data[i + half_len] - data[i]) * point; + }); + + res + } + + /// Sample a random list of multilinear polynomials. + /// Returns + /// - the list of polynomials, + /// - its sum of polynomial evaluations over the boolean hypercube. + #[cfg(test)] + pub fn random_mle_list( + nv: usize, + degree: usize, + rng: &mut R, + ) -> (Vec>>, F) { + let start = start_timer!(|| "sample random mle list"); + let mut multiplicands = Vec::with_capacity(degree); + for _ in 0..degree { + multiplicands.push(Vec::with_capacity(1 << nv)) + } + let mut sum = F::zero(); + + for _ in 0..(1 << nv) { + let mut product = F::one(); + + for e in multiplicands.iter_mut() { + let val = F::rand(rng); + e.push(val); + product *= val; + } + sum += product; + } + + let list = multiplicands + .into_iter() + .map(|x| Arc::new(DenseMultilinearExtension::from_evaluations_vec(nv, x))) + .collect(); + + end_timer!(start); + (list, sum) + } + + // Build a randomize list of mle-s whose sum is zero. + #[cfg(test)] + pub fn random_zero_mle_list( + nv: usize, + degree: usize, + rng: &mut R, + ) -> Vec>> { + let start = start_timer!(|| "sample random zero mle list"); + + let mut multiplicands = Vec::with_capacity(degree); + for _ in 0..degree { + multiplicands.push(Vec::with_capacity(1 << nv)) + } + for _ in 0..(1 << nv) { + multiplicands[0].push(F::zero()); + for e in multiplicands.iter_mut().skip(1) { + e.push(F::rand(rng)); + } + } + + let list = multiplicands + .into_iter() + .map(|x| Arc::new(DenseMultilinearExtension::from_evaluations_vec(nv, x))) + .collect(); + + end_timer!(start); + list + } +} diff --git a/src/utils/espresso/sum_check/mod.rs b/src/utils/espresso/sum_check/mod.rs new file mode 100644 index 0000000..9126cda --- /dev/null +++ b/src/utils/espresso/sum_check/mod.rs @@ -0,0 +1,211 @@ +// code forked from: +// https://github.com/EspressoSystems/hyperplonk/tree/main/subroutines/src/poly_iop/sum_check +// +// Copyright (c) 2023 Espresso Systems (espressosys.com) +// This file is part of the HyperPlonk library. + +// You should have received a copy of the MIT License +// along with the HyperPlonk library. If not, see . + +//! This module implements the sum check protocol. + +use crate::utils::virtual_polynomial::{VPAuxInfo, VirtualPolynomial}; +use ark_ff::PrimeField; +use ark_poly::DenseMultilinearExtension; +use ark_std::{end_timer, start_timer}; +use std::{fmt::Debug, sync::Arc}; + +use espresso_subroutines::poly_iop::{prelude::PolyIOPErrors, PolyIOP}; +use espresso_transcript::IOPTranscript; +use structs::{IOPProof, IOPProverState, IOPVerifierState}; + +mod prover; +pub mod structs; +pub mod verifier; + +/// Trait for doing sum check protocols. +pub trait SumCheck { + type VirtualPolynomial; + type VPAuxInfo; + type MultilinearExtension; + + type SumCheckProof: Clone + Debug + Default + PartialEq; + type Transcript; + type SumCheckSubClaim: Clone + Debug + Default + PartialEq; + + /// Extract sum from the proof + fn extract_sum(proof: &Self::SumCheckProof) -> F; + + /// Initialize the system with a transcript + /// + /// This function is optional -- in the case where a SumCheck is + /// an building block for a more complex protocol, the transcript + /// may be initialized by this complex protocol, and passed to the + /// SumCheck prover/verifier. + fn init_transcript() -> Self::Transcript; + + /// Generate proof of the sum of polynomial over {0,1}^`num_vars` + /// + /// The polynomial is represented in the form of a VirtualPolynomial. + fn prove( + poly: &Self::VirtualPolynomial, + transcript: &mut Self::Transcript, + ) -> Result; + + /// Verify the claimed sum using the proof + fn verify( + sum: F, + proof: &Self::SumCheckProof, + aux_info: &Self::VPAuxInfo, + transcript: &mut Self::Transcript, + ) -> Result; +} + +/// Trait for sum check protocol prover side APIs. +pub trait SumCheckProver +where + Self: Sized, +{ + type VirtualPolynomial; + type ProverMessage; + + /// Initialize the prover state to argue for the sum of the input polynomial + /// over {0,1}^`num_vars`. + fn prover_init(polynomial: &Self::VirtualPolynomial) -> Result; + + /// Receive message from verifier, generate prover message, and proceed to + /// next round. + /// + /// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2). + fn prove_round_and_update_state( + &mut self, + challenge: &Option, + ) -> Result; +} + +/// Trait for sum check protocol verifier side APIs. +pub trait SumCheckVerifier { + type VPAuxInfo; + type ProverMessage; + type Challenge; + type Transcript; + type SumCheckSubClaim; + + /// Initialize the verifier's state. + fn verifier_init(index_info: &Self::VPAuxInfo) -> Self; + + /// Run verifier for the current round, given a prover message. + /// + /// Note that `verify_round_and_update_state` only samples and stores + /// challenges; and update the verifier's state accordingly. The actual + /// verifications are deferred (in batch) to `check_and_generate_subclaim` + /// at the last step. + fn verify_round_and_update_state( + &mut self, + prover_msg: &Self::ProverMessage, + transcript: &mut Self::Transcript, + ) -> Result; + + /// This function verifies the deferred checks in the interactive version of + /// the protocol; and generate the subclaim. Returns an error if the + /// proof failed to verify. + /// + /// If the asserted sum is correct, then the multilinear polynomial + /// evaluated at `subclaim.point` will be `subclaim.expected_evaluation`. + /// Otherwise, it is highly unlikely that those two will be equal. + /// Larger field size guarantees smaller soundness error. + fn check_and_generate_subclaim( + &self, + asserted_sum: &F, + ) -> Result; +} + +/// A SumCheckSubClaim is a claim generated by the verifier at the end of +/// verification when it is convinced. +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct SumCheckSubClaim { + /// the multi-dimensional point that this multilinear extension is evaluated + /// to + pub point: Vec, + /// the expected evaluation + pub expected_evaluation: F, +} + +impl SumCheck for PolyIOP { + type SumCheckProof = IOPProof; + type VirtualPolynomial = VirtualPolynomial; + type VPAuxInfo = VPAuxInfo; + type MultilinearExtension = Arc>; + type SumCheckSubClaim = SumCheckSubClaim; + type Transcript = IOPTranscript; + + fn extract_sum(proof: &Self::SumCheckProof) -> F { + let start = start_timer!(|| "extract sum"); + let res = proof.proofs[0].evaluations[0] + proof.proofs[0].evaluations[1]; + end_timer!(start); + res + } + + fn init_transcript() -> Self::Transcript { + let start = start_timer!(|| "init transcript"); + let res = IOPTranscript::::new(b"Initializing SumCheck transcript"); + end_timer!(start); + res + } + + fn prove( + poly: &Self::VirtualPolynomial, + transcript: &mut Self::Transcript, + ) -> Result { + let start = start_timer!(|| "sum check prove"); + + transcript.append_serializable_element(b"aux info", &poly.aux_info)?; + + let mut prover_state = IOPProverState::prover_init(poly)?; + let mut challenge = None; + let mut prover_msgs = Vec::with_capacity(poly.aux_info.num_variables); + for _ in 0..poly.aux_info.num_variables { + let prover_msg = + IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge)?; + transcript.append_serializable_element(b"prover msg", &prover_msg)?; + prover_msgs.push(prover_msg); + challenge = Some(transcript.get_and_append_challenge(b"Internal round")?); + } + // pushing the last challenge point to the state + if let Some(p) = challenge { + prover_state.challenges.push(p) + }; + + end_timer!(start); + Ok(IOPProof { + point: prover_state.challenges, + proofs: prover_msgs, + }) + } + + fn verify( + claimed_sum: F, + proof: &Self::SumCheckProof, + aux_info: &Self::VPAuxInfo, + transcript: &mut Self::Transcript, + ) -> Result { + let start = start_timer!(|| "sum check verify"); + + transcript.append_serializable_element(b"aux info", aux_info)?; + let mut verifier_state = IOPVerifierState::verifier_init(aux_info); + for i in 0..aux_info.num_variables { + let prover_msg = proof.proofs.get(i).expect("proof is incomplete"); + transcript.append_serializable_element(b"prover msg", prover_msg)?; + IOPVerifierState::verify_round_and_update_state( + &mut verifier_state, + prover_msg, + transcript, + )?; + } + + let res = IOPVerifierState::check_and_generate_subclaim(&verifier_state, &claimed_sum); + + end_timer!(start); + res + } +} diff --git a/src/utils/espresso/sum_check/prover.rs b/src/utils/espresso/sum_check/prover.rs new file mode 100644 index 0000000..3b57283 --- /dev/null +++ b/src/utils/espresso/sum_check/prover.rs @@ -0,0 +1,220 @@ +// code forked from: +// https://github.com/EspressoSystems/hyperplonk/tree/main/subroutines/src/poly_iop/sum_check +// +// Copyright (c) 2023 Espresso Systems (espressosys.com) +// This file is part of the HyperPlonk library. + +// You should have received a copy of the MIT License +// along with the HyperPlonk library. If not, see . + +//! Prover subroutines for a SumCheck protocol. + +use super::SumCheckProver; +use crate::utils::multilinear_polynomial::fix_variables; +use crate::utils::virtual_polynomial::VirtualPolynomial; +use ark_ff::{batch_inversion, PrimeField}; +use ark_poly::DenseMultilinearExtension; +use ark_std::{cfg_into_iter, end_timer, start_timer, vec::Vec}; +use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator}; +use std::sync::Arc; + +use super::structs::{IOPProverMessage, IOPProverState}; +use espresso_subroutines::poly_iop::prelude::PolyIOPErrors; + +// #[cfg(feature = "parallel")] +use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; + +impl SumCheckProver for IOPProverState { + type VirtualPolynomial = VirtualPolynomial; + type ProverMessage = IOPProverMessage; + + /// Initialize the prover state to argue for the sum of the input polynomial + /// over {0,1}^`num_vars`. + fn prover_init(polynomial: &Self::VirtualPolynomial) -> Result { + let start = start_timer!(|| "sum check prover init"); + if polynomial.aux_info.num_variables == 0 { + return Err(PolyIOPErrors::InvalidParameters( + "Attempt to prove a constant.".to_string(), + )); + } + end_timer!(start); + + Ok(Self { + challenges: Vec::with_capacity(polynomial.aux_info.num_variables), + round: 0, + poly: polynomial.clone(), + extrapolation_aux: (1..polynomial.aux_info.max_degree) + .map(|degree| { + let points = (0..1 + degree as u64).map(F::from).collect::>(); + let weights = barycentric_weights(&points); + (points, weights) + }) + .collect(), + }) + } + + /// Receive message from verifier, generate prover message, and proceed to + /// next round. + /// + /// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2). + fn prove_round_and_update_state( + &mut self, + challenge: &Option, + ) -> Result { + // let start = + // start_timer!(|| format!("sum check prove {}-th round and update state", + // self.round)); + + if self.round >= self.poly.aux_info.num_variables { + return Err(PolyIOPErrors::InvalidProver( + "Prover is not active".to_string(), + )); + } + + // let fix_argument = start_timer!(|| "fix argument"); + + // Step 1: + // fix argument and evaluate f(x) over x_m = r; where r is the challenge + // for the current round, and m is the round number, indexed from 1 + // + // i.e.: + // at round m <= n, for each mle g(x_1, ... x_n) within the flattened_mle + // which has already been evaluated to + // + // g(r_1, ..., r_{m-1}, x_m ... x_n) + // + // eval g over r_m, and mutate g to g(r_1, ... r_m,, x_{m+1}... x_n) + let mut flattened_ml_extensions: Vec> = self + .poly + .flattened_ml_extensions + .par_iter() + .map(|x| x.as_ref().clone()) + .collect(); + + if let Some(chal) = challenge { + if self.round == 0 { + return Err(PolyIOPErrors::InvalidProver( + "first round should be prover first.".to_string(), + )); + } + self.challenges.push(*chal); + + let r = self.challenges[self.round - 1]; + // #[cfg(feature = "parallel")] + flattened_ml_extensions + .par_iter_mut() + .for_each(|mle| *mle = fix_variables(mle, &[r])); + // #[cfg(not(feature = "parallel"))] + // flattened_ml_extensions + // .iter_mut() + // .for_each(|mle| *mle = fix_variables(mle, &[r])); + } else if self.round > 0 { + return Err(PolyIOPErrors::InvalidProver( + "verifier message is empty".to_string(), + )); + } + // end_timer!(fix_argument); + + self.round += 1; + + let products_list = self.poly.products.clone(); + let mut products_sum = vec![F::zero(); self.poly.aux_info.max_degree + 1]; + + // Step 2: generate sum for the partial evaluated polynomial: + // f(r_1, ... r_m,, x_{m+1}... x_n) + + products_list.iter().for_each(|(coefficient, products)| { + let mut sum = cfg_into_iter!(0..1 << (self.poly.aux_info.num_variables - self.round)) + .fold( + || { + ( + vec![(F::zero(), F::zero()); products.len()], + vec![F::zero(); products.len() + 1], + ) + }, + |(mut buf, mut acc), b| { + buf.iter_mut() + .zip(products.iter()) + .for_each(|((eval, step), f)| { + let table = &flattened_ml_extensions[*f]; + *eval = table[b << 1]; + *step = table[(b << 1) + 1] - table[b << 1]; + }); + acc[0] += buf.iter().map(|(eval, _)| eval).product::(); + acc[1..].iter_mut().for_each(|acc| { + buf.iter_mut().for_each(|(eval, step)| *eval += step as &_); + *acc += buf.iter().map(|(eval, _)| eval).product::(); + }); + (buf, acc) + }, + ) + .map(|(_, partial)| partial) + .reduce( + || vec![F::zero(); products.len() + 1], + |mut sum, partial| { + sum.iter_mut() + .zip(partial.iter()) + .for_each(|(sum, partial)| *sum += partial); + sum + }, + ); + sum.iter_mut().for_each(|sum| *sum *= coefficient); + let extraploation = cfg_into_iter!(0..self.poly.aux_info.max_degree - products.len()) + .map(|i| { + let (points, weights) = &self.extrapolation_aux[products.len() - 1]; + let at = F::from((products.len() + 1 + i) as u64); + extrapolate(points, weights, &sum, &at) + }) + .collect::>(); + products_sum + .iter_mut() + .zip(sum.iter().chain(extraploation.iter())) + .for_each(|(products_sum, sum)| *products_sum += sum); + }); + + // update prover's state to the partial evaluated polynomial + self.poly.flattened_ml_extensions = flattened_ml_extensions + .par_iter() + .map(|x| Arc::new(x.clone())) + .collect(); + + Ok(IOPProverMessage { + evaluations: products_sum, + }) + } +} + +fn barycentric_weights(points: &[F]) -> Vec { + let mut weights = points + .iter() + .enumerate() + .map(|(j, point_j)| { + points + .iter() + .enumerate() + .filter_map(|(i, point_i)| (i != j).then(|| *point_j - point_i)) + .reduce(|acc, value| acc * value) + .unwrap_or_else(F::one) + }) + .collect::>(); + batch_inversion(&mut weights); + weights +} + +fn extrapolate(points: &[F], weights: &[F], evals: &[F], at: &F) -> F { + let (coeffs, sum_inv) = { + let mut coeffs = points.iter().map(|point| *at - point).collect::>(); + batch_inversion(&mut coeffs); + coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| { + *coeff *= weight; + }); + let sum_inv = coeffs.iter().sum::().inverse().unwrap_or_default(); + (coeffs, sum_inv) + }; + coeffs + .iter() + .zip(evals) + .map(|(coeff, eval)| *coeff * eval) + .sum::() + * sum_inv +} diff --git a/src/utils/espresso/sum_check/structs.rs b/src/utils/espresso/sum_check/structs.rs new file mode 100644 index 0000000..88b855a --- /dev/null +++ b/src/utils/espresso/sum_check/structs.rs @@ -0,0 +1,59 @@ +// code forked from: +// https://github.com/EspressoSystems/hyperplonk/tree/main/subroutines/src/poly_iop/sum_check +// +// Copyright (c) 2023 Espresso Systems (espressosys.com) +// This file is part of the HyperPlonk library. + +// You should have received a copy of the MIT License +// along with the HyperPlonk library. If not, see . + +//! This module defines structs that are shared by all sub protocols. + +use crate::utils::virtual_polynomial::VirtualPolynomial; +use ark_ff::PrimeField; +use ark_serialize::CanonicalSerialize; + +/// An IOP proof is a collections of +/// - messages from prover to verifier at each round through the interactive +/// protocol. +/// - a point that is generated by the transcript for evaluation +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct IOPProof { + pub point: Vec, + pub proofs: Vec>, +} + +/// A message from the prover to the verifier at a given round +/// is a list of evaluations. +#[derive(Clone, Debug, Default, PartialEq, Eq, CanonicalSerialize)] +pub struct IOPProverMessage { + pub(crate) evaluations: Vec, +} + +/// Prover State of a PolyIOP. +#[derive(Debug)] +pub struct IOPProverState { + /// sampled randomness given by the verifier + pub challenges: Vec, + /// the current round number + pub(crate) round: usize, + /// pointer to the virtual polynomial + pub(crate) poly: VirtualPolynomial, + /// points with precomputed barycentric weights for extrapolating smaller + /// degree uni-polys to `max_degree + 1` evaluations. + pub(crate) extrapolation_aux: Vec<(Vec, Vec)>, +} + +/// Prover State of a PolyIOP +#[derive(Debug)] +pub struct IOPVerifierState { + pub(crate) round: usize, + pub(crate) num_vars: usize, + pub(crate) max_degree: usize, + pub(crate) finished: bool, + /// a list storing the univariate polynomial in evaluation form sent by the + /// prover at each round + pub(crate) polynomials_received: Vec>, + /// a list storing the randomness sampled by the verifier at each round + pub(crate) challenges: Vec, +} diff --git a/src/utils/espresso/sum_check/verifier.rs b/src/utils/espresso/sum_check/verifier.rs new file mode 100644 index 0000000..cb60928 --- /dev/null +++ b/src/utils/espresso/sum_check/verifier.rs @@ -0,0 +1,362 @@ +// code forked from: +// https://github.com/EspressoSystems/hyperplonk/tree/main/subroutines/src/poly_iop/sum_check +// +// Copyright (c) 2023 Espresso Systems (espressosys.com) +// This file is part of the HyperPlonk library. + +// You should have received a copy of the MIT License +// along with the HyperPlonk library. If not, see . + +//! Verifier subroutines for a SumCheck protocol. + +use super::{SumCheckSubClaim, SumCheckVerifier}; +use crate::utils::virtual_polynomial::VPAuxInfo; +use ark_ff::PrimeField; +use ark_std::{end_timer, start_timer}; + +use super::structs::{IOPProverMessage, IOPVerifierState}; +use espresso_subroutines::poly_iop::prelude::PolyIOPErrors; +use espresso_transcript::IOPTranscript; + +#[cfg(feature = "parallel")] +use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; + +impl SumCheckVerifier for IOPVerifierState { + type VPAuxInfo = VPAuxInfo; + type ProverMessage = IOPProverMessage; + type Challenge = F; + type Transcript = IOPTranscript; + type SumCheckSubClaim = SumCheckSubClaim; + + /// Initialize the verifier's state. + fn verifier_init(index_info: &Self::VPAuxInfo) -> Self { + let start = start_timer!(|| "sum check verifier init"); + let res = Self { + round: 1, + num_vars: index_info.num_variables, + max_degree: index_info.max_degree, + finished: false, + polynomials_received: Vec::with_capacity(index_info.num_variables), + challenges: Vec::with_capacity(index_info.num_variables), + }; + end_timer!(start); + res + } + + /// Run verifier for the current round, given a prover message. + /// + /// Note that `verify_round_and_update_state` only samples and stores + /// challenges; and update the verifier's state accordingly. The actual + /// verifications are deferred (in batch) to `check_and_generate_subclaim` + /// at the last step. + fn verify_round_and_update_state( + &mut self, + prover_msg: &Self::ProverMessage, + transcript: &mut Self::Transcript, + ) -> Result { + let start = + start_timer!(|| format!("sum check verify {}-th round and update state", self.round)); + + if self.finished { + return Err(PolyIOPErrors::InvalidVerifier( + "Incorrect verifier state: Verifier is already finished.".to_string(), + )); + } + + // In an interactive protocol, the verifier should + // + // 1. check if the received 'P(0) + P(1) = expected`. + // 2. set `expected` to P(r)` + // + // When we turn the protocol to a non-interactive one, it is sufficient to defer + // such checks to `check_and_generate_subclaim` after the last round. + + let challenge = transcript.get_and_append_challenge(b"Internal round")?; + self.challenges.push(challenge); + self.polynomials_received + .push(prover_msg.evaluations.to_vec()); + + if self.round == self.num_vars { + // accept and close + self.finished = true; + } else { + // proceed to the next round + self.round += 1; + } + + end_timer!(start); + Ok(challenge) + } + + /// This function verifies the deferred checks in the interactive version of + /// the protocol; and generate the subclaim. Returns an error if the + /// proof failed to verify. + /// + /// If the asserted sum is correct, then the multilinear polynomial + /// evaluated at `subclaim.point` will be `subclaim.expected_evaluation`. + /// Otherwise, it is highly unlikely that those two will be equal. + /// Larger field size guarantees smaller soundness error. + fn check_and_generate_subclaim( + &self, + asserted_sum: &F, + ) -> Result { + let start = start_timer!(|| "sum check check and generate subclaim"); + if !self.finished { + return Err(PolyIOPErrors::InvalidVerifier( + "Incorrect verifier state: Verifier has not finished.".to_string(), + )); + } + + if self.polynomials_received.len() != self.num_vars { + return Err(PolyIOPErrors::InvalidVerifier( + "insufficient rounds".to_string(), + )); + } + + // the deferred check during the interactive phase: + // 2. set `expected` to P(r)` + #[cfg(feature = "parallel")] + let mut expected_vec = self + .polynomials_received + .clone() + .into_par_iter() + .zip(self.challenges.clone().into_par_iter()) + .map(|(evaluations, challenge)| { + if evaluations.len() != self.max_degree + 1 { + return Err(PolyIOPErrors::InvalidVerifier(format!( + "incorrect number of evaluations: {} vs {}", + evaluations.len(), + self.max_degree + 1 + ))); + } + interpolate_uni_poly::(&evaluations, challenge) + }) + .collect::, PolyIOPErrors>>()?; + + #[cfg(not(feature = "parallel"))] + let mut expected_vec = self + .polynomials_received + .clone() + .into_iter() + .zip(self.challenges.clone().into_iter()) + .map(|(evaluations, challenge)| { + if evaluations.len() != self.max_degree + 1 { + return Err(PolyIOPErrors::InvalidVerifier(format!( + "incorrect number of evaluations: {} vs {}", + evaluations.len(), + self.max_degree + 1 + ))); + } + interpolate_uni_poly::(&evaluations, challenge) + }) + .collect::, PolyIOPErrors>>()?; + + // insert the asserted_sum to the first position of the expected vector + expected_vec.insert(0, *asserted_sum); + + for (evaluations, &expected) in self + .polynomials_received + .iter() + .zip(expected_vec.iter()) + .take(self.num_vars) + { + // the deferred check during the interactive phase: + // 1. check if the received 'P(0) + P(1) = expected`. + if evaluations[0] + evaluations[1] != expected { + return Err(PolyIOPErrors::InvalidProof( + "Prover message is not consistent with the claim.".to_string(), + )); + } + } + end_timer!(start); + Ok(SumCheckSubClaim { + point: self.challenges.clone(), + // the last expected value (not checked within this function) will be included in the + // subclaim + expected_evaluation: expected_vec[self.num_vars], + }) + } +} + +/// Interpolate a uni-variate degree-`p_i.len()-1` polynomial and evaluate this +/// polynomial at `eval_at`: +/// +/// \sum_{i=0}^len p_i * (\prod_{j!=i} (eval_at - j)/(i-j) ) +/// +/// This implementation is linear in number of inputs in terms of field +/// operations. It also has a quadratic term in primitive operations which is +/// negligible compared to field operations. +/// TODO: The quadratic term can be removed by precomputing the lagrange +/// coefficients. +pub fn interpolate_uni_poly(p_i: &[F], eval_at: F) -> Result { + let start = start_timer!(|| "sum check interpolate uni poly opt"); + + let len = p_i.len(); + let mut evals = vec![]; + let mut prod = eval_at; + evals.push(eval_at); + + // `prod = \prod_{j} (eval_at - j)` + for e in 1..len { + let tmp = eval_at - F::from(e as u64); + evals.push(tmp); + prod *= tmp; + } + let mut res = F::zero(); + // we want to compute \prod (j!=i) (i-j) for a given i + // + // we start from the last step, which is + // denom[len-1] = (len-1) * (len-2) *... * 2 * 1 + // the step before that is + // denom[len-2] = (len-2) * (len-3) * ... * 2 * 1 * -1 + // and the step before that is + // denom[len-3] = (len-3) * (len-4) * ... * 2 * 1 * -1 * -2 + // + // i.e., for any i, the one before this will be derived from + // denom[i-1] = denom[i] * (len-i) / i + // + // that is, we only need to store + // - the last denom for i = len-1, and + // - the ratio between current step and fhe last step, which is the product of + // (len-i) / i from all previous steps and we store this product as a fraction + // number to reduce field divisions. + + // We know + // - 2^61 < factorial(20) < 2^62 + // - 2^122 < factorial(33) < 2^123 + // so we will be able to compute the ratio + // - for len <= 20 with i64 + // - for len <= 33 with i128 + // - for len > 33 with BigInt + if p_i.len() <= 20 { + let last_denominator = F::from(u64_factorial(len - 1)); + let mut ratio_numerator = 1i64; + let mut ratio_denominator = 1u64; + + for i in (0..len).rev() { + let ratio_numerator_f = if ratio_numerator < 0 { + -F::from((-ratio_numerator) as u64) + } else { + F::from(ratio_numerator as u64) + }; + + res += p_i[i] * prod * F::from(ratio_denominator) + / (last_denominator * ratio_numerator_f * evals[i]); + + // compute denom for the next step is current_denom * (len-i)/i + if i != 0 { + ratio_numerator *= -(len as i64 - i as i64); + ratio_denominator *= i as u64; + } + } + } else if p_i.len() <= 33 { + let last_denominator = F::from(u128_factorial(len - 1)); + let mut ratio_numerator = 1i128; + let mut ratio_denominator = 1u128; + + for i in (0..len).rev() { + let ratio_numerator_f = if ratio_numerator < 0 { + -F::from((-ratio_numerator) as u128) + } else { + F::from(ratio_numerator as u128) + }; + + res += p_i[i] * prod * F::from(ratio_denominator) + / (last_denominator * ratio_numerator_f * evals[i]); + + // compute denom for the next step is current_denom * (len-i)/i + if i != 0 { + ratio_numerator *= -(len as i128 - i as i128); + ratio_denominator *= i as u128; + } + } + } else { + let mut denom_up = field_factorial::(len - 1); + let mut denom_down = F::one(); + + for i in (0..len).rev() { + res += p_i[i] * prod * denom_down / (denom_up * evals[i]); + + // compute denom for the next step is current_denom * (len-i)/i + if i != 0 { + denom_up *= -F::from((len - i) as u64); + denom_down *= F::from(i as u64); + } + } + } + end_timer!(start); + Ok(res) +} + +/// compute the factorial(a) = 1 * 2 * ... * a +#[inline] +fn field_factorial(a: usize) -> F { + let mut res = F::one(); + for i in 2..=a { + res *= F::from(i as u64); + } + res +} + +/// compute the factorial(a) = 1 * 2 * ... * a +#[inline] +fn u128_factorial(a: usize) -> u128 { + let mut res = 1u128; + for i in 2..=a { + res *= i as u128; + } + res +} + +/// compute the factorial(a) = 1 * 2 * ... * a +#[inline] +fn u64_factorial(a: usize) -> u64 { + let mut res = 1u64; + for i in 2..=a { + res *= i as u64; + } + res +} + +#[cfg(test)] +mod test { + use super::interpolate_uni_poly; + use ark_bls12_377::Fr; + use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial, Polynomial}; + use ark_std::{vec::Vec, UniformRand}; + use espresso_subroutines::poly_iop::prelude::PolyIOPErrors; + + #[test] + fn test_interpolation() -> Result<(), PolyIOPErrors> { + let mut prng = ark_std::test_rng(); + + // test a polynomial with 20 known points, i.e., with degree 19 + let poly = DensePolynomial::::rand(20 - 1, &mut prng); + let evals = (0..20) + .map(|i| poly.evaluate(&Fr::from(i))) + .collect::>(); + let query = Fr::rand(&mut prng); + + assert_eq!(poly.evaluate(&query), interpolate_uni_poly(&evals, query)?); + + // test a polynomial with 33 known points, i.e., with degree 32 + let poly = DensePolynomial::::rand(33 - 1, &mut prng); + let evals = (0..33) + .map(|i| poly.evaluate(&Fr::from(i))) + .collect::>(); + let query = Fr::rand(&mut prng); + + assert_eq!(poly.evaluate(&query), interpolate_uni_poly(&evals, query)?); + + // test a polynomial with 64 known points, i.e., with degree 63 + let poly = DensePolynomial::::rand(64 - 1, &mut prng); + let evals = (0..64) + .map(|i| poly.evaluate(&Fr::from(i))) + .collect::>(); + let query = Fr::rand(&mut prng); + + assert_eq!(poly.evaluate(&query), interpolate_uni_poly(&evals, query)?); + + Ok(()) + } +} diff --git a/src/utils/espresso/virtual_polynomial.rs b/src/utils/espresso/virtual_polynomial.rs new file mode 100644 index 0000000..25a1f0f --- /dev/null +++ b/src/utils/espresso/virtual_polynomial.rs @@ -0,0 +1,550 @@ +// code forked from +// https://github.com/privacy-scaling-explorations/multifolding-poc/blob/main/src/espresso/virtual_polynomial.rs +// +// Copyright (c) 2023 Espresso Systems (espressosys.com) +// This file is part of the HyperPlonk library. + +// You should have received a copy of the MIT License +// along with the HyperPlonk library. If not, see . + +//! This module defines our main mathematical object `VirtualPolynomial`; and +//! various functions associated with it. + +use ark_ff::PrimeField; +use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; +use ark_serialize::CanonicalSerialize; +use ark_std::{end_timer, start_timer}; +use rayon::prelude::*; +use std::{cmp::max, collections::HashMap, marker::PhantomData, ops::Add, sync::Arc}; +use thiserror::Error; + +use ark_std::string::String; + +//-- aritherrors +/// A `enum` specifying the possible failure modes of the arithmetics. +#[derive(Error, Debug)] +pub enum ArithErrors { + #[error("Invalid parameters: {0}")] + InvalidParameters(String), + #[error("Should not arrive to this point")] + ShouldNotArrive, + #[error("An error during (de)serialization: {0}")] + SerializationErrors(ark_serialize::SerializationError), +} + +impl From for ArithErrors { + fn from(e: ark_serialize::SerializationError) -> Self { + Self::SerializationErrors(e) + } +} +//-- aritherrors + +#[rustfmt::skip] +/// A virtual polynomial is a sum of products of multilinear polynomials; +/// where the multilinear polynomials are stored via their multilinear +/// extensions: `(coefficient, DenseMultilinearExtension)` +/// +/// * Number of products n = `polynomial.products.len()`, +/// * Number of multiplicands of ith product m_i = +/// `polynomial.products[i].1.len()`, +/// * Coefficient of ith product c_i = `polynomial.products[i].0` +/// +/// The resulting polynomial is +/// +/// $$ \sum_{i=0}^{n} c_i \cdot \prod_{j=0}^{m_i} P_{ij} $$ +/// +/// Example: +/// f = c0 * f0 * f1 * f2 + c1 * f3 * f4 +/// where f0 ... f4 are multilinear polynomials +/// +/// - flattened_ml_extensions stores the multilinear extension representation of +/// f0, f1, f2, f3 and f4 +/// - products is +/// \[ +/// (c0, \[0, 1, 2\]), +/// (c1, \[3, 4\]) +/// \] +/// - raw_pointers_lookup_table maps fi to i +/// +#[derive(Clone, Debug, Default, PartialEq)] +pub struct VirtualPolynomial { + /// Aux information about the multilinear polynomial + pub aux_info: VPAuxInfo, + /// list of reference to products (as usize) of multilinear extension + pub products: Vec<(F, Vec)>, + /// Stores multilinear extensions in which product multiplicand can refer + /// to. + pub flattened_ml_extensions: Vec>>, + /// Pointers to the above poly extensions + raw_pointers_lookup_table: HashMap<*const DenseMultilinearExtension, usize>, +} + +#[derive(Clone, Debug, Default, PartialEq, Eq, CanonicalSerialize)] +/// Auxiliary information about the multilinear polynomial +pub struct VPAuxInfo { + /// max number of multiplicands in each product + pub max_degree: usize, + /// number of variables of the polynomial + pub num_variables: usize, + /// Associated field + #[doc(hidden)] + pub phantom: PhantomData, +} + +impl Add for &VirtualPolynomial { + type Output = VirtualPolynomial; + fn add(self, other: &VirtualPolynomial) -> Self::Output { + let start = start_timer!(|| "virtual poly add"); + let mut res = self.clone(); + for products in other.products.iter() { + let cur: Vec>> = products + .1 + .iter() + .map(|&x| other.flattened_ml_extensions[x].clone()) + .collect(); + + res.add_mle_list(cur, products.0) + .expect("add product failed"); + } + end_timer!(start); + res + } +} + +// TODO: convert this into a trait +impl VirtualPolynomial { + /// Creates an empty virtual polynomial with `num_variables`. + pub fn new(num_variables: usize) -> Self { + VirtualPolynomial { + aux_info: VPAuxInfo { + max_degree: 0, + num_variables, + phantom: PhantomData, + }, + products: Vec::new(), + flattened_ml_extensions: Vec::new(), + raw_pointers_lookup_table: HashMap::new(), + } + } + + /// Creates an new virtual polynomial from a MLE and its coefficient. + pub fn new_from_mle(mle: &Arc>, coefficient: F) -> Self { + let mle_ptr: *const DenseMultilinearExtension = Arc::as_ptr(mle); + let mut hm = HashMap::new(); + hm.insert(mle_ptr, 0); + + VirtualPolynomial { + aux_info: VPAuxInfo { + // The max degree is the max degree of any individual variable + max_degree: 1, + num_variables: mle.num_vars, + phantom: PhantomData, + }, + // here `0` points to the first polynomial of `flattened_ml_extensions` + products: vec![(coefficient, vec![0])], + flattened_ml_extensions: vec![mle.clone()], + raw_pointers_lookup_table: hm, + } + } + + /// Add a product of list of multilinear extensions to self + /// Returns an error if the list is empty, or the MLE has a different + /// `num_vars` from self. + /// + /// The MLEs will be multiplied together, and then multiplied by the scalar + /// `coefficient`. + pub fn add_mle_list( + &mut self, + mle_list: impl IntoIterator>>, + coefficient: F, + ) -> Result<(), ArithErrors> { + let mle_list: Vec>> = mle_list.into_iter().collect(); + let mut indexed_product = Vec::with_capacity(mle_list.len()); + + if mle_list.is_empty() { + return Err(ArithErrors::InvalidParameters( + "input mle_list is empty".to_string(), + )); + } + + self.aux_info.max_degree = max(self.aux_info.max_degree, mle_list.len()); + + for mle in mle_list { + if mle.num_vars != self.aux_info.num_variables { + return Err(ArithErrors::InvalidParameters(format!( + "product has a multiplicand with wrong number of variables {} vs {}", + mle.num_vars, self.aux_info.num_variables + ))); + } + + let mle_ptr: *const DenseMultilinearExtension = Arc::as_ptr(&mle); + if let Some(index) = self.raw_pointers_lookup_table.get(&mle_ptr) { + indexed_product.push(*index) + } else { + let curr_index = self.flattened_ml_extensions.len(); + self.flattened_ml_extensions.push(mle.clone()); + self.raw_pointers_lookup_table.insert(mle_ptr, curr_index); + indexed_product.push(curr_index); + } + } + self.products.push((coefficient, indexed_product)); + Ok(()) + } + + /// Multiple the current VirtualPolynomial by an MLE: + /// - add the MLE to the MLE list; + /// - multiple each product by MLE and its coefficient. + /// Returns an error if the MLE has a different `num_vars` from self. + pub fn mul_by_mle( + &mut self, + mle: Arc>, + coefficient: F, + ) -> Result<(), ArithErrors> { + let start = start_timer!(|| "mul by mle"); + + if mle.num_vars != self.aux_info.num_variables { + return Err(ArithErrors::InvalidParameters(format!( + "product has a multiplicand with wrong number of variables {} vs {}", + mle.num_vars, self.aux_info.num_variables + ))); + } + + let mle_ptr: *const DenseMultilinearExtension = Arc::as_ptr(&mle); + + // check if this mle already exists in the virtual polynomial + let mle_index = match self.raw_pointers_lookup_table.get(&mle_ptr) { + Some(&p) => p, + None => { + self.raw_pointers_lookup_table + .insert(mle_ptr, self.flattened_ml_extensions.len()); + self.flattened_ml_extensions.push(mle); + self.flattened_ml_extensions.len() - 1 + } + }; + + for (prod_coef, indices) in self.products.iter_mut() { + // - add the MLE to the MLE list; + // - multiple each product by MLE and its coefficient. + indices.push(mle_index); + *prod_coef *= coefficient; + } + + // increase the max degree by one as the MLE has degree 1. + self.aux_info.max_degree += 1; + end_timer!(start); + Ok(()) + } + + /// Given virtual polynomial `p(x)` and scalar `s`, compute `s*p(x)` + pub fn scalar_mul(&mut self, s: &F) { + for (prod_coef, _) in self.products.iter_mut() { + *prod_coef *= s; + } + } + + /// Evaluate the virtual polynomial at point `point`. + /// Returns an error is point.len() does not match `num_variables`. + pub fn evaluate(&self, point: &[F]) -> Result { + let start = start_timer!(|| "evaluation"); + + if self.aux_info.num_variables != point.len() { + return Err(ArithErrors::InvalidParameters(format!( + "wrong number of variables {} vs {}", + self.aux_info.num_variables, + point.len() + ))); + } + + // Evaluate all the MLEs at `point` + let evals: Vec = self + .flattened_ml_extensions + .iter() + .map(|x| { + x.evaluate(point).unwrap() // safe unwrap here since we have + // already checked that num_var + // matches + }) + .collect(); + + let res = self + .products + .iter() + .map(|(c, p)| *c * p.iter().map(|&i| evals[i]).product::()) + .sum(); + + end_timer!(start); + Ok(res) + } + + // Input poly f(x) and a random vector r, output + // \hat f(x) = \sum_{x_i \in eval_x} f(x_i) eq(x, r) + // where + // eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i)) + // + // This function is used in ZeroCheck. + pub fn build_f_hat(&self, r: &[F]) -> Result { + let start = start_timer!(|| "zero check build hat f"); + + if self.aux_info.num_variables != r.len() { + return Err(ArithErrors::InvalidParameters(format!( + "r.len() is different from number of variables: {} vs {}", + r.len(), + self.aux_info.num_variables + ))); + } + + let eq_x_r = build_eq_x_r(r)?; + let mut res = self.clone(); + res.mul_by_mle(eq_x_r, F::one())?; + + end_timer!(start); + Ok(res) + } +} + +/// Evaluate eq polynomial. +pub fn eq_eval(x: &[F], y: &[F]) -> Result { + if x.len() != y.len() { + return Err(ArithErrors::InvalidParameters( + "x and y have different length".to_string(), + )); + } + let start = start_timer!(|| "eq_eval"); + let mut res = F::one(); + for (&xi, &yi) in x.iter().zip(y.iter()) { + let xi_yi = xi * yi; + res *= xi_yi + xi_yi - xi - yi + F::one(); + } + end_timer!(start); + Ok(res) +} + +/// This function build the eq(x, r) polynomial for any given r. +/// +/// Evaluate +/// eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i)) +/// over r, which is +/// eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i)) +fn build_eq_x_r(r: &[F]) -> Result>, ArithErrors> { + let evals = build_eq_x_r_vec(r)?; + let mle = DenseMultilinearExtension::from_evaluations_vec(r.len(), evals); + + Ok(Arc::new(mle)) +} +/// This function build the eq(x, r) polynomial for any given r, and output the +/// evaluation of eq(x, r) in its vector form. +/// +/// Evaluate +/// eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i)) +/// over r, which is +/// eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i)) +fn build_eq_x_r_vec(r: &[F]) -> Result, ArithErrors> { + // we build eq(x,r) from its evaluations + // we want to evaluate eq(x,r) over x \in {0, 1}^num_vars + // for example, with num_vars = 4, x is a binary vector of 4, then + // 0 0 0 0 -> (1-r0) * (1-r1) * (1-r2) * (1-r3) + // 1 0 0 0 -> r0 * (1-r1) * (1-r2) * (1-r3) + // 0 1 0 0 -> (1-r0) * r1 * (1-r2) * (1-r3) + // 1 1 0 0 -> r0 * r1 * (1-r2) * (1-r3) + // .... + // 1 1 1 1 -> r0 * r1 * r2 * r3 + // we will need 2^num_var evaluations + + let mut eval = Vec::new(); + build_eq_x_r_helper(r, &mut eval)?; + + Ok(eval) +} + +/// A helper function to build eq(x, r) recursively. +/// This function takes `r.len()` steps, and for each step it requires a maximum +/// `r.len()-1` multiplications. +fn build_eq_x_r_helper(r: &[F], buf: &mut Vec) -> Result<(), ArithErrors> { + if r.is_empty() { + return Err(ArithErrors::InvalidParameters("r length is 0".to_string())); + } else if r.len() == 1 { + // initializing the buffer with [1-r_0, r_0] + buf.push(F::one() - r[0]); + buf.push(r[0]); + } else { + build_eq_x_r_helper(&r[1..], buf)?; + + // suppose at the previous step we received [b_1, ..., b_k] + // for the current step we will need + // if x_0 = 0: (1-r0) * [b_1, ..., b_k] + // if x_0 = 1: r0 * [b_1, ..., b_k] + // let mut res = vec![]; + // for &b_i in buf.iter() { + // let tmp = r[0] * b_i; + // res.push(b_i - tmp); + // res.push(tmp); + // } + // *buf = res; + + let mut res = vec![F::zero(); buf.len() << 1]; + res.par_iter_mut().enumerate().for_each(|(i, val)| { + let bi = buf[i >> 1]; + let tmp = r[0] * bi; + if i & 1 == 0 { + *val = bi - tmp; + } else { + *val = tmp; + } + }); + *buf = res; + } + + Ok(()) +} + +/// Decompose an integer into a binary vector in little endian. +pub fn bit_decompose(input: u64, num_var: usize) -> Vec { + let mut res = Vec::with_capacity(num_var); + let mut i = input; + for _ in 0..num_var { + res.push(i & 1 == 1); + i >>= 1; + } + res +} + +#[cfg(test)] +mod test { + use super::*; + use crate::utils::multilinear_polynomial::tests::random_mle_list; + use ark_bls12_377::Fr; + use ark_ff::UniformRand; + use ark_std::{ + rand::{Rng, RngCore}, + test_rng, + }; + + impl VirtualPolynomial { + /// Sample a random virtual polynomial, return the polynomial and its sum. + fn rand( + nv: usize, + num_multiplicands_range: (usize, usize), + num_products: usize, + rng: &mut R, + ) -> Result<(Self, F), ArithErrors> { + let start = start_timer!(|| "sample random virtual polynomial"); + + let mut sum = F::zero(); + let mut poly = VirtualPolynomial::new(nv); + for _ in 0..num_products { + let num_multiplicands = + rng.gen_range(num_multiplicands_range.0..num_multiplicands_range.1); + let (product, product_sum) = random_mle_list(nv, num_multiplicands, rng); + let coefficient = F::rand(rng); + poly.add_mle_list(product.into_iter(), coefficient)?; + sum += product_sum * coefficient; + } + + end_timer!(start); + Ok((poly, sum)) + } + } + + #[test] + fn test_virtual_polynomial_additions() -> Result<(), ArithErrors> { + let mut rng = test_rng(); + for nv in 2..5 { + for num_products in 2..5 { + let base: Vec = (0..nv).map(|_| Fr::rand(&mut rng)).collect(); + + let (a, _a_sum) = + VirtualPolynomial::::rand(nv, (2, 3), num_products, &mut rng)?; + let (b, _b_sum) = + VirtualPolynomial::::rand(nv, (2, 3), num_products, &mut rng)?; + let c = &a + &b; + + assert_eq!( + a.evaluate(base.as_ref())? + b.evaluate(base.as_ref())?, + c.evaluate(base.as_ref())? + ); + } + } + + Ok(()) + } + + #[test] + fn test_virtual_polynomial_mul_by_mle() -> Result<(), ArithErrors> { + let mut rng = test_rng(); + for nv in 2..5 { + for num_products in 2..5 { + let base: Vec = (0..nv).map(|_| Fr::rand(&mut rng)).collect(); + + let (a, _a_sum) = + VirtualPolynomial::::rand(nv, (2, 3), num_products, &mut rng)?; + let (b, _b_sum) = random_mle_list(nv, 1, &mut rng); + let b_mle = b[0].clone(); + let coeff = Fr::rand(&mut rng); + let b_vp = VirtualPolynomial::new_from_mle(&b_mle, coeff); + + let mut c = a.clone(); + + c.mul_by_mle(b_mle, coeff)?; + + assert_eq!( + a.evaluate(base.as_ref())? * b_vp.evaluate(base.as_ref())?, + c.evaluate(base.as_ref())? + ); + } + } + + Ok(()) + } + + #[test] + fn test_eq_xr() { + let mut rng = test_rng(); + for nv in 4..10 { + let r: Vec = (0..nv).map(|_| Fr::rand(&mut rng)).collect(); + let eq_x_r = build_eq_x_r(r.as_ref()).unwrap(); + let eq_x_r2 = build_eq_x_r_for_test(r.as_ref()); + assert_eq!(eq_x_r, eq_x_r2); + } + } + + /// Naive method to build eq(x, r). + /// Only used for testing purpose. + // Evaluate + // eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i)) + // over r, which is + // eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i)) + fn build_eq_x_r_for_test(r: &[F]) -> Arc> { + // we build eq(x,r) from its evaluations + // we want to evaluate eq(x,r) over x \in {0, 1}^num_vars + // for example, with num_vars = 4, x is a binary vector of 4, then + // 0 0 0 0 -> (1-r0) * (1-r1) * (1-r2) * (1-r3) + // 1 0 0 0 -> r0 * (1-r1) * (1-r2) * (1-r3) + // 0 1 0 0 -> (1-r0) * r1 * (1-r2) * (1-r3) + // 1 1 0 0 -> r0 * r1 * (1-r2) * (1-r3) + // .... + // 1 1 1 1 -> r0 * r1 * r2 * r3 + // we will need 2^num_var evaluations + + // First, we build array for {1 - r_i} + let one_minus_r: Vec = r.iter().map(|ri| F::one() - ri).collect(); + + let num_var = r.len(); + let mut eval = vec![]; + + for i in 0..1 << num_var { + let mut current_eval = F::one(); + let bit_sequence = bit_decompose(i, num_var); + + for (&bit, (ri, one_minus_ri)) in + bit_sequence.iter().zip(r.iter().zip(one_minus_r.iter())) + { + current_eval *= if bit { *ri } else { *one_minus_ri }; + } + eval.push(current_eval); + } + + let mle = DenseMultilinearExtension::from_evaluations_vec(num_var, eval); + + Arc::new(mle) + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 63e55ca..f2c0503 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1 +1,7 @@ pub mod vec; + +// expose espresso local modules +pub mod espresso; +pub use crate::utils::espresso::multilinear_polynomial; +pub use crate::utils::espresso::sum_check; +pub use crate::utils::espresso::virtual_polynomial; diff --git a/src/utils/vec.rs b/src/utils/vec.rs index 1b3d1b5..837310f 100644 --- a/src/utils/vec.rs +++ b/src/utils/vec.rs @@ -1,5 +1,6 @@ use ark_ff::PrimeField; use ark_std::cfg_iter; +use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; #[derive(Clone, Debug, Eq, PartialEq)] pub struct SparseMatrix {