Various improvement to memory management and API

[module]: added enum for backend
[VecZnx, VecZnxDft, VecZnxBig, VmpPMat]: added ptr to data
[VecZnxBorrow]: removed
[VecZnxAPI]: removed
This commit is contained in:
Jean-Philippe Bossuat
2025-03-17 12:07:40 +01:00
parent 97a1559bf2
commit 46c577409e
28 changed files with 896 additions and 1064 deletions

View File

@@ -1,11 +1,11 @@
use crate::{
ciphertext::Ciphertext,
elem::{Elem, ElemCommon, VecZnxCommon},
elem::{Elem, ElemCommon},
keys::SecretKey,
parameters::Parameters,
plaintext::Plaintext,
};
use base2k::{Module, SvpPPol, SvpPPolOps, VecZnxBigOps, VecZnxDft, VecZnxDftOps};
use base2k::{Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxBigOps, VecZnxDft, VecZnxDftOps};
use std::cmp::min;
pub struct Decryptor {
@@ -32,30 +32,24 @@ impl Parameters {
)
}
pub fn decrypt_rlwe<T>(
pub fn decrypt_rlwe(
&self,
res: &mut Plaintext<T>,
ct: &Ciphertext<T>,
res: &mut Plaintext,
ct: &Ciphertext<VecZnx>,
sk: &SvpPPol,
tmp_bytes: &mut [u8],
) where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemCommon<T>,
{
) {
decrypt_rlwe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes)
}
}
pub fn decrypt_rlwe<T>(
pub fn decrypt_rlwe(
module: &Module,
res: &mut Elem<T>,
a: &Elem<T>,
res: &mut Elem<VecZnx>,
a: &Elem<VecZnx>,
sk: &SvpPPol,
tmp_bytes: &mut [u8],
) where
T: VecZnxCommon<Owned = T>,
Elem<T>: ElemCommon<T>,
{
) {
let cols: usize = a.cols();
assert!(
@@ -65,9 +59,11 @@ pub fn decrypt_rlwe<T>(
decrypt_rlwe_tmp_byte(module, cols)
);
let res_dft_bytes: usize = module.bytes_of_vec_znx_dft(cols);
let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) =
tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let mut res_dft: VecZnxDft = VecZnxDft::from_bytes(a.cols(), tmp_bytes);
let mut res_dft: VecZnxDft =
VecZnxDft::from_bytes_borrow(module, a.cols(), tmp_bytes_vec_znx_dft);
let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big();
// res_dft <- DFT(ct[1]) * DFT(sk)
@@ -77,12 +73,7 @@ pub fn decrypt_rlwe<T>(
// res_big <- ct[1] x sk + ct[0]
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
// res <- normalize(ct[1] x sk + ct[0])
module.vec_znx_big_normalize(
a.log_base2k(),
res.at_mut(0),
&res_big,
&mut tmp_bytes[res_dft_bytes..],
);
module.vec_znx_big_normalize(a.log_base2k(), res.at_mut(0), &res_big, tmp_bytes_normalize);
res.log_base2k = a.log_base2k();
res.log_q = min(res.log_q(), a.log_q());