Added API in poulpy for updated vmp_add (+tests)

This commit is contained in:
Jean-Philippe Bossuat
2025-06-04 11:39:11 +02:00
parent fcdc8f53d3
commit 159cd8025f
14 changed files with 216 additions and 82 deletions

View File

@@ -4,7 +4,7 @@ use backend::{
};
use sampling::source::Source;
use crate::{GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, ScratchCore, derive_size};
use crate::{GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, ScratchCore, div_ceil};
pub struct GLWECiphertextFourier<C, B: Backend> {
pub data: VecZnxDft<C, B>,
@@ -15,14 +15,14 @@ pub struct GLWECiphertextFourier<C, B: Backend> {
impl<B: Backend> GLWECiphertextFourier<Vec<u8>, B> {
pub fn alloc(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self {
Self {
data: module.new_vec_znx_dft(rank + 1, derive_size(basek, k)),
data: module.new_vec_znx_dft(rank + 1, div_ceil(basek, k)),
basek: basek,
k: k,
}
}
pub fn bytes_of(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize {
module.bytes_of_vec_znx_dft(rank + 1, derive_size(basek, k))
module.bytes_of_vec_znx_dft(rank + 1, div_ceil(basek, k))
}
}
@@ -51,16 +51,16 @@ impl<T, B: Backend> GLWECiphertextFourier<T, B> {
impl GLWECiphertextFourier<Vec<u8>, FFT64> {
#[allow(dead_code)]
pub(crate) fn idft_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
module.bytes_of_vec_znx(1, derive_size(basek, k))
module.bytes_of_vec_znx(1, div_ceil(basek, k))
+ (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes())
}
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
module.bytes_of_vec_znx(rank + 1, derive_size(basek, k)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
module.bytes_of_vec_znx(rank + 1, div_ceil(basek, k)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
}
pub fn decrypt_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
let size: usize = derive_size(basek, k);
let size: usize = div_ceil(basek, k);
(module.vec_znx_big_normalize_tmp_bytes()
| module.bytes_of_vec_znx_dft(1, size)
| (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes()))
@@ -99,9 +99,9 @@ impl GLWECiphertextFourier<Vec<u8>, FFT64> {
rank: usize,
) -> usize {
let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank);
let out_size: usize = derive_size(basek, out_k);
let in_size: usize = derive_size(basek, in_k);
let ggsw_size: usize = derive_size(basek, ggsw_k);
let out_size: usize = div_ceil(basek, out_k);
let in_size: usize = div_ceil(basek, in_k);
let ggsw_size: usize = div_ceil(basek, ggsw_k);
let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank + 1, rank + 1, ggsw_size);
let res_small: usize = GLWECiphertext::bytes_of(module, basek, out_k, rank);
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();