use backend::hal::{ api::{ ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddScalarInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftToVecZnxBigTmpA, VecZnxNormalizeTmpBytes, VecZnxStd, VecZnxSubABInplace, ZnxZero, }, layouts::{Backend, DataRef, Module, ScalarZnx, ScratchOwned, VecZnxBig, VecZnxDft}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl}, }; use crate::{GGSWCiphertext, GLWECiphertext, GLWEDecryptFamily, GLWEPlaintext, GLWESecretExec, Infos}; pub trait GGSWAssertNoiseFamily = GLWEDecryptFamily + VecZnxBigAlloc + VecZnxDftAlloc + VecZnxBigNormalizeTmpBytes + VecZnxBigNormalize + VecZnxDftToVecZnxBigTmpA; impl GGSWCiphertext { pub fn assert_noise( &self, module: &Module, sk_exec: &GLWESecretExec, pt_want: &ScalarZnx, max_noise: F, ) where DataSk: DataRef, DataScalar: DataRef, Module: GGSWAssertNoiseFamily + VecZnxAddScalarInplace + VecZnxSubABInplace + VecZnxStd, B: TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, F: Fn(usize) -> f64, { let basek: usize = self.basek(); let k: usize = self.k(); let digits: usize = self.digits(); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); let mut pt_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(self.n(), 1, self.size()); let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(self.n(), 1, self.size()); let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWECiphertext::decrypt_scratch_space(module, self.n(), basek, k) | module.vec_znx_normalize_tmp_bytes(self.n()), ); (0..self.rank() + 1).for_each(|col_j| { (0..self.rows()).for_each(|row_i| { module.vec_znx_add_scalar_inplace(&mut pt.data, 0, (digits - 1) + row_i * digits, pt_want, 0); // mul with sk[col_j-1] if col_j > 0 { module.vec_znx_dft_from_vec_znx(1, 0, &mut pt_dft, 0, &pt.data, 0); module.svp_apply_inplace(&mut pt_dft, 0, &sk_exec.data, col_j - 1); module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt.data, 0, &pt_big, 0, scratch.borrow()); } self.at(row_i, col_j) .decrypt(module, &mut pt_have, &sk_exec, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt.data, 0); let std_pt: f64 = module.vec_znx_std(basek, &pt_have.data, 0).log2(); let noise: f64 = max_noise(col_j); println!("{} {}", std_pt, noise); assert!(std_pt <= noise, "{} > {}", std_pt, noise); pt.data.zero(); }); }); } } impl GGSWCiphertext { pub fn print_noise( &self, module: &Module, sk_exec: &GLWESecretExec, pt_want: &ScalarZnx, ) where DataSk: DataRef, DataScalar: DataRef, Module: GGSWAssertNoiseFamily + VecZnxAddScalarInplace + VecZnxSubABInplace + VecZnxStd, B: TakeVecZnxDftImpl + TakeVecZnxBigImpl + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { let basek: usize = self.basek(); let k: usize = self.k(); let digits: usize = self.digits(); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); let mut pt_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(self.n(), 1, self.size()); let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(self.n(), 1, self.size()); let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWECiphertext::decrypt_scratch_space(module, self.n(), basek, k) | module.vec_znx_normalize_tmp_bytes(module.n()), ); (0..self.rank() + 1).for_each(|col_j| { (0..self.rows()).for_each(|row_i| { module.vec_znx_add_scalar_inplace(&mut pt.data, 0, (digits - 1) + row_i * digits, pt_want, 0); // mul with sk[col_j-1] if col_j > 0 { module.vec_znx_dft_from_vec_znx(1, 0, &mut pt_dft, 0, &pt.data, 0); module.svp_apply_inplace(&mut pt_dft, 0, &sk_exec.data, col_j - 1); module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_big_normalize(basek, &mut pt.data, 0, &pt_big, 0, scratch.borrow()); } self.at(row_i, col_j) .decrypt(module, &mut pt_have, &sk_exec, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt.data, 0); let std_pt: f64 = module.vec_znx_std(basek, &pt_have.data, 0).log2(); println!("{}", std_pt); pt.data.zero(); }); }); } }