prototype trait for Elem<T> + new ciphertext for VmPPmat

This commit is contained in:
Jean-Philippe Bossuat
2025-02-18 11:04:13 +01:00
parent fdc2f3ac42
commit d486e89761
21 changed files with 767 additions and 811 deletions

View File

@@ -1,12 +1,12 @@
use crate::{
ciphertext::Ciphertext,
elem::{Elem, ElemBasics},
elem::{Elem, ElemVecZnx, VecZnxCommon},
keys::SecretKey,
parameters::Parameters,
plaintext::Plaintext,
};
use base2k::{
Infos, VecZnx, Module, SvpPPol, SvpPPolOps, VecZnxApi, VecZnxBigOps, VecZnxDft, VecZnxDftOps,
Infos, Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBigOps, VecZnxDft, VecZnxDftOps,
};
use std::cmp::min;
@@ -34,13 +34,16 @@ impl Parameters {
)
}
pub fn decrypt_rlwe_thread_safe(
pub fn decrypt_rlwe_thread_safe<T>(
&self,
res: &mut Plaintext<VecZnx>,
ct: &Ciphertext,
res: &mut Plaintext<T>,
ct: &Ciphertext<T>,
sk: &SvpPPol,
tmp_bytes: &mut [u8],
) {
) where
T: VecZnxCommon,
Elem<T>: Infos + ElemVecZnx<T>,
{
decrypt_rlwe_thread_safe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes)
}
}
@@ -52,26 +55,29 @@ pub fn decrypt_rlwe_thread_safe<T>(
sk: &SvpPPol,
tmp_bytes: &mut [u8],
) where
T: VecZnxApi + Infos,
T: VecZnxCommon,
Elem<T>: Infos + ElemVecZnx<T>,
{
let cols: usize = a.cols();
assert!(
tmp_bytes.len() >= decrypt_rlwe_thread_safe_tmp_byte(module, a.limbs()),
tmp_bytes.len() >= decrypt_rlwe_thread_safe_tmp_byte(module, cols),
"invalid tmp_bytes: tmp_bytes.len()={} < decrypt_rlwe_thread_safe_tmp_byte={}",
tmp_bytes.len(),
decrypt_rlwe_thread_safe_tmp_byte(module, a.limbs())
decrypt_rlwe_thread_safe_tmp_byte(module, cols)
);
let res_dft_bytes: usize = module.bytes_of_vec_znx_dft(a.limbs());
let res_dft_bytes: usize = module.bytes_of_vec_znx_dft(cols);
let mut res_dft: VecZnxDft = VecZnxDft::from_bytes(a.limbs(), tmp_bytes);
let mut res_dft: VecZnxDft = VecZnxDft::from_bytes(a.cols(), tmp_bytes);
let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big();
// res_dft <- DFT(ct[1]) * DFT(sk)
module.svp_apply_dft(&mut res_dft, sk, &a.value[1], a.limbs());
module.svp_apply_dft(&mut res_dft, sk, a.at(1), cols);
// res_big <- ct[1] x sk
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, a.limbs());
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols);
// res_big <- ct[1] x sk + ct[0]
module.vec_znx_big_add_small_inplace(&mut res_big, &a.value[0]);
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
// res <- normalize(ct[1] x sk + ct[0])
module.vec_znx_big_normalize(
a.log_base2k(),