renamed vmp API closer to spqlios

This commit is contained in:
Pro7ech
2025-08-25 11:58:51 +02:00
parent 1551f7a6f0
commit a1b865709d
38 changed files with 431 additions and 393 deletions

View File

@@ -5,7 +5,8 @@ use poulpy_hal::{
TakeVecZnxDftSlice, TakeVecZnxSlice, VecZnxAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftSubABInplace,
VecZnxDftZero, VecZnxIDFTTmpBytes, VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace,
VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxSubABInplace, VmpApply, VmpApplyAdd, VmpApplyTmpBytes,
VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxSubABInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, SvpPPol, VecZnx, ZnxView, ZnxZero},
};
@@ -32,7 +33,7 @@ pub fn cggi_blind_rotate_scratch_space<B: Backend>(
) -> usize
where
Module<B>: VecZnxDftAllocBytes
+ VmpApplyTmpBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxBigAllocBytes
+ VecZnxIDFTTmpBytes
@@ -47,7 +48,7 @@ where
let vmp_res: usize = module.vec_znx_dft_alloc_bytes(cols, brk_size) * extension_factor;
let vmp_xai: usize = module.vec_znx_dft_alloc_bytes(1, brk_size);
let acc_dft_add: usize = vmp_res;
let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
let acc: usize = if extension_factor > 1 {
VecZnx::alloc_bytes(module.n(), cols, k_res.div_ceil(basek)) * extension_factor
} else {
@@ -70,7 +71,7 @@ where
Module<B>: VecZnxBigAllocBytes
+ VecZnxDftAllocBytes
+ SvpPPolAllocBytes
+ VmpApplyTmpBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIDFTTmpBytes
+ IDFT<B>
@@ -88,8 +89,8 @@ where
+ VecZnxNormalizeInplace<B>
+ VecZnxCopy
+ VecZnxMulXpMinusOneInplace
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ IDFTConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalizeTmpBytes,
@@ -132,7 +133,7 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
Module<B>: VecZnxBigAllocBytes
+ VecZnxDftAllocBytes
+ SvpPPolAllocBytes
+ VmpApplyTmpBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIDFTTmpBytes
+ IDFT<B>
@@ -151,7 +152,7 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
+ VecZnxCopy
+ VecZnxMulXpMinusOneInplace
+ VecZnxBigNormalize<B>
+ VmpApply<B>,
+ VmpApplyDftToDft<B>,
Scratch<B>: TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx,
{
let n_glwe: usize = brk.n();
@@ -220,7 +221,7 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
// vmp_res = DFT(acc) * BRK[i]
(0..extension_factor).for_each(|i| {
module.vmp_apply(&mut vmp_res[i], &acc_dft[i], skii.data(), scratch5);
module.vmp_apply_dft_to_dft(&mut vmp_res[i], &acc_dft[i], skii.data(), scratch5);
});
// Trivial case: no rotation between polynomials, we can directly multiply with (X^{-ai} - 1)
@@ -299,7 +300,7 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
Module<B>: VecZnxBigAllocBytes
+ VecZnxDftAllocBytes
+ SvpPPolAllocBytes
+ VmpApplyTmpBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIDFTTmpBytes
+ IDFT<B>
@@ -317,7 +318,7 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
+ VecZnxNormalizeInplace<B>
+ VecZnxCopy
+ VecZnxMulXpMinusOneInplace
+ VmpApply<B>
+ VmpApplyDftToDft<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx,
{
@@ -377,7 +378,7 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
let ai_pos: usize = ((aii + two_n as i64) & (two_n - 1) as i64) as usize;
// vmp_res = DFT(acc) * BRK[i]
module.vmp_apply(&mut vmp_res, &acc_dft, skii.data(), scratch4);
module.vmp_apply_dft_to_dft(&mut vmp_res, &acc_dft, skii.data(), scratch4);
// DFT(X^ai -1) * (DFT(acc) * BRK[i])
(0..cols).for_each(|i| {
@@ -413,7 +414,7 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
Module<B>: VecZnxBigAllocBytes
+ VecZnxDftAllocBytes
+ SvpPPolAllocBytes
+ VmpApplyTmpBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIDFTTmpBytes
+ IDFT<B>
@@ -431,8 +432,8 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
+ VecZnxNormalizeInplace<B>
+ VecZnxCopy
+ VecZnxMulXpMinusOneInplace
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ IDFTConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalizeTmpBytes,