From 4a0d96d7a4f596c2107cc0f5b770aedfcdbcdce9 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Sat, 1 Jun 2024 15:17:30 +0530 Subject: [PATCH] move keys into different file --- src/bool/evaluator.rs | 1051 ++++------------------------------------ src/bool/keys.rs | 661 +++++++++++++++++++++++++ src/bool/mod.rs | 172 +++++++ src/bool/parameters.rs | 4 +- src/pbs.rs | 2 +- src/shortint/mod.rs | 17 +- src/shortint/ops.rs | 16 +- 7 files changed, 947 insertions(+), 976 deletions(-) create mode 100644 src/bool/keys.rs diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 9ea8081..c7f938f 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -15,7 +15,6 @@ use rand_distr::uniform::SampleUniform; use crate::{ backend::{ArithmeticOps, GetModulus, ModInit, ModularOpsU64, Modulus, VectorOps}, - bool::parameters::{MP_BOOL_PARAMS, SP_BOOL_PARAMS}, decomposer::{Decomposer, DefaultDecomposer, NumInfo, RlweDecomposer}, lwe::{decrypt_lwe, encrypt_lwe, lwe_key_switch, lwe_ksk_keygen, measure_noise_lwe, LweSecret}, multi_party::public_key_share, @@ -38,97 +37,15 @@ use crate::{ RowMut, Secret, }; -use super::parameters::{self, BoolParameters, CiphertextModulus}; - -thread_local! { - pub(crate) static BOOL_EVALUATOR: RefCell>, NttBackendU64, ModularOpsU64>, ModularOpsU64>>> = RefCell::new(BoolEvaluator::new(MP_BOOL_PARAMS)); - -} -pub(crate) static BOOL_SERVER_KEY: OnceLock< - ServerKeyEvaluationDomain>, DefaultSecureRng, NttBackendU64>, -> = OnceLock::new(); - -pub(crate) static MULTI_PARTY_CRS: OnceLock> = OnceLock::new(); - -pub fn set_parameter_set(parameter: &BoolParameters) { - BoolEvaluator::with_local_mut(|e| *e = BoolEvaluator::new(parameter.clone())) -} - -pub fn set_mp_seed(seed: [u8; 32]) { - assert!( - MULTI_PARTY_CRS.set(MultiPartyCrs { seed: seed }).is_ok(), - "Attempted to set MP SEED twice." - ) -} - -fn set_server_key(key: ServerKeyEvaluationDomain>, DefaultSecureRng, NttBackendU64>) { - assert!( - BOOL_SERVER_KEY.set(key).is_ok(), - "Attempted to set server key twice." - ); -} - -pub(crate) fn gen_keys() -> ( - ClientKey, - SeededServerKey>, BoolParameters, [u8; 32]>, -) { - BoolEvaluator::with_local_mut(|e| { - let ck = e.client_key(); - let sk = e.server_key(&ck); - - (ck, sk) - }) -} - -pub fn gen_client_key() -> ClientKey { - BoolEvaluator::with_local(|e| e.client_key()) -} - -pub fn gen_mp_keys_phase1( - ck: &ClientKey, -) -> CommonReferenceSeededCollectivePublicKeyShare, [u8; 32], BoolParameters> { - let seed = MultiPartyCrs::global().public_key_share_seed::(); - BoolEvaluator::with_local(|e| { - let pk_share = e.multi_party_public_key_share(seed, &ck); - pk_share - }) -} - -pub fn gen_mp_keys_phase2( - ck: &ClientKey, - pk: &PublicKey>, R, ModOp>, -) -> CommonReferenceSeededMultiPartyServerKeyShare>, BoolParameters, [u8; 32]> { - let seed = MultiPartyCrs::global().server_key_share_seed::(); - BoolEvaluator::with_local_mut(|e| { - let server_key_share = e.multi_party_server_key_share(seed, &pk.key, ck); - server_key_share - }) -} - -pub fn aggregate_public_key_shares( - shares: &[CommonReferenceSeededCollectivePublicKeyShare< - Vec, - [u8; 32], - BoolParameters, - >], -) -> PublicKey>, DefaultSecureRng, ModularOpsU64>> { - PublicKey::from(shares) -} - -pub fn aggregate_server_key_shares( - shares: &[CommonReferenceSeededMultiPartyServerKeyShare< - Vec>, - BoolParameters, - [u8; 32], - >], -) -> SeededMultiPartyServerKey>, [u8; 32], BoolParameters> { - BoolEvaluator::with_local(|e| e.aggregate_multi_party_server_key_shares(shares)) -} - -// GENERIC BELOW +use super::{ + keys::ClientKey, + parameters::{BoolParameters, CiphertextModulus}, + CommonReferenceSeededCollectivePublicKeyShare, CommonReferenceSeededMultiPartyServerKeyShare, + SeededMultiPartyServerKey, SeededServerKey, ServerKeyEvaluationDomain, +}; pub struct MultiPartyCrs { - seed: S, + pub(super) seed: S, } impl MultiPartyCrs { @@ -136,7 +53,7 @@ impl MultiPartyCrs { /// /// Public key seed equals the 1st seed extracted from PRNG Seeded with /// MiltiPartyCrs's seed. - fn public_key_share_seed + RandomFill>(&self) -> S { + pub(super) fn public_key_share_seed + RandomFill>(&self) -> S { let mut prng = Rng::new_with_seed(self.seed); let mut seed = S::default(); @@ -148,7 +65,7 @@ impl MultiPartyCrs { /// /// Server key seed equals the 2nd seed extracted from PRNG Seeded with /// MiltiPartyCrs's seed. - fn server_key_share_seed + RandomFill>(&self) -> S { + pub(super) fn server_key_share_seed + RandomFill>(&self) -> S { let mut prng = Rng::new_with_seed(self.seed); let mut seed = S::default(); @@ -158,14 +75,6 @@ impl MultiPartyCrs { } } -impl Global for MultiPartyCrs<[u8; 32]> { - fn global() -> &'static Self { - MULTI_PARTY_CRS - .get() - .expect("Multi Party Common Reference String not set") - } -} - pub(crate) trait BooleanGates { type Ciphertext: RowEntity; type Key; @@ -217,42 +126,6 @@ pub(crate) trait BooleanGates { fn not(&mut self, c: &Self::Ciphertext) -> Self::Ciphertext; } -impl WithLocal - for BoolEvaluator< - Vec>, - NttBackendU64, - ModularOpsU64>, - ModularOpsU64>, - > -{ - fn with_local(func: F) -> R - where - F: Fn(&Self) -> R, - { - BOOL_EVALUATOR.with_borrow(|s| func(s)) - } - - fn with_local_mut(func: F) -> R - where - F: Fn(&mut Self) -> R, - { - BOOL_EVALUATOR.with_borrow_mut(|s| func(s)) - } - - fn with_local_mut_mut(func: &mut F) -> R - where - F: FnMut(&mut Self) -> R, - { - BOOL_EVALUATOR.with_borrow_mut(|s| func(s)) - } -} - -impl Global for ServerKeyEvaluationDomain>, DefaultSecureRng, NttBackendU64> { - fn global() -> &'static Self { - BOOL_SERVER_KEY.get().unwrap() - } -} - struct ScratchMemory where M: Matrix, @@ -287,10 +160,6 @@ where } } -// thread_local! { -// pub(crate) static CLIENT_KEY: RefCell = -// RefCell::new(ClientKey::random()); } - trait BoolEncoding { type Element; fn true_el(&self) -> Self::Element; @@ -341,541 +210,6 @@ where } } -#[derive(Clone)] -pub struct ClientKey { - sk_rlwe: RlweSecret, - sk_lwe: LweSecret, -} - -impl ClientKey { - fn random() -> Self { - let sk_rlwe = RlweSecret::random(0, 0); - let sk_lwe = LweSecret::random(0, 0); - Self { sk_rlwe, sk_lwe } - } -} - -impl Encryptor> for ClientKey { - fn encrypt(&self, m: &bool) -> Vec { - BoolEvaluator::with_local(|e| e.sk_encrypt(*m, self)) - } -} - -impl Decryptor> for ClientKey { - fn decrypt(&self, c: &Vec) -> bool { - BoolEvaluator::with_local(|e| e.sk_decrypt(c, self)) - } -} - -impl MultiPartyDecryptor> for ClientKey { - type DecryptionShare = u64; - - fn gen_decryption_share(&self, c: &Vec) -> Self::DecryptionShare { - BoolEvaluator::with_local(|e| e.multi_party_decryption_share(c, &self)) - } - - fn aggregate_decryption_shares(&self, c: &Vec, shares: &[Self::DecryptionShare]) -> bool { - BoolEvaluator::with_local(|e| e.multi_party_decrypt(shares, c)) - } -} - -pub struct CommonReferenceSeededCollectivePublicKeyShare { - share: R, - cr_seed: S, - parameters: P, -} - -struct SeededPublicKey { - part_b: R, - seed: S, - parameters: P, - _phantom: PhantomData, -} - -impl - From<&[CommonReferenceSeededCollectivePublicKeyShare>]> - for SeededPublicKey, ModOp> -where - ModOp: VectorOps + ModInit>, - S: PartialEq + Clone, - R: RowMut + RowEntity + Clone, - R::Element: Clone + PartialEq, -{ - fn from( - value: &[CommonReferenceSeededCollectivePublicKeyShare>], - ) -> Self { - assert!(value.len() > 0); - - let parameters = &value[0].parameters; - let cr_seed = value[0].cr_seed.clone(); - - // Sum all Bs - let rlweq_modop = ModOp::new(parameters.rlwe_q().clone()); - let mut part_b = value[0].share.clone(); - value.iter().skip(1).for_each(|share_i| { - assert!(&share_i.cr_seed == &cr_seed); - assert!(&share_i.parameters == parameters); - - rlweq_modop.elwise_add_mut(part_b.as_mut(), share_i.share.as_ref()); - }); - - Self { - part_b, - seed: cr_seed, - parameters: parameters.clone(), - _phantom: PhantomData, - } - } -} - -pub struct PublicKey { - key: M, - _phantom: PhantomData<(Rng, ModOp)>, -} - -impl Encryptor> for PublicKey>, Rng, ModOp> { - fn encrypt(&self, m: &bool) -> Vec { - BoolEvaluator::with_local(|e| e.pk_encrypt(&self.key, *m)) - } -} - -impl Encryptor<[bool], Vec>> for PublicKey>, Rng, ModOp> { - fn encrypt(&self, m: &[bool]) -> Vec> { - BoolEvaluator::with_local(|e| e.pk_encrypt_batched(&self.key, m)) - } -} - -impl< - M: MatrixMut + MatrixEntity, - Rng: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], CiphertextModulus>, - ModOp, - > From, ModOp>> - for PublicKey -where - ::R: RowMut, - M::MatElement: Copy, -{ - fn from(value: SeededPublicKey, ModOp>) -> Self { - let mut prng = Rng::new_with_seed(value.seed); - - let mut key = M::zeros(2, value.parameters.rlwe_n().0); - // sample A - RandomFillUniformInModulus::random_fill( - &mut prng, - value.parameters.rlwe_q(), - key.get_row_mut(0), - ); - // Copy over B - key.get_row_mut(1).copy_from_slice(value.part_b.as_ref()); - - PublicKey { - key, - _phantom: PhantomData, - } - } -} - -impl< - M: MatrixMut + MatrixEntity, - Rng: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], CiphertextModulus>, - ModOp: VectorOps + ModInit>, - > - From< - &[CommonReferenceSeededCollectivePublicKeyShare< - M::R, - Rng::Seed, - BoolParameters, - >], - > for PublicKey -where - ::R: RowMut, - Rng::Seed: Copy + PartialEq, - M::MatElement: PartialEq + Copy, -{ - fn from( - value: &[CommonReferenceSeededCollectivePublicKeyShare< - M::R, - Rng::Seed, - BoolParameters, - >], - ) -> Self { - assert!(value.len() > 0); - - let parameters = &value[0].parameters; - let mut key = M::zeros(2, parameters.rlwe_n().0); - - // sample A - let seed = value[0].cr_seed; - let mut main_rng = Rng::new_with_seed(seed); - RandomFillUniformInModulus::random_fill( - &mut main_rng, - parameters.rlwe_q(), - key.get_row_mut(0), - ); - - // Sum all Bs - let rlweq_modop = ModOp::new(parameters.rlwe_q().clone()); - value.iter().for_each(|share_i| { - assert!(share_i.cr_seed == seed); - assert!(&share_i.parameters == parameters); - - rlweq_modop.elwise_add_mut(key.get_row_mut(1), share_i.share.as_ref()); - }); - - PublicKey { - key, - _phantom: PhantomData, - } - } -} - -pub struct CommonReferenceSeededMultiPartyServerKeyShare { - rgsw_cts: Vec, - /// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding - /// to -g is at 0 - auto_keys: HashMap, - lwe_ksk: M::R, - /// Common reference seed - cr_seed: S, - parameters: P, -} -pub struct SeededMultiPartyServerKey { - rgsw_cts: Vec, - /// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding - /// to -g is at 0 - auto_keys: HashMap, - lwe_ksk: M::R, - cr_seed: S, - parameters: P, -} - -impl - SeededMultiPartyServerKey< - Vec>, - ::Seed, - BoolParameters, - > -{ - pub fn set_server_key(&self) { - set_server_key(ServerKeyEvaluationDomain::< - Vec>, - DefaultSecureRng, - NttBackendU64, - >::from(self)) - } -} - -/// Seeded single party server key -pub struct SeededServerKey { - /// Rgsw cts of LWE secret elements - pub(crate) rgsw_cts: Vec, - /// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding - /// to -g is at 0 - pub(crate) auto_keys: HashMap, - /// LWE ksk to key switching LWE ciphertext from RLWE secret to LWE secret - pub(crate) lwe_ksk: M::R, - /// Parameters - pub(crate) parameters: P, - /// Main seed - pub(crate) seed: S, -} - -impl SeededServerKey, S> { - pub(crate) fn from_raw( - auto_keys: HashMap, - rgsw_cts: Vec, - lwe_ksk: M::R, - parameters: BoolParameters, - seed: S, - ) -> Self { - // sanity checks - auto_keys.iter().for_each(|v| { - assert!( - v.1.dimension() - == ( - parameters.auto_decomposition_count().0, - parameters.rlwe_n().0 - ) - ) - }); - - let (part_a_d, part_b_d) = parameters.rlwe_rgsw_decomposition_count(); - rgsw_cts.iter().for_each(|v| { - assert!(v.dimension() == (part_a_d.0 * 2 + part_b_d.0, parameters.rlwe_n().0)) - }); - assert!( - lwe_ksk.as_ref().len() - == (parameters.lwe_decomposition_count().0 * parameters.rlwe_n().0) - ); - - SeededServerKey { - rgsw_cts, - auto_keys, - lwe_ksk, - parameters, - seed, - } - } -} - -impl SeededServerKey>, BoolParameters, [u8; 32]> { - pub fn set_server_key(&self) { - set_server_key(ServerKeyEvaluationDomain::< - _, - DefaultSecureRng, - NttBackendU64, - >::from(self)); - } -} - -/// Server key in evaluation domain -pub(crate) struct ServerKeyEvaluationDomain { - /// Rgsw cts of LWE secret elements - rgsw_cts: Vec, - /// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding - /// to -g is at 0 - galois_keys: HashMap, - /// LWE ksk to key switching LWE ciphertext from RLWE secret to LWE secret - lwe_ksk: M, - _phanton: PhantomData<(R, N)>, -} - -impl< - M: MatrixMut + MatrixEntity, - R: RandomFillUniformInModulus<[M::MatElement], CiphertextModulus> + NewWithSeed, - N: NttInit> + Ntt, - > From<&SeededServerKey, R::Seed>> - for ServerKeyEvaluationDomain -where - ::R: RowMut, - M::MatElement: Copy, - R::Seed: Clone, -{ - fn from(value: &SeededServerKey, R::Seed>) -> Self { - let mut main_prng = R::new_with_seed(value.seed.clone()); - let parameters = &value.parameters; - let g = parameters.g() as isize; - let ring_size = value.parameters.rlwe_n().0; - let lwe_n = value.parameters.lwe_n().0; - let rlwe_q = value.parameters.rlwe_q(); - let lwq_q = value.parameters.lwe_q(); - - let nttop = N::new(rlwe_q, ring_size); - - // galois keys - let mut auto_keys = HashMap::new(); - let auto_decomp_count = parameters.auto_decomposition_count().0; - let auto_element_dlogs = parameters.auto_element_dlogs(); - for i in auto_element_dlogs.into_iter() { - let seeded_auto_key = value.auto_keys.get(&i).unwrap(); - assert!(seeded_auto_key.dimension() == (auto_decomp_count, ring_size)); - - let mut data = M::zeros(auto_decomp_count * 2, ring_size); - - // sample RLWE'_A(-s(X^k)) - data.iter_rows_mut().take(auto_decomp_count).for_each(|ri| { - RandomFillUniformInModulus::random_fill(&mut main_prng, &rlwe_q, ri.as_mut()) - }); - - // copy over RLWE'B_(-s(X^k)) - izip!( - data.iter_rows_mut().skip(auto_decomp_count), - seeded_auto_key.iter_rows() - ) - .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); - - // Send to Evaluation domain - data.iter_rows_mut() - .for_each(|ri| nttop.forward(ri.as_mut())); - - auto_keys.insert(i, data); - } - - // RGSW ciphertexts - let (rlrg_a_decomp, rlrg_b_decomp) = parameters.rlwe_rgsw_decomposition_count(); - let rgsw_cts = value - .rgsw_cts - .iter() - .map(|seeded_rgsw_si| { - assert!( - seeded_rgsw_si.dimension() - == (rlrg_a_decomp.0 * 2 + rlrg_b_decomp.0, ring_size) - ); - - let mut data = M::zeros(rlrg_a_decomp.0 * 2 + rlrg_b_decomp.0 * 2, ring_size); - - // copy over RLWE'(-sm) - izip!( - data.iter_rows_mut().take(rlrg_a_decomp.0 * 2), - seeded_rgsw_si.iter_rows().take(rlrg_a_decomp.0 * 2) - ) - .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); - - // sample RLWE'_A(m) - data.iter_rows_mut() - .skip(rlrg_a_decomp.0 * 2) - .take(rlrg_b_decomp.0) - .for_each(|ri| { - RandomFillUniformInModulus::random_fill( - &mut main_prng, - &rlwe_q, - ri.as_mut(), - ) - }); - - // copy over RLWE'_B(m) - izip!( - data.iter_rows_mut() - .skip(rlrg_a_decomp.0 * 2 + rlrg_b_decomp.0), - seeded_rgsw_si.iter_rows().skip(rlrg_a_decomp.0 * 2) - ) - .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); - - // send polynomials to evaluation domain - data.iter_rows_mut() - .for_each(|ri| nttop.forward(ri.as_mut())); - - data - }) - .collect_vec(); - - // LWE ksk - let lwe_ksk = { - let d = parameters.lwe_decomposition_count().0; - assert!(value.lwe_ksk.as_ref().len() == d * ring_size); - - let mut data = M::zeros(d * ring_size, lwe_n + 1); - izip!(data.iter_rows_mut(), value.lwe_ksk.as_ref().iter()).for_each(|(lwe_i, bi)| { - RandomFillUniformInModulus::random_fill( - &mut main_prng, - &lwq_q, - &mut lwe_i.as_mut()[1..], - ); - lwe_i.as_mut()[0] = *bi; - }); - - data - }; - - ServerKeyEvaluationDomain { - rgsw_cts, - galois_keys: auto_keys, - lwe_ksk, - _phanton: PhantomData, - } - } -} - -impl< - M: MatrixMut + MatrixEntity, - Rng: NewWithSeed, - N: NttInit> + Ntt, - > From<&SeededMultiPartyServerKey>> - for ServerKeyEvaluationDomain -where - ::R: RowMut, - Rng::Seed: Copy, - Rng: RandomFillUniformInModulus<[M::MatElement], CiphertextModulus>, - M::MatElement: Copy, -{ - fn from( - value: &SeededMultiPartyServerKey>, - ) -> Self { - let g = value.parameters.g() as isize; - let rlwe_n = value.parameters.rlwe_n().0; - let lwe_n = value.parameters.lwe_n().0; - let rlwe_q = value.parameters.rlwe_q(); - let lwe_q = value.parameters.lwe_q(); - - let mut main_prng = Rng::new_with_seed(value.cr_seed); - - let rlwe_nttop = N::new(rlwe_q, rlwe_n); - - // auto keys - let mut auto_keys = HashMap::new(); - let auto_d_count = value.parameters.auto_decomposition_count().0; - let auto_element_dlogs = value.parameters.auto_element_dlogs(); - for i in auto_element_dlogs.into_iter() { - let mut key = M::zeros(auto_d_count * 2, rlwe_n); - - // sample a - key.iter_rows_mut().take(auto_d_count).for_each(|ri| { - RandomFillUniformInModulus::random_fill(&mut main_prng, &rlwe_q, ri.as_mut()) - }); - - let key_part_b = value.auto_keys.get(&i).unwrap(); - assert!(key_part_b.dimension() == (auto_d_count, rlwe_n)); - izip!( - key.iter_rows_mut().skip(auto_d_count), - key_part_b.iter_rows() - ) - .for_each(|(to_ri, from_ri)| { - to_ri.as_mut().copy_from_slice(from_ri.as_ref()); - }); - - // send to evaluation domain - key.iter_rows_mut() - .for_each(|ri| rlwe_nttop.forward(ri.as_mut())); - - auto_keys.insert(i, key); - } - - // rgsw cts - let (rlrg_d_a, rlrg_d_b) = value.parameters.rlwe_rgsw_decomposition_count(); - let rgsw_ct_out = rlrg_d_a.0 * 2 + rlrg_d_b.0 * 2; - let rgsw_cts = value - .rgsw_cts - .iter() - .map(|ct_i_in| { - assert!(ct_i_in.dimension() == (rgsw_ct_out, rlwe_n)); - let mut eval_ct_i_out = M::zeros(rgsw_ct_out, rlwe_n); - - izip!(eval_ct_i_out.iter_rows_mut(), ct_i_in.iter_rows()).for_each( - |(to_ri, from_ri)| { - to_ri.as_mut().copy_from_slice(from_ri.as_ref()); - rlwe_nttop.forward(to_ri.as_mut()); - }, - ); - - eval_ct_i_out - }) - .collect_vec(); - - // lwe ksk - let d_lwe = value.parameters.lwe_decomposition_count().0; - let mut lwe_ksk = M::zeros(rlwe_n * d_lwe, lwe_n + 1); - izip!(lwe_ksk.iter_rows_mut(), value.lwe_ksk.as_ref().iter()).for_each(|(lwe_i, bi)| { - RandomFillUniformInModulus::random_fill( - &mut main_prng, - &lwe_q, - &mut lwe_i.as_mut()[1..], - ); - lwe_i.as_mut()[0] = *bi; - }); - - ServerKeyEvaluationDomain { - rgsw_cts, - galois_keys: auto_keys, - lwe_ksk, - _phanton: PhantomData, - } - } -} - -impl PbsKey for ServerKeyEvaluationDomain { - type M = M; - fn galois_key_for_auto(&self, k: usize) -> &Self::M { - self.galois_keys.get(&k).unwrap() - } - fn rgsw_ct_lwe_si(&self, si: usize) -> &Self::M { - &self.rgsw_cts[si] - } - - fn lwe_ksk(&self) -> &Self::M { - &self.lwe_ksk - } -} - struct BoolPbsInfo { auto_decomposer: DefaultDecomposer, rlwe_rgsw_decomposer: ( @@ -988,7 +322,7 @@ where M::R: TryConvertFrom1<[i32], CiphertextModulus> + RowEntity + Debug, ::R: RowMut, { - fn new(parameters: BoolParameters) -> Self + pub(super) fn new(parameters: BoolParameters) -> Self where RlweModOp: ModInit>, LweModOp: ModInit>, @@ -1113,7 +447,7 @@ where } } - fn client_key(&self) -> ClientKey { + pub(super) fn client_key(&self) -> ClientKey { let sk_lwe = LweSecret::random( self.pbs_info.parameters.lwe_n().0 >> 1, self.pbs_info.parameters.lwe_n().0, @@ -1122,10 +456,10 @@ where self.pbs_info.parameters.rlwe_n().0 >> 1, self.pbs_info.parameters.rlwe_n().0, ); - ClientKey { sk_rlwe, sk_lwe } + ClientKey::new(sk_rlwe, sk_lwe) } - fn server_key( + pub(super) fn server_key( &self, client_key: &ClientKey, ) -> SeededServerKey, [u8; 32]> { @@ -1136,8 +470,8 @@ where let mut main_prng = DefaultSecureRng::new_seeded(main_seed); let rlwe_n = self.pbs_info.parameters.rlwe_n().0; - let sk_rlwe = &client_key.sk_rlwe; - let sk_lwe = &client_key.sk_lwe; + let sk_rlwe = client_key.sk_rlwe(); + let sk_lwe = client_key.sk_lwe(); // generate auto keys let mut auto_keys = HashMap::new(); @@ -1231,7 +565,7 @@ where }) } - fn multi_party_server_key_share( + pub(super) fn multi_party_server_key_share( &self, cr_seed: [u8; 32], collective_pk: &M, @@ -1241,8 +575,8 @@ where DefaultSecureRng::with_local_mut(|rng| { let mut main_prng = DefaultSecureRng::new_seeded(cr_seed); - let sk_rlwe = &client_key.sk_rlwe; - let sk_lwe = &client_key.sk_lwe; + let sk_rlwe = client_key.sk_rlwe(); + let sk_lwe = client_key.sk_lwe(); let g = self.pbs_info.parameters.g(); let ring_size = self.pbs_info.parameters.rlwe_n().0; @@ -1348,17 +682,17 @@ where rng, ); - CommonReferenceSeededMultiPartyServerKeyShare { - auto_keys, + CommonReferenceSeededMultiPartyServerKeyShare::new( rgsw_cts, + auto_keys, lwe_ksk, cr_seed, - parameters: self.pbs_info.parameters.clone(), - } + self.pbs_info.parameters.clone(), + ) }) } - fn multi_party_public_key_share( + pub(super) fn multi_party_public_key_share( &self, cr_seed: [u8; 32], client_key: &ClientKey, @@ -1374,22 +708,21 @@ where let mut main_prng = DefaultSecureRng::new_seeded(cr_seed); public_key_share( &mut share_out, - client_key.sk_rlwe.values(), + client_key.sk_rlwe().values(), modop, nttop, &mut main_prng, rng, ); - - CommonReferenceSeededCollectivePublicKeyShare { - share: share_out, - cr_seed: cr_seed, - parameters: self.pbs_info.parameters.clone(), - } + CommonReferenceSeededCollectivePublicKeyShare::new( + share_out, + cr_seed, + self.pbs_info.parameters.clone(), + ) }) } - fn multi_party_decryption_share( + pub(super) fn multi_party_decryption_share( &self, lwe_ct: &M::R, client_key: &ClientKey, @@ -1397,7 +730,7 @@ where assert!(lwe_ct.as_ref().len() == self.pbs_info.parameters.rlwe_n().0 + 1); let modop = &self.pbs_info.rlwe_modop; let mut neg_s = M::R::try_convert_from( - client_key.sk_rlwe.values(), + client_key.sk_rlwe().values(), &self.pbs_info.parameters.rlwe_q(), ); modop.elwise_neg_mut(neg_s.as_mut()); @@ -1532,7 +865,7 @@ where encrypt_lwe( &mut lwe_out, &m, - client_key.sk_rlwe.values(), + client_key.sk_rlwe().values(), &self.pbs_info.rlwe_modop, rng, ); @@ -1543,13 +876,13 @@ where pub fn sk_decrypt(&self, lwe_ct: &M::R, client_key: &ClientKey) -> bool { let m = decrypt_lwe( lwe_ct, - client_key.sk_rlwe.values(), + client_key.sk_rlwe().values(), &self.pbs_info.rlwe_modop, ); self.pbs_info.rlwe_q().decode(m) } - fn aggregate_multi_party_server_key_shares( + pub(super) fn aggregate_multi_party_server_key_shares( &self, shares: &[CommonReferenceSeededMultiPartyServerKeyShare< M, @@ -1562,8 +895,8 @@ where M: Clone, { assert!(shares.len() > 0); - let parameters = shares[0].parameters.clone(); - let cr_seed = &shares[0].cr_seed; + let parameters = shares[0].parameters().clone(); + let cr_seed = shares[0].cr_seed(); let rlwe_n = parameters.rlwe_n().0; let g = parameters.g() as isize; @@ -1572,8 +905,8 @@ where // sanity checks shares.iter().skip(1).for_each(|s| { - assert!(s.parameters == parameters); - assert!(&s.cr_seed == cr_seed); + assert!(s.parameters() == ¶meters); + assert!(s.cr_seed() == cr_seed); }); let rlweq_modop = &self.pbs_info.rlwe_modop; @@ -1586,7 +919,7 @@ where let mut key = M::zeros(parameters.auto_decomposition_count().0, rlwe_n); shares.iter().for_each(|s| { - let auto_key_share_i = s.auto_keys.get(&i).expect("Auto key {i} missing"); + let auto_key_share_i = s.auto_keys().get(&i).expect("Auto key {i} missing"); assert!( auto_key_share_i.dimension() == (parameters.auto_decomposition_count().0, rlwe_n) @@ -1619,11 +952,11 @@ where let rgsw_cts = (0..lwe_n).into_iter().map(|index| { // copy over rgsw ciphertext for index^th secret element from first share and // treat it as accumulating rgsw ciphertext - let mut rgsw_i = shares[0].rgsw_cts[index].clone(); + let mut rgsw_i = shares[0].rgsw_cts()[index].clone(); shares.iter().skip(1).for_each(|si| { // copy over si's RGSW[index] ciphertext and send to evaluation domain - izip!(tmp_rgsw.iter_rows_mut(), si.rgsw_cts[index].iter_rows()).for_each( + izip!(tmp_rgsw.iter_rows_mut(), si.rgsw_cts()[index].iter_rows()).for_each( |(to_ri, from_ri)| { to_ri.as_mut().copy_from_slice(from_ri.as_ref()); rlweq_nttop.forward(to_ri.as_mut()) @@ -1722,17 +1055,11 @@ where let mut lwe_ksk = M::R::zeros(rlwe_n * parameters.lwe_decomposition_count().0); let lweq_modop = &self.pbs_info.lwe_modop; shares.iter().for_each(|si| { - assert!(si.lwe_ksk.as_ref().len() == rlwe_n * parameters.lwe_decomposition_count().0); - lweq_modop.elwise_add_mut(lwe_ksk.as_mut(), si.lwe_ksk.as_ref()) + assert!(si.lwe_ksk().as_ref().len() == rlwe_n * parameters.lwe_decomposition_count().0); + lweq_modop.elwise_add_mut(lwe_ksk.as_mut(), si.lwe_ksk().as_ref()) }); - SeededMultiPartyServerKey { - rgsw_cts, - auto_keys, - lwe_ksk, - cr_seed: cr_seed.clone(), - parameters: parameters, - } + SeededMultiPartyServerKey::new(rgsw_cts, auto_keys, lwe_ksk, cr_seed.clone(), parameters) } } @@ -1976,82 +1303,18 @@ where } } -thread_local! { - static PBS_TRACER: RefCell>>> = -RefCell::new(PBSTracer::default()); } - -#[derive(Default)] -struct PBSTracer -where - M: Matrix + Default, -{ - pub(crate) ct_rlwe_q_mod: M::R, - pub(crate) ct_lwe_q_mod: M::R, - pub(crate) ct_lwe_q_mod_after_ksk: M::R, - pub(crate) ct_br_q_mod: Vec, -} - -impl PBSTracer>> { - fn trace(&self, parameters: &BoolParameters, sk_lwe: &[i32], sk_rlwe: &[i32]) { - assert!(parameters.rlwe_n().0 == sk_rlwe.len()); - assert!(parameters.lwe_n().0 == sk_lwe.len()); - - let modop_rlweq = ModularOpsU64::new(*parameters.rlwe_q()); - // noise after mod down Q -> Q_ks - let m_back0 = decrypt_lwe(&self.ct_rlwe_q_mod, sk_rlwe, &modop_rlweq); - - let modop_lweq = ModularOpsU64::>::new(*parameters.lwe_q()); - // noise after mod down Q -> Q_ks - let m_back1 = decrypt_lwe(&self.ct_lwe_q_mod, sk_rlwe, &modop_lweq); - // noise after key switch from RLWE -> LWE - let m_back2 = decrypt_lwe(&self.ct_lwe_q_mod_after_ksk, sk_lwe, &modop_lweq); - - // noise after mod down odd from Q_ks -> q - let modop_br_q = ModularOpsU64::::new(*parameters.br_q() as u64); - let m_back3 = decrypt_lwe(&self.ct_br_q_mod, sk_lwe, &modop_br_q); - - println!( - " - M initial mod Q: {m_back0}, - M after mod down Q -> Q_ks: {m_back1}, - M after key switch from RLWE -> LWE: {m_back2}, - M after mod dwon Q_ks -> q: {m_back3} - " - ); - } -} - -impl WithLocal for PBSTracer>> { - fn with_local(func: F) -> R - where - F: Fn(&Self) -> R, - { - PBS_TRACER.with_borrow(|t| func(t)) - } - - fn with_local_mut(func: F) -> R - where - F: Fn(&mut Self) -> R, - { - PBS_TRACER.with_borrow_mut(|t| func(t)) - } - - fn with_local_mut_mut(func: &mut F) -> R - where - F: FnMut(&mut Self) -> R, - { - PBS_TRACER.with_borrow_mut(|t| func(t)) - } -} - #[cfg(test)] mod tests { + use bool::parameters::{MP_BOOL_PARAMS, SP_BOOL_PARAMS}; use rand::{thread_rng, Rng}; use rand_distr::Uniform; use crate::{ backend::{GetModulus, ModInit, ModularOpsU64, WordSizeModulus}, - bool, + bool::{ + self, CommonReferenceSeededMultiPartyServerKeyShare, PublicKey, + SeededMultiPartyServerKey, + }, ntt::NttBackendU64, random::{RandomElementInModulus, DEFAULT_RNG}, rgsw::{ @@ -2065,98 +1328,6 @@ mod tests { use super::*; - #[test] - fn tri() { - let bool_evaluator = BoolEvaluator::< - Vec>, - NttBackendU64, - ModularOpsU64>, - ModularOpsU64>, - >::new(SP_BOOL_PARAMS); - let mut v = bool_evaluator.pbs_info.g_k_dlog_map.clone(); - // v.sort(); - println!("{:?}", v); - let client_key = bool_evaluator.client_key(); - let server_key = bool_evaluator.server_key(&client_key); - let server_key_eval_domain = - ServerKeyEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from(&server_key); - - let ring_size = bool_evaluator.pbs_info.parameters.rlwe_n().0; - let rlwe_q = bool_evaluator.pbs_info.rlwe_q().q().unwrap(); - let mut rng = DefaultSecureRng::new(); - - let mut m = vec![0u64; ring_size as usize]; - RandomFillUniformInModulus::random_fill(&mut rng, &rlwe_q, m.as_mut_slice()); - - let ntt_op = &bool_evaluator.pbs_info.rlwe_nttop; - let mod_op = &bool_evaluator.pbs_info.rlwe_modop; - - // RLWE_{s}(m) - let mut seed_rlwe = [0u8; 32]; - rng.fill_bytes(&mut seed_rlwe); - let mut seeded_rlwe_m = SeededRlweCiphertext::empty(ring_size as usize, seed_rlwe, rlwe_q); - let mut p_rng = DefaultSecureRng::new_seeded(seed_rlwe); - secret_key_encrypt_rlwe( - &m, - &mut seeded_rlwe_m.data, - client_key.sk_rlwe.values(), - mod_op, - ntt_op, - &mut p_rng, - &mut rng, - ); - let mut rlwe_m = RlweCiphertext::>, DefaultSecureRng>::from(&seeded_rlwe_m); - - let k = 1; - let auto_k = (5usize).pow(k as u32); - // let auto_k = -5; - - let decomposer = bool_evaluator.pbs_info.auto_decomposer(); - - // Send RLWE_{s}(m) -> RLWE_{s}(m^k) - let mut scratch_space = - vec![vec![0u64; ring_size as usize]; decomposer.decomposition_count() + 2]; - let (auto_map_index, auto_map_sign) = bool_evaluator.pbs_info.rlwe_auto_map(k); - galois_auto( - &mut rlwe_m, - server_key_eval_domain.galois_key_for_auto(k), - &mut scratch_space, - &auto_map_index, - &auto_map_sign, - mod_op, - ntt_op, - decomposer, - ); - - let rlwe_m_k = rlwe_m; - - // Decrypt RLWE_{s}(m^k) and check - let mut encoded_m_k_back = vec![0u64; ring_size as usize]; - decrypt_rlwe( - &rlwe_m_k, - client_key.sk_rlwe.values(), - &mut encoded_m_k_back, - ntt_op, - mod_op, - ); - - { - let mut m_k = vec![0u64; ring_size]; - let (auto_map_index, auto_map_sign) = generate_auto_map(ring_size, auto_k as isize); - izip!(m.iter(), auto_map_index.iter(), auto_map_sign.iter()).for_each( - |(v, to_index, sign)| { - if !*sign { - m_k[*to_index] = (rlwe_q - *v); - } else { - m_k[*to_index] = *v; - } - }, - ); - let noise = measure_noise(&rlwe_m_k, &m_k, ntt_op, mod_op, client_key.sk_rlwe.values()); - println!("Ksk noise: {noise}"); - } - } - #[test] fn bool_encrypt_decrypt_works() { let bool_evaluator = BoolEvaluator::< @@ -2178,13 +1349,6 @@ mod tests { #[test] fn bool_nand() { - DefaultSecureRng::with_local_mut(|r| { - let rng = DefaultSecureRng::new_seeded([19u8; 32]); - *r = rng; - }); - - // let mog = WordSizeModulus::>::new(12u64); - let mut bool_evaluator = BoolEvaluator::< Vec>, NttBackendU64, @@ -2221,13 +1385,13 @@ mod tests { }; let n = measure_noise_lwe( &ct0, - client_key.sk_rlwe.values(), + client_key.sk_rlwe().values(), &bool_evaluator.pbs_info.rlwe_modop, &ideal, ); let v = decrypt_lwe( &ct0, - client_key.sk_rlwe.values(), + client_key.sk_rlwe().values(), &bool_evaluator.pbs_info.rlwe_modop, ); (n, v) @@ -2240,27 +1404,18 @@ mod tests { }; let n = measure_noise_lwe( &ct1, - client_key.sk_rlwe.values(), + client_key.sk_rlwe().values(), &bool_evaluator.pbs_info.rlwe_modop, &ideal, ); let v = decrypt_lwe( &ct1, - client_key.sk_rlwe.values(), + client_key.sk_rlwe().values(), &bool_evaluator.pbs_info.rlwe_modop, ); (n, v) }; - // // // Trace PBS - // PBSTracer::with_local(|t| { - // t.trace( - // &SP_BOOL_PARAMS, - // &client_key.sk_lwe.values(), - // client_key.sk_rlwe.values(), - // ) - // }); - // Calculate noise in ciphertext post PBS let noise_out = { let ideal = if m_out { @@ -2270,13 +1425,13 @@ mod tests { }; let n = measure_noise_lwe( &ct_back, - client_key.sk_rlwe.values(), + client_key.sk_rlwe().values(), &bool_evaluator.pbs_info.rlwe_modop, &ideal, ); let v = decrypt_lwe( &ct_back, - client_key.sk_rlwe.values(), + client_key.sk_rlwe().values(), &bool_evaluator.pbs_info.rlwe_modop, ); (n, v) @@ -2352,7 +1507,7 @@ mod tests { let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.pbs_info.rlwe_n()]; parties.iter().for_each(|k| { - izip!(ideal_rlwe_sk.iter_mut(), k.sk_rlwe.values()).for_each(|(ideal_i, s_i)| { + izip!(ideal_rlwe_sk.iter_mut(), k.sk_rlwe().values()).for_each(|(ideal_i, s_i)| { *ideal_i = *ideal_i + s_i; }); }); @@ -2373,7 +1528,7 @@ mod tests { DefaultSecureRng, ModularOpsU64>, >::from(public_key_share.as_slice()); - let lwe_ct = bool_evaluator.pk_encrypt(&collective_pk.key, m); + let lwe_ct = bool_evaluator.pk_encrypt(collective_pk.key(), m); let decryption_shares = parties .iter() @@ -2478,7 +1633,7 @@ mod tests { let server_key_shares = parties .iter() .map(|k| { - bool_evaluator.multi_party_server_key_share(pbs_cr_seed, &collective_pk.key, k) + bool_evaluator.multi_party_server_key_share(pbs_cr_seed, &collective_pk.key(), k) }) .collect_vec(); let seeded_server_key = @@ -2491,25 +1646,25 @@ mod tests { let ideal_client_key = { let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.pbs_info.rlwe_n()]; parties.iter().for_each(|k| { - izip!(ideal_rlwe_sk.iter_mut(), k.sk_rlwe.values()).for_each(|(ideal_i, s_i)| { + izip!(ideal_rlwe_sk.iter_mut(), k.sk_rlwe().values()).for_each(|(ideal_i, s_i)| { *ideal_i = *ideal_i + s_i; }); }); let mut ideal_lwe_sk = vec![0i32; bool_evaluator.pbs_info.lwe_n()]; parties.iter().for_each(|k| { - izip!(ideal_lwe_sk.iter_mut(), k.sk_lwe.values()).for_each(|(ideal_i, s_i)| { + izip!(ideal_lwe_sk.iter_mut(), k.sk_lwe().values()).for_each(|(ideal_i, s_i)| { *ideal_i = *ideal_i + s_i; }); }); - ClientKey { - sk_lwe: LweSecret { - values: ideal_lwe_sk, - }, - sk_rlwe: RlweSecret { + ClientKey::new( + RlweSecret { values: ideal_rlwe_sk, }, - } + LweSecret { + values: ideal_lwe_sk, + }, + ) }; ( @@ -2537,8 +1692,8 @@ mod tests { let mut m0 = true; let mut m1 = false; - let mut lwe0 = bool_evaluator.pk_encrypt(&collective_pk.key, m0); - let mut lwe1 = bool_evaluator.pk_encrypt(&collective_pk.key, m1); + let mut lwe0 = bool_evaluator.pk_encrypt(collective_pk.key(), m0); + let mut lwe1 = bool_evaluator.pk_encrypt(collective_pk.key(), m1); for _ in 0..2000 { let lwe_out = bool_evaluator.nand(&lwe0, &lwe1, &server_key_eval); @@ -2555,13 +1710,13 @@ mod tests { }; let n = measure_noise_lwe( &lwe0, - ideal_client_key.sk_rlwe.values(), + ideal_client_key.sk_rlwe().values(), &bool_evaluator.pbs_info.rlwe_modop, &ideal, ); let v = decrypt_lwe( &lwe0, - ideal_client_key.sk_rlwe.values(), + ideal_client_key.sk_rlwe().values(), &bool_evaluator.pbs_info.rlwe_modop, ); (n, v) @@ -2574,13 +1729,13 @@ mod tests { }; let n = measure_noise_lwe( &lwe1, - ideal_client_key.sk_rlwe.values(), + ideal_client_key.sk_rlwe().values(), &bool_evaluator.pbs_info.rlwe_modop, &ideal, ); let v = decrypt_lwe( &lwe1, - ideal_client_key.sk_rlwe.values(), + ideal_client_key.sk_rlwe().values(), &bool_evaluator.pbs_info.rlwe_modop, ); (n, v) @@ -2603,13 +1758,13 @@ mod tests { }; let n = measure_noise_lwe( &lwe_out, - ideal_client_key.sk_rlwe.values(), + ideal_client_key.sk_rlwe().values(), &bool_evaluator.pbs_info.rlwe_modop, &ideal_m, ); let v = decrypt_lwe( &lwe_out, - ideal_client_key.sk_rlwe.values(), + ideal_client_key.sk_rlwe().values(), &bool_evaluator.pbs_info.rlwe_modop, ); (n, v) @@ -2679,25 +1834,25 @@ mod tests { let ideal_client_key = { let mut ideal_rlwe_sk = vec![0i32; bool_evaluator.pbs_info.rlwe_n()]; parties.iter().for_each(|k| { - izip!(ideal_rlwe_sk.iter_mut(), k.sk_rlwe.values()).for_each(|(ideal_i, s_i)| { + izip!(ideal_rlwe_sk.iter_mut(), k.sk_rlwe().values()).for_each(|(ideal_i, s_i)| { *ideal_i = *ideal_i + s_i; }); }); let mut ideal_lwe_sk = vec![0i32; bool_evaluator.pbs_info.lwe_n()]; parties.iter().for_each(|k| { - izip!(ideal_lwe_sk.iter_mut(), k.sk_lwe.values()).for_each(|(ideal_i, s_i)| { + izip!(ideal_lwe_sk.iter_mut(), k.sk_lwe().values()).for_each(|(ideal_i, s_i)| { *ideal_i = *ideal_i + s_i; }); }); - ClientKey { - sk_lwe: LweSecret { - values: ideal_lwe_sk, - }, - sk_rlwe: RlweSecret { + ClientKey::new( + RlweSecret { values: ideal_rlwe_sk, }, - } + LweSecret { + values: ideal_lwe_sk, + }, + ) }; // check noise in freshly encrypted RLWE ciphertext (ie var_fresh) @@ -2723,7 +1878,7 @@ mod tests { let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; public_key_encrypt_rlwe::<_, _, _, _, i32, _>( &mut rlwe_ct, - &collective_pk.key, + collective_pk.key(), &m, rlwe_modop, rlwe_nttop, @@ -2733,7 +1888,7 @@ mod tests { let mut m_back = vec![0u64; rlwe_n]; decrypt_rlwe( &rlwe_ct, - ideal_client_key.sk_rlwe.values(), + ideal_client_key.sk_rlwe().values(), &mut m_back, rlwe_nttop, rlwe_modop, @@ -2767,7 +1922,7 @@ mod tests { let server_key_shares = parties .iter() .map(|k| { - bool_evaluator.multi_party_server_key_share(pbs_cr_seed, &collective_pk.key, k) + bool_evaluator.multi_party_server_key_share(pbs_cr_seed, collective_pk.key(), k) }) .collect_vec(); @@ -2778,8 +1933,8 @@ mod tests { if true { let mut check = Stats { samples: vec![] }; izip!( - ideal_client_key.sk_lwe.values.iter(), - seeded_server_key.rgsw_cts.iter() + ideal_client_key.sk_lwe().values.iter(), + seeded_server_key.rgsw_cts().iter() ) .for_each(|(s_i, rgsw_ct_i)| { // X^{s[i]} @@ -2793,7 +1948,7 @@ mod tests { // RLWE'(-sm) let mut neg_s_eval = - Vec::::try_convert_from(ideal_client_key.sk_rlwe.values(), rlwe_q); + Vec::::try_convert_from(ideal_client_key.sk_rlwe().values(), rlwe_q); rlwe_modop.elwise_neg_mut(&mut neg_s_eval); rlwe_nttop.forward(&mut neg_s_eval); for j in 0..rlwe_rgsw_decomposer.a().decomposition_count() { @@ -2817,7 +1972,7 @@ mod tests { let mut m_back = vec![0u64; rlwe_n]; decrypt_rlwe( &rlwe_ct, - ideal_client_key.sk_rlwe.values(), + ideal_client_key.sk_rlwe().values(), &mut m_back, rlwe_nttop, rlwe_modop, @@ -2851,7 +2006,7 @@ mod tests { let mut m_back = vec![0u64; rlwe_n]; decrypt_rlwe( &rlwe_ct, - ideal_client_key.sk_rlwe.values(), + ideal_client_key.sk_rlwe().values(), &mut m_back, rlwe_nttop, rlwe_modop, @@ -2881,8 +2036,8 @@ mod tests { let mut check = Stats { samples: vec![] }; izip!( - ideal_client_key.sk_lwe.values(), - server_key_eval_domain.rgsw_cts.iter() + ideal_client_key.sk_lwe().values(), + server_key_eval_domain.rgsw_cts().iter() ) .for_each(|(s_i, rgsw_ct_i)| { let mut m = vec![0u64; rlwe_n]; @@ -2890,7 +2045,7 @@ mod tests { let mut rlwe_ct = vec![vec![0u64; rlwe_n]; 2]; public_key_encrypt_rlwe::<_, _, _, _, i32, _>( &mut rlwe_ct, - &collective_pk.key, + collective_pk.key(), &m, rlwe_modop, rlwe_nttop, @@ -2936,7 +2091,7 @@ mod tests { let mut m_plus_e_times_m1 = vec![0u64; rlwe_n]; decrypt_rlwe( &rlwe_ct, - ideal_client_key.sk_rlwe.values(), + ideal_client_key.sk_rlwe().values(), &mut m_plus_e_times_m1, rlwe_nttop, rlwe_modop, @@ -2954,7 +2109,7 @@ mod tests { let mut m_plus_e_times_m1_more_e = vec![0u64; rlwe_n]; decrypt_rlwe( &rlwe_after, - ideal_client_key.sk_rlwe.values(), + ideal_client_key.sk_rlwe().values(), &mut m_plus_e_times_m1_more_e, rlwe_nttop, rlwe_modop, @@ -2992,7 +2147,7 @@ mod tests { let mut check = Stats { samples: vec![] }; let mut neg_s_poly = - Vec::::try_convert_from(ideal_client_key.sk_rlwe.values(), rlwe_q); + Vec::::try_convert_from(ideal_client_key.sk_rlwe().values(), rlwe_q); rlwe_modop.elwise_neg_mut(neg_s_poly.as_mut_slice()); let g = bool_evaluator.pbs_info.g(); @@ -3036,7 +2191,7 @@ mod tests { ); decrypt_rlwe( &rlwe_ct, - ideal_client_key.sk_rlwe.values(), + ideal_client_key.sk_rlwe().values(), &mut m_out, rlwe_nttop, rlwe_modop, @@ -3070,7 +2225,7 @@ mod tests { }; public_key_encrypt_rlwe::<_, _, _, _, i32, _>( &mut rlwe_ct, - &collective_pk.key, + collective_pk.key(), &m, rlwe_modop, rlwe_nttop, @@ -3082,7 +2237,7 @@ mod tests { let mut m_plus_e = vec![0u64; rlwe_n]; decrypt_rlwe( &rlwe_ct, - ideal_client_key.sk_rlwe.values(), + ideal_client_key.sk_rlwe().values(), &mut m_plus_e, rlwe_nttop, rlwe_modop, @@ -3118,7 +2273,7 @@ mod tests { let mut m_out = vec![0u64; rlwe_n]; decrypt_rlwe( &rlwe_ct, - ideal_client_key.sk_rlwe.values(), + ideal_client_key.sk_rlwe().values(), &mut m_out, rlwe_nttop, rlwe_modop, @@ -3149,7 +2304,7 @@ mod tests { encrypt_lwe( &mut lwe_in_ct, &m, - ideal_client_key.sk_rlwe.values(), + ideal_client_key.sk_rlwe().values(), lwe_modop, &mut rng, ); @@ -3167,10 +2322,10 @@ mod tests { // We only care about noise added by LWE key switch // m+e let m_plus_e = - decrypt_lwe(&lwe_in_ct, ideal_client_key.sk_rlwe.values(), lwe_modop); + decrypt_lwe(&lwe_in_ct, ideal_client_key.sk_rlwe().values(), lwe_modop); let m_plus_e_plus_lwe_ksk_noise = - decrypt_lwe(&lwe_out, ideal_client_key.sk_lwe.values(), lwe_modop); + decrypt_lwe(&lwe_out, ideal_client_key.sk_lwe().values(), lwe_modop); let diff = lwe_modop.sub(&m_plus_e_plus_lwe_ksk_noise, &m_plus_e); diff --git a/src/bool/keys.rs b/src/bool/keys.rs new file mode 100644 index 0000000..bde666f --- /dev/null +++ b/src/bool/keys.rs @@ -0,0 +1,661 @@ +use std::{collections::HashMap, hash::Hash, marker::PhantomData}; + +use crate::{ + backend::{ModInit, VectorOps}, + lwe::LweSecret, + random::{NewWithSeed, RandomFillUniformInModulus}, + rgsw::RlweSecret, + utils::WithLocal, + Decryptor, Encryptor, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, RowEntity, RowMut, +}; + +use super::{parameters, BoolEvaluator, BoolParameters, CiphertextModulus}; + +/// Client key with RLWE and LWE secrets +#[derive(Clone)] +pub struct ClientKey { + sk_rlwe: RlweSecret, + sk_lwe: LweSecret, +} + +mod impl_ck { + use super::*; + + // Client key + impl ClientKey { + pub(in super::super) fn random() -> Self { + let sk_rlwe = RlweSecret::random(0, 0); + let sk_lwe = LweSecret::random(0, 0); + Self { sk_rlwe, sk_lwe } + } + + pub(in super::super) fn new(sk_rlwe: RlweSecret, sk_lwe: LweSecret) -> Self { + Self { sk_rlwe, sk_lwe } + } + + pub(in super::super) fn sk_rlwe(&self) -> &RlweSecret { + &self.sk_rlwe + } + + pub(in super::super) fn sk_lwe(&self) -> &LweSecret { + &self.sk_lwe + } + } + + impl Encryptor> for ClientKey { + fn encrypt(&self, m: &bool) -> Vec { + BoolEvaluator::with_local(|e| e.sk_encrypt(*m, self)) + } + } + + impl Decryptor> for ClientKey { + fn decrypt(&self, c: &Vec) -> bool { + BoolEvaluator::with_local(|e| e.sk_decrypt(c, self)) + } + } + + impl MultiPartyDecryptor> for ClientKey { + type DecryptionShare = u64; + + fn gen_decryption_share(&self, c: &Vec) -> Self::DecryptionShare { + BoolEvaluator::with_local(|e| e.multi_party_decryption_share(c, &self)) + } + + fn aggregate_decryption_shares( + &self, + c: &Vec, + shares: &[Self::DecryptionShare], + ) -> bool { + BoolEvaluator::with_local(|e| e.multi_party_decrypt(shares, c)) + } + } +} + +/// Public key +pub struct PublicKey { + key: M, + _phantom: PhantomData<(Rng, ModOp)>, +} + +pub(super) mod impl_pk { + use super::*; + + impl PublicKey { + pub(in super::super) fn key(&self) -> &M { + &self.key + } + } + + impl Encryptor> for PublicKey>, Rng, ModOp> { + fn encrypt(&self, m: &bool) -> Vec { + BoolEvaluator::with_local(|e| e.pk_encrypt(&self.key, *m)) + } + } + + impl Encryptor<[bool], Vec>> for PublicKey>, Rng, ModOp> { + fn encrypt(&self, m: &[bool]) -> Vec> { + BoolEvaluator::with_local(|e| e.pk_encrypt_batched(&self.key, m)) + } + } + + impl< + M: MatrixMut + MatrixEntity, + Rng: NewWithSeed + + RandomFillUniformInModulus<[M::MatElement], CiphertextModulus>, + ModOp, + > From, ModOp>> + for PublicKey + where + ::R: RowMut, + M::MatElement: Copy, + { + fn from( + value: SeededPublicKey, ModOp>, + ) -> Self { + let mut prng = Rng::new_with_seed(value.seed); + + let mut key = M::zeros(2, value.parameters.rlwe_n().0); + // sample A + RandomFillUniformInModulus::random_fill( + &mut prng, + value.parameters.rlwe_q(), + key.get_row_mut(0), + ); + // Copy over B + key.get_row_mut(1).copy_from_slice(value.part_b.as_ref()); + + PublicKey { + key, + _phantom: PhantomData, + } + } + } + + impl< + M: MatrixMut + MatrixEntity, + Rng: NewWithSeed + + RandomFillUniformInModulus<[M::MatElement], CiphertextModulus>, + ModOp: VectorOps + ModInit>, + > + From< + &[CommonReferenceSeededCollectivePublicKeyShare< + M::R, + Rng::Seed, + BoolParameters, + >], + > for PublicKey + where + ::R: RowMut, + Rng::Seed: Copy + PartialEq, + M::MatElement: PartialEq + Copy, + { + fn from( + value: &[CommonReferenceSeededCollectivePublicKeyShare< + M::R, + Rng::Seed, + BoolParameters, + >], + ) -> Self { + assert!(value.len() > 0); + + let parameters = &value[0].parameters; + let mut key = M::zeros(2, parameters.rlwe_n().0); + + // sample A + let seed = value[0].cr_seed; + let mut main_rng = Rng::new_with_seed(seed); + RandomFillUniformInModulus::random_fill( + &mut main_rng, + parameters.rlwe_q(), + key.get_row_mut(0), + ); + + // Sum all Bs + let rlweq_modop = ModOp::new(parameters.rlwe_q().clone()); + value.iter().for_each(|share_i| { + assert!(share_i.cr_seed == seed); + assert!(&share_i.parameters == parameters); + + rlweq_modop.elwise_add_mut(key.get_row_mut(1), share_i.share.as_ref()); + }); + + PublicKey { + key, + _phantom: PhantomData, + } + } + } +} + +/// Seeded public key +struct SeededPublicKey { + part_b: Ro, + seed: S, + parameters: P, + _phantom: PhantomData, +} + +mod impl_seeded_pk { + use super::*; + + impl + From<&[CommonReferenceSeededCollectivePublicKeyShare>]> + for SeededPublicKey, ModOp> + where + ModOp: VectorOps + ModInit>, + S: PartialEq + Clone, + R: RowMut + RowEntity + Clone, + R::Element: Clone + PartialEq, + { + fn from( + value: &[CommonReferenceSeededCollectivePublicKeyShare< + R, + S, + BoolParameters, + >], + ) -> Self { + assert!(value.len() > 0); + + let parameters = &value[0].parameters; + let cr_seed = value[0].cr_seed.clone(); + + // Sum all Bs + let rlweq_modop = ModOp::new(parameters.rlwe_q().clone()); + let mut part_b = value[0].share.clone(); + value.iter().skip(1).for_each(|share_i| { + assert!(&share_i.cr_seed == &cr_seed); + assert!(&share_i.parameters == parameters); + + rlweq_modop.elwise_add_mut(part_b.as_mut(), share_i.share.as_ref()); + }); + + Self { + part_b, + seed: cr_seed, + parameters: parameters.clone(), + _phantom: PhantomData, + } + } + } +} + +/// CRS seeded collective public key share +pub struct CommonReferenceSeededCollectivePublicKeyShare { + share: Ro, + cr_seed: S, + parameters: P, +} +impl CommonReferenceSeededCollectivePublicKeyShare { + pub(super) fn new(share: Ro, cr_seed: S, parameters: P) -> Self { + CommonReferenceSeededCollectivePublicKeyShare { + share, + cr_seed, + parameters, + } + } +} + +/// CRS seeded Multi-party server key share +pub struct CommonReferenceSeededMultiPartyServerKeyShare { + rgsw_cts: Vec, + /// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding + /// to -g is at 0 + auto_keys: HashMap, + lwe_ksk: M::R, + /// Common reference seed + cr_seed: S, + parameters: P, +} + +impl CommonReferenceSeededMultiPartyServerKeyShare { + pub(super) fn new( + rgsw_cts: Vec, + auto_keys: HashMap, + lwe_ksk: M::R, + cr_seed: S, + parameters: P, + ) -> Self { + CommonReferenceSeededMultiPartyServerKeyShare { + rgsw_cts, + auto_keys, + lwe_ksk, + cr_seed, + parameters, + } + } + + pub(super) fn cr_seed(&self) -> &S { + &self.cr_seed + } + + pub(super) fn parameters(&self) -> &P { + &self.parameters + } + + pub(super) fn auto_keys(&self) -> &HashMap { + &self.auto_keys + } + + pub(super) fn rgsw_cts(&self) -> &[M] { + &self.rgsw_cts + } + + pub(super) fn lwe_ksk(&self) -> &M::R { + &self.lwe_ksk + } +} + +/// CRS seeded MultiParty server key +pub struct SeededMultiPartyServerKey { + rgsw_cts: Vec, + /// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding + /// to -g is at 0 + auto_keys: HashMap, + lwe_ksk: M::R, + cr_seed: S, + parameters: P, +} + +impl SeededMultiPartyServerKey { + pub(super) fn new( + rgsw_cts: Vec, + auto_keys: HashMap, + lwe_ksk: M::R, + cr_seed: S, + parameters: P, + ) -> Self { + SeededMultiPartyServerKey { + rgsw_cts, + auto_keys, + lwe_ksk, + cr_seed, + parameters, + } + } + + pub(super) fn rgsw_cts(&self) -> &[M] { + &self.rgsw_cts + } +} + +/// Seeded single party server key +pub struct SeededServerKey { + /// Rgsw cts of LWE secret elements + pub(crate) rgsw_cts: Vec, + /// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding + /// to -g is at 0 + pub(crate) auto_keys: HashMap, + /// LWE ksk to key switching LWE ciphertext from RLWE secret to LWE secret + pub(crate) lwe_ksk: M::R, + /// Parameters + pub(crate) parameters: P, + /// Main seed + pub(crate) seed: S, +} +impl SeededServerKey, S> { + pub(super) fn from_raw( + auto_keys: HashMap, + rgsw_cts: Vec, + lwe_ksk: M::R, + parameters: BoolParameters, + seed: S, + ) -> Self { + // sanity checks + auto_keys.iter().for_each(|v| { + assert!( + v.1.dimension() + == ( + parameters.auto_decomposition_count().0, + parameters.rlwe_n().0 + ) + ) + }); + + let (part_a_d, part_b_d) = parameters.rlwe_rgsw_decomposition_count(); + rgsw_cts.iter().for_each(|v| { + assert!(v.dimension() == (part_a_d.0 * 2 + part_b_d.0, parameters.rlwe_n().0)) + }); + assert!( + lwe_ksk.as_ref().len() + == (parameters.lwe_decomposition_count().0 * parameters.rlwe_n().0) + ); + + SeededServerKey { + rgsw_cts, + auto_keys, + lwe_ksk, + parameters, + seed, + } + } +} + +/// Server key in evaluation domain +pub(crate) struct ServerKeyEvaluationDomain { + /// Rgsw cts of LWE secret elements + rgsw_cts: Vec, + /// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding + /// to -g is at 0 + galois_keys: HashMap, + /// LWE ksk to key switching LWE ciphertext from RLWE secret to LWE secret + lwe_ksk: M, + _phanton: PhantomData<(R, N)>, +} + +pub(super) mod impl_server_key_eval_domain { + use itertools::{izip, Itertools}; + + use crate::{ + ntt::{Ntt, NttInit}, + pbs::PbsKey, + }; + + use super::*; + + impl ServerKeyEvaluationDomain { + pub(in super::super) fn rgsw_cts(&self) -> &[M] { + &self.rgsw_cts + } + } + + impl< + M: MatrixMut + MatrixEntity, + R: RandomFillUniformInModulus<[M::MatElement], CiphertextModulus> + + NewWithSeed, + N: NttInit> + Ntt, + > From<&SeededServerKey, R::Seed>> + for ServerKeyEvaluationDomain + where + ::R: RowMut, + M::MatElement: Copy, + R::Seed: Clone, + { + fn from(value: &SeededServerKey, R::Seed>) -> Self { + let mut main_prng = R::new_with_seed(value.seed.clone()); + let parameters = &value.parameters; + let g = parameters.g() as isize; + let ring_size = value.parameters.rlwe_n().0; + let lwe_n = value.parameters.lwe_n().0; + let rlwe_q = value.parameters.rlwe_q(); + let lwq_q = value.parameters.lwe_q(); + + let nttop = N::new(rlwe_q, ring_size); + + // galois keys + let mut auto_keys = HashMap::new(); + let auto_decomp_count = parameters.auto_decomposition_count().0; + let auto_element_dlogs = parameters.auto_element_dlogs(); + for i in auto_element_dlogs.into_iter() { + let seeded_auto_key = value.auto_keys.get(&i).unwrap(); + assert!(seeded_auto_key.dimension() == (auto_decomp_count, ring_size)); + + let mut data = M::zeros(auto_decomp_count * 2, ring_size); + + // sample RLWE'_A(-s(X^k)) + data.iter_rows_mut().take(auto_decomp_count).for_each(|ri| { + RandomFillUniformInModulus::random_fill(&mut main_prng, &rlwe_q, ri.as_mut()) + }); + + // copy over RLWE'B_(-s(X^k)) + izip!( + data.iter_rows_mut().skip(auto_decomp_count), + seeded_auto_key.iter_rows() + ) + .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); + + // Send to Evaluation domain + data.iter_rows_mut() + .for_each(|ri| nttop.forward(ri.as_mut())); + + auto_keys.insert(i, data); + } + + // RGSW ciphertexts + let (rlrg_a_decomp, rlrg_b_decomp) = parameters.rlwe_rgsw_decomposition_count(); + let rgsw_cts = value + .rgsw_cts + .iter() + .map(|seeded_rgsw_si| { + assert!( + seeded_rgsw_si.dimension() + == (rlrg_a_decomp.0 * 2 + rlrg_b_decomp.0, ring_size) + ); + + let mut data = M::zeros(rlrg_a_decomp.0 * 2 + rlrg_b_decomp.0 * 2, ring_size); + + // copy over RLWE'(-sm) + izip!( + data.iter_rows_mut().take(rlrg_a_decomp.0 * 2), + seeded_rgsw_si.iter_rows().take(rlrg_a_decomp.0 * 2) + ) + .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); + + // sample RLWE'_A(m) + data.iter_rows_mut() + .skip(rlrg_a_decomp.0 * 2) + .take(rlrg_b_decomp.0) + .for_each(|ri| { + RandomFillUniformInModulus::random_fill( + &mut main_prng, + &rlwe_q, + ri.as_mut(), + ) + }); + + // copy over RLWE'_B(m) + izip!( + data.iter_rows_mut() + .skip(rlrg_a_decomp.0 * 2 + rlrg_b_decomp.0), + seeded_rgsw_si.iter_rows().skip(rlrg_a_decomp.0 * 2) + ) + .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); + + // send polynomials to evaluation domain + data.iter_rows_mut() + .for_each(|ri| nttop.forward(ri.as_mut())); + + data + }) + .collect_vec(); + + // LWE ksk + let lwe_ksk = { + let d = parameters.lwe_decomposition_count().0; + assert!(value.lwe_ksk.as_ref().len() == d * ring_size); + + let mut data = M::zeros(d * ring_size, lwe_n + 1); + izip!(data.iter_rows_mut(), value.lwe_ksk.as_ref().iter()).for_each( + |(lwe_i, bi)| { + RandomFillUniformInModulus::random_fill( + &mut main_prng, + &lwq_q, + &mut lwe_i.as_mut()[1..], + ); + lwe_i.as_mut()[0] = *bi; + }, + ); + + data + }; + + ServerKeyEvaluationDomain { + rgsw_cts, + galois_keys: auto_keys, + lwe_ksk, + _phanton: PhantomData, + } + } + } + + impl< + M: MatrixMut + MatrixEntity, + Rng: NewWithSeed, + N: NttInit> + Ntt, + > From<&SeededMultiPartyServerKey>> + for ServerKeyEvaluationDomain + where + ::R: RowMut, + Rng::Seed: Copy, + Rng: RandomFillUniformInModulus<[M::MatElement], CiphertextModulus>, + M::MatElement: Copy, + { + fn from( + value: &SeededMultiPartyServerKey>, + ) -> Self { + let g = value.parameters.g() as isize; + let rlwe_n = value.parameters.rlwe_n().0; + let lwe_n = value.parameters.lwe_n().0; + let rlwe_q = value.parameters.rlwe_q(); + let lwe_q = value.parameters.lwe_q(); + + let mut main_prng = Rng::new_with_seed(value.cr_seed); + + let rlwe_nttop = N::new(rlwe_q, rlwe_n); + + // auto keys + let mut auto_keys = HashMap::new(); + let auto_d_count = value.parameters.auto_decomposition_count().0; + let auto_element_dlogs = value.parameters.auto_element_dlogs(); + for i in auto_element_dlogs.into_iter() { + let mut key = M::zeros(auto_d_count * 2, rlwe_n); + + // sample a + key.iter_rows_mut().take(auto_d_count).for_each(|ri| { + RandomFillUniformInModulus::random_fill(&mut main_prng, &rlwe_q, ri.as_mut()) + }); + + let key_part_b = value.auto_keys.get(&i).unwrap(); + assert!(key_part_b.dimension() == (auto_d_count, rlwe_n)); + izip!( + key.iter_rows_mut().skip(auto_d_count), + key_part_b.iter_rows() + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); + + // send to evaluation domain + key.iter_rows_mut() + .for_each(|ri| rlwe_nttop.forward(ri.as_mut())); + + auto_keys.insert(i, key); + } + + // rgsw cts + let (rlrg_d_a, rlrg_d_b) = value.parameters.rlwe_rgsw_decomposition_count(); + let rgsw_ct_out = rlrg_d_a.0 * 2 + rlrg_d_b.0 * 2; + let rgsw_cts = value + .rgsw_cts + .iter() + .map(|ct_i_in| { + assert!(ct_i_in.dimension() == (rgsw_ct_out, rlwe_n)); + let mut eval_ct_i_out = M::zeros(rgsw_ct_out, rlwe_n); + + izip!(eval_ct_i_out.iter_rows_mut(), ct_i_in.iter_rows()).for_each( + |(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + rlwe_nttop.forward(to_ri.as_mut()); + }, + ); + + eval_ct_i_out + }) + .collect_vec(); + + // lwe ksk + let d_lwe = value.parameters.lwe_decomposition_count().0; + let mut lwe_ksk = M::zeros(rlwe_n * d_lwe, lwe_n + 1); + izip!(lwe_ksk.iter_rows_mut(), value.lwe_ksk.as_ref().iter()).for_each( + |(lwe_i, bi)| { + RandomFillUniformInModulus::random_fill( + &mut main_prng, + &lwe_q, + &mut lwe_i.as_mut()[1..], + ); + lwe_i.as_mut()[0] = *bi; + }, + ); + + ServerKeyEvaluationDomain { + rgsw_cts, + galois_keys: auto_keys, + lwe_ksk, + _phanton: PhantomData, + } + } + } + + impl PbsKey for ServerKeyEvaluationDomain { + type M = M; + fn galois_key_for_auto(&self, k: usize) -> &Self::M { + self.galois_keys.get(&k).unwrap() + } + fn rgsw_ct_lwe_si(&self, si: usize) -> &Self::M { + &self.rgsw_cts[si] + } + + fn lwe_ksk(&self) -> &Self::M { + &self.lwe_ksk + } + } +} diff --git a/src/bool/mod.rs b/src/bool/mod.rs index 992bafe..a8a4f9d 100644 --- a/src/bool/mod.rs +++ b/src/bool/mod.rs @@ -1,4 +1,176 @@ pub(crate) mod evaluator; +pub(crate) mod keys; pub(crate) mod parameters; pub type FheBool = Vec; + +use std::{cell::RefCell, sync::OnceLock}; + +use evaluator::*; +use keys::*; +use parameters::*; + +use crate::{ + backend::ModularOpsU64, + ntt::NttBackendU64, + random::{DefaultSecureRng, NewWithSeed}, + utils::{Global, WithLocal}, +}; + +thread_local! { + static BOOL_EVALUATOR: RefCell>, NttBackendU64, ModularOpsU64>, ModularOpsU64>>> = RefCell::new(BoolEvaluator::new(MP_BOOL_PARAMS)); + +} +static BOOL_SERVER_KEY: OnceLock< + ServerKeyEvaluationDomain>, DefaultSecureRng, NttBackendU64>, +> = OnceLock::new(); + +static MULTI_PARTY_CRS: OnceLock> = OnceLock::new(); + +pub fn set_parameter_set(parameter: &BoolParameters) { + BoolEvaluator::with_local_mut(|e| *e = BoolEvaluator::new(parameter.clone())) +} + +pub fn set_mp_seed(seed: [u8; 32]) { + assert!( + MULTI_PARTY_CRS.set(MultiPartyCrs { seed: seed }).is_ok(), + "Attempted to set MP SEED twice." + ) +} + +fn set_server_key(key: ServerKeyEvaluationDomain>, DefaultSecureRng, NttBackendU64>) { + assert!( + BOOL_SERVER_KEY.set(key).is_ok(), + "Attempted to set server key twice." + ); +} + +pub(crate) fn gen_keys() -> ( + ClientKey, + SeededServerKey>, BoolParameters, [u8; 32]>, +) { + BoolEvaluator::with_local_mut(|e| { + let ck = e.client_key(); + let sk = e.server_key(&ck); + + (ck, sk) + }) +} + +pub fn gen_client_key() -> ClientKey { + BoolEvaluator::with_local(|e| e.client_key()) +} + +pub fn gen_mp_keys_phase1( + ck: &ClientKey, +) -> CommonReferenceSeededCollectivePublicKeyShare, [u8; 32], BoolParameters> { + let seed = MultiPartyCrs::global().public_key_share_seed::(); + BoolEvaluator::with_local(|e| { + let pk_share = e.multi_party_public_key_share(seed, &ck); + pk_share + }) +} + +pub fn gen_mp_keys_phase2( + ck: &ClientKey, + pk: &PublicKey>, R, ModOp>, +) -> CommonReferenceSeededMultiPartyServerKeyShare>, BoolParameters, [u8; 32]> { + let seed = MultiPartyCrs::global().server_key_share_seed::(); + BoolEvaluator::with_local_mut(|e| { + let server_key_share = e.multi_party_server_key_share(seed, pk.key(), ck); + server_key_share + }) +} + +pub fn aggregate_public_key_shares( + shares: &[CommonReferenceSeededCollectivePublicKeyShare< + Vec, + [u8; 32], + BoolParameters, + >], +) -> PublicKey>, DefaultSecureRng, ModularOpsU64>> { + PublicKey::from(shares) +} + +pub fn aggregate_server_key_shares( + shares: &[CommonReferenceSeededMultiPartyServerKeyShare< + Vec>, + BoolParameters, + [u8; 32], + >], +) -> SeededMultiPartyServerKey>, [u8; 32], BoolParameters> { + BoolEvaluator::with_local(|e| e.aggregate_multi_party_server_key_shares(shares)) +} + +// SERVER KEY EVAL DOMAIN // +impl SeededServerKey>, BoolParameters, [u8; 32]> { + pub fn set_server_key(&self) { + set_server_key(ServerKeyEvaluationDomain::< + _, + DefaultSecureRng, + NttBackendU64, + >::from(self)); + } +} + +impl + SeededMultiPartyServerKey< + Vec>, + ::Seed, + BoolParameters, + > +{ + pub fn set_server_key(&self) { + set_server_key(ServerKeyEvaluationDomain::< + Vec>, + DefaultSecureRng, + NttBackendU64, + >::from(self)) + } +} + +impl Global for ServerKeyEvaluationDomain>, DefaultSecureRng, NttBackendU64> { + fn global() -> &'static Self { + BOOL_SERVER_KEY.get().unwrap() + } +} + +// MULTIPARTY CRS // +impl Global for MultiPartyCrs<[u8; 32]> { + fn global() -> &'static Self { + MULTI_PARTY_CRS + .get() + .expect("Multi Party Common Reference String not set") + } +} + +// BOOL EVALUATOR // +impl WithLocal + for BoolEvaluator< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModularOpsU64>, + > +{ + fn with_local(func: F) -> R + where + F: Fn(&Self) -> R, + { + BOOL_EVALUATOR.with_borrow(|s| func(s)) + } + + fn with_local_mut(func: F) -> R + where + F: Fn(&mut Self) -> R, + { + BOOL_EVALUATOR.with_borrow_mut(|s| func(s)) + } + + fn with_local_mut_mut(func: &mut F) -> R + where + F: FnMut(&mut Self) -> R, + { + BOOL_EVALUATOR.with_borrow_mut(|s| func(s)) + } +} diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs index 1a64401..1e8d7d8 100644 --- a/src/bool/parameters.rs +++ b/src/bool/parameters.rs @@ -1,9 +1,9 @@ -use num_traits::{ConstZero, FromPrimitive, PrimInt, ToPrimitive, Zero}; +use num_traits::{ConstZero, FromPrimitive, PrimInt}; use crate::{backend::Modulus, decomposer::Decomposer}; #[derive(Clone, PartialEq)] -pub struct BoolParameters { +pub(crate) struct BoolParameters { rlwe_q: CiphertextModulus, lwe_q: CiphertextModulus, br_q: usize, diff --git a/src/pbs.rs b/src/pbs.rs index bad1bcb..127e6c7 100644 --- a/src/pbs.rs +++ b/src/pbs.rs @@ -214,7 +214,7 @@ fn blind_rotation< >( trivial_rlwe_test_poly: &mut MT, scratch_matrix: &mut Mmut, - g: isize, + _g: isize, w: usize, q: usize, gk_to_si: &[Vec], diff --git a/src/shortint/mod.rs b/src/shortint/mod.rs index 140efbe..9cb394c 100644 --- a/src/shortint/mod.rs +++ b/src/shortint/mod.rs @@ -1,11 +1,8 @@ use itertools::Itertools; use crate::{ - bool::evaluator::{ - BoolEvaluator, ClientKey, PublicKey, ServerKeyEvaluationDomain, BOOL_SERVER_KEY, - }, - utils::{Global, WithLocal}, - Decryptor, Encryptor, Matrix, MultiPartyDecryptor, + bool::keys::{ClientKey, PublicKey}, + Decryptor, Encryptor, MultiPartyDecryptor, }; mod ops; @@ -100,7 +97,7 @@ mod frontend { eight_bit_mul, }; use crate::{ - bool::evaluator::{BoolEvaluator, ServerKeyEvaluationDomain}, + bool::{evaluator::BoolEvaluator, keys::ServerKeyEvaluationDomain}, utils::{Global, WithLocal}, }; @@ -307,12 +304,10 @@ mod tests { use crate::{ bool::{ - evaluator::{ - aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key, gen_keys, - gen_mp_keys_phase1, gen_mp_keys_phase2, set_mp_seed, set_parameter_set, - BoolEvaluator, ClientKey, - }, + aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key, gen_keys, + gen_mp_keys_phase1, gen_mp_keys_phase2, parameters::{MP_BOOL_PARAMS, SP_BOOL_PARAMS}, + set_mp_seed, set_parameter_set, }, shortint::types::FheUint8, Decryptor, Encryptor, MultiPartyDecryptor, diff --git a/src/shortint/ops.rs b/src/shortint/ops.rs index 73fb363..88db460 100644 --- a/src/shortint/ops.rs +++ b/src/shortint/ops.rs @@ -1,18 +1,6 @@ -use std::mem::MaybeUninit; - use itertools::{izip, Itertools}; -use num_traits::PrimInt; - -use crate::{ - backend::ModularOpsU64, - bool::{ - evaluator::{BoolEvaluator, BooleanGates, ClientKey, ServerKeyEvaluationDomain}, - parameters::CiphertextModulus, - }, - ntt::NttBackendU64, - random::DefaultSecureRng, - Decryptor, -}; + +use crate::bool::evaluator::BooleanGates; pub(super) fn half_adder( evaluator: &mut E,