Added API in poulpy for updated vmp_add (+tests)

This commit is contained in:
Jean-Philippe Bossuat
2025-06-04 11:39:11 +02:00
parent fcdc8f53d3
commit 159cd8025f
14 changed files with 216 additions and 82 deletions

View File

@@ -7,26 +7,28 @@ use sampling::source::Source;
use crate::{
AutomorphismKey, GLWECiphertext, GLWECiphertextFourier, GLWESecret, GLWESwitchingKey, GetRow, Infos, ScratchCore, SetRow,
TensorKey, derive_size,
TensorKey, div_ceil,
};
pub struct GGSWCiphertext<C, B: Backend> {
pub data: MatZnxDft<C, B>,
pub basek: usize,
pub k: usize,
pub(crate) data: MatZnxDft<C, B>,
pub(crate) basek: usize,
pub(crate) k: usize,
pub(crate) digits: usize,
}
impl<B: Backend> GGSWCiphertext<Vec<u8>, B> {
pub fn alloc(module: &Module<B>, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
pub fn alloc(module: &Module<B>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self {
Self {
data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(basek, k)),
basek: basek,
data: module.new_mat_znx_dft(div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k)),
basek,
k: k,
digits,
}
}
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> usize {
module.bytes_of_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(basek, k))
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize {
module.bytes_of_mat_znx_dft(div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k))
}
}
@@ -50,11 +52,15 @@ impl<T, B: Backend> GGSWCiphertext<T, B> {
pub fn rank(&self) -> usize {
self.data.cols_out() - 1
}
pub fn digits(&self) -> usize {
self.digits
}
}
impl GGSWCiphertext<Vec<u8>, FFT64> {
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
let size = derive_size(basek, k);
let size = div_ceil(basek, k);
GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
+ module.bytes_of_vec_znx(rank + 1, size)
+ module.bytes_of_vec_znx(1, size)
@@ -68,8 +74,8 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
tsk_k: usize,
rank: usize,
) -> usize {
let tsk_size: usize = derive_size(basek, tsk_k);
let self_size: usize = derive_size(basek, self_k);
let tsk_size: usize = div_ceil(basek, tsk_k);
let self_size: usize = div_ceil(basek, self_k);
let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size);
let tmp_dft_col_data: usize = module.bytes_of_vec_znx_dft(1, self_size);
let vmp: usize = tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tsk_size);
@@ -87,7 +93,7 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
rank: usize,
) -> usize {
GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, out_k, rank, in_k, rank, ksk_k)
+ module.bytes_of_vec_znx_dft(rank + 1, derive_size(basek, in_k))
+ module.bytes_of_vec_znx_dft(rank + 1, div_ceil(basek, in_k))
}
pub fn keyswitch_scratch_space(
@@ -99,7 +105,7 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
tsk_k: usize,
rank: usize,
) -> usize {
let out_size: usize = derive_size(basek, out_k);
let out_size: usize = div_ceil(basek, out_k);
let res_znx: usize = module.bytes_of_vec_znx(rank + 1, out_size);
let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
@@ -130,7 +136,7 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
rank: usize,
) -> usize {
let cols: usize = rank + 1;
let out_size: usize = derive_size(basek, out_k);
let out_size: usize = div_ceil(basek, out_k);
let res: usize = module.bytes_of_vec_znx(cols, out_size);
let res_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
let ci_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
@@ -199,6 +205,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
let basek: usize = self.basek();
let k: usize = self.k();
let rank: usize = self.rank();
let digits: usize = self.digits();
let (mut tmp_pt, scratch1) = scratch.tmp_glwe_pt(module, basek, k);
let (mut tmp_ct, scratch2) = scratch1.tmp_glwe_ct(module, basek, k, rank);
@@ -207,7 +214,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
tmp_pt.data.zero();
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i, pt, 0);
module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i * digits, pt, 0);
module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch2);
(0..rank + 1).for_each(|col_j| {