Browse Source

Add parallelism to aggregate_non_interactive_multi_party_server_key_shares

For example,

In a `12 cores` server with `48GB RAM`,
the call to `aggregate_server_key_shares`:
- for `examples/if_and_else`:
  - prior to this commit it took `47.56s`
  - with this commit it takes `4.66s`
- for `examples/non_interactive_fheuint8`:
  - prior to this commit it took `158.15s`
  - with this commit it takes `14.96s`

so about `~10x` reduction.

In a `4 cores` laptop with `8GB RAM` (low capacity laptop, with multiple other apps running    ),
the call to `aggregate_server_key_shares`:
- for `examples/if_and_else`:
  - prior to this commit it took `48.65s`
  - with this commit it takes `23.11s`

so about `~2x` reduction.

Co-authored-by: Carlos Pérez <c.perezbaro@gmail.com>
par-agg-key-shares
arnaucube 10 months ago
parent
commit
720b13faba
2 changed files with 63 additions and 48 deletions
  1. +2
    -1
      Cargo.toml
  2. +61
    -47
      src/bool/evaluator.rs

+ 2
- 1
Cargo.toml

@ -14,6 +14,7 @@ num-traits = "0.2.18"
rand = "0.8.5" rand = "0.8.5"
rand_chacha = "0.3.1" rand_chacha = "0.3.1"
rand_distr = "0.4.3" rand_distr = "0.4.3"
rayon = "1.10.0"
[dev-dependencies] [dev-dependencies]
criterion = "0.5.1" criterion = "0.5.1"
@ -59,4 +60,4 @@ required-features = ["non_interactive_mp"]
[[example]] [[example]]
name = "if_and_else" name = "if_and_else"
path = "./examples/if_and_else.rs" path = "./examples/if_and_else.rs"
required-features = ["non_interactive_mp"]
required-features = ["non_interactive_mp"]

+ 61
- 47
src/bool/evaluator.rs

@ -9,6 +9,10 @@ use itertools::{izip, Itertools};
use num_traits::{FromPrimitive, One, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero}; use num_traits::{FromPrimitive, One, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero};
use rand_distr::uniform::SampleUniform; use rand_distr::uniform::SampleUniform;
use rayon::iter::FlatMap;
use rayon::prelude::*;
use std::sync::{Arc, Mutex};
use crate::{ use crate::{
backend::{ArithmeticOps, GetModulus, ModInit, Modulus, ShoupMatrixFMA, VectorOps}, backend::{ArithmeticOps, GetModulus, ModInit, Modulus, ShoupMatrixFMA, VectorOps},
bool::parameters::ParameterVariant, bool::parameters::ParameterVariant,
@ -627,7 +631,7 @@ pub(super) fn multi_party_user_id_lwe_segment(
impl<M: Matrix, NttOp, RlweModOp, LweModOp, SKey> BoolEvaluator<M, NttOp, RlweModOp, LweModOp, SKey> impl<M: Matrix, NttOp, RlweModOp, LweModOp, SKey> BoolEvaluator<M, NttOp, RlweModOp, LweModOp, SKey>
where where
M: MatrixEntity + MatrixMut,
M: MatrixEntity + MatrixMut + Send + Sync,
M::MatElement: PrimInt M::MatElement: PrimInt
+ Debug + Debug
+ Display + Display
@ -636,16 +640,21 @@ where
+ WrappingSub + WrappingSub
+ WrappingAdd + WrappingAdd
+ SampleUniform + SampleUniform
+ From<bool>,
NttOp: Ntt<Element = M::MatElement>,
+ From<bool>
+ Send
+ Sync,
NttOp: Ntt<Element = M::MatElement> + Send + Sync,
RlweModOp: ArithmeticOps<Element = M::MatElement> RlweModOp: ArithmeticOps<Element = M::MatElement>
+ VectorOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>
+ GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>> + GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>>
+ ShoupMatrixFMA<M::R>,
+ ShoupMatrixFMA<M::R>
+ Send
+ Sync,
LweModOp: ArithmeticOps<Element = M::MatElement> LweModOp: ArithmeticOps<Element = M::MatElement>
+ VectorOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>
+ GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>>, + GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>>,
M::R: TryConvertFrom1<[i32], CiphertextModulus<M::MatElement>> + RowEntity + Debug,
M::R:
TryConvertFrom1<[i32], CiphertextModulus<M::MatElement>> + RowEntity + Debug + Send + Sync,
<M as Matrix>::R: RowMut, <M as Matrix>::R: RowMut,
{ {
pub(super) fn new(parameters: BoolParameters<M::MatElement>) -> Self pub(super) fn new(parameters: BoolParameters<M::MatElement>) -> Self
@ -1378,6 +1387,10 @@ where
) )
}) })
.collect_vec(); .collect_vec();
// clone self.rlwe_n & self.parameters so that we don't access to &self inside the
// threads
let rlwe_n = self.parameters().rlwe_n().0.clone();
let parameters = self.parameters().clone();
// Note: Each user is assigned a contigous LWE segement and the LWE dimension is // Note: Each user is assigned a contigous LWE segement and the LWE dimension is
// split approximately uniformly across all users. Hence, concatenation of all // split approximately uniformly across all users. Hence, concatenation of all
// user specific lwe segments will give LWE dimension. // user specific lwe segments will give LWE dimension.
@ -1385,8 +1398,9 @@ where
.into_iter() .into_iter()
.enumerate() .enumerate()
.flat_map(|(user_id, lwe_segment)| { .flat_map(|(user_id, lwe_segment)| {
let mut rgsws = Vec::new();
(lwe_segment.0..lwe_segment.1) (lwe_segment.0..lwe_segment.1)
.into_iter()
.into_par_iter()
.map(|lwe_index| { .map(|lwe_index| {
// We sample d_b `-a_i`s to key switch and generate RLWE'(m). But before // We sample d_b `-a_i`s to key switch and generate RLWE'(m). But before
// we sampling we need to puncture a_prng d_max - d_b times to align // we sampling we need to puncture a_prng d_max - d_b times to align
@ -1396,7 +1410,7 @@ where
cr_seed.ni_rgsw_ct_seed_for_index::<DefaultSecureRng>(lwe_index), cr_seed.ni_rgsw_ct_seed_for_index::<DefaultSecureRng>(lwe_index),
); );
let mut scratch = M::R::zeros(self.parameters().rlwe_n().0);
let mut scratch = M::R::zeros(rlwe_n);
(0..d_max - rgsw_x_rgsw_decomposer.b().decomposition_count().0) (0..d_max - rgsw_x_rgsw_decomposer.b().decomposition_count().0)
.for_each(|_| { .for_each(|_| {
RandomFillUniformInModulus::random_fill( RandomFillUniformInModulus::random_fill(
@ -1420,7 +1434,7 @@ where
let mut decomp_neg_ai = M::zeros( let mut decomp_neg_ai = M::zeros(
ni_uj_to_s_decomposer.decomposition_count().0, ni_uj_to_s_decomposer.decomposition_count().0,
self.parameters().rlwe_n().0,
rlwe_n,
); );
scratch.as_ref().iter().enumerate().for_each(|(index, el)| { scratch.as_ref().iter().enumerate().for_each(|(index, el)| {
ni_uj_to_s_decomposer ni_uj_to_s_decomposer
@ -1447,41 +1461,40 @@ where
// then use to produce RLWE'(-sX^{s_{lwe}[l]}). // then use to produce RLWE'(-sX^{s_{lwe}[l]}).
// Hence, after aggregation we decompose a_{i, l} * s + e to // Hence, after aggregation we decompose a_{i, l} * s + e to
// prepare for key switching // prepare for key switching
let ni_rgsw_zero_encs = (0..rgsw_x_rgsw_decomposer
.a()
.decomposition_count()
.0)
.map(|i| {
let mut sum = M::R::zeros(self.parameters().rlwe_n().0);
key_shares.iter().for_each(|k| {
let to_add_ref = k
.ni_rgsw_zero_enc_for_lwe_index(lwe_index)
.get_row_slice(i);
assert!(to_add_ref.len() == self.parameters().rlwe_n().0);
rlwe_modop.elwise_add_mut(sum.as_mut(), to_add_ref);
});
// decompose
let mut decomp_sum = M::zeros(
ni_uj_to_s_decomposer.decomposition_count().0,
self.parameters().rlwe_n().0,
);
sum.as_ref().iter().enumerate().for_each(|(index, el)| {
ni_uj_to_s_decomposer
.decompose_iter(el)
.enumerate()
.for_each(|(row_j, d_el)| {
(decomp_sum.as_mut()[row_j]).as_mut()[index] = d_el;
});
});
decomp_sum
.iter_rows_mut()
.for_each(|r| nttop.forward(r.as_mut()));
decomp_sum
})
.collect_vec();
let ni_rgsw_zero_encs =
(0..rgsw_x_rgsw_decomposer.a().decomposition_count().0)
.map(|i| {
let mut sum = M::R::zeros(rlwe_n);
key_shares.iter().for_each(|k| {
let to_add_ref = k
.ni_rgsw_zero_enc_for_lwe_index(lwe_index)
.get_row_slice(i);
assert!(to_add_ref.len() == rlwe_n);
rlwe_modop.elwise_add_mut(sum.as_mut(), to_add_ref);
});
// decompose
let mut decomp_sum = M::zeros(
ni_uj_to_s_decomposer.decomposition_count().0,
rlwe_n,
);
sum.as_ref().iter().enumerate().for_each(|(index, el)| {
ni_uj_to_s_decomposer
.decompose_iter(el)
.enumerate()
.for_each(|(row_j, d_el)| {
(decomp_sum.as_mut()[row_j]).as_mut()[index] =
d_el;
});
});
decomp_sum
.iter_rows_mut()
.for_each(|r| nttop.forward(r.as_mut()));
decomp_sum
})
.collect_vec();
// Produce RGSW(X^{s_{j=user_id, lwe}[l]}) for the // Produce RGSW(X^{s_{j=user_id, lwe}[l]}) for the
// leader, ie user's id = user_id. // leader, ie user's id = user_id.
@ -1503,7 +1516,7 @@ where
.0 .0
- rlwe_x_rgsw_decomposer.b().decomposition_count().0..], - rlwe_x_rgsw_decomposer.b().decomposition_count().0..],
&rlwe_x_rgsw_decomposer, &rlwe_x_rgsw_decomposer,
self.parameters(),
&parameters,
(&uj_to_s_ksks[user_id], &uj_to_s_ksks_part_a_eval[user_id]), (&uj_to_s_ksks[user_id], &uj_to_s_ksks_part_a_eval[user_id]),
rlwe_modop, rlwe_modop,
nttop, nttop,
@ -1524,7 +1537,7 @@ where
&ni_rgsw_zero_encs, &ni_rgsw_zero_encs,
&decomp_neg_ais, &decomp_neg_ais,
&rgsw_x_rgsw_decomposer, &rgsw_x_rgsw_decomposer,
self.parameters(),
&parameters,
( (
&uj_to_s_ksks[other_user_id], &uj_to_s_ksks[other_user_id],
&uj_to_s_ksks_part_a_eval[other_user_id], &uj_to_s_ksks_part_a_eval[other_user_id],
@ -1548,7 +1561,7 @@ where
&rlwe_x_rgsw_decomposer, &rlwe_x_rgsw_decomposer,
&rgsw_x_rgsw_decomposer, &rgsw_x_rgsw_decomposer,
&mut RuntimeScratchMutRef::new( &mut RuntimeScratchMutRef::new(
scratch_rgsw_x_rgsw.as_mut(),
scratch_rgsw_x_rgsw.clone().as_mut(),
), ),
nttop, nttop,
rlwe_modop, rlwe_modop,
@ -1557,7 +1570,8 @@ where
rgsw_i rgsw_i
}) })
.collect_vec()
.collect_into_vec(&mut rgsws);
rgsws
}) })
.collect_vec(); .collect_vec();

Loading…
Cancel
Save