use poulpy_hal::{ api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace}, layouts::{Backend, DataMut, DataRef, Module, ScratchOwned, ZnxView, ZnxViewMut}, }; use crate::layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToMut, LWESecret, LWESecretToRef, LWEToMut}; impl LWE { pub fn decrypt(&mut self, module: &M, pt: &mut P, sk: &S) where P: LWEPlaintextToMut, S: LWESecretToRef, M: LWEDecrypt, { module.lwe_decrypt(self, pt, sk); } } pub trait LWEDecrypt { fn lwe_decrypt(&self, res: &mut R, pt: &mut P, sk: &S) where R: LWEToMut, P: LWEPlaintextToMut, S: LWESecretToRef; } impl LWEDecrypt for Module where Self: Sized + ZnNormalizeInplace, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { fn lwe_decrypt(&self, res: &mut R, pt: &mut P, sk: &S) where R: LWEToMut, P: LWEPlaintextToMut, S: LWESecretToRef, { let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); let pt: &mut LWEPlaintext<&mut [u8]> = &mut pt.to_mut(); let sk: LWESecret<&[u8]> = sk.to_ref(); #[cfg(debug_assertions)] { assert_eq!(res.n(), sk.n()); } (0..pt.size().min(res.size())).for_each(|i| { pt.data.at_mut(0, i)[0] = res.data.at(0, i)[0] + res.data.at(0, i)[1..] .iter() .zip(sk.data.at(0, 0)) .map(|(x, y)| x * y) .sum::(); }); self.zn_normalize_inplace( 1, res.base2k().into(), &mut pt.data, 0, ScratchOwned::alloc(size_of::()).borrow(), ); pt.base2k = res.base2k(); pt.k = crate::layouts::TorusPrecision(res.k().0.min(pt.size() as u32 * res.base2k().0)); } }