From a05e959e752cff10f6328b9eb24db2d171cf465c Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Wed, 5 Jun 2024 17:15:40 +0530 Subject: [PATCH] add decomp_iter --- benches/modulus.rs | 68 ++++++++++++++++++++++++++-- src/bool/noise.rs | 12 ++--- src/decomposer.rs | 108 +++++++++++++++++++++++++++++++++++++++++---- src/lib.rs | 1 + src/lwe.rs | 2 +- src/main.rs | 4 +- src/rgsw.rs | 5 ++- 7 files changed, 176 insertions(+), 24 deletions(-) diff --git a/benches/modulus.rs b/benches/modulus.rs index 087ad93..749ffa8 100644 --- a/benches/modulus.rs +++ b/benches/modulus.rs @@ -1,9 +1,70 @@ -use bin_rs::{ModInit, ModularOpsU64, VectorOps}; +use bin_rs::{Decomposer, DefaultDecomposer, ModInit, ModularOpsU64, VectorOps}; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; -use itertools::Itertools; +use itertools::{izip, Itertools}; use rand::{thread_rng, Rng}; use rand_distr::Uniform; +pub(crate) fn decompose_r( + r: &[u64], + decomp_r: &mut [Vec], + decomposer: &DefaultDecomposer, +) { + let ring_size = r.len(); + // let d = decomposer.decomposition_count(); + // let mut count = 0; + for ri in 0..ring_size { + // let el_decomposed = decomposer.decompose(&r[ri]); + decomposer + .decompose_iter(&r[ri]) + .enumerate() + .into_iter() + .for_each(|(j, el)| { + decomp_r[j][ri] = el; + }); + } +} + +fn benchmark_decomposer(c: &mut Criterion) { + let mut group = c.benchmark_group("decomposer"); + + // let decomposers = vec![]; + // 55 + for prime in [36028797017456641] { + for ring_size in [1 << 11] { + let logb = 11; + let decomposer = DefaultDecomposer::new(prime, logb, 2); + + let mut rng = thread_rng(); + let dist = Uniform::new(0, prime); + let a = (&mut rng).sample_iter(dist).take(ring_size).collect_vec(); + + group.bench_function( + BenchmarkId::new( + "decompose", + format!( + "q={prime}/N={ring_size}/logB={logb}/d={}", + decomposer.decomposition_count() + ), + ), + |b| { + b.iter_batched_ref( + || { + ( + a.clone(), + vec![vec![0u64; ring_size]; decomposer.decomposition_count()], + ) + }, + |(r, decomp_r)| (decompose_r(r, decomp_r, &decomposer)), + criterion::BatchSize::PerIteration, + ) + }, + ); + } + } + + group.finish(); +} + fn benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("modulus"); // 55 @@ -34,5 +95,6 @@ fn benchmark(c: &mut Criterion) { group.finish(); } +criterion_group!(decomposer, benchmark_decomposer); criterion_group!(modulus, benchmark); -criterion_main!(modulus); +criterion_main!(modulus, decomposer); diff --git a/src/bool/noise.rs b/src/bool/noise.rs index d2b7491..a6ff6b0 100644 --- a/src/bool/noise.rs +++ b/src/bool/noise.rs @@ -103,13 +103,13 @@ mod test { println!("Gate time: {:?}", now.elapsed()); // mp decrypt - let decryption_shares = cks - .iter() - .map(|c| evaluator.multi_party_decryption_share(&c_out, c)) - .collect_vec(); - let m_out = evaluator.multi_party_decrypt(&decryption_shares, &c_out); + // let decryption_shares = cks + // .iter() + // .map(|c| evaluator.multi_party_decryption_share(&c_out, c)) + // .collect_vec(); + // let m_out = evaluator.multi_party_decrypt(&decryption_shares, &c_out); let m_expected = (m0 ^ m1); - assert_eq!(m_expected, m_out, "Expected {m_expected} but got {m_out}"); + // assert_eq!(m_expected, m_out, "Expected {m_expected} but got {m_out}"); // // find noise update // { diff --git a/src/decomposer.rs b/src/decomposer.rs index 070aed0..d02efee 100644 --- a/src/decomposer.rs +++ b/src/decomposer.rs @@ -39,19 +39,33 @@ where pub trait Decomposer { type Element; + type Iter: Iterator; fn new(q: Self::Element, logb: usize, d: usize) -> Self; - //FIXME(Jay): there's no reason why it returns a vec instead of an iterator - fn decompose(&self, v: &Self::Element) -> Vec; + + fn decompose_to_vec(&self, v: &Self::Element) -> Vec; + fn decompose_iter(&self, v: &Self::Element) -> Self::Iter; fn decomposition_count(&self) -> usize; } -// TODO(Jay): Shouldn't Decompose also return corresponding gadget vector ? pub struct DefaultDecomposer { + /// Ciphertext modulus q: T, + /// Log of ciphertext modulus logq: usize, + /// Log of base B logb: usize, + /// base B + b: T, + /// (B - 1). To simulate (% B) as &(B-1), that is extract least significant + /// logb bits + b_mask: T, + /// B/2 + bby2: T, + /// Decomposition count d: usize, + /// No. of bits to ignore in rounding ignore_bits: usize, + /// No. of limbs to ignore in rounding. Set to ceil(logq / logb) - d ignore_limbs: usize, } @@ -96,6 +110,7 @@ impl Decompose for DefaultDecomposer { type Element = T; + type Iter = DecomposerIter; fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer { // if q is power of 2, then `BITS - leading_zeros` outputs logq + 1. @@ -113,6 +128,9 @@ impl Decompose q, logq, logb, + b: T::one() << logb, + b_mask: (T::one() << logb) - T::one(), + bby2: T::one() << (logb - 1), d, ignore_bits, ignore_limbs, @@ -120,7 +138,7 @@ impl Decompose } // TODO(Jay): Outline the caveat - fn decompose(&self, value: &T) -> Vec { + fn decompose_to_vec(&self, value: &T) -> Vec { let mut value = round_value(*value, self.ignore_bits); let q = self.q; @@ -153,6 +171,75 @@ impl Decompose fn decomposition_count(&self) -> usize { self.d } + + fn decompose_iter(&self, value: &T) -> DecomposerIter { + let mut value = round_value(*value, self.ignore_bits); + + if value >= (self.q >> 1) { + value = !(self.q - value) + T::one() + } + + DecomposerIter { + value, + q: self.q, + logb: self.logb, + b: self.b, + bby2: self.bby2, + b_mask: self.b_mask, + steps_left: self.d, + } + } +} + +impl DefaultDecomposer {} + +pub struct DecomposerIter { + /// Value to decompose + value: T, + steps_left: usize, + /// (1 << logb) - 1 (for % (1< Iterator for DecomposerIter { + type Item = T; + + fn next(&mut self) -> Option { + if self.steps_left != 0 { + self.steps_left -= 1; + let k_i = self.value & self.b_mask; + + self.value = (self.value - k_i) >> self.logb; + + if k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one())) { + self.value = self.value + T::one(); + Some(self.q + k_i - self.b) + } else { + Some(k_i) + } + + // let carry = >::from( + // k_i > self.bby2 || (k_i == self.bby2 && ((self.value & + // T::one()) == T::one())), ); + // self.value = self.value + carry; + + // Some( + // (self.q & ((carry << self.logq) - (T::one() & carry))) + k_i + // - (carry << self.logb), ) + + // Some(k_i) + } else { + None + } + } } fn round_value(value: T, ignore_bits: usize) -> T { @@ -197,15 +284,18 @@ mod tests { let modq_op = ModularOpsU64::new(q); for _ in 0..100000 { let value = rng.gen_range(0..q); - let limbs = decomposer.decompose(&value); - let value_back = decomposer.recompose(&limbs, &modq_op); - let rounded_value = - round_value(value, decomposer.ignore_bits) << decomposer.ignore_bits; - stats.add_more(&Vec::::try_convert_from(&limbs, &q)); + let limbs = decomposer.decompose_to_vec(&value); + let value_back = round_value( + decomposer.recompose(&limbs, &modq_op), + decomposer.ignore_bits, + ); + let rounded_value = round_value(value, decomposer.ignore_bits); assert_eq!( rounded_value, value_back, "Expected {rounded_value} got {value_back} for q={q}" ); + + stats.add_more(&Vec::::try_convert_from(&limbs, &q)); } } println!("Mean: {}", stats.mean()); diff --git a/src/lib.rs b/src/lib.rs index 4f91121..3bb54b2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,6 +21,7 @@ mod shortint; mod utils; pub use backend::{ModInit, ModularOpsU64, VectorOps}; +pub use decomposer::{Decomposer, DecomposerIter, DefaultDecomposer}; pub use ntt::{Ntt, NttBackendU64, NttInit}; pub trait Matrix: AsRef<[Self::R]> { diff --git a/src/lwe.rs b/src/lwe.rs index e036f41..b086952 100644 --- a/src/lwe.rs +++ b/src/lwe.rs @@ -125,7 +125,7 @@ pub(crate) fn lwe_key_switch< .as_ref() .iter() .skip(1) - .flat_map(|ai| decomposer.decompose(ai)); + .flat_map(|ai| decomposer.decompose_to_vec(ai)); izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| { operator.elwise_fma_scalar_mut(lwe_out.as_mut(), beta_ij_lwe.as_ref(), &ai_j); }); diff --git a/src/main.rs b/src/main.rs index f415470..899422c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,8 +16,7 @@ fn decomposer(mut value: u64, q: u64, d: usize, logb: u64) -> Vec { // for _ in 0..d { // let k_i = carry + (value & full_mask); // value = (value) >> logb; - // let go = thread_rng().gen_bool(1.0 / 2.0); - // if k_i > bby2 || (k_i == bby2 && ((value & 1) == 1)) { + // if k_i > bby2 { // // if (k_i == bby2 && ((value & 1) == 1)) { // // println!("AA"); // // } @@ -31,7 +30,6 @@ fn decomposer(mut value: u64, q: u64, d: usize, logb: u64) -> Vec { // carry = 0; // } // } - // println!("Last carry {carry}"); // return out; let mut out = Vec::with_capacity(d); diff --git a/src/rgsw.rs b/src/rgsw.rs index 9e15b59..c2c13ca 100644 --- a/src/rgsw.rs +++ b/src/rgsw.rs @@ -518,7 +518,8 @@ pub(crate) fn decompose_r>( let d = decomposer.decomposition_count(); for ri in 0..ring_size { - let el_decomposed = decomposer.decompose(&r[ri]); + let el_decomposed = decomposer.decompose_to_vec(&r[ri]); + for j in 0..d { decomp_r[j].as_mut()[ri] = el_decomposed[j]; } @@ -570,7 +571,7 @@ pub(crate) fn galois_auto< .for_each(|(el_in, to_index, sign)| { let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in }; - let el_out_decomposed = decomposer.decompose(&el_out); + let el_out_decomposed = decomposer.decompose_to_vec(&el_out); for j in 0..d { scratch_matrix_d_ring[j].as_mut()[*to_index] = el_out_decomposed[j]; }