use poulpy_hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxAddScalarInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalizeTmpBytes, VecZnxSubABInplace, }, layouts::{Backend, DataRef, Module, ScalarZnx, ScratchOwned, VecZnxBig, VecZnxDft, ZnxZero}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, }; use crate::layouts::{GGSWCiphertext, GLWECiphertext, GLWEPlaintext, Infos, prepared::GLWESecretPrepared}; impl GGSWCiphertext { pub fn assert_noise( &self, module: &Module, sk_prepared: &GLWESecretPrepared, pt_want: &ScalarZnx, max_noise: F, ) where DataSk: DataRef, DataScalar: DataRef, Module: VecZnxDftAllocBytes + VecZnxBigAllocBytes + VecZnxDftApply + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxNormalizeTmpBytes + VecZnxBigAlloc + VecZnxDftAlloc + VecZnxBigNormalizeTmpBytes + VecZnxIdftApplyTmpA + VecZnxAddScalarInplace + VecZnxSubABInplace, B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, F: Fn(usize) -> f64, { let basek: usize = self.basek(); let k: usize = self.k(); let digits: usize = self.digits(); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); let mut pt_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(1, self.size()); let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(1, self.size()); let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, basek, k) | module.vec_znx_normalize_tmp_bytes()); (0..self.rank() + 1).for_each(|col_j| { (0..self.rows()).for_each(|row_i| { module.vec_znx_add_scalar_inplace(&mut pt.data, 0, (digits - 1) + row_i * digits, pt_want, 0); // mul with sk[col_j-1] if col_j > 0 { module.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0); module.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1); module.vec_znx_idft_apply_tmpa(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt.data, 0, &pt_big, 0, scratch.borrow()); } self.at(row_i, col_j) .decrypt(module, &mut pt_have, sk_prepared, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt.data, 0); let std_pt: f64 = pt_have.data.std(basek, 0).log2(); let noise: f64 = max_noise(col_j); println!("{} {}", std_pt, noise); assert!(std_pt <= noise, "{} > {}", std_pt, noise); pt.data.zero(); }); }); } } impl GGSWCiphertext { pub fn print_noise( &self, module: &Module, sk_prepared: &GLWESecretPrepared, pt_want: &ScalarZnx, ) where DataSk: DataRef, DataScalar: DataRef, Module: VecZnxDftAllocBytes + VecZnxBigAllocBytes + VecZnxDftApply + SvpApplyDftToDftInplace + VecZnxIdftApplyConsume + VecZnxBigAddInplace + VecZnxBigAddSmallInplace + VecZnxBigNormalize + VecZnxNormalizeTmpBytes + VecZnxBigAlloc + VecZnxDftAlloc + VecZnxBigNormalizeTmpBytes + VecZnxIdftApplyTmpA + VecZnxAddScalarInplace + VecZnxSubABInplace, B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { let basek: usize = self.basek(); let k: usize = self.k(); let digits: usize = self.digits(); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); let mut pt_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(1, self.size()); let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(1, self.size()); let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, basek, k) | module.vec_znx_normalize_tmp_bytes()); (0..self.rank() + 1).for_each(|col_j| { (0..self.rows()).for_each(|row_i| { module.vec_znx_add_scalar_inplace(&mut pt.data, 0, (digits - 1) + row_i * digits, pt_want, 0); // mul with sk[col_j-1] if col_j > 0 { module.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0); module.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1); module.vec_znx_idft_apply_tmpa(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt.data, 0, &pt_big, 0, scratch.borrow()); } self.at(row_i, col_j) .decrypt(module, &mut pt_have, sk_prepared, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt.data, 0); let std_pt: f64 = pt_have.data.std(basek, 0).log2(); println!("col: {} row: {}: {}", col_j, row_i, std_pt); pt.data.zero(); }); }); } }