Improve sum check in general and preprocess for sum check in mlkzg multi_open (#123)

* feat: faster sum check prover and multilinear kzg batching open

* fix: add comment about why we combine polys that have the same opening point

* fix: remove the unnecessary last eval increment
This commit is contained in:
Han
2023-02-15 00:49:58 +08:00
committed by GitHub
parent 70b0df5c52
commit f64bfe6c2a
3 changed files with 129 additions and 103 deletions

View File

@@ -35,6 +35,9 @@ pub struct IOPProverState<F: PrimeField> {
pub(crate) round: usize,
/// pointer to the virtual polynomial
pub(crate) poly: VirtualPolynomial<F>,
/// points with precomputed barycentric weights for extrapolating smaller
/// degree uni-polys to `max_degree + 1` evaluations.
pub(crate) extrapolation_aux: Vec<(Vec<F>, Vec<F>)>,
}
/// Prover State of a PolyIOP

View File

@@ -12,9 +12,9 @@ use crate::poly_iop::{
structs::{IOPProverMessage, IOPProverState},
};
use arithmetic::{fix_variables, VirtualPolynomial};
use ark_ff::PrimeField;
use ark_ff::{batch_inversion, PrimeField};
use ark_poly::DenseMultilinearExtension;
use ark_std::{end_timer, start_timer, vec::Vec};
use ark_std::{cfg_into_iter, end_timer, start_timer, vec::Vec};
use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator};
use std::sync::Arc;
@@ -40,6 +40,13 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
challenges: Vec::with_capacity(polynomial.aux_info.num_variables),
round: 0,
poly: polynomial.clone(),
extrapolation_aux: (1..polynomial.aux_info.max_degree)
.map(|degree| {
let points = (0..1 + degree as u64).map(F::from).collect::<Vec<_>>();
let weights = barycentric_weights(&points);
(points, weights)
})
.collect(),
})
}
@@ -110,83 +117,56 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
let products_list = self.poly.products.clone();
let mut products_sum = vec![F::zero(); self.poly.aux_info.max_degree + 1];
// let compute_sum = start_timer!(|| "compute sum");
// Step 2: generate sum for the partial evaluated polynomial:
// f(r_1, ... r_m,, x_{m+1}... x_n)
#[cfg(feature = "parallel")]
{
let flag = (self.poly.aux_info.max_degree == 2)
&& (products_list.len() == 1)
&& (products_list[0].0 == F::one());
if flag {
for (t, e) in products_sum.iter_mut().enumerate() {
let evals = (0..1 << (self.poly.aux_info.num_variables - self.round))
.into_par_iter()
.map(|b| {
// evaluate P_round(t)
let table0 = &flattened_ml_extensions[products_list[0].1[0]];
let table1 = &flattened_ml_extensions[products_list[0].1[1]];
if t == 0 {
table0[b << 1] * table1[b << 1]
} else if t == 1 {
table0[(b << 1) + 1] * table1[(b << 1) + 1]
} else {
(table0[(b << 1) + 1] + table0[(b << 1) + 1] - table0[b << 1])
* (table1[(b << 1) + 1] + table1[(b << 1) + 1] - table1[b << 1])
}
})
.collect::<Vec<F>>();
*e += evals.par_iter().sum::<F>();
}
} else {
for (t, e) in products_sum.iter_mut().enumerate() {
let t = F::from(t as u128);
let products = (0..1 << (self.poly.aux_info.num_variables - self.round))
.into_par_iter()
.map(|b| {
// evaluate P_round(t)
let mut tmp = F::zero();
products_list.iter().for_each(|(coefficient, products)| {
let num_mles = products.len();
let mut product = *coefficient;
for &f in products.iter().take(num_mles) {
let table = &flattened_ml_extensions[f]; // f's range is checked in init
// TODO: Could be done faster by cashing the results from the
// previous t and adding the diff
// Also possible to use Karatsuba multiplication
product *=
table[b << 1] + (table[(b << 1) + 1] - table[b << 1]) * t;
}
tmp += product;
products_list.iter().for_each(|(coefficient, products)| {
let mut sum = cfg_into_iter!(0..1 << (self.poly.aux_info.num_variables - self.round))
.fold(
|| {
(
vec![(F::zero(), F::zero()); products.len()],
vec![F::zero(); products.len() + 1],
)
},
|(mut buf, mut acc), b| {
buf.iter_mut()
.zip(products.iter())
.for_each(|((eval, step), f)| {
let table = &flattened_ml_extensions[*f];
*eval = table[b << 1];
*step = table[(b << 1) + 1] - table[b << 1];
});
tmp
})
.collect::<Vec<F>>();
*e += products.par_iter().sum::<F>();
}
}
}
#[cfg(not(feature = "parallel"))]
products_sum.iter_mut().enumerate().for_each(|(t, e)| {
let t = F::from(t as u64);
let one_minus_t = F::one() - t;
for b in 0..1 << (self.poly.aux_info.num_variables - self.round) {
// evaluate P_round(t)
for (coefficient, products) in products_list.iter() {
let num_mles = products.len();
let mut product = *coefficient;
for &f in products.iter().take(num_mles) {
let table = &flattened_ml_extensions[f]; // f's range is checked in init
product *= table[b << 1] + (table[(b << 1) + 1] - table[b << 1]) * t;
}
*e += product;
}
}
acc[0] += buf.iter().map(|(eval, _)| eval).product::<F>();
acc[1..].iter_mut().for_each(|acc| {
buf.iter_mut().for_each(|(eval, step)| *eval += step as &_);
*acc += buf.iter().map(|(eval, _)| eval).product::<F>();
});
(buf, acc)
},
)
.map(|(_, partial)| partial)
.reduce(
|| vec![F::zero(); products.len() + 1],
|mut sum, partial| {
sum.iter_mut()
.zip(partial.iter())
.for_each(|(sum, partial)| *sum += partial);
sum
},
);
sum.iter_mut().for_each(|sum| *sum *= coefficient);
let extraploation = cfg_into_iter!(0..self.poly.aux_info.max_degree - products.len())
.map(|i| {
let (points, weights) = &self.extrapolation_aux[products.len() - 1];
let at = F::from((products.len() + 1 + i) as u64);
extrapolate(points, weights, &sum, &at)
})
.collect::<Vec<_>>();
products_sum
.iter_mut()
.zip(sum.iter().chain(extraploation.iter()))
.for_each(|(products_sum, sum)| *products_sum += sum);
});
// update prover's state to the partial evaluated polynomial
@@ -195,10 +175,43 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
.map(|x| Arc::new(x.clone()))
.collect();
// end_timer!(compute_sum);
// end_timer!(start);
Ok(IOPProverMessage {
evaluations: products_sum,
})
}
}
fn barycentric_weights<F: PrimeField>(points: &[F]) -> Vec<F> {
let mut weights = points
.iter()
.enumerate()
.map(|(j, point_j)| {
points
.iter()
.enumerate()
.filter_map(|(i, point_i)| (i != j).then(|| *point_j - point_i))
.reduce(|acc, value| acc * value)
.unwrap_or_else(F::one)
})
.collect::<Vec<_>>();
batch_inversion(&mut weights);
weights
}
fn extrapolate<F: PrimeField>(points: &[F], weights: &[F], evals: &[F], at: &F) -> F {
let (coeffs, sum_inv) = {
let mut coeffs = points.iter().map(|point| *at - point).collect::<Vec<_>>();
batch_inversion(&mut coeffs);
coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| {
*coeff *= weight;
});
let sum_inv = coeffs.iter().sum::<F>().inverse().unwrap_or_default();
(coeffs, sum_inv)
};
coeffs
.iter()
.zip(evals)
.map(|(coeff, eval)| *coeff * eval)
.sum::<F>()
* sum_inv
}