From 1d7099600ad8b5d63818c3951ed1650f12c43764 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Mon, 24 Jun 2024 15:26:53 +0700 Subject: [PATCH] add differing base feature for RLWExRGSw and RGSWxRGSW for interactive mpc --- src/bool/evaluator.rs | 449 +++++++++++++++++++++++++---------------- src/bool/keys.rs | 24 ++- src/bool/mp_api.rs | 17 +- src/bool/noise.rs | 11 +- src/bool/parameters.rs | 8 +- src/rgsw/mod.rs | 2 + src/rgsw/runtime.rs | 51 +++-- src/utils.rs | 29 ++- 8 files changed, 381 insertions(+), 210 deletions(-) diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 5166fc3..b19e9cf 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -41,8 +41,8 @@ use crate::{ RlweCiphertext, RlweSecret, }, utils::{ - fill_random_ternary_secret_with_hamming_weight, generate_prime, mod_exponent, - puncture_p_rng, Global, TryConvertFrom1, WithLocal, + encode_x_pow_si_with_emebedding_factor, fill_random_ternary_secret_with_hamming_weight, + generate_prime, mod_exponent, puncture_p_rng, Global, TryConvertFrom1, WithLocal, }, Decryptor, Encoder, Encryptor, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, Row, RowEntity, RowMut, Secret, @@ -527,6 +527,24 @@ where reduced_ct_i_out } +/// Assigns user with user_id segement of LWE secret indices for which they +/// generate RGSW(X^{s[i]}) as the leader (i.e. for RLWExRGSW). If returned +/// tuple is (start, end), user's segment is [start, end) +pub(super) fn interactive_mult_party_user_id_lwe_segment( + user_id: usize, + total_users: usize, + lwe_n: usize, +) -> (usize, usize) { + let per_user = (lwe_n as f64 / total_users as f64) + .ceil() + .to_usize() + .unwrap(); + ( + per_user * user_id, + std::cmp::min(per_user * (user_id + 1), lwe_n), + ) +} + impl BoolEvaluator where M: MatrixEntity + MatrixMut, @@ -800,6 +818,8 @@ where pub(super) fn multi_party_server_key_share>( &self, + user_id: usize, + total_users: usize, cr_seed: &MultiPartyCrs<[u8; 32]>, collective_pk: &M, client_key: &K, @@ -809,10 +829,7 @@ where MultiPartyCrs<[u8; 32]>, > { assert_eq!(self.parameters().variant(), &ParameterVariant::MultiParty); - // let user_id = 0; - - // let user_segment_start = 0; - // let user_segment_end = 1; + assert!(user_id < total_users); let sk_rlwe = client_key.sk_rlwe(); let sk_lwe = client_key.sk_lwe(); @@ -836,41 +853,85 @@ where ); // rgsw ciphertexts of lwe secret elements - let rgsw_cts = DefaultSecureRng::with_local_mut(|rng| { - let rgsw_rgsw_decomposer = self - .pbs_info - .parameters - .rgsw_rgsw_decomposer::>(); - let (rgrg_d_a, rgrg_d_b) = ( - rgsw_rgsw_decomposer.0.decomposition_count(), - rgsw_rgsw_decomposer.1.decomposition_count(), - ); - let (rgrg_gadget_a, rgrg_gadget_b) = ( - rgsw_rgsw_decomposer.0.gadget_vector(), - rgsw_rgsw_decomposer.1.gadget_vector(), + let (self_leader_rgsws, not_self_leader_rgsws) = DefaultSecureRng::with_local_mut(|rng| { + let mut self_leader_rgsw = vec![]; + let mut not_self_leader_rgsws = vec![]; + + let (segment_start, segment_end) = interactive_mult_party_user_id_lwe_segment( + user_id, + total_users, + self.pbs_info().lwe_n(), ); - let rgsw_cts = sk_lwe - .iter() - .map(|si| { - let mut m = M::R::zeros(ring_size); - //TODO(Jay): It will be nice to have a function that returns polynomial - // (monomial infact!) corresponding to secret element embedded in ring X^{2N+1}. - // Save lots of mistakes where one forgest to emebed si in bigger ring. - let si = *si * (self.pbs_info.embedding_factor as i32); - if si < 0 { - // X^{-si} = X^{2N-si} = -X^{N-si}, assuming abs(si) < N - // (which it is given si is secret element) - m.as_mut()[ring_size - (si.abs() as usize)] = rlwe_q.neg_one(); - } else { - m.as_mut()[si as usize] = M::MatElement::one(); - } - // public key RGSW encryption has no part that can be seeded, unlike secret key - // RGSW encryption where RLWE'_A(m) is seeded + // self LWE secret indices + { + // LWE secret indices for which user is the leader they need to send RGSW(m) for + // RLWE x RGSW multiplication + let rlrg_decomposer = self.pbs_info().rlwe_rgsw_decomposer(); + let (rlrg_d_a, rlrg_d_b) = ( + rlrg_decomposer.a().decomposition_count(), + rlrg_decomposer.b().decomposition_count(), + ); + let (gadget_a, gadget_b) = ( + rlrg_decomposer.a().gadget_vector(), + rlrg_decomposer.b().gadget_vector(), + ); + for s_index in segment_start..segment_end { + let mut out_rgsw = M::zeros(rlrg_d_a * 2 + rlrg_d_b * 2, ring_size); + public_key_encrypt_rgsw( + &mut out_rgsw, + &encode_x_pow_si_with_emebedding_factor::< + M::R, + CiphertextModulus, + >( + sk_lwe[s_index], + self.pbs_info().embedding_factor(), + ring_size, + self.pbs_info().rlwe_q(), + ) + .as_ref(), + collective_pk, + &gadget_a, + &gadget_b, + rlweq_modop, + rlweq_nttop, + rng, + ); + self_leader_rgsw.push(out_rgsw); + } + } + + // not self LWE secret indices + { + // LWE secret indices for which user isn't the leader, they need to send RGSW(m) + // for RGSW x RGSW multiplcation + let rgsw_rgsw_decomposer = self + .pbs_info + .parameters + .rgsw_rgsw_decomposer::>(); + let (rgrg_d_a, rgrg_d_b) = ( + rgsw_rgsw_decomposer.a().decomposition_count(), + rgsw_rgsw_decomposer.b().decomposition_count(), + ); + let (rgrg_gadget_a, rgrg_gadget_b) = ( + rgsw_rgsw_decomposer.a().gadget_vector(), + rgsw_rgsw_decomposer.b().gadget_vector(), + ); + + for s_index in (0..segment_start).chain(segment_end..self.parameters().lwe_n().0) { let mut out_rgsw = M::zeros(rgrg_d_a * 2 + rgrg_d_b * 2, ring_size); public_key_encrypt_rgsw( &mut out_rgsw, - &m.as_ref(), + &encode_x_pow_si_with_emebedding_factor::< + M::R, + CiphertextModulus, + >( + sk_lwe[s_index], + self.pbs_info().embedding_factor(), + ring_size, + self.pbs_info().rlwe_q(), + ) + .as_ref(), collective_pk, &rgrg_gadget_a, &rgrg_gadget_b, @@ -879,10 +940,11 @@ where rng, ); - out_rgsw - }) - .collect_vec(); - rgsw_cts + not_self_leader_rgsws.push(out_rgsw); + } + } + + (self_leader_rgsw, not_self_leader_rgsws) }); // LWE Ksk @@ -893,14 +955,173 @@ where ); CommonReferenceSeededMultiPartyServerKeyShare::new( - rgsw_cts, + self_leader_rgsws, + not_self_leader_rgsws, auto_keys, lwe_ksk, cr_seed.clone(), self.pbs_info.parameters.clone(), + user_id, ) } + pub(super) fn aggregate_multi_party_server_key_shares( + &self, + shares: &[CommonReferenceSeededMultiPartyServerKeyShare< + M, + BoolParameters, + MultiPartyCrs, + >], + ) -> SeededMultiPartyServerKey, BoolParameters> + where + S: PartialEq + Clone, + M: Clone, + { + assert_eq!(self.parameters().variant(), &ParameterVariant::MultiParty); + assert!(shares.len() > 0); + + let total_users = shares.len(); + + let parameters = shares[0].parameters().clone(); + let cr_seed = shares[0].cr_seed(); + + let rlwe_n = parameters.rlwe_n().0; + let g = parameters.g() as isize; + let rlwe_q = parameters.rlwe_q(); + let lwe_q = parameters.lwe_q(); + + // sanity checks + shares.iter().skip(1).for_each(|s| { + assert!(s.parameters() == ¶meters); + assert!(s.cr_seed() == cr_seed); + }); + + let rlweq_modop = &self.pbs_info.rlwe_modop; + let rlweq_nttop = &self.pbs_info.rlwe_nttop; + + // auto keys + let mut auto_keys = HashMap::new(); + let auto_elements_dlog = parameters.auto_element_dlogs(); + for i in auto_elements_dlog.into_iter() { + let mut key = M::zeros(parameters.auto_decomposition_count().0, rlwe_n); + + shares.iter().for_each(|s| { + let auto_key_share_i = s.auto_keys().get(&i).expect("Auto key {i} missing"); + assert!( + auto_key_share_i.dimension() + == (parameters.auto_decomposition_count().0, rlwe_n) + ); + izip!(key.iter_rows_mut(), auto_key_share_i.iter_rows()).for_each( + |(partb_out, partb_share)| { + rlweq_modop.elwise_add_mut(partb_out.as_mut(), partb_share.as_ref()); + }, + ); + }); + + auto_keys.insert(i, key); + } + + // rgsw ciphertext (most expensive part!) + let rgsw_cts = { + let rgsw_by_rgsw_decomposer = + parameters.rgsw_rgsw_decomposer::>(); + let rlwe_x_rgsw_decomposer = self.pbs_info().rlwe_rgsw_decomposer(); + let rgsw_x_rgsw_dimension = ( + rgsw_by_rgsw_decomposer.a().decomposition_count() * 2 + + rgsw_by_rgsw_decomposer.b().decomposition_count() * 2, + rlwe_n, + ); + let rlwe_x_rgsw_dimension = ( + rlwe_x_rgsw_decomposer.a().decomposition_count() * 2 + + rlwe_x_rgsw_decomposer.b().decomposition_count() * 2, + rlwe_n, + ); + let mut rgsw_x_rgsw_scratch_mat = M::zeros( + std::cmp::max( + rgsw_by_rgsw_decomposer.a().decomposition_count(), + rgsw_by_rgsw_decomposer.b().decomposition_count(), + ) + rlwe_x_rgsw_dimension.0, + rlwe_n, + ); + + let shares_in_correct_order = (0..total_users) + .map(|i| shares.iter().find(|s| s.user_id() == i).unwrap()) + .collect_vec(); + + let lwe_n = self.parameters().lwe_n().0; + let (users_segments, users_segments_sizes): (Vec<(usize, usize)>, Vec) = (0 + ..total_users) + .map(|(user_id)| { + let (start_index, end_index) = + interactive_mult_party_user_id_lwe_segment(user_id, total_users, lwe_n); + ((start_index, end_index), end_index - start_index) + }) + .unzip(); + + let mut rgsw_cts = Vec::with_capacity(lwe_n); + users_segments + .iter() + .enumerate() + .for_each(|(user_id, user_segment)| { + let share = shares_in_correct_order[user_id]; + for secret_index in user_segment.0..user_segment.1 { + let mut rgsw_i = + share.self_leader_rgsws()[secret_index - user_segment.0].clone(); + // assert already exists in RGSW x RGSW rountine + assert!(rgsw_i.dimension() == rlwe_x_rgsw_dimension); + + // multiply leader's RGSW ct at `secret_index` with RGSW cts of other users + // for lwe index `secret_index` + (0..total_users) + .filter(|i| i != &user_id) + .for_each(|other_user_id| { + let mut offset = 0; + if other_user_id < user_id { + offset = users_segments_sizes[other_user_id]; + } + + let mut other_rgsw_i = shares_in_correct_order[other_user_id] + .not_self_leader_rgsws() + [secret_index.checked_sub(offset).unwrap()] + .clone(); + // assert already exists in RGSW x RGSW rountine + assert!(other_rgsw_i.dimension() == rgsw_x_rgsw_dimension); + + // send to evaluation domain for RGSwxRGSW mul + other_rgsw_i + .iter_rows_mut() + .for_each(|r| rlweq_nttop.forward(r.as_mut())); + + rgsw_by_rgsw_inplace( + &mut rgsw_i, + rlwe_x_rgsw_decomposer.a().decomposition_count(), + rlwe_x_rgsw_decomposer.b().decomposition_count(), + &other_rgsw_i, + &rgsw_by_rgsw_decomposer, + &mut rgsw_x_rgsw_scratch_mat, + rlweq_nttop, + rlweq_modop, + ) + }); + + rgsw_cts.push(rgsw_i); + } + }); + + rgsw_cts + }; + + // LWE ksks + let mut lwe_ksk = M::R::zeros(rlwe_n * parameters.lwe_decomposition_count().0); + let lweq_modop = &self.pbs_info.lwe_modop; + shares.iter().for_each(|si| { + assert!(si.lwe_ksk().as_ref().len() == rlwe_n * parameters.lwe_decomposition_count().0); + lweq_modop.elwise_add_mut(lwe_ksk.as_mut(), si.lwe_ksk().as_ref()) + }); + + SeededMultiPartyServerKey::new(rgsw_cts, auto_keys, lwe_ksk, cr_seed.clone(), parameters) + } + pub(super) fn aggregate_non_interactive_multi_party_key_share( &self, cr_seed: &NonInteractiveMultiPartyCrs<[u8; 32]>, @@ -1351,6 +1572,8 @@ where .for_each(|user_i_rgsws| { rgsw_by_rgsw_inplace( &mut rgsw_i, + rgsw_by_rgsw_decomposer.a().decomposition_count(), + rgsw_by_rgsw_decomposer.b().decomposition_count(), &user_i_rgsws[s_index], &rgsw_by_rgsw_decomposer, &mut scratch_matrix, @@ -1834,125 +2057,6 @@ where let m = decrypt_lwe(lwe_ct, &client_key.sk_rlwe(), &self.pbs_info.rlwe_modop); self.pbs_info.rlwe_q().decode(m) } - - pub(super) fn aggregate_multi_party_server_key_shares( - &self, - shares: &[CommonReferenceSeededMultiPartyServerKeyShare< - M, - BoolParameters, - MultiPartyCrs, - >], - ) -> SeededMultiPartyServerKey, BoolParameters> - where - S: PartialEq + Clone, - M: Clone, - { - assert_eq!(self.parameters().variant(), &ParameterVariant::MultiParty); - assert!(shares.len() > 0); - let parameters = shares[0].parameters().clone(); - let cr_seed = shares[0].cr_seed(); - - let rlwe_n = parameters.rlwe_n().0; - let g = parameters.g() as isize; - let rlwe_q = parameters.rlwe_q(); - let lwe_q = parameters.lwe_q(); - - // sanity checks - shares.iter().skip(1).for_each(|s| { - assert!(s.parameters() == ¶meters); - assert!(s.cr_seed() == cr_seed); - }); - - let rlweq_modop = &self.pbs_info.rlwe_modop; - let rlweq_nttop = &self.pbs_info.rlwe_nttop; - - // auto keys - let mut auto_keys = HashMap::new(); - let auto_elements_dlog = parameters.auto_element_dlogs(); - for i in auto_elements_dlog.into_iter() { - let mut key = M::zeros(parameters.auto_decomposition_count().0, rlwe_n); - - shares.iter().for_each(|s| { - let auto_key_share_i = s.auto_keys().get(&i).expect("Auto key {i} missing"); - assert!( - auto_key_share_i.dimension() - == (parameters.auto_decomposition_count().0, rlwe_n) - ); - izip!(key.iter_rows_mut(), auto_key_share_i.iter_rows()).for_each( - |(partb_out, partb_share)| { - rlweq_modop.elwise_add_mut(partb_out.as_mut(), partb_share.as_ref()); - }, - ); - }); - - auto_keys.insert(i, key); - } - - // rgsw ciphertext (most expensive part!) - let lwe_n = parameters.lwe_n().0; - let rgsw_by_rgsw_decomposer = - parameters.rgsw_rgsw_decomposer::>(); - let mut scratch_matrix = M::zeros( - std::cmp::max( - rgsw_by_rgsw_decomposer.a().decomposition_count(), - rgsw_by_rgsw_decomposer.b().decomposition_count(), - ) + (rgsw_by_rgsw_decomposer.a().decomposition_count() * 2 - + rgsw_by_rgsw_decomposer.b().decomposition_count() * 2), - rlwe_n, - ); - - let mut tmp_rgsw = - RgswCiphertext::::empty(rlwe_n, &rgsw_by_rgsw_decomposer, rlwe_q.clone()).data; - let rgsw_cts = (0..lwe_n).into_iter().map(|index| { - // copy over rgsw ciphertext for index^th secret element from first share and - // treat it as accumulating rgsw ciphertext - let mut rgsw_i = shares[0].rgsw_cts()[index].clone(); - - shares.iter().skip(1).for_each(|si| { - // copy over si's RGSW[index] ciphertext and send to evaluation domain - izip!(tmp_rgsw.iter_rows_mut(), si.rgsw_cts()[index].iter_rows()).for_each( - |(to_ri, from_ri)| { - to_ri.as_mut().copy_from_slice(from_ri.as_ref()); - rlweq_nttop.forward(to_ri.as_mut()) - }, - ); - - rgsw_by_rgsw_inplace( - &mut rgsw_i, - &tmp_rgsw, - &rgsw_by_rgsw_decomposer, - &mut scratch_matrix, - rlweq_nttop, - rlweq_modop, - ); - }); - - rgsw_i - }); - // d_a and d_b may differ for RGSWxRGSW multiplication and RLWExRGSW - // multiplication. After this point RGSW ciphertexts will only be used for - // RLWExRGSW multiplication (in blind rotation). Thus we drop any additional - // RLWE ciphertexts in RGSW ciphertexts after RGSw x RGSW multiplication - let rgsw_cts = rgsw_cts - .map(|ct_i_in| { - trim_rgsw_ct_matrix_from_rgrg_to_rlrg( - ct_i_in, - self.parameters().rgsw_by_rgsw_decomposition_params(), - self.parameters().rlwe_by_rgsw_decomposition_params(), - ) - }) - .collect_vec(); - - // LWE ksks - let mut lwe_ksk = M::R::zeros(rlwe_n * parameters.lwe_decomposition_count().0); - let lweq_modop = &self.pbs_info.lwe_modop; - shares.iter().for_each(|si| { - assert!(si.lwe_ksk().as_ref().len() == rlwe_n * parameters.lwe_decomposition_count().0); - lweq_modop.elwise_add_mut(lwe_ksk.as_mut(), si.lwe_ksk().as_ref()) - }); - - SeededMultiPartyServerKey::new(rgsw_cts, auto_keys, lwe_ksk, cr_seed.clone(), parameters) - } } impl BoolEvaluator @@ -2267,6 +2371,8 @@ mod tests { }); }); + let mut rng = DefaultSecureRng::new(); + // check noise in freshly encrypted RLWE ciphertext (ie var_fresh) if false { let mut rng = DefaultSecureRng::new(); @@ -2316,9 +2422,6 @@ mod tests { if true { // Generate server key shares - let mut rng = DefaultSecureRng::new(); - let mut pk_cr_seed = [0u8; 32]; - rng.fill_bytes(&mut pk_cr_seed); let public_key_share = parties .iter() .map(|k| bool_evaluator.multi_party_public_key_share(&int_mp_seed, k)) @@ -2329,12 +2432,13 @@ mod tests { ModularOpsU64>, >::from(public_key_share.as_slice()); - let pbs_cr_seed = [0u8; 32]; - rng.fill_bytes(&mut pk_cr_seed); let server_key_shares = parties .iter() - .map(|k| { + .enumerate() + .map(|(user_id, k)| { bool_evaluator.multi_party_server_key_share( + user_id, + no_of_parties, &int_mp_seed, collective_pk.key(), k, @@ -2351,13 +2455,12 @@ mod tests { izip!(ideal_lwe_sk.iter(), seeded_server_key.rgsw_cts().iter()).for_each( |(s_i, rgsw_ct_i)| { // X^{s[i]} - let mut m_si = vec![0u64; rlwe_n]; - let s_i = *s_i * (bool_evaluator.pbs_info.embedding_factor as i32); - if s_i < 0 { - m_si[rlwe_n - (s_i.abs() as usize)] = rlwe_q.neg_one(); - } else { - m_si[s_i as usize] = 1; - } + let m_si = encode_x_pow_si_with_emebedding_factor::, _>( + *s_i, + bool_evaluator.pbs_info.embedding_factor, + rlwe_n, + rlwe_q, + ); // RLWE'(-sm) let mut neg_s_eval = diff --git a/src/bool/keys.rs b/src/bool/keys.rs index 99bc03f..03b8911 100644 --- a/src/bool/keys.rs +++ b/src/bool/keys.rs @@ -314,7 +314,8 @@ impl CommonReferenceSeededCollectivePublicKeyShare { /// CRS seeded Multi-party server key share pub struct CommonReferenceSeededMultiPartyServerKeyShare { - rgsw_cts: Vec, + self_leader_rgsws: Vec, + not_self_leader_rgsws: Vec, /// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding /// to -g is at 0 auto_keys: HashMap, @@ -322,22 +323,27 @@ pub struct CommonReferenceSeededMultiPartyServerKeyShare { /// Common reference seed cr_seed: S, parameters: P, + user_id: usize, } impl CommonReferenceSeededMultiPartyServerKeyShare { pub(super) fn new( - rgsw_cts: Vec, + self_leader_rgsws: Vec, + not_self_leader_rgsws: Vec, auto_keys: HashMap, lwe_ksk: M::R, cr_seed: S, parameters: P, + user_id: usize, ) -> Self { CommonReferenceSeededMultiPartyServerKeyShare { - rgsw_cts, + self_leader_rgsws, + not_self_leader_rgsws, auto_keys, lwe_ksk, cr_seed, parameters, + user_id, } } @@ -353,13 +359,21 @@ impl CommonReferenceSeededMultiPartyServerKeyShare { &self.auto_keys } - pub(super) fn rgsw_cts(&self) -> &[M] { - &self.rgsw_cts + pub(crate) fn self_leader_rgsws(&self) -> &[M] { + &self.self_leader_rgsws + } + + pub(super) fn not_self_leader_rgsws(&self) -> &[M] { + &self.not_self_leader_rgsws } pub(super) fn lwe_ksk(&self) -> &M::R { &self.lwe_ksk } + + pub(super) fn user_id(&self) -> usize { + self.user_id + } } /// CRS seeded MultiParty server key diff --git a/src/bool/mp_api.rs b/src/bool/mp_api.rs index f77ec09..4dd7279 100644 --- a/src/bool/mp_api.rs +++ b/src/bool/mp_api.rs @@ -58,6 +58,8 @@ pub fn gen_mp_keys_phase1( pub fn gen_mp_keys_phase2( ck: &ClientKey, + user_id: usize, + total_users: usize, pk: &PublicKey>, R, ModOp>, ) -> CommonReferenceSeededMultiPartyServerKeyShare< Vec>, @@ -65,8 +67,13 @@ pub fn gen_mp_keys_phase2( MultiPartyCrs<[u8; 32]>, > { BoolEvaluator::with_local_mut(|e| { - let server_key_share = - e.multi_party_server_key_share(MultiPartyCrs::global(), pk.key(), ck); + let server_key_share = e.multi_party_server_key_share( + user_id, + total_users, + MultiPartyCrs::global(), + pk.key(), + ck, + ); server_key_share }) } @@ -251,7 +258,11 @@ mod tests { let pk = aggregate_public_key_shares(&pk_shares); // round 2 - let server_key_shares = cks.iter().map(|k| gen_mp_keys_phase2(k, &pk)).collect_vec(); + let server_key_shares = cks + .iter() + .enumerate() + .map(|(user_id, k)| gen_mp_keys_phase2(k, user_id, parties, &pk)) + .collect_vec(); // server key let server_key = aggregate_server_key_shares(&server_key_shares); diff --git a/src/bool/noise.rs b/src/bool/noise.rs index dcf87e3..6851f9d 100644 --- a/src/bool/noise.rs +++ b/src/bool/noise.rs @@ -13,6 +13,7 @@ mod test { }, evaluator::MultiPartyCrs, ntt::NttBackendU64, + parameters::OPTIMISED_SMALL_MP_BOOL_PARAMS, random::DefaultSecureRng, }; @@ -25,7 +26,7 @@ mod test { ModularOpsU64>, ModulusPowerOf2>, ShoupServerKeyEvaluationDomain>>, - >::new(SMALL_MP_BOOL_PARAMS); + >::new(OPTIMISED_SMALL_MP_BOOL_PARAMS); let parties = 2; @@ -72,7 +73,10 @@ mod test { // round 2 let server_key_shares = cks .iter() - .map(|c| evaluator.multi_party_server_key_share(&cr_seed, &pk.key(), c)) + .enumerate() + .map(|(index, c)| { + evaluator.multi_party_server_key_share(index, parties, &cr_seed, &pk.key(), c) + }) .collect_vec(); let server_key = evaluator.aggregate_multi_party_server_key_shares(&server_key_shares); @@ -89,9 +93,6 @@ mod test { let mut c_m0 = evaluator.pk_encrypt(pk.key(), m0); let mut c_m1 = evaluator.pk_encrypt(pk.key(), m1); - let true_el_encoded = evaluator.parameters().rlwe_q().true_el(); - let false_el_encoded = evaluator.parameters().rlwe_q().false_el(); - // let mut stats = Stats::new(); for _ in 0..1000 { diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs index f8336e4..f6f801f 100644 --- a/src/bool/parameters.rs +++ b/src/bool/parameters.rs @@ -494,14 +494,14 @@ pub(crate) const OPTIMISED_SMALL_MP_BOOL_PARAMS: BoolParameters = BoolParam lwe_n: LweDimension(500), lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(11)), rlrg_decomposer_params: ( - DecompostionLogBase(24), + DecompostionLogBase(16), (DecompositionCount(1), DecompositionCount(1)), ), rgrg_decomposer_params: Some(( - DecompostionLogBase(12), - (DecompositionCount(3), DecompositionCount(3)), + DecompostionLogBase(8), + (DecompositionCount(6), DecompositionCount(6)), )), - auto_decomposer_params: (DecompostionLogBase(20), DecompositionCount(1)), + auto_decomposer_params: (DecompostionLogBase(24), DecompositionCount(1)), non_interactive_ui_to_s_key_switch_decomposer: None, g: 5, w: 10, diff --git a/src/rgsw/mod.rs b/src/rgsw/mod.rs index 87b6ef3..f121aa1 100644 --- a/src/rgsw/mod.rs +++ b/src/rgsw/mod.rs @@ -1114,6 +1114,8 @@ pub(crate) mod tests { ); rgsw_by_rgsw_inplace( &mut rgsw_carrym, + decomposer.a().decomposition_count(), + decomposer.b().decomposition_count(), &rgsw_m.data, &decomposer, &mut scratch_matrix, diff --git a/src/rgsw/runtime.rs b/src/rgsw/runtime.rs index e8d19c9..d1d0eba 100644 --- a/src/rgsw/runtime.rs +++ b/src/rgsw/runtime.rs @@ -546,14 +546,19 @@ pub(crate) fn rlwe_by_rgsw_shoup< /// - rgsw_1_eval: RGSW(m1) in Evaluation domain /// - scratch_matrix_d_plus_rgsw_by_ring: scratch space matrix with rows /// (max(d_a, d_b) + d_a*2+d_b*2) and columns ring_size +/// +/// ## Note: +/// - We treat RGSW x RGSW as multiple RLWE x RGSW multiplications. . pub(crate) fn rgsw_by_rgsw_inplace< Mmut: MatrixMut, D: RlweDecomposer, ModOp: VectorOps, NttOp: Ntt, >( - rgsw_0: &mut Mmut, - rgsw_1_eval: &Mmut, + rgsw0: &mut Mmut, + rgsw0_da: usize, + rgsw0_db: usize, + rgsw1_eval: &Mmut, decomposer: &D, scratch_matrix: &mut Mmut, ntt_op: &NttOp, @@ -567,11 +572,12 @@ pub(crate) fn rgsw_by_rgsw_inplace< let d_a = decomposer_a.decomposition_count(); let d_b = decomposer_b.decomposition_count(); let max_d = std::cmp::max(d_a, d_b); - let rgsw_rows = d_a * 2 + d_b * 2; - assert!(rgsw_0.dimension().0 == rgsw_rows); - let ring_size = rgsw_0.dimension().1; - assert!(rgsw_1_eval.dimension() == (rgsw_rows, ring_size)); - assert!(scratch_matrix.fits(max_d + rgsw_rows, ring_size)); + let rgsw1_rows = d_a * 2 + d_b * 2; + let rgsw0_rows = rgsw0_da * 2 + rgsw0_db * 2; + let ring_size = rgsw0.dimension().1; + assert!(rgsw0.dimension().0 == rgsw0_rows); + assert!(rgsw1_eval.dimension() == (rgsw1_rows, ring_size)); + assert!(scratch_matrix.fits(max_d + rgsw0_rows, ring_size)); let (decomp_r_space, rgsw_space) = scratch_matrix.split_at_row_mut(max_d); @@ -579,18 +585,25 @@ pub(crate) fn rgsw_by_rgsw_inplace< rgsw_space .iter_mut() .for_each(|ri| ri.as_mut().fill(Mmut::MatElement::zero())); - let (rlwe_dash_space_nsm, rlwe_dash_space_m) = rgsw_space.split_at_mut(d_a * 2); + let (rlwe_dash_space_nsm, rlwe_dash_space_m) = rgsw_space.split_at_mut(rgsw0_da * 2); let (rlwe_dash_space_nsm_parta, rlwe_dash_space_nsm_partb) = - rlwe_dash_space_nsm.split_at_mut(d_a); - let (rlwe_dash_space_m_parta, rlwe_dash_space_m_partb) = rlwe_dash_space_m.split_at_mut(d_b); + rlwe_dash_space_nsm.split_at_mut(rgsw0_da); + let (rlwe_dash_space_m_parta, rlwe_dash_space_m_partb) = + rlwe_dash_space_m.split_at_mut(rgsw0_db); - let (rgsw0_nsm, rgsw0_m) = rgsw_0.split_at_row(d_a * 2); - let (rgsw1_nsm, rgsw1_m) = rgsw_1_eval.split_at_row(d_a * 2); + let (rgsw0_nsm, rgsw0_m) = rgsw0.split_at_row(rgsw0_da * 2); + let (rgsw1_nsm, rgsw1_m) = rgsw1_eval.split_at_row(d_a * 2); // RGSW x RGSW izip!( - rgsw0_nsm.iter().take(d_a).chain(rgsw0_m.iter().take(d_b)), - rgsw0_nsm.iter().skip(d_a).chain(rgsw0_m.iter().skip(d_b)), + rgsw0_nsm + .iter() + .take(rgsw0_da) + .chain(rgsw0_m.iter().take(rgsw0_db)), + rgsw0_nsm + .iter() + .skip(rgsw0_da) + .chain(rgsw0_m.iter().skip(rgsw0_db)), rlwe_dash_space_nsm_parta .iter_mut() .chain(rlwe_dash_space_m_parta.iter_mut()), @@ -599,7 +612,9 @@ pub(crate) fn rgsw_by_rgsw_inplace< .chain(rlwe_dash_space_m_partb.iter_mut()), ) .for_each(|(rlwe_a, rlwe_b, rlwe_out_a, rlwe_out_b)| { - // Part A + // RLWE(m0) x RGSW(m1) + + // Part A: Decomp \cdot RLWE'(-sm1) decompose_r(rlwe_a.as_ref(), decomp_r_space.as_mut(), decomposer_a); decomp_r_space .iter_mut() @@ -618,7 +633,7 @@ pub(crate) fn rgsw_by_rgsw_inplace< mod_op, ); - // Part B + // Part B: Decompose \cdot RLWE'(m1) decompose_r(rlwe_b.as_ref(), decomp_r_space.as_mut(), decomposer_b); decomp_r_space .iter_mut() @@ -639,11 +654,11 @@ pub(crate) fn rgsw_by_rgsw_inplace< }); // copy over RGSW(m0m1) into RGSW(m0) - izip!(rgsw_0.iter_rows_mut(), rgsw_space.iter()) + izip!(rgsw0.iter_rows_mut(), rgsw_space.iter()) .for_each(|(to_ri, from_ri)| to_ri.as_mut().copy_from_slice(from_ri.as_ref())); // send back to coefficient domain - rgsw_0 + rgsw0 .iter_rows_mut() .for_each(|ri| ntt_op.backward(ri.as_mut())); } diff --git a/src/utils.rs b/src/utils.rs index 45fea91..aad1144 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,12 +1,12 @@ use std::{fmt::Debug, usize, vec}; use itertools::{izip, Itertools}; -use num_traits::{FromPrimitive, PrimInt, Signed}; +use num_traits::{FromPrimitive, One, PrimInt, Signed}; use crate::{ backend::Modulus, random::{RandomElementInModulus, RandomFill}, - Matrix, + Matrix, Row, RowEntity, RowMut, }; pub trait WithLocal { fn with_local(func: F) -> R @@ -190,6 +190,31 @@ pub fn negacyclic_mul T>( return r; } +/// Returns a polynomial X^{emebedding_factor * si} \mod {Z_Q / X^{N}+1} +pub(crate) fn encode_x_pow_si_with_emebedding_factor< + R: RowEntity + RowMut, + M: Modulus, +>( + si: i32, + embedding_factor: usize, + ring_size: usize, + modulus: &M, +) -> R +where + R::Element: One, +{ + assert!((si.abs() as usize) < ring_size); + let mut m = R::zeros(ring_size); + let si = si * (embedding_factor as i32); + if si < 0 { + // X^{-si} = X^{2N-si} = -X^{N-si}, assuming abs(si) < N + m.as_mut()[ring_size - (si.abs() as usize)] = modulus.neg_one(); + } else { + m.as_mut()[si as usize] = R::Element::one(); + } + m +} + pub(crate) fn puncture_p_rng>( p_rng: &mut R, times: usize,