mirror of
https://github.com/arnaucube/hyperplonk.git
synced 2026-01-09 23:51:28 +01:00
Improve sum check in general and preprocess for sum check in mlkzg multi_open (#123)
* feat: faster sum check prover and multilinear kzg batching open * fix: add comment about why we combine polys that have the same opening point * fix: remove the unnecessary last eval increment
This commit is contained in:
@@ -23,8 +23,7 @@ use arithmetic::{build_eq_x_r_vec, DenseMultilinearExtension, VPAuxInfo, Virtual
|
||||
use ark_ec::{msm::VariableBaseMSM, PairingEngine, ProjectiveCurve};
|
||||
use ark_ff::PrimeField;
|
||||
use ark_std::{end_timer, log2, start_timer, One, Zero};
|
||||
use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
use std::{collections::BTreeMap, iter, marker::PhantomData, ops::Deref, sync::Arc};
|
||||
use transcript::IOPTranscript;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq)]
|
||||
@@ -80,22 +79,39 @@ where
|
||||
|
||||
// \tilde g_i(b) = eq(t, i) * f_i(b)
|
||||
let timer = start_timer!(|| format!("compute tilde g for {} points", points.len()));
|
||||
let mut tilde_gs = vec![];
|
||||
for (index, f_i) in polynomials.iter().enumerate() {
|
||||
let mut tilde_g_eval = vec![E::Fr::zero(); 1 << num_var];
|
||||
for (j, &f_i_eval) in f_i.iter().enumerate() {
|
||||
tilde_g_eval[j] = f_i_eval * eq_t_i_list[index];
|
||||
}
|
||||
tilde_gs.push(Arc::new(DenseMultilinearExtension::from_evaluations_vec(
|
||||
num_var,
|
||||
tilde_g_eval,
|
||||
)));
|
||||
}
|
||||
// combine the polynomials that have same opening point first to reduce the
|
||||
// cost of sum check later.
|
||||
let point_indices = points
|
||||
.iter()
|
||||
.fold(BTreeMap::<_, _>::new(), |mut indices, point| {
|
||||
let idx = indices.len();
|
||||
indices.entry(point).or_insert(idx);
|
||||
indices
|
||||
});
|
||||
let deduped_points =
|
||||
BTreeMap::from_iter(point_indices.iter().map(|(point, idx)| (*idx, *point)))
|
||||
.into_values()
|
||||
.collect::<Vec<_>>();
|
||||
let merged_tilde_gs = polynomials
|
||||
.iter()
|
||||
.zip(points.iter())
|
||||
.zip(eq_t_i_list.iter())
|
||||
.fold(
|
||||
iter::repeat_with(DenseMultilinearExtension::zero)
|
||||
.map(Arc::new)
|
||||
.take(point_indices.len())
|
||||
.collect::<Vec<_>>(),
|
||||
|mut merged_tilde_gs, ((poly, point), coeff)| {
|
||||
*Arc::make_mut(&mut merged_tilde_gs[point_indices[point]]) +=
|
||||
(*coeff, poly.deref());
|
||||
merged_tilde_gs
|
||||
},
|
||||
);
|
||||
end_timer!(timer);
|
||||
|
||||
let timer = start_timer!(|| format!("compute tilde eq for {} points", points.len()));
|
||||
let tilde_eqs: Vec<Arc<DenseMultilinearExtension<<E as PairingEngine>::Fr>>> = points
|
||||
.par_iter()
|
||||
let tilde_eqs: Vec<_> = deduped_points
|
||||
.iter()
|
||||
.map(|point| {
|
||||
let eq_b_zi = build_eq_x_r_vec(point).unwrap();
|
||||
Arc::new(DenseMultilinearExtension::from_evaluations_vec(
|
||||
@@ -110,8 +126,8 @@ where
|
||||
|
||||
let step = start_timer!(|| "add mle");
|
||||
let mut sum_check_vp = VirtualPolynomial::new(num_var);
|
||||
for (tilde_g, tilde_eq) in tilde_gs.iter().zip(tilde_eqs.into_iter()) {
|
||||
sum_check_vp.add_mle_list([tilde_g.clone(), tilde_eq], E::Fr::one())?;
|
||||
for (merged_tilde_g, tilde_eq) in merged_tilde_gs.iter().zip(tilde_eqs.into_iter()) {
|
||||
sum_check_vp.add_mle_list([merged_tilde_g.clone(), tilde_eq], E::Fr::one())?;
|
||||
}
|
||||
end_timer!(step);
|
||||
|
||||
@@ -133,17 +149,11 @@ where
|
||||
// build g'(X) = \sum_i=1..k \tilde eq_i(a2) * \tilde g_i(X) where (a2) is the
|
||||
// sumcheck's point \tilde eq_i(a2) = eq(a2, point_i)
|
||||
let step = start_timer!(|| "evaluate at a2");
|
||||
let mut g_prime_evals = vec![E::Fr::zero(); 1 << num_var];
|
||||
for (tilde_g, point) in tilde_gs.iter().zip(points.iter()) {
|
||||
let mut g_prime = Arc::new(DenseMultilinearExtension::zero());
|
||||
for (merged_tilde_g, point) in merged_tilde_gs.iter().zip(deduped_points.iter()) {
|
||||
let eq_i_a2 = eq_eval(a2, point)?;
|
||||
for (j, &tilde_g_eval) in tilde_g.iter().enumerate() {
|
||||
g_prime_evals[j] += tilde_g_eval * eq_i_a2;
|
||||
}
|
||||
*Arc::make_mut(&mut g_prime) += (eq_i_a2, merged_tilde_g.deref());
|
||||
}
|
||||
let g_prime = Arc::new(DenseMultilinearExtension::from_evaluations_vec(
|
||||
num_var,
|
||||
g_prime_evals,
|
||||
));
|
||||
end_timer!(step);
|
||||
|
||||
let step = start_timer!(|| "pcs open");
|
||||
|
||||
@@ -35,6 +35,9 @@ pub struct IOPProverState<F: PrimeField> {
|
||||
pub(crate) round: usize,
|
||||
/// pointer to the virtual polynomial
|
||||
pub(crate) poly: VirtualPolynomial<F>,
|
||||
/// points with precomputed barycentric weights for extrapolating smaller
|
||||
/// degree uni-polys to `max_degree + 1` evaluations.
|
||||
pub(crate) extrapolation_aux: Vec<(Vec<F>, Vec<F>)>,
|
||||
}
|
||||
|
||||
/// Prover State of a PolyIOP
|
||||
|
||||
@@ -12,9 +12,9 @@ use crate::poly_iop::{
|
||||
structs::{IOPProverMessage, IOPProverState},
|
||||
};
|
||||
use arithmetic::{fix_variables, VirtualPolynomial};
|
||||
use ark_ff::PrimeField;
|
||||
use ark_ff::{batch_inversion, PrimeField};
|
||||
use ark_poly::DenseMultilinearExtension;
|
||||
use ark_std::{end_timer, start_timer, vec::Vec};
|
||||
use ark_std::{cfg_into_iter, end_timer, start_timer, vec::Vec};
|
||||
use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator};
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -40,6 +40,13 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
|
||||
challenges: Vec::with_capacity(polynomial.aux_info.num_variables),
|
||||
round: 0,
|
||||
poly: polynomial.clone(),
|
||||
extrapolation_aux: (1..polynomial.aux_info.max_degree)
|
||||
.map(|degree| {
|
||||
let points = (0..1 + degree as u64).map(F::from).collect::<Vec<_>>();
|
||||
let weights = barycentric_weights(&points);
|
||||
(points, weights)
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -110,83 +117,56 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
|
||||
let products_list = self.poly.products.clone();
|
||||
let mut products_sum = vec![F::zero(); self.poly.aux_info.max_degree + 1];
|
||||
|
||||
// let compute_sum = start_timer!(|| "compute sum");
|
||||
|
||||
// Step 2: generate sum for the partial evaluated polynomial:
|
||||
// f(r_1, ... r_m,, x_{m+1}... x_n)
|
||||
|
||||
#[cfg(feature = "parallel")]
|
||||
{
|
||||
let flag = (self.poly.aux_info.max_degree == 2)
|
||||
&& (products_list.len() == 1)
|
||||
&& (products_list[0].0 == F::one());
|
||||
if flag {
|
||||
for (t, e) in products_sum.iter_mut().enumerate() {
|
||||
let evals = (0..1 << (self.poly.aux_info.num_variables - self.round))
|
||||
.into_par_iter()
|
||||
.map(|b| {
|
||||
// evaluate P_round(t)
|
||||
let table0 = &flattened_ml_extensions[products_list[0].1[0]];
|
||||
let table1 = &flattened_ml_extensions[products_list[0].1[1]];
|
||||
if t == 0 {
|
||||
table0[b << 1] * table1[b << 1]
|
||||
} else if t == 1 {
|
||||
table0[(b << 1) + 1] * table1[(b << 1) + 1]
|
||||
} else {
|
||||
(table0[(b << 1) + 1] + table0[(b << 1) + 1] - table0[b << 1])
|
||||
* (table1[(b << 1) + 1] + table1[(b << 1) + 1] - table1[b << 1])
|
||||
}
|
||||
})
|
||||
.collect::<Vec<F>>();
|
||||
*e += evals.par_iter().sum::<F>();
|
||||
}
|
||||
} else {
|
||||
for (t, e) in products_sum.iter_mut().enumerate() {
|
||||
let t = F::from(t as u128);
|
||||
let products = (0..1 << (self.poly.aux_info.num_variables - self.round))
|
||||
.into_par_iter()
|
||||
.map(|b| {
|
||||
// evaluate P_round(t)
|
||||
let mut tmp = F::zero();
|
||||
products_list.iter().for_each(|(coefficient, products)| {
|
||||
let num_mles = products.len();
|
||||
let mut product = *coefficient;
|
||||
for &f in products.iter().take(num_mles) {
|
||||
let table = &flattened_ml_extensions[f]; // f's range is checked in init
|
||||
// TODO: Could be done faster by cashing the results from the
|
||||
// previous t and adding the diff
|
||||
// Also possible to use Karatsuba multiplication
|
||||
product *=
|
||||
table[b << 1] + (table[(b << 1) + 1] - table[b << 1]) * t;
|
||||
}
|
||||
tmp += product;
|
||||
products_list.iter().for_each(|(coefficient, products)| {
|
||||
let mut sum = cfg_into_iter!(0..1 << (self.poly.aux_info.num_variables - self.round))
|
||||
.fold(
|
||||
|| {
|
||||
(
|
||||
vec![(F::zero(), F::zero()); products.len()],
|
||||
vec![F::zero(); products.len() + 1],
|
||||
)
|
||||
},
|
||||
|(mut buf, mut acc), b| {
|
||||
buf.iter_mut()
|
||||
.zip(products.iter())
|
||||
.for_each(|((eval, step), f)| {
|
||||
let table = &flattened_ml_extensions[*f];
|
||||
*eval = table[b << 1];
|
||||
*step = table[(b << 1) + 1] - table[b << 1];
|
||||
});
|
||||
|
||||
tmp
|
||||
})
|
||||
.collect::<Vec<F>>();
|
||||
*e += products.par_iter().sum::<F>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "parallel"))]
|
||||
products_sum.iter_mut().enumerate().for_each(|(t, e)| {
|
||||
let t = F::from(t as u64);
|
||||
let one_minus_t = F::one() - t;
|
||||
|
||||
for b in 0..1 << (self.poly.aux_info.num_variables - self.round) {
|
||||
// evaluate P_round(t)
|
||||
for (coefficient, products) in products_list.iter() {
|
||||
let num_mles = products.len();
|
||||
let mut product = *coefficient;
|
||||
for &f in products.iter().take(num_mles) {
|
||||
let table = &flattened_ml_extensions[f]; // f's range is checked in init
|
||||
product *= table[b << 1] + (table[(b << 1) + 1] - table[b << 1]) * t;
|
||||
}
|
||||
*e += product;
|
||||
}
|
||||
}
|
||||
acc[0] += buf.iter().map(|(eval, _)| eval).product::<F>();
|
||||
acc[1..].iter_mut().for_each(|acc| {
|
||||
buf.iter_mut().for_each(|(eval, step)| *eval += step as &_);
|
||||
*acc += buf.iter().map(|(eval, _)| eval).product::<F>();
|
||||
});
|
||||
(buf, acc)
|
||||
},
|
||||
)
|
||||
.map(|(_, partial)| partial)
|
||||
.reduce(
|
||||
|| vec![F::zero(); products.len() + 1],
|
||||
|mut sum, partial| {
|
||||
sum.iter_mut()
|
||||
.zip(partial.iter())
|
||||
.for_each(|(sum, partial)| *sum += partial);
|
||||
sum
|
||||
},
|
||||
);
|
||||
sum.iter_mut().for_each(|sum| *sum *= coefficient);
|
||||
let extraploation = cfg_into_iter!(0..self.poly.aux_info.max_degree - products.len())
|
||||
.map(|i| {
|
||||
let (points, weights) = &self.extrapolation_aux[products.len() - 1];
|
||||
let at = F::from((products.len() + 1 + i) as u64);
|
||||
extrapolate(points, weights, &sum, &at)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
products_sum
|
||||
.iter_mut()
|
||||
.zip(sum.iter().chain(extraploation.iter()))
|
||||
.for_each(|(products_sum, sum)| *products_sum += sum);
|
||||
});
|
||||
|
||||
// update prover's state to the partial evaluated polynomial
|
||||
@@ -195,10 +175,43 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
|
||||
.map(|x| Arc::new(x.clone()))
|
||||
.collect();
|
||||
|
||||
// end_timer!(compute_sum);
|
||||
// end_timer!(start);
|
||||
Ok(IOPProverMessage {
|
||||
evaluations: products_sum,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn barycentric_weights<F: PrimeField>(points: &[F]) -> Vec<F> {
|
||||
let mut weights = points
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(j, point_j)| {
|
||||
points
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, point_i)| (i != j).then(|| *point_j - point_i))
|
||||
.reduce(|acc, value| acc * value)
|
||||
.unwrap_or_else(F::one)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
batch_inversion(&mut weights);
|
||||
weights
|
||||
}
|
||||
|
||||
fn extrapolate<F: PrimeField>(points: &[F], weights: &[F], evals: &[F], at: &F) -> F {
|
||||
let (coeffs, sum_inv) = {
|
||||
let mut coeffs = points.iter().map(|point| *at - point).collect::<Vec<_>>();
|
||||
batch_inversion(&mut coeffs);
|
||||
coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| {
|
||||
*coeff *= weight;
|
||||
});
|
||||
let sum_inv = coeffs.iter().sum::<F>().inverse().unwrap_or_default();
|
||||
(coeffs, sum_inv)
|
||||
};
|
||||
coeffs
|
||||
.iter()
|
||||
.zip(evals)
|
||||
.map(|(coeff, eval)| *coeff * eval)
|
||||
.sum::<F>()
|
||||
* sum_inv
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user