rename raw dft ops

This commit is contained in:
Pro7ech
2025-08-25 09:08:27 +02:00
parent 62448e0293
commit 1551f7a6f0
69 changed files with 666 additions and 715 deletions

View File

@@ -1,7 +1,7 @@
use poulpy_hal::{
api::{
ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VmpApply, VmpApplyAdd, VmpApplyTmpBytes,
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApply, VmpApplyAdd, VmpApplyTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
};
@@ -56,8 +56,8 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VecZnxDftFromVecZnx<B>
+ VecZnxDftToVecZnxBigConsume<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
@@ -76,8 +76,8 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VecZnxDftFromVecZnx<B>
+ VecZnxDftToVecZnxBigConsume<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
@@ -132,8 +132,8 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VecZnxDftFromVecZnx<B>
+ VecZnxDftToVecZnxBigConsume<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
@@ -188,8 +188,8 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VecZnxDftFromVecZnx<B>
+ VecZnxDftToVecZnxBigConsume<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,

View File

@@ -1,8 +1,8 @@
use poulpy_hal::{
api::{
ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxDftFromVecZnx,
VecZnxDftToVecZnxBigConsume, VecZnxDftToVecZnxBigTmpA, VecZnxNormalizeTmpBytes, VmpApply, VmpApplyAdd, VmpApplyTmpBytes,
DFT, IDFTConsume, IDFTTmpA, ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace,
VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace,
VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxNormalizeTmpBytes, VmpApply, VmpApplyAdd, VmpApplyTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VmpPMat, ZnxInfos},
};
@@ -107,13 +107,13 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VmpApplyTmpBytes
+ VecZnxBigAllocBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftFromVecZnx<B>
+ DFT<B>
+ VecZnxDftCopy<B>
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VecZnxDftAddInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxDftToVecZnxBigTmpA<B>,
+ IDFTTmpA<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
{
#[cfg(debug_assertions)]
@@ -143,8 +143,8 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VecZnxDftFromVecZnx<B>
+ VecZnxDftToVecZnxBigConsume<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxDftAllocBytes
@@ -152,7 +152,7 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VecZnxNormalizeTmpBytes
+ VecZnxDftCopy<B>
+ VecZnxDftAddInplace<B>
+ VecZnxDftToVecZnxBigTmpA<B>,
+ IDFTTmpA<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
{
self.keyswitch_internal(module, lhs, ksk, scratch);
@@ -171,8 +171,8 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VecZnxDftFromVecZnx<B>
+ VecZnxDftToVecZnxBigConsume<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxDftAllocBytes
@@ -180,7 +180,7 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VecZnxNormalizeTmpBytes
+ VecZnxDftCopy<B>
+ VecZnxDftAddInplace<B>
+ VecZnxDftToVecZnxBigTmpA<B>,
+ IDFTTmpA<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
{
unsafe {
@@ -199,13 +199,13 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VmpApplyTmpBytes
+ VecZnxBigAllocBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftFromVecZnx<B>
+ DFT<B>
+ VecZnxDftCopy<B>
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VecZnxDftAddInplace<B>
+ VecZnxBigNormalize<B>
+ VecZnxDftToVecZnxBigTmpA<B>,
+ IDFTTmpA<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
{
assert!(
@@ -229,7 +229,7 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
// Pre-compute DFT of (a0, a1, a2)
let (mut ci_dft, scratch1) = scratch.take_vec_znx_dft(n, cols, self.size());
(0..cols).for_each(|i| {
module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i);
module.dft(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i);
});
(1..cols).for_each(|col_j| {
@@ -308,7 +308,7 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0);
let (mut tmp_idft, scratch3) = scratch2.take_vec_znx_big(n, 1, tsk.size());
(0..cols).for_each(|i| {
module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i);
module.idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i);
module.vec_znx_big_normalize(
self.basek(),
&mut self.at_mut(row_i, col_j).data,
@@ -334,8 +334,8 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VecZnxDftFromVecZnx<B>
+ VecZnxDftToVecZnxBigConsume<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,

View File

@@ -1,7 +1,7 @@
use poulpy_hal::{
api::{
ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VmpApply, VmpApplyAdd, VmpApplyTmpBytes,
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApply, VmpApplyAdd, VmpApplyTmpBytes,
},
layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos},
};
@@ -133,8 +133,8 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VmpApplyTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VecZnxDftFromVecZnx<B>
+ VecZnxDftToVecZnxBigConsume<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
@@ -162,8 +162,8 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
+ VmpApplyTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VecZnxDftFromVecZnx<B>
+ VecZnxDftToVecZnxBigConsume<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
@@ -192,8 +192,8 @@ impl<D: DataRef> GLWECiphertext<D> {
+ VmpApplyTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VecZnxDftFromVecZnx<B>
+ VecZnxDftToVecZnxBigConsume<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDft<B>,
@@ -224,17 +224,16 @@ where
DataRes: DataMut,
DataIn: DataRef,
DataVmp: DataRef,
Module<B>:
VecZnxDftAllocBytes + VecZnxDftFromVecZnx<B> + VmpApply<B> + VecZnxDftToVecZnxBigConsume<B> + VecZnxBigAddSmallInplace<B>,
Module<B>: VecZnxDftAllocBytes + DFT<B> + VmpApply<B> + IDFTConsume<B> + VecZnxBigAddSmallInplace<B>,
Scratch<B>: TakeVecZnxDft<B>,
{
let cols: usize = a.cols();
let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size());
(0..cols - 1).for_each(|col_i| {
module.vec_znx_dft_from_vec_znx(1, 0, &mut ai_dft, col_i, a, col_i + 1);
module.dft(1, 0, &mut ai_dft, col_i, a, col_i + 1);
});
module.vmp_apply(&mut res_dft, &ai_dft, mat, scratch1);
let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_dft_to_vec_znx_big_consume(res_dft);
let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_consume(res_dft);
module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
res_big
}
@@ -251,12 +250,7 @@ where
DataRes: DataMut,
DataIn: DataRef,
DataVmp: DataRef,
Module<B>: VecZnxDftAllocBytes
+ VecZnxDftFromVecZnx<B>
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VecZnxDftToVecZnxBigConsume<B>
+ VecZnxBigAddSmallInplace<B>,
Module<B>: VecZnxDftAllocBytes + DFT<B> + VmpApply<B> + VmpApplyAdd<B> + IDFTConsume<B> + VecZnxBigAddSmallInplace<B>,
Scratch<B>: TakeVecZnxDft<B>,
{
let cols: usize = a.cols();
@@ -278,7 +272,7 @@ where
res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize);
(0..cols - 1).for_each(|col_i| {
module.vec_znx_dft_from_vec_znx(digits, digits - di - 1, &mut ai_dft, col_i, a, col_i + 1);
module.dft(digits, digits - di - 1, &mut ai_dft, col_i, a, col_i + 1);
});
if di == 0 {
@@ -289,7 +283,7 @@ where
});
res_dft.set_size(res_dft.max_size());
let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_dft_to_vec_znx_big_consume(res_dft);
let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_consume(res_dft);
module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
res_big
}

View File

@@ -1,7 +1,7 @@
use poulpy_hal::{
api::{
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VmpApply, VmpApplyAdd, VmpApplyTmpBytes,
DFT, IDFTConsume, ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApply, VmpApplyAdd, VmpApplyTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero},
};
@@ -26,8 +26,8 @@ impl LWECiphertext<Vec<u8>> {
+ VmpApplyTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VecZnxDftFromVecZnx<B>
+ VecZnxDftToVecZnxBigConsume<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
{
@@ -51,8 +51,8 @@ impl<DLwe: DataMut> LWECiphertext<DLwe> {
+ VecZnxBigNormalizeTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VecZnxDftFromVecZnx<B>
+ VecZnxDftToVecZnxBigConsume<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,