diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 7ae85f8..0b221dd 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -8,6 +8,7 @@ use std::{ marker::PhantomData, ops::Shr, sync::OnceLock, + usize, }; use itertools::{izip, partition, Itertools}; @@ -21,7 +22,10 @@ use crate::{ }, decomposer::{Decomposer, DefaultDecomposer, NumInfo, RlweDecomposer}, lwe::{decrypt_lwe, encrypt_lwe, lwe_key_switch, lwe_ksk_keygen, measure_noise_lwe, LweSecret}, - multi_party::{non_interactive_ksk_gen, public_key_share}, + multi_party::{ + non_interactive_ksk_gen, non_interactive_ksk_zero_encryptions_for_other_party_i, + non_interactive_rgsw_ct, public_key_share, + }, ntt::{self, Ntt, NttBackendU64, NttInit}, pbs::{pbs, sample_extract, PbsInfo, PbsKey, WithShoupRepr}, random::{ @@ -45,14 +49,77 @@ use super::{ keys::ClientKey, parameters::{BoolParameters, CiphertextModulus}, CommonReferenceSeededCollectivePublicKeyShare, CommonReferenceSeededMultiPartyServerKeyShare, - NonInteractiveClientKey, SeededMultiPartyServerKey, SeededServerKey, ServerKeyEvaluationDomain, + DecompositionCount, DecompostionLogBase, DoubleDecomposerParams, NonInteractiveClientKey, + SeededMultiPartyServerKey, SeededServerKey, ServerKeyEvaluationDomain, ShoupServerKeyEvaluationDomain, }; +pub struct NonInteractiveMultiPartyServerKeyShare { + /// (ak*si + e + \beta ui, ak*si + e) + ni_rgsw_cts: (Vec, Vec), + ui_to_s_ksk: M, + others_ksk_zero_encs: Vec, + user_index: usize, +} + +impl NonInteractiveMultiPartyServerKeyShare { + fn zero_enc_for_ui_to_s_ksk_for_user_i(&self, user_i: usize) -> &M { + assert!(user_i != self.user_index); + if user_i < self.user_index { + &self.others_ksk_zero_encs[user_i] + } else { + &self.others_ksk_zero_encs[user_i - 1] + } + } +} + pub struct MultiPartyCrs { pub(super) seed: S, } +pub struct NonInteractiveMultiPartyCrs { + pub(super) seed: S, +} + +impl NonInteractiveMultiPartyCrs { + fn server_key_share_seed + RandomFill>(&self) -> S { + let mut p_rng = R::new_with_seed(self.seed); + + // for main server key share seed sample once + let mut out = S::default(); + RandomFill::::random_fill(&mut p_rng, &mut out); + + out + } + + fn ui_to_s_ks_seed + RandomFill>(&self) -> S { + let mut p_rng = R::new_with_seed(self.seed); + + // puncture twice + let mut out = S::default(); + RandomFill::::random_fill(&mut p_rng, &mut out); + RandomFill::::random_fill(&mut p_rng, &mut out); + + out + } + + fn ui_to_s_ks_seed_for_user_i + RandomFill>( + &self, + user_i: usize, + ) -> S { + let ks_seed = self.ui_to_s_ks_seed::(); + let mut p_rng = R::new_with_seed(ks_seed); + + // puncture user_i times + let mut out = S::default(); + for _ in 0..user_i { + RandomFill::::random_fill(&mut p_rng, &mut out); + } + + out + } +} + impl MultiPartyCrs { /// Seed to generate public key share using MultiPartyCrs as the main seed. /// @@ -322,6 +389,94 @@ impl } } +fn trim_rgsw_ct_matrix_from_rgrg_to_rlrg< + M: MatrixMut + MatrixEntity, + D: DoubleDecomposerParams, +>( + rgsw_ct_in: M, + rgrg_params: D, + rlrg_params: D, +) -> M +where + M::R: RowMut, + M::MatElement: Copy, +{ + let (rgswrgsw_d_a, rgswrgsw_d_b) = ( + rgrg_params.decomposition_count_a(), + rgrg_params.decomposition_count_b(), + ); + let (rlrg_d_a, rlrg_d_b) = ( + rlrg_params.decomposition_count_a(), + rlrg_params.decomposition_count_b(), + ); + let rgsw_ct_rows_in = rgswrgsw_d_a.0 * 2 + rgswrgsw_d_b.0 * 2; + let rgsw_ct_rows_out = rlrg_d_a.0 * 2 + rlrg_d_b.0 * 2; + assert!(rgsw_ct_in.dimension().0 == rgsw_ct_rows_in); + assert!(rgswrgsw_d_a.0 >= rlrg_d_a.0, "RGSWxRGSW part A decomposition count {} must be >= RLWExRGSW part A decomposition count {}", rgswrgsw_d_a.0 , rlrg_d_a.0); + assert!(rgswrgsw_d_b.0 >= rlrg_d_b.0, "RGSWxRGSW part B decomposition count {} must be >= RLWExRGSW part B decomposition count {}", rgswrgsw_d_b.0 , rlrg_d_b.0); + + let mut reduced_ct_i_out = M::zeros(rgsw_ct_rows_out, rgsw_ct_in.dimension().1); + + // RLWE'(-sm) part A + izip!( + reduced_ct_i_out.iter_rows_mut().take(rlrg_d_a.0), + rgsw_ct_in + .iter_rows() + .skip(rgswrgsw_d_a.0 - rlrg_d_a.0) + .take(rlrg_d_a.0) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); + + // RLWE'(-sm) part B + izip!( + reduced_ct_i_out + .iter_rows_mut() + .skip(rlrg_d_a.0) + .take(rlrg_d_a.0), + rgsw_ct_in + .iter_rows() + .skip(rgswrgsw_d_a.0 + (rgswrgsw_d_a.0 - rlrg_d_a.0)) + .take(rlrg_d_a.0) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); + + // RLWE'(m) Part A + izip!( + reduced_ct_i_out + .iter_rows_mut() + .skip(rlrg_d_a.0 * 2) + .take(rlrg_d_b.0), + rgsw_ct_in + .iter_rows() + .skip(rgswrgsw_d_a.0 * 2 + (rgswrgsw_d_b.0 - rlrg_d_b.0)) + .take(rlrg_d_b.0) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); + + // RLWE'(m) Part B + izip!( + reduced_ct_i_out + .iter_rows_mut() + .skip(rlrg_d_a.0 * 2 + rlrg_d_b.0) + .take(rlrg_d_b.0), + rgsw_ct_in + .iter_rows() + .skip(rgswrgsw_d_a.0 * 2 + rgswrgsw_d_b.0 + (rgswrgsw_d_b.0 - rlrg_d_b.0)) + .take(rlrg_d_b.0) + ) + .for_each(|(to_ri, from_ri)| { + to_ri.as_mut().copy_from_slice(from_ri.as_ref()); + }); + + reduced_ct_i_out +} + impl BoolEvaluator where M: MatrixEntity + MatrixMut, @@ -481,6 +636,22 @@ where ClientKey::new(sk_rlwe, sk_lwe) } + pub(super) fn non_interactive_client_key(&self) -> NonInteractiveClientKey { + let sk_lwe = LweSecret::random( + self.pbs_info.parameters.lwe_n().0 >> 1, + self.pbs_info.parameters.lwe_n().0, + ); + let sk_rlwe = RlweSecret::random( + self.pbs_info.parameters.rlwe_n().0 >> 1, + self.pbs_info.parameters.rlwe_n().0, + ); + let sk_u_rlwe = RlweSecret::random( + self.pbs_info.parameters.rlwe_n().0 >> 1, + self.pbs_info.parameters.rlwe_n().0, + ); + NonInteractiveClientKey::new(sk_rlwe, sk_u_rlwe, sk_lwe) + } + pub(super) fn single_party_server_key( &self, client_key: &ClientKey, @@ -714,21 +885,547 @@ where }) } + pub(super) fn aggregate_non_interactive_multi_party_key_share( + &self, + cr_seed: &NonInteractiveMultiPartyCrs<[u8; 32]>, + total_users: usize, + key_shares: &[NonInteractiveMultiPartyServerKeyShare], + ) -> Vec + where + M: Clone, + { + let rlwe_modop = &self.pbs_info().rlwe_modop; + let nttop = &self.pbs_info().rlwe_nttop; + let ring_size = self.parameters().rlwe_n().0; + let rlwe_q = self.parameters().rlwe_q(); + + // genrate key switching key from u_i to s + let ui_to_s_ksk_decomposition_count = self + .parameters() + .non_interactive_ui_to_s_key_switch_decomposition_count(); + let mut ui_to_s_ksks = key_shares + .iter() + .map(|share| { + let mut useri_ui_to_s_ksk = share.ui_to_s_ksk.clone(); + key_shares + .iter() + .filter(|x| x.user_index != share.user_index) + .for_each(|(other_share)| { + let op2 = other_share.zero_enc_for_ui_to_s_ksk_for_user_i(share.user_index); + assert!(op2.dimension() == (ui_to_s_ksk_decomposition_count.0, ring_size)); + izip!(useri_ui_to_s_ksk.iter_rows_mut(), op2.iter_rows()).for_each( + |(add_to, add_from)| { + rlwe_modop.elwise_add_mut(add_to.as_mut(), add_from.as_ref()) + }, + ); + }); + useri_ui_to_s_ksk + }) + .collect_vec(); + + let mut key_prng = + DefaultSecureRng::new_seeded(cr_seed.server_key_share_seed::()); + + let rgsw_by_rgsw_decomposer = self + .parameters() + .rgsw_rgsw_decomposer::>(); + + // Generate RGSW Cts + let rgsw_cts_all_users_eval = { + // temporarily put ui_to_s in evaluation domain and sample a_i's for u_i to s + // ksk for upcomong key switches + ui_to_s_ksks.iter_mut().for_each(|ksk_i| { + ksk_i + .iter_rows_mut() + .for_each(|r| nttop.forward(r.as_mut())) + }); + let ui_to_s_ksks_part_a_eval = key_shares + .iter() + .map(|share| { + let mut ksk_prng = DefaultSecureRng::new_seeded( + cr_seed.ui_to_s_ks_seed_for_user_i::(share.user_index), + ); + let mut ais = M::zeros(ui_to_s_ksk_decomposition_count.0, ring_size); + ais.iter_rows_mut().for_each(|r_ai| { + RandomFillUniformInModulus::random_fill( + &mut ksk_prng, + rlwe_q, + r_ai.as_mut(), + ); + nttop.forward(r_ai.as_mut()) + }); + ais + }) + .collect_vec(); + + let max_rgrg_deocmposer = if rgsw_by_rgsw_decomposer.a().decomposition_count() + > rgsw_by_rgsw_decomposer.b().decomposition_count() + { + rgsw_by_rgsw_decomposer.a() + } else { + rgsw_by_rgsw_decomposer.b() + }; + + let ui_to_s_ksk_decomposer = self + .parameters() + .non_interactive_ui_to_s_key_switch_decomposer::>( + ); + + // Generate a_i*s + E = \sum_{j \in P} a_i*s_j + e for all rlwes in each + // non-interactive rgsws. Then decompose and put decompositions in evaluation + // for u_i -> s key switch. + let decomp_ni_rgsws_part_1_acc = { + let mut tmp_space = M::R::zeros(ring_size); + + (0..self.parameters().lwe_n().0) + .into_iter() + .map(|lwe_index| { + (0..max_rgrg_deocmposer.decomposition_count()) + .into_iter() + .map(|d_index| { + let mut sum = M::zeros( + ui_to_s_ksk_decomposer.decomposition_count(), + ring_size, + ); + + // a_i*s + E + key_shares.iter().for_each(|s| { + rlwe_modop.elwise_add_mut( + tmp_space.as_mut(), + s.ni_rgsw_cts.1[lwe_index].get_row_slice(d_index), + ); + }); + + tmp_space.as_ref().iter().enumerate().for_each(|(ri, el)| { + ui_to_s_ksk_decomposer + .decompose_iter(el) + .enumerate() + .for_each(|(row_j, d_el)| { + (sum.as_mut()[row_j]).as_mut()[ri] = d_el; + }); + }); + + sum.iter_rows_mut().for_each(|r| nttop.forward(r.as_mut())); + + sum + }) + .collect_vec() + }) + .collect_vec() + }; + + // Sample a_i's are used to generate non-interactive rgsw cts for all lwe + // indices. Since a_i are just needed for key switches , decompose them + // and put them in evaluation domain + // Decomposition count used for RGSW ct in non-interactive key share gen equals + // max of A and B decomposition required in RGSWxRGSW. This is because same + // polynomials are used to generate RLWE'(m) and RLWE'(-sm) + let decomp_ni_rgsw_neg_ais = + { + let mut tmp_space = M::R::zeros(ring_size); + (0..self.parameters().lwe_n().0) + .into_iter() + .map(|_| { + // FIXME(Jay): well well, ais are only required for key switching to + // generate RLWE'(m) and RLWE'(m) itself + // requires RLWE(\beta^i m) for i \in RGSWxRGSW + // part B decomposition count. However, we still need to puncture prng + // for ais corresponding to ignored limbs. + // Probably it will be nice idea to + // avoid decomposition after punturing for a_i's corresponding to + // ignored limbs. Moreover, note that, for + // RGSWxRGSW often times decompostion count for part A + // > part B. Hence, it's very likely that we are doing + // unecesary decompositions for ignored limbs all the time. + (0..max_rgrg_deocmposer.decomposition_count()) + .map(|_| { + RandomFillUniformInModulus::random_fill( + &mut key_prng, + rlwe_q, + tmp_space.as_mut(), + ); + + // negate + rlwe_modop.elwise_neg_mut(tmp_space.as_mut()); + + // decomposer a_i for ui -> s key switch + let mut decomp_neg_ai = M::zeros( + ui_to_s_ksk_decomposer.decomposition_count(), + ring_size, + ); + tmp_space.as_ref().iter().enumerate().for_each( + |(index, el)| { + ui_to_s_ksk_decomposer + .decompose_iter(el) + .enumerate() + .for_each(|(row_j, d_el)| { + (decomp_neg_ai.as_mut()[row_j]).as_mut() + [index] = d_el; + }); + }, + ); + + // put in evaluation domain + decomp_neg_ai + .iter_rows_mut() + .for_each(|r| nttop.forward(r.as_mut())); + + decomp_neg_ai + }) + .collect_vec() + }) + .collect_vec() + }; + + // genrate RGSW cts + let rgsw_cts_all_users_eval = izip!( + key_shares.iter(), + ui_to_s_ksks.iter(), + ui_to_s_ksks_part_a_eval.iter() + ) + .map( + |(share, user_uitos_ksk_partb_eval, user_uitos_ksk_parta_eval)| { + // RGSW_s(X^{s[i]}) + let rgsw_cts_user_i_eval = izip!( + share.ni_rgsw_cts.0.iter(), + decomp_ni_rgsw_neg_ais.iter(), + decomp_ni_rgsws_part_1_acc.iter() + ) + .map( + |( + m_encs_under_ui, + decomposed_rgsw_i_neg_ais, + decomposed_acc_rgsw_part1, + )| { + let d_a = rgsw_by_rgsw_decomposer.a().decomposition_count(); + let d_b = rgsw_by_rgsw_decomposer.b().decomposition_count(); + let max_d = std::cmp::max(d_a, d_b); + + assert!(decomposed_rgsw_i_neg_ais.len() == max_d); + + // To be RGSW(X^{s[i]}) = [RLWE'(-sm), RLWE'(m)] + let mut rgsw_ct_eval = M::zeros(d_a * 2 + d_b * 2, ring_size); + let (rlwe_dash_nsm, rlwe_dash_m) = + rgsw_ct_eval.split_at_row_mut(d_a * 2); + + let mut scratch_row = M::R::zeros(ring_size); + + let mut m_encs_under_ui_eval = m_encs_under_ui.clone(); + m_encs_under_ui_eval + .iter_rows_mut() + .for_each(|r| nttop.forward(r.as_mut())); + + // RLWE(-sm) + { + // Recall that we have RLWE(a_i * s + e). We key + // switch RLWE(a_i * s + e) using ksk(u_i -> s) to + // get RLWE(a_i*s*u_i + e*u_i). + // Given (u_i * a_i + e + \beta m), we obtain + // RLWE(-sm) = RLWE(a_i*s*u_i + e*u_i) + (0, (u_i * + // a_i + e + \beta m)) + // + // Again RLWE'(-sm) only cares for RLWE(\beta -sm) + // with scaling factor \beta corresponding to most + // signficant d_a limbs. Hence, we skip (d_max - + // d_a) least signficant limbs + + let (rlwe_dash_neg_sm_part_a, rlwe_dash_neg_sm_part_b) = + rlwe_dash_nsm.split_at_mut(d_a); + + izip!( + rlwe_dash_neg_sm_part_a.iter_mut(), + rlwe_dash_neg_sm_part_b.iter_mut(), + decomposed_acc_rgsw_part1.iter().skip(max_d - d_a), + m_encs_under_ui_eval.iter_rows().skip(max_d - d_a) + ) + .for_each( + |(rlwe_a, rlwe_b, decomp_ai_s, beta_m_enc_ui)| { + // RLWE_s(a_i * s * u_i + u_i * e) = decomp * + // ksk(u_i -> s) + izip!( + decomp_ai_s.iter_rows(), + user_uitos_ksk_partb_eval.iter_rows(), + user_uitos_ksk_parta_eval.iter_rows() + ) + .for_each( + |(a0, part_b, part_a)| { + // rlwe_b += decomp[i] * ksk part_b[i] + rlwe_modop.elwise_mul( + scratch_row.as_mut(), + a0.as_ref(), + part_b.as_ref(), + ); + rlwe_modop.elwise_add_mut( + rlwe_b.as_mut(), + scratch_row.as_ref(), + ); + + // rlwe_a += decomp[i] * ksk + // part_a[i] + rlwe_modop.elwise_mul( + scratch_row.as_mut(), + a0.as_ref(), + part_a.as_ref(), + ); + rlwe_modop.elwise_add_mut( + rlwe_a.as_mut(), + scratch_row.as_ref(), + ); + }, + ); + + // RLWE_s(-sm) = RLWE_s(a_i * s * u_i + u_i + // * e) + (0, a_i * u + e + m) + rlwe_modop.elwise_add_mut( + rlwe_a.as_mut(), + beta_m_enc_ui.as_ref(), + ); + }, + ); + } + + // RLWE(m) + { + // Routine: + // Let RLWE(-a_i * u_i) = (decomp<-a_i> \cdot ksk(u_i -> s)), + // then RLWE(m) = RLWE(a'*s + e + m) = (a_i + // * u_i + e + m, 0) + RLWE(-a_i * u_i) + // + // Since RLWE'(m) only cares for RLWE ciphertexts corresponding + // to higher d_b limbs, we skip routine 1 for lower max(d_a, + // d_b) - d_b limbs + let (rlwe_dash_m_part_a, rlwe_dash_m_part_b) = + rlwe_dash_m.split_at_mut(d_b); + + izip!( + rlwe_dash_m_part_a.iter_mut(), + rlwe_dash_m_part_b.iter_mut(), + decomposed_rgsw_i_neg_ais.iter().skip(max_d - d_b), + m_encs_under_ui_eval.iter_rows().skip(max_d - d_b) + ) + .for_each( + |(rlwe_a, rlwe_b, decomp_neg_ai, beta_m_enc_ui)| { + // RLWE_s(-a_i * ui) = decomp<-a_i> \cdot ksk(ui -> s) + izip!( + decomp_neg_ai.iter_rows(), + user_uitos_ksk_partb_eval.iter_rows(), + user_uitos_ksk_parta_eval.iter_rows() + ) + .for_each( + |(a0, part_b, part_a)| { + // rlwe_b += decomp[i] * ksk part_b[i] + rlwe_modop.elwise_mul( + scratch_row.as_mut(), + a0.as_ref(), + part_b.as_ref(), + ); + rlwe_modop.elwise_add_mut( + rlwe_b.as_mut(), + scratch_row.as_ref(), + ); + + // rlwe_a += decomp[i] * ksk + // part_a[i] + rlwe_modop.elwise_mul( + scratch_row.as_mut(), + a0.as_ref(), + part_a.as_ref(), + ); + rlwe_modop.elwise_add_mut( + rlwe_a.as_mut(), + scratch_row.as_ref(), + ); + }, + ); + + // RLWE_s(m) = (a_i * ui + e + \beta m, 0) + // + + // RLWE_s(-a_i * ui) + rlwe_modop.elwise_add_mut( + rlwe_b.as_mut(), + beta_m_enc_ui.as_ref(), + ); + }, + ); + } + + rgsw_ct_eval + }, + ) + .collect_vec(); + rgsw_cts_user_i_eval + }, + ) + .collect_vec(); + + // put u_i -> s ksks back in coefficient domain + ui_to_s_ksks.iter_mut().for_each(|ksk_i| { + ksk_i + .iter_rows_mut() + .for_each(|r| nttop.backward(r.as_mut())) + }); + + rgsw_cts_all_users_eval + }; + + // RGSW x RGSW + let lwe_n = self.parameters().lwe_n().0; + let mut scratch_matrix = M::zeros( + std::cmp::max( + rgsw_by_rgsw_decomposer.a().decomposition_count(), + rgsw_by_rgsw_decomposer.b().decomposition_count(), + ) + (rgsw_by_rgsw_decomposer.a().decomposition_count() * 2 + + rgsw_by_rgsw_decomposer.b().decomposition_count() * 2), + ring_size, + ); + let rgsw_cts_untrimmed = (0..lwe_n).map(|s_index| { + // copy over s_index^th rgsw ct of user 0. Use it to accumulate RGSW products of + // all RGSW ciphertexts at s_index + let mut rgsw_i = rgsw_cts_all_users_eval[0][s_index].clone(); + rgsw_i + .iter_rows_mut() + .for_each(|r| nttop.backward(r.as_mut())); + + rgsw_cts_all_users_eval + .iter() + .skip(1) + .for_each(|user_i_rgsws| { + rgsw_by_rgsw_inplace( + &mut rgsw_i, + &user_i_rgsws[s_index], + &rgsw_by_rgsw_decomposer, + &mut scratch_matrix, + nttop, + rlwe_modop, + ); + }); + + rgsw_i + }); + + // After this point we don't require RGSW cts for RGSWxRGSW + // multiplicaiton anymore. So we trim them to suit RLWExRGSW require for PBS + let rgsw_cts = rgsw_cts_untrimmed + .map(|rgsw_ct| { + trim_rgsw_ct_matrix_from_rgrg_to_rlrg( + rgsw_ct, + self.parameters().rgsw_by_rgsw_decomposition_params(), + self.parameters().rlwe_by_rgsw_decomposition_params(), + ) + }) + .collect_vec(); + rgsw_cts + } + pub(super) fn non_interactive_multi_party_key_share( - self_ui_to_ksk_seed: [u8; 32], - others_ui_to_ksk_seed: &[[u8; 32]], + &self, + // TODO(Jay): Should get a common reference seed here and derive the rest. + cr_seed: &NonInteractiveMultiPartyCrs<[u8; 32]>, + self_index: usize, + total_users: usize, client_key: &NonInteractiveClientKey, - ) { - // // ui_to_s_ksk - // non_interactive_ksk_gen( - // client_key.sk_rlwe().values(), - // client_key.sk_u_rlwe().values(), - // gadget_vec, - // p_rng, - // rng, - // nttop, - // modop, - // ) + ) -> NonInteractiveMultiPartyServerKeyShare { + // TODO: check whether parameters support `total_users` + + DefaultSecureRng::with_local_mut(|rng| { + let nttop = self.pbs_info().nttop_rlweq(); + let rlwe_modop = self.pbs_info().modop_rlweq(); + let ring_size = self.pbs_info().rlwe_n(); + let rlwe_q = self.parameters().rlwe_q(); + + // ui_to_s_ksk + let non_interactive_decomposer = self + .parameters() + .non_interactive_ui_to_s_key_switch_decomposer::>( + ); + let non_interactive_gadget_vec = non_interactive_decomposer.gadget_vector(); + let ui_to_s_ksk = { + let mut p_rng = DefaultSecureRng::new_seeded( + cr_seed.ui_to_s_ks_seed_for_user_i::(self_index), + ); + non_interactive_ksk_gen::( + client_key.sk_rlwe().values(), + client_key.sk_u_rlwe().values(), + &non_interactive_gadget_vec, + &mut p_rng, + rng, + nttop, + rlwe_modop, + ) + }; + + // zero encryptions for others uj_to_s ksk + let all_users_except_self = (0..total_users).filter(|x| *x != self_index); + let zero_encs_for_others = all_users_except_self + .map(|other_user_index| { + let mut p_rng = DefaultSecureRng::new_seeded( + cr_seed.ui_to_s_ks_seed_for_user_i::(other_user_index), + ); + let zero_encs = + non_interactive_ksk_zero_encryptions_for_other_party_i::( + client_key.sk_rlwe().values(), + &non_interactive_gadget_vec, + &mut p_rng, + rng, + nttop, + rlwe_modop, + ); + zero_encs + }) + .collect_vec(); + + // Main Key gen follows // + + let mut key_prng = + DefaultSecureRng::new_seeded(cr_seed.server_key_share_seed::()); + // generate non-interactive rgsw cts + let rgsw_by_rgsw_decomposer = self + .parameters() + .rgsw_rgsw_decomposer::>(); + + let ni_rgrg_gadget_vec = { + if rgsw_by_rgsw_decomposer.a().decomposition_count() + > rgsw_by_rgsw_decomposer.b().decomposition_count() + { + rgsw_by_rgsw_decomposer.a().gadget_vector() + } else { + rgsw_by_rgsw_decomposer.b().gadget_vector() + } + }; + let ni_rgsw_cts: (Vec, Vec) = client_key + .sk_lwe() + .values() + .iter() + .map(|s_i| { + // X^{s[i]} + let mut m = M::R::zeros(ring_size); + if *s_i < 0 { + // X^{-s[i]} -> -X^{N+s[i]} + m.as_mut()[ring_size - (s_i.abs() as usize)] = rlwe_q.neg_one(); + } else { + m.as_mut()[*s_i as usize] = M::MatElement::one(); + } + + non_interactive_rgsw_ct::( + client_key.sk_rlwe().values(), + client_key.sk_u_rlwe().values(), + m.as_ref(), + &ni_rgrg_gadget_vec, + &mut key_prng, + rng, + nttop, + rlwe_modop, + ) + }) + .unzip(); + + NonInteractiveMultiPartyServerKeyShare { + ni_rgsw_cts, + ui_to_s_ksk, + others_ksk_zero_encs: zero_encs_for_others, + user_index: self_index, + } + }) } pub(super) fn multi_party_public_key_share( @@ -1018,75 +1715,13 @@ where // multiplication. After this point RGSW ciphertexts will only be used for // RLWExRGSW multiplication (in blind rotation). Thus we drop any additional // RLWE ciphertexts in RGSW ciphertexts after RGSw x RGSW multiplication - let (rgswrgsw_d_a, rgswrgsw_d_b) = self.pbs_info.parameters.rgsw_rgsw_decomposition_count(); - let (rlrg_d_a, rlrg_d_b) = self.pbs_info.parameters.rlwe_rgsw_decomposition_count(); - let rgsw_ct_rows_in = rgswrgsw_d_a.0 * 2 + rgswrgsw_d_b.0 * 2; - let rgsw_ct_rows_out = rlrg_d_a.0 * 2 + rlrg_d_b.0 * 2; - assert!(rgswrgsw_d_a.0 >= rlrg_d_a.0, "RGSWxRGSW part A decomposition count {} must be >= RLWExRGSW part A decomposition count {}", rgswrgsw_d_a.0 , rlrg_d_a.0); - assert!(rgswrgsw_d_b.0 >= rlrg_d_b.0, "RGSWxRGSW part B decomposition count {} must be >= RLWExRGSW part B decomposition count {}", rgswrgsw_d_b.0 , rlrg_d_b.0); let rgsw_cts = rgsw_cts .map(|ct_i_in| { - assert!(ct_i_in.dimension() == (rgsw_ct_rows_in, rlwe_n)); - let mut reduced_ct_i_out = M::zeros(rgsw_ct_rows_out, rlwe_n); - - // RLWE'(-sm) part A - izip!( - reduced_ct_i_out.iter_rows_mut().take(rlrg_d_a.0), - ct_i_in - .iter_rows() - .skip(rgswrgsw_d_a.0 - rlrg_d_a.0) - .take(rlrg_d_a.0) + trim_rgsw_ct_matrix_from_rgrg_to_rlrg( + ct_i_in, + self.parameters().rgsw_by_rgsw_decomposition_params(), + self.parameters().rlwe_by_rgsw_decomposition_params(), ) - .for_each(|(to_ri, from_ri)| { - to_ri.as_mut().copy_from_slice(from_ri.as_ref()); - }); - - // RLWE'(-sm) part B - izip!( - reduced_ct_i_out - .iter_rows_mut() - .skip(rlrg_d_a.0) - .take(rlrg_d_a.0), - ct_i_in - .iter_rows() - .skip(rgswrgsw_d_a.0 + (rgswrgsw_d_a.0 - rlrg_d_a.0)) - .take(rlrg_d_a.0) - ) - .for_each(|(to_ri, from_ri)| { - to_ri.as_mut().copy_from_slice(from_ri.as_ref()); - }); - - // RLWE'(m) Part A - izip!( - reduced_ct_i_out - .iter_rows_mut() - .skip(rlrg_d_a.0 * 2) - .take(rlrg_d_b.0), - ct_i_in - .iter_rows() - .skip(rgswrgsw_d_a.0 * 2 + (rgswrgsw_d_b.0 - rlrg_d_b.0)) - .take(rlrg_d_b.0) - ) - .for_each(|(to_ri, from_ri)| { - to_ri.as_mut().copy_from_slice(from_ri.as_ref()); - }); - - // RLWE'(m) Part B - izip!( - reduced_ct_i_out - .iter_rows_mut() - .skip(rlrg_d_a.0 * 2 + rlrg_d_b.0) - .take(rlrg_d_b.0), - ct_i_in - .iter_rows() - .skip(rgswrgsw_d_a.0 * 2 + rgswrgsw_d_b.0 + (rgswrgsw_d_b.0 - rlrg_d_b.0)) - .take(rlrg_d_b.0) - ) - .for_each(|(to_ri, from_ri)| { - to_ri.as_mut().copy_from_slice(from_ri.as_ref()); - }); - - reduced_ct_i_out }) .collect_vec(); @@ -1317,9 +1952,10 @@ mod tests { use rand_distr::Uniform; use crate::{ + backend::ModulusPowerOf2, bool::{ self, CommonReferenceSeededMultiPartyServerKeyShare, PublicKey, - SeededMultiPartyServerKey, SMALL_MP_BOOL_PARAMS, + SeededMultiPartyServerKey, NON_INTERACTIVE_SMALL_MP_BOOL_PARAMS, SMALL_MP_BOOL_PARAMS, }, ntt::NttBackendU64, random::{RandomElementInModulus, DEFAULT_RNG}, @@ -2411,4 +3047,178 @@ mod tests { // } // } } + + #[test] + fn testtest() { + let evaluator = BoolEvaluator::< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModulusPowerOf2>, + ShoupServerKeyEvaluationDomain>>, + >::new(NON_INTERACTIVE_SMALL_MP_BOOL_PARAMS); + let mp_seed = NonInteractiveMultiPartyCrs { seed: [0u8; 32] }; + + let ring_size = evaluator.parameters().rlwe_n().0; + let rlwe_q = evaluator.parameters().rlwe_q(); + let rlwe_modop = evaluator.pbs_info().modop_rlweq(); + let nttop = evaluator.pbs_info().nttop_rlweq(); + + let parties = 2; + + let cks = (0..parties) + .map(|_| evaluator.non_interactive_client_key()) + .collect_vec(); + + let key_shares = (0..parties) + .map(|i| evaluator.non_interactive_multi_party_key_share(&mp_seed, i, parties, &cks[i])) + .collect_vec(); + // dbg!(key_shares[1].user_index); + + let rgsw_cts = evaluator.aggregate_non_interactive_multi_party_key_share( + &mp_seed, + parties, + &key_shares, + ); + + let mut ideal_rlwe = vec![0; ring_size]; + cks.iter().for_each(|k| { + izip!(ideal_rlwe.iter_mut(), k.sk_rlwe().values().iter()).for_each(|(a, b)| { + *a = *a + b; + }); + }); + + let mut ideal_lwe = vec![0; evaluator.parameters().lwe_n().0]; + cks.iter().for_each(|k| { + izip!(ideal_lwe.iter_mut(), k.sk_lwe().values().iter()).for_each(|(a, b)| { + *a = *a + b; + }); + }); + + let mut stats = Stats::new(); + + let (rlrg_decomp_a, rlrg_decomp_b) = evaluator + .parameters() + .rlwe_rgsw_decomposer::>(); + let gadget_vec_a = rlrg_decomp_a.gadget_vector(); + let gadget_vec_b = rlrg_decomp_b.gadget_vector(); + let d_a = rlrg_decomp_a.decomposition_count(); + let d_b = rlrg_decomp_b.decomposition_count(); + let s_poly = Vec::::try_convert_from(ideal_rlwe.as_slice(), rlwe_q); + let mut neg_s_poly_eval = s_poly.clone(); + rlwe_modop.elwise_neg_mut(&mut neg_s_poly_eval); + nttop.forward(neg_s_poly_eval.as_mut()); + rgsw_cts.iter().enumerate().for_each(|(s_index, ct)| { + // X^{lwe_s[i]} + let mut m = vec![0u64; ring_size]; + if ideal_lwe[s_index] < 0 { + m[ring_size - (ideal_lwe[s_index].abs() as usize)] = rlwe_q.neg_one(); + } else { + m[(ideal_lwe[s_index] as usize)] = 1; + } + + let mut neg_sm = m.clone(); + nttop.forward(&mut neg_sm); + rlwe_modop.elwise_mul_mut(&mut neg_sm, &neg_s_poly_eval); + nttop.backward(&mut neg_sm); + + // RLWE'(-sm) + gadget_vec_a.iter().enumerate().for_each(|(index, beta)| { + // RLWE(\beta -sm) + + // \beta * -sX^[lwe_s[i]] + let mut beta_neg_sm = neg_sm.clone(); + rlwe_modop.elwise_scalar_mul_mut(&mut beta_neg_sm, beta); + + // extract RLWE(-sm \beta) + let mut rlwe = vec![vec![0u64; ring_size]; 2]; + rlwe[0].copy_from_slice(&ct[index]); + rlwe[1].copy_from_slice(&ct[index + d_a]); + + // decrypt + let mut m_out = vec![0u64; ring_size]; + decrypt_rlwe(&rlwe, &ideal_rlwe, &mut m_out, nttop, rlwe_modop); + + let mut diff = m_out; + rlwe_modop.elwise_sub_mut(&mut diff, &beta_neg_sm); + + stats.add_more(&Vec::::try_convert_from(&diff, rlwe_q)); + }); + }); + + println!("Stats: {}", stats.std_dev().abs().log2()); + } } + +// let (rgswrgsw_d_a, rgswrgsw_d_b) = +// self.pbs_info.parameters.rgsw_rgsw_decomposition_count(); let (rlrg_d_a, +// rlrg_d_b) = self.pbs_info.parameters.rlwe_rgsw_decomposition_count(); let +// rgsw_ct_rows_in = rgswrgsw_d_a.0 * 2 + rgswrgsw_d_b.0 * 2; +// let rgsw_ct_rows_out = rlrg_d_a.0 * 2 + rlrg_d_b.0 * 2; +// assert!(rgswrgsw_d_a.0 >= rlrg_d_a.0, "RGSWxRGSW part A decomposition count {} must be >= RLWExRGSW part A decomposition count {}", rgswrgsw_d_a.0 , rlrg_d_a.0); +// assert!(rgswrgsw_d_b.0 >= rlrg_d_b.0, "RGSWxRGSW part B decomposition count {} must be >= RLWExRGSW part B decomposition count {}", rgswrgsw_d_b.0 , rlrg_d_b.0); +// let rgsw_cts = rgsw_cts +// .map(|ct_i_in| { +// assert!(ct_i_in.dimension() == (rgsw_ct_rows_in, rlwe_n)); +// let mut reduced_ct_i_out = M::zeros(rgsw_ct_rows_out, rlwe_n); + +// // RLWE'(-sm) part A +// izip!( +// reduced_ct_i_out.iter_rows_mut().take(rlrg_d_a.0), +// ct_i_in +// .iter_rows() +// .skip(rgswrgsw_d_a.0 - rlrg_d_a.0) +// .take(rlrg_d_a.0) +// ) +// .for_each(|(to_ri, from_ri)| { +// to_ri.as_mut().copy_from_slice(from_ri.as_ref()); +// }); + +// // RLWE'(-sm) part B +// izip!( +// reduced_ct_i_out +// .iter_rows_mut() +// .skip(rlrg_d_a.0) +// .take(rlrg_d_a.0), +// ct_i_in +// .iter_rows() +// .skip(rgswrgsw_d_a.0 + (rgswrgsw_d_a.0 - rlrg_d_a.0)) +// .take(rlrg_d_a.0) +// ) +// .for_each(|(to_ri, from_ri)| { +// to_ri.as_mut().copy_from_slice(from_ri.as_ref()); +// }); + +// // RLWE'(m) Part A +// izip!( +// reduced_ct_i_out +// .iter_rows_mut() +// .skip(rlrg_d_a.0 * 2) +// .take(rlrg_d_b.0), +// ct_i_in +// .iter_rows() +// .skip(rgswrgsw_d_a.0 * 2 + (rgswrgsw_d_b.0 - rlrg_d_b.0)) +// .take(rlrg_d_b.0) +// ) +// .for_each(|(to_ri, from_ri)| { +// to_ri.as_mut().copy_from_slice(from_ri.as_ref()); +// }); + +// // RLWE'(m) Part B +// izip!( +// reduced_ct_i_out +// .iter_rows_mut() +// .skip(rlrg_d_a.0 * 2 + rlrg_d_b.0) +// .take(rlrg_d_b.0), +// ct_i_in +// .iter_rows() +// .skip(rgswrgsw_d_a.0 * 2 + rgswrgsw_d_b.0 + +// (rgswrgsw_d_b.0 - rlrg_d_b.0)) .take(rlrg_d_b.0) +// ) +// .for_each(|(to_ri, from_ri)| { +// to_ri.as_mut().copy_from_slice(from_ri.as_ref()); +// }); + +// reduced_ct_i_out +// }) +// .collect_vec(); diff --git a/src/bool/keys.rs b/src/bool/keys.rs index da79d3b..9f8de2f 100644 --- a/src/bool/keys.rs +++ b/src/bool/keys.rs @@ -32,12 +32,6 @@ mod impl_ck { // Client key impl ClientKey { - pub(in super::super) fn random() -> Self { - let sk_rlwe = RlweSecret::random(0, 0); - let sk_lwe = LweSecret::random(0, 0); - Self { sk_rlwe, sk_lwe } - } - pub(in super::super) fn new(sk_rlwe: RlweSecret, sk_lwe: LweSecret) -> Self { Self { sk_rlwe, sk_lwe } } @@ -53,17 +47,6 @@ mod impl_ck { // Client key impl NonInteractiveClientKey { - pub(in super::super) fn random() -> Self { - let sk_rlwe = RlweSecret::random(0, 0); - let sk_u_rlwe = RlweSecret::random(0, 0); - let sk_lwe = LweSecret::random(0, 0); - Self { - sk_rlwe, - sk_u_rlwe, - sk_lwe, - } - } - pub(in super::super) fn new( sk_rlwe: RlweSecret, sk_u_rlwe: RlweSecret, diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs index 8c5245e..4e603ab 100644 --- a/src/bool/parameters.rs +++ b/src/bool/parameters.rs @@ -2,7 +2,7 @@ use num_traits::{ConstZero, FromPrimitive, PrimInt}; use crate::{backend::Modulus, decomposer::Decomposer}; -trait DoubleDecomposerParams { +pub(super) trait DoubleDecomposerParams { type Base; type Count; @@ -100,12 +100,14 @@ pub struct BoolParameters { DecompostionLogBase, (DecompositionCount, DecompositionCount), ), + auto_decomposer_params: (DecompostionLogBase, DecompositionCount), /// RGSW x RGSW decomposition count for (part A, part B) rgrg_decomposer_params: Option<( DecompostionLogBase, (DecompositionCount, DecompositionCount), )>, - auto_decomposer_params: (DecompostionLogBase, DecompositionCount), + non_interactive_ui_to_s_key_switch_decomposer: + Option<(DecompostionLogBase, DecompositionCount)>, g: usize, w: usize, variant: ParameterVariant, @@ -140,6 +142,27 @@ impl BoolParameters { self.w } + pub(crate) fn rlwe_by_rgsw_decomposition_params( + &self, + ) -> ( + DecompostionLogBase, + (DecompositionCount, DecompositionCount), + ) { + self.rlrg_decomposer_params + } + + pub(crate) fn rgsw_by_rgsw_decomposition_params( + &self, + ) -> ( + DecompostionLogBase, + (DecompositionCount, DecompositionCount), + ) { + self.rgrg_decomposer_params.expect(&format!( + "Parameter variant {:?} does not support RGSWxRGSW", + self.variant + )) + } + pub(crate) fn rlwe_rgsw_decomposition_base(&self) -> DecompostionLogBase { self.rlrg_decomposer_params.0 } @@ -172,6 +195,18 @@ impl BoolParameters { self.lwe_decomposer_params.decomposition_count() } + pub(crate) fn non_interactive_ui_to_s_key_switch_decomposition_count( + &self, + ) -> DecompositionCount { + let params = self + .non_interactive_ui_to_s_key_switch_decomposer + .expect(&format!( + "Parameter variant {:?} does not support non-interactive", + self.variant + )); + params.decomposition_count() + } + pub(crate) fn rgsw_rgsw_decomposer>(&self) -> (D, D) where El: Copy, @@ -238,6 +273,25 @@ impl BoolParameters { ) } + pub(crate) fn non_interactive_ui_to_s_key_switch_decomposer>( + &self, + ) -> D + where + El: Copy, + { + let params = self + .non_interactive_ui_to_s_key_switch_decomposer + .expect(&format!( + "Parameter variant {:?} does not support non-interactive", + self.variant + )); + D::new( + self.rlwe_q.0, + params.decomposition_base().0, + params.decomposition_count().0, + ) + } + /// Returns dlogs of `g` for which auto keys are required as /// per the parameter. Given that autos are required for [-g, g, g^2, ..., /// g^w] function returns the following [0, 1, 2, ..., w] where `w` is @@ -397,6 +451,7 @@ pub(crate) const SP_BOOL_PARAMS: BoolParameters = BoolParameters:: { ), rgrg_decomposer_params: None, auto_decomposer_params: (DecompostionLogBase(7), DecompositionCount(4)), + non_interactive_ui_to_s_key_switch_decomposer: None, g: 5, w: 5, variant: ParameterVariant::SingleParty, @@ -418,6 +473,7 @@ pub(crate) const MP_BOOL_PARAMS: BoolParameters = BoolParameters:: { (DecompositionCount(5), DecompositionCount(5)), )), auto_decomposer_params: (DecompostionLogBase(12), DecompositionCount(5)), + non_interactive_ui_to_s_key_switch_decomposer: None, g: 5, w: 10, variant: ParameterVariant::MultiParty, @@ -439,19 +495,44 @@ pub(crate) const SMALL_MP_BOOL_PARAMS: BoolParameters = BoolParameters:: = BoolParameters:: { + rlwe_q: CiphertextModulus::new_non_native(36028797018820609), + lwe_q: CiphertextModulus::new_non_native(1 << 20), + br_q: 1 << 11, + rlwe_n: PolynomialSize(1 << 11), + lwe_n: LweDimension(10), + lwe_decomposer_params: (DecompostionLogBase(4), DecompositionCount(5)), + rlrg_decomposer_params: ( + DecompostionLogBase(11), + (DecompositionCount(2), DecompositionCount(1)), + ), + rgrg_decomposer_params: Some(( + DecompostionLogBase(11), + (DecompositionCount(5), DecompositionCount(4)), + )), + auto_decomposer_params: (DecompostionLogBase(11), DecompositionCount(2)), + non_interactive_ui_to_s_key_switch_decomposer: Some(( + DecompostionLogBase(1), + DecompositionCount(55), + )), + g: 5, + w: 10, + variant: ParameterVariant::NonInteractiveMultiParty, +}; #[cfg(test)] mod tests { use crate::utils::generate_prime; #[test] fn find_prime() { - let bits = 55; - let ring_size = 1 << 15; + let bits = 60; + let ring_size = 1 << 11; let prime = generate_prime(bits, ring_size * 2, 1 << bits).unwrap(); dbg!(prime); } diff --git a/src/multi_party.rs b/src/multi_party.rs index 50d919c..48c4798 100644 --- a/src/multi_party.rs +++ b/src/multi_party.rs @@ -48,7 +48,7 @@ pub(crate) fn public_key_share< modop.elwise_add_mut(share_out.as_mut(), s.as_ref()); // s*e + e } -fn non_interactive_rgsw_ct< +pub(crate) fn non_interactive_rgsw_ct< M: MatrixMut + MatrixEntity, S, PRng: RandomFillUniformInModulus<[M::MatElement], ModOp::M>, @@ -140,7 +140,8 @@ pub(crate) fn non_interactive_ksk_gen< rng: &mut Rng, nttop: &NttOp, modop: &ModOp, -) where +) -> M +where ::R: RowMut + TryConvertFrom1<[S], ModOp::M> + RowEntity, M::MatElement: Copy, { @@ -153,7 +154,6 @@ pub(crate) fn non_interactive_ksk_gen< let mut s_poly_eval = M::R::try_convert_from(s, q); nttop.forward(s_poly_eval.as_mut()); let u_poly = M::R::try_convert_from(u, q); - // a_i * s + \beta u + e let mut ksk = M::zeros(d, ring_size); @@ -176,6 +176,8 @@ pub(crate) fn non_interactive_ksk_gen< // a_i * s + e + \beta * u modop.elwise_add_mut(e_ksk.as_mut(), scratch_space.as_ref()); }); + + ksk } pub(crate) fn non_interactive_ksk_zero_encryptions_for_other_party_i<