From 72d8cafa9598cb04d47514a43bafca29d32fdc25 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Mon, 1 Jul 2024 15:26:10 +0530 Subject: [PATCH] add bool frontend --- src/bool/evaluator.rs | 57 +------------ src/bool/keys.rs | 4 +- src/bool/mod.rs | 180 +++++++++++++++++++++++++++++++++++++++- src/bool/mp_api.rs | 65 +++++++++++++-- src/bool/ni_mp_api.rs | 2 +- src/bool/print_noise.rs | 12 ++- src/lib.rs | 2 +- src/shortint/enc_dec.rs | 27 +----- src/shortint/mod.rs | 3 +- src/shortint/ops.rs | 2 +- 10 files changed, 257 insertions(+), 97 deletions(-) diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 099a5a5..d04bcee 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -33,7 +33,7 @@ use crate::{ encode_x_pow_si_with_emebedding_factor, mod_exponent, puncture_p_rng, TryConvertFrom1, WithLocal, }, - Encoder, Matrix, MatrixEntity, MatrixMut, RowEntity, RowMut, + BooleanGates, Encoder, Matrix, MatrixEntity, MatrixMut, RowEntity, RowMut, }; use super::{ @@ -171,57 +171,6 @@ impl NonInteractiveMultiPartyCrs { } } -pub(crate) trait BooleanGates { - type Ciphertext: RowEntity; - type Key; - - fn and_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); - fn nand_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); - fn or_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); - fn nor_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); - fn xor_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); - fn xnor_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); - fn not_inplace(&mut self, c: &mut Self::Ciphertext); - - fn and( - &mut self, - c0: &Self::Ciphertext, - c1: &Self::Ciphertext, - key: &Self::Key, - ) -> Self::Ciphertext; - fn nand( - &mut self, - c0: &Self::Ciphertext, - c1: &Self::Ciphertext, - key: &Self::Key, - ) -> Self::Ciphertext; - fn or( - &mut self, - c0: &Self::Ciphertext, - c1: &Self::Ciphertext, - key: &Self::Key, - ) -> Self::Ciphertext; - fn nor( - &mut self, - c0: &Self::Ciphertext, - c1: &Self::Ciphertext, - key: &Self::Key, - ) -> Self::Ciphertext; - fn xor( - &mut self, - c0: &Self::Ciphertext, - c1: &Self::Ciphertext, - key: &Self::Key, - ) -> Self::Ciphertext; - fn xnor( - &mut self, - c0: &Self::Ciphertext, - c1: &Self::Ciphertext, - key: &Self::Key, - ) -> Self::Ciphertext; - fn not(&mut self, c: &Self::Ciphertext) -> Self::Ciphertext; -} - struct ScratchMemory where M: Matrix, @@ -2324,7 +2273,7 @@ where ); } - fn not_inplace(&mut self, c0: &mut M::R) { + fn not_inplace(&self, c0: &mut M::R) { let modop = &self.pbs_info.rlwe_modop; c0.as_mut().iter_mut().for_each(|v| *v = modop.neg(v)); } @@ -2395,7 +2344,7 @@ where out } - fn not(&mut self, c: &Self::Ciphertext) -> Self::Ciphertext { + fn not(&self, c: &Self::Ciphertext) -> Self::Ciphertext { let mut out = c.clone(); self.not_inplace(&mut out); out diff --git a/src/bool/keys.rs b/src/bool/keys.rs index e6ecc51..3bbd727 100644 --- a/src/bool/keys.rs +++ b/src/bool/keys.rs @@ -497,7 +497,7 @@ pub(super) mod impl_server_key_eval_domain { use itertools::{izip, Itertools}; use crate::{ - evaluator::InteractiveMultiPartyCrs, + bool::evaluator::InteractiveMultiPartyCrs, ntt::{Ntt, NttInit}, pbs::PbsKey, random::RandomFill, @@ -1268,7 +1268,7 @@ pub struct CommonReferenceSeededNonInteractiveMultiPartyServerKeyShare; +pub type FheBool = impl_bool_frontend::FheBool>; + +pub(crate) trait BooleanGates { + type Ciphertext: RowEntity; + type Key; + + fn and_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); + fn nand_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); + fn or_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); + fn nor_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); + fn xor_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); + fn xnor_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); + fn not_inplace(&self, c: &mut Self::Ciphertext); + + fn and( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext; + fn nand( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext; + fn or( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext; + fn nor( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext; + fn xor( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext; + fn xnor( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext; + fn not(&self, c: &Self::Ciphertext) -> Self::Ciphertext; +} + +mod impl_bool_frontend { + use crate::MultiPartyDecryptor; + + /// Fhe Bool ciphertext + #[derive(Clone)] + pub struct FheBool { + pub(crate) data: C, + } + + impl FheBool { + pub(crate) fn data(&self) -> &C { + &self.data + } + + pub(crate) fn data_mut(&mut self) -> &mut C { + &mut self.data + } + } + + impl MultiPartyDecryptor> for K + where + K: MultiPartyDecryptor, + { + type DecryptionShare = >::DecryptionShare; + + fn aggregate_decryption_shares( + &self, + c: &FheBool, + shares: &[Self::DecryptionShare], + ) -> bool { + self.aggregate_decryption_shares(&c.data, shares) + } + + fn gen_decryption_share(&self, c: &FheBool) -> Self::DecryptionShare { + self.gen_decryption_share(&c.data) + } + } + + mod ops { + use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not}; + + use crate::{ + utils::{Global, WithLocal}, + BooleanGates, + }; + + use super::super::{BoolEvaluator, RuntimeServerKey}; + + type FheBool = super::super::FheBool; + + impl BitAnd for &FheBool { + type Output = FheBool; + fn bitand(self, rhs: Self) -> Self::Output { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + FheBool { + data: e.and(self.data(), rhs.data(), key), + } + }) + } + } + + impl BitAndAssign for FheBool { + fn bitand_assign(&mut self, rhs: Self) { + BoolEvaluator::with_local_mut_mut(&mut |e| { + let key = RuntimeServerKey::global(); + e.and_inplace(&mut self.data_mut(), rhs.data(), key); + }); + } + } + + impl BitOr for &FheBool { + type Output = FheBool; + fn bitor(self, rhs: Self) -> Self::Output { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + FheBool { + data: e.or(self.data(), rhs.data(), key), + } + }) + } + } + + impl BitOrAssign for FheBool { + fn bitor_assign(&mut self, rhs: Self) { + BoolEvaluator::with_local_mut_mut(&mut |e| { + let key = RuntimeServerKey::global(); + e.or_inplace(&mut self.data_mut(), rhs.data(), key); + }); + } + } + + impl BitXor for &FheBool { + type Output = FheBool; + fn bitxor(self, rhs: Self) -> Self::Output { + BoolEvaluator::with_local_mut(|e| { + let key = RuntimeServerKey::global(); + FheBool { + data: e.xor(self.data(), rhs.data(), key), + } + }) + } + } + + impl BitXorAssign for FheBool { + fn bitxor_assign(&mut self, rhs: Self) { + BoolEvaluator::with_local_mut_mut(&mut |e| { + let key = RuntimeServerKey::global(); + e.xor_inplace(&mut self.data_mut(), rhs.data(), key); + }); + } + } + + impl Not for &FheBool { + type Output = FheBool; + fn not(self) -> Self::Output { + BoolEvaluator::with_local(|e| FheBool { + data: e.not(self.data()), + }) + } + } + } +} mod common_mp_enc_dec { use super::BoolEvaluator; diff --git a/src/bool/mp_api.rs b/src/bool/mp_api.rs index e0529b3..22fc242 100644 --- a/src/bool/mp_api.rs +++ b/src/bool/mp_api.rs @@ -265,10 +265,10 @@ mod tests { use crate::{ bool::{ - evaluator::BooleanGates, + evaluator::BoolEncoding, keys::tests::{ideal_sk_rlwe, measure_noise_lwe}, + BooleanGates, }, - evaluator::{BoolEncoding, BoolEvaluator}, Encryptor, MultiPartyDecryptor, }; @@ -368,9 +368,8 @@ mod tests { use rand::Rng; use crate::{ - backend::ModulusPowerOf2, evaluator::BoolEncoding, pbs::PbsInfo, - rgsw::seeded_secret_key_encrypt_rlwe, utils::WithLocal, Decryptor, ModularOpsU64, - NttBackendU64, SampleExtractor, + bool::impl_bool_frontend::FheBool, pbs::PbsInfo, rgsw::seeded_secret_key_encrypt_rlwe, + Decryptor, SampleExtractor, }; use super::*; @@ -423,6 +422,27 @@ mod tests { BoolEvaluator::with_local(|e| e.sk_decrypt(c, self)) } } + + impl, C> Encryptor> for K + where + K: Encryptor, + { + fn encrypt(&self, m: &bool) -> FheBool { + FheBool { + data: self.encrypt(m), + } + } + } + + impl, C> Decryptor> for K + where + K: Decryptor, + { + fn decrypt(&self, c: &FheBool) -> bool { + self.decrypt(c.data()) + } + } + impl Encryptor<[bool], (Vec>, [u8; 32])> for K where K: SinglePartyClientKey, @@ -658,5 +678,40 @@ mod tests { } } } + + #[test] + #[cfg(feature = "interactive_mp")] + fn all_bool_apis() { + use crate::FheBool; + + set_single_party_parameter_sets(SP_TEST_BOOL_PARAMS); + + let (ck, sk) = gen_keys(); + sk.set_server_key(); + + for _ in 0..100 { + let a = thread_rng().gen_bool(0.5); + let b = thread_rng().gen_bool(0.5); + + let c_a: FheBool = ck.encrypt(&a); + let c_b: FheBool = ck.encrypt(&b); + + let c_out = &c_a & &c_b; + let out = ck.decrypt(&c_out); + assert_eq!(out, a & b, "Expected {} but got {out}", a & b); + + let c_out = &c_a | &c_b; + let out = ck.decrypt(&c_out); + assert_eq!(out, a | b, "Expected {} but got {out}", a | b); + + let c_out = &c_a ^ &c_b; + let out = ck.decrypt(&c_out); + assert_eq!(out, a ^ b, "Expected {} but got {out}", a ^ b); + + let c_out = !(&c_a); + let out = ck.decrypt(&c_out); + assert_eq!(out, !a, "Expected {} but got {out}", !a); + } + } } } diff --git a/src/bool/ni_mp_api.rs b/src/bool/ni_mp_api.rs index 04ae464..6034833 100644 --- a/src/bool/ni_mp_api.rs +++ b/src/bool/ni_mp_api.rs @@ -406,8 +406,8 @@ mod tests { use crate::{ bool::{ - evaluator::BooleanGates, keys::tests::{ideal_sk_rlwe, measure_noise_lwe}, + BooleanGates, }, Encoder, Encryptor, KeySwitchWithId, MultiPartyDecryptor, }; diff --git a/src/bool/print_noise.rs b/src/bool/print_noise.rs index 6438d10..cb1764a 100644 --- a/src/bool/print_noise.rs +++ b/src/bool/print_noise.rs @@ -370,8 +370,10 @@ mod tests { fn qwerty() { use crate::{ aggregate_public_key_shares, aggregate_server_key_shares, - bool::keys::{key_size::KeySize, ServerKeyEvaluationDomain}, - evaluator::InteractiveMultiPartyCrs, + bool::{ + evaluator::InteractiveMultiPartyCrs, + keys::{key_size::KeySize, ServerKeyEvaluationDomain}, + }, gen_client_key, gen_mp_keys_phase1, gen_mp_keys_phase2, parameters::CiphertextModulus, random::DefaultSecureRng, @@ -432,9 +434,11 @@ mod tests { fn querty2() { use crate::{ aggregate_server_key_shares, - bool::keys::{key_size::KeySize, NonInteractiveServerKeyEvaluationDomain}, + bool::{ + evaluator::NonInteractiveMultiPartyCrs, + keys::{key_size::KeySize, NonInteractiveServerKeyEvaluationDomain}, + }, decomposer::DefaultDecomposer, - evaluator::NonInteractiveMultiPartyCrs, gen_client_key, gen_server_key_share, parameters::CiphertextModulus, random::DefaultSecureRng, diff --git a/src/lib.rs b/src/lib.rs index e504983..0e91204 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,7 +22,7 @@ pub use backend::{ // ParameterSelector, }; pub use bool::*; pub use ntt::{Ntt, NttBackendU64, NttInit}; -pub use shortint::{div_zero_error_flag, FheBool, FheUint8}; +pub use shortint::{div_zero_error_flag, FheUint8}; pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer}; diff --git a/src/shortint/enc_dec.rs b/src/shortint/enc_dec.rs index 4c2c593..30f8b3e 100644 --- a/src/shortint/enc_dec.rs +++ b/src/shortint/enc_dec.rs @@ -1,19 +1,13 @@ use itertools::Itertools; use crate::{ - bool::BoolEvaluator, + bool::{BoolEvaluator, FheBool}, random::{DefaultSecureRng, RandomFillUniformInModulus}, utils::WithLocal, Decryptor, Encryptor, KeySwitchWithId, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, RowMut, SampleExtractor, }; -/// Fhe Bool ciphertext -#[derive(Clone)] -pub struct FheBool { - pub(super) data: C, -} - /// Fhe UInt8 type /// /// - Stores encryptions of bits in little endian (i.e least signficant bit @@ -210,25 +204,6 @@ where } } -impl MultiPartyDecryptor> for K -where - K: MultiPartyDecryptor, -{ - type DecryptionShare = >::DecryptionShare; - - fn aggregate_decryption_shares( - &self, - c: &FheBool, - shares: &[Self::DecryptionShare], - ) -> bool { - self.aggregate_decryption_shares(&c.data, shares) - } - - fn gen_decryption_share(&self, c: &FheBool) -> Self::DecryptionShare { - self.gen_decryption_share(&c.data) - } -} - impl Encryptor> for K where K: Encryptor, diff --git a/src/shortint/mod.rs b/src/shortint/mod.rs index 09f65c8..1f23610 100644 --- a/src/shortint/mod.rs +++ b/src/shortint/mod.rs @@ -2,11 +2,10 @@ mod enc_dec; mod ops; pub type FheUint8 = enc_dec::FheUint8>; -pub type FheBool = enc_dec::FheBool>; use std::cell::RefCell; -use crate::bool::{evaluator::BooleanGates, BoolEvaluator, RuntimeServerKey}; +use crate::bool::{BoolEvaluator, BooleanGates, FheBool, RuntimeServerKey}; thread_local! { static DIV_ZERO_ERROR: RefCell> = RefCell::new(None); diff --git a/src/shortint/ops.rs b/src/shortint/ops.rs index a6f6833..e4911ec 100644 --- a/src/shortint/ops.rs +++ b/src/shortint/ops.rs @@ -1,6 +1,6 @@ use itertools::{izip, Itertools}; -use crate::bool::evaluator::BooleanGates; +use crate::bool::BooleanGates; pub(super) fn half_adder( evaluator: &mut E,