use poulpy_hal::{ api::{ ScratchAvailable, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, }; use crate::layouts::{GGLWEInfos, GLWE, GLWELayout, LWE, LWEInfos, Rank, TorusPrecision, prepared::LWESwitchingKeyPrepared}; impl LWE> { pub fn keyswitch_scratch_space( module: &Module, out_infos: &OUT, in_infos: &IN, key_infos: &KEY, ) -> usize where OUT: LWEInfos, IN: LWEInfos, KEY: GGLWEInfos, Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDftTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxNormalizeTmpBytes, { let max_k: TorusPrecision = in_infos.k().max(out_infos.k()); let glwe_in_infos: GLWELayout = GLWELayout { n: module.n().into(), base2k: in_infos.base2k(), k: max_k, rank: Rank(1), }; let glwe_out_infos: GLWELayout = GLWELayout { n: module.n().into(), base2k: out_infos.base2k(), k: max_k, rank: Rank(1), }; let glwe_in: usize = GLWE::bytes_of_from_infos(module, &glwe_in_infos); let glwe_out: usize = GLWE::bytes_of_from_infos(module, &glwe_out_infos); let ks: usize = GLWE::keyswitch_scratch_space(module, &glwe_out_infos, &glwe_in_infos, key_infos); glwe_in + glwe_out + ks } } impl LWE { pub fn keyswitch( &mut self, module: &Module, a: &LWE, ksk: &LWESwitchingKeyPrepared, scratch: &mut Scratch, ) where A: DataRef, DKs: DataRef, Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VmpApplyDftToDft + VmpApplyDftToDftAdd + VecZnxDftApply + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxNormalize + VecZnxNormalizeTmpBytes + VecZnxCopy, Scratch: ScratchAvailable, { #[cfg(debug_assertions)] { assert!(self.n() <= module.n() as u32); assert!(a.n() <= module.n() as u32); assert!(scratch.available() >= LWE::keyswitch_scratch_space(module, self, a, ksk)); } let max_k: TorusPrecision = self.k().max(a.k()); let a_size: usize = a.k().div_ceil(ksk.base2k()) as usize; let (mut glwe_in, scratch_1) = scratch.take_glwe_ct(&GLWELayout { n: ksk.n(), base2k: a.base2k(), k: max_k, rank: Rank(1), }); glwe_in.data.zero(); let (mut glwe_out, scratch_1) = scratch_1.take_glwe_ct(&GLWELayout { n: ksk.n(), base2k: self.base2k(), k: max_k, rank: Rank(1), }); let n_lwe: usize = a.n().into(); for i in 0..a_size { let data_lwe: &[i64] = a.data.at(0, i); glwe_in.data.at_mut(0, i)[0] = data_lwe[0]; glwe_in.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); } glwe_out.keyswitch(module, &glwe_in, &ksk.0, scratch_1); self.sample_extract(&glwe_out); } }