diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 9bde3f6..8a37404 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -46,24 +46,29 @@ use crate::{ }; use super::{ - keys::ClientKey, parameters::{BoolParameters, CiphertextModulus}, - CommonReferenceSeededCollectivePublicKeyShare, CommonReferenceSeededMultiPartyServerKeyShare, - DecompositionCount, DecompostionLogBase, DoubleDecomposerParams, NonInteractiveClientKey, - SeededMultiPartyServerKey, SeededServerKey, ServerKeyEvaluationDomain, - ShoupServerKeyEvaluationDomain, + ClientKey, CommonReferenceSeededCollectivePublicKeyShare, + CommonReferenceSeededMultiPartyServerKeyShare, DecompositionCount, DecompostionLogBase, + DoubleDecomposerParams, SeededMultiPartyServerKey, SeededNonInteractiveMultiPartyServerKey, + SeededSinglePartyServerKey, ServerKeyEvaluationDomain, ShoupServerKeyEvaluationDomain, + ThrowMeAwayKey, }; -pub struct NonInteractiveMultiPartyServerKeyShare { +pub struct NonInteractiveMultiPartyServerKeyShare { /// (ak*si + e + \beta ui, ak*si + e) ni_rgsw_cts: (Vec, Vec), ui_to_s_ksk: M, others_ksk_zero_encs: Vec, + + auto_keys_share: HashMap, + lwe_ksk_share: M::R, + user_index: usize, + cr_seed: S, } -impl NonInteractiveMultiPartyServerKeyShare { - fn zero_enc_for_ui_to_s_ksk_for_user_i(&self, user_i: usize) -> &M { +impl NonInteractiveMultiPartyServerKeyShare { + fn ui_to_s_ksk_zero_encs_for_user_i(&self, user_i: usize) -> &M { assert!(user_i != self.user_index); if user_i < self.user_index { &self.others_ksk_zero_encs[user_i] @@ -77,46 +82,68 @@ pub struct MultiPartyCrs { pub(super) seed: S, } +fn puncture_p_rng>(p_rng: &mut R, times: usize) -> S { + let mut out = S::default(); + for _ in 0..times { + RandomFill::::random_fill(p_rng, &mut out); + } + return out; +} + +/// Common reference seed used for non-interactive multi-party. +/// +/// Initial Seed +/// Puncture 1 -> Key Seed +/// Puncture 1 -> Rgsw ciphertext seed +/// Puncture 2 -> auto keys seed +/// Puncture 3 -> Lwe key switching key seed +/// Puncture 2 -> user specific seed for u_j to s ksk +/// Punture j+1 -> user j's seed +#[derive(Clone)] pub struct NonInteractiveMultiPartyCrs { pub(super) seed: S, } +// impl Clone for NonInteractiveMultiPartyCrs where S: Clone {} +// impl Copy for NonInteractiveMultiPartyCrs where S: Copy {} + impl NonInteractiveMultiPartyCrs { - fn server_key_share_seed + RandomFill>(&self) -> S { + fn key_seed + RandomFill>(&self) -> S { let mut p_rng = R::new_with_seed(self.seed); + puncture_p_rng(&mut p_rng, 1) + } - // for main server key share seed sample once - let mut out = S::default(); - RandomFill::::random_fill(&mut p_rng, &mut out); + pub(crate) fn rgsw_cts_seed + RandomFill>(&self) -> S { + let key_seed = self.key_seed::(); + let mut p_rng = R::new_with_seed(key_seed); + puncture_p_rng(&mut p_rng, 1) + } - out + pub(crate) fn auto_keys_cts_seed + RandomFill>(&self) -> S { + let key_seed = self.key_seed::(); + let mut p_rng = R::new_with_seed(key_seed); + puncture_p_rng(&mut p_rng, 2) + } + + pub(crate) fn lwe_ksk_cts_seed + RandomFill>(&self) -> S { + let key_seed = self.key_seed::(); + let mut p_rng = R::new_with_seed(key_seed); + puncture_p_rng(&mut p_rng, 3) } fn ui_to_s_ks_seed + RandomFill>(&self) -> S { let mut p_rng = R::new_with_seed(self.seed); - - // puncture twice - let mut out = S::default(); - RandomFill::::random_fill(&mut p_rng, &mut out); - RandomFill::::random_fill(&mut p_rng, &mut out); - - out + puncture_p_rng(&mut p_rng, 2) } - fn ui_to_s_ks_seed_for_user_i + RandomFill>( + pub(crate) fn ui_to_s_ks_seed_for_user_i + RandomFill>( &self, user_i: usize, ) -> S { let ks_seed = self.ui_to_s_ks_seed::(); let mut p_rng = R::new_with_seed(ks_seed); - // puncture user_i times - let mut out = S::default(); - for _ in 0..user_i + 1 { - RandomFill::::random_fill(&mut p_rng, &mut out); - } - - out + puncture_p_rng(&mut p_rng, user_i + 1) } } @@ -636,7 +663,7 @@ where ClientKey::new(sk_rlwe, sk_lwe) } - pub(super) fn non_interactive_client_key(&self) -> NonInteractiveClientKey { + pub(super) fn non_interactive_client_key(&self) -> ThrowMeAwayKey { let sk_lwe = LweSecret::random( self.pbs_info.parameters.lwe_n().0 >> 1, self.pbs_info.parameters.lwe_n().0, @@ -649,13 +676,13 @@ where self.pbs_info.parameters.rlwe_n().0 >> 1, self.pbs_info.parameters.rlwe_n().0, ); - NonInteractiveClientKey::new(sk_rlwe, sk_u_rlwe, sk_lwe) + ThrowMeAwayKey::new(sk_rlwe, sk_u_rlwe, sk_lwe) } pub(super) fn single_party_server_key( &self, client_key: &ClientKey, - ) -> SeededServerKey, [u8; 32]> { + ) -> SeededSinglePartyServerKey, [u8; 32]> { DefaultSecureRng::with_local_mut(|rng| { let mut main_seed = [0u8; 32]; rng.fill_bytes(&mut main_seed); @@ -748,7 +775,7 @@ where rng, ); - SeededServerKey::from_raw( + SeededSinglePartyServerKey::from_raw( auto_keys, rgsw_cts, lwe_ksk, @@ -889,15 +916,42 @@ where &self, cr_seed: &NonInteractiveMultiPartyCrs<[u8; 32]>, total_users: usize, - key_shares: &[NonInteractiveMultiPartyServerKeyShare], - ) -> Vec + key_shares: &[NonInteractiveMultiPartyServerKeyShare< + M, + NonInteractiveMultiPartyCrs<[u8; 32]>, + >], + ) -> SeededNonInteractiveMultiPartyServerKey< + M, + NonInteractiveMultiPartyCrs<[u8; 32]>, + BoolParameters, + > where M: Clone + Debug, { + // sanity checks + let key_order = { + let existing_key_order = key_shares.iter().map(|s| s.user_index).collect_vec(); + + // record the order s.t. key_order[i] stores the position of i^th + // users key share in existing order + let mut key_order = Vec::with_capacity(existing_key_order.len()); + (0..total_users).map(|i| { + // find i + let index = existing_key_order + .iter() + .position(|x| x == &i) + .expect(&format!("Missing user {i}'s key!")); + key_order.push(index); + }); + + key_order + }; + let rlwe_modop = &self.pbs_info().rlwe_modop; let nttop = &self.pbs_info().rlwe_nttop; let ring_size = self.parameters().rlwe_n().0; let rlwe_q = self.parameters().rlwe_q(); + let lwe_modop = self.pbs_info().modop_lweq(); // genrate key switching key from u_i to s let ui_to_s_ksk_decomposition_count = self @@ -914,7 +968,7 @@ where .iter() .filter(|x| x.user_index != share.user_index) .for_each(|(other_share)| { - let op2 = other_share.zero_enc_for_ui_to_s_ksk_for_user_i(share.user_index); + let op2 = other_share.ui_to_s_ksk_zero_encs_for_user_i(share.user_index); assert!(op2.dimension() == (ui_to_s_ksk_decomposition_count.0, ring_size)); izip!(useri_ui_to_s_ksk.iter_rows_mut(), op2.iter_rows()).for_each( |(add_to, add_from)| { @@ -926,110 +980,111 @@ where }) .collect_vec(); - let mut key_prng = - DefaultSecureRng::new_seeded(cr_seed.server_key_share_seed::()); + let rgsw_cts = { + let mut rgsw_prng = + DefaultSecureRng::new_seeded(cr_seed.rgsw_cts_seed::()); - let rgsw_by_rgsw_decomposer = self - .parameters() - .rgsw_rgsw_decomposer::>(); - - // Generate RGSW Cts - let rgsw_cts_all_users_eval = { - // temporarily put ui_to_s in evaluation domain and sample a_i's for u_i to s - // ksk for upcomong key switches - ui_to_s_ksks.iter_mut().for_each(|ksk_i| { - ksk_i - .iter_rows_mut() - .for_each(|r| nttop.forward(r.as_mut())) - }); - let ui_to_s_ksks_part_a_eval = key_shares - .iter() - .map(|share| { - let mut ksk_prng = DefaultSecureRng::new_seeded( - cr_seed.ui_to_s_ks_seed_for_user_i::(share.user_index), - ); - let mut ais = M::zeros(ui_to_s_ksk_decomposition_count.0, ring_size); + let rgsw_by_rgsw_decomposer = self + .parameters() + .rgsw_rgsw_decomposer::>(); - ais.iter_rows_mut().for_each(|r_ai| { - RandomFillUniformInModulus::random_fill( - &mut ksk_prng, - rlwe_q, - r_ai.as_mut(), + // Generate RGSW Cts + let rgsw_cts_all_users_eval = { + // temporarily put ui_to_s in evaluation domain and sample a_i's for u_i to s + // ksk for upcomong key switches + ui_to_s_ksks.iter_mut().for_each(|ksk_i| { + ksk_i + .iter_rows_mut() + .for_each(|r| nttop.forward(r.as_mut())) + }); + let ui_to_s_ksks_part_a_eval = key_shares + .iter() + .map(|share| { + let mut ksk_prng = DefaultSecureRng::new_seeded( + cr_seed + .ui_to_s_ks_seed_for_user_i::(share.user_index), ); + let mut ais = M::zeros(ui_to_s_ksk_decomposition_count.0, ring_size); + + ais.iter_rows_mut().for_each(|r_ai| { + RandomFillUniformInModulus::random_fill( + &mut ksk_prng, + rlwe_q, + r_ai.as_mut(), + ); + + nttop.forward(r_ai.as_mut()) + }); + ais + }) + .collect_vec(); - nttop.forward(r_ai.as_mut()) - }); - ais - }) - .collect_vec(); - - let max_rgrg_deocmposer = if rgsw_by_rgsw_decomposer.a().decomposition_count() - > rgsw_by_rgsw_decomposer.b().decomposition_count() - { - rgsw_by_rgsw_decomposer.a() - } else { - rgsw_by_rgsw_decomposer.b() - }; + let max_rgrg_deocmposer = if rgsw_by_rgsw_decomposer.a().decomposition_count() + > rgsw_by_rgsw_decomposer.b().decomposition_count() + { + rgsw_by_rgsw_decomposer.a() + } else { + rgsw_by_rgsw_decomposer.b() + }; - let ui_to_s_ksk_decomposer = self + let ui_to_s_ksk_decomposer = self .parameters() .non_interactive_ui_to_s_key_switch_decomposer::>( ); - // Generate a_i*s + E = \sum_{j \in P} a_i*s_j + e for all rlwes in each - // non-interactive rgsws. Then decompose and put decompositions in evaluation - // for u_i -> s key switch. - let decomp_ni_rgsws_part_1_acc = { - let mut tmp_space = M::R::zeros(ring_size); - - (0..self.parameters().lwe_n().0) - .into_iter() - .map(|lwe_index| { - (0..max_rgrg_deocmposer.decomposition_count()) - .into_iter() - .map(|d_index| { - let mut sum = M::zeros( - ui_to_s_ksk_decomposer.decomposition_count(), - ring_size, - ); - - // set temp_space to all zeros - tmp_space.as_mut().fill(M::MatElement::zero()); - - // a_i*s + E - key_shares.iter().for_each(|s| { - rlwe_modop.elwise_add_mut( - tmp_space.as_mut(), - s.ni_rgsw_cts.1[lwe_index].get_row_slice(d_index), + // Generate a_i*s + E = \sum_{j \in P} a_i*s_j + e for all rlwes in each + // non-interactive rgsws. Then decompose and put decompositions in evaluation + // for u_i -> s key switch. + let decomp_ni_rgsws_part_1_acc = { + let mut tmp_space = M::R::zeros(ring_size); + + (0..self.parameters().lwe_n().0) + .into_iter() + .map(|lwe_index| { + (0..max_rgrg_deocmposer.decomposition_count()) + .into_iter() + .map(|d_index| { + let mut sum = M::zeros( + ui_to_s_ksk_decomposer.decomposition_count(), + ring_size, ); - }); - - tmp_space.as_ref().iter().enumerate().for_each(|(ri, el)| { - ui_to_s_ksk_decomposer - .decompose_iter(el) - .enumerate() - .for_each(|(row_j, d_el)| { - (sum.as_mut()[row_j]).as_mut()[ri] = d_el; - }); - }); - - sum.iter_rows_mut().for_each(|r| nttop.forward(r.as_mut())); - - sum - }) - .collect_vec() - }) - .collect_vec() - }; - // Sample a_i's are used to generate non-interactive rgsw cts for all lwe - // indices. Since a_i are just needed for key switches , decompose them - // and put them in evaluation domain - // Decomposition count used for RGSW ct in non-interactive key share gen equals - // max of A and B decomposition required in RGSWxRGSW. This is because same - // polynomials are used to generate RLWE'(m) and RLWE'(-sm) - let decomp_ni_rgsw_neg_ais = - { + // set temp_space to all zeros + tmp_space.as_mut().fill(M::MatElement::zero()); + + // a_i*s + E + key_shares.iter().for_each(|s| { + rlwe_modop.elwise_add_mut( + tmp_space.as_mut(), + s.ni_rgsw_cts.1[lwe_index].get_row_slice(d_index), + ); + }); + + tmp_space.as_ref().iter().enumerate().for_each(|(ri, el)| { + ui_to_s_ksk_decomposer + .decompose_iter(el) + .enumerate() + .for_each(|(row_j, d_el)| { + (sum.as_mut()[row_j]).as_mut()[ri] = d_el; + }); + }); + + sum.iter_rows_mut().for_each(|r| nttop.forward(r.as_mut())); + + sum + }) + .collect_vec() + }) + .collect_vec() + }; + + // Sample a_i's are used to generate non-interactive rgsw cts for all lwe + // indices. Since a_i are just needed for key switches , decompose them + // and put them in evaluation domain + // Decomposition count used for RGSW ct in non-interactive key share gen equals + // max of A and B decomposition required in RGSWxRGSW. This is because same + // polynomials are used to generate RLWE'(m) and RLWE'(-sm) + let decomp_ni_rgsw_neg_ais = { let mut tmp_space = M::R::zeros(ring_size); (0..self.parameters().lwe_n().0) .into_iter() @@ -1048,7 +1103,7 @@ where (0..max_rgrg_deocmposer.decomposition_count()) .map(|_| { RandomFillUniformInModulus::random_fill( - &mut key_prng, + &mut rgsw_prng, rlwe_q, tmp_space.as_mut(), ); @@ -1085,244 +1140,296 @@ where .collect_vec() }; - // genrate RGSW cts - let rgsw_cts_all_users_eval = izip!( - key_shares.iter(), - ui_to_s_ksks.iter(), - ui_to_s_ksks_part_a_eval.iter() - ) - .map( - |(share, user_uitos_ksk_partb_eval, user_uitos_ksk_parta_eval)| { - // RGSW_s(X^{s[i]}) - let rgsw_cts_user_i_eval = izip!( - share.ni_rgsw_cts.0.iter(), - decomp_ni_rgsw_neg_ais.iter(), - decomp_ni_rgsws_part_1_acc.iter() - ) - .map( - |( - m_encs_under_ui, - decomposed_rgsw_i_neg_ais, - decomposed_acc_rgsw_part1, - )| { - let d_a = rgsw_by_rgsw_decomposer.a().decomposition_count(); - let d_b = rgsw_by_rgsw_decomposer.b().decomposition_count(); - let max_d = std::cmp::max(d_a, d_b); - - assert!(decomposed_rgsw_i_neg_ais.len() == max_d); - - // To be RGSW(X^{s[i]}) = [RLWE'(-sm), RLWE'(m)] - let mut rgsw_ct_eval = M::zeros(d_a * 2 + d_b * 2, ring_size); - let (rlwe_dash_nsm, rlwe_dash_m) = - rgsw_ct_eval.split_at_row_mut(d_a * 2); - - let mut scratch_row = M::R::zeros(ring_size); - - let mut m_encs_under_ui_eval = m_encs_under_ui.clone(); - m_encs_under_ui_eval - .iter_rows_mut() - .for_each(|r| nttop.forward(r.as_mut())); - - // RLWE(-sm) - { - // Recall that we have RLWE(a_i * s + e). We key - // switch RLWE(a_i * s + e) using ksk(u_i -> s) to - // get RLWE(a_i*s*u_i + e*u_i). - // Given (u_i * a_i + e + \beta m), we obtain - // RLWE(-sm) = RLWE(a_i*s*u_i + e*u_i) + (0, (u_i * - // a_i + e + \beta m)) - // - // Again RLWE'(-sm) only cares for RLWE(\beta -sm) - // with scaling factor \beta corresponding to most - // signficant d_a limbs. Hence, we skip (d_max - - // d_a) least signficant limbs - - let (rlwe_dash_neg_sm_part_a, rlwe_dash_neg_sm_part_b) = - rlwe_dash_nsm.split_at_mut(d_a); - - izip!( - rlwe_dash_neg_sm_part_a.iter_mut(), - rlwe_dash_neg_sm_part_b.iter_mut(), - decomposed_acc_rgsw_part1.iter().skip(max_d - d_a), - m_encs_under_ui_eval.iter_rows().skip(max_d - d_a) - ) - .for_each( - |(rlwe_a, rlwe_b, decomp_ai_s, beta_m_enc_ui)| { - // RLWE_s(a_i * s * u_i + u_i * e) = decomp * - // ksk(u_i -> s) - izip!( - decomp_ai_s.iter_rows(), - user_uitos_ksk_partb_eval.iter_rows(), - user_uitos_ksk_parta_eval.iter_rows() - ) - .for_each( - |(a0, part_b, part_a)| { - // rlwe_b += decomp[i] * ksk part_b[i] - rlwe_modop.elwise_mul( - scratch_row.as_mut(), - a0.as_ref(), - part_b.as_ref(), - ); - rlwe_modop.elwise_add_mut( - rlwe_b.as_mut(), - scratch_row.as_ref(), - ); - - // rlwe_a += decomp[i] * ksk - // part_a[i] - rlwe_modop.elwise_mul( - scratch_row.as_mut(), - a0.as_ref(), - part_a.as_ref(), - ); - rlwe_modop.elwise_add_mut( - rlwe_a.as_mut(), - scratch_row.as_ref(), - ); - }, - ); + // genrate RGSW cts + let rgsw_cts_all_users_eval = izip!( + key_shares.iter(), + ui_to_s_ksks.iter(), + ui_to_s_ksks_part_a_eval.iter() + ) + .map( + |(share, user_uitos_ksk_partb_eval, user_uitos_ksk_parta_eval)| { + // RGSW_s(X^{s[i]}) + let rgsw_cts_user_i_eval = izip!( + share.ni_rgsw_cts.0.iter(), + decomp_ni_rgsw_neg_ais.iter(), + decomp_ni_rgsws_part_1_acc.iter() + ) + .map( + |( + m_encs_under_ui, + decomposed_rgsw_i_neg_ais, + decomposed_acc_rgsw_part1, + )| { + let d_a = rgsw_by_rgsw_decomposer.a().decomposition_count(); + let d_b = rgsw_by_rgsw_decomposer.b().decomposition_count(); + let max_d = std::cmp::max(d_a, d_b); + + assert!(decomposed_rgsw_i_neg_ais.len() == max_d); + + // To be RGSW(X^{s[i]}) = [RLWE'(-sm), RLWE'(m)] + let mut rgsw_ct_eval = M::zeros(d_a * 2 + d_b * 2, ring_size); + let (rlwe_dash_nsm, rlwe_dash_m) = + rgsw_ct_eval.split_at_row_mut(d_a * 2); + + let mut scratch_row = M::R::zeros(ring_size); + + let mut m_encs_under_ui_eval = m_encs_under_ui.clone(); + m_encs_under_ui_eval + .iter_rows_mut() + .for_each(|r| nttop.forward(r.as_mut())); + + // RLWE(-sm) + { + // Recall that we have RLWE(a_i * s + e). We key + // switch RLWE(a_i * s + e) using ksk(u_i -> s) to + // get RLWE(a_i*s*u_i + e*u_i). + // Given (u_i * a_i + e + \beta m), we obtain + // RLWE(-sm) = RLWE(a_i*s*u_i + e*u_i) + (0, (u_i * + // a_i + e + \beta m)) + // + // Again RLWE'(-sm) only cares for RLWE(\beta -sm) + // with scaling factor \beta corresponding to most + // signficant d_a limbs. Hence, we skip (d_max - + // d_a) least signficant limbs + + let (rlwe_dash_neg_sm_part_a, rlwe_dash_neg_sm_part_b) = + rlwe_dash_nsm.split_at_mut(d_a); + + izip!( + rlwe_dash_neg_sm_part_a.iter_mut(), + rlwe_dash_neg_sm_part_b.iter_mut(), + decomposed_acc_rgsw_part1.iter().skip(max_d - d_a), + m_encs_under_ui_eval.iter_rows().skip(max_d - d_a) + ) + .for_each( + |(rlwe_a, rlwe_b, decomp_ai_s, beta_m_enc_ui)| { + // RLWE_s(a_i * s * u_i + u_i * e) = decomp + // * ksk(u_i + // -> s) + izip!( + decomp_ai_s.iter_rows(), + user_uitos_ksk_partb_eval.iter_rows(), + user_uitos_ksk_parta_eval.iter_rows() + ) + .for_each( + |(a0, part_b, part_a)| { + // rlwe_b += decomp[i] * ksk + // part_b[i] + rlwe_modop.elwise_mul( + scratch_row.as_mut(), + a0.as_ref(), + part_b.as_ref(), + ); + rlwe_modop.elwise_add_mut( + rlwe_b.as_mut(), + scratch_row.as_ref(), + ); + + // rlwe_a += decomp[i] * ksk + // part_a[i] + rlwe_modop.elwise_mul( + scratch_row.as_mut(), + a0.as_ref(), + part_a.as_ref(), + ); + rlwe_modop.elwise_add_mut( + rlwe_a.as_mut(), + scratch_row.as_ref(), + ); + }, + ); + + // RLWE_s(-sm) = RLWE_s(a_i * s * u_i + u_i + // * e) + (0, a_i * u + e + m) + rlwe_modop.elwise_add_mut( + rlwe_a.as_mut(), + beta_m_enc_ui.as_ref(), + ); + }, + ); + } - // RLWE_s(-sm) = RLWE_s(a_i * s * u_i + u_i - // * e) + (0, a_i * u + e + m) - rlwe_modop.elwise_add_mut( - rlwe_a.as_mut(), - beta_m_enc_ui.as_ref(), - ); - }, - ); - } + // RLWE(m) + { + // Routine: + // Let RLWE(-a_i * u_i) = (decomp<-a_i> \cdot ksk(u_i -> s)), + // then RLWE(m) = RLWE(a'*s + e + m) = (a_i + // * u_i + e + m, 0) + RLWE(-a_i * u_i) + // + // Since RLWE'(m) only cares for RLWE ciphertexts corresponding + // to higher d_b limbs, we skip routine 1 for lower max(d_a, + // d_b) - d_b limbs + let (rlwe_dash_m_part_a, rlwe_dash_m_part_b) = + rlwe_dash_m.split_at_mut(d_b); + + izip!( + rlwe_dash_m_part_a.iter_mut(), + rlwe_dash_m_part_b.iter_mut(), + decomposed_rgsw_i_neg_ais.iter().skip(max_d - d_b), + m_encs_under_ui_eval.iter_rows().skip(max_d - d_b) + ) + .for_each( + |(rlwe_a, rlwe_b, decomp_neg_ai, beta_m_enc_ui)| { + // RLWE_s(-a_i * ui) = decomp<-a_i> \cdot ksk(ui -> s) + izip!( + decomp_neg_ai.iter_rows(), + user_uitos_ksk_partb_eval.iter_rows(), + user_uitos_ksk_parta_eval.iter_rows() + ) + .for_each( + |(a0, part_b, part_a)| { + // rlwe_b += decomp[i] * ksk part_b[i] + rlwe_modop.elwise_mul( + scratch_row.as_mut(), + a0.as_ref(), + part_b.as_ref(), + ); + rlwe_modop.elwise_add_mut( + rlwe_b.as_mut(), + scratch_row.as_ref(), + ); + + // rlwe_a += decomp[i] * ksk + // part_a[i] + rlwe_modop.elwise_mul( + scratch_row.as_mut(), + a0.as_ref(), + part_a.as_ref(), + ); + rlwe_modop.elwise_add_mut( + rlwe_a.as_mut(), + scratch_row.as_ref(), + ); + }, + ); + + // RLWE_s(m) = (a_i * ui + e + \beta m, 0) + // + + // RLWE_s(-a_i * ui) + rlwe_modop.elwise_add_mut( + rlwe_b.as_mut(), + beta_m_enc_ui.as_ref(), + ); + }, + ); + } - // RLWE(m) - { - // Routine: - // Let RLWE(-a_i * u_i) = (decomp<-a_i> \cdot ksk(u_i -> s)), - // then RLWE(m) = RLWE(a'*s + e + m) = (a_i - // * u_i + e + m, 0) + RLWE(-a_i * u_i) - // - // Since RLWE'(m) only cares for RLWE ciphertexts corresponding - // to higher d_b limbs, we skip routine 1 for lower max(d_a, - // d_b) - d_b limbs - let (rlwe_dash_m_part_a, rlwe_dash_m_part_b) = - rlwe_dash_m.split_at_mut(d_b); - - izip!( - rlwe_dash_m_part_a.iter_mut(), - rlwe_dash_m_part_b.iter_mut(), - decomposed_rgsw_i_neg_ais.iter().skip(max_d - d_b), - m_encs_under_ui_eval.iter_rows().skip(max_d - d_b) - ) - .for_each( - |(rlwe_a, rlwe_b, decomp_neg_ai, beta_m_enc_ui)| { - // RLWE_s(-a_i * ui) = decomp<-a_i> \cdot ksk(ui -> s) - izip!( - decomp_neg_ai.iter_rows(), - user_uitos_ksk_partb_eval.iter_rows(), - user_uitos_ksk_parta_eval.iter_rows() - ) - .for_each( - |(a0, part_b, part_a)| { - // rlwe_b += decomp[i] * ksk part_b[i] - rlwe_modop.elwise_mul( - scratch_row.as_mut(), - a0.as_ref(), - part_b.as_ref(), - ); - rlwe_modop.elwise_add_mut( - rlwe_b.as_mut(), - scratch_row.as_ref(), - ); - - // rlwe_a += decomp[i] * ksk - // part_a[i] - rlwe_modop.elwise_mul( - scratch_row.as_mut(), - a0.as_ref(), - part_a.as_ref(), - ); - rlwe_modop.elwise_add_mut( - rlwe_a.as_mut(), - scratch_row.as_ref(), - ); - }, - ); + rgsw_ct_eval + }, + ) + .collect_vec(); + rgsw_cts_user_i_eval + }, + ) + .collect_vec(); - // RLWE_s(m) = (a_i * ui + e + \beta m, 0) - // + - // RLWE_s(-a_i * ui) - rlwe_modop.elwise_add_mut( - rlwe_b.as_mut(), - beta_m_enc_ui.as_ref(), - ); - }, - ); - } + // put u_i -> s ksks back in coefficient domain + ui_to_s_ksks.iter_mut().for_each(|ksk_i| { + ksk_i + .iter_rows_mut() + .for_each(|r| nttop.backward(r.as_mut())) + }); - rgsw_ct_eval - }, - ) - .collect_vec(); - rgsw_cts_user_i_eval - }, - ) - .collect_vec(); + rgsw_cts_all_users_eval + }; - // put u_i -> s ksks back in coefficient domain - ui_to_s_ksks.iter_mut().for_each(|ksk_i| { - ksk_i + // RGSW x RGSW + let lwe_n = self.parameters().lwe_n().0; + let mut scratch_matrix = M::zeros( + std::cmp::max( + rgsw_by_rgsw_decomposer.a().decomposition_count(), + rgsw_by_rgsw_decomposer.b().decomposition_count(), + ) + (rgsw_by_rgsw_decomposer.a().decomposition_count() * 2 + + rgsw_by_rgsw_decomposer.b().decomposition_count() * 2), + ring_size, + ); + let rgsw_cts_untrimmed = (0..lwe_n).map(|s_index| { + // copy over s_index^th rgsw ct of user 0. Use it to accumulate RGSW products of + // all RGSW ciphertexts at s_index + let mut rgsw_i = rgsw_cts_all_users_eval[0][s_index].clone(); + rgsw_i .iter_rows_mut() - .for_each(|r| nttop.backward(r.as_mut())) + .for_each(|r| nttop.backward(r.as_mut())); + + rgsw_cts_all_users_eval + .iter() + .skip(1) + .for_each(|user_i_rgsws| { + rgsw_by_rgsw_inplace( + &mut rgsw_i, + &user_i_rgsws[s_index], + &rgsw_by_rgsw_decomposer, + &mut scratch_matrix, + nttop, + rlwe_modop, + ); + }); + + rgsw_i }); - rgsw_cts_all_users_eval + // After this point we don't require RGSW cts for RGSWxRGSW + // multiplicaiton anymore. So we trim them to suit RLWExRGSW require for PBS + let rgsw_cts = rgsw_cts_untrimmed + .map(|rgsw_ct| { + trim_rgsw_ct_matrix_from_rgrg_to_rlrg( + rgsw_ct, + self.parameters().rgsw_by_rgsw_decomposition_params(), + self.parameters().rlwe_by_rgsw_decomposition_params(), + ) + }) + .collect_vec(); + rgsw_cts }; - // RGSW x RGSW - let lwe_n = self.parameters().lwe_n().0; - let mut scratch_matrix = M::zeros( - std::cmp::max( - rgsw_by_rgsw_decomposer.a().decomposition_count(), - rgsw_by_rgsw_decomposer.b().decomposition_count(), - ) + (rgsw_by_rgsw_decomposer.a().decomposition_count() * 2 - + rgsw_by_rgsw_decomposer.b().decomposition_count() * 2), - ring_size, - ); - let rgsw_cts_untrimmed = (0..lwe_n).map(|s_index| { - // copy over s_index^th rgsw ct of user 0. Use it to accumulate RGSW products of - // all RGSW ciphertexts at s_index - let mut rgsw_i = rgsw_cts_all_users_eval[0][s_index].clone(); - rgsw_i - .iter_rows_mut() - .for_each(|r| nttop.backward(r.as_mut())); - - rgsw_cts_all_users_eval - .iter() - .skip(1) - .for_each(|user_i_rgsws| { - rgsw_by_rgsw_inplace( - &mut rgsw_i, - &user_i_rgsws[s_index], - &rgsw_by_rgsw_decomposer, - &mut scratch_matrix, - nttop, - rlwe_modop, + // auto keys + let auto_keys = { + let mut auto_keys = HashMap::new(); + let auto_elements_dlog = self.parameters().auto_element_dlogs(); + for i in auto_elements_dlog.into_iter() { + let mut key = M::zeros(self.parameters().auto_decomposition_count().0, ring_size); + + key_shares.iter().for_each(|s| { + let auto_key_share_i = s.auto_keys_share.get(&i).expect("Auto key {i} missing"); + assert!( + auto_key_share_i.dimension() + == (self.parameters().auto_decomposition_count().0, ring_size) + ); + izip!(key.iter_rows_mut(), auto_key_share_i.iter_rows()).for_each( + |(partb_out, partb_share)| { + rlwe_modop.elwise_add_mut(partb_out.as_mut(), partb_share.as_ref()); + }, ); }); - rgsw_i - }); + auto_keys.insert(i, key); + } + auto_keys + }; - // After this point we don't require RGSW cts for RGSWxRGSW - // multiplicaiton anymore. So we trim them to suit RLWExRGSW require for PBS - let rgsw_cts = rgsw_cts_untrimmed - .map(|rgsw_ct| { - trim_rgsw_ct_matrix_from_rgrg_to_rlrg( - rgsw_ct, - self.parameters().rgsw_by_rgsw_decomposition_params(), - self.parameters().rlwe_by_rgsw_decomposition_params(), - ) - }) - .collect_vec(); - rgsw_cts + // LWE ksk + let lwe_ksk = { + let mut lwe_ksk = + M::R::zeros(self.parameters().lwe_decomposition_count().0 * ring_size); + key_shares.iter().for_each(|s| { + assert!( + s.lwe_ksk_share.as_ref().len() + == self.parameters().lwe_decomposition_count().0 * ring_size + ); + lwe_modop.elwise_add_mut(lwe_ksk.as_mut(), s.lwe_ksk_share.as_ref()); + }); + lwe_ksk + }; + + SeededNonInteractiveMultiPartyServerKey::new( + ui_to_s_ksks, + key_order, + rgsw_cts, + auto_keys, + lwe_ksk, + cr_seed.clone(), + self.parameters().clone(), + ) } pub(super) fn non_interactive_multi_party_key_share( @@ -1331,16 +1438,15 @@ where cr_seed: &NonInteractiveMultiPartyCrs<[u8; 32]>, self_index: usize, total_users: usize, - client_key: &NonInteractiveClientKey, - ) -> NonInteractiveMultiPartyServerKeyShare { + client_key: &ThrowMeAwayKey, + ) -> NonInteractiveMultiPartyServerKeyShare> { // TODO: check whether parameters support `total_users` + let nttop = self.pbs_info().nttop_rlweq(); + let rlwe_modop = self.pbs_info().modop_rlweq(); + let ring_size = self.pbs_info().rlwe_n(); + let rlwe_q = self.parameters().rlwe_q(); - DefaultSecureRng::with_local_mut(|rng| { - let nttop = self.pbs_info().nttop_rlweq(); - let rlwe_modop = self.pbs_info().modop_rlweq(); - let ring_size = self.pbs_info().rlwe_n(); - let rlwe_q = self.parameters().rlwe_q(); - + let (ui_to_s_ksk, zero_encs_for_others) = DefaultSecureRng::with_local_mut(|rng| { // ui_to_s_ksk let non_interactive_decomposer = self .parameters() @@ -1383,10 +1489,13 @@ where }) .collect_vec(); - // Main Key gen follows // + (ui_to_s_ksk, zero_encs_for_others) + }); - let mut key_prng = - DefaultSecureRng::new_seeded(cr_seed.server_key_share_seed::()); + // Non-interactive RGSW cts = (a_i * u_j + e + \beta X^{s[i]}, a_i * s_j + e') + let ni_rgsw_cts = DefaultSecureRng::with_local_mut(|rng| { + let mut rgsw_cts_prng = + DefaultSecureRng::new_seeded(cr_seed.rgsw_cts_seed::()); // generate non-interactive rgsw cts let rgsw_by_rgsw_decomposer = self .parameters() @@ -1408,11 +1517,12 @@ where .map(|s_i| { // X^{s[i]} let mut m = M::R::zeros(ring_size); - if *s_i < 0 { + let s_i = s_i * (self.pbs_info().embedding_factor() as i32); + if s_i < 0 { // X^{-s[i]} -> -X^{N+s[i]} m.as_mut()[ring_size - (s_i.abs() as usize)] = rlwe_q.neg_one(); } else { - m.as_mut()[*s_i as usize] = M::MatElement::one(); + m.as_mut()[s_i as usize] = M::MatElement::one(); } non_interactive_rgsw_ct::( @@ -1420,20 +1530,112 @@ where client_key.sk_u_rlwe().values(), m.as_ref(), &ni_rgrg_gadget_vec, - &mut key_prng, + &mut rgsw_cts_prng, rng, nttop, rlwe_modop, ) }) .unzip(); + ni_rgsw_cts + }); - NonInteractiveMultiPartyServerKeyShare { - ni_rgsw_cts, - ui_to_s_ksk, - others_ksk_zero_encs: zero_encs_for_others, - user_index: self_index, + // Auto key share + let auto_keys_share = { + let auto_seed = cr_seed.auto_keys_cts_seed::(); + self._common_rountine_multi_party_auto_keys_share_gen(auto_seed, client_key.sk_rlwe()) + }; + + // Lwe Ksk share + let lwe_ksk_share = { + let lwe_ksk_seed = cr_seed.lwe_ksk_cts_seed::(); + self._common_rountine_multi_party_lwe_ksk_share_gen( + lwe_ksk_seed, + client_key.sk_rlwe(), + client_key.sk_lwe(), + ) + }; + + NonInteractiveMultiPartyServerKeyShare { + ni_rgsw_cts, + ui_to_s_ksk, + others_ksk_zero_encs: zero_encs_for_others, + user_index: self_index, + auto_keys_share, + lwe_ksk_share, + cr_seed: cr_seed.clone(), + } + } + + fn _common_rountine_multi_party_auto_keys_share_gen( + &self, + auto_seed: ::Seed, + sk_rlwe: &RlweSecret, + ) -> HashMap { + let g = self.pbs_info.parameters.g(); + let ring_size = self.pbs_info.parameters.rlwe_n().0; + let br_q = self.pbs_info.parameters.br_q(); + let rlweq_modop = &self.pbs_info.rlwe_modop; + let rlweq_nttop = &self.pbs_info.rlwe_nttop; + + DefaultSecureRng::with_local_mut(|rng| { + let mut p_rng = DefaultSecureRng::new_seeded(auto_seed); + + let mut auto_keys = HashMap::new(); + let auto_gadget = self.pbs_info.auto_decomposer.gadget_vector(); + let auto_element_dlogs = self.pbs_info.parameters.auto_element_dlogs(); + + for i in auto_element_dlogs.into_iter() { + let g_pow = if i == 0 { + -(g as isize) + } else { + (g.pow(i as u32) % br_q) as isize + }; + + let mut ksk_out = M::zeros( + self.pbs_info.auto_decomposer.decomposition_count(), + ring_size, + ); + galois_key_gen( + &mut ksk_out, + sk_rlwe.values(), + g_pow, + &auto_gadget, + rlweq_modop, + rlweq_nttop, + &mut p_rng, + rng, + ); + auto_keys.insert(i, ksk_out); } + + auto_keys + }) + } + + fn _common_rountine_multi_party_lwe_ksk_share_gen( + &self, + lwe_ksk_seed: ::Seed, + sk_rlwe: &RlweSecret, + sk_lwe: &LweSecret, + ) -> M::R { + DefaultSecureRng::with_local_mut(|rng| { + let mut p_rng = DefaultSecureRng::new_seeded(lwe_ksk_seed); + let mut lwe_ksk = M::R::zeros( + self.pbs_info.lwe_decomposer.decomposition_count() * self.parameters().rlwe_n().0, + ); + let lwe_modop = &self.pbs_info.lwe_modop; + let d_lwe_gadget_vec = self.pbs_info.lwe_decomposer.gadget_vector(); + lwe_ksk_keygen( + sk_rlwe.values(), + sk_lwe.values(), + &mut lwe_ksk, + &d_lwe_gadget_vec, + lwe_modop, + &mut p_rng, + rng, + ); + lwe_ksk }) } @@ -3104,7 +3306,7 @@ mod tests { }); }); - let mut stats = Stats::new(); + // let mut stats = Stats::new(); let (rlrg_decomp_a, rlrg_decomp_b) = evaluator .parameters() @@ -3118,46 +3320,46 @@ mod tests { rlwe_modop.elwise_neg_mut(&mut neg_s_poly_eval); nttop.forward(neg_s_poly_eval.as_mut()); - rgsw_cts.iter().enumerate().for_each(|(s_index, ct)| { - // X^{lwe_s[i]} - let mut m = vec![0u64; ring_size]; - if ideal_lwe[s_index] < 0 { - m[ring_size - (ideal_lwe[s_index].abs() as usize)] = rlwe_q.neg_one(); - } else { - m[(ideal_lwe[s_index] as usize)] = 1; - } + // rgsw_cts.iter().enumerate().for_each(|(s_index, ct)| { + // // X^{lwe_s[i]} + // let mut m = vec![0u64; ring_size]; + // if ideal_lwe[s_index] < 0 { + // m[ring_size - (ideal_lwe[s_index].abs() as usize)] = + // rlwe_q.neg_one(); } else { + // m[(ideal_lwe[s_index] as usize)] = 1; + // } - let mut neg_sm = m.clone(); - nttop.forward(&mut neg_sm); - rlwe_modop.elwise_mul_mut(&mut neg_sm, &neg_s_poly_eval); - nttop.backward(&mut neg_sm); + // let mut neg_sm = m.clone(); + // nttop.forward(&mut neg_sm); + // rlwe_modop.elwise_mul_mut(&mut neg_sm, &neg_s_poly_eval); + // nttop.backward(&mut neg_sm); - // RLWE'(-sm) - gadget_vec_a.iter().enumerate().for_each(|(index, beta)| { - // RLWE(\beta -sm) + // // RLWE'(-sm) + // gadget_vec_a.iter().enumerate().for_each(|(index, beta)| { + // // RLWE(\beta -sm) - // \beta * -sX^[lwe_s[i]] - let mut beta_neg_sm = neg_sm.clone(); - rlwe_modop.elwise_scalar_mul_mut(&mut beta_neg_sm, beta); + // // \beta * -sX^[lwe_s[i]] + // let mut beta_neg_sm = neg_sm.clone(); + // rlwe_modop.elwise_scalar_mul_mut(&mut beta_neg_sm, beta); - // extract RLWE(-sm \beta) - let mut rlwe = vec![vec![0u64; ring_size]; 2]; - rlwe[0].copy_from_slice(&ct[index]); - rlwe[1].copy_from_slice(&ct[index + d_a]); + // // extract RLWE(-sm \beta) + // let mut rlwe = vec![vec![0u64; ring_size]; 2]; + // rlwe[0].copy_from_slice(&ct[index]); + // rlwe[1].copy_from_slice(&ct[index + d_a]); - // decrypt - let mut m_out = vec![0u64; ring_size]; - decrypt_rlwe(&rlwe, &ideal_rlwe, &mut m_out, nttop, rlwe_modop); - // println!("{:?}", &beta_neg_sm); + // // decrypt + // let mut m_out = vec![0u64; ring_size]; + // decrypt_rlwe(&rlwe, &ideal_rlwe, &mut m_out, nttop, + // rlwe_modop); // println!("{:?}", &beta_neg_sm); - let mut diff = m_out; - rlwe_modop.elwise_sub_mut(&mut diff, &beta_neg_sm); + // let mut diff = m_out; + // rlwe_modop.elwise_sub_mut(&mut diff, &beta_neg_sm); - stats.add_more(&Vec::::try_convert_from(&diff, rlwe_q)); - }); - }); + // stats.add_more(&Vec::::try_convert_from(&diff, rlwe_q)); + // }); + // }); - println!("Stats: {}", stats.std_dev().abs().log2()); + // println!("Stats: {}", stats.std_dev().abs().log2()); } } diff --git a/src/bool/keys.rs b/src/bool/keys.rs index 9f8de2f..c8a0a12 100644 --- a/src/bool/keys.rs +++ b/src/bool/keys.rs @@ -12,6 +12,25 @@ use crate::{ use super::{parameters, BoolEvaluator, BoolParameters, CiphertextModulus}; +trait SinglePartyClientKey { + type Element; + fn sk_rlwe(&self) -> &[Self::Element]; + fn sk_lwe(&self) -> &[Self::Element]; +} + +trait InteractiveMultiPartyClientKey { + type Element; + fn sk_rlwe(&self) -> &[Self::Element]; + fn sk_lwe(&self) -> &[Self::Element]; +} + +trait NonInteractiveMultiPartyClientKey { + type Element; + fn sk_rlwe(&self) -> &[Self::Element]; + fn sk_u_rlwe(&self) -> &[Self::Element]; + fn sk_lwe(&self) -> &[Self::Element]; +} + /// Client key with RLWE and LWE secrets #[derive(Clone)] pub struct ClientKey { @@ -21,7 +40,7 @@ pub struct ClientKey { /// Client key with RLWE and LWE secrets #[derive(Clone)] -pub struct NonInteractiveClientKey { +pub struct ThrowMeAwayKey { sk_rlwe: RlweSecret, sk_u_rlwe: RlweSecret, sk_lwe: LweSecret, @@ -46,7 +65,7 @@ mod impl_ck { } // Client key - impl NonInteractiveClientKey { + impl ThrowMeAwayKey { pub(in super::super) fn new( sk_rlwe: RlweSecret, sk_u_rlwe: RlweSecret, @@ -369,7 +388,7 @@ impl SeededMultiPartyServerKey { } /// Seeded single party server key -pub struct SeededServerKey { +pub struct SeededSinglePartyServerKey { /// Rgsw cts of LWE secret elements pub(crate) rgsw_cts: Vec, /// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding @@ -382,7 +401,7 @@ pub struct SeededServerKey { /// Main seed pub(crate) seed: S, } -impl SeededServerKey, S> { +impl SeededSinglePartyServerKey, S> { pub(super) fn from_raw( auto_keys: HashMap, rgsw_cts: Vec, @@ -410,7 +429,7 @@ impl SeededServerKey, S> { == (parameters.lwe_decomposition_count().0 * parameters.rlwe_n().0) ); - SeededServerKey { + SeededSinglePartyServerKey { rgsw_cts, auto_keys, lwe_ksk, @@ -438,6 +457,7 @@ pub(super) mod impl_server_key_eval_domain { use crate::{ backend::Modulus, + bool::{NonInteractiveMultiPartyCrs, SeededNonInteractiveMultiPartyServerKey}, ntt::{Ntt, NttInit}, pbs::PbsKey, }; @@ -455,14 +475,16 @@ pub(super) mod impl_server_key_eval_domain { R: RandomFillUniformInModulus<[M::MatElement], CiphertextModulus> + NewWithSeed, N: NttInit> + Ntt, - > From<&SeededServerKey, R::Seed>> + > From<&SeededSinglePartyServerKey, R::Seed>> for ServerKeyEvaluationDomain, R, N> where ::R: RowMut, M::MatElement: Copy, R::Seed: Clone, { - fn from(value: &SeededServerKey, R::Seed>) -> Self { + fn from( + value: &SeededSinglePartyServerKey, R::Seed>, + ) -> Self { let mut main_prng = R::new_with_seed(value.seed.clone()); let parameters = &value.parameters; let g = parameters.g() as isize; @@ -697,7 +719,218 @@ pub(super) mod impl_server_key_eval_domain { } } -/// Server key in evaluation domain +pub(crate) struct NonInteractiveServerKeyEvaluationDomain { + /// RGSW ciphertexts ideal lwe secret key elements under ideal rlwe secret + rgsw_cts: Vec, + /// Automorphism keys under ideal rlwe secret + auto_keys: HashMap, + /// LWE key switching key from Q -> Q_{ks} + lwe_ksk: M, + /// Key switching key from user j to ideal secret key s. User j's ksk is at + /// j'th element + ui_to_s_ksks: Vec, + parameters: P, + _phanton: PhantomData<(R, N)>, +} + +pub(super) mod impl_non_interactive_server_key_eval_domain { + use itertools::{izip, Itertools}; + + use crate::{bool::NonInteractiveMultiPartyCrs, random::RandomFill, Ntt, NttInit}; + + use super::*; + + impl + From< + SeededNonInteractiveMultiPartyServerKey< + M, + NonInteractiveMultiPartyCrs, + BoolParameters, + >, + > for NonInteractiveServerKeyEvaluationDomain, Rng, N> + where + M: MatrixMut + MatrixEntity + Clone, + Rng: NewWithSeed + + RandomFillUniformInModulus<[M::MatElement], CiphertextModulus> + + RandomFill<::Seed>, + N: Ntt + NttInit>, + M::R: RowMut, + M::MatElement: Copy, + Rng::Seed: Clone + Copy + Default, + { + fn from( + value: SeededNonInteractiveMultiPartyServerKey< + M, + NonInteractiveMultiPartyCrs, + BoolParameters, + >, + ) -> Self { + let rlwe_nttop = N::new(value.parameters.rlwe_q(), value.parameters.rlwe_n().0); + let ring_size = value.parameters.rlwe_n().0; + + // RGSW cts + // copy over rgsw cts and send to evaluation domain + let mut rgsw_cts = value.rgsw_cts.clone(); + rgsw_cts.iter_mut().for_each(|c| { + c.iter_rows_mut() + .for_each(|ri| rlwe_nttop.forward(ri.as_mut())) + }); + + // Auto keys + // populate pseudo random part of auto keys. Then send auto keys to + // evaluation domain + let mut auto_keys = HashMap::new(); + let auto_seed = value.cr_seed.auto_keys_cts_seed::(); + let mut auto_prng = Rng::new_with_seed(auto_seed); + let auto_element_dlogs = value.parameters.auto_element_dlogs(); + let d_auto = value.parameters.auto_decomposition_count().0; + auto_element_dlogs.iter().for_each(|el| { + let auto_part_b = value + .auto_keys + .get(el) + .expect(&format!("Auto key for element g^{el} not found")); + + assert!(auto_part_b.dimension() == (d_auto, ring_size)); + + let mut auto_ct = M::zeros(d_auto, ring_size); + + // sample part A + auto_ct.iter_rows_mut().take(d_auto).for_each(|ri| { + RandomFillUniformInModulus::random_fill( + &mut auto_prng, + value.parameters.rlwe_q(), + ri.as_mut(), + ) + }); + + // Copy over part B + izip!( + auto_ct.iter_rows_mut().skip(d_auto), + auto_part_b.iter_rows() + ) + .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); + + // send to evaluation domain + auto_ct + .iter_rows_mut() + .for_each(|r| rlwe_nttop.forward(r.as_mut())); + + auto_keys.insert(*el, auto_ct); + }); + + // LWE ksk + // populate pseudo random part of lwe ciphertexts in ksk and copy over part b + // elements + let lwe_ksk_seed = value.cr_seed.lwe_ksk_cts_seed::(); + let mut lwe_ksk_prng = Rng::new_with_seed(lwe_ksk_seed); + let mut lwe_ksk = M::zeros( + value.parameters.lwe_decomposition_count().0 * ring_size, + value.parameters.lwe_n().0 + 1, + ); + lwe_ksk.iter_rows_mut().for_each(|ri| { + // first element is resereved for part b. Only sample a_is in the rest + RandomFillUniformInModulus::random_fill( + &mut lwe_ksk_prng, + value.parameters.lwe_q(), + &mut ri.as_mut()[1..], + ) + }); + // copy over part bs + izip!(value.lwe_ksk.as_ref().iter(), lwe_ksk.iter_rows_mut()).for_each( + |(b_el, lwe_ct)| { + lwe_ct.as_mut()[0] = *b_el; + }, + ); + + // u_i to s ksk + let d_uitos = value + .parameters + .non_interactive_ui_to_s_key_switch_decomposition_count() + .0; + let total_users = *value.ui_to_s_ksks_key_order.iter().max().unwrap(); + let ui_to_s_ksks = (0..total_users) + .map(|user_index| { + let user_i_seed = value.cr_seed.ui_to_s_ks_seed_for_user_i::(user_index); + let mut prng = Rng::new_with_seed(user_i_seed); + + let mut ksk_ct = M::zeros(d_uitos * 2, ring_size); + + ksk_ct.iter_rows_mut().take(d_uitos).for_each(|r| { + RandomFillUniformInModulus::random_fill( + &mut prng, + value.parameters.rlwe_q(), + r.as_mut(), + ); + }); + + let incoming_ksk_partb_ref = + &value.ui_to_s_ksks[value.ui_to_s_ksks_key_order[user_index]]; + assert!(ksk_ct.dimension() == (d_uitos, ring_size)); + izip!( + ksk_ct.iter_rows_mut().skip(d_uitos), + incoming_ksk_partb_ref.iter_rows() + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); + + ksk_ct + .iter_rows_mut() + .for_each(|r| rlwe_nttop.forward(r.as_mut())); + ksk_ct + }) + .collect_vec(); + + NonInteractiveServerKeyEvaluationDomain { + rgsw_cts, + auto_keys, + lwe_ksk, + ui_to_s_ksks, + parameters: value.parameters.clone(), + _phanton: PhantomData, + } + } + } +} + +pub struct SeededNonInteractiveMultiPartyServerKey { + /// u_i to s key switching keys in random order + ui_to_s_ksks: Vec, + /// Defines order for u_i to s key switchin keys by storing the index of + /// user j's ksk in `ui_to_s_ksks` at index `j`. Find user j's u_i to s ksk + /// at `ui_to_s_ksks[ui_to_s_ksks_key_order[j]]` + ui_to_s_ksks_key_order: Vec, + /// RGSW ciphertets + rgsw_cts: Vec, + auto_keys: HashMap, + lwe_ksk: M::R, + cr_seed: S, + parameters: P, +} + +impl SeededNonInteractiveMultiPartyServerKey { + pub(super) fn new( + ui_to_s_ksks: Vec, + ui_to_s_ksks_key_order: Vec, + rgsw_cts: Vec, + auto_keys: HashMap, + lwe_ksk: M::R, + cr_seed: S, + parameters: P, + ) -> Self { + Self { + ui_to_s_ksks, + ui_to_s_ksks_key_order, + rgsw_cts, + auto_keys, + lwe_ksk, + cr_seed, + parameters, + } + } +} + +/// Server key in evaluation domain with Shoup representations pub(crate) struct ShoupServerKeyEvaluationDomain { /// Rgsw cts of LWE secret elements rgsw_cts: Vec>, diff --git a/src/bool/mod.rs b/src/bool/mod.rs index 9fb0318..4a06c2b 100644 --- a/src/bool/mod.rs +++ b/src/bool/mod.rs @@ -1,8 +1,12 @@ pub(crate) mod evaluator; pub(crate) mod keys; -pub mod noise; +mod mp_api; +mod ni_mp_api; +mod noise; pub(crate) mod parameters; +pub use mp_api::*; + pub type FheBool = Vec; use std::{cell::RefCell, sync::OnceLock}; @@ -10,171 +14,3 @@ use std::{cell::RefCell, sync::OnceLock}; use evaluator::*; use keys::*; use parameters::*; - -use crate::{ - backend::{ModularOpsU64, ModulusPowerOf2}, - ntt::NttBackendU64, - random::{DefaultSecureRng, NewWithSeed}, - utils::{Global, WithLocal}, -}; - -thread_local! { - static BOOL_EVALUATOR: RefCell>, NttBackendU64, ModularOpsU64>, ModulusPowerOf2>, ShoupServerKeyEvaluationDomain>>>>> = RefCell::new(None); - -} -static BOOL_SERVER_KEY: OnceLock>>> = OnceLock::new(); - -static MULTI_PARTY_CRS: OnceLock> = OnceLock::new(); - -pub enum ParameterSelector { - MultiPartyLessThanOrEqualTo16, -} - -pub fn set_parameter_set(select: ParameterSelector) { - match select { - ParameterSelector::MultiPartyLessThanOrEqualTo16 => { - BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(SMALL_MP_BOOL_PARAMS))); - } - } -} - -pub fn set_mp_seed(seed: [u8; 32]) { - assert!( - MULTI_PARTY_CRS.set(MultiPartyCrs { seed: seed }).is_ok(), - "Attempted to set MP SEED twice." - ) -} - -fn set_server_key(key: ShoupServerKeyEvaluationDomain>>) { - assert!( - BOOL_SERVER_KEY.set(key).is_ok(), - "Attempted to set server key twice." - ); -} - -pub(crate) fn gen_keys() -> ( - ClientKey, - SeededServerKey>, BoolParameters, [u8; 32]>, -) { - BoolEvaluator::with_local_mut(|e| { - let ck = e.client_key(); - let sk = e.single_party_server_key(&ck); - - (ck, sk) - }) -} - -pub fn gen_client_key() -> ClientKey { - BoolEvaluator::with_local(|e| e.client_key()) -} - -pub fn gen_mp_keys_phase1( - ck: &ClientKey, -) -> CommonReferenceSeededCollectivePublicKeyShare, [u8; 32], BoolParameters> { - let seed = MultiPartyCrs::global().public_key_share_seed::(); - BoolEvaluator::with_local(|e| { - let pk_share = e.multi_party_public_key_share(seed, &ck); - pk_share - }) -} - -pub fn gen_mp_keys_phase2( - ck: &ClientKey, - pk: &PublicKey>, R, ModOp>, -) -> CommonReferenceSeededMultiPartyServerKeyShare>, BoolParameters, [u8; 32]> { - let seed = MultiPartyCrs::global().server_key_share_seed::(); - BoolEvaluator::with_local_mut(|e| { - let server_key_share = e.multi_party_server_key_share(seed, pk.key(), ck); - server_key_share - }) -} - -pub fn aggregate_public_key_shares( - shares: &[CommonReferenceSeededCollectivePublicKeyShare< - Vec, - [u8; 32], - BoolParameters, - >], -) -> PublicKey>, DefaultSecureRng, ModularOpsU64>> { - PublicKey::from(shares) -} - -pub fn aggregate_server_key_shares( - shares: &[CommonReferenceSeededMultiPartyServerKeyShare< - Vec>, - BoolParameters, - [u8; 32], - >], -) -> SeededMultiPartyServerKey>, [u8; 32], BoolParameters> { - BoolEvaluator::with_local(|e| e.aggregate_multi_party_server_key_shares(shares)) -} - -// SERVER KEY EVAL (/SHOUP) DOMAIN // -impl SeededServerKey>, BoolParameters, [u8; 32]> { - pub fn set_server_key(&self) { - let eval = ServerKeyEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from(self); - set_server_key(ShoupServerKeyEvaluationDomain::from(eval)); - } -} - -impl - SeededMultiPartyServerKey< - Vec>, - ::Seed, - BoolParameters, - > -{ - pub fn set_server_key(&self) { - set_server_key(ShoupServerKeyEvaluationDomain::from( - ServerKeyEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from(self), - )) - } -} - -// MULTIPARTY CRS // -impl Global for MultiPartyCrs<[u8; 32]> { - fn global() -> &'static Self { - MULTI_PARTY_CRS - .get() - .expect("Multi Party Common Reference String not set") - } -} - -// BOOL EVALUATOR // -impl WithLocal - for BoolEvaluator< - Vec>, - NttBackendU64, - ModularOpsU64>, - ModulusPowerOf2>, - ShoupServerKeyEvaluationDomain>>, - > -{ - fn with_local(func: F) -> R - where - F: Fn(&Self) -> R, - { - BOOL_EVALUATOR.with_borrow(|s| func(s.as_ref().expect("Parameters not set"))) - } - - fn with_local_mut(func: F) -> R - where - F: Fn(&mut Self) -> R, - { - BOOL_EVALUATOR.with_borrow_mut(|s| func(s.as_mut().expect("Parameters not set"))) - } - - fn with_local_mut_mut(func: &mut F) -> R - where - F: FnMut(&mut Self) -> R, - { - BOOL_EVALUATOR.with_borrow_mut(|s| func(s.as_mut().expect("Parameters not set"))) - } -} - -pub(crate) type RuntimeServerKey = ShoupServerKeyEvaluationDomain>>; -impl Global for RuntimeServerKey { - fn global() -> &'static Self { - BOOL_SERVER_KEY.get().expect("Server key not set!") - } -} diff --git a/src/bool/mp_api.rs b/src/bool/mp_api.rs new file mode 100644 index 0000000..cb3ba33 --- /dev/null +++ b/src/bool/mp_api.rs @@ -0,0 +1,169 @@ +use crate::{ + backend::{ModularOpsU64, ModulusPowerOf2}, + ntt::NttBackendU64, + random::{DefaultSecureRng, NewWithSeed}, + utils::{Global, WithLocal}, +}; + +use super::*; + +thread_local! { + static BOOL_EVALUATOR: RefCell>, NttBackendU64, ModularOpsU64>, ModulusPowerOf2>, ShoupServerKeyEvaluationDomain>>>>> = RefCell::new(None); + +} +static BOOL_SERVER_KEY: OnceLock>>> = OnceLock::new(); + +static MULTI_PARTY_CRS: OnceLock> = OnceLock::new(); + +pub enum ParameterSelector { + MultiPartyLessThanOrEqualTo16, +} + +pub fn set_parameter_set(select: ParameterSelector) { + match select { + ParameterSelector::MultiPartyLessThanOrEqualTo16 => { + BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(SMALL_MP_BOOL_PARAMS))); + } + } +} + +pub fn set_mp_seed(seed: [u8; 32]) { + assert!( + MULTI_PARTY_CRS.set(MultiPartyCrs { seed: seed }).is_ok(), + "Attempted to set MP SEED twice." + ) +} + +fn set_server_key(key: ShoupServerKeyEvaluationDomain>>) { + assert!( + BOOL_SERVER_KEY.set(key).is_ok(), + "Attempted to set server key twice." + ); +} + +pub(crate) fn gen_keys() -> ( + ClientKey, + SeededSinglePartyServerKey>, BoolParameters, [u8; 32]>, +) { + BoolEvaluator::with_local_mut(|e| { + let ck = e.client_key(); + let sk = e.single_party_server_key(&ck); + + (ck, sk) + }) +} + +pub fn gen_client_key() -> ClientKey { + BoolEvaluator::with_local(|e| e.client_key()) +} + +pub fn gen_mp_keys_phase1( + ck: &ClientKey, +) -> CommonReferenceSeededCollectivePublicKeyShare, [u8; 32], BoolParameters> { + let seed = MultiPartyCrs::global().public_key_share_seed::(); + BoolEvaluator::with_local(|e| { + let pk_share = e.multi_party_public_key_share(seed, &ck); + pk_share + }) +} + +pub fn gen_mp_keys_phase2( + ck: &ClientKey, + pk: &PublicKey>, R, ModOp>, +) -> CommonReferenceSeededMultiPartyServerKeyShare>, BoolParameters, [u8; 32]> { + let seed = MultiPartyCrs::global().server_key_share_seed::(); + BoolEvaluator::with_local_mut(|e| { + let server_key_share = e.multi_party_server_key_share(seed, pk.key(), ck); + server_key_share + }) +} + +pub fn aggregate_public_key_shares( + shares: &[CommonReferenceSeededCollectivePublicKeyShare< + Vec, + [u8; 32], + BoolParameters, + >], +) -> PublicKey>, DefaultSecureRng, ModularOpsU64>> { + PublicKey::from(shares) +} + +pub fn aggregate_server_key_shares( + shares: &[CommonReferenceSeededMultiPartyServerKeyShare< + Vec>, + BoolParameters, + [u8; 32], + >], +) -> SeededMultiPartyServerKey>, [u8; 32], BoolParameters> { + BoolEvaluator::with_local(|e| e.aggregate_multi_party_server_key_shares(shares)) +} + +// SERVER KEY EVAL (/SHOUP) DOMAIN // +impl SeededSinglePartyServerKey>, BoolParameters, [u8; 32]> { + pub fn set_server_key(&self) { + let eval = ServerKeyEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from(self); + set_server_key(ShoupServerKeyEvaluationDomain::from(eval)); + } +} + +impl + SeededMultiPartyServerKey< + Vec>, + ::Seed, + BoolParameters, + > +{ + pub fn set_server_key(&self) { + set_server_key(ShoupServerKeyEvaluationDomain::from( + ServerKeyEvaluationDomain::<_, _, DefaultSecureRng, NttBackendU64>::from(self), + )) + } +} + +// MULTIPARTY CRS // +impl Global for MultiPartyCrs<[u8; 32]> { + fn global() -> &'static Self { + MULTI_PARTY_CRS + .get() + .expect("Multi Party Common Reference String not set") + } +} + +// BOOL EVALUATOR // +impl WithLocal + for BoolEvaluator< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModulusPowerOf2>, + ShoupServerKeyEvaluationDomain>>, + > +{ + fn with_local(func: F) -> R + where + F: Fn(&Self) -> R, + { + BOOL_EVALUATOR.with_borrow(|s| func(s.as_ref().expect("Parameters not set"))) + } + + fn with_local_mut(func: F) -> R + where + F: Fn(&mut Self) -> R, + { + BOOL_EVALUATOR.with_borrow_mut(|s| func(s.as_mut().expect("Parameters not set"))) + } + + fn with_local_mut_mut(func: &mut F) -> R + where + F: FnMut(&mut Self) -> R, + { + BOOL_EVALUATOR.with_borrow_mut(|s| func(s.as_mut().expect("Parameters not set"))) + } +} + +pub(crate) type RuntimeServerKey = ShoupServerKeyEvaluationDomain>>; +impl Global for RuntimeServerKey { + fn global() -> &'static Self { + BOOL_SERVER_KEY.get().expect("Server key not set!") + } +} diff --git a/src/bool/ni_mp_api.rs b/src/bool/ni_mp_api.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/bool/noise.rs b/src/bool/noise.rs index 7674f93..8979c8b 100644 --- a/src/bool/noise.rs +++ b/src/bool/noise.rs @@ -4,9 +4,9 @@ mod test { use crate::{ backend::{ArithmeticOps, ModularOpsU64, Modulus, ModulusPowerOf2}, bool::{ - set_parameter_set, BoolEncoding, BoolEvaluator, BooleanGates, CiphertextModulus, - ClientKey, PublicKey, ServerKeyEvaluationDomain, ShoupServerKeyEvaluationDomain, - MP_BOOL_PARAMS, SMALL_MP_BOOL_PARAMS, + BoolEncoding, BoolEvaluator, BooleanGates, CiphertextModulus, ClientKey, PublicKey, + ServerKeyEvaluationDomain, ShoupServerKeyEvaluationDomain, MP_BOOL_PARAMS, + SMALL_MP_BOOL_PARAMS, }, lwe::{decrypt_lwe, LweSecret}, ntt::NttBackendU64,