mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Added API in poulpy for updated vmp_add (+tests)
This commit is contained in:
@@ -7,26 +7,28 @@ use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
AutomorphismKey, GLWECiphertext, GLWECiphertextFourier, GLWESecret, GLWESwitchingKey, GetRow, Infos, ScratchCore, SetRow,
|
||||
TensorKey, derive_size,
|
||||
TensorKey, div_ceil,
|
||||
};
|
||||
|
||||
pub struct GGSWCiphertext<C, B: Backend> {
|
||||
pub data: MatZnxDft<C, B>,
|
||||
pub basek: usize,
|
||||
pub k: usize,
|
||||
pub(crate) data: MatZnxDft<C, B>,
|
||||
pub(crate) basek: usize,
|
||||
pub(crate) k: usize,
|
||||
pub(crate) digits: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> GGSWCiphertext<Vec<u8>, B> {
|
||||
pub fn alloc(module: &Module<B>, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
|
||||
pub fn alloc(module: &Module<B>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self {
|
||||
Self {
|
||||
data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(basek, k)),
|
||||
basek: basek,
|
||||
data: module.new_mat_znx_dft(div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k)),
|
||||
basek,
|
||||
k: k,
|
||||
digits,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> usize {
|
||||
module.bytes_of_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(basek, k))
|
||||
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize {
|
||||
module.bytes_of_mat_znx_dft(div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,11 +52,15 @@ impl<T, B: Backend> GGSWCiphertext<T, B> {
|
||||
pub fn rank(&self) -> usize {
|
||||
self.data.cols_out() - 1
|
||||
}
|
||||
|
||||
pub fn digits(&self) -> usize {
|
||||
self.digits
|
||||
}
|
||||
}
|
||||
|
||||
impl GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
|
||||
let size = derive_size(basek, k);
|
||||
let size = div_ceil(basek, k);
|
||||
GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
|
||||
+ module.bytes_of_vec_znx(rank + 1, size)
|
||||
+ module.bytes_of_vec_znx(1, size)
|
||||
@@ -68,8 +74,8 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
tsk_k: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let tsk_size: usize = derive_size(basek, tsk_k);
|
||||
let self_size: usize = derive_size(basek, self_k);
|
||||
let tsk_size: usize = div_ceil(basek, tsk_k);
|
||||
let self_size: usize = div_ceil(basek, self_k);
|
||||
let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size);
|
||||
let tmp_dft_col_data: usize = module.bytes_of_vec_znx_dft(1, self_size);
|
||||
let vmp: usize = tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tsk_size);
|
||||
@@ -87,7 +93,7 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, out_k, rank, in_k, rank, ksk_k)
|
||||
+ module.bytes_of_vec_znx_dft(rank + 1, derive_size(basek, in_k))
|
||||
+ module.bytes_of_vec_znx_dft(rank + 1, div_ceil(basek, in_k))
|
||||
}
|
||||
|
||||
pub fn keyswitch_scratch_space(
|
||||
@@ -99,7 +105,7 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
tsk_k: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let out_size: usize = derive_size(basek, out_k);
|
||||
let out_size: usize = div_ceil(basek, out_k);
|
||||
|
||||
let res_znx: usize = module.bytes_of_vec_znx(rank + 1, out_size);
|
||||
let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
@@ -130,7 +136,7 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let cols: usize = rank + 1;
|
||||
let out_size: usize = derive_size(basek, out_k);
|
||||
let out_size: usize = div_ceil(basek, out_k);
|
||||
let res: usize = module.bytes_of_vec_znx(cols, out_size);
|
||||
let res_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
|
||||
let ci_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
|
||||
@@ -199,6 +205,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
||||
let basek: usize = self.basek();
|
||||
let k: usize = self.k();
|
||||
let rank: usize = self.rank();
|
||||
let digits: usize = self.digits();
|
||||
|
||||
let (mut tmp_pt, scratch1) = scratch.tmp_glwe_pt(module, basek, k);
|
||||
let (mut tmp_ct, scratch2) = scratch1.tmp_glwe_ct(module, basek, k, rank);
|
||||
@@ -207,7 +214,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
||||
tmp_pt.data.zero();
|
||||
|
||||
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
|
||||
module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i, pt, 0);
|
||||
module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i * digits, pt, 0);
|
||||
module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch2);
|
||||
|
||||
(0..rank + 1).for_each(|col_j| {
|
||||
|
||||
Reference in New Issue
Block a user