From c48c1b97a574852b68e4fd803db40a01c57a3a9d Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Fri, 16 Dec 2022 11:12:44 -0500 Subject: [PATCH] more parallelization --- arithmetic/src/multilinear_polynomial.rs | 26 +++++-------------- .../src/pcs/multilinear_kzg/batching.rs | 17 +++++++----- 2 files changed, 16 insertions(+), 27 deletions(-) diff --git a/arithmetic/src/multilinear_polynomial.rs b/arithmetic/src/multilinear_polynomial.rs index e7d569c..42e2fa3 100644 --- a/arithmetic/src/multilinear_polynomial.rs +++ b/arithmetic/src/multilinear_polynomial.rs @@ -160,16 +160,9 @@ fn fix_one_variable_helper(data: &[F], nv: usize, point: &F) -> Vec } #[cfg(feature = "parallel")] - if nv >= 13 { - // on my computer we parallelization doesn't help till nv >= 13 - res.par_iter_mut().enumerate().for_each(|(i, x)| { - *x = data[i << 1] + (data[(i << 1) + 1] - data[i << 1]) * point; - }); - } else { - for i in 0..(1 << (nv - 1)) { - res[i] = data[i << 1] + (data[(i << 1) + 1] - data[i << 1]) * point; - } - } + res.par_iter_mut().enumerate().for_each(|(i, x)| { + *x = data[i << 1] + (data[(i << 1) + 1] - data[i << 1]) * point; + }); res } @@ -279,16 +272,9 @@ fn fix_last_variable_helper(data: &[F], nv: usize, point: &F) -> Vec= 13 { - // on my computer we parallelization doesn't help till nv >= 13 - res.par_iter_mut().enumerate().for_each(|(i, x)| { - *x = data[i] + (data[i + half_len] - data[i]) * point; - }); - } else { - for b in 0..(1 << (nv - 1)) { - res[b] = data[b] + (data[b + half_len] - data[b]) * point; - } - } + res.par_iter_mut().enumerate().for_each(|(i, x)| { + *x = data[i] + (data[i + half_len] - data[i]) * point; + }); res } diff --git a/subroutines/src/pcs/multilinear_kzg/batching.rs b/subroutines/src/pcs/multilinear_kzg/batching.rs index c2f6d33..238c82b 100644 --- a/subroutines/src/pcs/multilinear_kzg/batching.rs +++ b/subroutines/src/pcs/multilinear_kzg/batching.rs @@ -17,6 +17,7 @@ use arithmetic::{build_eq_x_r_vec, DenseMultilinearExtension, VPAuxInfo, Virtual use ark_ec::{msm::VariableBaseMSM, PairingEngine, ProjectiveCurve}; use ark_ff::PrimeField; use ark_std::{end_timer, log2, start_timer, One, Zero}; +use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; use std::{marker::PhantomData, sync::Arc}; use transcript::IOPTranscript; @@ -87,13 +88,15 @@ where end_timer!(timer); let timer = start_timer!(|| format!("compute tilde eq for {} points", points.len())); - let mut tilde_eqs = vec![]; - for point in points.iter() { - let eq_b_zi = build_eq_x_r_vec(point)?; - tilde_eqs.push(Arc::new(DenseMultilinearExtension::from_evaluations_vec( - num_var, eq_b_zi, - ))); - } + let tilde_eqs: Vec::Fr>>> = points + .par_iter() + .map(|point| { + let eq_b_zi = build_eq_x_r_vec(point).unwrap(); + Arc::new(DenseMultilinearExtension::from_evaluations_vec( + num_var, eq_b_zi, + )) + }) + .collect(); end_timer!(timer); // built the virtual polynomial for SumCheck