diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index d04bcee..00ed7db 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -2070,41 +2070,6 @@ where }) } - pub(super) fn multi_party_decryption_share>( - &self, - lwe_ct: &M::R, - client_key: &K, - ) -> ::MatElement { - assert!(lwe_ct.as_ref().len() == self.pbs_info.parameters.rlwe_n().0 + 1); - let modop = &self.pbs_info.rlwe_modop; - let mut neg_s = - M::R::try_convert_from(&client_key.sk_rlwe(), &self.pbs_info.parameters.rlwe_q()); - modop.elwise_neg_mut(neg_s.as_mut()); - - let mut neg_sa = M::MatElement::zero(); - izip!(lwe_ct.as_ref().iter().skip(1), neg_s.as_ref().iter()).for_each(|(ai, nsi)| { - neg_sa = modop.add(&neg_sa, &modop.mul(ai, nsi)); - }); - - let e = DefaultSecureRng::with_local_mut(|rng| { - RandomGaussianElementInModulus::random(rng, self.pbs_info.parameters.rlwe_q()) - }); - let share = modop.add(&neg_sa, &e); - - share - } - - pub(crate) fn multi_party_decrypt(&self, shares: &[M::MatElement], lwe_ct: &M::R) -> bool { - let modop = &self.pbs_info.rlwe_modop; - let mut sum_a = M::MatElement::zero(); - shares - .iter() - .for_each(|share_i| sum_a = modop.add(&sum_a, &share_i)); - - let encoded_m = modop.add(&lwe_ct.as_ref()[0], &sum_a); - self.pbs_info.parameters.rlwe_q().decode(encoded_m) - } - pub fn sk_encrypt>( &self, m: bool, diff --git a/src/bool/keys.rs b/src/bool/keys.rs index 3bbd727..a40bf3e 100644 --- a/src/bool/keys.rs +++ b/src/bool/keys.rs @@ -99,6 +99,7 @@ mod impl_ck { } } + #[cfg(feature = "interactive_mp")] impl InteractiveMultiPartyClientKey for ClientKey<[u8; 32], E> { type Element = i32; fn sk_lwe(&self) -> Vec { @@ -109,6 +110,7 @@ mod impl_ck { } } + #[cfg(feature = "non_interactive_mp")] impl NonInteractiveMultiPartyClientKey for ClientKey<[u8; 32], E> { type Element = i32; fn sk_lwe(&self) -> Vec { diff --git a/src/bool/mod.rs b/src/bool/mod.rs index e465f7d..77ad7e5 100644 --- a/src/bool/mod.rs +++ b/src/bool/mod.rs @@ -205,25 +205,6 @@ mod common_mp_enc_dec { type Mat = Vec>; - impl MultiPartyDecryptor::R> for super::keys::ClientKey<[u8; 32], E> { - type DecryptionShare = ::MatElement; - - /// Generate multi-party decryption share for LWE ciphertext `c` - fn gen_decryption_share(&self, c: &::R) -> Self::DecryptionShare { - BoolEvaluator::with_local(|e| e.multi_party_decryption_share(c, self)) - } - - /// Aggregate mult-party decryptions shares of all parties, decrypt LWE - /// ciphertext `c`, and return the bool plaintext - fn aggregate_decryption_shares( - &self, - c: &::R, - shares: &[Self::DecryptionShare], - ) -> bool { - BoolEvaluator::with_local(|e| e.multi_party_decrypt(shares, c)) - } - } - impl SampleExtractor<::R> for Mat { /// Sample extract coefficient at `index` as a LWE ciphertext from RLWE /// ciphertext `Self` diff --git a/src/bool/mp_api.rs b/src/bool/mp_api.rs index d3a5030..623bca3 100644 --- a/src/bool/mp_api.rs +++ b/src/bool/mp_api.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, ops::Mul, sync::OnceLock}; +use std::{cell::RefCell, sync::OnceLock}; use crate::{ backend::{ModularOpsU64, ModulusPowerOf2}, @@ -184,9 +184,13 @@ impl Global for RuntimeServerKey { mod impl_enc_dec { use crate::{ bool::evaluator::BoolEncoding, + multi_party::{ + multi_party_aggregate_decryption_shares_and_decrypt, multi_party_decryption_share, + }, pbs::{sample_extract, PbsInfo}, rgsw::public_key_encrypt_rlwe, - Encryptor, Matrix, MatrixEntity, RowEntity, + utils::TryConvertFrom1, + Encryptor, Matrix, MatrixEntity, MultiPartyDecryptor, RowEntity, }; use itertools::Itertools; use num_traits::{ToPrimitive, Zero}; @@ -254,14 +258,50 @@ mod impl_enc_dec { }) } } + + impl MultiPartyDecryptor::R> for K + where + K: InteractiveMultiPartyClientKey, + ::R: + TryConvertFrom1<[K::Element], CiphertextModulus<::MatElement>>, + { + type DecryptionShare = ::MatElement; + + fn gen_decryption_share(&self, c: &::R) -> Self::DecryptionShare { + BoolEvaluator::with_local(|e| { + DefaultSecureRng::with_local_mut(|rng| { + multi_party_decryption_share( + c, + self.sk_rlwe().as_slice(), + e.pbs_info().modop_rlweq(), + rng, + ) + }) + }) + } + + fn aggregate_decryption_shares( + &self, + c: &::R, + shares: &[Self::DecryptionShare], + ) -> bool { + BoolEvaluator::with_local(|e| { + let noisy_m = multi_party_aggregate_decryption_shares_and_decrypt( + c, + shares, + e.pbs_info().modop_rlweq(), + ); + + e.pbs_info().rlwe_q().decode(noisy_m) + }) + } + } } #[cfg(test)] mod tests { - use std::thread::panicking; - use itertools::Itertools; - use rand::{thread_rng, RngCore}; + use rand::{thread_rng, Rng, RngCore}; use crate::{ bool::{ @@ -269,7 +309,7 @@ mod tests { keys::tests::{ideal_sk_rlwe, measure_noise_lwe}, BooleanGates, }, - Encryptor, MultiPartyDecryptor, + Encryptor, MultiPartyDecryptor, SampleExtractor, }; use super::*; @@ -363,13 +403,52 @@ mod tests { } } + #[test] + fn batched_fhe_u8s_extract_works() { + set_parameter_set(ParameterSelector::InteractiveLTE2Party); + let mut seed = [0u8; 32]; + thread_rng().fill_bytes(&mut seed); + set_mp_seed(seed); + + let parties = 2; + let cks = (0..parties).map(|_| gen_client_key()).collect_vec(); + + // round 1 + let pk_shares = cks.iter().map(|k| gen_mp_keys_phase1(k)).collect_vec(); + + // collective pk + let pk = aggregate_public_key_shares(&pk_shares); + + let parameters = BoolEvaluator::with_local(|e| e.parameters().clone()); + + let batch_size = parameters.rlwe_n().0 * 3 + 123; + let m = (0..batch_size) + .map(|_| thread_rng().gen::()) + .collect_vec(); + + let seeded_ct = pk.encrypt(m.as_slice()); + + let m_back = (0..batch_size) + .map(|i| { + let ct = seeded_ct.extract_at(i); + cks[0].aggregate_decryption_shares( + &ct, + &cks.iter() + .map(|k| k.gen_decryption_share(&ct)) + .collect_vec(), + ) + }) + .collect_vec(); + + assert_eq!(m, m_back); + } + mod sp_api { use num_traits::ToPrimitive; - use rand::Rng; use crate::{ bool::impl_bool_frontend::FheBool, pbs::PbsInfo, rgsw::seeded_secret_key_encrypt_rlwe, - Decryptor, SampleExtractor, + Decryptor, }; use super::*; @@ -501,28 +580,6 @@ mod tests { } } - #[test] - fn batch_extract_works() { - set_single_party_parameter_sets(SP_TEST_BOOL_PARAMS); - - let (ck, sk) = gen_keys(); - sk.set_server_key(); - - let batch_size = (SP_TEST_BOOL_PARAMS.rlwe_n().0 * 3 + 123); - let m = (0..batch_size) - .map(|_| thread_rng().gen::()) - .collect_vec(); - - let seeded_ct = ck.encrypt(m.as_slice()); - let ct = seeded_ct.unseed::>>(); - - let m_back = (0..batch_size) - .map(|i| ck.decrypt(&ct.extract_at(i))) - .collect_vec(); - - assert_eq!(m, m_back); - } - #[test] #[cfg(feature = "interactive_mp")] fn all_uint8_apis() { diff --git a/src/bool/ni_mp_api.rs b/src/bool/ni_mp_api.rs index 6034833..5ada3ec 100644 --- a/src/bool/ni_mp_api.rs +++ b/src/bool/ni_mp_api.rs @@ -200,11 +200,15 @@ pub(super) struct BatchedFheBools { mod impl_enc_dec { use crate::{ bool::{evaluator::BoolEncoding, keys::NonInteractiveMultiPartyClientKey}, + multi_party::{ + multi_party_aggregate_decryption_shares_and_decrypt, multi_party_decryption_share, + }, pbs::{sample_extract, PbsInfo, WithShoupRepr}, random::{NewWithSeed, RandomFillUniformInModulus}, rgsw::{rlwe_key_switch, seeded_secret_key_encrypt_rlwe}, utils::TryConvertFrom1, - Encryptor, KeySwitchWithId, Matrix, MatrixEntity, MatrixMut, RowEntity, RowMut, + Encryptor, KeySwitchWithId, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, + RowEntity, RowMut, }; use itertools::Itertools; use num_traits::{ToPrimitive, Zero}; @@ -351,6 +355,44 @@ mod impl_enc_dec { } } + impl MultiPartyDecryptor::R> for K + where + K: NonInteractiveMultiPartyClientKey, + ::R: + TryConvertFrom1<[K::Element], CiphertextModulus<::MatElement>>, + { + type DecryptionShare = ::MatElement; + + fn gen_decryption_share(&self, c: &::R) -> Self::DecryptionShare { + BoolEvaluator::with_local(|e| { + DefaultSecureRng::with_local_mut(|rng| { + multi_party_decryption_share( + c, + self.sk_rlwe().as_slice(), + e.pbs_info().modop_rlweq(), + rng, + ) + }) + }) + } + + fn aggregate_decryption_shares( + &self, + c: &::R, + shares: &[Self::DecryptionShare], + ) -> bool { + BoolEvaluator::with_local(|e| { + let noisy_m = multi_party_aggregate_decryption_shares_and_decrypt( + c, + shares, + e.pbs_info().modop_rlweq(), + ); + + e.pbs_info().rlwe_q().decode(noisy_m) + }) + } + } + impl KeySwitchWithId for Mat { /// Key switch RLWE ciphertext `Self` from user j's RLWE secret u_j /// to ideal RLWE secret `s` of non-interactive multi-party protocol. diff --git a/src/multi_party.rs b/src/multi_party.rs index 4c9361c..3f51775 100644 --- a/src/multi_party.rs +++ b/src/multi_party.rs @@ -1,13 +1,16 @@ use std::fmt::Debug; use itertools::izip; +use num_traits::Zero; use crate::{ - backend::{GetModulus, VectorOps}, + backend::{GetModulus, Modulus, VectorOps}, ntt::Ntt, - random::{RandomFillGaussianInModulus, RandomFillUniformInModulus}, + random::{ + RandomFillGaussianInModulus, RandomFillUniformInModulus, RandomGaussianElementInModulus, + }, utils::TryConvertFrom1, - Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, + ArithmeticOps, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, }; pub(crate) fn public_key_share< @@ -50,6 +53,59 @@ pub(crate) fn public_key_share< modop.elwise_add_mut(share_out.as_mut(), s.as_ref()); // s*e + e } +/// Generate decryption share for LWE ciphertext `lwe_ct` with user's secret `s` +pub(crate) fn multi_party_decryption_share< + R: RowMut + RowEntity, + Mod: Modulus, + ModOp: ArithmeticOps + VectorOps + GetModulus, + Rng: RandomGaussianElementInModulus, + S, +>( + lwe_ct: &R, + s: &[S], + mod_op: &ModOp, + rng: &mut Rng, +) -> R::Element +where + R: TryConvertFrom1<[S], Mod>, + R::Element: Zero, +{ + assert!(lwe_ct.as_ref().len() == s.len() + 1); + let mut neg_s = R::try_convert_from(s, mod_op.modulus()); + mod_op.elwise_neg_mut(neg_s.as_mut()); + + // share = (\sum -s_i * a_i) + e + let mut share = R::Element::zero(); + izip!(neg_s.as_ref().iter(), lwe_ct.as_ref().iter().skip(1)).for_each(|(si, ai)| { + share = mod_op.add(&share, &mod_op.mul(si, ai)); + }); + + let e = rng.random(mod_op.modulus()); + share = mod_op.add(&share, &e); + + share +} + +/// Aggregate decryption shares for `lwe_ct` and return noisy decryption output +/// `m + e` +pub(crate) fn multi_party_aggregate_decryption_shares_and_decrypt< + R: RowMut + RowEntity, + ModOp: ArithmeticOps, +>( + lwe_ct: &R, + shares: &[R::Element], + mod_op: &ModOp, +) -> R::Element +where + R::Element: Zero, +{ + let mut sum_shares = R::Element::zero(); + shares + .iter() + .for_each(|v| sum_shares = mod_op.add(&sum_shares, v)); + mod_op.add(&lwe_ct.as_ref()[0], &sum_shares) +} + pub(crate) fn non_interactive_rgsw_ct< M: MatrixMut + MatrixEntity, S, diff --git a/src/shortint/enc_dec.rs b/src/shortint/enc_dec.rs index c909b24..e8f8e77 100644 --- a/src/shortint/enc_dec.rs +++ b/src/shortint/enc_dec.rs @@ -8,10 +8,11 @@ use crate::{ RowMut, SampleExtractor, }; -/// Fhe UInt8 type +/// Fhe UInt8 /// -/// - Stores encryptions of bits in little endian (i.e least signficant bit -/// stored at 0th index and most signficant bit stores at 7th index) +/// Note that `Self.data` stores encryptions of bits in little endian (i.e least +/// signficant bit stored at 0th index and most signficant bit stores at 7th +/// index) #[derive(Clone)] pub struct FheUint8 { pub(super) data: Vec, @@ -27,7 +28,9 @@ impl FheUint8 { } } -/// Stored a batch of Fhe Uint8 ciphertext as collection of RLWE ciphertexts +/// Stores a batch of Fhe Uint8 ciphertext as collection of unseeded RLWE +/// ciphertexts always encrypted under the ideal RLWE secret `s` of the MPC +/// protocol /// /// To extract Fhe Uint8 ciphertext at `index` call `self.extract(index)` pub struct BatchedFheUint8 { @@ -37,6 +40,68 @@ pub struct BatchedFheUint8 { count: usize, } +impl Encryptor<[u8], BatchedFheUint8> for K +where + K: Encryptor<[bool], Vec>, +{ + /// Encrypt a batch of uint8s packed in vector of RLWE ciphertexts + /// + /// Uint8s can be extracted from `BatchedFheUint8` with `SampleExtractor` + fn encrypt(&self, m: &[u8]) -> BatchedFheUint8 { + let bool_m = m + .iter() + .flat_map(|v| { + (0..8) + .into_iter() + .map(|i| ((*v >> i) & 1) == 1) + .collect_vec() + }) + .collect_vec(); + let cts = K::encrypt(&self, &bool_m); + BatchedFheUint8 { + data: cts, + count: m.len(), + } + } +} + +impl> From<&SeededBatchedFheUint8> + for BatchedFheUint8 +where + ::R: RowMut, +{ + /// Unseeds collection of seeded RLWE ciphertext in SeededBatchedFheUint8 + /// and returns as `Self` + fn from(value: &SeededBatchedFheUint8) -> Self { + BoolEvaluator::with_local(|e| { + let parameters = e.parameters(); + let ring_size = parameters.rlwe_n().0; + let rlwe_q = parameters.rlwe_q(); + + let mut prng = DefaultSecureRng::new_seeded(value.seed); + let rlwes = value + .data + .iter() + .map(|partb| { + let mut rlwe = M::zeros(2, ring_size); + + // sample A + RandomFillUniformInModulus::random_fill(&mut prng, rlwe_q, rlwe.get_row_mut(0)); + + // Copy over B + rlwe.get_row_mut(1).copy_from_slice(partb.as_ref()); + + rlwe + }) + .collect_vec(); + Self { + data: rlwes, + count: value.count, + } + }) + } +} + impl SampleExtractor> for BatchedFheUint8 where C: SampleExtractor, @@ -82,8 +147,28 @@ where } } +/// Stores a batch of FheUint8s packed in a collection unseeded RLWE ciphertexts +/// +/// `Self` stores unseeded RLWE ciphertexts encrypted under user's RLWE secret +/// `u_j` and is different from `BatchFheUint8` which stores collection of RLWE +/// ciphertexts under ideal RLWE secret `s` of the (non-interactive/interactive) +/// MPC protocol. +/// +/// To extract FheUint8s from `Self`'s collection of RLWE ciphertexts, first +/// switch `Self` to `BatchFheUint8` with `key_switch(user_id)` where `user_id` +/// is user's id. This key switches collection of RLWE ciphertexts from +/// user's RLWE secret `u_j` to ideal RLWE secret `s` of the MPC protocol. Then +/// proceed to use `SampleExtract` on `BatchFheUint8` (for ex, call +/// `extract_at(0)` to extract FheUint8 stored at index 0) +pub struct NonInteractiveBatchedFheUint8 { + /// Vector of RLWE ciphertexts `C` + data: Vec, + /// Count of FheUint8s packed in vector of RLWE ciphertexts + count: usize, +} + impl> From<&SeededBatchedFheUint8> - for BatchedFheUint8 + for NonInteractiveBatchedFheUint8 where ::R: RowMut, { @@ -119,6 +204,27 @@ where } } +impl KeySwitchWithId> for NonInteractiveBatchedFheUint8 +where + C: KeySwitchWithId, +{ + /// Key switch `Self`'s collection of RLWE cihertexts encrypted under user's + /// RLWE secret `u_j` to ideal RLWE secret `s` of the MPC protocol. + /// + /// - user_id: user id of user `j` + fn key_switch(&self, user_id: usize) -> BatchedFheUint8 { + let data = self + .data + .iter() + .map(|c| c.key_switch(user_id)) + .collect_vec(); + BatchedFheUint8 { + data, + count: self.count, + } + } +} + pub struct SeededBatchedFheUint8 { /// Vector of Seeded RLWE ciphertexts `C`. /// @@ -131,21 +237,12 @@ pub struct SeededBatchedFheUint8 { count: usize, } -impl SeededBatchedFheUint8 { - pub fn unseed(&self) -> BatchedFheUint8 - where - BatchedFheUint8: for<'a> From<&'a SeededBatchedFheUint8>, - M: Matrix, - { - BatchedFheUint8::from(self) - } -} - impl Encryptor<[u8], SeededBatchedFheUint8> for K where K: Encryptor<[bool], (Vec, S)>, { - /// Encrypt a slice of u8s of arbitray length as `SeededBatchedFheUint8` + /// Encrypt a slice of u8s of arbitray length packed into collection of + /// seeded RLWE ciphertexts and return `SeededBatchedFheUint8` fn encrypt(&self, m: &[u8]) -> SeededBatchedFheUint8 { // convert vector of u8s to vector bools let bool_m = m @@ -161,46 +258,42 @@ where } } -impl Encryptor<[u8], BatchedFheUint8> for K -where - K: Encryptor<[bool], Vec>, -{ - fn encrypt(&self, m: &[u8]) -> BatchedFheUint8 { - let bool_m = m - .iter() - .flat_map(|v| { - (0..8) - .into_iter() - .map(|i| ((*v >> i) & 1) == 1) - .collect_vec() - }) - .collect_vec(); - let cts = K::encrypt(&self, &bool_m); - BatchedFheUint8 { - data: cts, - count: m.len(), - } - } -} - -impl KeySwitchWithId> for BatchedFheUint8 -where - C: KeySwitchWithId, -{ - /// Key switching collection of RLWE ciphertexts in `BatchedFheUint8` from - /// user j's RLWE secret u_j to ideal RLWE secret key `s` of the protocol. +impl SeededBatchedFheUint8 { + /// Unseed collection of seeded RLWE ciphertexts of `Self` and returns + /// `NonInteractiveBatchedFheUint8` with collection of unseeded RLWE + /// ciphertexts. /// - /// - user_id: user id of user j - fn key_switch(&self, user_id: usize) -> BatchedFheUint8 { - let data = self - .data - .iter() - .map(|c| c.key_switch(user_id)) - .collect_vec(); - BatchedFheUint8 { - data, - count: self.count, - } + /// In non-interactive MPC setting, RLWE ciphertexts are encrypted under + /// user's RLWE secret `u_j`. The RLWE ciphertexts must be key switched to + /// ideal RLWE secret `s` of the MPC protocol before use. + /// + /// Note that we don't provide `unseed` API from `Self` to + /// `BatchedFheUint8`. This is because: + /// + /// - In non-interactive setting (1) client encrypts private inputs using + /// their secret `u_j` as `SeededBatchedFheUint8` and sends it to the + /// server. (2) Server unseeds `SeededBatchedFheUint8` into + /// `NonInteractiveBatchedFheUint8` indicating that private inputs are + /// still encrypted under user's RLWE secret `u_j`. (3) Server key + /// switches `NonInteractiveBatchedFheUint8` from user's RLWE secret `u_j` + /// to ideal RLWE secret `s` and outputs `BatchedFheUint8`. (4) + /// `BatchedFheUint8` always stores RLWE secret under ideal RLWE secret of + /// the protocol. Hence, it is safe to extract FheUint8s. Server proceeds + /// to extract necessary FheUint8s. + /// + /// - In interactive setting (1) client always encrypts private inputs using + /// public key corresponding to ideal RLWE secret `s` of the protocol and + /// produces `BatchedFheUint8`. (2) Given `BatchedFheUint8` stores + /// collection of RLWE ciphertext under ideal RLWE secret `s`, server can + /// directly extract necessary FheUint8s to use. + /// + /// Thus, there's no need to go directly from `Self` to `BatchedFheUint8`. + pub fn unseed(&self) -> NonInteractiveBatchedFheUint8 + where + NonInteractiveBatchedFheUint8: for<'a> From<&'a SeededBatchedFheUint8>, + M: Matrix, + { + NonInteractiveBatchedFheUint8::from(self) } }