// code forked from: // https://github.com/EspressoSystems/hyperplonk/tree/main/subroutines/src/poly_iop/sum_check // // Copyright (c) 2023 Espresso Systems (espressosys.com) // This file is part of the HyperPlonk library. // You should have received a copy of the MIT License // along with the HyperPlonk library. If not, see . //! Prover subroutines for a SumCheck protocol. use super::SumCheckProver; use crate::utils::{ lagrange_poly::compute_lagrange_interpolated_poly, multilinear_polynomial::fix_variables, virtual_polynomial::VirtualPolynomial, }; use ark_ff::{batch_inversion, PrimeField}; use ark_poly::DenseMultilinearExtension; use ark_std::{cfg_into_iter, end_timer, start_timer}; use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator}; use std::sync::Arc; use super::structs::{IOPProverMessage, IOPProverState}; use espresso_subroutines::poly_iop::prelude::PolyIOPErrors; // #[cfg(feature = "parallel")] use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; impl SumCheckProver for IOPProverState { type VirtualPolynomial = VirtualPolynomial; type ProverMessage = IOPProverMessage; /// Initialize the prover state to argue for the sum of the input polynomial /// over {0,1}^`num_vars`. fn prover_init(polynomial: &Self::VirtualPolynomial) -> Result { let start = start_timer!(|| "sum check prover init"); if polynomial.aux_info.num_variables == 0 { return Err(PolyIOPErrors::InvalidParameters( "Attempt to prove a constant.".to_string(), )); } end_timer!(start); Ok(Self { challenges: Vec::with_capacity(polynomial.aux_info.num_variables), round: 0, poly: polynomial.clone(), extrapolation_aux: (1..polynomial.aux_info.max_degree) .map(|degree| { let points = (0..1 + degree as u64).map(F::from).collect::>(); let weights = barycentric_weights(&points); (points, weights) }) .collect(), }) } /// Receive message from verifier, generate prover message, and proceed to /// next round. /// /// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2). fn prove_round_and_update_state( &mut self, challenge: &Option, ) -> Result { // let start = // start_timer!(|| format!("sum check prove {}-th round and update state", // self.round)); if self.round >= self.poly.aux_info.num_variables { return Err(PolyIOPErrors::InvalidProver( "Prover is not active".to_string(), )); } // let fix_argument = start_timer!(|| "fix argument"); // Step 1: // fix argument and evaluate f(x) over x_m = r; where r is the challenge // for the current round, and m is the round number, indexed from 1 // // i.e.: // at round m <= n, for each mle g(x_1, ... x_n) within the flattened_mle // which has already been evaluated to // // g(r_1, ..., r_{m-1}, x_m ... x_n) // // eval g over r_m, and mutate g to g(r_1, ... r_m,, x_{m+1}... x_n) let mut flattened_ml_extensions: Vec> = self .poly .flattened_ml_extensions .par_iter() .map(|x| x.as_ref().clone()) .collect(); if let Some(chal) = challenge { if self.round == 0 { return Err(PolyIOPErrors::InvalidProver( "first round should be prover first.".to_string(), )); } self.challenges.push(*chal); let r = self.challenges[self.round - 1]; // #[cfg(feature = "parallel")] flattened_ml_extensions .par_iter_mut() .for_each(|mle| *mle = fix_variables(mle, &[r])); // #[cfg(not(feature = "parallel"))] // flattened_ml_extensions // .iter_mut() // .for_each(|mle| *mle = fix_variables(mle, &[r])); } else if self.round > 0 { return Err(PolyIOPErrors::InvalidProver( "verifier message is empty".to_string(), )); } // end_timer!(fix_argument); self.round += 1; let products_list = self.poly.products.clone(); let mut products_sum = vec![F::ZERO; self.poly.aux_info.max_degree + 1]; // Step 2: generate sum for the partial evaluated polynomial: // f(r_1, ... r_m,, x_{m+1}... x_n) products_list.iter().for_each(|(coefficient, products)| { let mut sum = cfg_into_iter!(0..1 << (self.poly.aux_info.num_variables - self.round)) .fold( || { ( vec![(F::ZERO, F::ZERO); products.len()], vec![F::ZERO; products.len() + 1], ) }, |(mut buf, mut acc), b| { buf.iter_mut() .zip(products.iter()) .for_each(|((eval, step), f)| { let table = &flattened_ml_extensions[*f]; *eval = table[b << 1]; *step = table[(b << 1) + 1] - table[b << 1]; }); acc[0] += buf.iter().map(|(eval, _)| eval).product::(); acc[1..].iter_mut().for_each(|acc| { buf.iter_mut().for_each(|(eval, step)| *eval += step as &_); *acc += buf.iter().map(|(eval, _)| eval).product::(); }); (buf, acc) }, ) .map(|(_, partial)| partial) .reduce( || vec![F::ZERO; products.len() + 1], |mut sum, partial| { sum.iter_mut() .zip(partial.iter()) .for_each(|(sum, partial)| *sum += partial); sum }, ); sum.iter_mut().for_each(|sum| *sum *= coefficient); let extraploation = cfg_into_iter!(0..self.poly.aux_info.max_degree - products.len()) .map(|i| { let (points, weights) = &self.extrapolation_aux[products.len() - 1]; let at = F::from((products.len() + 1 + i) as u64); extrapolate(points, weights, &sum, &at) }) .collect::>(); products_sum .iter_mut() .zip(sum.iter().chain(extraploation.iter())) .for_each(|(products_sum, sum)| *products_sum += sum); }); // update prover's state to the partial evaluated polynomial self.poly.flattened_ml_extensions = flattened_ml_extensions .par_iter() .map(|x| Arc::new(x.clone())) .collect(); let prover_poly = compute_lagrange_interpolated_poly::(&products_sum); Ok(IOPProverMessage { coeffs: prover_poly.coeffs, }) } } #[allow(clippy::filter_map_bool_then)] fn barycentric_weights(points: &[F]) -> Vec { let mut weights = points .iter() .enumerate() .map(|(j, point_j)| { points .iter() .enumerate() .filter_map(|(i, point_i)| (i != j).then(|| *point_j - point_i)) .reduce(|acc, value| acc * value) .unwrap_or_else(F::one) }) .collect::>(); batch_inversion(&mut weights); weights } fn extrapolate(points: &[F], weights: &[F], evals: &[F], at: &F) -> F { let (coeffs, sum_inv) = { let mut coeffs = points.iter().map(|point| *at - point).collect::>(); batch_inversion(&mut coeffs); coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| { *coeff *= weight; }); let sum_inv = coeffs.iter().sum::().inverse().unwrap_or_default(); (coeffs, sum_inv) }; coeffs .iter() .zip(evals) .map(|(coeff, eval)| *coeff * eval) .sum::() * sum_inv }