mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-11 16:41:29 +01:00
decompose with bit hacks without brnaching speeds up bootstrappoing by 15ms
This commit is contained in:
@@ -230,7 +230,7 @@ pub(super) struct BoolPbsInfo<M: Matrix, Ntt, RlweModOp, LweModOp> {
|
||||
|
||||
impl<M: Matrix, NttOp, RlweModOp, LweModOp> PbsInfo for BoolPbsInfo<M, NttOp, RlweModOp, LweModOp>
|
||||
where
|
||||
M::MatElement: PrimInt + WrappingSub + NumInfo + FromPrimitive,
|
||||
M::MatElement: PrimInt + WrappingSub + NumInfo + FromPrimitive + From<bool>,
|
||||
RlweModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>,
|
||||
LweModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>,
|
||||
NttOp: Ntt<Element = M::MatElement>,
|
||||
@@ -319,8 +319,14 @@ impl<M: Matrix, NttOp, RlweModOp, LweModOp> BoolEvaluator<M, NttOp, RlweModOp, L
|
||||
impl<M: Matrix, NttOp, RlweModOp, LweModOp> BoolEvaluator<M, NttOp, RlweModOp, LweModOp>
|
||||
where
|
||||
M: MatrixEntity + MatrixMut,
|
||||
M::MatElement:
|
||||
PrimInt + Debug + Display + NumInfo + FromPrimitive + WrappingSub + SampleUniform,
|
||||
M::MatElement: PrimInt
|
||||
+ Debug
|
||||
+ Display
|
||||
+ NumInfo
|
||||
+ FromPrimitive
|
||||
+ WrappingSub
|
||||
+ SampleUniform
|
||||
+ From<bool>,
|
||||
NttOp: Ntt<Element = M::MatElement>,
|
||||
RlweModOp: ArithmeticOps<Element = M::MatElement>
|
||||
+ VectorOps<Element = M::MatElement>
|
||||
@@ -1108,7 +1114,8 @@ impl<M, NttOp, RlweModOp, LweModOp> BooleanGates for BoolEvaluator<M, NttOp, Rlw
|
||||
where
|
||||
M: MatrixMut + MatrixEntity,
|
||||
M::R: RowMut + RowEntity + Clone,
|
||||
M::MatElement: PrimInt + FromPrimitive + One + Copy + Zero + Display + WrappingSub + NumInfo,
|
||||
M::MatElement:
|
||||
PrimInt + FromPrimitive + One + Copy + Zero + Display + WrappingSub + NumInfo + From<bool>,
|
||||
RlweModOp: VectorOps<Element = M::MatElement>
|
||||
+ ArithmeticOps<Element = M::MatElement>
|
||||
+ GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>>,
|
||||
|
||||
@@ -106,7 +106,7 @@ impl<T: PrimInt + NumInfo + Debug> DefaultDecomposer<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo> Decomposer
|
||||
impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo + From<bool>> Decomposer
|
||||
for DefaultDecomposer<T>
|
||||
{
|
||||
type Element = T;
|
||||
@@ -182,6 +182,7 @@ impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo> Decompose
|
||||
DecomposerIter {
|
||||
value,
|
||||
q: self.q,
|
||||
logq: self.logq,
|
||||
logb: self.logb,
|
||||
b: self.b,
|
||||
bby2: self.bby2,
|
||||
@@ -205,11 +206,13 @@ pub struct DecomposerIter<T> {
|
||||
bby2: T,
|
||||
/// Ciphertext modulus
|
||||
q: T,
|
||||
/// Log of ciphertext modulus
|
||||
logq: usize,
|
||||
/// b = 1 << logb
|
||||
b: T,
|
||||
}
|
||||
|
||||
impl<T: PrimInt> Iterator for DecomposerIter<T> {
|
||||
impl<T: PrimInt + From<bool>> Iterator for DecomposerIter<T> {
|
||||
type Item = T;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
@@ -219,20 +222,27 @@ impl<T: PrimInt> Iterator for DecomposerIter<T> {
|
||||
|
||||
self.value = (self.value - k_i) >> self.logb;
|
||||
|
||||
if k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one())) {
|
||||
self.value = self.value + T::one();
|
||||
Some(self.q + k_i - self.b)
|
||||
} else {
|
||||
Some(k_i)
|
||||
}
|
||||
// if k_i > self.bby2 || (k_i == self.bby2 && ((self.value &
|
||||
// T::one()) == T::one())) { self.value = self.value
|
||||
// + T::one(); Some(self.q + k_i - self.b)
|
||||
// } else {
|
||||
// Some(k_i)
|
||||
// }
|
||||
|
||||
// let carry = <T as From<bool>>::from(
|
||||
// k_i > self.bby2 || (k_i == self.bby2 && ((self.value &
|
||||
// T::one()) == T::one())), );
|
||||
// self.value = self.value + carry;
|
||||
// Following is without branching impl of the commented version above. It
|
||||
// happens to speed up bootstrapping for `SMALL_MP_BOOL_PARAMS` (& other
|
||||
// parameters as well but I haven't tested) by roughly 15ms.
|
||||
// Suprisingly the improvement does not show up when I benchmark
|
||||
// `decomposer_iter` in isolation. Putting this remark here as a
|
||||
// future task to investiage (TODO).
|
||||
let carry = <T as From<bool>>::from(
|
||||
k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one())),
|
||||
);
|
||||
self.value = self.value + carry;
|
||||
|
||||
// Some((self.q & ((carry << 55) - (T::one() & carry))) + k_i -
|
||||
// (carry << self.logb))
|
||||
Some(
|
||||
(self.q & ((carry << self.logq) - (T::one() & carry))) + k_i - (carry << self.logb),
|
||||
)
|
||||
|
||||
// Some(k_i)
|
||||
} else {
|
||||
|
||||
@@ -125,7 +125,7 @@ pub(crate) fn lwe_key_switch<
|
||||
.as_ref()
|
||||
.iter()
|
||||
.skip(1)
|
||||
.flat_map(|ai| decomposer.decompose_to_vec(ai));
|
||||
.flat_map(|ai| decomposer.decompose_iter(ai));
|
||||
izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| {
|
||||
operator.elwise_fma_scalar_mut(lwe_out.as_mut(), beta_ij_lwe.as_ref(), &ai_j);
|
||||
});
|
||||
|
||||
13
src/rgsw.rs
13
src/rgsw.rs
@@ -515,21 +515,14 @@ pub(crate) fn decompose_r<R: RowMut, D: Decomposer<Element = R::Element>>(
|
||||
R::Element: Copy,
|
||||
{
|
||||
let ring_size = r.len();
|
||||
let d = decomposer.decomposition_count();
|
||||
|
||||
for ri in 0..ring_size {
|
||||
// let el_decomposed = decomposer.decompose_to_vec(&r[ri]);
|
||||
|
||||
decomposer
|
||||
.decompose_iter(&r[ri])
|
||||
.enumerate()
|
||||
.for_each(|(index, el)| {
|
||||
decomp_r[index].as_mut()[ri] = el;
|
||||
});
|
||||
|
||||
// for j in 0..d {
|
||||
// decomp_r[j].as_mut()[ri] = el_decomposed[j];
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -578,18 +571,12 @@ pub(crate) fn galois_auto<
|
||||
.for_each(|(el_in, to_index, sign)| {
|
||||
let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in };
|
||||
|
||||
// let el_out_decomposed = decomposer.decompose_to_vec(&el_out);
|
||||
|
||||
decomposer
|
||||
.decompose_iter(&el_out)
|
||||
.enumerate()
|
||||
.for_each(|(index, el)| {
|
||||
scratch_matrix_d_ring[index].as_mut()[*to_index] = el;
|
||||
});
|
||||
|
||||
// for j in 0..d {
|
||||
// scratch_matrix_d_ring[j].as_mut()[*to_index] =
|
||||
// el_out_decomposed[j]; }
|
||||
});
|
||||
|
||||
// transform decomposed a(X^k) to evaluation domain
|
||||
|
||||
Reference in New Issue
Block a user