diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 87a4cda..88c3f29 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -636,7 +636,6 @@ struct BoolPbsInfo { rlwe_modop: RlweModOp, lwe_modop: LweModOp, embedding_factor: usize, - nand_test_vec: M::R, rlwe_qby4: M::MatElement, rlwe_auto_maps: Vec<(Vec, Vec)>, parameters: BoolParameters, @@ -715,6 +714,12 @@ where { pbs_info: BoolPbsInfo, scratch_memory: ScratchMemory, + nand_test_vec: M::R, + and_test_vec: M::R, + or_test_vec: M::R, + nor_test_vec: M::R, + xor_test_vec: M::R, + xnor_test_vec: M::R, _phantom: PhantomData, } @@ -764,39 +769,79 @@ where let rlwe_modop = RlweModOp::new(*parameters.rlwe_q()); let lwe_modop = LweModOp::new(*parameters.lwe_q()); - // set test vectors let q = *parameters.br_q(); let qby2 = q >> 1; let qby8 = q >> 3; - let mut nand_test_vec = M::R::zeros(qby2); // Q/8 (Q: rlwe_q) let true_m_el = parameters.rlwe_q().true_el(); // -Q/8 let false_m_el = parameters.rlwe_q().false_el(); - for i in 0..qby2 { - if i < (3 * qby8) { - nand_test_vec.as_mut()[i] = true_m_el; - } else { - nand_test_vec.as_mut()[i] = false_m_el; - } - } - - // v(X) -> v(X^{-g}) let (auto_map_index, auto_map_sign) = generate_auto_map(qby2, -(g as isize)); - let mut nand_test_vec_autog = M::R::zeros(qby2); - izip!( - nand_test_vec.as_ref().iter(), - auto_map_index.iter(), - auto_map_sign.iter() - ) - .for_each(|(v, to_index, to_sign)| { - if !to_sign { - // negate - nand_test_vec_autog.as_mut()[*to_index] = rlwe_modop.neg(v); - } else { - nand_test_vec_autog.as_mut()[*to_index] = *v; + + let init_test_vec = |partition_el: usize, + before_partition_el: M::MatElement, + after_partition_el: M::MatElement| { + let mut test_vec = M::R::zeros(qby2); + for i in 0..qby2 { + if i < partition_el { + test_vec.as_mut()[i] = before_partition_el; + } else { + test_vec.as_mut()[i] = after_partition_el; + } } - }); + + // v(X) -> v(X^{-g}) + let mut test_vec_autog = M::R::zeros(qby2); + izip!( + test_vec.as_ref().iter(), + auto_map_index.iter(), + auto_map_sign.iter() + ) + .for_each(|(v, to_index, to_sign)| { + if !to_sign { + // negate + test_vec_autog.as_mut()[*to_index] = rlwe_modop.neg(v); + } else { + test_vec_autog.as_mut()[*to_index] = *v; + } + }); + + return test_vec_autog; + }; + + let nand_test_vec = init_test_vec(3 * qby8, true_m_el, false_m_el); + let and_test_vec = init_test_vec(3 * qby8, false_m_el, true_m_el); + let or_test_vec = init_test_vec(qby8, false_m_el, true_m_el); + let nor_test_vec = init_test_vec(qby8, true_m_el, false_m_el); + let xor_test_vec = init_test_vec(qby8, false_m_el, true_m_el); + let xnor_test_vec = init_test_vec(qby8, true_m_el, false_m_el); + + // // set test vectors + // let mut nand_test_vec = M::R::zeros(qby2); + // for i in 0..qby2 { + // if i < (3 * qby8) { + // nand_test_vec.as_mut()[i] = true_m_el; + // } else { + // nand_test_vec.as_mut()[i] = false_m_el; + // } + // } + + // // v(X) -> v(X^{-g}) + // let (auto_map_index, auto_map_sign) = generate_auto_map(qby2, -(g as isize)); + // let mut nand_test_vec_autog = M::R::zeros(qby2); + // izip!( + // nand_test_vec.as_ref().iter(), + // auto_map_index.iter(), + // auto_map_sign.iter() + // ) + // .for_each(|(v, to_index, to_sign)| { + // if !to_sign { + // // negate + // nand_test_vec_autog.as_mut()[*to_index] = rlwe_modop.neg(v); + // } else { + // nand_test_vec_autog.as_mut()[*to_index] = *v; + // } + // }); // auto map indices and sign let mut rlwe_auto_maps = vec![]; @@ -819,7 +864,6 @@ where lwe_modop, rlwe_modop, rlwe_nttop, - nand_test_vec: nand_test_vec_autog, rlwe_qby4, rlwe_auto_maps, parameters: parameters, @@ -828,6 +872,12 @@ where BoolEvaluator { pbs_info, scratch_memory, + nand_test_vec, + and_test_vec, + or_test_vec, + nor_test_vec, + xnor_test_vec, + xor_test_vec, _phantom: PhantomData, } } @@ -1419,13 +1469,11 @@ where } } - // TODO(Jay): scratch spaces must be thread local. Don't pass them as arguments - pub fn nand( - &mut self, - c0: &M::R, - c1: &M::R, - server_key: &ServerKeyEvaluationDomain, - ) -> M::R { + /// Returns c0 + c1 + Q/4 + fn _add_and_shift_lwe_cts(&self, c0: &M::R, c1: &M::R) -> M::R + where + M::R: Clone, + { let mut c_out = M::R::zeros(c0.as_ref().len()); let modop = &self.pbs_info.rlwe_modop; izip!( @@ -1438,11 +1486,111 @@ where }); // +Q/4 c_out.as_mut()[0] = modop.add(&c_out.as_ref()[0], &self.pbs_info.rlwe_qby4); + c_out + } + + /// Returns 2(c0 - c1) + Q/4 + fn _subtract_double_and_shift_lwe_cts(&self, c0: &M::R, c1: &M::R) -> M::R + where + M::R: Clone, + { + let mut c_out = c0.clone(); + let modop = &self.pbs_info.rlwe_modop; + // c0 - c1 + modop.elwise_sub_mut(c_out.as_mut(), c1.as_ref()); + + // double + c_out.as_mut().iter_mut().for_each(|v| *v = modop.add(v, v)); + c_out + } + + pub fn nand( + &mut self, + c0: &M::R, + c1: &M::R, + server_key: &ServerKeyEvaluationDomain, + ) -> M::R + where + M::R: Clone, + { + let mut c_out = self._add_and_shift_lwe_cts(c0, c1); + + // PBS + pbs( + &self.pbs_info, + &self.nand_test_vec, + &mut c_out, + server_key, + &mut self.scratch_memory.lwe_vector, + &mut self.scratch_memory.decomposition_matrix, + ); + + c_out + } + + pub fn and( + &mut self, + c0: &M::R, + c1: &M::R, + server_key: &ServerKeyEvaluationDomain, + ) -> M::R + where + M::R: Clone, + { + let mut c_out = self._add_and_shift_lwe_cts(c0, c1); + + // PBS + pbs( + &self.pbs_info, + &self.and_test_vec, + &mut c_out, + server_key, + &mut self.scratch_memory.lwe_vector, + &mut self.scratch_memory.decomposition_matrix, + ); + + c_out + } + + pub fn or( + &mut self, + c0: &M::R, + c1: &M::R, + server_key: &ServerKeyEvaluationDomain, + ) -> M::R + where + M::R: Clone, + { + let mut c_out = self._add_and_shift_lwe_cts(c0, c1); + + // PBS + pbs( + &self.pbs_info, + &self.or_test_vec, + &mut c_out, + server_key, + &mut self.scratch_memory.lwe_vector, + &mut self.scratch_memory.decomposition_matrix, + ); + + c_out + } + + pub fn nor( + &mut self, + c0: &M::R, + c1: &M::R, + server_key: &ServerKeyEvaluationDomain, + ) -> M::R + where + M::R: Clone, + { + let mut c_out = self._add_and_shift_lwe_cts(c0, c1); // PBS pbs( &self.pbs_info, - &self.pbs_info.nand_test_vec, + &self.nor_test_vec, &mut c_out, server_key, &mut self.scratch_memory.lwe_vector, @@ -1451,6 +1599,62 @@ where c_out } + + pub fn xor( + &mut self, + c0: &M::R, + c1: &M::R, + server_key: &ServerKeyEvaluationDomain, + ) -> M::R + where + M::R: Clone, + { + let mut c_out = self._subtract_double_and_shift_lwe_cts(c0, c1); + + // PBS + pbs( + &self.pbs_info, + &self.xor_test_vec, + &mut c_out, + server_key, + &mut self.scratch_memory.lwe_vector, + &mut self.scratch_memory.decomposition_matrix, + ); + + c_out + } + + pub fn xnor( + &mut self, + c0: &M::R, + c1: &M::R, + server_key: &ServerKeyEvaluationDomain, + ) -> M::R + where + M::R: Clone, + { + let mut c_out = self._subtract_double_and_shift_lwe_cts(c0, c1); + + // PBS + pbs( + &self.pbs_info, + &self.xnor_test_vec, + &mut c_out, + server_key, + &mut self.scratch_memory.lwe_vector, + &mut self.scratch_memory.decomposition_matrix, + ); + + c_out + } + + pub fn not(&mut self, c0: &M::R) -> M::R + where + ::R: FromIterator<::MatElement>, + { + let modop = &self.pbs_info.rlwe_modop; + c0.as_ref().iter().map(|v| modop.neg(v)).collect() + } } /// LMKCY+ Blind rotation @@ -1956,6 +2160,7 @@ mod tests { let mut m0 = false; let mut m1 = true; + let mut ct0 = bool_evaluator.sk_encrypt(m0, &client_key); let mut ct1 = bool_evaluator.sk_encrypt(m1, &client_key); @@ -2051,6 +2256,44 @@ mod tests { } } + #[test] + fn bool_xor() { + let mut bool_evaluator = BoolEvaluator::< + Vec>, + NttBackendU64, + ModularOpsU64>, + ModularOpsU64>, + >::new(SP_BOOL_PARAMS); + + // println!("{:?}", bool_evaluator.nand_test_vec); + let client_key = bool_evaluator.client_key(); + let seeded_server_key = bool_evaluator.server_key(&client_key); + let server_key_eval_domain = + ServerKeyEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( + &seeded_server_key, + ); + + let mut m0 = false; + let mut m1 = true; + + let mut ct0 = bool_evaluator.sk_encrypt(m0, &client_key); + let mut ct1 = bool_evaluator.sk_encrypt(m1, &client_key); + + for _ in 0..1000 { + let ct_back = bool_evaluator.xor(&ct0, &ct1, &server_key_eval_domain); + let m_out = (m0 ^ m1); + + let m_back = bool_evaluator.sk_decrypt(&ct_back, &client_key); + assert!(m_out == m_back, "Expected {m_out}, got {m_back}"); + + m1 = m0; + m0 = m_out; + + ct1 = ct0; + ct0 = ct_back; + } + } + #[test] fn multi_party_encryption_decryption() { let bool_evaluator = BoolEvaluator::< diff --git a/src/lib.rs b/src/lib.rs index 3d15566..8272510 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ mod ntt; mod num; mod random; mod rgsw; +mod shortint; mod utils; pub trait Matrix: AsRef<[Self::R]> { diff --git a/src/shortint.rs b/src/shortint.rs new file mode 100644 index 0000000..1191312 --- /dev/null +++ b/src/shortint.rs @@ -0,0 +1,14 @@ +use itertools::izip; + +use crate::Matrix; + +struct FheUint8 { + data: M, +} + +fn add(a: FheUint8, b: FheUint8) { + // CALL THE EVALUATOR + izip!(a.data.iter_rows(), b.data.iter_rows()).for_each(|(a_bit, b_bit)| { + // A ^ B + }); +}