mirror of
https://github.com/arnaucube/hyperplonk.git
synced 2026-01-12 17:01:28 +01:00
Improve sum check in general and preprocess for sum check in mlkzg multi_open (#123)
* feat: faster sum check prover and multilinear kzg batching open * fix: add comment about why we combine polys that have the same opening point * fix: remove the unnecessary last eval increment
This commit is contained in:
@@ -35,6 +35,9 @@ pub struct IOPProverState<F: PrimeField> {
|
||||
pub(crate) round: usize,
|
||||
/// pointer to the virtual polynomial
|
||||
pub(crate) poly: VirtualPolynomial<F>,
|
||||
/// points with precomputed barycentric weights for extrapolating smaller
|
||||
/// degree uni-polys to `max_degree + 1` evaluations.
|
||||
pub(crate) extrapolation_aux: Vec<(Vec<F>, Vec<F>)>,
|
||||
}
|
||||
|
||||
/// Prover State of a PolyIOP
|
||||
|
||||
@@ -12,9 +12,9 @@ use crate::poly_iop::{
|
||||
structs::{IOPProverMessage, IOPProverState},
|
||||
};
|
||||
use arithmetic::{fix_variables, VirtualPolynomial};
|
||||
use ark_ff::PrimeField;
|
||||
use ark_ff::{batch_inversion, PrimeField};
|
||||
use ark_poly::DenseMultilinearExtension;
|
||||
use ark_std::{end_timer, start_timer, vec::Vec};
|
||||
use ark_std::{cfg_into_iter, end_timer, start_timer, vec::Vec};
|
||||
use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator};
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -40,6 +40,13 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
|
||||
challenges: Vec::with_capacity(polynomial.aux_info.num_variables),
|
||||
round: 0,
|
||||
poly: polynomial.clone(),
|
||||
extrapolation_aux: (1..polynomial.aux_info.max_degree)
|
||||
.map(|degree| {
|
||||
let points = (0..1 + degree as u64).map(F::from).collect::<Vec<_>>();
|
||||
let weights = barycentric_weights(&points);
|
||||
(points, weights)
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -110,83 +117,56 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
|
||||
let products_list = self.poly.products.clone();
|
||||
let mut products_sum = vec![F::zero(); self.poly.aux_info.max_degree + 1];
|
||||
|
||||
// let compute_sum = start_timer!(|| "compute sum");
|
||||
|
||||
// Step 2: generate sum for the partial evaluated polynomial:
|
||||
// f(r_1, ... r_m,, x_{m+1}... x_n)
|
||||
|
||||
#[cfg(feature = "parallel")]
|
||||
{
|
||||
let flag = (self.poly.aux_info.max_degree == 2)
|
||||
&& (products_list.len() == 1)
|
||||
&& (products_list[0].0 == F::one());
|
||||
if flag {
|
||||
for (t, e) in products_sum.iter_mut().enumerate() {
|
||||
let evals = (0..1 << (self.poly.aux_info.num_variables - self.round))
|
||||
.into_par_iter()
|
||||
.map(|b| {
|
||||
// evaluate P_round(t)
|
||||
let table0 = &flattened_ml_extensions[products_list[0].1[0]];
|
||||
let table1 = &flattened_ml_extensions[products_list[0].1[1]];
|
||||
if t == 0 {
|
||||
table0[b << 1] * table1[b << 1]
|
||||
} else if t == 1 {
|
||||
table0[(b << 1) + 1] * table1[(b << 1) + 1]
|
||||
} else {
|
||||
(table0[(b << 1) + 1] + table0[(b << 1) + 1] - table0[b << 1])
|
||||
* (table1[(b << 1) + 1] + table1[(b << 1) + 1] - table1[b << 1])
|
||||
}
|
||||
})
|
||||
.collect::<Vec<F>>();
|
||||
*e += evals.par_iter().sum::<F>();
|
||||
}
|
||||
} else {
|
||||
for (t, e) in products_sum.iter_mut().enumerate() {
|
||||
let t = F::from(t as u128);
|
||||
let products = (0..1 << (self.poly.aux_info.num_variables - self.round))
|
||||
.into_par_iter()
|
||||
.map(|b| {
|
||||
// evaluate P_round(t)
|
||||
let mut tmp = F::zero();
|
||||
products_list.iter().for_each(|(coefficient, products)| {
|
||||
let num_mles = products.len();
|
||||
let mut product = *coefficient;
|
||||
for &f in products.iter().take(num_mles) {
|
||||
let table = &flattened_ml_extensions[f]; // f's range is checked in init
|
||||
// TODO: Could be done faster by cashing the results from the
|
||||
// previous t and adding the diff
|
||||
// Also possible to use Karatsuba multiplication
|
||||
product *=
|
||||
table[b << 1] + (table[(b << 1) + 1] - table[b << 1]) * t;
|
||||
}
|
||||
tmp += product;
|
||||
products_list.iter().for_each(|(coefficient, products)| {
|
||||
let mut sum = cfg_into_iter!(0..1 << (self.poly.aux_info.num_variables - self.round))
|
||||
.fold(
|
||||
|| {
|
||||
(
|
||||
vec![(F::zero(), F::zero()); products.len()],
|
||||
vec![F::zero(); products.len() + 1],
|
||||
)
|
||||
},
|
||||
|(mut buf, mut acc), b| {
|
||||
buf.iter_mut()
|
||||
.zip(products.iter())
|
||||
.for_each(|((eval, step), f)| {
|
||||
let table = &flattened_ml_extensions[*f];
|
||||
*eval = table[b << 1];
|
||||
*step = table[(b << 1) + 1] - table[b << 1];
|
||||
});
|
||||
|
||||
tmp
|
||||
})
|
||||
.collect::<Vec<F>>();
|
||||
*e += products.par_iter().sum::<F>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "parallel"))]
|
||||
products_sum.iter_mut().enumerate().for_each(|(t, e)| {
|
||||
let t = F::from(t as u64);
|
||||
let one_minus_t = F::one() - t;
|
||||
|
||||
for b in 0..1 << (self.poly.aux_info.num_variables - self.round) {
|
||||
// evaluate P_round(t)
|
||||
for (coefficient, products) in products_list.iter() {
|
||||
let num_mles = products.len();
|
||||
let mut product = *coefficient;
|
||||
for &f in products.iter().take(num_mles) {
|
||||
let table = &flattened_ml_extensions[f]; // f's range is checked in init
|
||||
product *= table[b << 1] + (table[(b << 1) + 1] - table[b << 1]) * t;
|
||||
}
|
||||
*e += product;
|
||||
}
|
||||
}
|
||||
acc[0] += buf.iter().map(|(eval, _)| eval).product::<F>();
|
||||
acc[1..].iter_mut().for_each(|acc| {
|
||||
buf.iter_mut().for_each(|(eval, step)| *eval += step as &_);
|
||||
*acc += buf.iter().map(|(eval, _)| eval).product::<F>();
|
||||
});
|
||||
(buf, acc)
|
||||
},
|
||||
)
|
||||
.map(|(_, partial)| partial)
|
||||
.reduce(
|
||||
|| vec![F::zero(); products.len() + 1],
|
||||
|mut sum, partial| {
|
||||
sum.iter_mut()
|
||||
.zip(partial.iter())
|
||||
.for_each(|(sum, partial)| *sum += partial);
|
||||
sum
|
||||
},
|
||||
);
|
||||
sum.iter_mut().for_each(|sum| *sum *= coefficient);
|
||||
let extraploation = cfg_into_iter!(0..self.poly.aux_info.max_degree - products.len())
|
||||
.map(|i| {
|
||||
let (points, weights) = &self.extrapolation_aux[products.len() - 1];
|
||||
let at = F::from((products.len() + 1 + i) as u64);
|
||||
extrapolate(points, weights, &sum, &at)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
products_sum
|
||||
.iter_mut()
|
||||
.zip(sum.iter().chain(extraploation.iter()))
|
||||
.for_each(|(products_sum, sum)| *products_sum += sum);
|
||||
});
|
||||
|
||||
// update prover's state to the partial evaluated polynomial
|
||||
@@ -195,10 +175,43 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
|
||||
.map(|x| Arc::new(x.clone()))
|
||||
.collect();
|
||||
|
||||
// end_timer!(compute_sum);
|
||||
// end_timer!(start);
|
||||
Ok(IOPProverMessage {
|
||||
evaluations: products_sum,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn barycentric_weights<F: PrimeField>(points: &[F]) -> Vec<F> {
|
||||
let mut weights = points
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(j, point_j)| {
|
||||
points
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, point_i)| (i != j).then(|| *point_j - point_i))
|
||||
.reduce(|acc, value| acc * value)
|
||||
.unwrap_or_else(F::one)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
batch_inversion(&mut weights);
|
||||
weights
|
||||
}
|
||||
|
||||
fn extrapolate<F: PrimeField>(points: &[F], weights: &[F], evals: &[F], at: &F) -> F {
|
||||
let (coeffs, sum_inv) = {
|
||||
let mut coeffs = points.iter().map(|point| *at - point).collect::<Vec<_>>();
|
||||
batch_inversion(&mut coeffs);
|
||||
coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| {
|
||||
*coeff *= weight;
|
||||
});
|
||||
let sum_inv = coeffs.iter().sum::<F>().inverse().unwrap_or_default();
|
||||
(coeffs, sum_inv)
|
||||
};
|
||||
coeffs
|
||||
.iter()
|
||||
.zip(evals)
|
||||
.map(|(coeff, eval)| *coeff * eval)
|
||||
.sum::<F>()
|
||||
* sum_inv
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user