@ -9,6 +9,10 @@ use itertools::{izip, Itertools};
use num_traits ::{ FromPrimitive , One , PrimInt , ToPrimitive , WrappingAdd , WrappingSub , Zero } ;
use rand_distr ::uniform ::SampleUniform ;
use rayon ::iter ::FlatMap ;
use rayon ::prelude ::* ;
use std ::sync ::{ Arc , Mutex } ;
use crate ::{
backend ::{ ArithmeticOps , GetModulus , ModInit , Modulus , ShoupMatrixFMA , VectorOps } ,
bool ::parameters ::ParameterVariant ,
@ -627,7 +631,7 @@ pub(super) fn multi_party_user_id_lwe_segment(
impl < M : Matrix , NttOp , RlweModOp , LweModOp , SKey > BoolEvaluator < M , NttOp , RlweModOp , LweModOp , SKey >
where
M : MatrixEntity + MatrixMut ,
M : MatrixEntity + MatrixMut + Send + Sync ,
M ::MatElement : PrimInt
+ Debug
+ Display
@ -636,16 +640,21 @@ where
+ WrappingSub
+ WrappingAdd
+ SampleUniform
+ From < bool > ,
NttOp : Ntt < Element = M ::MatElement > ,
+ From < bool >
+ Send
+ Sync ,
NttOp : Ntt < Element = M ::MatElement > + Send + Sync ,
RlweModOp : ArithmeticOps < Element = M ::MatElement >
+ VectorOps < Element = M ::MatElement >
+ GetModulus < Element = M ::MatElement , M = CiphertextModulus < M ::MatElement > >
+ ShoupMatrixFMA < M ::R > ,
+ ShoupMatrixFMA < M ::R >
+ Send
+ Sync ,
LweModOp : ArithmeticOps < Element = M ::MatElement >
+ VectorOps < Element = M ::MatElement >
+ GetModulus < Element = M ::MatElement , M = CiphertextModulus < M ::MatElement > > ,
M ::R : TryConvertFrom1 < [ i32 ] , CiphertextModulus < M ::MatElement > > + RowEntity + Debug ,
M ::R :
TryConvertFrom1 < [ i32 ] , CiphertextModulus < M ::MatElement > > + RowEntity + Debug + Send + Sync ,
< M as Matrix > ::R : RowMut ,
{
pub ( super ) fn new ( parameters : BoolParameters < M ::MatElement > ) -> Self
@ -1378,6 +1387,10 @@ where
)
} )
. collect_vec ( ) ;
// clone self.rlwe_n & self.parameters so that we don't access to &self inside the
// threads
let rlwe_n = self . parameters ( ) . rlwe_n ( ) . 0. clone ( ) ;
let parameters = self . parameters ( ) . clone ( ) ;
// Note: Each user is assigned a contigous LWE segement and the LWE dimension is
// split approximately uniformly across all users. Hence, concatenation of all
// user specific lwe segments will give LWE dimension.
@ -1385,8 +1398,9 @@ where
. into_iter ( )
. enumerate ( )
. flat_map ( | ( user_id , lwe_segment ) | {
let mut rgsws = Vec ::new ( ) ;
( lwe_segment . 0 . . lwe_segment . 1 )
. into_iter ( )
. into_par_ iter ( )
. map ( | lwe_index | {
// We sample d_b `-a_i`s to key switch and generate RLWE'(m). But before
// we sampling we need to puncture a_prng d_max - d_b times to align
@ -1396,7 +1410,7 @@ where
cr_seed . ni_rgsw_ct_seed_for_index ::< DefaultSecureRng > ( lwe_index ) ,
) ;
let mut scratch = M ::R ::zeros ( self . parameters ( ) . rlwe_n ( ) . 0 ) ;
let mut scratch = M ::R ::zeros ( rlwe_n ) ;
( 0 . . d_max - rgsw_x_rgsw_decomposer . b ( ) . decomposition_count ( ) . 0 )
. for_each ( | _ | {
RandomFillUniformInModulus ::random_fill (
@ -1420,7 +1434,7 @@ where
let mut decomp_neg_ai = M ::zeros (
ni_uj_to_s_decomposer . decomposition_count ( ) . 0 ,
self . parameters ( ) . rlwe_n ( ) . 0 ,
rlwe_n ,
) ;
scratch . as_ref ( ) . iter ( ) . enumerate ( ) . for_each ( | ( index , el ) | {
ni_uj_to_s_decomposer
@ -1447,41 +1461,40 @@ where
// then use to produce RLWE'(-sX^{s_{lwe}[l]}).
// Hence, after aggregation we decompose a_{i, l} * s + e to
// prepare for key switching
let ni_rgsw_zero_encs = ( 0 . . rgsw_x_rgsw_decomposer
. a ( )
. decomposition_count ( )
. 0 )
. map ( | i | {
let mut sum = M ::R ::zeros ( self . parameters ( ) . rlwe_n ( ) . 0 ) ;
key_shares . iter ( ) . for_each ( | k | {
let to_add_ref = k
. ni_rgsw_zero_enc_for_lwe_index ( lwe_index )
. get_row_slice ( i ) ;
assert ! ( to_add_ref . len ( ) = = self . parameters ( ) . rlwe_n ( ) . 0 ) ;
rlwe_modop . elwise_add_mut ( sum . as_mut ( ) , to_add_ref ) ;
} ) ;
// decompose
let mut decomp_sum = M ::zeros (
ni_uj_to_s_decomposer . decomposition_count ( ) . 0 ,
self . parameters ( ) . rlwe_n ( ) . 0 ,
) ;
sum . as_ref ( ) . iter ( ) . enumerate ( ) . for_each ( | ( index , el ) | {
ni_uj_to_s_decomposer
. decompose_iter ( el )
. enumerate ( )
. for_each ( | ( row_j , d_el ) | {
( decomp_sum . as_mut ( ) [ row_j ] ) . as_mut ( ) [ index ] = d_el ;
} ) ;
} ) ;
decomp_sum
. iter_rows_mut ( )
. for_each ( | r | nttop . forward ( r . as_mut ( ) ) ) ;
decomp_sum
} )
. collect_vec ( ) ;
let ni_rgsw_zero_encs =
( 0 . . rgsw_x_rgsw_decomposer . a ( ) . decomposition_count ( ) . 0 )
. map ( | i | {
let mut sum = M ::R ::zeros ( rlwe_n ) ;
key_shares . iter ( ) . for_each ( | k | {
let to_add_ref = k
. ni_rgsw_zero_enc_for_lwe_index ( lwe_index )
. get_row_slice ( i ) ;
assert ! ( to_add_ref . len ( ) = = rlwe_n ) ;
rlwe_modop . elwise_add_mut ( sum . as_mut ( ) , to_add_ref ) ;
} ) ;
// decompose
let mut decomp_sum = M ::zeros (
ni_uj_to_s_decomposer . decomposition_count ( ) . 0 ,
rlwe_n ,
) ;
sum . as_ref ( ) . iter ( ) . enumerate ( ) . for_each ( | ( index , el ) | {
ni_uj_to_s_decomposer
. decompose_iter ( el )
. enumerate ( )
. for_each ( | ( row_j , d_el ) | {
( decomp_sum . as_mut ( ) [ row_j ] ) . as_mut ( ) [ index ] =
d_el ;
} ) ;
} ) ;
decomp_sum
. iter_rows_mut ( )
. for_each ( | r | nttop . forward ( r . as_mut ( ) ) ) ;
decomp_sum
} )
. collect_vec ( ) ;
// Produce RGSW(X^{s_{j=user_id, lwe}[l]}) for the
// leader, ie user's id = user_id.
@ -1503,7 +1516,7 @@ where
. 0
- rlwe_x_rgsw_decomposer . b ( ) . decomposition_count ( ) . 0 . . ] ,
& rlwe_x_rgsw_decomposer ,
self . parameters ( ) ,
& parameters ,
( & uj_to_s_ksks [ user_id ] , & uj_to_s_ksks_part_a_eval [ user_id ] ) ,
rlwe_modop ,
nttop ,
@ -1524,7 +1537,7 @@ where
& ni_rgsw_zero_encs ,
& decomp_neg_ais ,
& rgsw_x_rgsw_decomposer ,
self . parameters ( ) ,
& parameters ,
(
& uj_to_s_ksks [ other_user_id ] ,
& uj_to_s_ksks_part_a_eval [ other_user_id ] ,
@ -1548,7 +1561,7 @@ where
& rlwe_x_rgsw_decomposer ,
& rgsw_x_rgsw_decomposer ,
& mut RuntimeScratchMutRef ::new (
scratch_rgsw_x_rgsw . as_mut ( ) ,
scratch_rgsw_x_rgsw . clone ( ) . as_mut ( ) ,
) ,
nttop ,
rlwe_modop ,
@ -1557,7 +1570,8 @@ where
rgsw_i
} )
. collect_vec ( )
. collect_into_vec ( & mut rgsws ) ;
rgsws
} )
. collect_vec ( ) ;