// code forked from:
|
|
// https://github.com/EspressoSystems/hyperplonk/tree/main/subroutines/src/poly_iop/sum_check
|
|
//
|
|
// Copyright (c) 2023 Espresso Systems (espressosys.com)
|
|
// This file is part of the HyperPlonk library.
|
|
|
|
// You should have received a copy of the MIT License
|
|
// along with the HyperPlonk library. If not, see <https://mit-license.org/>.
|
|
|
|
//! Prover subroutines for a SumCheck protocol.
|
|
|
|
use super::SumCheckProver;
|
|
use crate::utils::{
|
|
lagrange_poly::compute_lagrange_interpolated_poly, multilinear_polynomial::fix_variables,
|
|
virtual_polynomial::VirtualPolynomial,
|
|
};
|
|
use ark_ff::{batch_inversion, PrimeField};
|
|
use ark_poly::DenseMultilinearExtension;
|
|
use ark_std::{cfg_into_iter, end_timer, start_timer};
|
|
use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator};
|
|
use std::sync::Arc;
|
|
|
|
use super::structs::{IOPProverMessage, IOPProverState};
|
|
use espresso_subroutines::poly_iop::prelude::PolyIOPErrors;
|
|
|
|
// #[cfg(feature = "parallel")]
|
|
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
|
|
|
|
impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
|
|
type VirtualPolynomial = VirtualPolynomial<F>;
|
|
type ProverMessage = IOPProverMessage<F>;
|
|
|
|
/// Initialize the prover state to argue for the sum of the input polynomial
|
|
/// over {0,1}^`num_vars`.
|
|
fn prover_init(polynomial: &Self::VirtualPolynomial) -> Result<Self, PolyIOPErrors> {
|
|
let start = start_timer!(|| "sum check prover init");
|
|
if polynomial.aux_info.num_variables == 0 {
|
|
return Err(PolyIOPErrors::InvalidParameters(
|
|
"Attempt to prove a constant.".to_string(),
|
|
));
|
|
}
|
|
end_timer!(start);
|
|
|
|
Ok(Self {
|
|
challenges: Vec::with_capacity(polynomial.aux_info.num_variables),
|
|
round: 0,
|
|
poly: polynomial.clone(),
|
|
extrapolation_aux: (1..polynomial.aux_info.max_degree)
|
|
.map(|degree| {
|
|
let points = (0..1 + degree as u64).map(F::from).collect::<Vec<_>>();
|
|
let weights = barycentric_weights(&points);
|
|
(points, weights)
|
|
})
|
|
.collect(),
|
|
})
|
|
}
|
|
|
|
/// Receive message from verifier, generate prover message, and proceed to
|
|
/// next round.
|
|
///
|
|
/// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2).
|
|
fn prove_round_and_update_state(
|
|
&mut self,
|
|
challenge: &Option<F>,
|
|
) -> Result<Self::ProverMessage, PolyIOPErrors> {
|
|
// let start =
|
|
// start_timer!(|| format!("sum check prove {}-th round and update state",
|
|
// self.round));
|
|
|
|
if self.round >= self.poly.aux_info.num_variables {
|
|
return Err(PolyIOPErrors::InvalidProver(
|
|
"Prover is not active".to_string(),
|
|
));
|
|
}
|
|
|
|
// let fix_argument = start_timer!(|| "fix argument");
|
|
|
|
// Step 1:
|
|
// fix argument and evaluate f(x) over x_m = r; where r is the challenge
|
|
// for the current round, and m is the round number, indexed from 1
|
|
//
|
|
// i.e.:
|
|
// at round m <= n, for each mle g(x_1, ... x_n) within the flattened_mle
|
|
// which has already been evaluated to
|
|
//
|
|
// g(r_1, ..., r_{m-1}, x_m ... x_n)
|
|
//
|
|
// eval g over r_m, and mutate g to g(r_1, ... r_m,, x_{m+1}... x_n)
|
|
let mut flattened_ml_extensions: Vec<DenseMultilinearExtension<F>> = self
|
|
.poly
|
|
.flattened_ml_extensions
|
|
.par_iter()
|
|
.map(|x| x.as_ref().clone())
|
|
.collect();
|
|
|
|
if let Some(chal) = challenge {
|
|
if self.round == 0 {
|
|
return Err(PolyIOPErrors::InvalidProver(
|
|
"first round should be prover first.".to_string(),
|
|
));
|
|
}
|
|
self.challenges.push(*chal);
|
|
|
|
let r = self.challenges[self.round - 1];
|
|
// #[cfg(feature = "parallel")]
|
|
flattened_ml_extensions
|
|
.par_iter_mut()
|
|
.for_each(|mle| *mle = fix_variables(mle, &[r]));
|
|
// #[cfg(not(feature = "parallel"))]
|
|
// flattened_ml_extensions
|
|
// .iter_mut()
|
|
// .for_each(|mle| *mle = fix_variables(mle, &[r]));
|
|
} else if self.round > 0 {
|
|
return Err(PolyIOPErrors::InvalidProver(
|
|
"verifier message is empty".to_string(),
|
|
));
|
|
}
|
|
// end_timer!(fix_argument);
|
|
|
|
self.round += 1;
|
|
|
|
let products_list = self.poly.products.clone();
|
|
let mut products_sum = vec![F::ZERO; self.poly.aux_info.max_degree + 1];
|
|
|
|
// Step 2: generate sum for the partial evaluated polynomial:
|
|
// f(r_1, ... r_m,, x_{m+1}... x_n)
|
|
|
|
products_list.iter().for_each(|(coefficient, products)| {
|
|
let mut sum = cfg_into_iter!(0..1 << (self.poly.aux_info.num_variables - self.round))
|
|
.fold(
|
|
|| {
|
|
(
|
|
vec![(F::ZERO, F::ZERO); products.len()],
|
|
vec![F::ZERO; products.len() + 1],
|
|
)
|
|
},
|
|
|(mut buf, mut acc), b| {
|
|
buf.iter_mut()
|
|
.zip(products.iter())
|
|
.for_each(|((eval, step), f)| {
|
|
let table = &flattened_ml_extensions[*f];
|
|
*eval = table[b << 1];
|
|
*step = table[(b << 1) + 1] - table[b << 1];
|
|
});
|
|
acc[0] += buf.iter().map(|(eval, _)| eval).product::<F>();
|
|
acc[1..].iter_mut().for_each(|acc| {
|
|
buf.iter_mut().for_each(|(eval, step)| *eval += step as &_);
|
|
*acc += buf.iter().map(|(eval, _)| eval).product::<F>();
|
|
});
|
|
(buf, acc)
|
|
},
|
|
)
|
|
.map(|(_, partial)| partial)
|
|
.reduce(
|
|
|| vec![F::ZERO; products.len() + 1],
|
|
|mut sum, partial| {
|
|
sum.iter_mut()
|
|
.zip(partial.iter())
|
|
.for_each(|(sum, partial)| *sum += partial);
|
|
sum
|
|
},
|
|
);
|
|
sum.iter_mut().for_each(|sum| *sum *= coefficient);
|
|
let extraploation = cfg_into_iter!(0..self.poly.aux_info.max_degree - products.len())
|
|
.map(|i| {
|
|
let (points, weights) = &self.extrapolation_aux[products.len() - 1];
|
|
let at = F::from((products.len() + 1 + i) as u64);
|
|
extrapolate(points, weights, &sum, &at)
|
|
})
|
|
.collect::<Vec<_>>();
|
|
products_sum
|
|
.iter_mut()
|
|
.zip(sum.iter().chain(extraploation.iter()))
|
|
.for_each(|(products_sum, sum)| *products_sum += sum);
|
|
});
|
|
|
|
// update prover's state to the partial evaluated polynomial
|
|
self.poly.flattened_ml_extensions = flattened_ml_extensions
|
|
.par_iter()
|
|
.map(|x| Arc::new(x.clone()))
|
|
.collect();
|
|
|
|
let prover_poly = compute_lagrange_interpolated_poly::<F>(&products_sum);
|
|
Ok(IOPProverMessage {
|
|
coeffs: prover_poly.coeffs,
|
|
})
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::filter_map_bool_then)]
|
|
fn barycentric_weights<F: PrimeField>(points: &[F]) -> Vec<F> {
|
|
let mut weights = points
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(j, point_j)| {
|
|
points
|
|
.iter()
|
|
.enumerate()
|
|
.filter_map(|(i, point_i)| (i != j).then(|| *point_j - point_i))
|
|
.reduce(|acc, value| acc * value)
|
|
.unwrap_or_else(F::one)
|
|
})
|
|
.collect::<Vec<_>>();
|
|
batch_inversion(&mut weights);
|
|
weights
|
|
}
|
|
|
|
fn extrapolate<F: PrimeField>(points: &[F], weights: &[F], evals: &[F], at: &F) -> F {
|
|
let (coeffs, sum_inv) = {
|
|
let mut coeffs = points.iter().map(|point| *at - point).collect::<Vec<_>>();
|
|
batch_inversion(&mut coeffs);
|
|
coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| {
|
|
*coeff *= weight;
|
|
});
|
|
let sum_inv = coeffs.iter().sum::<F>().inverse().unwrap_or_default();
|
|
(coeffs, sum_inv)
|
|
};
|
|
coeffs
|
|
.iter()
|
|
.zip(evals)
|
|
.map(|(coeff, eval)| *coeff * eval)
|
|
.sum::<F>()
|
|
* sum_inv
|
|
}
|