use poulpy_hal::{ api::{ ModuleN, ScratchAvailable, ScratchTakeBasic, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, oep::{VecZnxAddScalarInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxDftApplyImpl, SvpApplyDftToDftImpl, VecZnxIdftApplyTmpAImpl, VecZnxBigNormalizeImpl}, source::Source, }; use crate::{ ScratchTakeCore, layouts::{ GetDist, GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, GLWESwitchingKey, LWEInfos, Rank, TensorKey, TensorKeyToMut, prepared::GLWESecretPrepared, }, encryption::gglwe_ksk::GLWESwitchingKeyEncryptSk, }; impl TensorKey> { pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize where A: GGLWEInfos, M: GGLWETensorKeyEncryptSk { module.gglwe_tensor_key_encrypt_sk_tmp_bytes(infos) } // pub fn encrypt_sk_tmp_bytes(module: &Module, infos: &A) -> usize // where // A: GGLWEInfos, // Module: ModuleN + SvpPPolBytesOf + SvpPPolAlloc + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigBytesOf, // { // GLWESecretPrepared::bytes_of(module, infos.rank_out()) // + module.bytes_of_vec_znx_dft(infos.rank_out().into(), 1) // + module.bytes_of_vec_znx_big(1, 1) // + module.bytes_of_vec_znx_dft(1, 1) // + GLWESecret::bytes_of(module, Rank(1)) // + GLWESwitchingKey::encrypt_sk_tmp_bytes(module, infos) // } } impl TensorKey { pub fn encrypt_sk( &mut self, module: &M, sk: &GLWESecret, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where M: GGLWETensorKeyEncryptSk, GLWESecret: GetDist, Scratch: ScratchAvailable + ScratchTakeCore, { module.gglwe_tensor_key_encrypt_sk(self, sk, source_xa, source_xe, scratch); } // pub fn encrypt_sk( // &mut self, // module: &Module, // sk: &GLWESecret, // source_xa: &mut Source, // source_xe: &mut Source, // scratch: &mut Scratch, // ) where // GLWESecret: GetDist, // Module: ModuleN // + SvpApplyDftToDft // + VecZnxIdftApplyTmpA // + VecZnxAddScalarInplace // + VecZnxDftBytesOf // + VecZnxBigNormalize // + VecZnxDftApply // + SvpApplyDftToDftInplace // + VecZnxIdftApplyConsume // + VecZnxNormalizeTmpBytes // + VecZnxFillUniform // + VecZnxSubInplace // + VecZnxAddInplace // + VecZnxNormalizeInplace // + VecZnxAddNormal // + VecZnxNormalize // + VecZnxSub // + SvpPrepare // + VecZnxSwitchRing // + SvpPPolBytesOf // + VecZnxBigAllocBytesImpl // + VecZnxBigBytesOf // + SvpPPolAlloc, // Scratch: ScratchTakeBasic + ScratchTakeCore, // { // #[cfg(debug_assertions)] // { // assert_eq!(self.rank_out(), sk.rank()); // assert_eq!(self.n(), sk.n()); // } // // let n: RingDegree = sk.n(); // let rank: Rank = self.rank_out(); // let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(module, rank); // sk_dft_prep.prepare(module, sk); // let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(module, rank.into(), 1); // (0..rank.into()).for_each(|i| { // module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); // }); // let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(module, 1, 1); // let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(module, Rank(1)); // let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(module, 1, 1); // (0..rank.into()).for_each(|i| { // (i..rank.into()).for_each(|j| { // module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i); // module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); // module.vec_znx_big_normalize( // self.base2k().into(), // &mut sk_ij.data.as_vec_znx_mut(), // 0, // self.base2k().into(), // &sk_ij_big, // 0, // scratch_5, // ); // self.at_mut(i, j) // .encrypt_sk(module, &sk_ij, sk, source_xa, source_xe, scratch_5); // }); // }) // } } pub trait GGLWETensorKeyEncryptSk where Self: Sized + ModuleN + SvpPPolBytesOf + SvpPPolAlloc + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigBytesOf, { fn gglwe_tensor_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize where A: GGLWEInfos; fn gglwe_tensor_key_encrypt_sk( &self, res: &mut R, sk: &S, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where R: TensorKeyToMut, S: GLWESecretToRef + GetDist; } impl GGLWETensorKeyEncryptSk for Module where Module: ModuleN + SvpPPolBytesOf + SvpPPolAlloc + VecZnxNormalizeTmpBytes + VecZnxDftBytesOf + VecZnxNormalizeTmpBytes + VecZnxBigBytesOf + VecZnxAddScalarInplaceImpl + VecZnxDftApply + VecZnxDftApplyImpl + SvpApplyDftToDftImpl + GLWESwitchingKeyEncryptSk + SvpApplyDftToDft + VecZnxIdftApplyTmpAImpl + VecZnxBigNormalizeImpl + VecZnxIdftApplyTmpA + VecZnxBigNormalize + VecZnxAddScalarInplaceImpl + SvpPrepare, Scratch: ScratchTakeBasic + ScratchTakeCore, { fn gglwe_tensor_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize where A: GGLWEInfos, { GLWESecretPrepared::bytes_of(self, infos.rank_out()) + self.bytes_of_vec_znx_dft(infos.rank_out().into(), 1) + self.bytes_of_vec_znx_big(1, 1) + self.bytes_of_vec_znx_dft(1, 1) + GLWESecret::bytes_of(self, Rank(1)) + GLWESwitchingKey::encrypt_sk_tmp_bytes(self, infos) } fn gglwe_tensor_key_encrypt_sk( &self, res: &mut R, sk: &S, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where R: TensorKeyToMut, S: GLWESecretToRef + GetDist, { let res: &mut TensorKey<&mut [u8]> = &mut res.to_mut(); // let n: RingDegree = sk.n(); let rank: Rank = res.rank_out(); let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(self, rank); sk_dft_prep.prepare(self, sk); let sk: &GLWESecret<&[u8]> = &sk.to_ref(); #[cfg(debug_assertions)] { assert_eq!(res.rank_out(), sk.rank()); assert_eq!(res.n(), sk.n()); } let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, rank.into(), 1); (0..rank.into()).for_each(|i| { self.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); }); let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(self, 1, 1); let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(self, Rank(1)); let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(self, 1, 1); (0..rank.into()).for_each(|i| { (i..rank.into()).for_each(|j| { self.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i); self.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); self.vec_znx_big_normalize( res.base2k().into(), &mut sk_ij.data.as_vec_znx_mut(), 0, res.base2k().into(), &sk_ij_big, 0, scratch_5, ); res.at_mut(i, j) .encrypt_sk(self, &sk_ij, sk, source_xa, source_xe, scratch_5); }); }) } }