Improve sum check in general and preprocess for sum check in mlkzg multi_open (#123)

* feat: faster sum check prover and multilinear kzg batching open

* fix: add comment about why we combine polys that have the same opening point

* fix: remove the unnecessary last eval increment
This commit is contained in:
Han
2023-02-15 00:49:58 +08:00
committed by GitHub
parent 70b0df5c52
commit f64bfe6c2a
3 changed files with 129 additions and 103 deletions

View File

@@ -23,8 +23,7 @@ use arithmetic::{build_eq_x_r_vec, DenseMultilinearExtension, VPAuxInfo, Virtual
use ark_ec::{msm::VariableBaseMSM, PairingEngine, ProjectiveCurve}; use ark_ec::{msm::VariableBaseMSM, PairingEngine, ProjectiveCurve};
use ark_ff::PrimeField; use ark_ff::PrimeField;
use ark_std::{end_timer, log2, start_timer, One, Zero}; use ark_std::{end_timer, log2, start_timer, One, Zero};
use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; use std::{collections::BTreeMap, iter, marker::PhantomData, ops::Deref, sync::Arc};
use std::{marker::PhantomData, sync::Arc};
use transcript::IOPTranscript; use transcript::IOPTranscript;
#[derive(Clone, Debug, Default, PartialEq, Eq)] #[derive(Clone, Debug, Default, PartialEq, Eq)]
@@ -80,22 +79,39 @@ where
// \tilde g_i(b) = eq(t, i) * f_i(b) // \tilde g_i(b) = eq(t, i) * f_i(b)
let timer = start_timer!(|| format!("compute tilde g for {} points", points.len())); let timer = start_timer!(|| format!("compute tilde g for {} points", points.len()));
let mut tilde_gs = vec![]; // combine the polynomials that have same opening point first to reduce the
for (index, f_i) in polynomials.iter().enumerate() { // cost of sum check later.
let mut tilde_g_eval = vec![E::Fr::zero(); 1 << num_var]; let point_indices = points
for (j, &f_i_eval) in f_i.iter().enumerate() { .iter()
tilde_g_eval[j] = f_i_eval * eq_t_i_list[index]; .fold(BTreeMap::<_, _>::new(), |mut indices, point| {
} let idx = indices.len();
tilde_gs.push(Arc::new(DenseMultilinearExtension::from_evaluations_vec( indices.entry(point).or_insert(idx);
num_var, indices
tilde_g_eval, });
))); let deduped_points =
} BTreeMap::from_iter(point_indices.iter().map(|(point, idx)| (*idx, *point)))
.into_values()
.collect::<Vec<_>>();
let merged_tilde_gs = polynomials
.iter()
.zip(points.iter())
.zip(eq_t_i_list.iter())
.fold(
iter::repeat_with(DenseMultilinearExtension::zero)
.map(Arc::new)
.take(point_indices.len())
.collect::<Vec<_>>(),
|mut merged_tilde_gs, ((poly, point), coeff)| {
*Arc::make_mut(&mut merged_tilde_gs[point_indices[point]]) +=
(*coeff, poly.deref());
merged_tilde_gs
},
);
end_timer!(timer); end_timer!(timer);
let timer = start_timer!(|| format!("compute tilde eq for {} points", points.len())); let timer = start_timer!(|| format!("compute tilde eq for {} points", points.len()));
let tilde_eqs: Vec<Arc<DenseMultilinearExtension<<E as PairingEngine>::Fr>>> = points let tilde_eqs: Vec<_> = deduped_points
.par_iter() .iter()
.map(|point| { .map(|point| {
let eq_b_zi = build_eq_x_r_vec(point).unwrap(); let eq_b_zi = build_eq_x_r_vec(point).unwrap();
Arc::new(DenseMultilinearExtension::from_evaluations_vec( Arc::new(DenseMultilinearExtension::from_evaluations_vec(
@@ -110,8 +126,8 @@ where
let step = start_timer!(|| "add mle"); let step = start_timer!(|| "add mle");
let mut sum_check_vp = VirtualPolynomial::new(num_var); let mut sum_check_vp = VirtualPolynomial::new(num_var);
for (tilde_g, tilde_eq) in tilde_gs.iter().zip(tilde_eqs.into_iter()) { for (merged_tilde_g, tilde_eq) in merged_tilde_gs.iter().zip(tilde_eqs.into_iter()) {
sum_check_vp.add_mle_list([tilde_g.clone(), tilde_eq], E::Fr::one())?; sum_check_vp.add_mle_list([merged_tilde_g.clone(), tilde_eq], E::Fr::one())?;
} }
end_timer!(step); end_timer!(step);
@@ -133,17 +149,11 @@ where
// build g'(X) = \sum_i=1..k \tilde eq_i(a2) * \tilde g_i(X) where (a2) is the // build g'(X) = \sum_i=1..k \tilde eq_i(a2) * \tilde g_i(X) where (a2) is the
// sumcheck's point \tilde eq_i(a2) = eq(a2, point_i) // sumcheck's point \tilde eq_i(a2) = eq(a2, point_i)
let step = start_timer!(|| "evaluate at a2"); let step = start_timer!(|| "evaluate at a2");
let mut g_prime_evals = vec![E::Fr::zero(); 1 << num_var]; let mut g_prime = Arc::new(DenseMultilinearExtension::zero());
for (tilde_g, point) in tilde_gs.iter().zip(points.iter()) { for (merged_tilde_g, point) in merged_tilde_gs.iter().zip(deduped_points.iter()) {
let eq_i_a2 = eq_eval(a2, point)?; let eq_i_a2 = eq_eval(a2, point)?;
for (j, &tilde_g_eval) in tilde_g.iter().enumerate() { *Arc::make_mut(&mut g_prime) += (eq_i_a2, merged_tilde_g.deref());
g_prime_evals[j] += tilde_g_eval * eq_i_a2;
}
} }
let g_prime = Arc::new(DenseMultilinearExtension::from_evaluations_vec(
num_var,
g_prime_evals,
));
end_timer!(step); end_timer!(step);
let step = start_timer!(|| "pcs open"); let step = start_timer!(|| "pcs open");

View File

@@ -35,6 +35,9 @@ pub struct IOPProverState<F: PrimeField> {
pub(crate) round: usize, pub(crate) round: usize,
/// pointer to the virtual polynomial /// pointer to the virtual polynomial
pub(crate) poly: VirtualPolynomial<F>, pub(crate) poly: VirtualPolynomial<F>,
/// points with precomputed barycentric weights for extrapolating smaller
/// degree uni-polys to `max_degree + 1` evaluations.
pub(crate) extrapolation_aux: Vec<(Vec<F>, Vec<F>)>,
} }
/// Prover State of a PolyIOP /// Prover State of a PolyIOP

View File

@@ -12,9 +12,9 @@ use crate::poly_iop::{
structs::{IOPProverMessage, IOPProverState}, structs::{IOPProverMessage, IOPProverState},
}; };
use arithmetic::{fix_variables, VirtualPolynomial}; use arithmetic::{fix_variables, VirtualPolynomial};
use ark_ff::PrimeField; use ark_ff::{batch_inversion, PrimeField};
use ark_poly::DenseMultilinearExtension; use ark_poly::DenseMultilinearExtension;
use ark_std::{end_timer, start_timer, vec::Vec}; use ark_std::{cfg_into_iter, end_timer, start_timer, vec::Vec};
use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator}; use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator};
use std::sync::Arc; use std::sync::Arc;
@@ -40,6 +40,13 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
challenges: Vec::with_capacity(polynomial.aux_info.num_variables), challenges: Vec::with_capacity(polynomial.aux_info.num_variables),
round: 0, round: 0,
poly: polynomial.clone(), poly: polynomial.clone(),
extrapolation_aux: (1..polynomial.aux_info.max_degree)
.map(|degree| {
let points = (0..1 + degree as u64).map(F::from).collect::<Vec<_>>();
let weights = barycentric_weights(&points);
(points, weights)
})
.collect(),
}) })
} }
@@ -110,83 +117,56 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
let products_list = self.poly.products.clone(); let products_list = self.poly.products.clone();
let mut products_sum = vec![F::zero(); self.poly.aux_info.max_degree + 1]; let mut products_sum = vec![F::zero(); self.poly.aux_info.max_degree + 1];
// let compute_sum = start_timer!(|| "compute sum");
// Step 2: generate sum for the partial evaluated polynomial: // Step 2: generate sum for the partial evaluated polynomial:
// f(r_1, ... r_m,, x_{m+1}... x_n) // f(r_1, ... r_m,, x_{m+1}... x_n)
#[cfg(feature = "parallel")] products_list.iter().for_each(|(coefficient, products)| {
{ let mut sum = cfg_into_iter!(0..1 << (self.poly.aux_info.num_variables - self.round))
let flag = (self.poly.aux_info.max_degree == 2) .fold(
&& (products_list.len() == 1) || {
&& (products_list[0].0 == F::one()); (
if flag { vec![(F::zero(), F::zero()); products.len()],
for (t, e) in products_sum.iter_mut().enumerate() { vec![F::zero(); products.len() + 1],
let evals = (0..1 << (self.poly.aux_info.num_variables - self.round)) )
.into_par_iter() },
.map(|b| { |(mut buf, mut acc), b| {
// evaluate P_round(t) buf.iter_mut()
let table0 = &flattened_ml_extensions[products_list[0].1[0]]; .zip(products.iter())
let table1 = &flattened_ml_extensions[products_list[0].1[1]]; .for_each(|((eval, step), f)| {
if t == 0 { let table = &flattened_ml_extensions[*f];
table0[b << 1] * table1[b << 1] *eval = table[b << 1];
} else if t == 1 { *step = table[(b << 1) + 1] - table[b << 1];
table0[(b << 1) + 1] * table1[(b << 1) + 1]
} else {
(table0[(b << 1) + 1] + table0[(b << 1) + 1] - table0[b << 1])
* (table1[(b << 1) + 1] + table1[(b << 1) + 1] - table1[b << 1])
}
})
.collect::<Vec<F>>();
*e += evals.par_iter().sum::<F>();
}
} else {
for (t, e) in products_sum.iter_mut().enumerate() {
let t = F::from(t as u128);
let products = (0..1 << (self.poly.aux_info.num_variables - self.round))
.into_par_iter()
.map(|b| {
// evaluate P_round(t)
let mut tmp = F::zero();
products_list.iter().for_each(|(coefficient, products)| {
let num_mles = products.len();
let mut product = *coefficient;
for &f in products.iter().take(num_mles) {
let table = &flattened_ml_extensions[f]; // f's range is checked in init
// TODO: Could be done faster by cashing the results from the
// previous t and adding the diff
// Also possible to use Karatsuba multiplication
product *=
table[b << 1] + (table[(b << 1) + 1] - table[b << 1]) * t;
}
tmp += product;
}); });
acc[0] += buf.iter().map(|(eval, _)| eval).product::<F>();
tmp acc[1..].iter_mut().for_each(|acc| {
}) buf.iter_mut().for_each(|(eval, step)| *eval += step as &_);
.collect::<Vec<F>>(); *acc += buf.iter().map(|(eval, _)| eval).product::<F>();
*e += products.par_iter().sum::<F>(); });
} (buf, acc)
} },
} )
.map(|(_, partial)| partial)
#[cfg(not(feature = "parallel"))] .reduce(
products_sum.iter_mut().enumerate().for_each(|(t, e)| { || vec![F::zero(); products.len() + 1],
let t = F::from(t as u64); |mut sum, partial| {
let one_minus_t = F::one() - t; sum.iter_mut()
.zip(partial.iter())
for b in 0..1 << (self.poly.aux_info.num_variables - self.round) { .for_each(|(sum, partial)| *sum += partial);
// evaluate P_round(t) sum
for (coefficient, products) in products_list.iter() { },
let num_mles = products.len(); );
let mut product = *coefficient; sum.iter_mut().for_each(|sum| *sum *= coefficient);
for &f in products.iter().take(num_mles) { let extraploation = cfg_into_iter!(0..self.poly.aux_info.max_degree - products.len())
let table = &flattened_ml_extensions[f]; // f's range is checked in init .map(|i| {
product *= table[b << 1] + (table[(b << 1) + 1] - table[b << 1]) * t; let (points, weights) = &self.extrapolation_aux[products.len() - 1];
} let at = F::from((products.len() + 1 + i) as u64);
*e += product; extrapolate(points, weights, &sum, &at)
} })
} .collect::<Vec<_>>();
products_sum
.iter_mut()
.zip(sum.iter().chain(extraploation.iter()))
.for_each(|(products_sum, sum)| *products_sum += sum);
}); });
// update prover's state to the partial evaluated polynomial // update prover's state to the partial evaluated polynomial
@@ -195,10 +175,43 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
.map(|x| Arc::new(x.clone())) .map(|x| Arc::new(x.clone()))
.collect(); .collect();
// end_timer!(compute_sum);
// end_timer!(start);
Ok(IOPProverMessage { Ok(IOPProverMessage {
evaluations: products_sum, evaluations: products_sum,
}) })
} }
} }
fn barycentric_weights<F: PrimeField>(points: &[F]) -> Vec<F> {
let mut weights = points
.iter()
.enumerate()
.map(|(j, point_j)| {
points
.iter()
.enumerate()
.filter_map(|(i, point_i)| (i != j).then(|| *point_j - point_i))
.reduce(|acc, value| acc * value)
.unwrap_or_else(F::one)
})
.collect::<Vec<_>>();
batch_inversion(&mut weights);
weights
}
fn extrapolate<F: PrimeField>(points: &[F], weights: &[F], evals: &[F], at: &F) -> F {
let (coeffs, sum_inv) = {
let mut coeffs = points.iter().map(|point| *at - point).collect::<Vec<_>>();
batch_inversion(&mut coeffs);
coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| {
*coeff *= weight;
});
let sum_inv = coeffs.iter().sum::<F>().inverse().unwrap_or_default();
(coeffs, sum_inv)
};
coeffs
.iter()
.zip(evals)
.map(|(coeff, eval)| *coeff * eval)
.sum::<F>()
* sum_inv
}