From a20a3e8e77e04a9e7aadc7765d68937eae39013a Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Wed, 5 Jun 2024 17:39:35 +0530 Subject: [PATCH] decompose with bit hacks without brnaching speeds up bootstrappoing by 15ms --- src/bool/evaluator.rs | 15 +++++++++++---- src/decomposer.rs | 42 ++++++++++++++++++++++++++---------------- src/lwe.rs | 2 +- src/rgsw.rs | 13 ------------- 4 files changed, 38 insertions(+), 34 deletions(-) diff --git a/src/bool/evaluator.rs b/src/bool/evaluator.rs index 8bae269..5f28000 100644 --- a/src/bool/evaluator.rs +++ b/src/bool/evaluator.rs @@ -230,7 +230,7 @@ pub(super) struct BoolPbsInfo { impl PbsInfo for BoolPbsInfo where - M::MatElement: PrimInt + WrappingSub + NumInfo + FromPrimitive, + M::MatElement: PrimInt + WrappingSub + NumInfo + FromPrimitive + From, RlweModOp: ArithmeticOps + VectorOps, LweModOp: ArithmeticOps + VectorOps, NttOp: Ntt, @@ -319,8 +319,14 @@ impl BoolEvaluator BoolEvaluator where M: MatrixEntity + MatrixMut, - M::MatElement: - PrimInt + Debug + Display + NumInfo + FromPrimitive + WrappingSub + SampleUniform, + M::MatElement: PrimInt + + Debug + + Display + + NumInfo + + FromPrimitive + + WrappingSub + + SampleUniform + + From, NttOp: Ntt, RlweModOp: ArithmeticOps + VectorOps @@ -1108,7 +1114,8 @@ impl BooleanGates for BoolEvaluator, RlweModOp: VectorOps + ArithmeticOps + GetModulus>, diff --git a/src/decomposer.rs b/src/decomposer.rs index 4f43b0e..4337bef 100644 --- a/src/decomposer.rs +++ b/src/decomposer.rs @@ -106,7 +106,7 @@ impl DefaultDecomposer { } } -impl Decomposer +impl> Decomposer for DefaultDecomposer { type Element = T; @@ -182,6 +182,7 @@ impl Decompose DecomposerIter { value, q: self.q, + logq: self.logq, logb: self.logb, b: self.b, bby2: self.bby2, @@ -205,11 +206,13 @@ pub struct DecomposerIter { bby2: T, /// Ciphertext modulus q: T, + /// Log of ciphertext modulus + logq: usize, /// b = 1 << logb b: T, } -impl Iterator for DecomposerIter { +impl> Iterator for DecomposerIter { type Item = T; fn next(&mut self) -> Option { @@ -219,20 +222,27 @@ impl Iterator for DecomposerIter { self.value = (self.value - k_i) >> self.logb; - if k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one())) { - self.value = self.value + T::one(); - Some(self.q + k_i - self.b) - } else { - Some(k_i) - } - - // let carry = >::from( - // k_i > self.bby2 || (k_i == self.bby2 && ((self.value & - // T::one()) == T::one())), ); - // self.value = self.value + carry; - - // Some((self.q & ((carry << 55) - (T::one() & carry))) + k_i - - // (carry << self.logb)) + // if k_i > self.bby2 || (k_i == self.bby2 && ((self.value & + // T::one()) == T::one())) { self.value = self.value + // + T::one(); Some(self.q + k_i - self.b) + // } else { + // Some(k_i) + // } + + // Following is without branching impl of the commented version above. It + // happens to speed up bootstrapping for `SMALL_MP_BOOL_PARAMS` (& other + // parameters as well but I haven't tested) by roughly 15ms. + // Suprisingly the improvement does not show up when I benchmark + // `decomposer_iter` in isolation. Putting this remark here as a + // future task to investiage (TODO). + let carry = >::from( + k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one())), + ); + self.value = self.value + carry; + + Some( + (self.q & ((carry << self.logq) - (T::one() & carry))) + k_i - (carry << self.logb), + ) // Some(k_i) } else { diff --git a/src/lwe.rs b/src/lwe.rs index b086952..ca74629 100644 --- a/src/lwe.rs +++ b/src/lwe.rs @@ -125,7 +125,7 @@ pub(crate) fn lwe_key_switch< .as_ref() .iter() .skip(1) - .flat_map(|ai| decomposer.decompose_to_vec(ai)); + .flat_map(|ai| decomposer.decompose_iter(ai)); izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| { operator.elwise_fma_scalar_mut(lwe_out.as_mut(), beta_ij_lwe.as_ref(), &ai_j); }); diff --git a/src/rgsw.rs b/src/rgsw.rs index 1be5e26..0e09eda 100644 --- a/src/rgsw.rs +++ b/src/rgsw.rs @@ -515,21 +515,14 @@ pub(crate) fn decompose_r>( R::Element: Copy, { let ring_size = r.len(); - let d = decomposer.decomposition_count(); for ri in 0..ring_size { - // let el_decomposed = decomposer.decompose_to_vec(&r[ri]); - decomposer .decompose_iter(&r[ri]) .enumerate() .for_each(|(index, el)| { decomp_r[index].as_mut()[ri] = el; }); - - // for j in 0..d { - // decomp_r[j].as_mut()[ri] = el_decomposed[j]; - // } } } @@ -578,18 +571,12 @@ pub(crate) fn galois_auto< .for_each(|(el_in, to_index, sign)| { let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in }; - // let el_out_decomposed = decomposer.decompose_to_vec(&el_out); - decomposer .decompose_iter(&el_out) .enumerate() .for_each(|(index, el)| { scratch_matrix_d_ring[index].as_mut()[*to_index] = el; }); - - // for j in 0..d { - // scratch_matrix_d_ring[j].as_mut()[*to_index] = - // el_out_decomposed[j]; } }); // transform decomposed a(X^k) to evaluation domain