mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-09 15:41:30 +01:00
Add parallelism to aggregate_non_interactive_multi_party_server_key_shares
For example, In a `12 cores` server with `48GB RAM`, the call to `aggregate_server_key_shares`: - for `examples/if_and_else`: - prior to this commit it took `47.56s` - with this commit it takes `4.66s` - for `examples/non_interactive_fheuint8`: - prior to this commit it took `158.15s` - with this commit it takes `14.96s` so about `~10x` reduction. In a `4 cores` laptop with `8GB RAM` (low capacity laptop, with multiple other apps running ), the call to `aggregate_server_key_shares`: - for `examples/if_and_else`: - prior to this commit it took `48.65s` - with this commit it takes `23.11s` so about `~2x` reduction. Co-authored-by: Carlos Pérez <c.perezbaro@gmail.com>
This commit is contained in:
@@ -14,6 +14,7 @@ num-traits = "0.2.18"
|
|||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
rand_chacha = "0.3.1"
|
rand_chacha = "0.3.1"
|
||||||
rand_distr = "0.4.3"
|
rand_distr = "0.4.3"
|
||||||
|
rayon = "1.10.0"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = "0.5.1"
|
criterion = "0.5.1"
|
||||||
|
|||||||
@@ -9,6 +9,10 @@ use itertools::{izip, Itertools};
|
|||||||
use num_traits::{FromPrimitive, One, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero};
|
use num_traits::{FromPrimitive, One, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero};
|
||||||
use rand_distr::uniform::SampleUniform;
|
use rand_distr::uniform::SampleUniform;
|
||||||
|
|
||||||
|
use rayon::iter::FlatMap;
|
||||||
|
use rayon::prelude::*;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
backend::{ArithmeticOps, GetModulus, ModInit, Modulus, ShoupMatrixFMA, VectorOps},
|
backend::{ArithmeticOps, GetModulus, ModInit, Modulus, ShoupMatrixFMA, VectorOps},
|
||||||
bool::parameters::ParameterVariant,
|
bool::parameters::ParameterVariant,
|
||||||
@@ -627,7 +631,7 @@ pub(super) fn multi_party_user_id_lwe_segment(
|
|||||||
|
|
||||||
impl<M: Matrix, NttOp, RlweModOp, LweModOp, SKey> BoolEvaluator<M, NttOp, RlweModOp, LweModOp, SKey>
|
impl<M: Matrix, NttOp, RlweModOp, LweModOp, SKey> BoolEvaluator<M, NttOp, RlweModOp, LweModOp, SKey>
|
||||||
where
|
where
|
||||||
M: MatrixEntity + MatrixMut,
|
M: MatrixEntity + MatrixMut + Send + Sync,
|
||||||
M::MatElement: PrimInt
|
M::MatElement: PrimInt
|
||||||
+ Debug
|
+ Debug
|
||||||
+ Display
|
+ Display
|
||||||
@@ -636,16 +640,21 @@ where
|
|||||||
+ WrappingSub
|
+ WrappingSub
|
||||||
+ WrappingAdd
|
+ WrappingAdd
|
||||||
+ SampleUniform
|
+ SampleUniform
|
||||||
+ From<bool>,
|
+ From<bool>
|
||||||
NttOp: Ntt<Element = M::MatElement>,
|
+ Send
|
||||||
|
+ Sync,
|
||||||
|
NttOp: Ntt<Element = M::MatElement> + Send + Sync,
|
||||||
RlweModOp: ArithmeticOps<Element = M::MatElement>
|
RlweModOp: ArithmeticOps<Element = M::MatElement>
|
||||||
+ VectorOps<Element = M::MatElement>
|
+ VectorOps<Element = M::MatElement>
|
||||||
+ GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>>
|
+ GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>>
|
||||||
+ ShoupMatrixFMA<M::R>,
|
+ ShoupMatrixFMA<M::R>
|
||||||
|
+ Send
|
||||||
|
+ Sync,
|
||||||
LweModOp: ArithmeticOps<Element = M::MatElement>
|
LweModOp: ArithmeticOps<Element = M::MatElement>
|
||||||
+ VectorOps<Element = M::MatElement>
|
+ VectorOps<Element = M::MatElement>
|
||||||
+ GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>>,
|
+ GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>>,
|
||||||
M::R: TryConvertFrom1<[i32], CiphertextModulus<M::MatElement>> + RowEntity + Debug,
|
M::R:
|
||||||
|
TryConvertFrom1<[i32], CiphertextModulus<M::MatElement>> + RowEntity + Debug + Send + Sync,
|
||||||
<M as Matrix>::R: RowMut,
|
<M as Matrix>::R: RowMut,
|
||||||
{
|
{
|
||||||
pub(super) fn new(parameters: BoolParameters<M::MatElement>) -> Self
|
pub(super) fn new(parameters: BoolParameters<M::MatElement>) -> Self
|
||||||
@@ -1378,6 +1387,10 @@ where
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
|
// clone self.rlwe_n & self.parameters so that we don't access to &self inside the
|
||||||
|
// threads
|
||||||
|
let rlwe_n = self.parameters().rlwe_n().0.clone();
|
||||||
|
let parameters = self.parameters().clone();
|
||||||
// Note: Each user is assigned a contigous LWE segement and the LWE dimension is
|
// Note: Each user is assigned a contigous LWE segement and the LWE dimension is
|
||||||
// split approximately uniformly across all users. Hence, concatenation of all
|
// split approximately uniformly across all users. Hence, concatenation of all
|
||||||
// user specific lwe segments will give LWE dimension.
|
// user specific lwe segments will give LWE dimension.
|
||||||
@@ -1385,8 +1398,9 @@ where
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.flat_map(|(user_id, lwe_segment)| {
|
.flat_map(|(user_id, lwe_segment)| {
|
||||||
|
let mut rgsws = Vec::new();
|
||||||
(lwe_segment.0..lwe_segment.1)
|
(lwe_segment.0..lwe_segment.1)
|
||||||
.into_iter()
|
.into_par_iter()
|
||||||
.map(|lwe_index| {
|
.map(|lwe_index| {
|
||||||
// We sample d_b `-a_i`s to key switch and generate RLWE'(m). But before
|
// We sample d_b `-a_i`s to key switch and generate RLWE'(m). But before
|
||||||
// we sampling we need to puncture a_prng d_max - d_b times to align
|
// we sampling we need to puncture a_prng d_max - d_b times to align
|
||||||
@@ -1396,7 +1410,7 @@ where
|
|||||||
cr_seed.ni_rgsw_ct_seed_for_index::<DefaultSecureRng>(lwe_index),
|
cr_seed.ni_rgsw_ct_seed_for_index::<DefaultSecureRng>(lwe_index),
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut scratch = M::R::zeros(self.parameters().rlwe_n().0);
|
let mut scratch = M::R::zeros(rlwe_n);
|
||||||
(0..d_max - rgsw_x_rgsw_decomposer.b().decomposition_count().0)
|
(0..d_max - rgsw_x_rgsw_decomposer.b().decomposition_count().0)
|
||||||
.for_each(|_| {
|
.for_each(|_| {
|
||||||
RandomFillUniformInModulus::random_fill(
|
RandomFillUniformInModulus::random_fill(
|
||||||
@@ -1420,7 +1434,7 @@ where
|
|||||||
|
|
||||||
let mut decomp_neg_ai = M::zeros(
|
let mut decomp_neg_ai = M::zeros(
|
||||||
ni_uj_to_s_decomposer.decomposition_count().0,
|
ni_uj_to_s_decomposer.decomposition_count().0,
|
||||||
self.parameters().rlwe_n().0,
|
rlwe_n,
|
||||||
);
|
);
|
||||||
scratch.as_ref().iter().enumerate().for_each(|(index, el)| {
|
scratch.as_ref().iter().enumerate().for_each(|(index, el)| {
|
||||||
ni_uj_to_s_decomposer
|
ni_uj_to_s_decomposer
|
||||||
@@ -1447,41 +1461,40 @@ where
|
|||||||
// then use to produce RLWE'(-sX^{s_{lwe}[l]}).
|
// then use to produce RLWE'(-sX^{s_{lwe}[l]}).
|
||||||
// Hence, after aggregation we decompose a_{i, l} * s + e to
|
// Hence, after aggregation we decompose a_{i, l} * s + e to
|
||||||
// prepare for key switching
|
// prepare for key switching
|
||||||
let ni_rgsw_zero_encs = (0..rgsw_x_rgsw_decomposer
|
let ni_rgsw_zero_encs =
|
||||||
.a()
|
(0..rgsw_x_rgsw_decomposer.a().decomposition_count().0)
|
||||||
.decomposition_count()
|
.map(|i| {
|
||||||
.0)
|
let mut sum = M::R::zeros(rlwe_n);
|
||||||
.map(|i| {
|
key_shares.iter().for_each(|k| {
|
||||||
let mut sum = M::R::zeros(self.parameters().rlwe_n().0);
|
let to_add_ref = k
|
||||||
key_shares.iter().for_each(|k| {
|
.ni_rgsw_zero_enc_for_lwe_index(lwe_index)
|
||||||
let to_add_ref = k
|
.get_row_slice(i);
|
||||||
.ni_rgsw_zero_enc_for_lwe_index(lwe_index)
|
assert!(to_add_ref.len() == rlwe_n);
|
||||||
.get_row_slice(i);
|
rlwe_modop.elwise_add_mut(sum.as_mut(), to_add_ref);
|
||||||
assert!(to_add_ref.len() == self.parameters().rlwe_n().0);
|
});
|
||||||
rlwe_modop.elwise_add_mut(sum.as_mut(), to_add_ref);
|
|
||||||
});
|
|
||||||
|
|
||||||
// decompose
|
// decompose
|
||||||
let mut decomp_sum = M::zeros(
|
let mut decomp_sum = M::zeros(
|
||||||
ni_uj_to_s_decomposer.decomposition_count().0,
|
ni_uj_to_s_decomposer.decomposition_count().0,
|
||||||
self.parameters().rlwe_n().0,
|
rlwe_n,
|
||||||
);
|
);
|
||||||
sum.as_ref().iter().enumerate().for_each(|(index, el)| {
|
sum.as_ref().iter().enumerate().for_each(|(index, el)| {
|
||||||
ni_uj_to_s_decomposer
|
ni_uj_to_s_decomposer
|
||||||
.decompose_iter(el)
|
.decompose_iter(el)
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.for_each(|(row_j, d_el)| {
|
.for_each(|(row_j, d_el)| {
|
||||||
(decomp_sum.as_mut()[row_j]).as_mut()[index] = d_el;
|
(decomp_sum.as_mut()[row_j]).as_mut()[index] =
|
||||||
});
|
d_el;
|
||||||
});
|
});
|
||||||
|
});
|
||||||
|
|
||||||
decomp_sum
|
decomp_sum
|
||||||
.iter_rows_mut()
|
.iter_rows_mut()
|
||||||
.for_each(|r| nttop.forward(r.as_mut()));
|
.for_each(|r| nttop.forward(r.as_mut()));
|
||||||
|
|
||||||
decomp_sum
|
decomp_sum
|
||||||
})
|
})
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
|
|
||||||
// Produce RGSW(X^{s_{j=user_id, lwe}[l]}) for the
|
// Produce RGSW(X^{s_{j=user_id, lwe}[l]}) for the
|
||||||
// leader, ie user's id = user_id.
|
// leader, ie user's id = user_id.
|
||||||
@@ -1503,7 +1516,7 @@ where
|
|||||||
.0
|
.0
|
||||||
- rlwe_x_rgsw_decomposer.b().decomposition_count().0..],
|
- rlwe_x_rgsw_decomposer.b().decomposition_count().0..],
|
||||||
&rlwe_x_rgsw_decomposer,
|
&rlwe_x_rgsw_decomposer,
|
||||||
self.parameters(),
|
¶meters,
|
||||||
(&uj_to_s_ksks[user_id], &uj_to_s_ksks_part_a_eval[user_id]),
|
(&uj_to_s_ksks[user_id], &uj_to_s_ksks_part_a_eval[user_id]),
|
||||||
rlwe_modop,
|
rlwe_modop,
|
||||||
nttop,
|
nttop,
|
||||||
@@ -1524,7 +1537,7 @@ where
|
|||||||
&ni_rgsw_zero_encs,
|
&ni_rgsw_zero_encs,
|
||||||
&decomp_neg_ais,
|
&decomp_neg_ais,
|
||||||
&rgsw_x_rgsw_decomposer,
|
&rgsw_x_rgsw_decomposer,
|
||||||
self.parameters(),
|
¶meters,
|
||||||
(
|
(
|
||||||
&uj_to_s_ksks[other_user_id],
|
&uj_to_s_ksks[other_user_id],
|
||||||
&uj_to_s_ksks_part_a_eval[other_user_id],
|
&uj_to_s_ksks_part_a_eval[other_user_id],
|
||||||
@@ -1548,7 +1561,7 @@ where
|
|||||||
&rlwe_x_rgsw_decomposer,
|
&rlwe_x_rgsw_decomposer,
|
||||||
&rgsw_x_rgsw_decomposer,
|
&rgsw_x_rgsw_decomposer,
|
||||||
&mut RuntimeScratchMutRef::new(
|
&mut RuntimeScratchMutRef::new(
|
||||||
scratch_rgsw_x_rgsw.as_mut(),
|
scratch_rgsw_x_rgsw.clone().as_mut(),
|
||||||
),
|
),
|
||||||
nttop,
|
nttop,
|
||||||
rlwe_modop,
|
rlwe_modop,
|
||||||
@@ -1557,7 +1570,8 @@ where
|
|||||||
|
|
||||||
rgsw_i
|
rgsw_i
|
||||||
})
|
})
|
||||||
.collect_vec()
|
.collect_into_vec(&mut rgsws);
|
||||||
|
rgsws
|
||||||
})
|
})
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user