Browse Source

decompose with bit hacks without brnaching speeds up bootstrappoing by 15ms

par-agg-key-shares
Janmajaya Mall 10 months ago
parent
commit
a20a3e8e77
4 changed files with 38 additions and 34 deletions
  1. +11
    -4
      src/bool/evaluator.rs
  2. +26
    -16
      src/decomposer.rs
  3. +1
    -1
      src/lwe.rs
  4. +0
    -13
      src/rgsw.rs

+ 11
- 4
src/bool/evaluator.rs

@ -230,7 +230,7 @@ pub(super) struct BoolPbsInfo {
impl<M: Matrix, NttOp, RlweModOp, LweModOp> PbsInfo for BoolPbsInfo<M, NttOp, RlweModOp, LweModOp> impl<M: Matrix, NttOp, RlweModOp, LweModOp> PbsInfo for BoolPbsInfo<M, NttOp, RlweModOp, LweModOp>
where where
M::MatElement: PrimInt + WrappingSub + NumInfo + FromPrimitive,
M::MatElement: PrimInt + WrappingSub + NumInfo + FromPrimitive + From<bool>,
RlweModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>, RlweModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>,
LweModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>, LweModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>,
NttOp: Ntt<Element = M::MatElement>, NttOp: Ntt<Element = M::MatElement>,
@ -319,8 +319,14 @@ impl BoolEvaluator
impl<M: Matrix, NttOp, RlweModOp, LweModOp> BoolEvaluator<M, NttOp, RlweModOp, LweModOp> impl<M: Matrix, NttOp, RlweModOp, LweModOp> BoolEvaluator<M, NttOp, RlweModOp, LweModOp>
where where
M: MatrixEntity + MatrixMut, M: MatrixEntity + MatrixMut,
M::MatElement:
PrimInt + Debug + Display + NumInfo + FromPrimitive + WrappingSub + SampleUniform,
M::MatElement: PrimInt
+ Debug
+ Display
+ NumInfo
+ FromPrimitive
+ WrappingSub
+ SampleUniform
+ From<bool>,
NttOp: Ntt<Element = M::MatElement>, NttOp: Ntt<Element = M::MatElement>,
RlweModOp: ArithmeticOps<Element = M::MatElement> RlweModOp: ArithmeticOps<Element = M::MatElement>
+ VectorOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>
@ -1108,7 +1114,8 @@ impl BooleanGates for BoolEvaluator
where where
M: MatrixMut + MatrixEntity, M: MatrixMut + MatrixEntity,
M::R: RowMut + RowEntity + Clone, M::R: RowMut + RowEntity + Clone,
M::MatElement: PrimInt + FromPrimitive + One + Copy + Zero + Display + WrappingSub + NumInfo,
M::MatElement:
PrimInt + FromPrimitive + One + Copy + Zero + Display + WrappingSub + NumInfo + From<bool>,
RlweModOp: VectorOps<Element = M::MatElement> RlweModOp: VectorOps<Element = M::MatElement>
+ ArithmeticOps<Element = M::MatElement> + ArithmeticOps<Element = M::MatElement>
+ GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>>, + GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>>,

+ 26
- 16
src/decomposer.rs

@ -106,7 +106,7 @@ impl DefaultDecomposer {
} }
} }
impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo> Decomposer
impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo + From<bool>> Decomposer
for DefaultDecomposer<T> for DefaultDecomposer<T>
{ {
type Element = T; type Element = T;
@ -182,6 +182,7 @@ impl Decompose
DecomposerIter { DecomposerIter {
value, value,
q: self.q, q: self.q,
logq: self.logq,
logb: self.logb, logb: self.logb,
b: self.b, b: self.b,
bby2: self.bby2, bby2: self.bby2,
@ -205,11 +206,13 @@ pub struct DecomposerIter {
bby2: T, bby2: T,
/// Ciphertext modulus /// Ciphertext modulus
q: T, q: T,
/// Log of ciphertext modulus
logq: usize,
/// b = 1 << logb /// b = 1 << logb
b: T, b: T,
} }
impl<T: PrimInt> Iterator for DecomposerIter<T> {
impl<T: PrimInt + From<bool>> Iterator for DecomposerIter<T> {
type Item = T; type Item = T;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
@ -219,20 +222,27 @@ impl Iterator for DecomposerIter {
self.value = (self.value - k_i) >> self.logb; self.value = (self.value - k_i) >> self.logb;
if k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one())) {
self.value = self.value + T::one();
Some(self.q + k_i - self.b)
} else {
Some(k_i)
}
// let carry = <T as From<bool>>::from(
// k_i > self.bby2 || (k_i == self.bby2 && ((self.value &
// T::one()) == T::one())), );
// self.value = self.value + carry;
// Some((self.q & ((carry << 55) - (T::one() & carry))) + k_i -
// (carry << self.logb))
// if k_i > self.bby2 || (k_i == self.bby2 && ((self.value &
// T::one()) == T::one())) { self.value = self.value
// + T::one(); Some(self.q + k_i - self.b)
// } else {
// Some(k_i)
// }
// Following is without branching impl of the commented version above. It
// happens to speed up bootstrapping for `SMALL_MP_BOOL_PARAMS` (& other
// parameters as well but I haven't tested) by roughly 15ms.
// Suprisingly the improvement does not show up when I benchmark
// `decomposer_iter` in isolation. Putting this remark here as a
// future task to investiage (TODO).
let carry = <T as From<bool>>::from(
k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one())),
);
self.value = self.value + carry;
Some(
(self.q & ((carry << self.logq) - (T::one() & carry))) + k_i - (carry << self.logb),
)
// Some(k_i) // Some(k_i)
} else { } else {

+ 1
- 1
src/lwe.rs

@ -125,7 +125,7 @@ pub(crate) fn lwe_key_switch<
.as_ref() .as_ref()
.iter() .iter()
.skip(1) .skip(1)
.flat_map(|ai| decomposer.decompose_to_vec(ai));
.flat_map(|ai| decomposer.decompose_iter(ai));
izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| { izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| {
operator.elwise_fma_scalar_mut(lwe_out.as_mut(), beta_ij_lwe.as_ref(), &ai_j); operator.elwise_fma_scalar_mut(lwe_out.as_mut(), beta_ij_lwe.as_ref(), &ai_j);
}); });

+ 0
- 13
src/rgsw.rs

@ -515,21 +515,14 @@ pub(crate) fn decompose_r>(
R::Element: Copy, R::Element: Copy,
{ {
let ring_size = r.len(); let ring_size = r.len();
let d = decomposer.decomposition_count();
for ri in 0..ring_size { for ri in 0..ring_size {
// let el_decomposed = decomposer.decompose_to_vec(&r[ri]);
decomposer decomposer
.decompose_iter(&r[ri]) .decompose_iter(&r[ri])
.enumerate() .enumerate()
.for_each(|(index, el)| { .for_each(|(index, el)| {
decomp_r[index].as_mut()[ri] = el; decomp_r[index].as_mut()[ri] = el;
}); });
// for j in 0..d {
// decomp_r[j].as_mut()[ri] = el_decomposed[j];
// }
} }
} }
@ -578,18 +571,12 @@ pub(crate) fn galois_auto<
.for_each(|(el_in, to_index, sign)| { .for_each(|(el_in, to_index, sign)| {
let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in }; let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in };
// let el_out_decomposed = decomposer.decompose_to_vec(&el_out);
decomposer decomposer
.decompose_iter(&el_out) .decompose_iter(&el_out)
.enumerate() .enumerate()
.for_each(|(index, el)| { .for_each(|(index, el)| {
scratch_matrix_d_ring[index].as_mut()[*to_index] = el; scratch_matrix_d_ring[index].as_mut()[*to_index] = el;
}); });
// for j in 0..d {
// scratch_matrix_d_ring[j].as_mut()[*to_index] =
// el_out_decomposed[j]; }
}); });
// transform decomposed a(X^k) to evaluation domain // transform decomposed a(X^k) to evaluation domain

Loading…
Cancel
Save