diff --git a/subroutines/src/pcs/multilinear_kzg/batching.rs b/subroutines/src/pcs/multilinear_kzg/batching.rs index dd8783d..f511ade 100644 --- a/subroutines/src/pcs/multilinear_kzg/batching.rs +++ b/subroutines/src/pcs/multilinear_kzg/batching.rs @@ -23,8 +23,7 @@ use arithmetic::{build_eq_x_r_vec, DenseMultilinearExtension, VPAuxInfo, Virtual use ark_ec::{msm::VariableBaseMSM, PairingEngine, ProjectiveCurve}; use ark_ff::PrimeField; use ark_std::{end_timer, log2, start_timer, One, Zero}; -use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; -use std::{marker::PhantomData, sync::Arc}; +use std::{collections::BTreeMap, iter, marker::PhantomData, ops::Deref, sync::Arc}; use transcript::IOPTranscript; #[derive(Clone, Debug, Default, PartialEq, Eq)] @@ -80,22 +79,39 @@ where // \tilde g_i(b) = eq(t, i) * f_i(b) let timer = start_timer!(|| format!("compute tilde g for {} points", points.len())); - let mut tilde_gs = vec![]; - for (index, f_i) in polynomials.iter().enumerate() { - let mut tilde_g_eval = vec![E::Fr::zero(); 1 << num_var]; - for (j, &f_i_eval) in f_i.iter().enumerate() { - tilde_g_eval[j] = f_i_eval * eq_t_i_list[index]; - } - tilde_gs.push(Arc::new(DenseMultilinearExtension::from_evaluations_vec( - num_var, - tilde_g_eval, - ))); - } + // combine the polynomials that have same opening point first to reduce the + // cost of sum check later. + let point_indices = points + .iter() + .fold(BTreeMap::<_, _>::new(), |mut indices, point| { + let idx = indices.len(); + indices.entry(point).or_insert(idx); + indices + }); + let deduped_points = + BTreeMap::from_iter(point_indices.iter().map(|(point, idx)| (*idx, *point))) + .into_values() + .collect::>(); + let merged_tilde_gs = polynomials + .iter() + .zip(points.iter()) + .zip(eq_t_i_list.iter()) + .fold( + iter::repeat_with(DenseMultilinearExtension::zero) + .map(Arc::new) + .take(point_indices.len()) + .collect::>(), + |mut merged_tilde_gs, ((poly, point), coeff)| { + *Arc::make_mut(&mut merged_tilde_gs[point_indices[point]]) += + (*coeff, poly.deref()); + merged_tilde_gs + }, + ); end_timer!(timer); let timer = start_timer!(|| format!("compute tilde eq for {} points", points.len())); - let tilde_eqs: Vec::Fr>>> = points - .par_iter() + let tilde_eqs: Vec<_> = deduped_points + .iter() .map(|point| { let eq_b_zi = build_eq_x_r_vec(point).unwrap(); Arc::new(DenseMultilinearExtension::from_evaluations_vec( @@ -110,8 +126,8 @@ where let step = start_timer!(|| "add mle"); let mut sum_check_vp = VirtualPolynomial::new(num_var); - for (tilde_g, tilde_eq) in tilde_gs.iter().zip(tilde_eqs.into_iter()) { - sum_check_vp.add_mle_list([tilde_g.clone(), tilde_eq], E::Fr::one())?; + for (merged_tilde_g, tilde_eq) in merged_tilde_gs.iter().zip(tilde_eqs.into_iter()) { + sum_check_vp.add_mle_list([merged_tilde_g.clone(), tilde_eq], E::Fr::one())?; } end_timer!(step); @@ -133,17 +149,11 @@ where // build g'(X) = \sum_i=1..k \tilde eq_i(a2) * \tilde g_i(X) where (a2) is the // sumcheck's point \tilde eq_i(a2) = eq(a2, point_i) let step = start_timer!(|| "evaluate at a2"); - let mut g_prime_evals = vec![E::Fr::zero(); 1 << num_var]; - for (tilde_g, point) in tilde_gs.iter().zip(points.iter()) { + let mut g_prime = Arc::new(DenseMultilinearExtension::zero()); + for (merged_tilde_g, point) in merged_tilde_gs.iter().zip(deduped_points.iter()) { let eq_i_a2 = eq_eval(a2, point)?; - for (j, &tilde_g_eval) in tilde_g.iter().enumerate() { - g_prime_evals[j] += tilde_g_eval * eq_i_a2; - } + *Arc::make_mut(&mut g_prime) += (eq_i_a2, merged_tilde_g.deref()); } - let g_prime = Arc::new(DenseMultilinearExtension::from_evaluations_vec( - num_var, - g_prime_evals, - )); end_timer!(step); let step = start_timer!(|| "pcs open"); diff --git a/subroutines/src/poly_iop/structs.rs b/subroutines/src/poly_iop/structs.rs index 441bf0d..b8e0848 100644 --- a/subroutines/src/poly_iop/structs.rs +++ b/subroutines/src/poly_iop/structs.rs @@ -35,6 +35,9 @@ pub struct IOPProverState { pub(crate) round: usize, /// pointer to the virtual polynomial pub(crate) poly: VirtualPolynomial, + /// points with precomputed barycentric weights for extrapolating smaller + /// degree uni-polys to `max_degree + 1` evaluations. + pub(crate) extrapolation_aux: Vec<(Vec, Vec)>, } /// Prover State of a PolyIOP diff --git a/subroutines/src/poly_iop/sum_check/prover.rs b/subroutines/src/poly_iop/sum_check/prover.rs index 448d8dd..f59e0ab 100644 --- a/subroutines/src/poly_iop/sum_check/prover.rs +++ b/subroutines/src/poly_iop/sum_check/prover.rs @@ -12,9 +12,9 @@ use crate::poly_iop::{ structs::{IOPProverMessage, IOPProverState}, }; use arithmetic::{fix_variables, VirtualPolynomial}; -use ark_ff::PrimeField; +use ark_ff::{batch_inversion, PrimeField}; use ark_poly::DenseMultilinearExtension; -use ark_std::{end_timer, start_timer, vec::Vec}; +use ark_std::{cfg_into_iter, end_timer, start_timer, vec::Vec}; use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator}; use std::sync::Arc; @@ -40,6 +40,13 @@ impl SumCheckProver for IOPProverState { challenges: Vec::with_capacity(polynomial.aux_info.num_variables), round: 0, poly: polynomial.clone(), + extrapolation_aux: (1..polynomial.aux_info.max_degree) + .map(|degree| { + let points = (0..1 + degree as u64).map(F::from).collect::>(); + let weights = barycentric_weights(&points); + (points, weights) + }) + .collect(), }) } @@ -110,83 +117,56 @@ impl SumCheckProver for IOPProverState { let products_list = self.poly.products.clone(); let mut products_sum = vec![F::zero(); self.poly.aux_info.max_degree + 1]; - // let compute_sum = start_timer!(|| "compute sum"); - // Step 2: generate sum for the partial evaluated polynomial: // f(r_1, ... r_m,, x_{m+1}... x_n) - #[cfg(feature = "parallel")] - { - let flag = (self.poly.aux_info.max_degree == 2) - && (products_list.len() == 1) - && (products_list[0].0 == F::one()); - if flag { - for (t, e) in products_sum.iter_mut().enumerate() { - let evals = (0..1 << (self.poly.aux_info.num_variables - self.round)) - .into_par_iter() - .map(|b| { - // evaluate P_round(t) - let table0 = &flattened_ml_extensions[products_list[0].1[0]]; - let table1 = &flattened_ml_extensions[products_list[0].1[1]]; - if t == 0 { - table0[b << 1] * table1[b << 1] - } else if t == 1 { - table0[(b << 1) + 1] * table1[(b << 1) + 1] - } else { - (table0[(b << 1) + 1] + table0[(b << 1) + 1] - table0[b << 1]) - * (table1[(b << 1) + 1] + table1[(b << 1) + 1] - table1[b << 1]) - } - }) - .collect::>(); - *e += evals.par_iter().sum::(); - } - } else { - for (t, e) in products_sum.iter_mut().enumerate() { - let t = F::from(t as u128); - let products = (0..1 << (self.poly.aux_info.num_variables - self.round)) - .into_par_iter() - .map(|b| { - // evaluate P_round(t) - let mut tmp = F::zero(); - products_list.iter().for_each(|(coefficient, products)| { - let num_mles = products.len(); - let mut product = *coefficient; - for &f in products.iter().take(num_mles) { - let table = &flattened_ml_extensions[f]; // f's range is checked in init - // TODO: Could be done faster by cashing the results from the - // previous t and adding the diff - // Also possible to use Karatsuba multiplication - product *= - table[b << 1] + (table[(b << 1) + 1] - table[b << 1]) * t; - } - tmp += product; + products_list.iter().for_each(|(coefficient, products)| { + let mut sum = cfg_into_iter!(0..1 << (self.poly.aux_info.num_variables - self.round)) + .fold( + || { + ( + vec![(F::zero(), F::zero()); products.len()], + vec![F::zero(); products.len() + 1], + ) + }, + |(mut buf, mut acc), b| { + buf.iter_mut() + .zip(products.iter()) + .for_each(|((eval, step), f)| { + let table = &flattened_ml_extensions[*f]; + *eval = table[b << 1]; + *step = table[(b << 1) + 1] - table[b << 1]; }); - - tmp - }) - .collect::>(); - *e += products.par_iter().sum::(); - } - } - } - - #[cfg(not(feature = "parallel"))] - products_sum.iter_mut().enumerate().for_each(|(t, e)| { - let t = F::from(t as u64); - let one_minus_t = F::one() - t; - - for b in 0..1 << (self.poly.aux_info.num_variables - self.round) { - // evaluate P_round(t) - for (coefficient, products) in products_list.iter() { - let num_mles = products.len(); - let mut product = *coefficient; - for &f in products.iter().take(num_mles) { - let table = &flattened_ml_extensions[f]; // f's range is checked in init - product *= table[b << 1] + (table[(b << 1) + 1] - table[b << 1]) * t; - } - *e += product; - } - } + acc[0] += buf.iter().map(|(eval, _)| eval).product::(); + acc[1..].iter_mut().for_each(|acc| { + buf.iter_mut().for_each(|(eval, step)| *eval += step as &_); + *acc += buf.iter().map(|(eval, _)| eval).product::(); + }); + (buf, acc) + }, + ) + .map(|(_, partial)| partial) + .reduce( + || vec![F::zero(); products.len() + 1], + |mut sum, partial| { + sum.iter_mut() + .zip(partial.iter()) + .for_each(|(sum, partial)| *sum += partial); + sum + }, + ); + sum.iter_mut().for_each(|sum| *sum *= coefficient); + let extraploation = cfg_into_iter!(0..self.poly.aux_info.max_degree - products.len()) + .map(|i| { + let (points, weights) = &self.extrapolation_aux[products.len() - 1]; + let at = F::from((products.len() + 1 + i) as u64); + extrapolate(points, weights, &sum, &at) + }) + .collect::>(); + products_sum + .iter_mut() + .zip(sum.iter().chain(extraploation.iter())) + .for_each(|(products_sum, sum)| *products_sum += sum); }); // update prover's state to the partial evaluated polynomial @@ -195,10 +175,43 @@ impl SumCheckProver for IOPProverState { .map(|x| Arc::new(x.clone())) .collect(); - // end_timer!(compute_sum); - // end_timer!(start); Ok(IOPProverMessage { evaluations: products_sum, }) } } + +fn barycentric_weights(points: &[F]) -> Vec { + let mut weights = points + .iter() + .enumerate() + .map(|(j, point_j)| { + points + .iter() + .enumerate() + .filter_map(|(i, point_i)| (i != j).then(|| *point_j - point_i)) + .reduce(|acc, value| acc * value) + .unwrap_or_else(F::one) + }) + .collect::>(); + batch_inversion(&mut weights); + weights +} + +fn extrapolate(points: &[F], weights: &[F], evals: &[F], at: &F) -> F { + let (coeffs, sum_inv) = { + let mut coeffs = points.iter().map(|point| *at - point).collect::>(); + batch_inversion(&mut coeffs); + coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| { + *coeff *= weight; + }); + let sum_inv = coeffs.iter().sum::().inverse().unwrap_or_default(); + (coeffs, sum_inv) + }; + coeffs + .iter() + .zip(evals) + .map(|(coeff, eval)| *coeff * eval) + .sum::() + * sum_inv +}