diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 2456fcf..70fbe33 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -2191,31 +2191,11 @@ mod tests { RgswCiphertext, RgswCiphertextEvaluationDomain, SeededRgswCiphertext, SeededRlweCiphertext, }, - utils::{negacyclic_mul, Stats}, + utils::{negacyclic_mul, tests::Stats}, }; use super::*; - #[test] - fn bool_encrypt_decrypt_works() { - let bool_evaluator = BoolEvaluator::< - Vec>, - NttBackendU64, - ModularOpsU64>, - ModularOpsU64>, - ShoupServerKeyEvaluationDomain>>, - >::new(SP_TEST_BOOL_PARAMS); - let client_key = bool_evaluator.client_key(); - - let mut m = true; - for _ in 0..1000 { - let lwe_ct = bool_evaluator.sk_encrypt(m, &client_key); - let m_back = bool_evaluator.sk_decrypt(&lwe_ct, &client_key); - assert_eq!(m, m_back); - m = !m; - } - } - #[test] fn noise_tester() { let bool_evaluator = BoolEvaluator::< diff --git a/src/bool/ni_mp_api.rs b/src/bool/ni_mp_api.rs index 97610bc..108ac03 100644 --- a/src/bool/ni_mp_api.rs +++ b/src/bool/ni_mp_api.rs @@ -167,11 +167,10 @@ mod impl_enc_dec { use crate::{ bool::{evaluator::BoolEncoding, keys::NonInteractiveMultiPartyClientKey}, pbs::{sample_extract, PbsInfo, WithShoupRepr}, - random::{DefaultSecureRng, NewWithSeed, RandomFillUniformInModulus}, + random::{NewWithSeed, RandomFillUniformInModulus}, rgsw::{key_switch, secret_key_encrypt_rlwe}, - utils::{TryConvertFrom1, WithLocal}, - Encryptor, KeySwitchWithId, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, - RowEntity, RowMut, + utils::TryConvertFrom1, + Encryptor, KeySwitchWithId, Matrix, MatrixEntity, MatrixMut, RowEntity, RowMut, }; use itertools::Itertools; use num_traits::{ToPrimitive, Zero}; @@ -359,13 +358,10 @@ mod tests { use crate::{ backend::{GetModulus, Modulus}, - bool::{ - evaluator::{BoolEncoding, BooleanGates}, - keys::SinglePartyClientKey, - }, + bool::{evaluator::BooleanGates, keys::SinglePartyClientKey}, lwe::decrypt_lwe, rgsw::decrypt_rlwe, - utils::{Stats, TryConvertFrom1}, + utils::{tests::Stats, TryConvertFrom1}, ArithmeticOps, Encoder, Encryptor, KeySwitchWithId, ModInit, MultiPartyDecryptor, NttInit, Row, VectorOps, }; @@ -448,9 +444,11 @@ mod tests { ct.extract(0) }; - for _ in 0..100 { + for _ in 0..1000 { + // let now = std::time::Instant::now(); let ct_out = BoolEvaluator::with_local_mut(|e| e.xor(&ct0, &ct1, RuntimeServerKey::global())); + // println!("Time: {:?}", now.elapsed()); let decryption_shares = cks .iter() @@ -458,7 +456,7 @@ mod tests { .collect_vec(); let m_out = cks[0].aggregate_decryption_shares(&ct_out, &decryption_shares); - let m_expected = (m0 ^ m1); + let m_expected = m0 ^ m1; { let noise = measure_noise_lwe( diff --git a/src/decomposer.rs b/src/decomposer.rs index 40bf941..dc0c43c 100644 --- a/src/decomposer.rs +++ b/src/decomposer.rs @@ -275,7 +275,7 @@ mod tests { use crate::{ backend::{ModInit, ModularOpsU64}, decomposer::round_value, - utils::{generate_prime, Stats, TryConvertFrom1}, + utils::{generate_prime, tests::Stats, TryConvertFrom1}, }; use super::{Decomposer, DefaultDecomposer}; @@ -288,7 +288,7 @@ mod tests { let ring_size = 1 << 11; let mut rng = thread_rng(); - let mut stats = Stats { samples: vec![] }; + let mut stats = Stats::new(); for i in [true] { let q = if i { diff --git a/src/main.rs b/src/main.rs index 899422c..f328e4d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,85 +1 @@ -use std::os::unix::thread; - -use rand::{thread_rng, Rng}; - -fn decomposer(mut value: u64, q: u64, d: usize, logb: u64) -> Vec { - let b = 1u64 << logb; - let full_mask = b - 1u64; - let bby2 = b >> 1; - - if value >= (q >> 1) { - value = !(q - value) + 1; - } - - // let mut carry = 0; - // let mut out = Vec::with_capacity(d); - // for _ in 0..d { - // let k_i = carry + (value & full_mask); - // value = (value) >> logb; - // if k_i > bby2 { - // // if (k_i == bby2 && ((value & 1) == 1)) { - // // println!("AA"); - // // } - // out.push(q - (b - k_i)); - // carry = 1; - // } else { - // // if (k_i == bby2) { - // // println!("BB"); - // // } - // out.push(k_i); - // carry = 0; - // } - // } - // return out; - - let mut out = Vec::with_capacity(d); - for _ in 0..d { - let k_i = value & full_mask; - value = (value - k_i) >> logb; - - if k_i > bby2 || (k_i == bby2 && ((value & 1) == 1)) { - // if (k_i == bby2 && ((value & 1) == 1)) { - // println!("AA"); - // } - out.push(q - (b - k_i)); - value += 1; - } else { - // if (k_i == bby2) { - // println!("BB"); - // } - out.push(k_i); - } - } - - return out; -} - -fn recompose(limbs: &[u64], q: u64, logb: u64) -> u64 { - let mut out = 0; - limbs.iter().enumerate().for_each(|(i, l)| { - let a = 1u128 << (logb * (i as u64)); - let a = ((a * (*l as u128)) % (q as u128)) as u64; - out = (out + a) % q; - }); - out % q -} - -fn main() { - // let mut v = Vec::with_capacity(10); - // v[0] = 1; - // println!("Hello, world!"); - - let mut rng = thread_rng(); - - let q = 36028797018820609u64; - let logb = 11; - let d = 5; - - for _ in 0..100000 { - let value = rng.gen_range(0..q); - let limbs = decomposer(value, q, d, logb); - // println!("{:?}", &limbs); - let value_back = recompose(&limbs, q, logb); - assert_eq!(value, value_back) - } -} +fn main() {} diff --git a/src/multi_party.rs b/src/multi_party.rs index 10fcccf..4c9361c 100644 --- a/src/multi_party.rs +++ b/src/multi_party.rs @@ -4,10 +4,10 @@ use itertools::izip; use crate::{ backend::{GetModulus, VectorOps}, - ntt::{self, Ntt}, - random::{NewWithSeed, RandomFillGaussianInModulus, RandomFillUniformInModulus}, + ntt::Ntt, + random::{RandomFillGaussianInModulus, RandomFillUniformInModulus}, utils::TryConvertFrom1, - Decomposer, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, + Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, }; pub(crate) fn public_key_share< @@ -213,7 +213,7 @@ where let mut scratch_space = M::R::zeros(ring_size); - izip!(zero_encs.iter_rows_mut()).for_each(|(e_zero)| { + izip!(zero_encs.iter_rows_mut()).for_each(|e_zero| { // sample a_i RandomFillUniformInModulus::random_fill(p_rng, q, e_zero.as_mut()); diff --git a/src/rgsw/mod.rs b/src/rgsw/mod.rs index 1d52b94..2780db1 100644 --- a/src/rgsw/mod.rs +++ b/src/rgsw/mod.rs @@ -512,13 +512,13 @@ pub(crate) mod tests { use crate::{ backend::{GetModulus, ModInit, ModularOpsU64, Modulus, VectorOps}, decomposer::{Decomposer, DefaultDecomposer, RlweDecomposer}, - ntt::{self, Ntt, NttBackendU64, NttInit}, - random::{DefaultSecureRng, NewWithSeed, RandomFillUniformInModulus}, + ntt::{Ntt, NttBackendU64, NttInit}, + random::{DefaultSecureRng, RandomFillUniformInModulus}, rgsw::{ galois_auto_shoup, rlwe_by_rgsw_shoup, ShoupAutoKeyEvaluationDomain, ShoupRgswCiphertextEvaluationDomain, }, - utils::{generate_prime, negacyclic_mul, Stats, TryConvertFrom1}, + utils::{generate_prime, negacyclic_mul, tests::Stats, TryConvertFrom1}, Matrix, Secret, }; diff --git a/src/rgsw/runtime.rs b/src/rgsw/runtime.rs index 44fbaef..e8d19c9 100644 --- a/src/rgsw/runtime.rs +++ b/src/rgsw/runtime.rs @@ -2,10 +2,10 @@ use itertools::izip; use num_traits::Zero; use crate::{ - backend::{ArithmeticOps, GetModulus, Modulus, ShoupMatrixFMA, VectorOps}, + backend::{ArithmeticOps, GetModulus, ShoupMatrixFMA, VectorOps}, decomposer::{Decomposer, RlweDecomposer}, ntt::Ntt, - Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, + Matrix, MatrixEntity, MatrixMut, RowEntity, RowMut, }; use super::IsTrivial; diff --git a/src/shortint/enc_dec.rs b/src/shortint/enc_dec.rs index cce46f0..ef614d7 100644 --- a/src/shortint/enc_dec.rs +++ b/src/shortint/enc_dec.rs @@ -3,7 +3,7 @@ use itertools::Itertools; use crate::{ bool::BoolEvaluator, random::{DefaultSecureRng, RandomFillUniformInModulus}, - utils::{TryConvertFrom1, WithLocal}, + utils::WithLocal, Decryptor, Encryptor, KeySwitchWithId, Matrix, MatrixEntity, MatrixMut, MultiPartyDecryptor, RowMut, SampleExtractor, }; diff --git a/src/shortint/mod.rs b/src/shortint/mod.rs index 8e26783..a073c13 100644 --- a/src/shortint/mod.rs +++ b/src/shortint/mod.rs @@ -205,16 +205,3 @@ mod frontend { } } } - -#[cfg(test)] -mod tests { - use itertools::Itertools; - use num_traits::Euclid; - - use crate::{ - bool::set_parameter_set, shortint::enc_dec::FheUint8, utils::WithLocal, Decryptor, - Encryptor, MultiPartyDecryptor, - }; - - use super::*; -} diff --git a/src/utils.rs b/src/utils.rs index ac97d89..4a91449 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -86,7 +86,7 @@ pub fn fill_random_ternary_secret_with_hamming_weight< let mut secret_indices = (0..size).into_iter().map(|i| i).collect_vec(); let mut bit_index = 0; let mut byte_index = 0; - for _ in 0..hamming_weight { + for i in 0..hamming_weight { let s_index = RandomElementInModulus::::random(rng, &secret_indices.len()); let curr_bit = (bytes[byte_index] >> bit_index) & 1; @@ -97,7 +97,7 @@ pub fn fill_random_ternary_secret_with_hamming_weight< } secret_indices[s_index] = *secret_indices.last().unwrap(); - secret_indices.truncate(secret_indices.len()); + secret_indices.truncate(secret_indices.len() - 1); if bit_index == 7 { bit_index = 0; @@ -232,79 +232,73 @@ impl TryConvertFrom1<[P::Element], P> for Vec { } } -pub(crate) struct Stats { - pub(crate) samples: Vec, -} +#[cfg(test)] +pub(crate) mod tests { + use std::fmt::Debug; -impl Stats -where - // T: for<'a> Sum<&'a T>, - T: for<'a> std::iter::Sum<&'a T> + std::iter::Sum, -{ - pub(crate) fn new() -> Self { - Self { samples: vec![] } - } + use num_traits::{FromPrimitive, PrimInt}; + use rand::thread_rng; + + use crate::random::DefaultSecureRng; + + use super::fill_random_ternary_secret_with_hamming_weight; - pub(crate) fn mean(&self) -> f64 { - self.samples.iter().sum::().to_f64().unwrap() / (self.samples.len() as f64) + pub(crate) struct Stats { + pub(crate) samples: Vec, } - pub(crate) fn std_dev(&self) -> f64 { - let mean = self.mean(); + impl Stats + where + // T: for<'a> Sum<&'a T>, + T: for<'a> std::iter::Sum<&'a T> + std::iter::Sum, + { + pub(crate) fn new() -> Self { + Self { samples: vec![] } + } - // diff - let diff_sq = self - .samples - .iter() - .map(|v| { - let t = v.to_f64().unwrap() - mean; - t * t - }) - .into_iter() - .sum::(); + pub(crate) fn mean(&self) -> f64 { + self.samples.iter().sum::().to_f64().unwrap() / (self.samples.len() as f64) + } - (diff_sq / (self.samples.len() as f64)).sqrt() - } + pub(crate) fn std_dev(&self) -> f64 { + let mean = self.mean(); + + // diff + let diff_sq = self + .samples + .iter() + .map(|v| { + let t = v.to_f64().unwrap() - mean; + t * t + }) + .into_iter() + .sum::(); + + (diff_sq / (self.samples.len() as f64)).sqrt() + } - pub(crate) fn add_more(&mut self, values: &[T]) { - self.samples.extend(values.iter()); + pub(crate) fn add_more(&mut self, values: &[T]) { + self.samples.extend(values.iter()); + } } -} -#[cfg(test)] -mod tests { - - use super::is_probably_prime; - // let n = 1 << (11 + 1); - // let mut start = 1 << 55; - // while start < (1 << 56) { - // if start % n == 1 { - // break; - // } - // start += 1; - // } - - // let mut prime = None; - // while start < (1 << 56) { - // if is_probably_prime(start) { - // dbg!(start); - // prime = Some(start); - // break; - // } - // dbg!(start); - // start += (n); - // } #[test] - fn gg() { - let q = 30; - for i in 0..1000 { - let x = (1u64 << (q * 2)) + (i * (1 << q)) + 1; - let is_prime = is_probably_prime(x); - if is_prime { - println!("{x} = 2^{} + {i} * 2^{q} + 1", 2 * q); - } + fn ternary_secret_has_correct_hw() { + let mut rng = DefaultSecureRng::new(); + for n in 4..15 { + let ring_size = 1 << n; + let mut out = vec![0i32; ring_size]; + fill_random_ternary_secret_with_hamming_weight(&mut out, ring_size >> 1, &mut rng); + + // check hamming weight of out equals ring_size/2 + let mut non_zeros = 0; + out.iter().for_each(|i| { + if *i != 0 { + non_zeros += 1; + } + }); + + assert_eq!(ring_size >> 1, non_zeros); } - - // println!("{:?}", prime); } }