Browse Source

put decomposer in main.rs in different file

par-agg-key-shares
Janmajaya Mall 10 months ago
parent
commit
ccee110b34
8 changed files with 169 additions and 75 deletions
  1. +1
    -1
      src/backend.rs
  2. +12
    -4
      src/bool/evaluator.rs
  3. +1
    -0
      src/bool/mod.rs
  4. +39
    -3
      src/bool/parameters.rs
  5. +12
    -50
      src/decomposer.rs
  6. +85
    -3
      src/main.rs
  7. +15
    -14
      src/random.rs
  8. +4
    -0
      src/utils.rs

+ 1
- 1
src/backend.rs

@ -43,7 +43,7 @@ impl Modulus for u64 {
} }
fn map_element_to_i64(&self, v: &Self::Element) -> i64 { fn map_element_to_i64(&self, v: &Self::Element) -> i64 {
assert!(v <= self, "{v} must be <= {self}"); assert!(v <= self, "{v} must be <= {self}");
if *v > (self >> 1) {
if *v >= (self >> 1) {
-ToPrimitive::to_i64(&(self - v)).unwrap() -ToPrimitive::to_i64(&(self - v)).unwrap()
} else { } else {
ToPrimitive::to_i64(v).unwrap() ToPrimitive::to_i64(v).unwrap()

+ 12
- 4
src/bool/evaluator.rs

@ -160,7 +160,7 @@ where
} }
} }
trait BoolEncoding {
pub(super) trait BoolEncoding {
type Element; type Element;
fn true_el(&self) -> Self::Element; fn true_el(&self) -> Self::Element;
fn false_el(&self) -> Self::Element; fn false_el(&self) -> Self::Element;
@ -210,7 +210,7 @@ where
} }
} }
struct BoolPbsInfo<M: Matrix, Ntt, RlweModOp, LweModOp> {
pub(super) struct BoolPbsInfo<M: Matrix, Ntt, RlweModOp, LweModOp> {
auto_decomposer: DefaultDecomposer<M::MatElement>, auto_decomposer: DefaultDecomposer<M::MatElement>,
rlwe_rgsw_decomposer: ( rlwe_rgsw_decomposer: (
DefaultDecomposer<M::MatElement>, DefaultDecomposer<M::MatElement>,
@ -305,7 +305,15 @@ where
_phantom: PhantomData<M>, _phantom: PhantomData<M>,
} }
impl<M: Matrix, NttOp, RlweModOp, LweModOp> BoolEvaluator<M, NttOp, RlweModOp, LweModOp> {}
impl<M: Matrix, NttOp, RlweModOp, LweModOp> BoolEvaluator<M, NttOp, RlweModOp, LweModOp> {
pub(super) fn parameters(&self) -> &BoolParameters<M::MatElement> {
&self.pbs_info.parameters
}
pub(super) fn pbs_info(&self) -> &BoolPbsInfo<M, NttOp, RlweModOp, LweModOp> {
&self.pbs_info
}
}
impl<M: Matrix, NttOp, RlweModOp, LweModOp> BoolEvaluator<M, NttOp, RlweModOp, LweModOp> impl<M: Matrix, NttOp, RlweModOp, LweModOp> BoolEvaluator<M, NttOp, RlweModOp, LweModOp>
where where
@ -1687,7 +1695,7 @@ mod tests {
>::new(MP_BOOL_PARAMS); >::new(MP_BOOL_PARAMS);
let (parties, collective_pk, _, _, server_key_eval, ideal_client_key) = let (parties, collective_pk, _, _, server_key_eval, ideal_client_key) =
_multi_party_all_keygen(&bool_evaluator, 64);
_multi_party_all_keygen(&bool_evaluator, 2);
let mut m0 = true; let mut m0 = true;
let mut m1 = false; let mut m1 = false;

+ 1
- 0
src/bool/mod.rs

@ -1,5 +1,6 @@
pub(crate) mod evaluator; pub(crate) mod evaluator;
pub(crate) mod keys; pub(crate) mod keys;
pub mod noise;
pub(crate) mod parameters; pub(crate) mod parameters;
pub type FheBool = Vec<u64>; pub type FheBool = Vec<u64>;

+ 39
- 3
src/bool/parameters.rs

@ -319,22 +319,58 @@ pub(crate) const MP_BOOL_PARAMS: BoolParameters = BoolParameters:: {
lwe_decomposer_base: DecompostionLogBase(4), lwe_decomposer_base: DecompostionLogBase(4),
lwe_decomposer_count: DecompositionCount(5), lwe_decomposer_count: DecompositionCount(5),
rlrg_decomposer_base: DecompostionLogBase(12), rlrg_decomposer_base: DecompostionLogBase(12),
rlrg_decomposer_count: (DecompositionCount(2), DecompositionCount(2)),
rlrg_decomposer_count: (DecompositionCount(5), DecompositionCount(5)),
rgrg_decomposer_base: DecompostionLogBase(12), rgrg_decomposer_base: DecompostionLogBase(12),
rgrg_decomposer_count: (DecompositionCount(5), DecompositionCount(4)),
rgrg_decomposer_count: (DecompositionCount(5), DecompositionCount(5)),
auto_decomposer_base: DecompostionLogBase(12), auto_decomposer_base: DecompostionLogBase(12),
auto_decomposer_count: DecompositionCount(5), auto_decomposer_count: DecompositionCount(5),
g: 5, g: 5,
w: 10, w: 10,
}; };
// pub(crate) const SMALL_MP_BOOL_PARAMS: BoolParameters<u64> =
// BoolParameters::<u64> { rlwe_q:
// CiphertextModulus::new_non_native(36028797018820609), lwe_q:
// CiphertextModulus::new_non_native(1 << 20), br_q: 1 << 11,
// rlwe_n: PolynomialSize(1 << 11),
// lwe_n: LweDimension(600),
// lwe_decomposer_base: DecompostionLogBase(4),
// lwe_decomposer_count: DecompositionCount(5),
// rlrg_decomposer_base: DecompostionLogBase(11),
// rlrg_decomposer_count: (DecompositionCount(2), DecompositionCount(2)),
// rgrg_decomposer_base: DecompostionLogBase(11),
// rgrg_decomposer_count: (DecompositionCount(5), DecompositionCount(4)),
// auto_decomposer_base: DecompostionLogBase(11),
// auto_decomposer_count: DecompositionCount(2),
// g: 5,
// w: 10,
// };
pub(crate) const SMALL_MP_BOOL_PARAMS: BoolParameters<u64> = BoolParameters::<u64> {
rlwe_q: CiphertextModulus::new_non_native(36028797018820609),
lwe_q: CiphertextModulus::new_non_native(1 << 20),
br_q: 1 << 11,
rlwe_n: PolynomialSize(1 << 11),
lwe_n: LweDimension(500),
lwe_decomposer_base: DecompostionLogBase(4),
lwe_decomposer_count: DecompositionCount(5),
rlrg_decomposer_base: DecompostionLogBase(11),
rlrg_decomposer_count: (DecompositionCount(5), DecompositionCount(5)),
rgrg_decomposer_base: DecompostionLogBase(11),
rgrg_decomposer_count: (DecompositionCount(5), DecompositionCount(5)),
auto_decomposer_base: DecompostionLogBase(11),
auto_decomposer_count: DecompositionCount(5),
g: 5,
w: 10,
};
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::utils::generate_prime; use crate::utils::generate_prime;
#[test] #[test]
fn find_prime() { fn find_prime() {
let bits = 61;
let bits = 55;
let ring_size = 1 << 11; let ring_size = 1 << 11;
let prime = generate_prime(bits, ring_size * 2, 1 << bits).unwrap(); let prime = generate_prime(bits, ring_size * 2, 1 << bits).unwrap();
dbg!(prime); dbg!(prime);

+ 12
- 50
src/decomposer.rs

@ -132,20 +132,21 @@ impl Decompose
let full_mask = b - T::one(); let full_mask = b - T::one();
let bby2 = b >> 1; let bby2 = b >> 1;
if value > (q >> 1) {
if value >= (q >> 1) {
value = !(q - value) + T::one() value = !(q - value) + T::one()
} }
let mut out = Vec::with_capacity(self.d); let mut out = Vec::with_capacity(self.d);
for _ in 0..self.d { for _ in 0..self.d {
let k_i = value & full_mask; let k_i = value & full_mask;
value = (value - k_i) >> logb; value = (value - k_i) >> logb;
if k_i > bby2 || (k_i == bby2 && ((value & full_mask) >= bby2)) {
if k_i > bby2 || (k_i == bby2 && ((value & T::one()) == T::one())) {
out.push(q - (b - k_i)); out.push(q - (b - k_i));
value = value + T::one(); value = value + T::one();
} else { } else {
out.push(k_i)
out.push(k_i);
} }
} }
@ -157,44 +158,6 @@ impl Decompose
} }
} }
// impl<T> Decomposer for dyn AsRef<DefaultDecomposer<T>>
// where
// DefaultDecomposer<T>: Decomposer<Element = T>,
// {
// type Element = T;
// fn new(q: Self::Element, logb: usize, d: usize) -> Self {
// DefaultDecomposer::<T>::new(q, logb, d)
// }
// fn decompose(&self, v: &Self::Element) -> Vec<Self::Element> {
// todo!()
// }
// fn decomposition_count(&self) -> usize {
// todo!()
// }
// }
// impl<U: AsRef<DefaultDecomposer<T>>> Decomposer for U
// where
// DefaultDecomposer<T>: Decomposer,
// {
// type Element = T;
// fn new(q: Self::Element, logb: usize, d: usize) -> Self {
// todo!()
// }
// fn decompose(&self, v: &Self::Element) -> Vec<Self::Element> {
// todo!()
// }
// fn decomposition_count(&self) -> usize {
// todo!()
// }
// }
fn round_value<T: PrimInt>(value: T, ignore_bits: usize) -> T { fn round_value<T: PrimInt>(value: T, ignore_bits: usize) -> T {
if ignore_bits == 0 { if ignore_bits == 0 {
return value; return value;
@ -219,24 +182,23 @@ mod tests {
#[test] #[test]
fn decomposition_works() { fn decomposition_works() {
let logq = 50;
let logb = 5;
let d = 10;
let logq = 55;
let logb = 11;
let d = 5;
let ring_size = 1 << 11;
let mut rng = thread_rng(); let mut rng = thread_rng();
let mut stats = Stats { samples: vec![] }; let mut stats = Stats { samples: vec![] };
// q is prime of bits logq and i is true, other q = 1<<logq
// FIXME: Test fails when q is prime, albeit the difference is minute
for i in [false] {
for i in [true] {
let q = if i { let q = if i {
generate_prime(logq, 1 << 4, 1u64 << logq).unwrap()
generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap()
} else { } else {
1u64 << logq 1u64 << logq
}; };
let decomposer = DefaultDecomposer::new(q, logb, d); let decomposer = DefaultDecomposer::new(q, logb, d);
let modq_op = ModularOpsU64::new(q); let modq_op = ModularOpsU64::new(q);
for _ in 0..1000 {
for _ in 0..100000 {
let value = rng.gen_range(0..q); let value = rng.gen_range(0..q);
let limbs = decomposer.decompose(&value); let limbs = decomposer.decompose(&value);
let value_back = decomposer.recompose(&limbs, &modq_op); let value_back = decomposer.recompose(&limbs, &modq_op);
@ -250,6 +212,6 @@ mod tests {
} }
} }
println!("Mean: {}", stats.mean()); println!("Mean: {}", stats.mean());
println!("Std: {}", stats.std_dev());
println!("Std: {}", stats.std_dev().abs().log2());
} }
} }

+ 85
- 3
src/main.rs

@ -1,5 +1,87 @@
use std::os::unix::thread;
use rand::{thread_rng, Rng};
fn decomposer(mut value: u64, q: u64, d: usize, logb: u64) -> Vec<u64> {
let b = 1u64 << logb;
let full_mask = b - 1u64;
let bby2 = b >> 1;
if value >= (q >> 1) {
value = !(q - value) + 1;
}
// let mut carry = 0;
// let mut out = Vec::with_capacity(d);
// for _ in 0..d {
// let k_i = carry + (value & full_mask);
// value = (value) >> logb;
// let go = thread_rng().gen_bool(1.0 / 2.0);
// if k_i > bby2 || (k_i == bby2 && ((value & 1) == 1)) {
// // if (k_i == bby2 && ((value & 1) == 1)) {
// // println!("AA");
// // }
// out.push(q - (b - k_i));
// carry = 1;
// } else {
// // if (k_i == bby2) {
// // println!("BB");
// // }
// out.push(k_i);
// carry = 0;
// }
// }
// println!("Last carry {carry}");
// return out;
let mut out = Vec::with_capacity(d);
for _ in 0..d {
let k_i = value & full_mask;
value = (value - k_i) >> logb;
if k_i > bby2 || (k_i == bby2 && ((value & 1) == 1)) {
// if (k_i == bby2 && ((value & 1) == 1)) {
// println!("AA");
// }
out.push(q - (b - k_i));
value += 1;
} else {
// if (k_i == bby2) {
// println!("BB");
// }
out.push(k_i);
}
}
return out;
}
fn recompose(limbs: &[u64], q: u64, logb: u64) -> u64 {
let mut out = 0;
limbs.iter().enumerate().for_each(|(i, l)| {
let a = 1u128 << (logb * (i as u64));
let a = ((a * (*l as u128)) % (q as u128)) as u64;
out = (out + a) % q;
});
out % q
}
fn main() { fn main() {
let mut v = Vec::with_capacity(10);
v[0] = 1;
println!("Hello, world!");
// let mut v = Vec::with_capacity(10);
// v[0] = 1;
// println!("Hello, world!");
let mut rng = thread_rng();
let q = 36028797018820609u64;
let logb = 11;
let d = 5;
for _ in 0..100000 {
let value = rng.gen_range(0..q);
let limbs = decomposer(value, q, d, logb);
// println!("{:?}", &limbs);
let value_back = recompose(&limbs, q, logb);
assert_eq!(value, value_back)
}
} }

+ 15
- 14
src/random.rs

@ -111,15 +111,15 @@ where
C: Modulus<Element = T>, C: Modulus<Element = T>,
{ {
fn random_fill(&mut self, modulus: &C, container: &mut [T]) { fn random_fill(&mut self, modulus: &C, container: &mut [T]) {
izip!(
rand_distr::Normal::new(0.0, 3.19f64)
.unwrap()
.sample_iter(&mut self.rng),
container.iter_mut()
)
.for_each(|(from, to)| {
*to = modulus.map_element_from_f64(from);
});
// izip!(
// rand_distr::Normal::new(0.0, 3.19f64)
// .unwrap()
// .sample_iter(&mut self.rng),
// container.iter_mut()
// )
// .for_each(|(from, to)| {
// *to = modulus.map_element_from_f64(from);
// });
} }
} }
@ -173,11 +173,12 @@ where
impl<T, M: Modulus<Element = T>> RandomGaussianElementInModulus<T, M> for DefaultSecureRng { impl<T, M: Modulus<Element = T>> RandomGaussianElementInModulus<T, M> for DefaultSecureRng {
fn random(&mut self, modulus: &M) -> T { fn random(&mut self, modulus: &M) -> T {
modulus.map_element_from_f64(
rand_distr::Normal::new(0.0, 3.19f64)
.unwrap()
.sample(&mut self.rng),
)
// modulus.map_element_from_f64(
// rand_distr::Normal::new(0.0, 3.19f64)
// .unwrap()
// .sample(&mut self.rng),
// )
modulus.map_element_from_f64(0.0)
} }
} }

+ 4
- 0
src/utils.rs

@ -279,6 +279,10 @@ where
// T: for<'a> Sum<&'a T>, // T: for<'a> Sum<&'a T>,
T: for<'a> std::iter::Sum<&'a T> + std::iter::Sum<T>, T: for<'a> std::iter::Sum<&'a T> + std::iter::Sum<T>,
{ {
pub(crate) fn new() -> Self {
Self { samples: vec![] }
}
pub(crate) fn mean(&self) -> f64 { pub(crate) fn mean(&self) -> f64 {
self.samples.iter().sum::<T>().to_f64().unwrap() / (self.samples.len() as f64) self.samples.iter().sum::<T>().to_f64().unwrap() / (self.samples.len() as f64)
} }

Loading…
Cancel
Save