improve compute_sum parallelism (#14)

This commit is contained in:
chancharles92
2022-05-12 16:53:01 -07:00
committed by GitHub
parent 213ce7fa3f
commit 7c53790094
2 changed files with 22 additions and 23 deletions

View File

@@ -212,7 +212,7 @@ impl<F: PrimeField> SumCheck<F> for PolyIOP<F> {
domain_info: &Self::DomainInfo,
transcript: &mut Self::Transcript,
) -> Result<Self::SubClaim, PolyIOPErrors> {
let start = start_timer!(|| "sum check prove");
let start = start_timer!(|| "sum check verify");
transcript.append_domain_info(domain_info)?;
let mut verifier_state = VerifierState::verifier_init(domain_info);

View File

@@ -122,39 +122,38 @@ impl<F: PrimeField> SumCheckProver<F> for ProverState<F> {
let compute_sum = start_timer!(|| "compute sum");
// generate sum
for b in 0..1 << (nv - i) {
#[cfg(feature = "parallel")]
products_sum
.par_iter_mut()
.take(degree + 1)
.enumerate()
.for_each(|(i, e)| {
// evaluate P_round(t)
for (coefficient, products) in products.iter() {
let num_multiplicands = products.len();
let mut product = *coefficient;
for &f in products.iter().take(num_multiplicands) {
let table = &flattened_ml_extensions[f]; // f's range is checked in init
product *= table[b << 1] * (F::one() - F::from(i as u64))
+ table[(b << 1) + 1] * F::from(i as u64);
}
*e += product;
#[cfg(feature = "parallel")]
products_sum.par_iter_mut().enumerate().for_each(|(t, e)| {
for b in 0..1 << (nv - i) {
// evaluate P_round(t)
for (coefficient, products) in products.iter() {
let num_multiplicands = products.len();
let mut product = *coefficient;
for &f in products.iter().take(num_multiplicands) {
let table = &flattened_ml_extensions[f]; // f's range is checked in init
product *= table[b << 1] * (F::one() - F::from(t as u64))
+ table[(b << 1) + 1] * F::from(t as u64);
}
});
#[cfg(not(feature = "parallel"))]
*e += product;
}
}
});
#[cfg(not(feature = "parallel"))]
for b in 0..1 << (nv - i) {
products_sum
.iter_mut()
.take(degree + 1)
.enumerate()
.for_each(|(i, e)| {
.for_each(|(t, e)| {
// evaluate P_round(t)
for (coefficient, products) in products.iter() {
let num_multiplicands = products.len();
let mut product = *coefficient;
for &f in products.iter().take(num_multiplicands) {
let table = &flattened_ml_extensions[f]; // f's range is checked in init
product *= table[b << 1] * (F::one() - F::from(i as u64))
+ table[(b << 1) + 1] * F::from(i as u64);
product *= table[b << 1] * (F::one() - F::from(t as u64))
+ table[(b << 1) + 1] * F::from(t as u64);
}
*e += product;
}