From ccee110b34c88d7222396fd4f179254aa5028e5b Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Tue, 4 Jun 2024 15:42:13 +0530 Subject: [PATCH] put decomposer in main.rs in different file --- src/backend.rs | 2 +- src/bool/evaluator.rs | 16 ++++++-- src/bool/mod.rs | 1 + src/bool/parameters.rs | 42 ++++++++++++++++++-- src/decomposer.rs | 62 ++++++----------------------- src/main.rs | 88 ++++++++++++++++++++++++++++++++++++++++-- src/random.rs | 29 +++++++------- src/utils.rs | 4 ++ 8 files changed, 169 insertions(+), 75 deletions(-) diff --git a/src/backend.rs b/src/backend.rs index 845db22..04996b9 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -43,7 +43,7 @@ impl Modulus for u64 { } fn map_element_to_i64(&self, v: &Self::Element) -> i64 { assert!(v <= self, "{v} must be <= {self}"); - if *v > (self >> 1) { + if *v >= (self >> 1) { -ToPrimitive::to_i64(&(self - v)).unwrap() } else { ToPrimitive::to_i64(v).unwrap() diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index c7f938f..e50ef37 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -160,7 +160,7 @@ where } } -trait BoolEncoding { +pub(super) trait BoolEncoding { type Element; fn true_el(&self) -> Self::Element; fn false_el(&self) -> Self::Element; @@ -210,7 +210,7 @@ where } } -struct BoolPbsInfo { +pub(super) struct BoolPbsInfo { auto_decomposer: DefaultDecomposer, rlwe_rgsw_decomposer: ( DefaultDecomposer, @@ -305,7 +305,15 @@ where _phantom: PhantomData, } -impl BoolEvaluator {} +impl BoolEvaluator { + pub(super) fn parameters(&self) -> &BoolParameters { + &self.pbs_info.parameters + } + + pub(super) fn pbs_info(&self) -> &BoolPbsInfo { + &self.pbs_info + } +} impl BoolEvaluator where @@ -1687,7 +1695,7 @@ mod tests { >::new(MP_BOOL_PARAMS); let (parties, collective_pk, _, _, server_key_eval, ideal_client_key) = - _multi_party_all_keygen(&bool_evaluator, 64); + _multi_party_all_keygen(&bool_evaluator, 2); let mut m0 = true; let mut m1 = false; diff --git a/src/bool/mod.rs b/src/bool/mod.rs index a8a4f9d..c4b19fb 100644 --- a/src/bool/mod.rs +++ b/src/bool/mod.rs @@ -1,5 +1,6 @@ pub(crate) mod evaluator; pub(crate) mod keys; +pub mod noise; pub(crate) mod parameters; pub type FheBool = Vec; diff --git a/src/bool/parameters.rs b/src/bool/parameters.rs index 1e8d7d8..0eb2a7d 100644 --- a/src/bool/parameters.rs +++ b/src/bool/parameters.rs @@ -319,22 +319,58 @@ pub(crate) const MP_BOOL_PARAMS: BoolParameters = BoolParameters:: { lwe_decomposer_base: DecompostionLogBase(4), lwe_decomposer_count: DecompositionCount(5), rlrg_decomposer_base: DecompostionLogBase(12), - rlrg_decomposer_count: (DecompositionCount(2), DecompositionCount(2)), + rlrg_decomposer_count: (DecompositionCount(5), DecompositionCount(5)), rgrg_decomposer_base: DecompostionLogBase(12), - rgrg_decomposer_count: (DecompositionCount(5), DecompositionCount(4)), + rgrg_decomposer_count: (DecompositionCount(5), DecompositionCount(5)), auto_decomposer_base: DecompostionLogBase(12), auto_decomposer_count: DecompositionCount(5), g: 5, w: 10, }; +// pub(crate) const SMALL_MP_BOOL_PARAMS: BoolParameters = +// BoolParameters:: { rlwe_q: +// CiphertextModulus::new_non_native(36028797018820609), lwe_q: +// CiphertextModulus::new_non_native(1 << 20), br_q: 1 << 11, +// rlwe_n: PolynomialSize(1 << 11), +// lwe_n: LweDimension(600), +// lwe_decomposer_base: DecompostionLogBase(4), +// lwe_decomposer_count: DecompositionCount(5), +// rlrg_decomposer_base: DecompostionLogBase(11), +// rlrg_decomposer_count: (DecompositionCount(2), DecompositionCount(2)), +// rgrg_decomposer_base: DecompostionLogBase(11), +// rgrg_decomposer_count: (DecompositionCount(5), DecompositionCount(4)), +// auto_decomposer_base: DecompostionLogBase(11), +// auto_decomposer_count: DecompositionCount(2), +// g: 5, +// w: 10, +// }; + +pub(crate) const SMALL_MP_BOOL_PARAMS: BoolParameters = BoolParameters:: { + rlwe_q: CiphertextModulus::new_non_native(36028797018820609), + lwe_q: CiphertextModulus::new_non_native(1 << 20), + br_q: 1 << 11, + rlwe_n: PolynomialSize(1 << 11), + lwe_n: LweDimension(500), + lwe_decomposer_base: DecompostionLogBase(4), + lwe_decomposer_count: DecompositionCount(5), + rlrg_decomposer_base: DecompostionLogBase(11), + rlrg_decomposer_count: (DecompositionCount(5), DecompositionCount(5)), + rgrg_decomposer_base: DecompostionLogBase(11), + rgrg_decomposer_count: (DecompositionCount(5), DecompositionCount(5)), + auto_decomposer_base: DecompostionLogBase(11), + auto_decomposer_count: DecompositionCount(5), + g: 5, + w: 10, +}; + #[cfg(test)] mod tests { use crate::utils::generate_prime; #[test] fn find_prime() { - let bits = 61; + let bits = 55; let ring_size = 1 << 11; let prime = generate_prime(bits, ring_size * 2, 1 << bits).unwrap(); dbg!(prime); diff --git a/src/decomposer.rs b/src/decomposer.rs index 56b2d39..e851efc 100644 --- a/src/decomposer.rs +++ b/src/decomposer.rs @@ -132,20 +132,21 @@ impl Decompose let full_mask = b - T::one(); let bby2 = b >> 1; - if value > (q >> 1) { + if value >= (q >> 1) { value = !(q - value) + T::one() } let mut out = Vec::with_capacity(self.d); for _ in 0..self.d { let k_i = value & full_mask; + value = (value - k_i) >> logb; - if k_i > bby2 || (k_i == bby2 && ((value & full_mask) >= bby2)) { + if k_i > bby2 || (k_i == bby2 && ((value & T::one()) == T::one())) { out.push(q - (b - k_i)); value = value + T::one(); } else { - out.push(k_i) + out.push(k_i); } } @@ -157,44 +158,6 @@ impl Decompose } } -// impl Decomposer for dyn AsRef> -// where -// DefaultDecomposer: Decomposer, -// { -// type Element = T; - -// fn new(q: Self::Element, logb: usize, d: usize) -> Self { -// DefaultDecomposer::::new(q, logb, d) -// } - -// fn decompose(&self, v: &Self::Element) -> Vec { -// todo!() -// } - -// fn decomposition_count(&self) -> usize { -// todo!() -// } -// } - -// impl>> Decomposer for U -// where -// DefaultDecomposer: Decomposer, -// { -// type Element = T; - -// fn new(q: Self::Element, logb: usize, d: usize) -> Self { -// todo!() -// } - -// fn decompose(&self, v: &Self::Element) -> Vec { -// todo!() -// } - -// fn decomposition_count(&self) -> usize { -// todo!() -// } -// } - fn round_value(value: T, ignore_bits: usize) -> T { if ignore_bits == 0 { return value; @@ -219,24 +182,23 @@ mod tests { #[test] fn decomposition_works() { - let logq = 50; - let logb = 5; - let d = 10; + let logq = 55; + let logb = 11; + let d = 5; + let ring_size = 1 << 11; let mut rng = thread_rng(); let mut stats = Stats { samples: vec![] }; - // q is prime of bits logq and i is true, other q = 1< Vec { + let b = 1u64 << logb; + let full_mask = b - 1u64; + let bby2 = b >> 1; + + if value >= (q >> 1) { + value = !(q - value) + 1; + } + + // let mut carry = 0; + // let mut out = Vec::with_capacity(d); + // for _ in 0..d { + // let k_i = carry + (value & full_mask); + // value = (value) >> logb; + // let go = thread_rng().gen_bool(1.0 / 2.0); + // if k_i > bby2 || (k_i == bby2 && ((value & 1) == 1)) { + // // if (k_i == bby2 && ((value & 1) == 1)) { + // // println!("AA"); + // // } + // out.push(q - (b - k_i)); + // carry = 1; + // } else { + // // if (k_i == bby2) { + // // println!("BB"); + // // } + // out.push(k_i); + // carry = 0; + // } + // } + // println!("Last carry {carry}"); + // return out; + + let mut out = Vec::with_capacity(d); + for _ in 0..d { + let k_i = value & full_mask; + value = (value - k_i) >> logb; + + if k_i > bby2 || (k_i == bby2 && ((value & 1) == 1)) { + // if (k_i == bby2 && ((value & 1) == 1)) { + // println!("AA"); + // } + out.push(q - (b - k_i)); + value += 1; + } else { + // if (k_i == bby2) { + // println!("BB"); + // } + out.push(k_i); + } + } + + return out; +} + +fn recompose(limbs: &[u64], q: u64, logb: u64) -> u64 { + let mut out = 0; + limbs.iter().enumerate().for_each(|(i, l)| { + let a = 1u128 << (logb * (i as u64)); + let a = ((a * (*l as u128)) % (q as u128)) as u64; + out = (out + a) % q; + }); + out % q +} + fn main() { - let mut v = Vec::with_capacity(10); - v[0] = 1; - println!("Hello, world!"); + // let mut v = Vec::with_capacity(10); + // v[0] = 1; + // println!("Hello, world!"); + + let mut rng = thread_rng(); + + let q = 36028797018820609u64; + let logb = 11; + let d = 5; + + for _ in 0..100000 { + let value = rng.gen_range(0..q); + let limbs = decomposer(value, q, d, logb); + // println!("{:?}", &limbs); + let value_back = recompose(&limbs, q, logb); + assert_eq!(value, value_back) + } } diff --git a/src/random.rs b/src/random.rs index 78b71ab..7b66a2a 100644 --- a/src/random.rs +++ b/src/random.rs @@ -111,15 +111,15 @@ where C: Modulus, { fn random_fill(&mut self, modulus: &C, container: &mut [T]) { - izip!( - rand_distr::Normal::new(0.0, 3.19f64) - .unwrap() - .sample_iter(&mut self.rng), - container.iter_mut() - ) - .for_each(|(from, to)| { - *to = modulus.map_element_from_f64(from); - }); + // izip!( + // rand_distr::Normal::new(0.0, 3.19f64) + // .unwrap() + // .sample_iter(&mut self.rng), + // container.iter_mut() + // ) + // .for_each(|(from, to)| { + // *to = modulus.map_element_from_f64(from); + // }); } } @@ -173,11 +173,12 @@ where impl> RandomGaussianElementInModulus for DefaultSecureRng { fn random(&mut self, modulus: &M) -> T { - modulus.map_element_from_f64( - rand_distr::Normal::new(0.0, 3.19f64) - .unwrap() - .sample(&mut self.rng), - ) + // modulus.map_element_from_f64( + // rand_distr::Normal::new(0.0, 3.19f64) + // .unwrap() + // .sample(&mut self.rng), + // ) + modulus.map_element_from_f64(0.0) } } diff --git a/src/utils.rs b/src/utils.rs index 968eed3..31b694e 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -279,6 +279,10 @@ where // T: for<'a> Sum<&'a T>, T: for<'a> std::iter::Sum<&'a T> + std::iter::Sum, { + pub(crate) fn new() -> Self { + Self { samples: vec![] } + } + pub(crate) fn mean(&self) -> f64 { self.samples.iter().sum::().to_f64().unwrap() / (self.samples.len() as f64) }