use poulpy_hal::{ api::{ScratchTakeBasic, SvpPPolBytesOf}, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut}, }; use crate::{ GLWEDecrypt, ScratchTakeCore, layouts::{GLWEInfos, GLWEPlaintext, GLWESecretPrepared, GLWESecretTensor, GLWESecretTensorPrepared, GLWETensor}, }; impl GLWETensor> { pub fn decrypt_tmp_bytes(module: &M, a_infos: &A) -> usize where A: GLWEInfos, M: GLWETensorDecrypt, { module.glwe_tensor_decrypt_tmp_bytes(a_infos) } } impl GLWETensor { pub fn decrypt( &self, module: &M, pt: &mut GLWEPlaintext

, sk: &GLWESecretPrepared, sk_tensor: &GLWESecretTensorPrepared, scratch: &mut Scratch, ) where P: DataMut, S0: DataRef, S1: DataRef, M: GLWETensorDecrypt, Scratch: ScratchTakeBasic, { module.glwe_tensor_decrypt(self, pt, sk, sk_tensor, scratch); } } pub trait GLWETensorDecrypt { fn glwe_tensor_decrypt_tmp_bytes(&self, infos: &A) -> usize where A: GLWEInfos; fn glwe_tensor_decrypt( &self, res: &GLWETensor, pt: &mut GLWEPlaintext

, sk: &GLWESecretPrepared, sk_tensor: &GLWESecretTensorPrepared, scratch: &mut Scratch, ) where R: DataRef, P: DataMut, S0: DataRef, S1: DataRef; } impl GLWETensorDecrypt for Module where Self: GLWEDecrypt + SvpPPolBytesOf, Scratch: ScratchTakeCore, { fn glwe_tensor_decrypt_tmp_bytes(&self, infos: &A) -> usize where A: GLWEInfos, { self.glwe_decrypt_tmp_bytes(infos) } fn glwe_tensor_decrypt( &self, res: &GLWETensor, pt: &mut GLWEPlaintext

, sk: &GLWESecretPrepared, sk_tensor: &GLWESecretTensorPrepared, scratch: &mut Scratch, ) where R: DataRef, P: DataMut, S0: DataRef, S1: DataRef, { let rank: usize = sk.rank().as_usize(); let (mut sk_grouped, scratch_1) = scratch.take_glwe_secret_prepared(self, (GLWESecretTensor::pairs(rank) + rank).into()); for i in 0..rank { sk_grouped.data.at_mut(i, 0).copy_from_slice(sk.data.at(i, 0)); } for i in 0..sk_grouped.rank().as_usize() - rank { sk_grouped.data.at_mut(i + rank, 0).copy_from_slice(sk_tensor.data.at(i, 0)); } self.glwe_decrypt(res, pt, &sk_grouped, scratch_1); } }