diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index de35acf..3f186ec 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -1,5 +1,6 @@ use std::{ cell::{OnceCell, RefCell}, + clone, collections::HashMap, fmt::{Debug, Display}, iter::Once, @@ -20,8 +21,8 @@ use crate::{ multi_party::public_key_share, ntt::{self, Ntt, NttBackendU64, NttInit}, random::{ - DefaultSecureRng, NewWithSeed, RandomFillGaussianInModulus, RandomFillUniformInModulus, - RandomGaussianElementInModulus, + DefaultSecureRng, NewWithSeed, RandomFill, RandomFillGaussianInModulus, + RandomFillUniformInModulus, RandomGaussianElementInModulus, }, rgsw::{ decrypt_rlwe, galois_auto, galois_key_gen, generate_auto_map, public_key_encrypt_rgsw, @@ -32,10 +33,11 @@ use crate::{ fill_random_ternary_secret_with_hamming_weight, generate_prime, mod_exponent, Global, TryConvertFrom1, WithLocal, }, - Decryptor, Encryptor, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, + Decryptor, Encryptor, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, Row, RowEntity, + RowMut, Secret, }; -use super::parameters::{BoolParameters, CiphertextModulus}; +use super::parameters::{self, BoolParameters, CiphertextModulus}; thread_local! { pub(crate) static BOOL_EVALUATOR: RefCell>, NttBackendU64, ModularOpsU64>, ModularOpsU64>>> = RefCell::new(BoolEvaluator::new(MP_BOOL_PARAMS)); @@ -45,10 +47,19 @@ pub(crate) static BOOL_SERVER_KEY: OnceLock< ServerKeyEvaluationDomain>, DefaultSecureRng, NttBackendU64>, > = OnceLock::new(); +pub(crate) static MULTI_PARTY_CRS: OnceLock> = OnceLock::new(); + pub fn set_parameter_set(parameter: &BoolParameters) { BoolEvaluator::with_local_mut(|e| *e = BoolEvaluator::new(parameter.clone())) } +pub fn set_mp_seed(seed: [u8; 32]) { + assert!( + MULTI_PARTY_CRS.set(MultiPartyCrs { seed: seed }).is_ok(), + "Attempted to set MP SEED twice." + ) +} + fn set_server_key(key: ServerKeyEvaluationDomain>, DefaultSecureRng, NttBackendU64>) { assert!( BOOL_SERVER_KEY.set(key).is_ok(), @@ -56,7 +67,7 @@ fn set_server_key(key: ServerKeyEvaluationDomain>, DefaultSecureRng ); } -pub fn gen_keys() -> ( +pub(crate) fn gen_keys() -> ( ClientKey, SeededServerKey>, BoolParameters, [u8; 32]>, ) { @@ -67,6 +78,93 @@ pub fn gen_keys() -> ( (ck, sk) }) } + +pub fn gen_client_key() -> ClientKey { + BoolEvaluator::with_local(|e| e.client_key()) +} + +pub fn gen_mp_keys_phase1( + ck: &ClientKey, +) -> CommonReferenceSeededCollectivePublicKeyShare, [u8; 32], BoolParameters> { + let seed = MultiPartyCrs::global().public_key_share_seed::(); + BoolEvaluator::with_local(|e| { + let pk_share = e.multi_party_public_key_share(seed, &ck); + pk_share + }) +} + +pub fn gen_mp_keys_phase2( + ck: &ClientKey, + pk: &PublicKey>, R, ModOp>, +) -> CommonReferenceSeededMultiPartyServerKeyShare>, BoolParameters, [u8; 32]> { + let seed = MultiPartyCrs::global().server_key_share_seed::(); + BoolEvaluator::with_local_mut(|e| { + let server_key_share = e.multi_party_server_key_share(seed, &pk.key, ck); + server_key_share + }) +} + +pub fn aggregate_public_key_shares( + shares: &[CommonReferenceSeededCollectivePublicKeyShare< + Vec, + [u8; 32], + BoolParameters, + >], +) -> PublicKey>, DefaultSecureRng, ModularOpsU64>> { + PublicKey::from(shares) +} + +pub fn aggregate_server_key_shares( + shares: &[CommonReferenceSeededMultiPartyServerKeyShare< + Vec>, + BoolParameters, + [u8; 32], + >], +) -> SeededMultiPartyServerKey>, [u8; 32], BoolParameters> { + BoolEvaluator::with_local(|e| e.aggregate_multi_party_server_key_shares(shares)) +} + +// GENERIC BELOW + +pub struct MultiPartyCrs { + seed: S, +} + +impl MultiPartyCrs { + /// Seed to generate public key share using MultiPartyCrs as the main seed. + /// + /// Public key seed equals the 1st seed extracted from PRNG Seeded with + /// MiltiPartyCrs's seed. + fn public_key_share_seed + RandomFill>(&self) -> S { + let mut prng = Rng::new_with_seed(self.seed); + + let mut seed = S::default(); + RandomFill::::random_fill(&mut prng, &mut seed); + seed + } + + /// Seed to generate server key share using MultiPartyCrs as the main seed. + /// + /// Server key seed equals the 2nd seed extracted from PRNG Seeded with + /// MiltiPartyCrs's seed. + fn server_key_share_seed + RandomFill>(&self) -> S { + let mut prng = Rng::new_with_seed(self.seed); + + let mut seed = S::default(); + RandomFill::::random_fill(&mut prng, &mut seed); + RandomFill::::random_fill(&mut prng, &mut seed); + seed + } +} + +impl Global for MultiPartyCrs<[u8; 32]> { + fn global() -> &'static Self { + MULTI_PARTY_CRS + .get() + .expect("Multi Party Common Reference String not set") + } +} + pub(crate) trait BooleanGates { type Ciphertext: RowEntity; type Key; @@ -323,19 +421,116 @@ impl Decryptor> for ClientKey { } } -struct MultiPartyDecryptionShare { - share: E, +impl MultiPartyDecryptor> for ClientKey { + type DecryptionShare = u64; + + fn gen_decryption_share(&self, c: &Vec) -> Self::DecryptionShare { + BoolEvaluator::with_local(|e| e.multi_party_decryption_share(c, &self)) + } + + fn aggregate_decryption_shares(&self, c: &Vec, shares: &[Self::DecryptionShare]) -> bool { + BoolEvaluator::with_local(|e| e.multi_party_decrypt(shares, c)) + } } -struct CommonReferenceSeededCollectivePublicKeyShare { +// struct MultiPartyDecryptionShare { +// share: E, +// } + +pub struct CommonReferenceSeededCollectivePublicKeyShare { share: R, cr_seed: S, parameters: P, } -struct PublicKey { +struct SeededPublicKey { + part_b: R, + seed: S, + parameters: P, + _phantom: PhantomData, +} + +impl + From<&[CommonReferenceSeededCollectivePublicKeyShare>]> + for SeededPublicKey, ModOp> +where + ModOp: VectorOps + ModInit>, + S: PartialEq + Clone, + R: RowMut + RowEntity + Clone, + R::Element: Clone + PartialEq, +{ + fn from( + value: &[CommonReferenceSeededCollectivePublicKeyShare>], + ) -> Self { + assert!(value.len() > 0); + + let parameters = &value[0].parameters; + let cr_seed = value[0].cr_seed.clone(); + + // Sum all Bs + let rlweq_modop = ModOp::new(parameters.rlwe_q().clone()); + let mut part_b = value[0].share.clone(); + value.iter().skip(1).for_each(|share_i| { + assert!(&share_i.cr_seed == &cr_seed); + assert!(&share_i.parameters == parameters); + + rlweq_modop.elwise_add_mut(part_b.as_mut(), share_i.share.as_ref()); + }); + + Self { + part_b, + seed: cr_seed, + parameters: parameters.clone(), + _phantom: PhantomData, + } + } +} + +pub struct PublicKey { key: M, - _phantom: PhantomData<(R, O)>, + _phantom: PhantomData<(Rng, ModOp)>, +} + +impl Encryptor> for PublicKey>, Rng, ModOp> { + fn encrypt(&self, m: &bool) -> Vec { + BoolEvaluator::with_local(|e| e.pk_encrypt(&self.key, *m)) + } +} + +impl Encryptor<[bool], Vec>> for PublicKey>, Rng, ModOp> { + fn encrypt(&self, m: &[bool]) -> Vec> { + BoolEvaluator::with_local(|e| e.pk_encrypt_batched(&self.key, m)) + } +} + +impl< + M: MatrixMut + MatrixEntity, + Rng: NewWithSeed + RandomFillUniformInModulus<[M::MatElement], CiphertextModulus>, + ModOp, + > From, ModOp>> + for PublicKey +where + ::R: RowMut, + M::MatElement: Copy, +{ + fn from(value: SeededPublicKey, ModOp>) -> Self { + let mut prng = Rng::new_with_seed(value.seed); + + let mut key = M::zeros(2, value.parameters.rlwe_n().0); + // sample A + RandomFillUniformInModulus::random_fill( + &mut prng, + value.parameters.rlwe_q(), + key.get_row_mut(0), + ); + // Copy over B + key.get_row_mut(1).copy_from_slice(value.part_b.as_ref()); + + PublicKey { + key, + _phantom: PhantomData, + } + } } impl< @@ -392,7 +587,7 @@ where } } -struct CommonReferenceSeededMultiPartyServerKeyShare { +pub struct CommonReferenceSeededMultiPartyServerKeyShare { rgsw_cts: Vec, /// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding /// to -g is at 0 @@ -402,7 +597,7 @@ struct CommonReferenceSeededMultiPartyServerKeyShare { cr_seed: S, parameters: P, } -struct SeededMultiPartyServerKey { +pub struct SeededMultiPartyServerKey { rgsw_cts: Vec, /// Auto keys. Key corresponding to g^{k} is at index `k`. Key corresponding /// to -g is at 0 @@ -412,6 +607,22 @@ struct SeededMultiPartyServerKey { parameters: P, } +impl + SeededMultiPartyServerKey< + Vec>, + ::Seed, + BoolParameters, + > +{ + pub fn set_server_key(&self) { + set_server_key(ServerKeyEvaluationDomain::< + Vec>, + DefaultSecureRng, + NttBackendU64, + >::from(self)) + } +} + /// Seeded single party server key pub struct SeededServerKey { /// Rgsw cts of LWE secret elements @@ -709,7 +920,6 @@ where } } - impl PbsKey for ServerKeyEvaluationDomain { type M = M; fn galois_key_for_auto(&self, k: usize) -> &Self::M { @@ -1241,7 +1451,7 @@ where &self, lwe_ct: &M::R, client_key: &ClientKey, - ) -> MultiPartyDecryptionShare<::MatElement> { + ) -> ::MatElement { assert!(lwe_ct.as_ref().len() == self.pbs_info.parameters.rlwe_n().0 + 1); let modop = &self.pbs_info.rlwe_modop; let mut neg_s = M::R::try_convert_from( @@ -1262,34 +1472,44 @@ where }); let share = modop.add(&neg_sa, &e); - MultiPartyDecryptionShare { share } + share } - pub(crate) fn multi_party_decrypt( - &self, - shares: &[MultiPartyDecryptionShare], - lwe_ct: &M::R, - ) -> bool { + pub(crate) fn multi_party_decrypt(&self, shares: &[M::MatElement], lwe_ct: &M::R) -> bool { let modop = &self.pbs_info.rlwe_modop; let mut sum_a = M::MatElement::zero(); shares .iter() - .for_each(|share_i| sum_a = modop.add(&sum_a, &share_i.share)); + .for_each(|share_i| sum_a = modop.add(&sum_a, &share_i)); let encoded_m = modop.add(&lwe_ct.as_ref()[0], &sum_a); self.pbs_info.parameters.rlwe_q().decode(encoded_m) } - /// First encrypt as RLWE(m) with m as constant polynomial and extract it as - /// LWE ciphertext pub(crate) fn pk_encrypt(&self, pk: &M, m: bool) -> M::R { + self.pk_encrypt_batched(pk, &vec![m]).remove(0) + } + + /// Encrypts a batch booleans as multiple LWE ciphertexts. + /// + /// For public key encryption we first encrypt `m` as a RLWE ciphertext and + /// then sample extract LWE samples at required indices. + /// + /// - TODO(Jay:) Communication can be improved by not sample exctracting and + /// instead just truncate degree 0 values (part Bs) + pub(crate) fn pk_encrypt_batched(&self, pk: &M, m: &[bool]) -> Vec { DefaultSecureRng::with_local_mut(|rng| { + let ring_size = self.pbs_info.parameters.rlwe_n().0; + assert!( + m.len() <= ring_size, + "Cannot batch encrypt > ring_size{ring_size} elements at once" + ); + let modop = &self.pbs_info.rlwe_modop; let nttop = &self.pbs_info.rlwe_nttop; // RLWE(0) // sample ephemeral key u - let ring_size = self.pbs_info.parameters.rlwe_n().0; let mut u = vec![0i32; ring_size]; fill_random_ternary_secret_with_hamming_weight(u.as_mut(), ring_size >> 1, rng); let mut u = M::R::try_convert_from(&u, &self.pbs_info.parameters.rlwe_q()); @@ -1326,22 +1546,31 @@ where modop.elwise_add_mut(rlwe.get_row_mut(1), ub.as_ref()); //FIXME(Jay): Figure out a way to get Q/8 form modulus - let m = if m { - // Q/8 - self.pbs_info.rlwe_q().true_el() - } else { - // -Q/8 - self.pbs_info.rlwe_q().false_el() - }; - - // b*u + e1 + m, where m is constant polynomial - rlwe.set(1, 0, modop.add(rlwe.get(1, 0), &m)); - - // sample extract index 0 - let mut lwe_out = M::R::zeros(ring_size + 1); - sample_extract(&mut lwe_out, &rlwe, modop, 0); + let mut m_vec = M::R::zeros(ring_size); + izip!(m_vec.as_mut().iter_mut(), m.iter()).for_each(|(m_el, m_bool)| { + if *m_bool { + // Q/8 + *m_el = self.pbs_info.rlwe_q().true_el() + } else { + // -Q/8 + *m_el = self.pbs_info.rlwe_q().false_el() + } + }); - lwe_out + // b*u + e1 + m + modop.elwise_add_mut(rlwe.get_row_mut(1), m_vec.as_ref()); + // rlwe.set(1, 0, modop.add(rlwe.get(1, 0), &m)); + + // sample extract index required indices + let samples = m.len(); + (0..samples) + .into_iter() + .map(|i| { + let mut lwe_out = M::R::zeros(ring_size + 1); + sample_extract(&mut lwe_out, &rlwe, modop, i); + lwe_out + }) + .collect_vec() }) } @@ -2103,7 +2332,6 @@ fn pbs, K: PbsK pbs_key, ); - // sample extract sample_extract(lwe_in, &trivial_rlwe_test_poly, pbs_info.modop_rlweq(), 0); } @@ -2731,7 +2959,7 @@ mod tests { >::new(MP_BOOL_PARAMS); let (parties, collective_pk, _, _, server_key_eval, ideal_client_key) = - _multi_party_all_keygen(&bool_evaluator, 8); + _multi_party_all_keygen(&bool_evaluator, 64); let mut m0 = true; let mut m1 = false; diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs index 5aac433..1a64401 100644 --- a/src/bool/parameters.rs +++ b/src/bool/parameters.rs @@ -307,10 +307,10 @@ pub(crate) const SP_BOOL_PARAMS: BoolParameters = BoolParameters:: { auto_decomposer_base: DecompostionLogBase(7), auto_decomposer_count: DecompositionCount(4), g: 5, - w: 10, + w: 5, }; -pub(super) const MP_BOOL_PARAMS: BoolParameters = BoolParameters:: { +pub(crate) const MP_BOOL_PARAMS: BoolParameters = BoolParameters:: { rlwe_q: CiphertextModulus::new_non_native(1152921504606830593), lwe_q: CiphertextModulus::new_non_native(1 << 20), br_q: 1 << 11, @@ -325,7 +325,7 @@ pub(super) const MP_BOOL_PARAMS: BoolParameters = BoolParameters:: { auto_decomposer_base: DecompostionLogBase(12), auto_decomposer_count: DecompositionCount(5), g: 5, - w: 5, + w: 10, }; #[cfg(test)] diff --git a/src/lib.rs b/src/lib.rs index 43e6a5f..ed40987 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +use std::{iter::Once, sync::OnceLock}; + use itertools::{izip, Itertools}; use num::UnsignedInteger; use num_traits::{abs, Zero}; @@ -156,10 +158,17 @@ impl RowEntity for Vec { } } -trait Encryptor { +trait Encryptor { fn encrypt(&self, m: &M) -> C; } -trait Decryptor { +trait Decryptor { fn decrypt(&self, c: &C) -> M; } + +trait MultiPartyDecryptor { + type DecryptionShare; + + fn gen_decryption_share(&self, c: &C) -> Self::DecryptionShare; + fn aggregate_decryption_shares(&self, c: &C, shares: &[Self::DecryptionShare]) -> M; +} diff --git a/src/random.rs b/src/random.rs index acc743e..78b71ab 100644 --- a/src/random.rs +++ b/src/random.rs @@ -138,6 +138,21 @@ where } } +impl RandomFill<[T; 32]> for DefaultSecureRng +where + T: PrimInt + SampleUniform, +{ + fn random_fill(&mut self, container: &mut [T; 32]) { + izip!( + (&mut self.rng).sample_iter(Uniform::new_inclusive(T::zero(), T::max_value())), + container.iter_mut() + ) + .for_each(|(from, to)| { + *to = from; + }); + } +} + impl RandomElement for DefaultSecureRng where T: PrimInt + SampleUniform, diff --git a/src/shortint/mod.rs b/src/shortint/mod.rs index 67f0e36..140efbe 100644 --- a/src/shortint/mod.rs +++ b/src/shortint/mod.rs @@ -1,9 +1,11 @@ use itertools::Itertools; use crate::{ - bool::evaluator::{BoolEvaluator, ClientKey, ServerKeyEvaluationDomain, BOOL_SERVER_KEY}, + bool::evaluator::{ + BoolEvaluator, ClientKey, PublicKey, ServerKeyEvaluationDomain, BOOL_SERVER_KEY, + }, utils::{Global, WithLocal}, - Decryptor, Encryptor, + Decryptor, Encryptor, Matrix, MultiPartyDecryptor, }; mod ops; @@ -26,6 +28,7 @@ impl Encryptor for ClientKey { impl Decryptor for ClientKey { fn decrypt(&self, c: &FheUint8) -> u8 { + assert!(c.data.len() == 8); let mut out = 0u8; c.data().iter().enumerate().for_each(|(index, bit_c)| { let bool = Decryptor::>::decrypt(self, bit_c); @@ -37,6 +40,60 @@ impl Decryptor for ClientKey { } } +impl Encryptor for PublicKey +where + PublicKey: Encryptor>, +{ + fn encrypt(&self, m: &u8) -> FheUint8 { + let cts = (0..8) + .into_iter() + .map(|i| { + let bit = ((m >> i) & 1) == 1; + Encryptor::>::encrypt(self, &bit) + }) + .collect_vec(); + FheUint8 { data: cts } + } +} + +impl MultiPartyDecryptor for ClientKey +where + ClientKey: MultiPartyDecryptor>, +{ + type DecryptionShare = Vec<>>::DecryptionShare>; + fn gen_decryption_share(&self, c: &FheUint8) -> Self::DecryptionShare { + assert!(c.data().len() == 8); + c.data() + .iter() + .map(|bit_c| { + let decryption_share = + MultiPartyDecryptor::>::gen_decryption_share(self, bit_c); + decryption_share + }) + .collect_vec() + } + + fn aggregate_decryption_shares(&self, c: &FheUint8, shares: &[Self::DecryptionShare]) -> u8 { + let mut out = 0u8; + + (0..8).into_iter().for_each(|i| { + // Collect bit i^th decryption share of each party + let bit_i_decryption_shares = shares.iter().map(|s| s[i]).collect_vec(); + let bit_i = MultiPartyDecryptor::>::aggregate_decryption_shares( + self, + &c.data()[i], + &bit_i_decryption_shares, + ); + + if bit_i { + out += 1 << i; + } + }); + + out + } +} + mod frontend { use super::ops::{ arbitrary_bit_adder, arbitrary_bit_division_for_quotient_and_rem, arbitrary_bit_subtractor, @@ -245,15 +302,20 @@ mod frontend { #[cfg(test)] mod tests { + use itertools::Itertools; use num_traits::Euclid; use crate::{ bool::{ - evaluator::{gen_keys, set_parameter_set, BoolEvaluator}, - parameters::SP_BOOL_PARAMS, + evaluator::{ + aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key, gen_keys, + gen_mp_keys_phase1, gen_mp_keys_phase2, set_mp_seed, set_parameter_set, + BoolEvaluator, ClientKey, + }, + parameters::{MP_BOOL_PARAMS, SP_BOOL_PARAMS}, }, shortint::types::FheUint8, - Decryptor, Encryptor, + Decryptor, Encryptor, MultiPartyDecryptor, }; #[test] @@ -403,4 +465,57 @@ mod tests { } } } + + #[test] + fn fheuint8_test_multi_party() { + set_parameter_set(&MP_BOOL_PARAMS); + set_mp_seed([0; 32]); + + let parties = 8; + + // client keys and public key share + let cks = (0..parties) + .into_iter() + .map(|i| gen_client_key()) + .collect_vec(); + + // round 1: generate pulic key shares + let pk_shares = cks.iter().map(|key| gen_mp_keys_phase1(key)).collect_vec(); + + let public_key = aggregate_public_key_shares(&pk_shares); + + // round 2: generate server key shares + let server_key_shares = cks + .iter() + .map(|key| gen_mp_keys_phase2(key, &public_key)) + .collect_vec(); + + // server aggregates the server key + let server_key = aggregate_server_key_shares(&server_key_shares); + server_key.set_server_key(); + + // Clients use Pk to encrypt private inputs + let a = 8u8; + let b = 10u8; + let c = 155u8; + let ct_a = public_key.encrypt(&a); + let ct_b = public_key.encrypt(&b); + let ct_c = public_key.encrypt(&c); + + // server computes + // a*b + c + let mut ct_ab = &ct_a * &ct_b; + ct_ab += &ct_c; + + // decrypt ab and check + // generate decryption shares + let decryption_shares = cks + .iter() + .map(|k| k.gen_decryption_share(&ct_ab)) + .collect_vec(); + + // aggregate and decryption ab + let ab_add_c = cks[0].aggregate_decryption_shares(&ct_ab, &decryption_shares); + assert!(ab_add_c == (a.wrapping_mul(b)).wrapping_add(c)); + } }