Browse Source

improve compute_sum parallelism (#14)

main
chancharles92 2 years ago
committed by GitHub
parent
commit
7c53790094
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 23 deletions
  1. +1
    -1
      poly-iop/src/sum_check/mod.rs
  2. +21
    -22
      poly-iop/src/sum_check/prover.rs

+ 1
- 1
poly-iop/src/sum_check/mod.rs

@ -212,7 +212,7 @@ impl SumCheck for PolyIOP {
domain_info: &Self::DomainInfo,
transcript: &mut Self::Transcript,
) -> Result<Self::SubClaim, PolyIOPErrors> {
let start = start_timer!(|| "sum check prove");
let start = start_timer!(|| "sum check verify");
transcript.append_domain_info(domain_info)?;
let mut verifier_state = VerifierState::verifier_init(domain_info);

+ 21
- 22
poly-iop/src/sum_check/prover.rs

@ -122,39 +122,38 @@ impl SumCheckProver for ProverState {
let compute_sum = start_timer!(|| "compute sum");
// generate sum
for b in 0..1 << (nv - i) {
#[cfg(feature = "parallel")]
products_sum
.par_iter_mut()
.take(degree + 1)
.enumerate()
.for_each(|(i, e)| {
// evaluate P_round(t)
for (coefficient, products) in products.iter() {
let num_multiplicands = products.len();
let mut product = *coefficient;
for &f in products.iter().take(num_multiplicands) {
let table = &flattened_ml_extensions[f]; // f's range is checked in init
product *= table[b << 1] * (F::one() - F::from(i as u64))
+ table[(b << 1) + 1] * F::from(i as u64);
}
*e += product;
#[cfg(feature = "parallel")]
products_sum.par_iter_mut().enumerate().for_each(|(t, e)| {
for b in 0..1 << (nv - i) {
// evaluate P_round(t)
for (coefficient, products) in products.iter() {
let num_multiplicands = products.len();
let mut product = *coefficient;
for &f in products.iter().take(num_multiplicands) {
let table = &flattened_ml_extensions[f]; // f's range is checked in init
product *= table[b << 1] * (F::one() - F::from(t as u64))
+ table[(b << 1) + 1] * F::from(t as u64);
}
});
#[cfg(not(feature = "parallel"))]
*e += product;
}
}
});
#[cfg(not(feature = "parallel"))]
for b in 0..1 << (nv - i) {
products_sum
.iter_mut()
.take(degree + 1)
.enumerate()
.for_each(|(i, e)| {
.for_each(|(t, e)| {
// evaluate P_round(t)
for (coefficient, products) in products.iter() {
let num_multiplicands = products.len();
let mut product = *coefficient;
for &f in products.iter().take(num_multiplicands) {
let table = &flattened_ml_extensions[f]; // f's range is checked in init
product *= table[b << 1] * (F::one() - F::from(i as u64))
+ table[(b << 1) + 1] * F::from(i as u64);
product *= table[b << 1] * (F::one() - F::from(t as u64))
+ table[(b << 1) + 1] * F::from(t as u64);
}
*e += product;
}

Loading…
Cancel
Save