mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-07 22:51:29 +01:00
Add parallelism to aggregate_non_interactive_multi_party_server_key_shares
For example, In a `12 cores` server with `48GB RAM`, the call to `aggregate_server_key_shares`: - for `examples/if_and_else`: - prior to this commit it took `47.56s` - with this commit it takes `4.66s` - for `examples/non_interactive_fheuint8`: - prior to this commit it took `158.15s` - with this commit it takes `14.96s` so about `~10x` reduction. In a `4 cores` laptop with `8GB RAM` (low capacity laptop, with multiple other apps running ), the call to `aggregate_server_key_shares`: - for `examples/if_and_else`: - prior to this commit it took `48.65s` - with this commit it takes `23.11s` so about `~2x` reduction. Co-authored-by: Carlos Pérez <c.perezbaro@gmail.com>
This commit is contained in:
@@ -14,6 +14,7 @@ num-traits = "0.2.18"
|
||||
rand = "0.8.5"
|
||||
rand_chacha = "0.3.1"
|
||||
rand_distr = "0.4.3"
|
||||
rayon = "1.10.0"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.5.1"
|
||||
@@ -59,4 +60,4 @@ required-features = ["non_interactive_mp"]
|
||||
[[example]]
|
||||
name = "if_and_else"
|
||||
path = "./examples/if_and_else.rs"
|
||||
required-features = ["non_interactive_mp"]
|
||||
required-features = ["non_interactive_mp"]
|
||||
|
||||
@@ -9,6 +9,10 @@ use itertools::{izip, Itertools};
|
||||
use num_traits::{FromPrimitive, One, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero};
|
||||
use rand_distr::uniform::SampleUniform;
|
||||
|
||||
use rayon::iter::FlatMap;
|
||||
use rayon::prelude::*;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::{
|
||||
backend::{ArithmeticOps, GetModulus, ModInit, Modulus, ShoupMatrixFMA, VectorOps},
|
||||
bool::parameters::ParameterVariant,
|
||||
@@ -627,7 +631,7 @@ pub(super) fn multi_party_user_id_lwe_segment(
|
||||
|
||||
impl<M: Matrix, NttOp, RlweModOp, LweModOp, SKey> BoolEvaluator<M, NttOp, RlweModOp, LweModOp, SKey>
|
||||
where
|
||||
M: MatrixEntity + MatrixMut,
|
||||
M: MatrixEntity + MatrixMut + Send + Sync,
|
||||
M::MatElement: PrimInt
|
||||
+ Debug
|
||||
+ Display
|
||||
@@ -636,16 +640,21 @@ where
|
||||
+ WrappingSub
|
||||
+ WrappingAdd
|
||||
+ SampleUniform
|
||||
+ From<bool>,
|
||||
NttOp: Ntt<Element = M::MatElement>,
|
||||
+ From<bool>
|
||||
+ Send
|
||||
+ Sync,
|
||||
NttOp: Ntt<Element = M::MatElement> + Send + Sync,
|
||||
RlweModOp: ArithmeticOps<Element = M::MatElement>
|
||||
+ VectorOps<Element = M::MatElement>
|
||||
+ GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>>
|
||||
+ ShoupMatrixFMA<M::R>,
|
||||
+ ShoupMatrixFMA<M::R>
|
||||
+ Send
|
||||
+ Sync,
|
||||
LweModOp: ArithmeticOps<Element = M::MatElement>
|
||||
+ VectorOps<Element = M::MatElement>
|
||||
+ GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>>,
|
||||
M::R: TryConvertFrom1<[i32], CiphertextModulus<M::MatElement>> + RowEntity + Debug,
|
||||
M::R:
|
||||
TryConvertFrom1<[i32], CiphertextModulus<M::MatElement>> + RowEntity + Debug + Send + Sync,
|
||||
<M as Matrix>::R: RowMut,
|
||||
{
|
||||
pub(super) fn new(parameters: BoolParameters<M::MatElement>) -> Self
|
||||
@@ -1378,6 +1387,10 @@ where
|
||||
)
|
||||
})
|
||||
.collect_vec();
|
||||
// clone self.rlwe_n & self.parameters so that we don't access to &self inside the
|
||||
// threads
|
||||
let rlwe_n = self.parameters().rlwe_n().0.clone();
|
||||
let parameters = self.parameters().clone();
|
||||
// Note: Each user is assigned a contigous LWE segement and the LWE dimension is
|
||||
// split approximately uniformly across all users. Hence, concatenation of all
|
||||
// user specific lwe segments will give LWE dimension.
|
||||
@@ -1385,8 +1398,9 @@ where
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.flat_map(|(user_id, lwe_segment)| {
|
||||
let mut rgsws = Vec::new();
|
||||
(lwe_segment.0..lwe_segment.1)
|
||||
.into_iter()
|
||||
.into_par_iter()
|
||||
.map(|lwe_index| {
|
||||
// We sample d_b `-a_i`s to key switch and generate RLWE'(m). But before
|
||||
// we sampling we need to puncture a_prng d_max - d_b times to align
|
||||
@@ -1396,7 +1410,7 @@ where
|
||||
cr_seed.ni_rgsw_ct_seed_for_index::<DefaultSecureRng>(lwe_index),
|
||||
);
|
||||
|
||||
let mut scratch = M::R::zeros(self.parameters().rlwe_n().0);
|
||||
let mut scratch = M::R::zeros(rlwe_n);
|
||||
(0..d_max - rgsw_x_rgsw_decomposer.b().decomposition_count().0)
|
||||
.for_each(|_| {
|
||||
RandomFillUniformInModulus::random_fill(
|
||||
@@ -1420,7 +1434,7 @@ where
|
||||
|
||||
let mut decomp_neg_ai = M::zeros(
|
||||
ni_uj_to_s_decomposer.decomposition_count().0,
|
||||
self.parameters().rlwe_n().0,
|
||||
rlwe_n,
|
||||
);
|
||||
scratch.as_ref().iter().enumerate().for_each(|(index, el)| {
|
||||
ni_uj_to_s_decomposer
|
||||
@@ -1447,41 +1461,40 @@ where
|
||||
// then use to produce RLWE'(-sX^{s_{lwe}[l]}).
|
||||
// Hence, after aggregation we decompose a_{i, l} * s + e to
|
||||
// prepare for key switching
|
||||
let ni_rgsw_zero_encs = (0..rgsw_x_rgsw_decomposer
|
||||
.a()
|
||||
.decomposition_count()
|
||||
.0)
|
||||
.map(|i| {
|
||||
let mut sum = M::R::zeros(self.parameters().rlwe_n().0);
|
||||
key_shares.iter().for_each(|k| {
|
||||
let to_add_ref = k
|
||||
.ni_rgsw_zero_enc_for_lwe_index(lwe_index)
|
||||
.get_row_slice(i);
|
||||
assert!(to_add_ref.len() == self.parameters().rlwe_n().0);
|
||||
rlwe_modop.elwise_add_mut(sum.as_mut(), to_add_ref);
|
||||
});
|
||||
let ni_rgsw_zero_encs =
|
||||
(0..rgsw_x_rgsw_decomposer.a().decomposition_count().0)
|
||||
.map(|i| {
|
||||
let mut sum = M::R::zeros(rlwe_n);
|
||||
key_shares.iter().for_each(|k| {
|
||||
let to_add_ref = k
|
||||
.ni_rgsw_zero_enc_for_lwe_index(lwe_index)
|
||||
.get_row_slice(i);
|
||||
assert!(to_add_ref.len() == rlwe_n);
|
||||
rlwe_modop.elwise_add_mut(sum.as_mut(), to_add_ref);
|
||||
});
|
||||
|
||||
// decompose
|
||||
let mut decomp_sum = M::zeros(
|
||||
ni_uj_to_s_decomposer.decomposition_count().0,
|
||||
self.parameters().rlwe_n().0,
|
||||
);
|
||||
sum.as_ref().iter().enumerate().for_each(|(index, el)| {
|
||||
ni_uj_to_s_decomposer
|
||||
.decompose_iter(el)
|
||||
.enumerate()
|
||||
.for_each(|(row_j, d_el)| {
|
||||
(decomp_sum.as_mut()[row_j]).as_mut()[index] = d_el;
|
||||
});
|
||||
});
|
||||
// decompose
|
||||
let mut decomp_sum = M::zeros(
|
||||
ni_uj_to_s_decomposer.decomposition_count().0,
|
||||
rlwe_n,
|
||||
);
|
||||
sum.as_ref().iter().enumerate().for_each(|(index, el)| {
|
||||
ni_uj_to_s_decomposer
|
||||
.decompose_iter(el)
|
||||
.enumerate()
|
||||
.for_each(|(row_j, d_el)| {
|
||||
(decomp_sum.as_mut()[row_j]).as_mut()[index] =
|
||||
d_el;
|
||||
});
|
||||
});
|
||||
|
||||
decomp_sum
|
||||
.iter_rows_mut()
|
||||
.for_each(|r| nttop.forward(r.as_mut()));
|
||||
decomp_sum
|
||||
.iter_rows_mut()
|
||||
.for_each(|r| nttop.forward(r.as_mut()));
|
||||
|
||||
decomp_sum
|
||||
})
|
||||
.collect_vec();
|
||||
decomp_sum
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
// Produce RGSW(X^{s_{j=user_id, lwe}[l]}) for the
|
||||
// leader, ie user's id = user_id.
|
||||
@@ -1503,7 +1516,7 @@ where
|
||||
.0
|
||||
- rlwe_x_rgsw_decomposer.b().decomposition_count().0..],
|
||||
&rlwe_x_rgsw_decomposer,
|
||||
self.parameters(),
|
||||
¶meters,
|
||||
(&uj_to_s_ksks[user_id], &uj_to_s_ksks_part_a_eval[user_id]),
|
||||
rlwe_modop,
|
||||
nttop,
|
||||
@@ -1524,7 +1537,7 @@ where
|
||||
&ni_rgsw_zero_encs,
|
||||
&decomp_neg_ais,
|
||||
&rgsw_x_rgsw_decomposer,
|
||||
self.parameters(),
|
||||
¶meters,
|
||||
(
|
||||
&uj_to_s_ksks[other_user_id],
|
||||
&uj_to_s_ksks_part_a_eval[other_user_id],
|
||||
@@ -1548,7 +1561,7 @@ where
|
||||
&rlwe_x_rgsw_decomposer,
|
||||
&rgsw_x_rgsw_decomposer,
|
||||
&mut RuntimeScratchMutRef::new(
|
||||
scratch_rgsw_x_rgsw.as_mut(),
|
||||
scratch_rgsw_x_rgsw.clone().as_mut(),
|
||||
),
|
||||
nttop,
|
||||
rlwe_modop,
|
||||
@@ -1557,7 +1570,8 @@ where
|
||||
|
||||
rgsw_i
|
||||
})
|
||||
.collect_vec()
|
||||
.collect_into_vec(&mut rgsws);
|
||||
rgsws
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user