Browse Source

Improve sum check in general and preprocess for sum check in mlkzg `multi_open` (#123)

* feat: faster sum check prover and multilinear kzg batching open

* fix: add comment about why we combine polys that have the same opening point

* fix: remove the unnecessary last eval increment
main
Han 1 year ago
committed by GitHub
parent
commit
f64bfe6c2a
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 129 additions and 103 deletions
  1. +36
    -26
      subroutines/src/pcs/multilinear_kzg/batching.rs
  2. +3
    -0
      subroutines/src/poly_iop/structs.rs
  3. +90
    -77
      subroutines/src/poly_iop/sum_check/prover.rs

+ 36
- 26
subroutines/src/pcs/multilinear_kzg/batching.rs

@ -23,8 +23,7 @@ use arithmetic::{build_eq_x_r_vec, DenseMultilinearExtension, VPAuxInfo, Virtual
use ark_ec::{msm::VariableBaseMSM, PairingEngine, ProjectiveCurve};
use ark_ff::PrimeField;
use ark_std::{end_timer, log2, start_timer, One, Zero};
use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
use std::{marker::PhantomData, sync::Arc};
use std::{collections::BTreeMap, iter, marker::PhantomData, ops::Deref, sync::Arc};
use transcript::IOPTranscript;
#[derive(Clone, Debug, Default, PartialEq, Eq)]
@ -80,22 +79,39 @@ where
// \tilde g_i(b) = eq(t, i) * f_i(b)
let timer = start_timer!(|| format!("compute tilde g for {} points", points.len()));
let mut tilde_gs = vec![];
for (index, f_i) in polynomials.iter().enumerate() {
let mut tilde_g_eval = vec![E::Fr::zero(); 1 << num_var];
for (j, &f_i_eval) in f_i.iter().enumerate() {
tilde_g_eval[j] = f_i_eval * eq_t_i_list[index];
}
tilde_gs.push(Arc::new(DenseMultilinearExtension::from_evaluations_vec(
num_var,
tilde_g_eval,
)));
}
// combine the polynomials that have same opening point first to reduce the
// cost of sum check later.
let point_indices = points
.iter()
.fold(BTreeMap::<_, _>::new(), |mut indices, point| {
let idx = indices.len();
indices.entry(point).or_insert(idx);
indices
});
let deduped_points =
BTreeMap::from_iter(point_indices.iter().map(|(point, idx)| (*idx, *point)))
.into_values()
.collect::<Vec<_>>();
let merged_tilde_gs = polynomials
.iter()
.zip(points.iter())
.zip(eq_t_i_list.iter())
.fold(
iter::repeat_with(DenseMultilinearExtension::zero)
.map(Arc::new)
.take(point_indices.len())
.collect::<Vec<_>>(),
|mut merged_tilde_gs, ((poly, point), coeff)| {
*Arc::make_mut(&mut merged_tilde_gs[point_indices[point]]) +=
(*coeff, poly.deref());
merged_tilde_gs
},
);
end_timer!(timer);
let timer = start_timer!(|| format!("compute tilde eq for {} points", points.len()));
let tilde_eqs: Vec<Arc<DenseMultilinearExtension<<E as PairingEngine>::Fr>>> = points
.par_iter()
let tilde_eqs: Vec<_> = deduped_points
.iter()
.map(|point| {
let eq_b_zi = build_eq_x_r_vec(point).unwrap();
Arc::new(DenseMultilinearExtension::from_evaluations_vec(
@ -110,8 +126,8 @@ where
let step = start_timer!(|| "add mle");
let mut sum_check_vp = VirtualPolynomial::new(num_var);
for (tilde_g, tilde_eq) in tilde_gs.iter().zip(tilde_eqs.into_iter()) {
sum_check_vp.add_mle_list([tilde_g.clone(), tilde_eq], E::Fr::one())?;
for (merged_tilde_g, tilde_eq) in merged_tilde_gs.iter().zip(tilde_eqs.into_iter()) {
sum_check_vp.add_mle_list([merged_tilde_g.clone(), tilde_eq], E::Fr::one())?;
}
end_timer!(step);
@ -133,17 +149,11 @@ where
// build g'(X) = \sum_i=1..k \tilde eq_i(a2) * \tilde g_i(X) where (a2) is the
// sumcheck's point \tilde eq_i(a2) = eq(a2, point_i)
let step = start_timer!(|| "evaluate at a2");
let mut g_prime_evals = vec![E::Fr::zero(); 1 << num_var];
for (tilde_g, point) in tilde_gs.iter().zip(points.iter()) {
let mut g_prime = Arc::new(DenseMultilinearExtension::zero());
for (merged_tilde_g, point) in merged_tilde_gs.iter().zip(deduped_points.iter()) {
let eq_i_a2 = eq_eval(a2, point)?;
for (j, &tilde_g_eval) in tilde_g.iter().enumerate() {
g_prime_evals[j] += tilde_g_eval * eq_i_a2;
}
*Arc::make_mut(&mut g_prime) += (eq_i_a2, merged_tilde_g.deref());
}
let g_prime = Arc::new(DenseMultilinearExtension::from_evaluations_vec(
num_var,
g_prime_evals,
));
end_timer!(step);
let step = start_timer!(|| "pcs open");

+ 3
- 0
subroutines/src/poly_iop/structs.rs

@ -35,6 +35,9 @@ pub struct IOPProverState {
pub(crate) round: usize,
/// pointer to the virtual polynomial
pub(crate) poly: VirtualPolynomial<F>,
/// points with precomputed barycentric weights for extrapolating smaller
/// degree uni-polys to `max_degree + 1` evaluations.
pub(crate) extrapolation_aux: Vec<(Vec<F>, Vec<F>)>,
}
/// Prover State of a PolyIOP

+ 90
- 77
subroutines/src/poly_iop/sum_check/prover.rs

@ -12,9 +12,9 @@ use crate::poly_iop::{
structs::{IOPProverMessage, IOPProverState},
};
use arithmetic::{fix_variables, VirtualPolynomial};
use ark_ff::PrimeField;
use ark_ff::{batch_inversion, PrimeField};
use ark_poly::DenseMultilinearExtension;
use ark_std::{end_timer, start_timer, vec::Vec};
use ark_std::{cfg_into_iter, end_timer, start_timer, vec::Vec};
use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator};
use std::sync::Arc;
@ -40,6 +40,13 @@ impl SumCheckProver for IOPProverState {
challenges: Vec::with_capacity(polynomial.aux_info.num_variables),
round: 0,
poly: polynomial.clone(),
extrapolation_aux: (1..polynomial.aux_info.max_degree)
.map(|degree| {
let points = (0..1 + degree as u64).map(F::from).collect::<Vec<_>>();
let weights = barycentric_weights(&points);
(points, weights)
})
.collect(),
})
}
@ -110,83 +117,56 @@ impl SumCheckProver for IOPProverState {
let products_list = self.poly.products.clone();
let mut products_sum = vec![F::zero(); self.poly.aux_info.max_degree + 1];
// let compute_sum = start_timer!(|| "compute sum");
// Step 2: generate sum for the partial evaluated polynomial:
// f(r_1, ... r_m,, x_{m+1}... x_n)
#[cfg(feature = "parallel")]
{
let flag = (self.poly.aux_info.max_degree == 2)
&& (products_list.len() == 1)
&& (products_list[0].0 == F::one());
if flag {
for (t, e) in products_sum.iter_mut().enumerate() {
let evals = (0..1 << (self.poly.aux_info.num_variables - self.round))
.into_par_iter()
.map(|b| {
// evaluate P_round(t)
let table0 = &flattened_ml_extensions[products_list[0].1[0]];
let table1 = &flattened_ml_extensions[products_list[0].1[1]];
if t == 0 {
table0[b << 1] * table1[b << 1]
} else if t == 1 {
table0[(b << 1) + 1] * table1[(b << 1) + 1]
} else {
(table0[(b << 1) + 1] + table0[(b << 1) + 1] - table0[b << 1])
* (table1[(b << 1) + 1] + table1[(b << 1) + 1] - table1[b << 1])
}
})
.collect::<Vec<F>>();
*e += evals.par_iter().sum::<F>();
}
} else {
for (t, e) in products_sum.iter_mut().enumerate() {
let t = F::from(t as u128);
let products = (0..1 << (self.poly.aux_info.num_variables - self.round))
.into_par_iter()
.map(|b| {
// evaluate P_round(t)
let mut tmp = F::zero();
products_list.iter().for_each(|(coefficient, products)| {
let num_mles = products.len();
let mut product = *coefficient;
for &f in products.iter().take(num_mles) {
let table = &flattened_ml_extensions[f]; // f's range is checked in init
// TODO: Could be done faster by cashing the results from the
// previous t and adding the diff
// Also possible to use Karatsuba multiplication
product *=
table[b << 1] + (table[(b << 1) + 1] - table[b << 1]) * t;
}
tmp += product;
products_list.iter().for_each(|(coefficient, products)| {
let mut sum = cfg_into_iter!(0..1 << (self.poly.aux_info.num_variables - self.round))
.fold(
|| {
(
vec![(F::zero(), F::zero()); products.len()],
vec![F::zero(); products.len() + 1],
)
},
|(mut buf, mut acc), b| {
buf.iter_mut()
.zip(products.iter())
.for_each(|((eval, step), f)| {
let table = &flattened_ml_extensions[*f];
*eval = table[b << 1];
*step = table[(b << 1) + 1] - table[b << 1];
});
tmp
})
.collect::<Vec<F>>();
*e += products.par_iter().sum::<F>();
}
}
}
#[cfg(not(feature = "parallel"))]
products_sum.iter_mut().enumerate().for_each(|(t, e)| {
let t = F::from(t as u64);
let one_minus_t = F::one() - t;
for b in 0..1 << (self.poly.aux_info.num_variables - self.round) {
// evaluate P_round(t)
for (coefficient, products) in products_list.iter() {
let num_mles = products.len();
let mut product = *coefficient;
for &f in products.iter().take(num_mles) {
let table = &flattened_ml_extensions[f]; // f's range is checked in init
product *= table[b << 1] + (table[(b << 1) + 1] - table[b << 1]) * t;
}
*e += product;
}
}
acc[0] += buf.iter().map(|(eval, _)| eval).product::<F>();
acc[1..].iter_mut().for_each(|acc| {
buf.iter_mut().for_each(|(eval, step)| *eval += step as &_);
*acc += buf.iter().map(|(eval, _)| eval).product::<F>();
});
(buf, acc)
},
)
.map(|(_, partial)| partial)
.reduce(
|| vec![F::zero(); products.len() + 1],
|mut sum, partial| {
sum.iter_mut()
.zip(partial.iter())
.for_each(|(sum, partial)| *sum += partial);
sum
},
);
sum.iter_mut().for_each(|sum| *sum *= coefficient);
let extraploation = cfg_into_iter!(0..self.poly.aux_info.max_degree - products.len())
.map(|i| {
let (points, weights) = &self.extrapolation_aux[products.len() - 1];
let at = F::from((products.len() + 1 + i) as u64);
extrapolate(points, weights, &sum, &at)
})
.collect::<Vec<_>>();
products_sum
.iter_mut()
.zip(sum.iter().chain(extraploation.iter()))
.for_each(|(products_sum, sum)| *products_sum += sum);
});
// update prover's state to the partial evaluated polynomial
@ -195,10 +175,43 @@ impl SumCheckProver for IOPProverState {
.map(|x| Arc::new(x.clone()))
.collect();
// end_timer!(compute_sum);
// end_timer!(start);
Ok(IOPProverMessage {
evaluations: products_sum,
})
}
}
fn barycentric_weights<F: PrimeField>(points: &[F]) -> Vec<F> {
let mut weights = points
.iter()
.enumerate()
.map(|(j, point_j)| {
points
.iter()
.enumerate()
.filter_map(|(i, point_i)| (i != j).then(|| *point_j - point_i))
.reduce(|acc, value| acc * value)
.unwrap_or_else(F::one)
})
.collect::<Vec<_>>();
batch_inversion(&mut weights);
weights
}
fn extrapolate<F: PrimeField>(points: &[F], weights: &[F], evals: &[F], at: &F) -> F {
let (coeffs, sum_inv) = {
let mut coeffs = points.iter().map(|point| *at - point).collect::<Vec<_>>();
batch_inversion(&mut coeffs);
coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| {
*coeff *= weight;
});
let sum_inv = coeffs.iter().sum::<F>().inverse().unwrap_or_default();
(coeffs, sum_inv)
};
coeffs
.iter()
.zip(evals)
.map(|(coeff, eval)| *coeff * eval)
.sum::<F>()
* sum_inv
}

Loading…
Cancel
Save