diff --git a/poulpy-core/src/decryption/glwe_ct.rs b/poulpy-core/src/decryption/glwe_ct.rs index 4306d33..d0ff213 100644 --- a/poulpy-core/src/decryption/glwe_ct.rs +++ b/poulpy-core/src/decryption/glwe_ct.rs @@ -1,80 +1,135 @@ use poulpy_hal::{ api::{ - SvpApplyDftToDftInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxDftApply, + ModuleN, ScratchTakeBasic, + SvpApplyDftToDftInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxDftApply, VecZnxBigBytesOf, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes, }, - layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch}, + layouts::{Backend, DataMut, DataViewMut, Module, Scratch}, }; -use crate::layouts::{GLWE, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}; +use crate::{ + layouts::{ + GLWE, GLWEInfos, GLWEPlaintext, LWEInfos, GLWEToMut, GLWEPlaintextToMut, + prepared::{GLWESecretPreparedToRef, GLWESecretPrepared}, + } +}; impl GLWE> { - pub fn decrypt_tmp_bytes(module: &Module, infos: &A) -> usize + pub fn decrypt_tmp_bytes(module: &M, a_infos: &A) -> usize where A: GLWEInfos, - Module: VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf, + M: GLWEDecryption, { - let size: usize = infos.size(); - (module.vec_znx_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_dft(1, size) + module.glwe_decrypt_tmp_bytes(a_infos) } } -impl GLWE { - pub fn decrypt( - &self, - module: &Module, - pt: &mut GLWEPlaintext, - sk: &GLWESecretPrepared, - scratch: &mut Scratch, - ) where - Module: VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxBigAddInplace - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch:, +impl GLWE { + pub fn decrypt(&mut self, module: &M, pt: &mut P, sk: &S, scratch: &mut Scratch) + where + P: GLWEPlaintextToMut, + S: GLWESecretPreparedToRef, + M: GLWEDecryption, + Scratch: ScratchTakeBasic, { + module.glwe_decrypt(self, pt, sk, scratch); + } +} + +pub trait GLWEDecryption +where + Self: Sized + + ModuleN + + VecZnxDftBytesOf + + VecZnxNormalizeTmpBytes + + VecZnxBigBytesOf + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize +{ + fn glwe_decrypt_tmp_bytes(&self, infos: &A) -> usize + where + A: GLWEInfos + { + let size: usize = infos.size(); + (self.vec_znx_normalize_tmp_bytes() | self.bytes_of_vec_znx_dft(1, size)) + self.bytes_of_vec_znx_dft(1, size) + } + + fn glwe_decrypt( + &self, + res: &mut R, + pt: &mut P, + sk: &S, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + P: GLWEPlaintextToMut, + S: GLWESecretPreparedToRef, + Scratch: ScratchTakeBasic, + { + + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let pt: &mut GLWEPlaintext<&mut [u8]> = &mut pt.to_ref(); + let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref(); + #[cfg(debug_assertions)] { - assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), sk.n()); + assert_eq!(res.rank(), sk.rank()); + assert_eq!(res.n(), sk.n()); assert_eq!(pt.n(), sk.n()); } - let cols: usize = (self.rank() + 1).into(); + let cols: usize = (res.rank() + 1).into(); - let (mut c0_big, scratch_1) = scratch.take_vec_znx_big(self.n().into(), 1, self.size()); // TODO optimize size when pt << ct + let (mut c0_big, scratch_1) = scratch.take_vec_znx_big(self, 1, res.size()); // TODO optimize size when pt << ct c0_big.data_mut().fill(0); { (1..cols).for_each(|i| { // ci_dft = DFT(a[i]) * DFT(s[i]) - let (mut ci_dft, _) = scratch_1.take_vec_znx_dft(self.n().into(), 1, self.size()); // TODO optimize size when pt << ct - module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &self.data, i); - module.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1); - let ci_big = module.vec_znx_idft_apply_consume(ci_dft); + let (mut ci_dft, _) = scratch_1.take_vec_znx_dft(self, 1, res.size()); // TODO optimize size when pt << ct + self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &res.data, i); + self.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1); + let ci_big = self.vec_znx_idft_apply_consume(ci_dft); // c0_big += a[i] * s[i] - module.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0); + self.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0); }); } // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) - module.vec_znx_big_add_small_inplace(&mut c0_big, 0, &self.data, 0); + self.vec_znx_big_add_small_inplace(&mut c0_big, 0, &res.data, 0); // pt = norm(BIG(m + e)) - module.vec_znx_big_normalize( - self.base2k().into(), + self.vec_znx_big_normalize( + res.base2k().into(), &mut pt.data, 0, - self.base2k().into(), + res.base2k().into(), &c0_big, 0, scratch_1, ); - pt.base2k = self.base2k(); - pt.k = pt.k().min(self.k()); + pt.base2k = res.base2k(); + pt.k = pt.k().min(res.k()); } + } + +impl GLWEDecryption for Module where + Self: ModuleN + + VecZnxDftBytesOf + + VecZnxNormalizeTmpBytes + + VecZnxBigBytesOf + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxBigAddInplace + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize +{ +} \ No newline at end of file diff --git a/poulpy-core/src/decryption/lwe_ct.rs b/poulpy-core/src/decryption/lwe_ct.rs index ade21e3..1042b72 100644 --- a/poulpy-core/src/decryption/lwe_ct.rs +++ b/poulpy-core/src/decryption/lwe_ct.rs @@ -4,40 +4,64 @@ use poulpy_hal::{ oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, }; -use crate::layouts::{LWE, LWEInfos, LWEPlaintext, LWESecret}; +use crate::layouts::{LWE, LWEInfos, LWEPlaintext, LWESecret, LWEToMut, LWEPlaintextToMut, LWESecretToRef}; -impl LWE -where - DataSelf: DataRef, +impl LWE { - pub fn decrypt(&self, module: &Module, pt: &mut LWEPlaintext, sk: &LWESecret) + pub fn decrypt(&mut self, module: &M, pt: &mut P, sk: S) where - DataPt: DataMut, - DataSk: DataRef, - Module: ZnNormalizeInplace, + P: LWEPlaintextToMut, + S: LWESecretToRef, + M: LWEDecrypt, B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { + module.lwe_decrypt(self, pt, sk); + } +} + +pub trait LWEDecrypt +where + Self: Sized + ZnNormalizeInplace +{ + fn lwe_decrypt(&self, res: &mut R, pt: &mut P, sk: S) + where + R: LWEToMut, + P: LWEPlaintextToMut, + S: LWESecretToRef, + BE: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, + { + + let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); + let pt: &mut LWEPlaintext<&mut [u8]> = &mut pt.to_mut(); + let sk: LWESecret<&[u8]> = sk.to_ref(); + #[cfg(debug_assertions)] { - assert_eq!(self.n(), sk.n()); + assert_eq!(res.n(), sk.n()); } - (0..pt.size().min(self.size())).for_each(|i| { - pt.data.at_mut(0, i)[0] = self.data.at(0, i)[0] - + self.data.at(0, i)[1..] + (0..pt.size().min(res.size())).for_each(|i| { + pt.data.at_mut(0, i)[0] = res.data.at(0, i)[0] + + res.data.at(0, i)[1..] .iter() .zip(sk.data.at(0, 0)) .map(|(x, y)| x * y) .sum::(); }); - module.zn_normalize_inplace( + self.zn_normalize_inplace( 1, - self.base2k().into(), + res.base2k().into(), &mut pt.data, 0, ScratchOwned::alloc(size_of::()).borrow(), ); - pt.base2k = self.base2k(); - pt.k = crate::layouts::TorusPrecision(self.k().0.min(pt.size() as u32 * self.base2k().0)); + pt.base2k = res.base2k(); + pt.k = crate::layouts::TorusPrecision(res.k().0.min(pt.size() as u32 * res.base2k().0)); } } + +impl LWEDecrypt for Module where + Self: Sized + ZnNormalizeInplace +{ + +} \ No newline at end of file