diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index eab6a44..b10f4b4 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -42,8 +42,8 @@ use crate::{ fill_random_ternary_secret_with_hamming_weight, generate_prime, mod_exponent, puncture_p_rng, Global, TryConvertFrom1, WithLocal, }, - Decryptor, Encryptor, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, Row, RowEntity, - RowMut, Secret, + Decryptor, Encoder, Encryptor, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, Row, + RowEntity, RowMut, Secret, }; use super::{ @@ -292,6 +292,19 @@ where } } +impl Encoder for B +where + B: BoolEncoding, +{ + fn encode(&self, v: bool) -> B::Element { + if v { + self.true_el() + } else { + self.false_el() + } + } +} + pub(super) struct BoolPbsInfo { auto_decomposer: DefaultDecomposer, rlwe_rgsw_decomposer: ( diff --git a/src/bool/ni_mp_api.rs b/src/bool/ni_mp_api.rs index 2a72099..cc87cb1 100644 --- a/src/bool/ni_mp_api.rs +++ b/src/bool/ni_mp_api.rs @@ -372,11 +372,11 @@ mod impl_enc_dec { mod tests { use impl_enc_dec::NonInteractiveBatchedFheBools; use itertools::{izip, Itertools}; - use num_traits::ToPrimitive; + use num_traits::{FromPrimitive, PrimInt, ToPrimitive, Zero}; use rand::{thread_rng, RngCore}; use crate::{ - backend::Modulus, + backend::{GetModulus, Modulus}, bool::{ evaluator::{BoolEncoding, BooleanGates}, keys::SinglePartyClientKey, @@ -384,12 +384,49 @@ mod tests { lwe::decrypt_lwe, rgsw::decrypt_rlwe, utils::{Stats, TryConvertFrom1}, - ArithmeticOps, Encryptor, KeySwitchWithId, ModInit, MultiPartyDecryptor, NttInit, - VectorOps, + ArithmeticOps, Encoder, Encryptor, KeySwitchWithId, ModInit, MultiPartyDecryptor, NttInit, + Row, VectorOps, }; use super::*; + pub(crate) fn ideal_sk_rlwe(cks: &[ClientKey]) -> Vec { + let mut ideal_rlwe_sk = cks[0].sk_rlwe(); + cks.iter().for_each(|k| { + let sk_rlwe = k.sk_rlwe(); + izip!(ideal_rlwe_sk.iter_mut(), sk_rlwe.iter()).for_each(|(a, b)| { + *a = *a + b; + }); + }); + ideal_rlwe_sk + } + + pub(crate) fn measure_noise_lwe< + R: Row, + S, + Modop: ArithmeticOps + + GetModulus, Element = R::Element>, + >( + lwe_ct: R, + m_expected: R::Element, + sk: &[S], + modop: &Modop, + ) -> f64 + where + R: TryConvertFrom1<[S], CiphertextModulus>, + R::Element: Zero + FromPrimitive + PrimInt, + { + let noisy_m = decrypt_lwe(&lwe_ct, &sk, modop); + let noise = modop.sub(&m_expected, &noisy_m); + modop + .modulus() + .map_element_to_i64(&noise) + .abs() + .to_f64() + .unwrap() + .log2() + } + #[test] fn non_interactive_mp_bool_nand() { set_parameter_set(ParameterSelector::NonInteractiveMultiPartyLessThanOrEqualTo16); @@ -411,16 +448,9 @@ mod tests { seeded_server_key.set_server_key(); let parameters = BoolEvaluator::with_local(|e| e.parameters().clone()); - let nttop = NttBackendU64::new(parameters.rlwe_q(), parameters.rlwe_n().0); - let rlwe_q_modop = ModularOpsU64::new(*parameters.rlwe_q()); + let rlwe_modop = parameters.default_rlwe_modop(); - let mut ideal_rlwe_sk = vec![0i32; parameters.rlwe_n().0]; - cks.iter().for_each(|k| { - let sk_rlwe = k.sk_rlwe(); - izip!(ideal_rlwe_sk.iter_mut(), sk_rlwe.iter()).for_each(|(a, b)| { - *a = *a + b; - }); - }); + let ideal_sk_rlwe = ideal_sk_rlwe(&cks); let mut m0 = false; let mut m1 = true; @@ -449,12 +479,8 @@ mod tests { let m_expected = (m0 ^ m1); { - let noisy_m = decrypt_lwe(&ct_out, &ideal_rlwe_sk, &rlwe_q_modop); - let noise = if m_expected { - rlwe_q_modop.sub(¶meters.rlwe_q().true_el(), &noisy_m) - } else { - rlwe_q_modop.sub(¶meters.rlwe_q().false_el(), &noisy_m) - }; + let noisy_m = decrypt_lwe(&ct_out, &ideal_sk_rlwe, &rlwe_modop); + let noise = rlwe_modop.sub(¶meters.rlwe_q().encode(m_expected), &noisy_m); println!( "Noise: {}", parameters @@ -464,7 +490,10 @@ mod tests { .to_f64() .unwrap() .log2() - ) + ); + // let noise = measure_noise_lwe(ct_out, + // parameters.rlwe_q().encode(m_expected), &ideal_sk_rlwe, + // &rlwe_modop); println!("Noise: {noise}"); } assert!(m_out == m_expected, "Expected {m_expected} but got {m_out}"); @@ -513,13 +542,7 @@ mod tests { let message = m .iter() - .map(|b| { - if *b { - parameters.rlwe_q().true_el() - } else { - parameters.rlwe_q().false_el() - } - }) + .map(|b| parameters.rlwe_q().encode(*b)) .collect_vec(); let mut m_out = vec![0u64; parameters.rlwe_n().0]; diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs index 30f11ae..f13a98b 100644 --- a/src/bool/parameters.rs +++ b/src/bool/parameters.rs @@ -531,7 +531,19 @@ pub(crate) const NON_INTERACTIVE_SMALL_MP_BOOL_PARAMS: BoolParameters = Boo }; #[cfg(test)] mod tests { - use crate::utils::generate_prime; + + impl BoolParameters { + pub(crate) fn default_rlwe_modop(&self) -> ModularOpsU64> { + ModularOpsU64::new(self.rlwe_q) + } + pub(crate) fn default_rlwe_nttop(&self) -> NttBackendU64 { + NttBackendU64::new(&self.rlwe_q, self.rlwe_n.0) + } + } + + use crate::{utils::generate_prime, ModInit, ModularOpsU64, Ntt, NttBackendU64, NttInit}; + + use super::{BoolParameters, CiphertextModulus}; #[test] fn find_prime() { diff --git a/src/lib.rs b/src/lib.rs index 08c45d5..8c7d242 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -183,3 +183,7 @@ pub trait MultiPartyDecryptor { pub trait KeySwitchWithId { fn key_switch(&self, user_id: usize) -> C; } + +pub(crate) trait Encoder { + fn encode(&self, v: F) -> T; +} diff --git a/src/utils.rs b/src/utils.rs index 2cd02a3..ac97d89 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -273,6 +273,7 @@ where #[cfg(test)] mod tests { + use super::is_probably_prime; // let n = 1 << (11 + 1); // let mut start = 1 << 55;