use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddNormal, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxView, ZnxViewMut, }, layouts::{Backend, DataMut, DataRef, Module, ScratchOwned, VecZnx}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, source::Source, }; use crate::{ encryption::{SIGMA, SIGMA_BOUND}, layouts::{Infos, LWECiphertext, LWEPlaintext, LWESecret}, }; impl LWECiphertext { pub fn encrypt_sk( &mut self, module: &Module, pt: &LWEPlaintext, sk: &LWESecret, source_xa: &mut Source, source_xe: &mut Source, ) where DataPt: DataRef, DataSk: DataRef, Module: VecZnxFillUniform + VecZnxAddNormal + VecZnxNormalizeInplace, B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { #[cfg(debug_assertions)] { assert_eq!(self.n(), sk.n()) } let basek: usize = self.basek(); let k: usize = self.k(); module.vec_znx_fill_uniform(basek, &mut self.data, 0, k, source_xa); let mut tmp_znx: VecZnx> = VecZnx::alloc(1, 1, self.size()); let min_size = self.size().min(pt.size()); (0..min_size).for_each(|i| { tmp_znx.at_mut(0, i)[0] = pt.data.at(0, i)[0] - self.data.at(0, i)[1..] .iter() .zip(sk.data.at(0, 0)) .map(|(x, y)| x * y) .sum::(); }); (min_size..self.size()).for_each(|i| { tmp_znx.at_mut(0, i)[0] -= self.data.at(0, i)[1..] .iter() .zip(sk.data.at(0, 0)) .map(|(x, y)| x * y) .sum::(); }); module.vec_znx_add_normal(basek, &mut self.data, 0, k, source_xe, SIGMA, SIGMA_BOUND); module.vec_znx_normalize_inplace( basek, &mut tmp_znx, 0, ScratchOwned::alloc(size_of::()).borrow(), ); (0..self.size()).for_each(|i| { self.data.at_mut(0, i)[0] = tmp_znx.at(0, i)[0]; }); } }