From ab7b1ca40f35203c787c71d228162d3f5141fd45 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Sun, 23 Jun 2024 13:38:04 +0700 Subject: [PATCH] fix rounding in decom --- src/bool/evaluator.rs | 31 ++++++++++---- src/bool/parameters.rs | 22 ++++++++++ src/decomposer.rs | 97 ++++++++++++++++++++++-------------------- src/rgsw/mod.rs | 95 +++++++++++++++++++++++++++++++---------- 4 files changed, 168 insertions(+), 77 deletions(-) diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index c514cba..f3452d5 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -326,8 +326,14 @@ pub(super) struct BoolPbsInfo { impl PbsInfo for BoolPbsInfo where - M::MatElement: - PrimInt + WrappingSub + NumInfo + FromPrimitive + From + Display + WrappingAdd, + M::MatElement: PrimInt + + WrappingSub + + NumInfo + + FromPrimitive + + From + + Display + + WrappingAdd + + Debug, RlweModOp: ArithmeticOps + ShoupMatrixFMA, LweModOp: ArithmeticOps + VectorOps, NttOp: Ntt, @@ -2003,7 +2009,8 @@ where + WrappingSub + NumInfo + From - + WrappingAdd, + + WrappingAdd + + Debug, RlweModOp: VectorOps + ArithmeticOps + ShoupMatrixFMA, @@ -2195,7 +2202,9 @@ mod tests { SP_TEST_BOOL_PARAMS, }, }, + evaluator, ntt::NttBackendU64, + parameters::OPTIMISED_SMALL_MP_BOOL_PARAMS, random::{RandomElementInModulus, DEFAULT_RNG}, rgsw::{ self, measure_noise, public_key_encrypt_rlwe, secret_key_encrypt_rlwe, @@ -2216,11 +2225,11 @@ mod tests { ModularOpsU64>, ModularOpsU64>, ShoupServerKeyEvaluationDomain>>, - >::new(SMALL_MP_BOOL_PARAMS); + >::new(OPTIMISED_SMALL_MP_BOOL_PARAMS); // let (_, collective_pk, _, _, server_key_eval, ideal_client_key) = // _multi_party_all_keygen(&bool_evaluator, 20); - let no_of_parties = 16; + let no_of_parties = 2; let lwe_q = bool_evaluator.pbs_info.parameters.lwe_q(); let rlwe_q = bool_evaluator.pbs_info.parameters.rlwe_q(); let lwe_n = bool_evaluator.pbs_info.parameters.lwe_n().0; @@ -2269,7 +2278,7 @@ mod tests { }); // check noise in freshly encrypted RLWE ciphertext (ie var_fresh) - if true { + if false { let mut rng = DefaultSecureRng::new(); let mut check = Stats { samples: vec![] }; for _ in 0..10 { @@ -2343,7 +2352,7 @@ mod tests { bool_evaluator.aggregate_multi_party_server_key_shares(&server_key_shares); // Check noise in RGSW ciphertexts of ideal LWE secret elements - if false { + if true { let mut check = Stats { samples: vec![] }; izip!(ideal_lwe_sk.iter(), seeded_server_key.rgsw_cts().iter()).for_each( |(s_i, rgsw_ct_i)| { @@ -2361,6 +2370,10 @@ mod tests { Vec::::try_convert_from(ideal_rlwe_sk.as_slice(), rlwe_q); rlwe_modop.elwise_neg_mut(&mut neg_s_eval); rlwe_nttop.forward(&mut neg_s_eval); + // let tmp_decomp = bool_evaluator + // .parameters() + // .rgsw_rgsw_decomposer::>(); + // let tmp_gadget = tmp_decomp.a().gadget_vector() for j in 0..rlwe_rgsw_decomposer.a().decomposition_count() { // RLWE(B^{j} * -s[X]*X^{s_lwe[i]}) @@ -2616,7 +2629,7 @@ mod tests { // check noise in RLWE(X^k) after sending RLWE(X) -> RLWE(X^k)using collective // auto key - if true { + if false { let mut check = Stats { samples: vec![] }; let br_q = bool_evaluator.pbs_info.br_q(); let g = bool_evaluator.pbs_info.g(); @@ -2692,7 +2705,7 @@ mod tests { // Check noise growth in ksk // TODO check in LWE key switching keys - if true { + if false { // 1. encrypt LWE ciphertext // 2. Key switching // 3. diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs index 851c5fb..f8336e4 100644 --- a/src/bool/parameters.rs +++ b/src/bool/parameters.rs @@ -486,6 +486,28 @@ pub(crate) const SMALL_MP_BOOL_PARAMS: BoolParameters = BoolParameters:: = BoolParameters:: { + rlwe_q: CiphertextModulus::new_non_native(18014398509404161), + lwe_q: CiphertextModulus::new_non_native(1 << 15), + br_q: 1 << 11, + rlwe_n: PolynomialSize(1 << 11), + lwe_n: LweDimension(500), + lwe_decomposer_params: (DecompostionLogBase(1), DecompositionCount(11)), + rlrg_decomposer_params: ( + DecompostionLogBase(24), + (DecompositionCount(1), DecompositionCount(1)), + ), + rgrg_decomposer_params: Some(( + DecompostionLogBase(12), + (DecompositionCount(3), DecompositionCount(3)), + )), + auto_decomposer_params: (DecompostionLogBase(20), DecompositionCount(1)), + non_interactive_ui_to_s_key_switch_decomposer: None, + g: 5, + w: 10, + variant: ParameterVariant::MultiParty, +}; + pub(crate) const NON_INTERACTIVE_SMALL_MP_BOOL_PARAMS: BoolParameters = BoolParameters:: { rlwe_q: CiphertextModulus::new_non_native(36028797018820609), lwe_q: CiphertextModulus::new_non_native(1 << 20), diff --git a/src/decomposer.rs b/src/decomposer.rs index ba883b4..5592f09 100644 --- a/src/decomposer.rs +++ b/src/decomposer.rs @@ -11,6 +11,7 @@ use std::{ use crate::backend::{ArithmeticOps, ModularOpsU64}; fn gadget_vector(logq: usize, logb: usize, d: usize) -> Vec { + assert!(logq >= (logb * d)); let ignored_bits = logq - (logb * d); (0..d) @@ -114,7 +115,8 @@ impl< + WrappingAdd + NumInfo + From - + Display, + + Display + + Debug, > Decomposer for DefaultDecomposer { type Element = T; @@ -128,6 +130,11 @@ impl< (T::BITS - q.leading_zeros()) as usize }; + assert!( + logq >= (logb * d), + "Decomposer wants logq >= logb*d but got logq={logq}, logb={logb}, d={d}" + ); + let ignore_bits = logq - (logb * d); DefaultDecomposer { @@ -144,20 +151,19 @@ impl< // TODO(Jay): Outline the caveat fn decompose_to_vec(&self, value: &T) -> Vec { - let mut value = round_value(*value, self.ignore_bits); - let q = self.q; let logb = self.logb; let b = T::one() << logb; let full_mask = b - T::one(); let bby2 = b >> 1; + let mut value = *value; if value >= (q >> 1) { value = !(q - value) + T::one() } - + value = round_value(value, self.ignore_bits); let mut out = Vec::with_capacity(self.d); - for _ in 0..self.d { + for _ in 0..(self.d) { let k_i = value & full_mask; value = (value - k_i) >> logb; @@ -178,11 +184,11 @@ impl< } fn decompose_iter(&self, value: &T) -> DecomposerIter { - let mut value = round_value(*value, self.ignore_bits); - + let mut value = *value; if value >= (self.q >> 1) { value = !(self.q - value) + T::one() } + value = round_value(value, self.ignore_bits); DecomposerIter { value, @@ -283,50 +289,49 @@ mod tests { #[test] fn decomposition_works() { - let logq = 55; - let logb = 12; - let d = 4; let ring_size = 1 << 11; let mut rng = thread_rng(); - let mut stats = vec![Stats::new(); d]; - for i in [true] { - let q = if i { - generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap() - } else { - 1u64 << logq - }; - let decomposer = DefaultDecomposer::new(q, logb, d); - dbg!(decomposer.ignore_bits); - let modq_op = ModularOpsU64::new(q); - for _ in 0..100000 { - let value = rng.gen_range(0..q); - let limbs = decomposer.decompose_to_vec(&value); - // let limbs_from_iter = decomposer.decompose_iter(&value).collect_vec(); - // assert_eq!(limbs, limbs_from_iter); - let value_back = round_value( - decomposer.recompose(&limbs, &modq_op), - decomposer.ignore_bits, - ); - let rounded_value = round_value(value, decomposer.ignore_bits); - // assert_eq!( - // rounded_value, value_back, - // "Expected {rounded_value} got {value_back} for q={q}" - // ); - - izip!(stats.iter_mut(), limbs.iter()).for_each(|(s, l)| { - s.add_more(&vec![q.map_element_to_i64(l)]); - }); + for logq in [37, 55] { + let logb = 11; + let d = 3; + let mut stats = vec![Stats::new(); d]; + + for i in [true] { + let q = if i { + generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap() + } else { + 1u64 << logq + }; + let decomposer = DefaultDecomposer::new(q, logb, d); + dbg!(decomposer.ignore_bits); + let modq_op = ModularOpsU64::new(q); + for _ in 0..1000000 { + let value = rng.gen_range(0..q); + let limbs = decomposer.decompose_to_vec(&value); + let limbs_from_iter = decomposer.decompose_iter(&value).collect_vec(); + assert_eq!(limbs, limbs_from_iter); + let value_back = round_value( + decomposer.recompose(&limbs, &modq_op), + decomposer.ignore_bits, + ); + let rounded_value = round_value(value, decomposer.ignore_bits); + assert!((rounded_value as i64 - value_back as i64).abs() <= 1,); + + izip!(stats.iter_mut(), limbs.iter()).for_each(|(s, l)| { + s.add_more(&vec![q.map_element_to_i64(l)]); + }); + } } - } - stats.iter().enumerate().for_each(|(index, s)| { - println!( - "Limb {index} - Mean: {}, Std: {}", - s.mean(), - s.std_dev().abs().log2() - ); - }); + stats.iter().enumerate().for_each(|(index, s)| { + println!( + "Limb {index} - Mean: {}, Std: {}", + s.mean(), + s.std_dev().abs().log2() + ); + }); + } } } diff --git a/src/rgsw/mod.rs b/src/rgsw/mod.rs index 9b711c5..87b6ef3 100644 --- a/src/rgsw/mod.rs +++ b/src/rgsw/mod.rs @@ -1156,8 +1156,8 @@ pub(crate) mod tests { let logq = 55; let ring_size = 1 << 11; let q = generate_prime(logq, ring_size as u64, 1u64 << logq).unwrap(); - let d = 12; - let logb = 4; + let d = 2; + let logb = 12; let decomposer = DefaultDecomposer::new(q, logb, d); let ntt_op = NttBackendU64::new(&q, ring_size as usize); @@ -1169,16 +1169,42 @@ pub(crate) mod tests { for _ in 0..10 { let mut a = vec![0u64; ring_size]; RandomFillUniformInModulus::random_fill(&mut rng, &q, a.as_mut()); - let mut e = vec![1u64; ring_size]; - // RandomFillGaussianInModulus::random_fill(&mut rng, &q, e.as_mut()); + let mut m = vec![0u64; ring_size]; + RandomFillGaussianInModulus::random_fill(&mut rng, &q, m.as_mut()); + + let mut sk = vec![0u64; ring_size]; + RandomFillGaussianInModulus::random_fill(&mut rng, &q, sk.as_mut()); + let mut sk_eval = sk.clone(); + ntt_op.forward(sk_eval.as_mut_slice()); let gadget_vector = decomposer.gadget_vector(); // ksk (beta e) - let mut ksk = vec![vec![0u64; ring_size]; decomposer.decomposition_count()]; - izip!(ksk.iter_rows_mut(), gadget_vector.iter()).for_each(|(row, beta)| { - row.as_mut_slice().copy_from_slice(e.as_ref()); - mod_op.elwise_scalar_mul_mut(row.as_mut_slice(), beta); + let mut ksk_part_b = vec![vec![0u64; ring_size]; decomposer.decomposition_count()]; + let mut ksk_part_a = vec![vec![0u64; ring_size]; decomposer.decomposition_count()]; + izip!( + ksk_part_b.iter_rows_mut(), + ksk_part_a.iter_rows_mut(), + gadget_vector.iter() + ) + .for_each(|(part_b, part_a, beta)| { + RandomFillUniformInModulus::random_fill(&mut rng, &q, part_a.as_mut()); + + // a * s + let mut tmp = part_a.to_vec(); + ntt_op.forward(tmp.as_mut()); + mod_op.elwise_mul_mut(tmp.as_mut(), sk_eval.as_ref()); + ntt_op.backward(tmp.as_mut()); + + // a*s + e + beta m + RandomFillGaussianInModulus::random_fill(&mut rng, &q, part_b.as_mut()); + // println!("E: {:?}", &part_b); + // a*s + e + mod_op.elwise_add_mut(part_b.as_mut_slice(), tmp.as_ref()); + // a*s + e + beta m + let mut tmp = m.to_vec(); + mod_op.elwise_scalar_mul_mut(tmp.as_mut_slice(), beta); + mod_op.elwise_add_mut(part_b.as_mut_slice(), tmp.as_ref()); }); // decompose a @@ -1195,35 +1221,60 @@ pub(crate) mod tests { // println!("Last limb"); - // decomp_a * ksk(beta e) - ksk.iter_mut() + // decomp_a * ksk(beta m) + ksk_part_b + .iter_mut() + .for_each(|r| ntt_op.forward(r.as_mut_slice())); + ksk_part_a + .iter_mut() .for_each(|r| ntt_op.forward(r.as_mut_slice())); decomposed_a .iter_mut() .for_each(|r| ntt_op.forward(r.as_mut_slice())); - let mut out = vec![0u64; ring_size]; - izip!(decomposed_a.iter(), ksk.iter()).for_each(|(a, b)| { - // out += a * b - let mut a_clone = a.clone(); - mod_op.elwise_mul_mut(a_clone.as_mut_slice(), b.as_ref()); - mod_op.elwise_add_mut(out.as_mut_slice(), a_clone.as_ref()); - }); - ntt_op.backward(out.as_mut_slice()); + let mut out = vec![vec![0u64; ring_size]; 2]; + izip!(decomposed_a.iter(), ksk_part_b.iter(), ksk_part_a.iter()).for_each( + |(d_a, part_b, part_a)| { + // out_a += d_a * part_a + let mut d_a_clone = d_a.clone(); + mod_op.elwise_mul_mut(d_a_clone.as_mut_slice(), part_a.as_ref()); + mod_op.elwise_add_mut(out[0].as_mut_slice(), d_a_clone.as_ref()); + + // out_b += d_a * part_b + let mut d_a_clone = d_a.clone(); + mod_op.elwise_mul_mut(d_a_clone.as_mut_slice(), part_b.as_ref()); + mod_op.elwise_add_mut(out[1].as_mut_slice(), d_a_clone.as_ref()); + }, + ); + out.iter_mut() + .for_each(|r| ntt_op.backward(r.as_mut_slice())); + + let out_back = { + // decrypt + // a*s + ntt_op.forward(out[0].as_mut()); + mod_op.elwise_mul_mut(out[0].as_mut(), sk_eval.as_ref()); + ntt_op.backward(out[0].as_mut()); + + // b - a*s + let tmp = (out[0]).clone(); + mod_op.elwise_sub_mut(out[1].as_mut(), tmp.as_ref()); + out.remove(1) + }; let out_expected = { let mut a_clone = a.clone(); - let mut e_clone = e.clone(); + let mut m_clone = m.clone(); ntt_op.forward(a_clone.as_mut_slice()); - ntt_op.forward(e_clone.as_mut_slice()); + ntt_op.forward(m_clone.as_mut_slice()); - mod_op.elwise_mul_mut(a_clone.as_mut_slice(), e_clone.as_mut_slice()); + mod_op.elwise_mul_mut(a_clone.as_mut_slice(), m_clone.as_mut_slice()); ntt_op.backward(a_clone.as_mut_slice()); a_clone }; let mut diff = out_expected; - mod_op.elwise_sub_mut(diff.as_mut_slice(), out.as_ref()); + mod_op.elwise_sub_mut(diff.as_mut_slice(), out_back.as_ref()); stats.add_more(&Vec::::try_convert_from(diff.as_ref(), &q)); }