renamed vmp API closer to spqlios

This commit is contained in:
Pro7ech
2025-08-25 11:58:51 +02:00
parent 1551f7a6f0
commit a1b865709d
38 changed files with 431 additions and 393 deletions

View File

@@ -5,7 +5,8 @@ use poulpy_hal::{
TakeVecZnxDftSlice, TakeVecZnxSlice, VecZnxAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftSubABInplace,
VecZnxDftZero, VecZnxIDFTTmpBytes, VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace,
VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxSubABInplace, VmpApply, VmpApplyAdd, VmpApplyTmpBytes,
VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxSubABInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, SvpPPol, VecZnx, ZnxView, ZnxZero},
};
@@ -32,7 +33,7 @@ pub fn cggi_blind_rotate_scratch_space<B: Backend>(
) -> usize
where
Module<B>: VecZnxDftAllocBytes
+ VmpApplyTmpBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxBigAllocBytes
+ VecZnxIDFTTmpBytes
@@ -47,7 +48,7 @@ where
let vmp_res: usize = module.vec_znx_dft_alloc_bytes(cols, brk_size) * extension_factor;
let vmp_xai: usize = module.vec_znx_dft_alloc_bytes(1, brk_size);
let acc_dft_add: usize = vmp_res;
let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
let acc: usize = if extension_factor > 1 {
VecZnx::alloc_bytes(module.n(), cols, k_res.div_ceil(basek)) * extension_factor
} else {
@@ -70,7 +71,7 @@ where
Module<B>: VecZnxBigAllocBytes
+ VecZnxDftAllocBytes
+ SvpPPolAllocBytes
+ VmpApplyTmpBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIDFTTmpBytes
+ IDFT<B>
@@ -88,8 +89,8 @@ where
+ VecZnxNormalizeInplace<B>
+ VecZnxCopy
+ VecZnxMulXpMinusOneInplace
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ IDFTConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalizeTmpBytes,
@@ -132,7 +133,7 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
Module<B>: VecZnxBigAllocBytes
+ VecZnxDftAllocBytes
+ SvpPPolAllocBytes
+ VmpApplyTmpBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIDFTTmpBytes
+ IDFT<B>
@@ -151,7 +152,7 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
+ VecZnxCopy
+ VecZnxMulXpMinusOneInplace
+ VecZnxBigNormalize<B>
+ VmpApply<B>,
+ VmpApplyDftToDft<B>,
Scratch<B>: TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx,
{
let n_glwe: usize = brk.n();
@@ -220,7 +221,7 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
// vmp_res = DFT(acc) * BRK[i]
(0..extension_factor).for_each(|i| {
module.vmp_apply(&mut vmp_res[i], &acc_dft[i], skii.data(), scratch5);
module.vmp_apply_dft_to_dft(&mut vmp_res[i], &acc_dft[i], skii.data(), scratch5);
});
// Trivial case: no rotation between polynomials, we can directly multiply with (X^{-ai} - 1)
@@ -299,7 +300,7 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
Module<B>: VecZnxBigAllocBytes
+ VecZnxDftAllocBytes
+ SvpPPolAllocBytes
+ VmpApplyTmpBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIDFTTmpBytes
+ IDFT<B>
@@ -317,7 +318,7 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
+ VecZnxNormalizeInplace<B>
+ VecZnxCopy
+ VecZnxMulXpMinusOneInplace
+ VmpApply<B>
+ VmpApplyDftToDft<B>
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx,
{
@@ -377,7 +378,7 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
let ai_pos: usize = ((aii + two_n as i64) & (two_n - 1) as i64) as usize;
// vmp_res = DFT(acc) * BRK[i]
module.vmp_apply(&mut vmp_res, &acc_dft, skii.data(), scratch4);
module.vmp_apply_dft_to_dft(&mut vmp_res, &acc_dft, skii.data(), scratch4);
// DFT(X^ai -1) * (DFT(acc) * BRK[i])
(0..cols).for_each(|i| {
@@ -413,7 +414,7 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
Module<B>: VecZnxBigAllocBytes
+ VecZnxDftAllocBytes
+ SvpPPolAllocBytes
+ VmpApplyTmpBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIDFTTmpBytes
+ IDFT<B>
@@ -431,8 +432,8 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
+ VecZnxNormalizeInplace<B>
+ VecZnxCopy
+ VecZnxMulXpMinusOneInplace
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ IDFTConsume<B>
+ VecZnxBigNormalize<B>
+ VecZnxNormalizeTmpBytes,

View File

@@ -5,8 +5,8 @@ use poulpy_hal::{
VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd,
VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftSubABInplace, VecZnxDftZero, VecZnxFillUniform, VecZnxIDFTTmpBytes,
VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate,
VecZnxRotateInplace, VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree, VmpApply, VmpApplyAdd, VmpApplyTmpBytes,
VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace,
VecZnxRotateInplace, VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree, VmpApplyDftToDft, VmpApplyDftToDftAdd,
VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace,
},
layouts::{Backend, Module, ScratchOwned, ZnxView},
oep::{
@@ -31,7 +31,7 @@ where
Module<B>: VecZnxBigAllocBytes
+ VecZnxDftAllocBytes
+ SvpPPolAllocBytes
+ VmpApplyTmpBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxIDFTTmpBytes
+ IDFT<B>
@@ -64,8 +64,8 @@ where
+ VecZnxSub
+ VmpPMatAlloc<B>
+ VmpPrepare<B>
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ ZnFillUniform
+ ZnAddNormal
+ ZnNormalizeInplace<B>,

View File

@@ -7,7 +7,7 @@ use poulpy_hal::{
VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy,
VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxNegateInplace, VecZnxNormalizeInplace,
VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace,
VecZnxSwithcDegree, VmpApply, VmpApplyAdd, VmpApplyTmpBytes,
VecZnxSwithcDegree, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, ToOwnedDeep},
oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl},
@@ -40,10 +40,10 @@ where
+ VecZnxCopy
+ VecZnxSubABInplace
+ VecZnxDftAllocBytes
+ VmpApplyTmpBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>
@@ -138,10 +138,10 @@ pub fn circuit_bootstrap_core<DRes, DLwe, DBrk, BRA: BlindRotationAlgo, B>(
+ VecZnxCopy
+ VecZnxSubABInplace
+ VecZnxDftAllocBytes
+ VmpApplyTmpBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>
@@ -263,10 +263,10 @@ fn post_process<DataRes, DataA, B: Backend>(
+ VecZnxCopy
+ VecZnxSubABInplace
+ VecZnxDftAllocBytes
+ VmpApplyTmpBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>
@@ -335,10 +335,10 @@ pub fn pack<D: DataMut, B: Backend>(
+ VecZnxCopy
+ VecZnxSubABInplace
+ VecZnxDftAllocBytes
+ VmpApplyTmpBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>
@@ -414,10 +414,10 @@ fn combine<A: DataMut, D: DataMut, DataAK: DataRef, B: Backend>(
+ VecZnxCopy
+ VecZnxSubABInplace
+ VecZnxDftAllocBytes
+ VmpApplyTmpBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ DFT<B>
+ IDFTConsume<B>
+ VecZnxBigAddSmallInplace<B>

View File

@@ -8,8 +8,8 @@ use poulpy_hal::{
VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy,
VecZnxDftAddInplace, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxFillUniform, VecZnxNegateInplace,
VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace,
VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, VmpPMatAlloc, VmpPrepare,
ZnAddNormal, ZnFillUniform, ZnNormalizeInplace,
VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace,
},
layouts::{Backend, Module, ScalarZnx, ScratchOwned, ZnxView, ZnxViewMut},
oep::{
@@ -66,9 +66,9 @@ where
+ VmpPrepare<B>
+ SvpPrepare<B>
+ SvpPPolAlloc<B>
+ VmpApplyTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VmpApplyDftToDftTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ SvpPPolAllocBytes
+ VecZnxRotateInplace
+ VecZnxBigAutomorphismInplace<B>
@@ -247,9 +247,9 @@ where
+ VmpPrepare<B>
+ SvpPrepare<B>
+ SvpPPolAlloc<B>
+ VmpApplyTmpBytes
+ VmpApply<B>
+ VmpApplyAdd<B>
+ VmpApplyDftToDftTmpBytes
+ VmpApplyDftToDft<B>
+ VmpApplyDftToDftAdd<B>
+ SvpPPolAllocBytes
+ VecZnxRotateInplace
+ VecZnxBigAutomorphismInplace<B>