Browse Source

more parallelization

main
Charles Chen 1 year ago
parent
commit
c48c1b97a5
2 changed files with 16 additions and 27 deletions
  1. +6
    -20
      arithmetic/src/multilinear_polynomial.rs
  2. +10
    -7
      subroutines/src/pcs/multilinear_kzg/batching.rs

+ 6
- 20
arithmetic/src/multilinear_polynomial.rs

@ -160,16 +160,9 @@ fn fix_one_variable_helper(data: &[F], nv: usize, point: &F) -> Vec
} }
#[cfg(feature = "parallel")] #[cfg(feature = "parallel")]
if nv >= 13 {
// on my computer we parallelization doesn't help till nv >= 13
res.par_iter_mut().enumerate().for_each(|(i, x)| {
*x = data[i << 1] + (data[(i << 1) + 1] - data[i << 1]) * point;
});
} else {
for i in 0..(1 << (nv - 1)) {
res[i] = data[i << 1] + (data[(i << 1) + 1] - data[i << 1]) * point;
}
}
res.par_iter_mut().enumerate().for_each(|(i, x)| {
*x = data[i << 1] + (data[(i << 1) + 1] - data[i << 1]) * point;
});
res res
} }
@ -279,16 +272,9 @@ fn fix_last_variable_helper(data: &[F], nv: usize, point: &F) -> Vec
} }
#[cfg(feature = "parallel")] #[cfg(feature = "parallel")]
if nv >= 13 {
// on my computer we parallelization doesn't help till nv >= 13
res.par_iter_mut().enumerate().for_each(|(i, x)| {
*x = data[i] + (data[i + half_len] - data[i]) * point;
});
} else {
for b in 0..(1 << (nv - 1)) {
res[b] = data[b] + (data[b + half_len] - data[b]) * point;
}
}
res.par_iter_mut().enumerate().for_each(|(i, x)| {
*x = data[i] + (data[i + half_len] - data[i]) * point;
});
res res
} }

+ 10
- 7
subroutines/src/pcs/multilinear_kzg/batching.rs

@ -17,6 +17,7 @@ use arithmetic::{build_eq_x_r_vec, DenseMultilinearExtension, VPAuxInfo, Virtual
use ark_ec::{msm::VariableBaseMSM, PairingEngine, ProjectiveCurve}; use ark_ec::{msm::VariableBaseMSM, PairingEngine, ProjectiveCurve};
use ark_ff::PrimeField; use ark_ff::PrimeField;
use ark_std::{end_timer, log2, start_timer, One, Zero}; use ark_std::{end_timer, log2, start_timer, One, Zero};
use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
use std::{marker::PhantomData, sync::Arc}; use std::{marker::PhantomData, sync::Arc};
use transcript::IOPTranscript; use transcript::IOPTranscript;
@ -87,13 +88,15 @@ where
end_timer!(timer); end_timer!(timer);
let timer = start_timer!(|| format!("compute tilde eq for {} points", points.len())); let timer = start_timer!(|| format!("compute tilde eq for {} points", points.len()));
let mut tilde_eqs = vec![];
for point in points.iter() {
let eq_b_zi = build_eq_x_r_vec(point)?;
tilde_eqs.push(Arc::new(DenseMultilinearExtension::from_evaluations_vec(
num_var, eq_b_zi,
)));
}
let tilde_eqs: Vec<Arc<DenseMultilinearExtension<<E as PairingEngine>::Fr>>> = points
.par_iter()
.map(|point| {
let eq_b_zi = build_eq_x_r_vec(point).unwrap();
Arc::new(DenseMultilinearExtension::from_evaluations_vec(
num_var, eq_b_zi,
))
})
.collect();
end_timer!(timer); end_timer!(timer);
// built the virtual polynomial for SumCheck // built the virtual polynomial for SumCheck

Loading…
Cancel
Save