mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
81 lines
2.8 KiB
Rust
81 lines
2.8 KiB
Rust
use poulpy_hal::{
|
|
api::{
|
|
SvpApplyDftToDftInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxDftApply,
|
|
VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes,
|
|
},
|
|
layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch},
|
|
};
|
|
|
|
use crate::layouts::{GLWE, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared};
|
|
|
|
impl GLWE<Vec<u8>> {
|
|
pub fn decrypt_tmp_bytes<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
|
|
where
|
|
A: GLWEInfos,
|
|
Module<B>: VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf,
|
|
{
|
|
let size: usize = infos.size();
|
|
(module.vec_znx_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_dft(1, size)
|
|
}
|
|
}
|
|
|
|
impl<DataSelf: DataRef> GLWE<DataSelf> {
|
|
pub fn decrypt<DataPt: DataMut, DataSk: DataRef, B: Backend>(
|
|
&self,
|
|
module: &Module<B>,
|
|
pt: &mut GLWEPlaintext<DataPt>,
|
|
sk: &GLWESecretPrepared<DataSk, B>,
|
|
scratch: &mut Scratch<B>,
|
|
) where
|
|
Module<B>: VecZnxDftApply<B>
|
|
+ SvpApplyDftToDftInplace<B>
|
|
+ VecZnxIdftApplyConsume<B>
|
|
+ VecZnxBigAddInplace<B>
|
|
+ VecZnxBigAddSmallInplace<B>
|
|
+ VecZnxBigNormalize<B>,
|
|
Scratch<B>:,
|
|
{
|
|
#[cfg(debug_assertions)]
|
|
{
|
|
assert_eq!(self.rank(), sk.rank());
|
|
assert_eq!(self.n(), sk.n());
|
|
assert_eq!(pt.n(), sk.n());
|
|
}
|
|
|
|
let cols: usize = (self.rank() + 1).into();
|
|
|
|
let (mut c0_big, scratch_1) = scratch.take_vec_znx_big(self.n().into(), 1, self.size()); // TODO optimize size when pt << ct
|
|
c0_big.data_mut().fill(0);
|
|
|
|
{
|
|
(1..cols).for_each(|i| {
|
|
// ci_dft = DFT(a[i]) * DFT(s[i])
|
|
let (mut ci_dft, _) = scratch_1.take_vec_znx_dft(self.n().into(), 1, self.size()); // TODO optimize size when pt << ct
|
|
module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &self.data, i);
|
|
module.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1);
|
|
let ci_big = module.vec_znx_idft_apply_consume(ci_dft);
|
|
|
|
// c0_big += a[i] * s[i]
|
|
module.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0);
|
|
});
|
|
}
|
|
|
|
// c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
|
|
module.vec_znx_big_add_small_inplace(&mut c0_big, 0, &self.data, 0);
|
|
|
|
// pt = norm(BIG(m + e))
|
|
module.vec_znx_big_normalize(
|
|
self.base2k().into(),
|
|
&mut pt.data,
|
|
0,
|
|
self.base2k().into(),
|
|
&c0_big,
|
|
0,
|
|
scratch_1,
|
|
);
|
|
|
|
pt.base2k = self.base2k();
|
|
pt.k = pt.k().min(self.k());
|
|
}
|
|
}
|