mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Ref. + AVX code & generic tests + benches (#85)
This commit is contained in:
committed by
GitHub
parent
99b9e3e10e
commit
56dbd29c59
@@ -1,7 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
|
||||
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
|
||||
};
|
||||
@@ -56,8 +57,8 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
@@ -76,8 +77,8 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
@@ -132,8 +133,8 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
|
||||
@@ -161,6 +162,12 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
|
||||
self.rank_out(),
|
||||
rhs.rank_out()
|
||||
);
|
||||
assert!(
|
||||
self.rows() <= lhs.rows(),
|
||||
"self.rows()={} > lhs.rows()={}",
|
||||
self.rows(),
|
||||
lhs.rows()
|
||||
);
|
||||
}
|
||||
|
||||
(0..self.rank_in()).for_each(|col_i| {
|
||||
@@ -188,8 +195,8 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, IDFTTmpA, ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace,
|
||||
VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace,
|
||||
VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy,
|
||||
VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VmpPMat, ZnxInfos},
|
||||
@@ -114,13 +114,13 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxBigAllocBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ DFT<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxDftCopy<B>
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ VecZnxDftAddInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ IDFTTmpA<B>,
|
||||
+ VecZnxIdftApplyTmpA<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
@@ -150,8 +150,8 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxDftAllocBytes
|
||||
@@ -159,10 +159,15 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftCopy<B>
|
||||
+ VecZnxDftAddInplace<B>
|
||||
+ IDFTTmpA<B>,
|
||||
+ VecZnxIdftApplyTmpA<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
|
||||
{
|
||||
self.keyswitch_internal(module, lhs, ksk, scratch);
|
||||
(0..lhs.rows()).for_each(|row_i| {
|
||||
// Key-switch column 0, i.e.
|
||||
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
|
||||
self.at_mut(row_i, 0)
|
||||
.keyswitch(module, &lhs.at(row_i, 0), ksk, scratch);
|
||||
});
|
||||
self.expand_row(module, tsk, scratch);
|
||||
}
|
||||
|
||||
@@ -178,8 +183,8 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxDftAllocBytes
|
||||
@@ -187,13 +192,16 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftCopy<B>
|
||||
+ VecZnxDftAddInplace<B>
|
||||
+ IDFTTmpA<B>,
|
||||
+ VecZnxIdftApplyTmpA<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
|
||||
{
|
||||
unsafe {
|
||||
let self_ptr: *mut GGSWCiphertext<DataSelf> = self as *mut GGSWCiphertext<DataSelf>;
|
||||
self.keyswitch(module, &*self_ptr, ksk, tsk, scratch);
|
||||
}
|
||||
(0..self.rows()).for_each(|row_i| {
|
||||
// Key-switch column 0, i.e.
|
||||
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
|
||||
self.at_mut(row_i, 0)
|
||||
.keyswitch_inplace(module, ksk, scratch);
|
||||
});
|
||||
self.expand_row(module, tsk, scratch);
|
||||
}
|
||||
|
||||
pub fn expand_row<DataTsk: DataRef, B: Backend>(
|
||||
@@ -206,13 +214,13 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxBigAllocBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ DFT<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxDftCopy<B>
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ VecZnxDftAddInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ IDFTTmpA<B>,
|
||||
+ VecZnxIdftApplyTmpA<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
|
||||
{
|
||||
assert!(
|
||||
@@ -234,9 +242,9 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
// Keyswitch the j-th row of the col 0
|
||||
(0..self.rows()).for_each(|row_i| {
|
||||
// Pre-compute DFT of (a0, a1, a2)
|
||||
let (mut ci_dft, scratch1) = scratch.take_vec_znx_dft(n, cols, self.size());
|
||||
let (mut ci_dft, scratch_1) = scratch.take_vec_znx_dft(n, cols, self.size());
|
||||
(0..cols).for_each(|i| {
|
||||
module.dft(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i);
|
||||
module.vec_znx_dft_apply(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i);
|
||||
});
|
||||
|
||||
(1..cols).for_each(|col_j| {
|
||||
@@ -262,8 +270,8 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
|
||||
let digits: usize = tsk.digits();
|
||||
|
||||
let (mut tmp_dft_i, scratch2) = scratch1.take_vec_znx_dft(n, cols, tsk.size());
|
||||
let (mut tmp_a, scratch3) = scratch2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(digits));
|
||||
let (mut tmp_dft_i, scratch_2) = scratch_1.take_vec_znx_dft(n, cols, tsk.size());
|
||||
let (mut tmp_a, scratch_3) = scratch_2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(digits));
|
||||
|
||||
{
|
||||
// Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2
|
||||
@@ -295,9 +303,9 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
|
||||
module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, &ci_dft, col_i);
|
||||
if di == 0 && col_i == 1 {
|
||||
module.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch3);
|
||||
module.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch_3);
|
||||
} else {
|
||||
module.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch3);
|
||||
module.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch_3);
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -313,46 +321,19 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
// =
|
||||
// (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2)
|
||||
module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0);
|
||||
let (mut tmp_idft, scratch3) = scratch2.take_vec_znx_big(n, 1, tsk.size());
|
||||
let (mut tmp_idft, scratch_3) = scratch_2.take_vec_znx_big(n, 1, tsk.size());
|
||||
(0..cols).for_each(|i| {
|
||||
module.idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i);
|
||||
module.vec_znx_idft_apply_tmpa(&mut tmp_idft, 0, &mut tmp_dft_i, i);
|
||||
module.vec_znx_big_normalize(
|
||||
self.basek(),
|
||||
&mut self.at_mut(row_i, col_j).data,
|
||||
i,
|
||||
&tmp_idft,
|
||||
0,
|
||||
scratch3,
|
||||
scratch_3,
|
||||
);
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn keyswitch_internal<DataLhs: DataRef, DataKsk: DataRef, B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
lhs: &GGSWCiphertext<DataLhs>,
|
||||
ksk: &GGLWESwitchingKeyPrepared<DataKsk, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
|
||||
{
|
||||
// Keyswitch the j-th row of the col 0
|
||||
(0..lhs.rows()).for_each(|row_i| {
|
||||
// Key-switch column 0, i.e.
|
||||
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
|
||||
self.at_mut(row_i, 0)
|
||||
.keyswitch(module, &lhs.at(row_i, 0), ksk, scratch);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
|
||||
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos},
|
||||
};
|
||||
@@ -117,6 +118,63 @@ impl<DataSelf: DataRef> GLWECiphertext<DataSelf> {
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn assert_keyswitch_inplace<B: Backend, DataRhs>(
|
||||
&self,
|
||||
module: &Module<B>,
|
||||
rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
|
||||
scratch: &Scratch<B>,
|
||||
) where
|
||||
DataRhs: DataRef,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
|
||||
Scratch<B>: ScratchAvailable,
|
||||
{
|
||||
let basek: usize = self.basek();
|
||||
assert_eq!(
|
||||
self.rank(),
|
||||
rhs.rank_out(),
|
||||
"self.rank(): {} != rhs.rank_out(): {}",
|
||||
self.rank(),
|
||||
rhs.rank_out()
|
||||
);
|
||||
assert_eq!(self.basek(), basek);
|
||||
assert_eq!(rhs.n(), self.n());
|
||||
assert!(
|
||||
scratch.available()
|
||||
>= GLWECiphertext::keyswitch_scratch_space(
|
||||
module,
|
||||
self.basek(),
|
||||
self.k(),
|
||||
self.k(),
|
||||
rhs.k(),
|
||||
rhs.digits(),
|
||||
rhs.rank_in(),
|
||||
rhs.rank_out(),
|
||||
),
|
||||
"scratch.available()={} < GLWECiphertext::keyswitch_scratch_space(
|
||||
module,
|
||||
self.basek(),
|
||||
self.k(),
|
||||
self.k(),
|
||||
rhs.k(),
|
||||
rhs.digits(),
|
||||
rhs.rank_in(),
|
||||
rhs.rank_out(),
|
||||
)={}",
|
||||
scratch.available(),
|
||||
GLWECiphertext::keyswitch_scratch_space(
|
||||
module,
|
||||
self.basek(),
|
||||
self.k(),
|
||||
self.k(),
|
||||
rhs.k(),
|
||||
rhs.digits(),
|
||||
rhs.rank_in(),
|
||||
rhs.rank_out(),
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
@@ -130,11 +188,10 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
|
||||
@@ -143,10 +200,10 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
{
|
||||
self.assert_keyswitch(module, lhs, rhs, scratch);
|
||||
}
|
||||
let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise
|
||||
let res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, rhs, scratch1);
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise
|
||||
let res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, rhs, scratch_1);
|
||||
(0..self.cols()).for_each(|i| {
|
||||
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
|
||||
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
|
||||
})
|
||||
}
|
||||
|
||||
@@ -162,16 +219,21 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
|
||||
{
|
||||
unsafe {
|
||||
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
|
||||
self.keyswitch(module, &*self_ptr, rhs, scratch);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
self.assert_keyswitch_inplace(module, rhs, scratch);
|
||||
}
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise
|
||||
let res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, rhs, scratch_1);
|
||||
(0..self.cols()).for_each(|i| {
|
||||
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,8 +254,8 @@ impl<D: DataRef> GLWECiphertext<D> {
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B>,
|
||||
@@ -224,16 +286,17 @@ where
|
||||
DataRes: DataMut,
|
||||
DataIn: DataRef,
|
||||
DataVmp: DataRef,
|
||||
Module<B>: VecZnxDftAllocBytes + DFT<B> + VmpApplyDftToDft<B> + IDFTConsume<B> + VecZnxBigAddSmallInplace<B>,
|
||||
Module<B>:
|
||||
VecZnxDftAllocBytes + VecZnxDftApply<B> + VmpApplyDftToDft<B> + VecZnxIdftApplyConsume<B> + VecZnxBigAddSmallInplace<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B>,
|
||||
{
|
||||
let cols: usize = a.cols();
|
||||
let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size());
|
||||
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size());
|
||||
(0..cols - 1).for_each(|col_i| {
|
||||
module.dft(1, 0, &mut ai_dft, col_i, a, col_i + 1);
|
||||
module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a, col_i + 1);
|
||||
});
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch1);
|
||||
let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_consume(res_dft);
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1);
|
||||
let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_apply_consume(res_dft);
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
|
||||
res_big
|
||||
}
|
||||
@@ -251,16 +314,16 @@ where
|
||||
DataIn: DataRef,
|
||||
DataVmp: DataRef,
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ DFT<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B>,
|
||||
{
|
||||
let cols: usize = a.cols();
|
||||
let size: usize = a.size();
|
||||
let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, size.div_ceil(digits));
|
||||
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, size.div_ceil(digits));
|
||||
|
||||
ai_dft.data_mut().fill(0);
|
||||
|
||||
@@ -277,18 +340,18 @@ where
|
||||
res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize);
|
||||
|
||||
(0..cols - 1).for_each(|col_i| {
|
||||
module.dft(digits, digits - di - 1, &mut ai_dft, col_i, a, col_i + 1);
|
||||
module.vec_znx_dft_apply(digits, digits - di - 1, &mut ai_dft, col_i, a, col_i + 1);
|
||||
});
|
||||
|
||||
if di == 0 {
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch1);
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1);
|
||||
} else {
|
||||
module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch1);
|
||||
module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_1);
|
||||
}
|
||||
});
|
||||
|
||||
res_dft.set_size(res_dft.max_size());
|
||||
let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_consume(res_dft);
|
||||
let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_apply_consume(res_dft);
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
|
||||
res_big
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
|
||||
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero},
|
||||
};
|
||||
@@ -26,8 +27,8 @@ impl LWECiphertext<Vec<u8>> {
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
{
|
||||
@@ -51,8 +52,8 @@ impl<DLwe: DataMut> LWECiphertext<DLwe> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
|
||||
@@ -67,7 +68,7 @@ impl<DLwe: DataMut> LWECiphertext<DLwe> {
|
||||
let max_k: usize = self.k().max(a.k());
|
||||
let basek: usize = self.basek();
|
||||
|
||||
let (mut glwe, scratch1) = scratch.take_glwe_ct(ksk.n(), basek, max_k, 1);
|
||||
let (mut glwe, scratch_1) = scratch.take_glwe_ct(ksk.n(), basek, max_k, 1);
|
||||
glwe.data.zero();
|
||||
|
||||
let n_lwe: usize = a.n();
|
||||
@@ -78,7 +79,7 @@ impl<DLwe: DataMut> LWECiphertext<DLwe> {
|
||||
glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]);
|
||||
});
|
||||
|
||||
glwe.keyswitch_inplace(module, &ksk.0, scratch1);
|
||||
glwe.keyswitch_inplace(module, &ksk.0, scratch_1);
|
||||
|
||||
self.sample_extract(&glwe);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user