From 5d6985b79984c4acb29d9c6e5880e42ab0a7673c Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Fri, 16 Dec 2022 10:03:30 -0500 Subject: [PATCH] more parallelization --- hyperplonk/src/snark.rs | 2 +- subroutines/src/poly_iop/sum_check/prover.rs | 15 +++++---------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/hyperplonk/src/snark.rs b/hyperplonk/src/snark.rs index 816683c..ff18ec2 100644 --- a/hyperplonk/src/snark.rs +++ b/hyperplonk/src/snark.rs @@ -76,7 +76,7 @@ where .collect(); let selector_commitments = selector_oracles - .iter() + .par_iter() .map(|poly| PCS::commit(&pcs_prover_param, poly)) .collect::, _>>()?; diff --git a/subroutines/src/poly_iop/sum_check/prover.rs b/subroutines/src/poly_iop/sum_check/prover.rs index a3cdb52..4145dde 100644 --- a/subroutines/src/poly_iop/sum_check/prover.rs +++ b/subroutines/src/poly_iop/sum_check/prover.rs @@ -9,7 +9,7 @@ use arithmetic::{fix_variables, VirtualPolynomial}; use ark_ff::PrimeField; use ark_poly::DenseMultilinearExtension; use ark_std::{end_timer, start_timer, vec::Vec}; -use rayon::prelude::IntoParallelIterator; +use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator}; use std::sync::Arc; #[cfg(feature = "parallel")] @@ -71,7 +71,7 @@ impl SumCheckProver for IOPProverState { let mut flattened_ml_extensions: Vec> = self .poly .flattened_ml_extensions - .iter() + .par_iter() .map(|x| x.as_ref().clone()) .collect(); @@ -132,9 +132,7 @@ impl SumCheckProver for IOPProverState { } }) .collect::>(); - for val in evals.iter() { - *e += val - } + *e += evals.par_iter().sum::(); } } else { for (t, e) in products_sum.iter_mut().enumerate() { @@ -161,10 +159,7 @@ impl SumCheckProver for IOPProverState { tmp }) .collect::>(); - - for i in products.iter() { - *e += i - } + *e += products.par_iter().sum::(); } } } @@ -190,7 +185,7 @@ impl SumCheckProver for IOPProverState { // update prover's state to the partial evaluated polynomial self.poly.flattened_ml_extensions = flattened_ml_extensions - .iter() + .par_iter() .map(|x| Arc::new(x.clone())) .collect();