mirror of
https://github.com/arnaucube/phantom-zone.git
synced 2026-01-11 16:41:29 +01:00
decompose with bit hacks without brnaching speeds up bootstrappoing by 15ms
This commit is contained in:
@@ -230,7 +230,7 @@ pub(super) struct BoolPbsInfo<M: Matrix, Ntt, RlweModOp, LweModOp> {
|
|||||||
|
|
||||||
impl<M: Matrix, NttOp, RlweModOp, LweModOp> PbsInfo for BoolPbsInfo<M, NttOp, RlweModOp, LweModOp>
|
impl<M: Matrix, NttOp, RlweModOp, LweModOp> PbsInfo for BoolPbsInfo<M, NttOp, RlweModOp, LweModOp>
|
||||||
where
|
where
|
||||||
M::MatElement: PrimInt + WrappingSub + NumInfo + FromPrimitive,
|
M::MatElement: PrimInt + WrappingSub + NumInfo + FromPrimitive + From<bool>,
|
||||||
RlweModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>,
|
RlweModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>,
|
||||||
LweModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>,
|
LweModOp: ArithmeticOps<Element = M::MatElement> + VectorOps<Element = M::MatElement>,
|
||||||
NttOp: Ntt<Element = M::MatElement>,
|
NttOp: Ntt<Element = M::MatElement>,
|
||||||
@@ -319,8 +319,14 @@ impl<M: Matrix, NttOp, RlweModOp, LweModOp> BoolEvaluator<M, NttOp, RlweModOp, L
|
|||||||
impl<M: Matrix, NttOp, RlweModOp, LweModOp> BoolEvaluator<M, NttOp, RlweModOp, LweModOp>
|
impl<M: Matrix, NttOp, RlweModOp, LweModOp> BoolEvaluator<M, NttOp, RlweModOp, LweModOp>
|
||||||
where
|
where
|
||||||
M: MatrixEntity + MatrixMut,
|
M: MatrixEntity + MatrixMut,
|
||||||
M::MatElement:
|
M::MatElement: PrimInt
|
||||||
PrimInt + Debug + Display + NumInfo + FromPrimitive + WrappingSub + SampleUniform,
|
+ Debug
|
||||||
|
+ Display
|
||||||
|
+ NumInfo
|
||||||
|
+ FromPrimitive
|
||||||
|
+ WrappingSub
|
||||||
|
+ SampleUniform
|
||||||
|
+ From<bool>,
|
||||||
NttOp: Ntt<Element = M::MatElement>,
|
NttOp: Ntt<Element = M::MatElement>,
|
||||||
RlweModOp: ArithmeticOps<Element = M::MatElement>
|
RlweModOp: ArithmeticOps<Element = M::MatElement>
|
||||||
+ VectorOps<Element = M::MatElement>
|
+ VectorOps<Element = M::MatElement>
|
||||||
@@ -1108,7 +1114,8 @@ impl<M, NttOp, RlweModOp, LweModOp> BooleanGates for BoolEvaluator<M, NttOp, Rlw
|
|||||||
where
|
where
|
||||||
M: MatrixMut + MatrixEntity,
|
M: MatrixMut + MatrixEntity,
|
||||||
M::R: RowMut + RowEntity + Clone,
|
M::R: RowMut + RowEntity + Clone,
|
||||||
M::MatElement: PrimInt + FromPrimitive + One + Copy + Zero + Display + WrappingSub + NumInfo,
|
M::MatElement:
|
||||||
|
PrimInt + FromPrimitive + One + Copy + Zero + Display + WrappingSub + NumInfo + From<bool>,
|
||||||
RlweModOp: VectorOps<Element = M::MatElement>
|
RlweModOp: VectorOps<Element = M::MatElement>
|
||||||
+ ArithmeticOps<Element = M::MatElement>
|
+ ArithmeticOps<Element = M::MatElement>
|
||||||
+ GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>>,
|
+ GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>>,
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ impl<T: PrimInt + NumInfo + Debug> DefaultDecomposer<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo> Decomposer
|
impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo + From<bool>> Decomposer
|
||||||
for DefaultDecomposer<T>
|
for DefaultDecomposer<T>
|
||||||
{
|
{
|
||||||
type Element = T;
|
type Element = T;
|
||||||
@@ -182,6 +182,7 @@ impl<T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + NumInfo> Decompose
|
|||||||
DecomposerIter {
|
DecomposerIter {
|
||||||
value,
|
value,
|
||||||
q: self.q,
|
q: self.q,
|
||||||
|
logq: self.logq,
|
||||||
logb: self.logb,
|
logb: self.logb,
|
||||||
b: self.b,
|
b: self.b,
|
||||||
bby2: self.bby2,
|
bby2: self.bby2,
|
||||||
@@ -205,11 +206,13 @@ pub struct DecomposerIter<T> {
|
|||||||
bby2: T,
|
bby2: T,
|
||||||
/// Ciphertext modulus
|
/// Ciphertext modulus
|
||||||
q: T,
|
q: T,
|
||||||
|
/// Log of ciphertext modulus
|
||||||
|
logq: usize,
|
||||||
/// b = 1 << logb
|
/// b = 1 << logb
|
||||||
b: T,
|
b: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: PrimInt> Iterator for DecomposerIter<T> {
|
impl<T: PrimInt + From<bool>> Iterator for DecomposerIter<T> {
|
||||||
type Item = T;
|
type Item = T;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
@@ -219,20 +222,27 @@ impl<T: PrimInt> Iterator for DecomposerIter<T> {
|
|||||||
|
|
||||||
self.value = (self.value - k_i) >> self.logb;
|
self.value = (self.value - k_i) >> self.logb;
|
||||||
|
|
||||||
if k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one())) {
|
// if k_i > self.bby2 || (k_i == self.bby2 && ((self.value &
|
||||||
self.value = self.value + T::one();
|
// T::one()) == T::one())) { self.value = self.value
|
||||||
Some(self.q + k_i - self.b)
|
// + T::one(); Some(self.q + k_i - self.b)
|
||||||
} else {
|
// } else {
|
||||||
Some(k_i)
|
// Some(k_i)
|
||||||
}
|
// }
|
||||||
|
|
||||||
// let carry = <T as From<bool>>::from(
|
// Following is without branching impl of the commented version above. It
|
||||||
// k_i > self.bby2 || (k_i == self.bby2 && ((self.value &
|
// happens to speed up bootstrapping for `SMALL_MP_BOOL_PARAMS` (& other
|
||||||
// T::one()) == T::one())), );
|
// parameters as well but I haven't tested) by roughly 15ms.
|
||||||
// self.value = self.value + carry;
|
// Suprisingly the improvement does not show up when I benchmark
|
||||||
|
// `decomposer_iter` in isolation. Putting this remark here as a
|
||||||
|
// future task to investiage (TODO).
|
||||||
|
let carry = <T as From<bool>>::from(
|
||||||
|
k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one())),
|
||||||
|
);
|
||||||
|
self.value = self.value + carry;
|
||||||
|
|
||||||
// Some((self.q & ((carry << 55) - (T::one() & carry))) + k_i -
|
Some(
|
||||||
// (carry << self.logb))
|
(self.q & ((carry << self.logq) - (T::one() & carry))) + k_i - (carry << self.logb),
|
||||||
|
)
|
||||||
|
|
||||||
// Some(k_i)
|
// Some(k_i)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ pub(crate) fn lwe_key_switch<
|
|||||||
.as_ref()
|
.as_ref()
|
||||||
.iter()
|
.iter()
|
||||||
.skip(1)
|
.skip(1)
|
||||||
.flat_map(|ai| decomposer.decompose_to_vec(ai));
|
.flat_map(|ai| decomposer.decompose_iter(ai));
|
||||||
izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| {
|
izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| {
|
||||||
operator.elwise_fma_scalar_mut(lwe_out.as_mut(), beta_ij_lwe.as_ref(), &ai_j);
|
operator.elwise_fma_scalar_mut(lwe_out.as_mut(), beta_ij_lwe.as_ref(), &ai_j);
|
||||||
});
|
});
|
||||||
|
|||||||
13
src/rgsw.rs
13
src/rgsw.rs
@@ -515,21 +515,14 @@ pub(crate) fn decompose_r<R: RowMut, D: Decomposer<Element = R::Element>>(
|
|||||||
R::Element: Copy,
|
R::Element: Copy,
|
||||||
{
|
{
|
||||||
let ring_size = r.len();
|
let ring_size = r.len();
|
||||||
let d = decomposer.decomposition_count();
|
|
||||||
|
|
||||||
for ri in 0..ring_size {
|
for ri in 0..ring_size {
|
||||||
// let el_decomposed = decomposer.decompose_to_vec(&r[ri]);
|
|
||||||
|
|
||||||
decomposer
|
decomposer
|
||||||
.decompose_iter(&r[ri])
|
.decompose_iter(&r[ri])
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.for_each(|(index, el)| {
|
.for_each(|(index, el)| {
|
||||||
decomp_r[index].as_mut()[ri] = el;
|
decomp_r[index].as_mut()[ri] = el;
|
||||||
});
|
});
|
||||||
|
|
||||||
// for j in 0..d {
|
|
||||||
// decomp_r[j].as_mut()[ri] = el_decomposed[j];
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -578,18 +571,12 @@ pub(crate) fn galois_auto<
|
|||||||
.for_each(|(el_in, to_index, sign)| {
|
.for_each(|(el_in, to_index, sign)| {
|
||||||
let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in };
|
let el_out = if !*sign { mod_op.neg(el_in) } else { *el_in };
|
||||||
|
|
||||||
// let el_out_decomposed = decomposer.decompose_to_vec(&el_out);
|
|
||||||
|
|
||||||
decomposer
|
decomposer
|
||||||
.decompose_iter(&el_out)
|
.decompose_iter(&el_out)
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.for_each(|(index, el)| {
|
.for_each(|(index, el)| {
|
||||||
scratch_matrix_d_ring[index].as_mut()[*to_index] = el;
|
scratch_matrix_d_ring[index].as_mut()[*to_index] = el;
|
||||||
});
|
});
|
||||||
|
|
||||||
// for j in 0..d {
|
|
||||||
// scratch_matrix_d_ring[j].as_mut()[*to_index] =
|
|
||||||
// el_out_decomposed[j]; }
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// transform decomposed a(X^k) to evaluation domain
|
// transform decomposed a(X^k) to evaluation domain
|
||||||
|
|||||||
Reference in New Issue
Block a user