diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 88c3f29..815a11c 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -2,12 +2,15 @@ use std::{ cell::{OnceCell, RefCell}, collections::HashMap, fmt::{Debug, Display}, + iter::Once, marker::PhantomData, ops::Shr, + sync::OnceLock, }; use itertools::{izip, partition, Itertools}; use num_traits::{FromPrimitive, Num, One, PrimInt, ToPrimitive, WrappingSub, Zero}; +use rand_distr::uniform::SampleUniform; use crate::{ backend::{ArithmeticOps, GetModulus, ModInit, ModularOpsU64, Modulus, VectorOps}, @@ -26,22 +29,95 @@ use crate::{ RlweCiphertext, RlweSecret, }, utils::{ - fill_random_ternary_secret_with_hamming_weight, generate_prime, mod_exponent, + fill_random_ternary_secret_with_hamming_weight, generate_prime, mod_exponent, Global, TryConvertFrom1, WithLocal, }, - Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, + Decryptor, Encryptor, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, }; use super::parameters::{BoolParameters, CiphertextModulus}; thread_local! { - static BOOL_EVALUATOR: RefCell>, NttBackendU64, ModularOpsU64>, ModularOpsU64>>> = RefCell::new(BoolEvaluator::new(MP_BOOL_PARAMS)); + pub(crate) static BOOL_EVALUATOR: RefCell>, NttBackendU64, ModularOpsU64>, ModularOpsU64>>> = RefCell::new(BoolEvaluator::new(MP_BOOL_PARAMS)); + } +pub(crate) static BOOL_SERVER_KEY: OnceLock< + ServerKeyEvaluationDomain>, DefaultSecureRng, NttBackendU64>, +> = OnceLock::new(); pub fn set_parameter_set(parameter: &BoolParameters) { BoolEvaluator::with_local_mut(|e| *e = BoolEvaluator::new(parameter.clone())) } +fn set_server_key(key: ServerKeyEvaluationDomain>, DefaultSecureRng, NttBackendU64>) { + assert!( + BOOL_SERVER_KEY.set(key).is_ok(), + "Attempted to set server key twice." + ); +} + +pub fn gen_keys() -> ( + ClientKey, + SeededServerKey>, BoolParameters, [u8; 32]>, +) { + BoolEvaluator::with_local_mut(|e| { + let ck = e.client_key(); + let sk = e.server_key(&ck); + + (ck, sk) + }) +} +pub(crate) trait BooleanGates { + type Ciphertext: RowEntity; + type Key; + + fn and_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); + fn nand_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); + fn or_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); + fn nor_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); + fn xor_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); + fn xnor_inplace(&mut self, c0: &mut Self::Ciphertext, c1: &Self::Ciphertext, key: &Self::Key); + fn not_inplace(&mut self, c: &mut Self::Ciphertext); + + fn and( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext; + fn nand( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext; + fn or( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext; + fn nor( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext; + fn xor( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext; + fn xnor( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext; + fn not(&mut self, c: &Self::Ciphertext) -> Self::Ciphertext; +} + impl WithLocal for BoolEvaluator< Vec>, @@ -63,6 +139,19 @@ impl WithLocal { BOOL_EVALUATOR.with_borrow_mut(|s| func(s)) } + + fn with_local_mut_mut(func: &mut F) -> R + where + F: FnMut(&mut Self) -> R, + { + BOOL_EVALUATOR.with_borrow_mut(|s| func(s)) + } +} + +impl Global for ServerKeyEvaluationDomain>, DefaultSecureRng, NttBackendU64> { + fn global() -> &'static Self { + BOOL_SERVER_KEY.get().unwrap() + } } struct ScratchMemory @@ -206,7 +295,7 @@ trait PbsInfo { } #[derive(Clone)] -struct ClientKey { +pub struct ClientKey { sk_rlwe: RlweSecret, sk_lwe: LweSecret, } @@ -219,25 +308,17 @@ impl ClientKey { } } -// impl WithLocal for ClientKey { -// fn with_local(func: F) -> R -// where -// F: Fn(&Self) -> R, -// { -// CLIENT_KEY.with_borrow(|client_key| func(client_key)) -// } - -// fn with_local_mut(func: F) -> R -// where -// F: Fn(&mut Self) -> R, -// { -// CLIENT_KEY.with_borrow_mut(|client_key| func(client_key)) -// } -// } - -// fn set_client_key(key: &ClientKey) { -// ClientKey::with_local_mut(|k| *k = key.clone()) -// } +impl Encryptor> for ClientKey { + fn encrypt(&self, m: &bool) -> Vec { + BoolEvaluator::with_local(|e| e.sk_encrypt(*m, self)) + } +} + +impl Decryptor> for ClientKey { + fn decrypt(&self, c: &Vec) -> bool { + BoolEvaluator::with_local(|e| e.sk_decrypt(c, self)) + } +} struct MultiPartyDecryptionShare { share: E, @@ -325,7 +406,7 @@ struct SeededMultiPartyServerKey { } /// Seeded single party server key -struct SeededServerKey { +pub struct SeededServerKey { /// Rgsw cts of LWE secret elements pub(crate) rgsw_cts: Vec, /// Auto keys @@ -376,8 +457,18 @@ impl SeededServerKey, S> { } } +impl SeededServerKey>, BoolParameters, [u8; 32]> { + pub fn set_server_key(&self) { + set_server_key(ServerKeyEvaluationDomain::< + _, + DefaultSecureRng, + NttBackendU64, + >::from(self)); + } +} + /// Server key in evaluation domain -struct ServerKeyEvaluationDomain { +pub(crate) struct ServerKeyEvaluationDomain { /// Rgsw cts of LWE secret elements rgsw_cts: Vec, /// Galois keys @@ -643,7 +734,7 @@ struct BoolPbsInfo { impl PbsInfo for BoolPbsInfo where - M::MatElement: PrimInt + WrappingSub + NumInfo + Debug + FromPrimitive, + M::MatElement: PrimInt + WrappingSub + NumInfo + FromPrimitive, RlweModOp: ArithmeticOps + VectorOps, LweModOp: ArithmeticOps + VectorOps, NttOp: Ntt, @@ -708,7 +799,7 @@ where } } -struct BoolEvaluator +pub(crate) struct BoolEvaluator where M: Matrix, { @@ -728,7 +819,8 @@ impl BoolEvaluator BoolEvaluator where M: MatrixEntity + MatrixMut, - M::MatElement: PrimInt + Debug + Display + NumInfo + FromPrimitive + WrappingSub, + M::MatElement: + PrimInt + Debug + Display + NumInfo + FromPrimitive + WrappingSub + SampleUniform, NttOp: Ntt, RlweModOp: ArithmeticOps + VectorOps @@ -738,10 +830,6 @@ where + GetModulus>, M::R: TryConvertFrom1<[i32], CiphertextModulus> + RowEntity + Debug, ::R: RowMut, - DefaultSecureRng: RandomFillGaussianInModulus<[M::MatElement], CiphertextModulus> - + RandomFillUniformInModulus<[M::MatElement], CiphertextModulus> - + RandomGaussianElementInModulus> - + NewWithSeed, { fn new(parameters: BoolParameters) -> Self where @@ -1219,7 +1307,7 @@ where let mut rlwe = M::zeros(2, ring_size); // sample error rlwe.iter_rows_mut().for_each(|ri| { - RandomFillGaussianInModulus::random_fill( + RandomFillGaussianInModulus::<[M::MatElement], CiphertextModulus>::random_fill( rng, &self.pbs_info.parameters.rlwe_q(), ri.as_mut(), @@ -1468,192 +1556,245 @@ where parameters: parameters, } } +} +impl BoolEvaluator +where + M: MatrixMut + MatrixEntity, + M::R: RowMut + RowEntity, + M::MatElement: PrimInt + FromPrimitive + One + Copy + Zero + Display + WrappingSub + NumInfo, + RlweModOp: VectorOps + + ArithmeticOps + + GetModulus>, + LweModOp: VectorOps + + ArithmeticOps + + GetModulus>, + NttOp: Ntt, +{ /// Returns c0 + c1 + Q/4 - fn _add_and_shift_lwe_cts(&self, c0: &M::R, c1: &M::R) -> M::R - where - M::R: Clone, - { - let mut c_out = M::R::zeros(c0.as_ref().len()); + fn _add_and_shift_lwe_cts(&self, c0: &mut M::R, c1: &M::R) { let modop = &self.pbs_info.rlwe_modop; - izip!( - c_out.as_mut().iter_mut(), - c0.as_ref().iter(), - c1.as_ref().iter() - ) - .for_each(|(o, i0, i1)| { - *o = modop.add(i0, i1); - }); + modop.elwise_add_mut(c0.as_mut(), c1.as_ref()); // +Q/4 - c_out.as_mut()[0] = modop.add(&c_out.as_ref()[0], &self.pbs_info.rlwe_qby4); - c_out + c0.as_mut()[0] = modop.add(&c0.as_ref()[0], &self.pbs_info.rlwe_qby4); } /// Returns 2(c0 - c1) + Q/4 - fn _subtract_double_and_shift_lwe_cts(&self, c0: &M::R, c1: &M::R) -> M::R - where - M::R: Clone, - { - let mut c_out = c0.clone(); + fn _subtract_double_and_shift_lwe_cts(&self, c0: &mut M::R, c1: &M::R) { let modop = &self.pbs_info.rlwe_modop; // c0 - c1 - modop.elwise_sub_mut(c_out.as_mut(), c1.as_ref()); + modop.elwise_sub_mut(c0.as_mut(), c1.as_ref()); // double - c_out.as_mut().iter_mut().for_each(|v| *v = modop.add(v, v)); - c_out + c0.as_mut().iter_mut().for_each(|v| *v = modop.add(v, v)); } +} - pub fn nand( +impl BooleanGates for BoolEvaluator +where + M: MatrixMut + MatrixEntity, + M::R: RowMut + RowEntity + Clone, + M::MatElement: PrimInt + FromPrimitive + One + Copy + Zero + Display + WrappingSub + NumInfo, + RlweModOp: VectorOps + + ArithmeticOps + + GetModulus>, + LweModOp: VectorOps + + ArithmeticOps + + GetModulus>, + NttOp: Ntt, +{ + type Ciphertext = M::R; + type Key = ServerKeyEvaluationDomain; + + fn nand_inplace( &mut self, - c0: &M::R, + c0: &mut M::R, c1: &M::R, server_key: &ServerKeyEvaluationDomain, - ) -> M::R - where - M::R: Clone, - { - let mut c_out = self._add_and_shift_lwe_cts(c0, c1); + ) { + self._add_and_shift_lwe_cts(c0, c1); // PBS pbs( &self.pbs_info, &self.nand_test_vec, - &mut c_out, + c0, server_key, &mut self.scratch_memory.lwe_vector, &mut self.scratch_memory.decomposition_matrix, ); - - c_out } - pub fn and( + fn and_inplace( &mut self, - c0: &M::R, + c0: &mut M::R, c1: &M::R, server_key: &ServerKeyEvaluationDomain, - ) -> M::R - where - M::R: Clone, - { - let mut c_out = self._add_and_shift_lwe_cts(c0, c1); + ) { + self._add_and_shift_lwe_cts(c0, c1); // PBS pbs( &self.pbs_info, &self.and_test_vec, - &mut c_out, + c0, server_key, &mut self.scratch_memory.lwe_vector, &mut self.scratch_memory.decomposition_matrix, ); - - c_out } - pub fn or( + fn or_inplace( &mut self, - c0: &M::R, + c0: &mut M::R, c1: &M::R, server_key: &ServerKeyEvaluationDomain, - ) -> M::R - where - M::R: Clone, - { - let mut c_out = self._add_and_shift_lwe_cts(c0, c1); + ) { + self._add_and_shift_lwe_cts(c0, c1); // PBS pbs( &self.pbs_info, &self.or_test_vec, - &mut c_out, + c0, server_key, &mut self.scratch_memory.lwe_vector, &mut self.scratch_memory.decomposition_matrix, ); - - c_out } - pub fn nor( + fn nor_inplace( &mut self, - c0: &M::R, + c0: &mut M::R, c1: &M::R, server_key: &ServerKeyEvaluationDomain, - ) -> M::R - where - M::R: Clone, - { - let mut c_out = self._add_and_shift_lwe_cts(c0, c1); + ) { + self._add_and_shift_lwe_cts(c0, c1); // PBS pbs( &self.pbs_info, &self.nor_test_vec, - &mut c_out, + c0, server_key, &mut self.scratch_memory.lwe_vector, &mut self.scratch_memory.decomposition_matrix, - ); - - c_out + ) } - pub fn xor( + fn xor_inplace( &mut self, - c0: &M::R, + c0: &mut M::R, c1: &M::R, server_key: &ServerKeyEvaluationDomain, - ) -> M::R - where - M::R: Clone, - { - let mut c_out = self._subtract_double_and_shift_lwe_cts(c0, c1); + ) { + self._subtract_double_and_shift_lwe_cts(c0, c1); // PBS pbs( &self.pbs_info, &self.xor_test_vec, - &mut c_out, + c0, server_key, &mut self.scratch_memory.lwe_vector, &mut self.scratch_memory.decomposition_matrix, ); - - c_out } - pub fn xnor( + fn xnor_inplace( &mut self, - c0: &M::R, + c0: &mut M::R, c1: &M::R, server_key: &ServerKeyEvaluationDomain, - ) -> M::R - where - M::R: Clone, - { - let mut c_out = self._subtract_double_and_shift_lwe_cts(c0, c1); + ) { + self._subtract_double_and_shift_lwe_cts(c0, c1); // PBS pbs( &self.pbs_info, &self.xnor_test_vec, - &mut c_out, + c0, server_key, &mut self.scratch_memory.lwe_vector, &mut self.scratch_memory.decomposition_matrix, ); - - c_out } - pub fn not(&mut self, c0: &M::R) -> M::R - where - ::R: FromIterator<::MatElement>, - { + fn not_inplace(&mut self, c0: &mut M::R) { let modop = &self.pbs_info.rlwe_modop; - c0.as_ref().iter().map(|v| modop.neg(v)).collect() + c0.as_mut().iter_mut().for_each(|v| *v = modop.neg(v)); + } + + fn and( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext { + let mut out = c0.clone(); + self.and_inplace(&mut out, c1, key); + out + } + + fn nand( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext { + let mut out = c0.clone(); + self.nand_inplace(&mut out, c1, key); + out + } + + fn or( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext { + let mut out = c0.clone(); + self.or_inplace(&mut out, c1, key); + out + } + + fn nor( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext { + let mut out = c0.clone(); + self.nor_inplace(&mut out, c1, key); + out + } + + fn xnor( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext { + let mut out = c0.clone(); + self.xnor_inplace(&mut out, c1, key); + out + } + + fn xor( + &mut self, + c0: &Self::Ciphertext, + c1: &Self::Ciphertext, + key: &Self::Key, + ) -> Self::Ciphertext { + let mut out = c0.clone(); + self.xor_inplace(&mut out, c1, key); + out + } + + fn not(&mut self, c: &Self::Ciphertext) -> Self::Ciphertext { + let mut out = c.clone(); + self.not_inplace(&mut out); + out } } @@ -1662,7 +1803,7 @@ where /// gk_to_si: [g^0, ..., g^{q/2-1}, -g^0, -g^1, .., -g^{q/2-1}] fn blind_rotation< MT: IsTrivial + MatrixMut, - Mmut: MatrixMut + Matrix, + Mmut: MatrixMut, D: Decomposer, NttOp: Ntt, ModOp: ArithmeticOps + VectorOps, @@ -1780,11 +1921,7 @@ fn blind_rotation< /// - key switching /// - mod down /// - blind rotate -fn pbs< - M: Matrix + MatrixMut + MatrixEntity, - P: PbsInfo, - K: PbsKey, ->( +fn pbs, K: PbsKey>( pbs_info: &P, test_vec: &M::R, lwe_in: &mut M::R, @@ -1793,7 +1930,7 @@ fn pbs< scratch_blind_rotate_matrix: &mut M, ) where ::R: RowMut, - M::MatElement: PrimInt + ToPrimitive + FromPrimitive + One + Copy + Zero + Display, + M::MatElement: PrimInt + FromPrimitive + One + Copy + Zero + Display, { let rlwe_q = pbs_info.rlwe_q(); let lwe_q = pbs_info.lwe_q(); @@ -2002,7 +2139,9 @@ fn sample_extract>( p_in: &[El], p_out: &mut [El], @@ -2092,6 +2231,13 @@ impl WithLocal for PBSTracer>> { { PBS_TRACER.with_borrow_mut(|t| func(t)) } + + fn with_local_mut_mut(func: &mut F) -> R + where + F: FnMut(&mut Self) -> R, + { + PBS_TRACER.with_borrow_mut(|t| func(t)) + } } #[cfg(test)] diff --git a/src/bool/mod.rs b/src/bool/mod.rs index bfa8111..468272e 100644 --- a/src/bool/mod.rs +++ b/src/bool/mod.rs @@ -1,2 +1,2 @@ -mod evaluator; -mod parameters; +pub(crate) mod evaluator; +pub(crate) mod parameters; diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs index 4bb2d36..41e7c3f 100644 --- a/src/bool/parameters.rs +++ b/src/bool/parameters.rs @@ -3,7 +3,7 @@ use num_traits::{ConstZero, FromPrimitive, PrimInt, ToPrimitive, Zero}; use crate::{backend::Modulus, decomposer::Decomposer}; #[derive(Clone, PartialEq)] -pub(super) struct BoolParameters { +pub struct BoolParameters { rlwe_q: CiphertextModulus, lwe_q: CiphertextModulus, br_q: usize, @@ -280,12 +280,12 @@ where } } -pub(super) const SP_BOOL_PARAMS: BoolParameters = BoolParameters:: { +pub(crate) const SP_BOOL_PARAMS: BoolParameters = BoolParameters:: { rlwe_q: CiphertextModulus::new_non_native(268369921u64), lwe_q: CiphertextModulus::new_non_native(1 << 16), - br_q: 1 << 10, - rlwe_n: PolynomialSize(1 << 10), - lwe_n: LweDimension(493), + br_q: 1 << 8, + rlwe_n: PolynomialSize(1 << 8), + lwe_n: LweDimension(10), lwe_decomposer_base: DecompostionLogBase(4), lwe_decomposer_count: DecompositionCount(4), rlrg_decomposer_base: DecompostionLogBase(7), diff --git a/src/decomposer.rs b/src/decomposer.rs index a490019..56b2d39 100644 --- a/src/decomposer.rs +++ b/src/decomposer.rs @@ -92,7 +92,7 @@ impl DefaultDecomposer { } } -impl Decomposer +impl Decomposer for DefaultDecomposer { type Element = T; diff --git a/src/lib.rs b/src/lib.rs index 8272510..43e6a5f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -144,6 +144,10 @@ impl Row for Vec { type Element = T; } +impl Row for [T] { + type Element = T; +} + impl RowMut for Vec {} impl RowEntity for Vec { @@ -151,3 +155,11 @@ impl RowEntity for Vec { vec![T::zero(); col] } } + +trait Encryptor { + fn encrypt(&self, m: &M) -> C; +} + +trait Decryptor { + fn decrypt(&self, c: &C) -> M; +} diff --git a/src/main.rs b/src/main.rs index e7a11a9..826cd6a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ fn main() { + let mut v = Vec::with_capacity(10); + v[0] = 1; println!("Hello, world!"); } diff --git a/src/random.rs b/src/random.rs index 6750356..acc743e 100644 --- a/src/random.rs +++ b/src/random.rs @@ -180,4 +180,11 @@ impl WithLocal for DefaultSecureRng { { DEFAULT_RNG.with_borrow_mut(|r| func(r)) } + + fn with_local_mut_mut(func: &mut F) -> R + where + F: FnMut(&mut Self) -> R, + { + DEFAULT_RNG.with_borrow_mut(|r| func(r)) + } } diff --git a/src/shortint.rs b/src/shortint.rs deleted file mode 100644 index 1191312..0000000 --- a/src/shortint.rs +++ /dev/null @@ -1,14 +0,0 @@ -use itertools::izip; - -use crate::Matrix; - -struct FheUint8 { - data: M, -} - -fn add(a: FheUint8, b: FheUint8) { - // CALL THE EVALUATOR - izip!(a.data.iter_rows(), b.data.iter_rows()).for_each(|(a_bit, b_bit)| { - // A ^ B - }); -} diff --git a/src/shortint/mod.rs b/src/shortint/mod.rs new file mode 100644 index 0000000..ddd9a41 --- /dev/null +++ b/src/shortint/mod.rs @@ -0,0 +1,262 @@ +use itertools::Itertools; + +use crate::{ + bool::evaluator::{BoolEvaluator, ClientKey, ServerKeyEvaluationDomain, BOOL_SERVER_KEY}, + utils::{Global, WithLocal}, + Decryptor, Encryptor, +}; +use ops::{ + arbitrary_bit_adder, arbitrary_bit_division_for_quotient_and_rem, arbitrary_bit_subtractor, + eight_bit_mul, +}; + +mod ops; +mod types; + +type FheUint8 = types::FheUint8>; + +fn add_mut(a: &mut FheUint8, b: &FheUint8) { + BoolEvaluator::with_local_mut_mut(&mut |e| { + let key = ServerKeyEvaluationDomain::global(); + arbitrary_bit_adder(e, a.data_mut(), b.data(), false, key); + }); +} + +fn sub(a: &FheUint8, b: &FheUint8) -> FheUint8 { + BoolEvaluator::with_local_mut(|e| { + let key = ServerKeyEvaluationDomain::global(); + let (out, _, _) = arbitrary_bit_subtractor(e, a.data(), b.data(), key); + FheUint8 { data: out } + }) +} + +fn mul(a: &FheUint8, b: &FheUint8) -> FheUint8 { + BoolEvaluator::with_local_mut(|e| { + let key = ServerKeyEvaluationDomain::global(); + let out = eight_bit_mul(e, a.data(), b.data(), key); + FheUint8 { data: out } + }) +} + +fn div(a: &FheUint8, b: &FheUint8) -> (FheUint8, FheUint8) { + BoolEvaluator::with_local_mut(|e| { + let key = ServerKeyEvaluationDomain::global(); + let (quotient, remainder) = + arbitrary_bit_division_for_quotient_and_rem(e, a.data(), b.data(), key); + + (FheUint8 { data: quotient }, FheUint8 { data: remainder }) + }) +} + +impl Encryptor for ClientKey { + fn encrypt(&self, m: &u8) -> FheUint8 { + let cts = (0..8) + .into_iter() + .map(|i| { + let bit = ((m >> i) & 1) == 1; + Encryptor::>::encrypt(self, &bit) + }) + .collect_vec(); + FheUint8 { data: cts } + } +} + +impl Decryptor for ClientKey { + fn decrypt(&self, c: &FheUint8) -> u8 { + let mut out = 0u8; + c.data().iter().enumerate().for_each(|(index, bit_c)| { + let bool = Decryptor::>::decrypt(self, bit_c); + if bool { + out += 1 << index; + } + }); + out + } +} + +mod frontend { + use super::ops::{ + arbitrary_bit_adder, arbitrary_bit_division_for_quotient_and_rem, arbitrary_bit_subtractor, + eight_bit_mul, + }; + use crate::{ + bool::evaluator::{BoolEvaluator, ServerKeyEvaluationDomain}, + utils::{Global, WithLocal}, + }; + + use super::{add_mut, div, mul, FheUint8}; + + mod arithetic { + use super::*; + use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub}; + + impl AddAssign<&FheUint8> for FheUint8 { + fn add_assign(&mut self, rhs: &FheUint8) { + BoolEvaluator::with_local_mut_mut(&mut |e| { + let key = ServerKeyEvaluationDomain::global(); + arbitrary_bit_adder(e, self.data_mut(), rhs.data(), false, key); + }); + } + } + + impl Add<&FheUint8> for &FheUint8 { + type Output = FheUint8; + fn add(self, rhs: &FheUint8) -> Self::Output { + let mut a = self.clone(); + a += rhs; + a + } + } + + impl Sub<&FheUint8> for &FheUint8 { + type Output = FheUint8; + fn sub(self, rhs: &FheUint8) -> Self::Output { + BoolEvaluator::with_local_mut(|e| { + let key = ServerKeyEvaluationDomain::global(); + let (out, _, _) = arbitrary_bit_subtractor(e, self.data(), self.data(), key); + FheUint8 { data: out } + }) + } + } + + impl Mul<&FheUint8> for &FheUint8 { + type Output = FheUint8; + fn mul(self, rhs: &FheUint8) -> Self::Output { + BoolEvaluator::with_local_mut(|e| { + let key = ServerKeyEvaluationDomain::global(); + let out = eight_bit_mul(e, self.data(), rhs.data(), key); + FheUint8 { data: out } + }) + } + } + + impl Div<&FheUint8> for &FheUint8 { + type Output = FheUint8; + fn div(self, rhs: &FheUint8) -> Self::Output { + BoolEvaluator::with_local_mut(|e| { + let key = ServerKeyEvaluationDomain::global(); + let (quotient, _) = arbitrary_bit_division_for_quotient_and_rem( + e, + self.data(), + rhs.data(), + key, + ); + FheUint8 { data: quotient } + }) + } + } + + impl Rem<&FheUint8> for &FheUint8 { + type Output = FheUint8; + fn rem(self, rhs: &FheUint8) -> Self::Output { + BoolEvaluator::with_local_mut(|e| { + let key = ServerKeyEvaluationDomain::global(); + let (_, remainder) = arbitrary_bit_division_for_quotient_and_rem( + e, + self.data(), + rhs.data(), + key, + ); + FheUint8 { data: remainder } + }) + } + } + } + + mod booleans {} +} + +#[cfg(test)] +mod tests { + use num_traits::Euclid; + + use crate::{ + bool::{ + evaluator::{gen_keys, set_parameter_set, BoolEvaluator}, + parameters::SP_BOOL_PARAMS, + }, + shortint::{add_mut, div, mul, sub, types::FheUint8}, + Decryptor, Encryptor, + }; + + #[test] + fn qwerty() { + set_parameter_set(&SP_BOOL_PARAMS); + + let (ck, sk) = gen_keys(); + sk.set_server_key(); + + for i in 1..=255 { + for j in 0..=255 { + let m0 = i; + let m1 = j; + let c0 = ck.encrypt(&m0); + let c1 = ck.encrypt(&m1); + + assert!(ck.decrypt(&c0) == m0); + assert!(ck.decrypt(&c1) == m1); + + // Add + // let mut c_m0_plus_m1 = FheUint8 { + // data: c0.data().to_vec(), + // }; + // add_mut(&mut c_m0_plus_m1, &c1); + // let m0_plus_m1 = ck.decrypt(&c_m0_plus_m1); + // assert_eq!( + // m0_plus_m1, + // m0.wrapping_add(m1), + // "Expected {} but got {m0_plus_m1} for {i}+{j}", + // m0.wrapping_add(m1) + // ); + + // Sub + // let c_sub = sub(&c0, &c1); + // let m0_sub_m1 = ck.decrypt(&c_sub); + // dbg!(m0, m1, m0_sub_m1); + // assert_eq!( + // m0_sub_m1, + // m0.wrapping_sub(m1), + // "Expected {} but got {m0_sub_m1} for {i}-{j}", + // m0.wrapping_sub(m1) + // ); + + // Mul + // let c_m0m1 = mul(&c0, &c1); + // let m0m1 = ck.decrypt(&c_m0m1); + // assert_eq!( + // m0m1, + // m0.wrapping_mul(m1), + // "Expected {} but got {m0m1} for {i}x{j}", + // m0.wrapping_mul(m1) + // ); + + // Div + // let (c_quotient, c_rem) = div(&c0, &c1); + // let m_quotient = ck.decrypt(&c_quotient); + // let m_remainder = ck.decrypt(&c_rem); + // if j != 0 { + // let (q, r) = i.div_rem_euclid(&j); + // assert_eq!( + // m_quotient, q, + // "Expected {} but got {m_quotient} for {i}/{j}", + // q + // ); + // assert_eq!( + // m_remainder, r, + // "Expected {} but got {m_quotient} for {i}%{j}", + // r + // ); + // } else { + // assert_eq!( + // m_quotient, 255, + // "Expected 255 but got {m_quotient}. Case div by zero" + // ); + // assert_eq!( + // m_remainder, i, + // "Expected {i} but got {m_quotient}. Case div by zero" + // ) + // } + } + } + } +} diff --git a/src/shortint/ops.rs b/src/shortint/ops.rs new file mode 100644 index 0000000..31ca2af --- /dev/null +++ b/src/shortint/ops.rs @@ -0,0 +1,362 @@ +use std::mem::MaybeUninit; + +use itertools::{izip, Itertools}; +use num_traits::PrimInt; + +use crate::{ + backend::ModularOpsU64, + bool::{ + evaluator::{BoolEvaluator, BooleanGates, ClientKey, ServerKeyEvaluationDomain}, + parameters::CiphertextModulus, + }, + ntt::NttBackendU64, + random::DefaultSecureRng, + Decryptor, +}; + +pub(super) fn half_adder( + evaluator: &mut E, + a: &mut E::Ciphertext, + b: &E::Ciphertext, + key: &E::Key, +) -> E::Ciphertext { + let carry = evaluator.and(a, b, key); + evaluator.xor_inplace(a, b, key); + carry +} + +pub(super) fn full_adder_plain_carry_in( + evaluator: &mut E, + a: &mut E::Ciphertext, + b: &E::Ciphertext, + carry_in: bool, + key: &E::Key, +) -> E::Ciphertext { + let mut a_and_b = evaluator.and(a, b, key); + evaluator.xor_inplace(a, b, key); //a = a ^ b + if carry_in { + // a_and_b = A & B | ((A^B) & C_in={True}) + evaluator.or_inplace(&mut a_and_b, &a, key); + } else { + // a_and_b = A & B | ((A^B) & C_in={False}) + // a_and_b = A & B + // noop + } + + // In xor if a input is 0, output equals the firt variable. If input is 1 then + // output equals !(first variable) + if carry_in { + // (A^B)^1 = !(A^B) + evaluator.not_inplace(a); + } else { + // (A^B)^0 + // no-op + } + a_and_b +} + +pub(super) fn full_adder( + evaluator: &mut E, + a: &mut E::Ciphertext, + b: &E::Ciphertext, + carry_in: &E::Ciphertext, + key: &E::Key, +) -> E::Ciphertext { + let mut a_and_b = evaluator.and(a, b, key); + evaluator.xor_inplace(a, b, key); //a = a ^ b + let a_xor_b_and_c = evaluator.and(&a, carry_in, key); + evaluator.or_inplace(&mut a_and_b, &a_xor_b_and_c, key); // a_and_b = A & B | ((A^B) & C_in) + evaluator.xor_inplace(a, &carry_in, key); + a_and_b +} + +pub(super) fn arbitrary_bit_adder( + evaluator: &mut E, + a: &mut [E::Ciphertext], + b: &[E::Ciphertext], + carry_in: bool, + key: &E::Key, +) -> (E::Ciphertext, E::Ciphertext) +where + E::Ciphertext: Clone, +{ + assert!(a.len() == b.len()); + let n = a.len(); + + let mut carry = if !carry_in { + half_adder(evaluator, &mut a[0], &b[0], key) + } else { + full_adder_plain_carry_in(evaluator, &mut a[0], &b[0], true, key) + }; + + izip!(a.iter_mut(), b.iter()) + .skip(1) + .take(n - 3) + .for_each(|(a_bit, b_bit)| { + carry = full_adder(evaluator, a_bit, b_bit, &carry, key); + }); + + let carry_last_last = full_adder(evaluator, &mut a[n - 2], &b[n - 2], &carry, key); + let carry_last = full_adder(evaluator, &mut a[n - 1], &b[n - 1], &carry_last_last, key); + + (carry_last, carry_last_last) +} + +pub(super) fn arbitrary_bit_subtractor( + evaluator: &mut E, + a: &[E::Ciphertext], + b: &[E::Ciphertext], + key: &E::Key, +) -> (Vec, E::Ciphertext, E::Ciphertext) +where + E::Ciphertext: Clone, +{ + let mut neg_b: Vec = b.iter().map(|v| evaluator.not(v)).collect(); + let (carry_last, carry_last_last) = arbitrary_bit_adder(evaluator, &mut neg_b, &a, true, key); + return (neg_b, carry_last, carry_last_last); +} + +pub(super) fn bit_mux( + evaluator: &mut E, + selector: E::Ciphertext, + if_true: &E::Ciphertext, + if_false: &E::Ciphertext, + key: &E::Key, +) -> E::Ciphertext { + // (s&a) | ((1-s)^b) + let not_selector = evaluator.not(&selector); + + let s_and_a = evaluator.and(&selector, if_true, key); + let s_and_b = evaluator.and(¬_selector, if_false, key); + evaluator.or(&s_and_a, &s_and_b, key) +} + +pub(super) fn arbitrary_bit_mux( + evaluator: &mut E, + selector: &E::Ciphertext, + if_true: &[E::Ciphertext], + if_false: &[E::Ciphertext], + key: &E::Key, +) -> Vec { + // (s&a) | ((1-s)^b) + let not_selector = evaluator.not(&selector); + + izip!(if_true.iter(), if_false.iter()) + .map(|(a, b)| { + let s_and_a = evaluator.and(&selector, a, key); + let s_and_b = evaluator.and(¬_selector, b, key); + evaluator.or(&s_and_a, &s_and_b, key) + }) + .collect() +} + +pub(super) fn eight_bit_mul( + evaluator: &mut E, + a: &[E::Ciphertext], + b: &[E::Ciphertext], + key: &E::Key, +) -> Vec { + assert!(a.len() == 8); + assert!(b.len() == 8); + let mut carries = Vec::with_capacity(7); + let mut out = Vec::with_capacity(8); + + for i in (0..8) { + if i == 0 { + let s = evaluator.and(&a[0], &b[0], key); + out.push(s); + } else if i == 1 { + let mut tmp0 = evaluator.and(&a[1], &b[0], key); + let tmp1 = evaluator.and(&a[0], &b[1], key); + let carry = half_adder(evaluator, &mut tmp0, &tmp1, key); + carries.push(carry); + out.push(tmp0); + } else { + let mut sum = { + let mut sum = evaluator.and(&a[i], &b[0], key); + let tmp = evaluator.and(&a[i - 1], &b[1], key); + carries[0] = full_adder(evaluator, &mut sum, &tmp, &carries[0], key); + sum + }; + + for j in 2..i { + let tmp = evaluator.and(&a[i - j], &b[j], key); + carries[j - 1] = full_adder(evaluator, &mut sum, &tmp, &carries[j - 1], key); + } + + let tmp = evaluator.and(&a[0], &b[i], key); + let carry = half_adder(evaluator, &mut sum, &tmp, key); + carries.push(carry); + + out.push(sum) + } + debug_assert!(carries.len() <= 7); + } + + out +} + +pub(super) fn arbitrary_bit_division_for_quotient_and_rem( + evaluator: &mut E, + a: &[E::Ciphertext], + b: &[E::Ciphertext], + key: &E::Key, +) -> (Vec, Vec) +where + E::Ciphertext: Clone, +{ + let n = a.len(); + let neg_b = b.iter().map(|v| evaluator.not(v)).collect_vec(); + + // Both remainder and quotient are initially stored in Big-endian in contract to + // the usual little endian we use. This is more friendly to vec pushes in + // division. After computing remainder and quotient, we simply reverse the + // vectors. + let mut remainder = vec![]; + let mut quotient = vec![]; + for i in 0..n { + // left shift + remainder.push(a[n - 1 - i].clone()); + + let mut subtract = remainder.clone(); + + // subtraction + // At i^th iteration remainder is only filled with i bits and the rest of the + // bits are zero. For example, at i = 1 + // 0 0 0 0 0 0 X X => remainder + // - Y Y Y Y Y Y Y Y => divisor . + // --------------- . + // Z Z Z Z Z Z Z Z => result + // For the next iteration we only care about result if divisor is <= remainder + // (which implies result <= remainder). Otherwise we care about remainder + // (recall re-storing division). Hence we optimise subtraction and + // ignore full adders for places where remainder bits are known to be false + // bits. We instead use `ANDs` to compute the carry overs, since the + // last carry over indicates whether the value has overflown (i.e. divisor <= + // remainder). Last carry out is `true` if value has not overflown, otherwise + // false. + let mut carry = + full_adder_plain_carry_in(evaluator, &mut subtract[i], &neg_b[0], true, key); + for j in 1..i + 1 { + carry = full_adder(evaluator, &mut subtract[i - j], &neg_b[j], &carry, key); + } + for j in i + 1..n { + // All I care about are the carries + evaluator.and_inplace(&mut carry, &neg_b[j], key); + } + + let not_carry = evaluator.not(&carry); + // Choose `remainder` if subtraction has overflown (i.e. carry = false). + // Otherwise choose `subtractor`. + // + // mux k^a | !(k)^b, where k is the selector. + izip!(remainder.iter_mut(), subtract.iter_mut()).for_each(|(r, s)| { + // choose `s` when carry is true, otherwise choose r + evaluator.and_inplace(s, &carry, key); + evaluator.and_inplace(r, ¬_carry, key); + evaluator.or_inplace(r, s, key); + }); + + // Set i^th MSB of quotient to 1 if carry = true, otherwise set it to 0. + // X&1 | X&0 => X&1 => X + quotient.push(carry); + } + + remainder.reverse(); + quotient.reverse(); + + (quotient, remainder) +} + +fn is_zero(evaluator: &mut E, a: &[E::Ciphertext], key: &E::Key) -> E::Ciphertext { + let mut a = a.iter().map(|v| evaluator.not(v)).collect_vec(); + let (out, rest_a) = a.split_at_mut(1); + rest_a.iter().for_each(|c| { + evaluator.and_inplace(&mut out[0], c, key); + }); + return a.remove(0); +} + +fn arbitrary_bit_equality( + evaluator: &mut E, + a: &[E::Ciphertext], + b: &[E::Ciphertext], + key: &E::Key, +) -> E::Ciphertext { + assert!(a.len() == b.len()); + let mut out = evaluator.and(&a[0], &b[0], key); + izip!(a.iter(), b.iter()).skip(1).for_each(|(abit, bbit)| { + let e = evaluator.xnor(abit, bbit, key); + evaluator.and(&mut out, &e, key); + }); + return out; +} + +/// Comaprator handle computes comparator result 2ns MSB onwards. It is +/// separated because comparator subroutine for signed and unsgind integers +/// differs only for 1st MSB and is common second MSB onwards +fn _comparator_handler_from_second_msb( + evaluator: &mut E, + a: &[E::Ciphertext], + b: &[E::Ciphertext], + mut comp: E::Ciphertext, + mut casc: E::Ciphertext, + key: &E::Key, +) -> E::Ciphertext { + let n = a.len(); + + // handle MSB - 1 + let mut tmp = evaluator.not(&b[n - 2]); + evaluator.and(&mut tmp, &a[n - 2], key); + evaluator.and(&mut tmp, &casc, key); + evaluator.or(&mut comp, &tmp, key); + + for i in 2..n { + // calculate cascading bit + let tmp_casc = evaluator.xnor(&a[n - 2 - i], &b[n - 2 - i], key); + evaluator.and(&mut casc, &tmp_casc, key); + + // calculate computate bit + let mut tmp = evaluator.not(&b[n - 1 - i]); + evaluator.and(&mut tmp, &a[n - 1 - i], key); + evaluator.and(&mut tmp, &casc, key); + evaluator.or(&mut comp, &tmp, key); + } + + return comp; +} + +/// Signed integer comparison is same as unsigned integer with MSB flipped. +fn arbitrary_signed_bit_comparator( + evaluator: &mut E, + a: &[E::Ciphertext], + b: &[E::Ciphertext], + key: &E::Key, +) -> E::Ciphertext { + assert!(a.len() == b.len()); + let n = a.len(); + + // handle MSB + let mut comp = evaluator.not(&a[n - 1]); + evaluator.and(&mut comp, &b[n - 1], key); // comp + let casc = evaluator.xnor(&a[n - 1], &b[n - 1], key); // casc + + return _comparator_handler_from_second_msb(evaluator, a, b, comp, casc, key); +} + +fn arbitrary_bit_comparator( + evaluator: &mut E, + a: &[E::Ciphertext], + b: &[E::Ciphertext], + key: &E::Key, +) -> E::Ciphertext { + assert!(a.len() == b.len()); + let n = a.len(); + + // handle MSB + let mut comp = evaluator.not(&b[n - 1]); + evaluator.and(&mut comp, &a[n - 1], key); + let casc = evaluator.xnor(&a[n - 1], &b[n - 1], key); + + return _comparator_handler_from_second_msb(evaluator, a, b, comp, casc, key); +} diff --git a/src/shortint/types.rs b/src/shortint/types.rs new file mode 100644 index 0000000..8178f36 --- /dev/null +++ b/src/shortint/types.rs @@ -0,0 +1,14 @@ +#[derive(Clone)] +pub(super) struct FheUint8 { + pub(super) data: Vec, +} + +impl FheUint8 { + pub(super) fn data(&self) -> &[C] { + &self.data + } + + pub(super) fn data_mut(&mut self) -> &mut [C] { + &mut self.data + } +} diff --git a/src/utils.rs b/src/utils.rs index 2720cc3..968eed3 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -15,6 +15,14 @@ pub trait WithLocal { fn with_local_mut(func: F) -> R where F: Fn(&mut Self) -> R; + + fn with_local_mut_mut(func: &mut F) -> R + where + F: FnMut(&mut Self) -> R; +} + +pub trait Global { + fn global() -> &'static Self; } pub fn fill_random_ternary_secret_with_hamming_weight<