From 720b13faba25f198c7b5268795357dd53819be80 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Fri, 26 Jul 2024 00:15:44 +0200 Subject: [PATCH] Add parallelism to aggregate_non_interactive_multi_party_server_key_shares MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For example, In a `12 cores` server with `48GB RAM`, the call to `aggregate_server_key_shares`: - for `examples/if_and_else`: - prior to this commit it took `47.56s` - with this commit it takes `4.66s` - for `examples/non_interactive_fheuint8`: - prior to this commit it took `158.15s` - with this commit it takes `14.96s` so about `~10x` reduction. In a `4 cores` laptop with `8GB RAM` (low capacity laptop, with multiple other apps running ), the call to `aggregate_server_key_shares`: - for `examples/if_and_else`: - prior to this commit it took `48.65s` - with this commit it takes `23.11s` so about `~2x` reduction. Co-authored-by: Carlos PĂ©rez --- Cargo.toml | 3 +- src/bool/evaluator.rs | 108 ++++++++++++++++++++++++------------------ 2 files changed, 63 insertions(+), 48 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8e34081..2fb18a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ num-traits = "0.2.18" rand = "0.8.5" rand_chacha = "0.3.1" rand_distr = "0.4.3" +rayon = "1.10.0" [dev-dependencies] criterion = "0.5.1" @@ -59,4 +60,4 @@ required-features = ["non_interactive_mp"] [[example]] name = "if_and_else" path = "./examples/if_and_else.rs" -required-features = ["non_interactive_mp"] \ No newline at end of file +required-features = ["non_interactive_mp"] diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 6b3f0b9..6591f45 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -9,6 +9,10 @@ use itertools::{izip, Itertools}; use num_traits::{FromPrimitive, One, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero}; use rand_distr::uniform::SampleUniform; +use rayon::iter::FlatMap; +use rayon::prelude::*; +use std::sync::{Arc, Mutex}; + use crate::{ backend::{ArithmeticOps, GetModulus, ModInit, Modulus, ShoupMatrixFMA, VectorOps}, bool::parameters::ParameterVariant, @@ -627,7 +631,7 @@ pub(super) fn multi_party_user_id_lwe_segment( impl BoolEvaluator where - M: MatrixEntity + MatrixMut, + M: MatrixEntity + MatrixMut + Send + Sync, M::MatElement: PrimInt + Debug + Display @@ -636,16 +640,21 @@ where + WrappingSub + WrappingAdd + SampleUniform - + From, - NttOp: Ntt, + + From + + Send + + Sync, + NttOp: Ntt + Send + Sync, RlweModOp: ArithmeticOps + VectorOps + GetModulus> - + ShoupMatrixFMA, + + ShoupMatrixFMA + + Send + + Sync, LweModOp: ArithmeticOps + VectorOps + GetModulus>, - M::R: TryConvertFrom1<[i32], CiphertextModulus> + RowEntity + Debug, + M::R: + TryConvertFrom1<[i32], CiphertextModulus> + RowEntity + Debug + Send + Sync, ::R: RowMut, { pub(super) fn new(parameters: BoolParameters) -> Self @@ -1378,6 +1387,10 @@ where ) }) .collect_vec(); + // clone self.rlwe_n & self.parameters so that we don't access to &self inside the + // threads + let rlwe_n = self.parameters().rlwe_n().0.clone(); + let parameters = self.parameters().clone(); // Note: Each user is assigned a contigous LWE segement and the LWE dimension is // split approximately uniformly across all users. Hence, concatenation of all // user specific lwe segments will give LWE dimension. @@ -1385,8 +1398,9 @@ where .into_iter() .enumerate() .flat_map(|(user_id, lwe_segment)| { + let mut rgsws = Vec::new(); (lwe_segment.0..lwe_segment.1) - .into_iter() + .into_par_iter() .map(|lwe_index| { // We sample d_b `-a_i`s to key switch and generate RLWE'(m). But before // we sampling we need to puncture a_prng d_max - d_b times to align @@ -1396,7 +1410,7 @@ where cr_seed.ni_rgsw_ct_seed_for_index::(lwe_index), ); - let mut scratch = M::R::zeros(self.parameters().rlwe_n().0); + let mut scratch = M::R::zeros(rlwe_n); (0..d_max - rgsw_x_rgsw_decomposer.b().decomposition_count().0) .for_each(|_| { RandomFillUniformInModulus::random_fill( @@ -1420,7 +1434,7 @@ where let mut decomp_neg_ai = M::zeros( ni_uj_to_s_decomposer.decomposition_count().0, - self.parameters().rlwe_n().0, + rlwe_n, ); scratch.as_ref().iter().enumerate().for_each(|(index, el)| { ni_uj_to_s_decomposer @@ -1447,41 +1461,40 @@ where // then use to produce RLWE'(-sX^{s_{lwe}[l]}). // Hence, after aggregation we decompose a_{i, l} * s + e to // prepare for key switching - let ni_rgsw_zero_encs = (0..rgsw_x_rgsw_decomposer - .a() - .decomposition_count() - .0) - .map(|i| { - let mut sum = M::R::zeros(self.parameters().rlwe_n().0); - key_shares.iter().for_each(|k| { - let to_add_ref = k - .ni_rgsw_zero_enc_for_lwe_index(lwe_index) - .get_row_slice(i); - assert!(to_add_ref.len() == self.parameters().rlwe_n().0); - rlwe_modop.elwise_add_mut(sum.as_mut(), to_add_ref); - }); - - // decompose - let mut decomp_sum = M::zeros( - ni_uj_to_s_decomposer.decomposition_count().0, - self.parameters().rlwe_n().0, - ); - sum.as_ref().iter().enumerate().for_each(|(index, el)| { - ni_uj_to_s_decomposer - .decompose_iter(el) - .enumerate() - .for_each(|(row_j, d_el)| { - (decomp_sum.as_mut()[row_j]).as_mut()[index] = d_el; - }); - }); - - decomp_sum - .iter_rows_mut() - .for_each(|r| nttop.forward(r.as_mut())); - - decomp_sum - }) - .collect_vec(); + let ni_rgsw_zero_encs = + (0..rgsw_x_rgsw_decomposer.a().decomposition_count().0) + .map(|i| { + let mut sum = M::R::zeros(rlwe_n); + key_shares.iter().for_each(|k| { + let to_add_ref = k + .ni_rgsw_zero_enc_for_lwe_index(lwe_index) + .get_row_slice(i); + assert!(to_add_ref.len() == rlwe_n); + rlwe_modop.elwise_add_mut(sum.as_mut(), to_add_ref); + }); + + // decompose + let mut decomp_sum = M::zeros( + ni_uj_to_s_decomposer.decomposition_count().0, + rlwe_n, + ); + sum.as_ref().iter().enumerate().for_each(|(index, el)| { + ni_uj_to_s_decomposer + .decompose_iter(el) + .enumerate() + .for_each(|(row_j, d_el)| { + (decomp_sum.as_mut()[row_j]).as_mut()[index] = + d_el; + }); + }); + + decomp_sum + .iter_rows_mut() + .for_each(|r| nttop.forward(r.as_mut())); + + decomp_sum + }) + .collect_vec(); // Produce RGSW(X^{s_{j=user_id, lwe}[l]}) for the // leader, ie user's id = user_id. @@ -1503,7 +1516,7 @@ where .0 - rlwe_x_rgsw_decomposer.b().decomposition_count().0..], &rlwe_x_rgsw_decomposer, - self.parameters(), + ¶meters, (&uj_to_s_ksks[user_id], &uj_to_s_ksks_part_a_eval[user_id]), rlwe_modop, nttop, @@ -1524,7 +1537,7 @@ where &ni_rgsw_zero_encs, &decomp_neg_ais, &rgsw_x_rgsw_decomposer, - self.parameters(), + ¶meters, ( &uj_to_s_ksks[other_user_id], &uj_to_s_ksks_part_a_eval[other_user_id], @@ -1548,7 +1561,7 @@ where &rlwe_x_rgsw_decomposer, &rgsw_x_rgsw_decomposer, &mut RuntimeScratchMutRef::new( - scratch_rgsw_x_rgsw.as_mut(), + scratch_rgsw_x_rgsw.clone().as_mut(), ), nttop, rlwe_modop, @@ -1557,7 +1570,8 @@ where rgsw_i }) - .collect_vec() + .collect_into_vec(&mut rgsws); + rgsws }) .collect_vec();