From 7c537900949eb16fb07dff4439024415dba167da Mon Sep 17 00:00:00 2001 From: chancharles92 Date: Thu, 12 May 2022 16:53:01 -0700 Subject: [PATCH] improve compute_sum parallelism (#14) --- poly-iop/src/sum_check/mod.rs | 2 +- poly-iop/src/sum_check/prover.rs | 43 ++++++++++++++++---------------- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/poly-iop/src/sum_check/mod.rs b/poly-iop/src/sum_check/mod.rs index 14aa635..e678d52 100644 --- a/poly-iop/src/sum_check/mod.rs +++ b/poly-iop/src/sum_check/mod.rs @@ -212,7 +212,7 @@ impl SumCheck for PolyIOP { domain_info: &Self::DomainInfo, transcript: &mut Self::Transcript, ) -> Result { - let start = start_timer!(|| "sum check prove"); + let start = start_timer!(|| "sum check verify"); transcript.append_domain_info(domain_info)?; let mut verifier_state = VerifierState::verifier_init(domain_info); diff --git a/poly-iop/src/sum_check/prover.rs b/poly-iop/src/sum_check/prover.rs index 6dc02c6..8eb3472 100644 --- a/poly-iop/src/sum_check/prover.rs +++ b/poly-iop/src/sum_check/prover.rs @@ -122,39 +122,38 @@ impl SumCheckProver for ProverState { let compute_sum = start_timer!(|| "compute sum"); // generate sum - for b in 0..1 << (nv - i) { - #[cfg(feature = "parallel")] - products_sum - .par_iter_mut() - .take(degree + 1) - .enumerate() - .for_each(|(i, e)| { - // evaluate P_round(t) - for (coefficient, products) in products.iter() { - let num_multiplicands = products.len(); - let mut product = *coefficient; - for &f in products.iter().take(num_multiplicands) { - let table = &flattened_ml_extensions[f]; // f's range is checked in init - product *= table[b << 1] * (F::one() - F::from(i as u64)) - + table[(b << 1) + 1] * F::from(i as u64); - } - *e += product; + #[cfg(feature = "parallel")] + products_sum.par_iter_mut().enumerate().for_each(|(t, e)| { + for b in 0..1 << (nv - i) { + // evaluate P_round(t) + for (coefficient, products) in products.iter() { + let num_multiplicands = products.len(); + let mut product = *coefficient; + for &f in products.iter().take(num_multiplicands) { + let table = &flattened_ml_extensions[f]; // f's range is checked in init + product *= table[b << 1] * (F::one() - F::from(t as u64)) + + table[(b << 1) + 1] * F::from(t as u64); } - }); - #[cfg(not(feature = "parallel"))] + *e += product; + } + } + }); + + #[cfg(not(feature = "parallel"))] + for b in 0..1 << (nv - i) { products_sum .iter_mut() .take(degree + 1) .enumerate() - .for_each(|(i, e)| { + .for_each(|(t, e)| { // evaluate P_round(t) for (coefficient, products) in products.iter() { let num_multiplicands = products.len(); let mut product = *coefficient; for &f in products.iter().take(num_multiplicands) { let table = &flattened_ml_extensions[f]; // f's range is checked in init - product *= table[b << 1] * (F::one() - F::from(i as u64)) - + table[(b << 1) + 1] * F::from(i as u64); + product *= table[b << 1] * (F::one() - F::from(t as u64)) + + table[(b << 1) + 1] * F::from(t as u64); } *e += product; }