Ref. + AVX code & generic tests + benches (#85)

This commit is contained in:
Jean-Philippe Bossuat
2025-09-15 16:16:11 +02:00
committed by GitHub
parent 99b9e3e10e
commit 56dbd29c59
286 changed files with 27797 additions and 7270 deletions

View File

@@ -1,7 +1,8 @@
use poulpy_hal::{
api::{
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
};
@@ -56,8 +57,8 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
@@ -76,8 +77,8 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
@@ -132,8 +133,8 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
@@ -161,6 +162,12 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
self.rank_out(),
rhs.rank_out()
);
assert!(
self.rows() <= lhs.rows(),
"self.rows()={} > lhs.rows()={}",
self.rows(),
lhs.rows()
);
}
(0..self.rank_in()).for_each(|col_i| {
@@ -188,8 +195,8 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,

View File

@@ -1,8 +1,8 @@
use poulpy_hal::{
api::{
DFT, IDFTConsume, IDFTTmpA, ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace,
VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace,
VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy,
VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VmpPMat, ZnxInfos},
@@ -114,13 +114,13 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigAllocBytes
+ VecZnxNormalizeTmpBytes
+ DFT<B>
+ VecZnxDftApply<B>
+ VecZnxDftCopy<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxDftAddInplace<B>
+ VecZnxBigNormalize<B>
+ IDFTTmpA<B>,
+ VecZnxIdftApplyTmpA<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
{
#[cfg(debug_assertions)]
@@ -150,8 +150,8 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxDftAllocBytes
@@ -159,10 +159,15 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VecZnxNormalizeTmpBytes
+ VecZnxDftCopy<B>
+ VecZnxDftAddInplace<B>
+ IDFTTmpA<B>,
+ VecZnxIdftApplyTmpA<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
{
self.keyswitch_internal(module, lhs, ksk, scratch);
(0..lhs.rows()).for_each(|row_i| {
// Key-switch column 0, i.e.
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
self.at_mut(row_i, 0)
.keyswitch(module, &lhs.at(row_i, 0), ksk, scratch);
});
self.expand_row(module, tsk, scratch);
}
@@ -178,8 +183,8 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxDftAllocBytes
@@ -187,13 +192,16 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VecZnxNormalizeTmpBytes
+ VecZnxDftCopy<B>
+ VecZnxDftAddInplace<B>
+ IDFTTmpA<B>,
+ VecZnxIdftApplyTmpA<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
{
unsafe {
let self_ptr: *mut GGSWCiphertext<DataSelf> = self as *mut GGSWCiphertext<DataSelf>;
self.keyswitch(module, &*self_ptr, ksk, tsk, scratch);
}
(0..self.rows()).for_each(|row_i| {
// Key-switch column 0, i.e.
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
self.at_mut(row_i, 0)
.keyswitch_inplace(module, ksk, scratch);
});
self.expand_row(module, tsk, scratch);
}
pub fn expand_row<DataTsk: DataRef, B: Backend>(
@@ -206,13 +214,13 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigAllocBytes
+ VecZnxNormalizeTmpBytes
+ DFT<B>
+ VecZnxDftApply<B>
+ VecZnxDftCopy<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ VecZnxDftAddInplace<B>
+ VecZnxBigNormalize<B>
+ IDFTTmpA<B>,
+ VecZnxIdftApplyTmpA<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
{
assert!(
@@ -234,9 +242,9 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
// Keyswitch the j-th row of the col 0
(0..self.rows()).for_each(|row_i| {
// Pre-compute DFT of (a0, a1, a2)
let (mut ci_dft, scratch1) = scratch.take_vec_znx_dft(n, cols, self.size());
let (mut ci_dft, scratch_1) = scratch.take_vec_znx_dft(n, cols, self.size());
(0..cols).for_each(|i| {
module.dft(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i);
module.vec_znx_dft_apply(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i);
});
(1..cols).for_each(|col_j| {
@@ -262,8 +270,8 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
let digits: usize = tsk.digits();
let (mut tmp_dft_i, scratch2) = scratch1.take_vec_znx_dft(n, cols, tsk.size());
let (mut tmp_a, scratch3) = scratch2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(digits));
let (mut tmp_dft_i, scratch_2) = scratch_1.take_vec_znx_dft(n, cols, tsk.size());
let (mut tmp_a, scratch_3) = scratch_2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(digits));
{
// Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2
@@ -295,9 +303,9 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, &ci_dft, col_i);
if di == 0 && col_i == 1 {
module.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch3);
module.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch_3);
} else {
module.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch3);
module.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch_3);
}
});
});
@@ -313,46 +321,19 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
// =
// (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2)
module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0);
let (mut tmp_idft, scratch3) = scratch2.take_vec_znx_big(n, 1, tsk.size());
let (mut tmp_idft, scratch_3) = scratch_2.take_vec_znx_big(n, 1, tsk.size());
(0..cols).for_each(|i| {
module.idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i);
module.vec_znx_idft_apply_tmpa(&mut tmp_idft, 0, &mut tmp_dft_i, i);
module.vec_znx_big_normalize(
self.basek(),
&mut self.at_mut(row_i, col_j).data,
i,
&tmp_idft,
0,
scratch3,
scratch_3,
);
});
})
})
}
fn keyswitch_internal<DataLhs: DataRef, DataKsk: DataRef, B: Backend>(
&mut self,
module: &Module<B>,
lhs: &GGSWCiphertext<DataLhs>,
ksk: &GGLWESwitchingKeyPrepared<DataKsk, B>,
scratch: &mut Scratch<B>,
) where
Module<B>: VecZnxDftAllocBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
{
// Keyswitch the j-th row of the col 0
(0..lhs.rows()).for_each(|row_i| {
// Key-switch column 0, i.e.
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
self.at_mut(row_i, 0)
.keyswitch(module, &lhs.at(row_i, 0), ksk, scratch);
})
}
}

View File

@@ -1,7 +1,8 @@
use poulpy_hal::{
api::{
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos},
};
@@ -117,6 +118,63 @@ impl<DataSelf: DataRef> GLWECiphertext<DataSelf> {
)
);
}
#[allow(dead_code)]
pub(crate) fn assert_keyswitch_inplace<B: Backend, DataRhs>(
&self,
module: &Module<B>,
rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
scratch: &Scratch<B>,
) where
DataRhs: DataRef,
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
Scratch<B>: ScratchAvailable,
{
let basek: usize = self.basek();
assert_eq!(
self.rank(),
rhs.rank_out(),
"self.rank(): {} != rhs.rank_out(): {}",
self.rank(),
rhs.rank_out()
);
assert_eq!(self.basek(), basek);
assert_eq!(rhs.n(), self.n());
assert!(
scratch.available()
>= GLWECiphertext::keyswitch_scratch_space(
module,
self.basek(),
self.k(),
self.k(),
rhs.k(),
rhs.digits(),
rhs.rank_in(),
rhs.rank_out(),
),
"scratch.available()={} < GLWECiphertext::keyswitch_scratch_space(
module,
self.basek(),
self.k(),
self.k(),
rhs.k(),
rhs.digits(),
rhs.rank_in(),
rhs.rank_out(),
)={}",
scratch.available(),
GLWECiphertext::keyswitch_scratch_space(
module,
self.basek(),
self.k(),
self.k(),
rhs.k(),
rhs.digits(),
rhs.rank_in(),
rhs.rank_out(),
)
);
}
}
impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
@@ -130,11 +188,10 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
Module<B>: VecZnxDftAllocBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDftTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
@@ -143,10 +200,10 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
{
self.assert_keyswitch(module, lhs, rhs, scratch);
}
let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise
let res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, rhs, scratch1);
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise
let res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, rhs, scratch_1);
(0..self.cols()).for_each(|i| {
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
})
}
@@ -162,16 +219,21 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VmpApplyDftToDftTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
{
unsafe {
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
self.keyswitch(module, &*self_ptr, rhs, scratch);
#[cfg(debug_assertions)]
{
self.assert_keyswitch_inplace(module, rhs, scratch);
}
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise
let res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, rhs, scratch_1);
(0..self.cols()).for_each(|i| {
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
})
}
}
@@ -192,8 +254,8 @@ impl<D: DataRef> GLWECiphertext<D> {
+ VmpApplyDftToDftTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDft<B>,
@@ -224,16 +286,17 @@ where
DataRes: DataMut,
DataIn: DataRef,
DataVmp: DataRef,
Module<B>: VecZnxDftAllocBytes + DFT<B> + VmpApplyDftToDft<B> + IDFTConsume<B> + VecZnxBigAddSmallInplace<B>,
Module<B>:
VecZnxDftAllocBytes + VecZnxDftApply<B> + VmpApplyDftToDft<B> + VecZnxIdftApplyConsume<B> + VecZnxBigAddSmallInplace<B>,
Scratch<B>: TakeVecZnxDft<B>,
{
let cols: usize = a.cols();
let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size());
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size());
(0..cols - 1).for_each(|col_i| {
module.dft(1, 0, &mut ai_dft, col_i, a, col_i + 1);
module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a, col_i + 1);
});
module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch1);
let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_consume(res_dft);
module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1);
let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_apply_consume(res_dft);
module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
res_big
}
@@ -251,16 +314,16 @@ where
DataIn: DataRef,
DataVmp: DataRef,
Module<B>: VecZnxDftAllocBytes
+ DFT<B>
+ VecZnxDftApply<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ IDFTConsume<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>,
Scratch<B>: TakeVecZnxDft<B>,
{
let cols: usize = a.cols();
let size: usize = a.size();
let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, size.div_ceil(digits));
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, size.div_ceil(digits));
ai_dft.data_mut().fill(0);
@@ -277,18 +340,18 @@ where
res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize);
(0..cols - 1).for_each(|col_i| {
module.dft(digits, digits - di - 1, &mut ai_dft, col_i, a, col_i + 1);
module.vec_znx_dft_apply(digits, digits - di - 1, &mut ai_dft, col_i, a, col_i + 1);
});
if di == 0 {
module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch1);
module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1);
} else {
module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch1);
module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_1);
}
});
res_dft.set_size(res_dft.max_size());
let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_consume(res_dft);
let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_apply_consume(res_dft);
module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
res_big
}

View File

@@ -1,7 +1,8 @@
use poulpy_hal::{
api::{
DFT, IDFTConsume, ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero},
};
@@ -26,8 +27,8 @@ impl LWECiphertext<Vec<u8>> {
+ VmpApplyDftToDftTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
{
@@ -51,8 +52,8 @@ impl<DLwe: DataMut> LWECiphertext<DLwe> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxDftApply<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
@@ -67,7 +68,7 @@ impl<DLwe: DataMut> LWECiphertext<DLwe> {
let max_k: usize = self.k().max(a.k());
let basek: usize = self.basek();
let (mut glwe, scratch1) = scratch.take_glwe_ct(ksk.n(), basek, max_k, 1);
let (mut glwe, scratch_1) = scratch.take_glwe_ct(ksk.n(), basek, max_k, 1);
glwe.data.zero();
let n_lwe: usize = a.n();
@@ -78,7 +79,7 @@ impl<DLwe: DataMut> LWECiphertext<DLwe> {
glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]);
});
glwe.keyswitch_inplace(module, &ksk.0, scratch1);
glwe.keyswitch_inplace(module, &ksk.0, scratch_1);
self.sample_extract(&glwe);
}