diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 47bcd9b..e667e91 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -1829,7 +1829,7 @@ mod tests { RgswCiphertext, RgswCiphertextEvaluationDomain, SeededRgswCiphertext, SeededRlweCiphertext, }, - utils::negacyclic_mul, + utils::{negacyclic_mul, Stats}, }; use super::*; @@ -2258,41 +2258,6 @@ mod tests { } } - struct Stats { - samples: Vec, - } - - impl Stats - where - // T: for<'a> Sum<&'a T>, - T: for<'a> std::iter::Sum<&'a T> + std::iter::Sum, - { - fn mean(&self) -> f64 { - self.samples.iter().sum::().to_f64().unwrap() / (self.samples.len() as f64) - } - - fn std_dev(&self) -> f64 { - let mean = self.mean(); - - // diff - let diff_sq = self - .samples - .iter() - .map(|v| { - let t = v.to_f64().unwrap() - mean; - t * t - }) - .into_iter() - .sum::(); - - (diff_sq / (self.samples.len() as f64)).sqrt() - } - - fn add_more(&mut self, values: &[T]) { - self.samples.extend(values.iter()); - } - } - #[test] fn tester() { // pub(super) const TEST_MP_BOOL_PARAMS: BoolParameters = diff --git a/src/rgsw.rs b/src/rgsw.rs index a46bbe9..b945a0e 100644 --- a/src/rgsw.rs +++ b/src/rgsw.rs @@ -1557,7 +1557,7 @@ pub(crate) mod tests { RgswCiphertext, RgswCiphertextEvaluationDomain, RlweCiphertext, RlwePublicKey, SeededAutoKey, SeededRgswCiphertext, SeededRlweCiphertext, SeededRlwePublicKey, }, - utils::{generate_prime, negacyclic_mul, TryConvertFrom}, + utils::{generate_prime, negacyclic_mul, Stats, TryConvertFrom}, Matrix, Secret, }; @@ -2065,9 +2065,9 @@ pub(crate) mod tests { ) ]; - // _measure_noise_rgsw(&rgsw_carrym, &carry_m, s.values(), &decomposer, q); + _measure_noise_rgsw(&rgsw_carrym, &carry_m, s.values(), &decomposer, q); - for i in 0..1 { + for i in 0..2 { let mut m = vec![0u64; ring_size as usize]; m[thread_rng().gen_range(0..ring_size) as usize] = if (i & 1) == 1 { q - 1 } else { 1 }; let rgsw_m = RgswCiphertextEvaluationDomain::<_, DefaultSecureRng, NttBackendU64>::from( @@ -2111,4 +2111,108 @@ pub(crate) mod tests { println!("RLWE(m) x RGSW(carry_m): {noise}"); } } + + #[test] + fn some_work() { + let logq = 50; + let ring_size = 1 << 10; + let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap(); + let d_rgsw = 10; + let logb = 5; + let decomposer = ( + DefaultDecomposer::new(q, logb, d_rgsw), + DefaultDecomposer::new(q, logb, d_rgsw), + ); + + let ntt_op = NttBackendU64::new(q, ring_size as usize); + let mod_op = ModularOpsU64::new(q); + let mut rng = DefaultSecureRng::new_seeded([0u8; 32]); + + let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); + + let mut check = Stats { samples: vec![] }; + + for _ in 0..100 { + let mut m0 = vec![0u64; ring_size as usize]; + m0[thread_rng().gen_range(0..ring_size) as usize] = 1; + let mut m1 = vec![0u64; ring_size as usize]; + m1[thread_rng().gen_range(0..ring_size) as usize] = 1; + + let mut rgsw_ct0 = { + let seeded_rgsw_ct = + _sk_encrypt_rgsw(&m0, s.values(), &decomposer, &mod_op, &ntt_op); + RgswCiphertextEvaluationDomain::>, DefaultSecureRng, NttBackendU64>::from(&seeded_rgsw_ct) + }; + let rgsw_ct1 = { + let seeded_rgsw_ct = + _sk_encrypt_rgsw(&m1, s.values(), &decomposer, &mod_op, &ntt_op); + RgswCiphertextEvaluationDomain::>, DefaultSecureRng, NttBackendU64>::from(&seeded_rgsw_ct) + }; + + // RGSW x RGSW + // send RGSW(m0) to coefficient domain + rgsw_ct0 + .data + .iter_mut() + .for_each(|r| ntt_op.backward(r.as_mut_slice())); + let mut scratch_matrix = vec![ + vec![0u64; ring_size as usize]; + std::cmp::max( + decomposer.a().decomposition_count(), + decomposer.b().decomposition_count() + ) + decomposer.a().decomposition_count() * 2 + + decomposer.b().decomposition_count() * 2 + ]; + rgsw_by_rgsw_inplace( + &mut rgsw_ct0.data, + &rgsw_ct1.data, + &decomposer, + &mut scratch_matrix, + &ntt_op, + &mod_op, + ); + let mut rgsw_m0m1 = rgsw_ct0; + // Back to Evaluation for RLWExRGSW + rgsw_m0m1 + .data + .iter_mut() + .for_each(|r| ntt_op.forward(r.as_mut_slice())); + + // Sample m2, encrypt it as RLWE(m2) and multiply RLWE(m2)xRGSW(m0m1) + let mut m2 = vec![0u64; ring_size as usize]; + RandomUniformDist::random_fill(&mut rng, &q, m2.as_mut_slice()); + let mut rlwe_in_ct = { _sk_encrypt_rlwe(&m2, s.values(), &ntt_op, &mod_op) }; + let mut scratch_space = vec![ + vec![0u64; ring_size as usize]; + std::cmp::max( + decomposer.a().decomposition_count(), + decomposer.b().decomposition_count() + ) + 2 + ]; + rlwe_by_rgsw( + &mut rlwe_in_ct, + &rgsw_m0m1.data, + &mut scratch_space, + &decomposer, + &ntt_op, + &mod_op, + ); + + // Decrypt RLWE(m0m1m2) + let mut m0m1m2_back = vec![0u64; ring_size as usize]; + decrypt_rlwe(&rlwe_in_ct, s.values(), &mut m0m1m2_back, &ntt_op, &mod_op); + + // Calculate m0m1m2 + let mul_mod = |v0: &u64, v1: &u64| ((*v0 as u128 * *v1 as u128) % q as u128) as u64; + let m0m1 = negacyclic_mul(&m0, &m1, mul_mod, q); + let m0m1m2 = negacyclic_mul(&m2, &m0m1, mul_mod, q); + + // diff + mod_op.elwise_sub_mut(m0m1m2_back.as_mut_slice(), m0m1m2.as_ref()); + + check.add_more(&Vec::::try_convert_from(&m0m1m2_back, &q)); + } + + println!("Std: {}", check.std_dev().abs().log2()); + } } diff --git a/src/utils.rs b/src/utils.rs index c6e0d6f..2d29fb7 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,7 +1,7 @@ -use std::usize; +use std::{fmt::Debug, usize}; use itertools::Itertools; -use num_traits::{PrimInt, Signed}; +use num_traits::{FromPrimitive, PrimInt, Signed}; use crate::RandomUniformDist; pub trait WithLocal { @@ -228,3 +228,38 @@ impl TryConvertFrom<[u64]> for Vec { .collect_vec() } } + +pub(crate) struct Stats { + pub(crate) samples: Vec, +} + +impl Stats +where + // T: for<'a> Sum<&'a T>, + T: for<'a> std::iter::Sum<&'a T> + std::iter::Sum, +{ + pub(crate) fn mean(&self) -> f64 { + self.samples.iter().sum::().to_f64().unwrap() / (self.samples.len() as f64) + } + + pub(crate) fn std_dev(&self) -> f64 { + let mean = self.mean(); + + // diff + let diff_sq = self + .samples + .iter() + .map(|v| { + let t = v.to_f64().unwrap() - mean; + t * t + }) + .into_iter() + .sum::(); + + (diff_sq / (self.samples.len() as f64)).sqrt() + } + + pub(crate) fn add_more(&mut self, values: &[T]) { + self.samples.extend(values.iter()); + } +}