diff --git a/examples/fheuint8.rs b/examples/fheuint8.rs new file mode 100644 index 0000000..85c7a1c --- /dev/null +++ b/examples/fheuint8.rs @@ -0,0 +1,71 @@ +use bin_rs::*; +use itertools::Itertools; +use rand::{thread_rng, RngCore}; + +fn plain_circuit(a: u8, b: u8, c: u8) -> u8 { + (a + b) * c +} + +fn fhe_circuit(fhe_a: &FheUint8, fhe_b: &FheUint8, fhe_c: &FheUint8) -> FheUint8 { + &(fhe_a + fhe_b) * fhe_c +} + +fn main() { + set_parameter_set(ParameterSelector::MultiPartyLessThan16); + let no_of_parties = 2; + let client_keys = (0..no_of_parties) + .into_iter() + .map(|_| gen_client_key()) + .collect_vec(); + + // set Multi-Party seed + let mut seed = [0u8; 32]; + thread_rng().fill_bytes(&mut seed); + set_mp_seed(seed); + + // multi-party key gen round 1 + let pk_shares = client_keys + .iter() + .map(|k| gen_mp_keys_phase1(k)) + .collect_vec(); + + // create public key + let public_key = aggregate_public_key_shares(&pk_shares); + + // multi-party key gen round 2 + let server_key_shares = client_keys + .iter() + .map(|k| gen_mp_keys_phase2(k, &public_key)) + .collect_vec(); + + // server aggregates server key shares and sets it + let server_key = aggregate_server_key_shares(&server_key_shares); + server_key.set_server_key(); + + // private inputs + let a = 4u8; + let b = 6u8; + let c = 128u8; + let fhe_a = public_key.encrypt(&a); + let fhe_b = public_key.encrypt(&b); + let fhe_c = public_key.encrypt(&c); + + // fhe evaluation + let now = std::time::Instant::now(); + let fhe_out = fhe_circuit(&fhe_a, &fhe_b, &fhe_c); + println!("Circuit time: {:?}", now.elapsed()); + + // plain evaluation + let out = plain_circuit(a, b, c); + + // generate decryption shares to decrypt ciphertext fhe_out + let decryption_shares = client_keys + .iter() + .map(|k| k.gen_decryption_share(&fhe_out)) + .collect_vec(); + + // decrypt fhe_out using decryption shares + let got_out = client_keys[0].aggregate_decryption_shares(&fhe_out, &decryption_shares); + + assert_eq!(got_out, out); +} diff --git a/src/bool/mod.rs b/src/bool/mod.rs index 8d6d9ab..f48ad21 100644 --- a/src/bool/mod.rs +++ b/src/bool/mod.rs @@ -26,8 +26,16 @@ static BOOL_SERVER_KEY: OnceLock>>> static MULTI_PARTY_CRS: OnceLock> = OnceLock::new(); -pub fn set_parameter_set(parameter: &BoolParameters) { - BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(parameter.clone()))); +pub enum ParameterSelector { + MultiPartyLessThan16, +} + +pub fn set_parameter_set(select: ParameterSelector) { + match select { + ParameterSelector::MultiPartyLessThan16 => { + BOOL_EVALUATOR.with_borrow_mut(|v| *v = Some(BoolEvaluator::new(SMALL_MP_BOOL_PARAMS))); + } + } } pub fn set_mp_seed(seed: [u8; 32]) { diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs index ba0f63a..f6c0f8d 100644 --- a/src/bool/parameters.rs +++ b/src/bool/parameters.rs @@ -3,7 +3,7 @@ use num_traits::{ConstZero, FromPrimitive, PrimInt}; use crate::{backend::Modulus, decomposer::Decomposer}; #[derive(Clone, PartialEq)] -pub(crate) struct BoolParameters { +pub struct BoolParameters { rlwe_q: CiphertextModulus, lwe_q: CiphertextModulus, br_q: usize, @@ -181,7 +181,7 @@ pub(crate) struct PolynomialSize(pub(crate) usize); /// T equals modulus when modulus is non-native. Otherwise T equals 0. bool is /// true when modulus is native, false otherwise. -pub(crate) struct CiphertextModulus(T, bool); +pub struct CiphertextModulus(T, bool); impl CiphertextModulus { const fn new_native() -> Self { diff --git a/src/lib.rs b/src/lib.rs index 99f8674..a723e4d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,8 +23,13 @@ mod utils; pub use backend::{ ArithmeticLazyOps, ArithmeticOps, ModInit, ModularOpsU64, ShoupMatrixFMA, VectorOps, }; +pub use bool::{ + aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key, gen_mp_keys_phase1, + gen_mp_keys_phase2, set_mp_seed, set_parameter_set, ParameterSelector, +}; pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer}; pub use ntt::{Ntt, NttBackendU64, NttInit}; +pub use shortint::FheUint8; pub trait Matrix: AsRef<[Self::R]> { type MatElement; @@ -165,15 +170,15 @@ impl RowEntity for Vec { } } -trait Encryptor { +pub trait Encryptor { fn encrypt(&self, m: &M) -> C; } -trait Decryptor { +pub trait Decryptor { fn decrypt(&self, c: &C) -> M; } -trait MultiPartyDecryptor { +pub trait MultiPartyDecryptor { type DecryptionShare; fn gen_decryption_share(&self, c: &C) -> Self::DecryptionShare; diff --git a/src/random.rs b/src/random.rs index a06bcbd..88db76a 100644 --- a/src/random.rs +++ b/src/random.rs @@ -12,7 +12,7 @@ thread_local! { pub(crate) static DEFAULT_RNG: RefCell = RefCell::new(DefaultSecureRng::new_seeded([0u8;32])); } -pub(crate) trait NewWithSeed { +pub trait NewWithSeed { type Seed; fn new_with_seed(seed: Self::Seed) -> Self; } @@ -59,7 +59,7 @@ where fn random_fill(&mut self, modulus: &P, container: &mut M); } -pub(crate) struct DefaultSecureRng { +pub struct DefaultSecureRng { rng: ChaCha8Rng, } diff --git a/src/shortint/mod.rs b/src/shortint/mod.rs index fe3876a..e1f4f85 100644 --- a/src/shortint/mod.rs +++ b/src/shortint/mod.rs @@ -8,7 +8,7 @@ use crate::{ mod ops; mod types; -type FheUint8 = types::FheUint8>; +pub type FheUint8 = types::FheUint8>; impl Encryptor for ClientKey { fn encrypt(&self, m: &u8) -> FheUint8 { @@ -308,9 +308,7 @@ mod tests { use crate::{ bool::{ aggregate_public_key_shares, aggregate_server_key_shares, gen_client_key, gen_keys, - gen_mp_keys_phase1, gen_mp_keys_phase2, - parameters::{MP_BOOL_PARAMS, SMALL_MP_BOOL_PARAMS, SP_BOOL_PARAMS}, - set_mp_seed, set_parameter_set, + gen_mp_keys_phase1, gen_mp_keys_phase2, set_mp_seed, set_parameter_set, }, shortint::types::FheUint8, Decryptor, Encryptor, MultiPartyDecryptor, @@ -318,7 +316,7 @@ mod tests { #[test] fn all_uint8_apis() { - set_parameter_set(&SP_BOOL_PARAMS); + set_parameter_set(crate::ParameterSelector::MultiPartyLessThan16); let (ck, sk) = gen_keys(); sk.set_server_key(); @@ -466,7 +464,7 @@ mod tests { #[test] fn fheuint8_test_multi_party() { - set_parameter_set(&SMALL_MP_BOOL_PARAMS); + set_parameter_set(crate::ParameterSelector::MultiPartyLessThan16); set_mp_seed([0; 32]); let parties = 8; diff --git a/src/shortint/types.rs b/src/shortint/types.rs index 8178f36..e6be57a 100644 --- a/src/shortint/types.rs +++ b/src/shortint/types.rs @@ -1,5 +1,5 @@ #[derive(Clone)] -pub(super) struct FheUint8 { +pub struct FheUint8 { pub(super) data: Vec, }