glwe and lwe decryption

This commit is contained in:
Rasoul Akhavan Mahdavi
2025-10-15 18:36:59 -04:00
parent 2ea59310fb
commit 2f2c7aef00
2 changed files with 132 additions and 53 deletions

View File

@@ -1,80 +1,135 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{
SvpApplyDftToDftInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxDftApply, ModuleN, ScratchTakeBasic,
SvpApplyDftToDftInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxDftApply, VecZnxBigBytesOf,
VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes,
}, },
layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch}, layouts::{Backend, DataMut, DataViewMut, Module, Scratch},
}; };
use crate::layouts::{GLWE, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}; use crate::{
layouts::{
GLWE, GLWEInfos, GLWEPlaintext, LWEInfos, GLWEToMut, GLWEPlaintextToMut,
prepared::{GLWESecretPreparedToRef, GLWESecretPrepared},
}
};
impl GLWE<Vec<u8>> { impl GLWE<Vec<u8>> {
pub fn decrypt_tmp_bytes<B: Backend, A>(module: &Module<B>, infos: &A) -> usize pub fn decrypt_tmp_bytes<A, M, BE: Backend>(module: &M, a_infos: &A) -> usize
where where
A: GLWEInfos, A: GLWEInfos,
Module<B>: VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, M: GLWEDecryption<BE>,
{ {
let size: usize = infos.size(); module.glwe_decrypt_tmp_bytes(a_infos)
(module.vec_znx_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_dft(1, size)
} }
} }
impl<DataSelf: DataRef> GLWE<DataSelf> { impl<DataSelf: DataMut> GLWE<DataSelf> {
pub fn decrypt<DataPt: DataMut, DataSk: DataRef, B: Backend>( pub fn decrypt<P, S, M, BE: Backend>(&mut self, module: &M, pt: &mut P, sk: &S, scratch: &mut Scratch<BE>)
&self, where
module: &Module<B>, P: GLWEPlaintextToMut,
pt: &mut GLWEPlaintext<DataPt>, S: GLWESecretPreparedToRef<BE>,
sk: &GLWESecretPrepared<DataSk, B>, M: GLWEDecryption<BE>,
scratch: &mut Scratch<B>, Scratch<BE>: ScratchTakeBasic,
) where
Module<B>: VecZnxDftApply<B>
+ SvpApplyDftToDftInplace<B>
+ VecZnxIdftApplyConsume<B>
+ VecZnxBigAddInplace<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxBigNormalize<B>,
Scratch<B>:,
{ {
module.glwe_decrypt(self, pt, sk, scratch);
}
}
pub trait GLWEDecryption<BE: Backend>
where
Self: Sized
+ ModuleN
+ VecZnxDftBytesOf
+ VecZnxNormalizeTmpBytes
+ VecZnxBigBytesOf
+ VecZnxDftApply<BE>
+ SvpApplyDftToDftInplace<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigAddInplace<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>
{
fn glwe_decrypt_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: GLWEInfos
{
let size: usize = infos.size();
(self.vec_znx_normalize_tmp_bytes() | self.bytes_of_vec_znx_dft(1, size)) + self.bytes_of_vec_znx_dft(1, size)
}
fn glwe_decrypt<R, P, S>(
&self,
res: &mut R,
pt: &mut P,
sk: &S,
scratch: &mut Scratch<BE>,
) where
R: GLWEToMut,
P: GLWEPlaintextToMut,
S: GLWESecretPreparedToRef<BE>,
Scratch<BE>: ScratchTakeBasic,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let pt: &mut GLWEPlaintext<&mut [u8]> = &mut pt.to_ref();
let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(self.rank(), sk.rank()); assert_eq!(res.rank(), sk.rank());
assert_eq!(self.n(), sk.n()); assert_eq!(res.n(), sk.n());
assert_eq!(pt.n(), sk.n()); assert_eq!(pt.n(), sk.n());
} }
let cols: usize = (self.rank() + 1).into(); let cols: usize = (res.rank() + 1).into();
let (mut c0_big, scratch_1) = scratch.take_vec_znx_big(self.n().into(), 1, self.size()); // TODO optimize size when pt << ct let (mut c0_big, scratch_1) = scratch.take_vec_znx_big(self, 1, res.size()); // TODO optimize size when pt << ct
c0_big.data_mut().fill(0); c0_big.data_mut().fill(0);
{ {
(1..cols).for_each(|i| { (1..cols).for_each(|i| {
// ci_dft = DFT(a[i]) * DFT(s[i]) // ci_dft = DFT(a[i]) * DFT(s[i])
let (mut ci_dft, _) = scratch_1.take_vec_znx_dft(self.n().into(), 1, self.size()); // TODO optimize size when pt << ct let (mut ci_dft, _) = scratch_1.take_vec_znx_dft(self, 1, res.size()); // TODO optimize size when pt << ct
module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &self.data, i); self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &res.data, i);
module.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1); self.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1);
let ci_big = module.vec_znx_idft_apply_consume(ci_dft); let ci_big = self.vec_znx_idft_apply_consume(ci_dft);
// c0_big += a[i] * s[i] // c0_big += a[i] * s[i]
module.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0); self.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0);
}); });
} }
// c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
module.vec_znx_big_add_small_inplace(&mut c0_big, 0, &self.data, 0); self.vec_znx_big_add_small_inplace(&mut c0_big, 0, &res.data, 0);
// pt = norm(BIG(m + e)) // pt = norm(BIG(m + e))
module.vec_znx_big_normalize( self.vec_znx_big_normalize(
self.base2k().into(), res.base2k().into(),
&mut pt.data, &mut pt.data,
0, 0,
self.base2k().into(), res.base2k().into(),
&c0_big, &c0_big,
0, 0,
scratch_1, scratch_1,
); );
pt.base2k = self.base2k(); pt.base2k = res.base2k();
pt.k = pt.k().min(self.k()); pt.k = pt.k().min(res.k());
} }
}
impl <BE: Backend> GLWEDecryption<BE> for Module<BE> where
Self: ModuleN
+ VecZnxDftBytesOf
+ VecZnxNormalizeTmpBytes
+ VecZnxBigBytesOf
+ VecZnxDftApply<BE>
+ SvpApplyDftToDftInplace<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigAddInplace<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>
{
} }

View File

@@ -4,40 +4,64 @@ use poulpy_hal::{
oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl},
}; };
use crate::layouts::{LWE, LWEInfos, LWEPlaintext, LWESecret}; use crate::layouts::{LWE, LWEInfos, LWEPlaintext, LWESecret, LWEToMut, LWEPlaintextToMut, LWESecretToRef};
impl<DataSelf> LWE<DataSelf> impl<DataSelf: DataRef + DataMut> LWE<DataSelf>
where
DataSelf: DataRef,
{ {
pub fn decrypt<DataPt, DataSk, B>(&self, module: &Module<B>, pt: &mut LWEPlaintext<DataPt>, sk: &LWESecret<DataSk>) pub fn decrypt<P, S, M, B>(&mut self, module: &M, pt: &mut P, sk: S)
where where
DataPt: DataMut, P: LWEPlaintextToMut,
DataSk: DataRef, S: LWESecretToRef,
Module<B>: ZnNormalizeInplace<B>, M: LWEDecrypt<B>,
B: Backend + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>, B: Backend + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
{ {
#[cfg(debug_assertions)] module.lwe_decrypt(self, pt, sk);
{ }
assert_eq!(self.n(), sk.n());
} }
(0..pt.size().min(self.size())).for_each(|i| { pub trait LWEDecrypt<BE: Backend>
pt.data.at_mut(0, i)[0] = self.data.at(0, i)[0] where
+ self.data.at(0, i)[1..] Self: Sized + ZnNormalizeInplace<BE>
{
fn lwe_decrypt<R, P, S>(&self, res: &mut R, pt: &mut P, sk: S)
where
R: LWEToMut,
P: LWEPlaintextToMut,
S: LWESecretToRef,
BE: Backend + ScratchOwnedAllocImpl<BE> + ScratchOwnedBorrowImpl<BE>,
{
let res: &mut LWE<&mut [u8]> = &mut res.to_mut();
let pt: &mut LWEPlaintext<&mut [u8]> = &mut pt.to_mut();
let sk: LWESecret<&[u8]> = sk.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), sk.n());
}
(0..pt.size().min(res.size())).for_each(|i| {
pt.data.at_mut(0, i)[0] = res.data.at(0, i)[0]
+ res.data.at(0, i)[1..]
.iter() .iter()
.zip(sk.data.at(0, 0)) .zip(sk.data.at(0, 0))
.map(|(x, y)| x * y) .map(|(x, y)| x * y)
.sum::<i64>(); .sum::<i64>();
}); });
module.zn_normalize_inplace( self.zn_normalize_inplace(
1, 1,
self.base2k().into(), res.base2k().into(),
&mut pt.data, &mut pt.data,
0, 0,
ScratchOwned::alloc(size_of::<i64>()).borrow(), ScratchOwned::alloc(size_of::<i64>()).borrow(),
); );
pt.base2k = self.base2k(); pt.base2k = res.base2k();
pt.k = crate::layouts::TorusPrecision(self.k().0.min(pt.size() as u32 * self.base2k().0)); pt.k = crate::layouts::TorusPrecision(res.k().0.min(pt.size() as u32 * res.base2k().0));
} }
} }
impl<BE: Backend> LWEDecrypt<BE> for Module<BE> where
Self: Sized + ZnNormalizeInplace<BE>
{
}