From 3236fccd7e43e72f2f4127e08d56fff8f9771fcd Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Wed, 19 Jun 2024 13:19:48 +0530 Subject: [PATCH] non-interactive example --- Cargo.toml | 9 ++ .../{fheuint8.rs => interactive_fheuint8.rs} | 37 ++++---- examples/non_interactive_fheuint8.rs | 73 +++++++++++++++ src/bool/mod.rs | 31 ++++++- src/bool/mp_api.rs | 89 ++++++++++++++++++- src/bool/ni_mp_api.rs | 2 +- src/lib.rs | 8 +- src/shortint/enc_dec.rs | 29 +++++- 8 files changed, 249 insertions(+), 29 deletions(-) rename examples/{fheuint8.rs => interactive_fheuint8.rs} (60%) create mode 100644 examples/non_interactive_fheuint8.rs diff --git a/Cargo.toml b/Cargo.toml index 20aae5f..64323eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,3 +26,12 @@ harness = false [[bench]] name = "modulus" harness = false + +[[example]] +name = "interactive_fheuint8" +path = "./examples/interactive_fheuint8.rs" + +[[example]] +name = "non_interactive_fheuint8" +path = "./examples/non_interactive_fheuint8.rs" +required-features = ["non_interactive_mp"] \ No newline at end of file diff --git a/examples/fheuint8.rs b/examples/interactive_fheuint8.rs similarity index 60% rename from examples/fheuint8.rs rename to examples/interactive_fheuint8.rs index db0e358..26ea2ea 100644 --- a/examples/fheuint8.rs +++ b/examples/interactive_fheuint8.rs @@ -6,13 +6,13 @@ fn plain_circuit(a: u8, b: u8, c: u8) -> u8 { (a + b) * c } -// fn fhe_circuit(fhe_a: &FheUint8, fhe_b: &FheUint8, fhe_c: &FheUint8) -> -// FheUint8 { &(fhe_a + fhe_b) * fhe_c -// } +fn fhe_circuit(fhe_a: &FheUint8, fhe_b: &FheUint8, fhe_c: &FheUint8) -> FheUint8 { + &(fhe_a + fhe_b) * fhe_c +} fn main() { set_parameter_set(ParameterSelector::MultiPartyLessThanOrEqualTo16); - let no_of_parties = 2; + let no_of_parties = 8; let client_keys = (0..no_of_parties) .into_iter() .map(|_| gen_client_key()) @@ -50,25 +50,22 @@ fn main() { let fhe_b = public_key.encrypt(&b); let fhe_c = public_key.encrypt(&c); - let fhe_batched = public_key.encrypt(vec![12, 3u8].as_slice()); - // fhe evaluation - // let now = std::time::Instant::now(); - // let fhe_out = fhe_circuit(&fhe_a, &fhe_b, &fhe_c); - // println!("Circuit time: {:?}", now.elapsed()); + let now = std::time::Instant::now(); + let fhe_out = fhe_circuit(&fhe_a, &fhe_b, &fhe_c); + println!("Circuit time: {:?}", now.elapsed()); - // // plain evaluation - // let out = plain_circuit(a, b, c); + // plain evaluation + let out = plain_circuit(a, b, c); - // // generate decryption shares to decrypt ciphertext fhe_out - // let decryption_shares = client_keys - // .iter() - // .map(|k| k.gen_decryption_share(&fhe_out)) - // .collect_vec(); + // generate decryption shares to decrypt ciphertext fhe_out + let decryption_shares = client_keys + .iter() + .map(|k| k.gen_decryption_share(&fhe_out)) + .collect_vec(); - // // decrypt fhe_out using decryption shares - // let got_out = client_keys[0].aggregate_decryption_shares(&fhe_out, - // &decryption_shares); + // decrypt fhe_out using decryption shares + let got_out = client_keys[0].aggregate_decryption_shares(&fhe_out, &decryption_shares); - // assert_eq!(got_out, out); + assert_eq!(got_out, out); } diff --git a/examples/non_interactive_fheuint8.rs b/examples/non_interactive_fheuint8.rs new file mode 100644 index 0000000..1c8b10b --- /dev/null +++ b/examples/non_interactive_fheuint8.rs @@ -0,0 +1,73 @@ +use bin_rs::*; +use itertools::Itertools; +use rand::{thread_rng, Rng, RngCore}; + +fn circuit(a: u8, b: u8, c: u8, d: u8) -> u8 { + ((a + b) * c) * d +} + +fn fhe_circuit(a: &FheUint8, b: &FheUint8, c: &FheUint8, d: &FheUint8) -> FheUint8 { + &(&(a + b) * c) * d +} + +fn main() { + set_parameter_set(ParameterSelector::NonInteractiveMultiPartyLessThanOrEqualTo16); + + // set CRS + let mut seed = [0u8; 32]; + thread_rng().fill_bytes(&mut seed); + set_common_reference_seed(seed); + + let no_of_parties = 2; + + // Generate client keys + let cks = (0..no_of_parties).map(|_| gen_client_key()).collect_vec(); + + // client 0 encrypts private input + let c0_a = thread_rng().gen::(); + let c0_b = thread_rng().gen::(); + let c0_batched_to_send = cks[0].encrypt(vec![c0_a, c0_b].as_slice()); + + // client 1 encrypts private input + let c1_a = thread_rng().gen::(); + let c1_b = thread_rng().gen::(); + let c1_batch_to_send = cks[1].encrypt(vec![c1_a, c1_b].as_slice()); + + // Both client indenpendently generate their server key shares + let server_key_shares = cks + .iter() + .enumerate() + .map(|(id, k)| gen_server_key_share(id, no_of_parties, k)) + .collect_vec(); + + // Server side + + // aggregates shares and generates server key + let server_key = aggregate_server_key_shares(&server_key_shares); + server_key.set_server_key(); + + // extract a and b from client0 inputs + let (ct_c0_a, ct_c0_b) = { + let ct = c0_batched_to_send.unseed::>>().key_switch(0); + (ct.extract(0), ct.extract(1)) + }; + + // extract a and b from client1 inputs + let (ct_c1_a, ct_c1_b) = { + let ct = c1_batch_to_send.unseed::>>().key_switch(1); + (ct.extract(0), ct.extract(1)) + }; + + let now = std::time::Instant::now(); + let c_out = fhe_circuit(&ct_c0_a, &ct_c1_a, &ct_c0_b, &ct_c1_b); + println!("Circuit Time: {:?}", now.elapsed()); + + // decrypt c_out + let decryption_shares = cks + .iter() + .map(|k| k.gen_decryption_share(&c_out)) + .collect_vec(); + let m_out = cks[0].aggregate_decryption_shares(&c_out, &decryption_shares); + let m_expected = circuit(c0_a, c1_a, c0_b, c1_b); + assert!(m_expected == m_out); +} diff --git a/src/bool/mod.rs b/src/bool/mod.rs index b8224d9..e6b5672 100644 --- a/src/bool/mod.rs +++ b/src/bool/mod.rs @@ -1,16 +1,21 @@ pub(crate) mod evaluator; mod keys; -mod mp_api; -mod ni_mp_api; mod noise; pub(crate) mod parameters; pub(crate) use keys::PublicKey; +#[cfg(feature = "interactive_mp")] +#[cfg(not(feature = "non_interactive_mp"))] +mod mp_api; +#[cfg(feature = "non_interactive_mp")] +mod ni_mp_api; + #[cfg(feature = "non_interactive_mp")] pub use ni_mp_api::*; #[cfg(feature = "interactive_mp")] +#[cfg(not(feature = "non_interactive_mp"))] pub use mp_api::*; pub type ClientKey = keys::ClientKey<[u8; 32], u64>; @@ -22,7 +27,11 @@ pub enum ParameterSelector { mod common_mp_enc_dec { use super::BoolEvaluator; - use crate::{utils::WithLocal, Matrix, MultiPartyDecryptor}; + use crate::{ + pbs::{sample_extract, PbsInfo}, + utils::WithLocal, + Matrix, MultiPartyDecryptor, RowEntity, SampleExtractor, + }; type Mat = Vec>; @@ -41,4 +50,20 @@ mod common_mp_enc_dec { BoolEvaluator::with_local(|e| e.multi_party_decrypt(shares, c)) } } + + impl SampleExtractor<::R> for Mat { + fn extract(&self, index: usize) -> ::R { + // input is RLWE ciphertext + assert!(self.dimension().0 == 2); + + let ring_size = self.dimension().1; + assert!(index < ring_size); + + BoolEvaluator::with_local(|e| { + let mut lwe_out = ::R::zeros(ring_size + 1); + sample_extract(&mut lwe_out, self, e.pbs_info().modop_rlweq(), index); + lwe_out + }) + } + } } diff --git a/src/bool/mp_api.rs b/src/bool/mp_api.rs index 7bb20af..7d745b7 100644 --- a/src/bool/mp_api.rs +++ b/src/bool/mp_api.rs @@ -9,7 +9,7 @@ use crate::{ use super::{evaluator::MultiPartyCrs, keys::*, parameters::*, ClientKey, ParameterSelector}; -pub type BoolEvaluator = super::evaluator::BoolEvaluator< +pub(crate) type BoolEvaluator = super::evaluator::BoolEvaluator< Vec>, NttBackendU64, ModularOpsU64>, @@ -304,8 +304,13 @@ mod tests { } mod sp_api { + use num_traits::ToPrimitive; + use rand::Rng; + use crate::{ - backend::ModulusPowerOf2, utils::WithLocal, Decryptor, ModularOpsU64, NttBackendU64, + backend::ModulusPowerOf2, evaluator::BoolEncoding, pbs::PbsInfo, + rgsw::secret_key_encrypt_rlwe, utils::WithLocal, Decryptor, ModularOpsU64, + NttBackendU64, SampleExtractor, }; use super::*; @@ -358,6 +363,86 @@ mod tests { BoolEvaluator::with_local(|e| e.sk_decrypt(c, self)) } } + impl Encryptor<[bool], (Vec>, [u8; 32])> for K + where + K: SinglePartyClientKey, + { + fn encrypt(&self, m: &[bool]) -> (Vec>, [u8; 32]) { + BoolEvaluator::with_local(|e| { + DefaultSecureRng::with_local_mut(|rng| { + let parameters = e.parameters(); + let ring_size = parameters.rlwe_n().0; + + let rlwe_count = ((m.len() as f64 / ring_size as f64).ceil()) + .to_usize() + .unwrap(); + + let mut seed = ::Seed::default(); + rng.fill_bytes(&mut seed); + let mut prng = DefaultSecureRng::new_seeded(seed); + + let sk_u = self.sk_rlwe(); + + // encrypt `m` into ceil(len(m)/N) RLWE ciphertexts + let rlwes = (0..rlwe_count) + .map(|index| { + let mut message = vec![0; ring_size]; + m[(index * ring_size) + ..std::cmp::min(m.len(), (index + 1) * ring_size)] + .iter() + .enumerate() + .for_each(|(i, v)| { + if *v { + message[i] = parameters.rlwe_q().true_el() + } else { + message[i] = parameters.rlwe_q().false_el() + } + }); + + // encrypt message + let mut rlwe_out = vec![0u64; parameters.rlwe_n().0]; + + secret_key_encrypt_rlwe( + &message, + &mut rlwe_out, + &sk_u, + e.pbs_info().modop_rlweq(), + e.pbs_info().nttop_rlweq(), + &mut prng, + rng, + ); + + rlwe_out + }) + .collect_vec(); + + (rlwes, seed) + }) + }) + } + } + + #[test] + fn batch_extract_works() { + set_single_party_parameter_sets(SP_TEST_BOOL_PARAMS); + + let (ck, sk) = gen_keys(); + sk.set_server_key(); + + let batch_size = (SP_TEST_BOOL_PARAMS.rlwe_n().0 * 3 + 123); + let m = (0..batch_size) + .map(|_| thread_rng().gen::()) + .collect_vec(); + + let seeded_ct = ck.encrypt(m.as_slice()); + let ct = seeded_ct.unseed::>>(); + + let m_back = (0..batch_size) + .map(|i| ck.decrypt(&ct.extract(i))) + .collect_vec(); + + assert_eq!(m, m_back); + } #[test] #[cfg(feature = "interactive_mp")] diff --git a/src/bool/ni_mp_api.rs b/src/bool/ni_mp_api.rs index e26cc40..97610bc 100644 --- a/src/bool/ni_mp_api.rs +++ b/src/bool/ni_mp_api.rs @@ -19,7 +19,7 @@ use super::{ ClientKey, ParameterSelector, }; -pub type BoolEvaluator = super::evaluator::BoolEvaluator< +pub(crate) type BoolEvaluator = super::evaluator::BoolEvaluator< Vec>, NttBackendU64, ModularOpsU64>, diff --git a/src/lib.rs b/src/lib.rs index 8c7d242..955be3b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,8 +23,10 @@ pub use backend::{ // aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key, // gen_mp_keys_phase1, gen_mp_keys_phase2, set_mp_seed, set_parameter_set, // ParameterSelector, }; +pub use bool::*; pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer}; pub use ntt::{Ntt, NttBackendU64, NttInit}; +pub use shortint::FheUint8; pub trait Matrix: AsRef<[Self::R]> { type MatElement; @@ -184,6 +186,10 @@ pub trait KeySwitchWithId { fn key_switch(&self, user_id: usize) -> C; } -pub(crate) trait Encoder { +pub trait SampleExtractor { + fn extract(&self, index: usize) -> R; +} + +trait Encoder { fn encode(&self, v: F) -> T; } diff --git a/src/shortint/enc_dec.rs b/src/shortint/enc_dec.rs index a8ee4c1..cce46f0 100644 --- a/src/shortint/enc_dec.rs +++ b/src/shortint/enc_dec.rs @@ -5,9 +5,13 @@ use crate::{ random::{DefaultSecureRng, RandomFillUniformInModulus}, utils::{TryConvertFrom1, WithLocal}, Decryptor, Encryptor, KeySwitchWithId, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, - RowMut, + RowMut, SampleExtractor, }; +/// Fhe UInt8 type +/// +/// - Stores encryptions of bits in little endian (i.e least signficant bit +/// stored at 0th index and most signficant bit stores at 7th index) #[derive(Clone)] pub struct FheUint8 { pub(super) data: Vec, @@ -27,6 +31,28 @@ pub struct BatchedFheUint8 { data: Vec, } +impl SampleExtractor> for BatchedFheUint8 +where + C: SampleExtractor, +{ + fn extract(&self, index: usize) -> FheUint8 { + BoolEvaluator::with_local(|e| { + let ring_size = e.parameters().rlwe_n().0; + + let start_index = index * 8; + let end_index = (index + 1) * 8; + let data = (start_index..end_index) + .map(|i| { + let rlwe_index = i / ring_size; + let coeff_index = i % ring_size; + self.data[rlwe_index].extract(coeff_index) + }) + .collect_vec(); + FheUint8 { data } + }) + } +} + impl> From<&SeededBatchedFheUint8> for BatchedFheUint8 where @@ -85,7 +111,6 @@ where .flat_map(|v| (0..8).into_iter().map(|i| (((*v) >> i) & 1) == 1)) .collect_vec(); let (cts, seed) = K::encrypt(&self, &m); - dbg!(cts.len()); SeededBatchedFheUint8 { data: cts, seed } } }