diff --git a/poulpy-backend/examples/rlwe_encrypt.rs b/poulpy-backend/examples/rlwe_encrypt.rs index b4338b9..64061e2 100644 --- a/poulpy-backend/examples/rlwe_encrypt.rs +++ b/poulpy-backend/examples/rlwe_encrypt.rs @@ -1,5 +1,5 @@ use itertools::izip; -use poulpy_backend::cpu_spqlios::FFT64Spqlios; +use poulpy_backend::cpu_fft64_ref::FFT64Ref; use poulpy_hal::{ api::{ ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal, @@ -16,9 +16,9 @@ fn main() { let ct_size: usize = 3; let msg_size: usize = 2; let log_scale: usize = msg_size * base2k - 5; - let module: Module = Module::::new(n as u64); + let module: Module = Module::::new(n as u64); - let mut scratch: ScratchOwned = ScratchOwned::::alloc(module.vec_znx_big_normalize_tmp_bytes()); + let mut scratch: ScratchOwned = ScratchOwned::::alloc(module.vec_znx_big_normalize_tmp_bytes()); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); @@ -28,7 +28,7 @@ fn main() { s.fill_ternary_prob(0, 0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_dft: SvpPPol, FFT64Spqlios> = module.svp_ppol_alloc(s.cols()); + let mut s_dft: SvpPPol, FFT64Ref> = module.svp_ppol_alloc(s.cols()); // s_dft <- DFT(s) module.svp_prepare(&mut s_dft, 0, &s, 0); @@ -43,7 +43,7 @@ fn main() { // Fill the second column with random values: ct = (0, a) module.vec_znx_fill_uniform(base2k, &mut ct, 1, &mut source); - let mut buf_dft: VecZnxDft, FFT64Spqlios> = module.vec_znx_dft_alloc(1, ct_size); + let mut buf_dft: VecZnxDft, FFT64Ref> = module.vec_znx_dft_alloc(1, ct_size); module.vec_znx_dft_apply(1, 0, &mut buf_dft, 0, &ct, 1); @@ -58,7 +58,7 @@ fn main() { // Alias scratch space (VecZnxDft is always at least as big as VecZnxBig) // BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized) - let mut buf_big: VecZnxBig, FFT64Spqlios> = module.vec_znx_big_alloc(1, ct_size); + let mut buf_big: VecZnxBig, FFT64Ref> = module.vec_znx_big_alloc(1, ct_size); module.vec_znx_idft_apply_tmpa(&mut buf_big, 0, &mut buf_dft, 0); // Creates a plaintext: VecZnx with 1 column diff --git a/poulpy-backend/src/cpu_fft64_avx/mod.rs b/poulpy-backend/src/cpu_fft64_avx/mod.rs index d9d4e5f..4ba20c7 100644 --- a/poulpy-backend/src/cpu_fft64_avx/mod.rs +++ b/poulpy-backend/src/cpu_fft64_avx/mod.rs @@ -7,7 +7,6 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp; -mod zn; mod znx_avx; pub struct FFT64Avx {} diff --git a/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs b/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs index 9e286b7..941ce19 100644 --- a/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs +++ b/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs @@ -53,10 +53,10 @@ where { fn vec_znx_normalize_impl( module: &Module, - res_basek: usize, + res_base2k: usize, res: &mut R, res_col: usize, - a_basek: usize, + a_base2k: usize, a: &A, a_col: usize, scratch: &mut Scratch, @@ -65,7 +65,7 @@ where A: VecZnxToRef, { let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::()); - vec_znx_normalize::(res_basek, res, res_col, a_basek, a, a_col, carry); + vec_znx_normalize::(res_base2k, res, res_col, a_base2k, a, a_col, carry); } } diff --git a/poulpy-backend/src/cpu_fft64_avx/zn.rs b/poulpy-backend/src/cpu_fft64_avx/zn.rs deleted file mode 100644 index 53ce1c9..0000000 --- a/poulpy-backend/src/cpu_fft64_avx/zn.rs +++ /dev/null @@ -1,73 +0,0 @@ -use poulpy_hal::{ - api::TakeSlice, - layouts::{Scratch, ZnToMut}, - oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl}, - reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform, zn_normalize_inplace, zn_normalize_tmp_bytes}, - source::Source, -}; - -use crate::cpu_fft64_avx::FFT64Avx; - -unsafe impl ZnNormalizeTmpBytesImpl for FFT64Avx { - fn zn_normalize_tmp_bytes_impl(n: usize) -> usize { - zn_normalize_tmp_bytes(n) - } -} - -unsafe impl ZnNormalizeInplaceImpl for FFT64Avx -where - Self: TakeSliceImpl, -{ - fn zn_normalize_inplace_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) - where - R: ZnToMut, - { - let (carry, _) = scratch.take_slice(n); - zn_normalize_inplace::(n, base2k, res, res_col, carry); - } -} - -unsafe impl ZnFillUniformImpl for FFT64Avx { - fn zn_fill_uniform_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) - where - R: ZnToMut, - { - zn_fill_uniform(n, base2k, res, res_col, source); - } -} - -unsafe impl ZnFillNormalImpl for FFT64Avx { - #[allow(clippy::too_many_arguments)] - fn zn_fill_normal_impl( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut, - { - zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound); - } -} - -unsafe impl ZnAddNormalImpl for FFT64Avx { - #[allow(clippy::too_many_arguments)] - fn zn_add_normal_impl( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut, - { - zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound); - } -} diff --git a/poulpy-backend/src/cpu_fft64_avx/znx_avx/normalization.rs b/poulpy-backend/src/cpu_fft64_avx/znx_avx/normalization.rs index 656fff2..9a53d02 100644 --- a/poulpy-backend/src/cpu_fft64_avx/znx_avx/normalization.rs +++ b/poulpy-backend/src/cpu_fft64_avx/znx_avx/normalization.rs @@ -68,7 +68,6 @@ pub fn znx_extract_digit_addmul_avx(base2k: usize, lsh: usize, res: &mut [i64], #[cfg(debug_assertions)] { assert_eq!(res.len(), src.len()); - assert!(lsh < base2k); } use std::arch::x86_64::{ diff --git a/poulpy-backend/src/cpu_fft64_ref/mod.rs b/poulpy-backend/src/cpu_fft64_ref/mod.rs index 360c315..e0110a4 100644 --- a/poulpy-backend/src/cpu_fft64_ref/mod.rs +++ b/poulpy-backend/src/cpu_fft64_ref/mod.rs @@ -6,7 +6,6 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp; -mod zn; mod znx; #[cfg(test)] diff --git a/poulpy-backend/src/cpu_fft64_ref/zn.rs b/poulpy-backend/src/cpu_fft64_ref/zn.rs deleted file mode 100644 index 954d559..0000000 --- a/poulpy-backend/src/cpu_fft64_ref/zn.rs +++ /dev/null @@ -1,73 +0,0 @@ -use poulpy_hal::{ - api::TakeSlice, - layouts::{Scratch, ZnToMut}, - oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl}, - reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform, zn_normalize_inplace, zn_normalize_tmp_bytes}, - source::Source, -}; - -use crate::cpu_fft64_ref::FFT64Ref; - -unsafe impl ZnNormalizeTmpBytesImpl for FFT64Ref { - fn zn_normalize_tmp_bytes_impl(n: usize) -> usize { - zn_normalize_tmp_bytes(n) - } -} - -unsafe impl ZnNormalizeInplaceImpl for FFT64Ref -where - Self: TakeSliceImpl, -{ - fn zn_normalize_inplace_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) - where - R: ZnToMut, - { - let (carry, _) = scratch.take_slice(n); - zn_normalize_inplace::(n, base2k, res, res_col, carry); - } -} - -unsafe impl ZnFillUniformImpl for FFT64Ref { - fn zn_fill_uniform_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) - where - R: ZnToMut, - { - zn_fill_uniform(n, base2k, res, res_col, source); - } -} - -unsafe impl ZnFillNormalImpl for FFT64Ref { - #[allow(clippy::too_many_arguments)] - fn zn_fill_normal_impl( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut, - { - zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound); - } -} - -unsafe impl ZnAddNormalImpl for FFT64Ref { - #[allow(clippy::too_many_arguments)] - fn zn_add_normal_impl( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut, - { - zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound); - } -} diff --git a/poulpy-backend/src/cpu_spqlios/fft64/mod.rs b/poulpy-backend/src/cpu_spqlios/fft64/mod.rs index f87b264..4a1f4e3 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/mod.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/mod.rs @@ -5,7 +5,6 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp_pmat; -mod zn; mod znx; pub struct FFT64Spqlios; diff --git a/poulpy-backend/src/cpu_spqlios/fft64/zn.rs b/poulpy-backend/src/cpu_spqlios/fft64/zn.rs deleted file mode 100644 index adc84fd..0000000 --- a/poulpy-backend/src/cpu_spqlios/fft64/zn.rs +++ /dev/null @@ -1,82 +0,0 @@ -use poulpy_hal::{ - api::TakeSlice, - layouts::{Scratch, Zn, ZnToMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut}, - oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl}, - reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform}, - source::Source, -}; - -use crate::cpu_spqlios::{FFT64Spqlios, ffi::zn64}; - -unsafe impl ZnNormalizeInplaceImpl for FFT64Spqlios -where - Self: TakeSliceImpl, -{ - fn zn_normalize_inplace_impl(n: usize, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) - where - A: ZnToMut, - { - let mut a: Zn<&mut [u8]> = a.to_mut(); - - let (tmp_bytes, _) = scratch.take_slice(n * size_of::()); - - unsafe { - zn64::zn64_normalize_base2k_ref( - n as u64, - base2k as u64, - a.at_mut_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } -} - -unsafe impl ZnFillUniformImpl for FFT64Spqlios { - fn zn_fill_uniform_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) - where - R: ZnToMut, - { - zn_fill_uniform(n, base2k, res, res_col, source); - } -} - -unsafe impl ZnFillNormalImpl for FFT64Spqlios { - #[allow(clippy::too_many_arguments)] - fn zn_fill_normal_impl( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut, - { - zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound); - } -} - -unsafe impl ZnAddNormalImpl for FFT64Spqlios { - #[allow(clippy::too_many_arguments)] - fn zn_add_normal_impl( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut, - { - zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound); - } -} diff --git a/poulpy-core/src/automorphism/gglwe_atk.rs b/poulpy-core/src/automorphism/gglwe_atk.rs index 87e545e..1e4dd92 100644 --- a/poulpy-core/src/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/automorphism/gglwe_atk.rs @@ -55,7 +55,11 @@ where A: GGLWEInfos, K: GGLWEInfos, { - self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + if res_infos.glwe_layout() == a_infos.glwe_layout() { + self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + } else { + self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + GLWE::bytes_of_from_infos(a_infos) + } } fn glwe_automorphism_key_automorphism(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) @@ -79,12 +83,16 @@ where a.dsize() ); + assert_eq!(res.base2k(), a.base2k()); + let cols_out: usize = (key.rank_out() + 1).into(); let cols_in: usize = key.rank_in().into(); let p: i64 = a.p(); let p_inv: i64 = self.galois_element_inv(p); + let same_layout: bool = res.glwe_layout() == a.glwe_layout(); + { let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); let a: &GGLWE<&[u8]> = &a.to_ref(); @@ -94,18 +102,30 @@ where let mut res_tmp: GLWE<&mut [u8]> = res.at_mut(row, col); let a_ct: GLWE<&[u8]> = a.at(row, col); - // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) - for i in 0..cols_out { - self.vec_znx_automorphism(p, res_tmp.data_mut(), i, &a_ct.data, i); + if same_layout { + // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) + for i in 0..cols_out { + self.vec_znx_automorphism(p, res_tmp.data_mut(), i, &a_ct.data, i); + } + + // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) + self.glwe_keyswitch_inplace(&mut res_tmp, key, scratch); + } else { + let (mut tmp_glwe, scratch_1) = scratch.take_glwe(a); + + // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) + for i in 0..cols_out { + self.vec_znx_automorphism(p, tmp_glwe.data_mut(), i, &a_ct.data, i); + } + + // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) + self.glwe_keyswitch(&mut res_tmp, &tmp_glwe, key, scratch_1); } - // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) - self.glwe_keyswitch_inplace(&mut res_tmp, key, scratch); - // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) - (0..cols_out).for_each(|i| { + for i in 0..cols_out { self.vec_znx_automorphism_inplace(p_inv, res_tmp.data_mut(), i, scratch); - }); + } } } } diff --git a/poulpy-core/src/automorphism/ggsw_ct.rs b/poulpy-core/src/automorphism/ggsw_ct.rs index 8644f98..0b2d977 100644 --- a/poulpy-core/src/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/automorphism/ggsw_ct.rs @@ -34,9 +34,9 @@ impl GGSW> { impl GGSW { pub fn automorphism(&mut self, module: &M, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) where - A: GGSWToRef, + A: GGSWToRef + GGSWInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, - T: GGLWEToGGSWKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, M: GGSWAutomorphism, { @@ -73,20 +73,21 @@ where fn ggsw_automorphism(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) where - R: GGSWToMut, - A: GGSWToRef, + R: GGSWToMut + GGSWInfos, + A: GGSWToRef + GGSWInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, - T: GGLWEToGGSWKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { + assert_eq!(res.dsize(), a.dsize()); + assert_eq!(res.base2k(), a.base2k()); + assert!(res.dnum() <= a.dnum()); + assert!(scratch.available() >= self.ggsw_automorphism_tmp_bytes(res, a, key, tsk)); + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let a: &GGSW<&[u8]> = &a.to_ref(); let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref(); - assert_eq!(res.dsize(), a.dsize()); - assert!(res.dnum() <= a.dnum()); - assert!(scratch.available() >= self.ggsw_automorphism_tmp_bytes(res, a, key, tsk)); - // Keyswitch the j-th row of the col 0 for row in 0..res.dnum().as_usize() { // Key-switch column 0, i.e. diff --git a/poulpy-core/src/automorphism/glwe_ct.rs b/poulpy-core/src/automorphism/glwe_ct.rs index b382197..ff9dfcd 100644 --- a/poulpy-core/src/automorphism/glwe_ct.rs +++ b/poulpy-core/src/automorphism/glwe_ct.rs @@ -7,8 +7,8 @@ use poulpy_hal::{ }; use crate::{ - GLWEKeySwitchInternal, GLWEKeyswitch, ScratchTakeCore, - layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos}, + GLWEKeySwitchInternal, GLWEKeyswitch, GLWENormalize, ScratchTakeCore, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWELayout, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos}, }; impl GLWE> { @@ -27,7 +27,7 @@ impl GLWE { pub fn automorphism(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) where M: GLWEAutomorphism, - A: GLWEToRef, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { @@ -37,7 +37,7 @@ impl GLWE { pub fn automorphism_add(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) where M: GLWEAutomorphism, - A: GLWEToRef, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { @@ -47,7 +47,7 @@ impl GLWE { pub fn automorphism_sub(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) where M: GLWEAutomorphism, - A: GLWEToRef, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { @@ -57,7 +57,7 @@ impl GLWE { pub fn automorphism_sub_negate(&mut self, module: &M, a: &A, key: &K, scratch: &mut Scratch) where M: GLWEAutomorphism, - A: GLWEToRef, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { @@ -110,46 +110,46 @@ pub trait GLWEAutomorphism { fn glwe_automorphism(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - A: GLWEToRef, + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; fn glwe_automorphism_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, + R: GLWEToMut + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; fn glwe_automorphism_add(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - A: GLWEToRef, + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; fn glwe_automorphism_add_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, + R: GLWEToMut + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; fn glwe_automorphism_sub(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - A: GLWEToRef, + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; fn glwe_automorphism_sub_negate(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - A: GLWEToRef, + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; fn glwe_automorphism_sub_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, + R: GLWEToMut + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; fn glwe_automorphism_sub_negate_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, + R: GLWEToMut + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; } @@ -164,7 +164,8 @@ where + VecZnxBigSubSmallInplace + VecZnxBigSubSmallNegateInplace + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, + + VecZnxBigNormalize + + GLWENormalize, Scratch: ScratchTakeCore, { fn glwe_automorphism_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize @@ -178,8 +179,8 @@ where fn glwe_automorphism(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - A: GLWEToRef, + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { @@ -194,7 +195,7 @@ where fn glwe_automorphism_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, + R: GLWEToMut + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { @@ -209,30 +210,58 @@ where fn glwe_automorphism_add(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - A: GLWEToRef, + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let a: &GLWE<&[u8]> = &a.to_ref(); - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); + let base2k_a: usize = a.base2k().into(); + let base2k_key: usize = key.base2k().into(); + let base2k_res: usize = res.base2k().into(); - for i in 0..res.rank().as_usize() + 1 { - self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); - self.vec_znx_big_add_small_inplace(&mut res_big, i, a.data(), i); - self.vec_znx_big_normalize( - res.base2k().into(), - res.data_mut(), - i, - key.base2k().into(), - &res_big, - i, - scratch_1, - ); - } + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + + if base2k_a != base2k_key { + let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: a.n(), + base2k: key.base2k(), + k: a.k(), + rank: a.rank(), + }); + self.glwe_normalize(&mut a_conv, a, scratch_2); + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, &a_conv, key, scratch_2); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); + self.vec_znx_big_add_small_inplace(&mut res_big, i, a_conv.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_2, + ); + } + } else { + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_add_small_inplace(&mut res_big, i, a.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_1, + ); + } + }; } fn glwe_automorphism_add_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) @@ -243,22 +272,49 @@ where { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); + let base2k_key: usize = key.base2k().into(); + let base2k_res: usize = res.base2k().into(); - for i in 0..res.rank().as_usize() + 1 { - self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); - self.vec_znx_big_add_small_inplace(&mut res_big, i, res.data(), i); - self.vec_znx_big_normalize( - res.base2k().into(), - res.data_mut(), - i, - key.base2k().into(), - &res_big, - i, - scratch_1, - ); - } + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + + if base2k_res != base2k_key { + let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: res.n(), + base2k: key.base2k(), + k: res.k(), + rank: res.rank(), + }); + self.glwe_normalize(&mut res_conv, res, scratch_2); + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, &res_conv, key, scratch_2); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); + self.vec_znx_big_add_small_inplace(&mut res_big, i, res_conv.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_2, + ); + } + } else { + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_add_small_inplace(&mut res_big, i, res.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_1, + ); + } + }; } fn glwe_automorphism_sub(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) @@ -271,22 +327,50 @@ where let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let a: &GLWE<&[u8]> = &a.to_ref(); - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); + let base2k_a: usize = a.base2k().into(); + let base2k_key: usize = key.base2k().into(); + let base2k_res: usize = res.base2k().into(); - for i in 0..res.rank().as_usize() + 1 { - self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); - self.vec_znx_big_sub_small_inplace(&mut res_big, i, a.data(), i); - self.vec_znx_big_normalize( - res.base2k().into(), - res.data_mut(), - i, - key.base2k().into(), - &res_big, - i, - scratch_1, - ); - } + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + + if base2k_a != base2k_key { + let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: a.n(), + base2k: key.base2k(), + k: a.k(), + rank: a.rank(), + }); + self.glwe_normalize(&mut a_conv, a, scratch_2); + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, &a_conv, key, scratch_2); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); + self.vec_znx_big_sub_small_inplace(&mut res_big, i, a_conv.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_2, + ); + } + } else { + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_sub_small_inplace(&mut res_big, i, a.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_1, + ); + } + }; } fn glwe_automorphism_sub_negate(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) @@ -299,22 +383,50 @@ where let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let a: &GLWE<&[u8]> = &a.to_ref(); - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); + let base2k_a: usize = a.base2k().into(); + let base2k_key: usize = key.base2k().into(); + let base2k_res: usize = res.base2k().into(); - for i in 0..res.rank().as_usize() + 1 { - self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); - self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, a.data(), i); - self.vec_znx_big_normalize( - res.base2k().into(), - res.data_mut(), - i, - key.base2k().into(), - &res_big, - i, - scratch_1, - ); - } + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + + if base2k_a != base2k_key { + let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: a.n(), + base2k: key.base2k(), + k: a.k(), + rank: a.rank(), + }); + self.glwe_normalize(&mut a_conv, a, scratch_2); + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, &a_conv, key, scratch_2); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); + self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, a_conv.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_2, + ); + } + } else { + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, a.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_1, + ); + } + }; } fn glwe_automorphism_sub_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) @@ -325,22 +437,49 @@ where { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); + let base2k_key: usize = key.base2k().into(); + let base2k_res: usize = res.base2k().into(); - for i in 0..res.rank().as_usize() + 1 { - self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); - self.vec_znx_big_sub_small_inplace(&mut res_big, i, res.data(), i); - self.vec_znx_big_normalize( - res.base2k().into(), - res.data_mut(), - i, - key.base2k().into(), - &res_big, - i, - scratch_1, - ); - } + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + + if base2k_res != base2k_key { + let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: res.n(), + base2k: key.base2k(), + k: res.k(), + rank: res.rank(), + }); + self.glwe_normalize(&mut res_conv, res, scratch_2); + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, &res_conv, key, scratch_2); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); + self.vec_znx_big_sub_small_inplace(&mut res_big, i, res_conv.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_2, + ); + } + } else { + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_sub_small_inplace(&mut res_big, i, res.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_1, + ); + } + }; } fn glwe_automorphism_sub_negate_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) @@ -351,21 +490,48 @@ where { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); + let base2k_key: usize = key.base2k().into(); + let base2k_res: usize = res.base2k().into(); - for i in 0..res.rank().as_usize() + 1 { - self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); - self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, res.data(), i); - self.vec_znx_big_normalize( - res.base2k().into(), - res.data_mut(), - i, - key.base2k().into(), - &res_big, - i, - scratch_1, - ); - } + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + + if base2k_res != base2k_key { + let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: res.n(), + base2k: key.base2k(), + k: res.k(), + rank: res.rank(), + }); + self.glwe_normalize(&mut res_conv, res, scratch_2); + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, &res_conv, key, scratch_2); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); + self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, res_conv.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_2, + ); + } + } else { + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, res.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_1, + ); + } + }; } } diff --git a/poulpy-core/src/conversion/gglwe_to_ggsw.rs b/poulpy-core/src/conversion/gglwe_to_ggsw.rs index 8554e50..9770f1d 100644 --- a/poulpy-core/src/conversion/gglwe_to_ggsw.rs +++ b/poulpy-core/src/conversion/gglwe_to_ggsw.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, }, - layouts::{Backend, DataMut, Module, Scratch, VecZnxBig}, + layouts::{Backend, DataMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VecZnxDftToRef, VecZnxToRef}, }; use crate::{ @@ -65,6 +65,7 @@ where assert_eq!(res.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32); assert_eq!(tsk.n(), self.n() as u32); + assert_eq!(res.base2k(), a.base2k()); for row in 0..res.dnum().into() { self.glwe_copy(&mut res.at_mut(row, 0), &a.at(row, 0)); @@ -111,28 +112,29 @@ where + VecZnxDftApply + VecZnxNormalize + VecZnxBigAddSmallInplace - + VecZnxIdftApplyConsume, + + VecZnxIdftApplyConsume + + VecZnxCopy, { fn ggsw_expand_rows_tmp_bytes(&self, res_infos: &R, tsk_infos: &A) -> usize where R: GGSWInfos, A: GGLWEInfos, { - let base2k_in: usize = res_infos.base2k().into(); let base2k_tsk: usize = tsk_infos.base2k().into(); let rank: usize = res_infos.rank().into(); let cols: usize = rank + 1; - let res_size = res_infos.size(); - let a_size: usize = (res_infos.size() * base2k_in).div_ceil(base2k_tsk); + let res_size: usize = res_infos.size(); + let a_size: usize = res_infos.max_k().as_usize().div_ceil(base2k_tsk); - let a_dft = self.bytes_of_vec_znx_dft(cols - 1, a_size); - let res_dft = self.bytes_of_vec_znx_dft(cols, a_size); + let a_0: usize = VecZnx::bytes_of(self.n(), 1, a_size); + let a_dft: usize = self.bytes_of_vec_znx_dft(cols - 1, a_size); + let res_dft: usize = self.bytes_of_vec_znx_dft(cols, a_size); let gglwe_prod: usize = self.gglwe_product_dft_tmp_bytes(res_size, a_size, tsk_infos); - let normalize = self.vec_znx_big_normalize_tmp_bytes(); + let normalize: usize = self.vec_znx_big_normalize_tmp_bytes(); - (a_dft + res_dft + gglwe_prod).max(normalize) + (a_0 + a_dft + res_dft + gglwe_prod).max(normalize) } fn ggsw_expand_row(&self, res: &mut R, tsk: &T, scratch: &mut Scratch) @@ -144,7 +146,7 @@ where let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref(); - let base2k_in: usize = res.base2k().into(); + let base2k_res: usize = res.base2k().into(); let base2k_tsk: usize = tsk.base2k().into(); assert!(scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk)); @@ -152,96 +154,129 @@ where let rank: usize = res.rank().into(); let cols: usize = rank + 1; - let a_size: usize = (res.size() * base2k_in).div_ceil(base2k_tsk); + let res_conv_size: usize = res.max_k().as_usize().div_ceil(base2k_tsk); + + let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, res_conv_size); + let (mut a_0, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, res_conv_size); // Keyswitch the j-th row of the col 0 for row in 0..res.dnum().as_usize() { - let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, a_size); + let glwe_mi_1: &GLWE<&[u8]> = &res.at(row, 0); - { - let glwe_mi_1: &GLWE<&[u8]> = &res.at(row, 0); - - if base2k_in == base2k_tsk { - for col_i in 0..cols - 1 { - self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, glwe_mi_1.data(), col_i + 1); - } - } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_size); - for i in 0..cols - 1 { - self.vec_znx_normalize( - base2k_tsk, - &mut a_conv, - 0, - base2k_in, - glwe_mi_1.data(), - i + 1, - scratch_2, - ); - self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_conv, 0); - } + if base2k_res == base2k_tsk { + for col_i in 0..cols - 1 { + self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, glwe_mi_1.data(), col_i + 1); } - } - - // Example for rank 3: - // - // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is - // actually composed of that many dnum and we focus on a specific row here - // implicitely given ci_dft. - // - // # Input - // - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) - // col 1: (0, 0, 0, 0) - // col 2: (0, 0, 0, 0) - // col 3: (0, 0, 0, 0) - // - // # Output - // - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) - // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) - // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) - // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) - for col in 1..cols { - let (mut res_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, tsk.size()); // Todo optimise - - // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 - // - // # Example for col=1 - // - // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) - // + - // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) - // + - // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) - // = - // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) - self.gglwe_product_dft(&mut res_dft, &a_dft, tsk.at(col - 1), scratch_2); - - let mut res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft); - - // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i - // - // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2) - // + - // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0) - // = - // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) - // = - // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) - self.vec_znx_big_add_small_inplace(&mut res_big, col, res.at(row, 0).data(), 0); - - for j in 0..cols { - self.vec_znx_big_normalize( - res.base2k().as_usize(), - res.at_mut(row, col).data_mut(), - j, - tsk.base2k().as_usize(), - &res_big, - j, + self.vec_znx_copy(&mut a_0, 0, glwe_mi_1.data(), 0); + } else { + for i in 0..cols - 1 { + self.vec_znx_normalize( + base2k_tsk, + &mut a_0, + 0, + base2k_res, + glwe_mi_1.data(), + i + 1, scratch_2, ); + self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_0, 0); } + self.vec_znx_normalize( + base2k_tsk, + &mut a_0, + 0, + base2k_res, + glwe_mi_1.data(), + 0, + scratch_2, + ); } + + ggsw_expand_rows_internal(self, row, res, &a_0, &a_dft, tsk, scratch_2) + } + } +} + +fn ggsw_expand_rows_internal( + module: &M, + row: usize, + res: &mut R, + a_0: &C, + a_dft: &A, + tsk: &T, + scratch: &mut Scratch, +) where + R: GGSWToMut, + C: VecZnxToRef, + A: VecZnxDftToRef, + M: GGLWEProduct + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize, + T: GGLWEToGGSWKeyPreparedToRef, + Scratch: ScratchTakeCore, +{ + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let a_0: &VecZnx<&[u8]> = &a_0.to_ref(); + let a_dft: &VecZnxDft<&[u8], BE> = &a_dft.to_ref(); + let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref(); + let cols: usize = res.rank().as_usize() + 1; + + // Example for rank 3: + // + // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is + // actually composed of that many dnum and we focus on a specific row here + // implicitely given ci_dft. + // + // # Input + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (0, 0, 0, 0) + // col 2: (0, 0, 0, 0) + // col 3: (0, 0, 0, 0) + // + // # Output + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) + // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) + // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) + for col in 1..cols { + let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols, tsk.size()); // Todo optimise + + // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 + // + // # Example for col=1 + // + // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) + // + + // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) + // + + // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) + module.gglwe_product_dft(&mut res_dft, a_dft, tsk.at(col - 1), scratch_1); + + let mut res_big: VecZnxBig<&mut [u8], BE> = module.vec_znx_idft_apply_consume(res_dft); + + // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i + // + // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2) + // + + // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) + // = + // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) + module.vec_znx_big_add_small_inplace(&mut res_big, col, a_0, 0); + + for j in 0..cols { + module.vec_znx_big_normalize( + res.base2k().as_usize(), + res.at_mut(row, col).data_mut(), + j, + tsk.base2k().as_usize(), + &res_big, + j, + scratch_1, + ); } } } diff --git a/poulpy-core/src/decryption/lwe.rs b/poulpy-core/src/decryption/lwe.rs index edd727b..997d7ea 100644 --- a/poulpy-core/src/decryption/lwe.rs +++ b/poulpy-core/src/decryption/lwe.rs @@ -1,39 +1,44 @@ use poulpy_hal::{ - api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace}, - layouts::{Backend, DataMut, DataRef, Module, ScratchOwned, ZnxView, ZnxViewMut}, + api::VecZnxNormalizeInplace, + layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut}, }; -use crate::layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToMut, LWESecret, LWESecretToRef, LWEToMut}; +use crate::{ + ScratchTakeCore, + layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToMut, LWESecret, LWESecretToRef, LWEToMut}, +}; impl LWE { - pub fn decrypt(&mut self, module: &M, pt: &mut P, sk: &S) + pub fn decrypt(&mut self, module: &M, pt: &mut P, sk: &S, scratch: &mut Scratch) where P: LWEPlaintextToMut, S: LWESecretToRef, - M: LWEDecrypt, + M: LWEDecrypt, + Scratch: ScratchTakeCore, { - module.lwe_decrypt(self, pt, sk); + module.lwe_decrypt(self, pt, sk, scratch); } } pub trait LWEDecrypt { - fn lwe_decrypt(&self, res: &mut R, pt: &mut P, sk: &S) - where - R: LWEToMut, - P: LWEPlaintextToMut, - S: LWESecretToRef; -} - -impl LWEDecrypt for Module -where - Self: Sized + ZnNormalizeInplace, - ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, -{ - fn lwe_decrypt(&self, res: &mut R, pt: &mut P, sk: &S) + fn lwe_decrypt(&self, res: &mut R, pt: &mut P, sk: &S, scratch: &mut Scratch) where R: LWEToMut, P: LWEPlaintextToMut, S: LWESecretToRef, + Scratch: ScratchTakeCore; +} + +impl LWEDecrypt for Module +where + Self: Sized + VecZnxNormalizeInplace, +{ + fn lwe_decrypt(&self, res: &mut R, pt: &mut P, sk: &S, scratch: &mut Scratch) + where + R: LWEToMut, + P: LWEPlaintextToMut, + S: LWESecretToRef, + Scratch: ScratchTakeCore, { let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); let pt: &mut LWEPlaintext<&mut [u8]> = &mut pt.to_mut(); @@ -52,13 +57,7 @@ where .map(|(x, y)| x * y) .sum::(); }); - self.zn_normalize_inplace( - 1, - res.base2k().into(), - &mut pt.data, - 0, - ScratchOwned::alloc(size_of::()).borrow(), - ); + self.vec_znx_normalize_inplace(res.base2k().into(), &mut pt.data, 0, scratch); pt.base2k = res.base2k(); pt.k = crate::layouts::TorusPrecision(res.k().0.min(pt.size() as u32 * res.base2k().0)); } diff --git a/poulpy-core/src/encryption/compressed/gglwe.rs b/poulpy-core/src/encryption/compressed/gglwe.rs index 1dfbf58..acf2eac 100644 --- a/poulpy-core/src/encryption/compressed/gglwe.rs +++ b/poulpy-core/src/encryption/compressed/gglwe.rs @@ -143,20 +143,21 @@ where let mut source_xa = Source::new(seed); let (mut tmp_pt, scrach_1) = scratch.take_glwe_plaintext(res); - for col_i in 0..rank_in { - for d_i in 0..dnum { + + for col_j in 0..rank_in { + for row_i in 0..dnum { // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt tmp_pt.data.zero(); // zeroes for next iteration - self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + d_i * dsize, pt, col_i); + self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, col_j); self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1); let (seed, mut source_xa_tmp) = source_xa.branch(); - seeds[col_i * dnum + d_i] = seed; + seeds[row_i * rank_in + col_j] = seed; self.glwe_encrypt_sk_internal( res.base2k().into(), res.k().into(), - &mut res.at_mut(d_i, col_i).data, + &mut res.at_mut(row_i, col_j).data, cols, true, Some((&tmp_pt, 0)), diff --git a/poulpy-core/src/encryption/compressed/ggsw.rs b/poulpy-core/src/encryption/compressed/ggsw.rs index 14b0de5..4a54a1a 100644 --- a/poulpy-core/src/encryption/compressed/ggsw.rs +++ b/poulpy-core/src/encryption/compressed/ggsw.rs @@ -105,8 +105,6 @@ where { let res: &mut GGSWCompressed<&mut [u8]> = &mut res.to_mut(); - println!("res.seed: {:?}", res.seed); - let (mut tmp_pt, scratch_1) = scratch.take_glwe_plaintext(res); let mut source = Source::new(seed_xa); diff --git a/poulpy-core/src/encryption/glwe.rs b/poulpy-core/src/encryption/glwe.rs index c81833e..e74403c 100644 --- a/poulpy-core/src/encryption/glwe.rs +++ b/poulpy-core/src/encryption/glwe.rs @@ -516,8 +516,6 @@ where // ct[i] = uniform (+ pt) self.vec_znx_fill_uniform(base2k, ct, col_ct, source_xa); - // println!("vec_znx_fill_uniform: {}", ct); - let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, size); // ci = ct[i] - pt diff --git a/poulpy-core/src/encryption/lwe.rs b/poulpy-core/src/encryption/lwe.rs index 7651e8d..1ddebd4 100644 --- a/poulpy-core/src/encryption/lwe.rs +++ b/poulpy-core/src/encryption/lwe.rs @@ -1,43 +1,67 @@ use poulpy_hal::{ - api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace}, - layouts::{Backend, DataMut, Module, ScratchOwned, Zn, ZnxView, ZnxViewMut}, + api::{VecZnxAddNormal, VecZnxFillUniform, VecZnxNormalizeInplace}, + layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxView, ZnxViewMut}, source::Source, }; use crate::{ + ScratchTakeCore, encryption::{SIGMA, SIGMA_BOUND}, layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToRef, LWESecret, LWESecretToRef, LWEToMut}, }; impl LWE { - pub fn encrypt_sk(&mut self, module: &M, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source) - where + pub fn encrypt_sk( + &mut self, + module: &M, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where P: LWEPlaintextToRef, S: LWESecretToRef, M: LWEEncryptSk, + Scratch: ScratchTakeCore, { - module.lwe_encrypt_sk(self, pt, sk, source_xa, source_xe); + module.lwe_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch); } } pub trait LWEEncryptSk { - fn lwe_encrypt_sk(&self, res: &mut R, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source) - where + fn lwe_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where R: LWEToMut, P: LWEPlaintextToRef, - S: LWESecretToRef; + S: LWESecretToRef, + Scratch: ScratchTakeCore; } impl LWEEncryptSk for Module where - Self: Sized + ZnFillUniform + ZnAddNormal + ZnNormalizeInplace, - ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Self: Sized + VecZnxFillUniform + VecZnxAddNormal + VecZnxNormalizeInplace, { - fn lwe_encrypt_sk(&self, res: &mut R, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source) - where + fn lwe_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where R: LWEToMut, P: LWEPlaintextToRef, S: LWESecretToRef, + Scratch: ScratchTakeCore, { let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); let pt: &LWEPlaintext<&[u8]> = &pt.to_ref(); @@ -51,11 +75,11 @@ where let base2k: usize = res.base2k().into(); let k: usize = res.k().into(); - self.zn_fill_uniform((res.n() + 1).into(), base2k, &mut res.data, 0, source_xa); + self.vec_znx_fill_uniform(base2k, &mut res.data, 0, source_xa); - let mut tmp_znx: Zn> = Zn::alloc(1, 1, res.size()); + let mut tmp_znx: VecZnx> = VecZnx::alloc(1, 1, res.size()); - let min_size = res.size().min(pt.size()); + let min_size: usize = res.size().min(pt.size()); (0..min_size).for_each(|i| { tmp_znx.at_mut(0, i)[0] = pt.data.at(0, i)[0] @@ -74,24 +98,9 @@ where .sum::(); }); - self.zn_add_normal( - 1, - base2k, - &mut res.data, - 0, - k, - source_xe, - SIGMA, - SIGMA_BOUND, - ); + self.vec_znx_add_normal(base2k, &mut tmp_znx, 0, k, source_xe, SIGMA, SIGMA_BOUND); - self.zn_normalize_inplace( - 1, - base2k, - &mut tmp_znx, - 0, - ScratchOwned::alloc(size_of::()).borrow(), - ); + self.vec_znx_normalize_inplace(base2k, &mut tmp_znx, 0, scratch); (0..res.size()).for_each(|i| { res.data.at_mut(0, i)[0] = tmp_znx.at(0, i)[0]; diff --git a/poulpy-core/src/external_product/gglwe.rs b/poulpy-core/src/external_product/gglwe.rs index 437cf39..c021e67 100644 --- a/poulpy-core/src/external_product/gglwe.rs +++ b/poulpy-core/src/external_product/gglwe.rs @@ -30,8 +30,8 @@ impl GLWEAutomorphismKey { pub fn external_product(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) where M: GGLWEExternalProduct, - A: GGLWEToRef, - B: GGSWPreparedToRef, + A: GGLWEToRef + GGLWEInfos, + B: GGSWPreparedToRef + GGSWInfos, Scratch: ScratchTakeCore, { module.gglwe_external_product(self, a, b, scratch); @@ -62,15 +62,11 @@ where fn gglwe_external_product(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) where - R: GGLWEToMut, - A: GGLWEToRef, - B: GGSWPreparedToRef, + R: GGLWEToMut + GGLWEInfos, + A: GGLWEToRef + GGLWEInfos, + B: GGSWPreparedToRef + GGSWInfos, Scratch: ScratchTakeCore, { - let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); - let a: &GGLWE<&[u8]> = &a.to_ref(); - let b: &GGSWPrepared<&[u8], BE> = &b.to_ref(); - assert_eq!( res.rank_in(), a.rank_in(), @@ -92,6 +88,11 @@ where res.rank_out(), b.rank() ); + assert_eq!(res.base2k(), a.base2k()); + + let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GGLWE<&[u8]> = &a.to_ref(); + let b: &GGSWPrepared<&[u8], BE> = &b.to_ref(); for row in 0..res.dnum().into() { for col in 0..res.rank_in().into() { @@ -149,8 +150,8 @@ impl GLWESwitchingKey { pub fn external_product(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) where M: GGLWEExternalProduct, - A: GGLWEToRef, - B: GGSWPreparedToRef, + A: GGLWEToRef + GGLWEInfos, + B: GGSWPreparedToRef + GGSWInfos, Scratch: ScratchTakeCore, { module.gglwe_external_product(self, a, b, scratch); diff --git a/poulpy-core/src/external_product/ggsw.rs b/poulpy-core/src/external_product/ggsw.rs index d33055d..92ff84b 100644 --- a/poulpy-core/src/external_product/ggsw.rs +++ b/poulpy-core/src/external_product/ggsw.rs @@ -50,6 +50,8 @@ where b.rank() ); + assert_eq!(res.base2k(), a.base2k()); + assert!(scratch.available() >= self.ggsw_external_product_tmp_bytes(res, a, b)); let min_dnum: usize = res.dnum().min(a.dnum()).into(); diff --git a/poulpy-core/src/external_product/glwe.rs b/poulpy-core/src/external_product/glwe.rs index 2c9fe12..ef85998 100644 --- a/poulpy-core/src/external_product/glwe.rs +++ b/poulpy-core/src/external_product/glwe.rs @@ -1,10 +1,10 @@ use poulpy_hal::{ api::{ - ModuleN, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, - VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, + ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, - layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft}, + layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnxBig, VecZnxDft}, }; use crate::{ @@ -30,7 +30,7 @@ impl GLWE> { impl GLWE { pub fn external_product(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) where - A: GLWEToRef, + A: GLWEToRef + GLWEInfos, B: GGSWPreparedToRef + GGSWInfos, M: GLWEExternalProduct, Scratch: ScratchTakeCore, @@ -57,20 +57,14 @@ pub trait GLWEExternalProduct { fn glwe_external_product_inplace(&self, res: &mut R, a: &D, scratch: &mut Scratch) where - R: GLWEToMut, + R: GLWEToMut + GLWEInfos, D: GGSWPreparedToRef + GGSWInfos, Scratch: ScratchTakeCore; fn glwe_external_product(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch) where - R: GLWEToMut, - A: GLWEToRef, - D: GGSWPreparedToRef + GGSWInfos, - Scratch: ScratchTakeCore; - fn glwe_external_product_add(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch) - where - R: GLWEToMut, - A: GLWEToRef, + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, D: GGSWPreparedToRef + GGSWInfos, Scratch: ScratchTakeCore; } @@ -84,168 +78,113 @@ where + VecZnxBigAddSmallInplace + GLWENormalize, { - fn glwe_external_product_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize + fn glwe_external_product_tmp_bytes(&self, res: &R, a: &A, ggsw: &B) -> usize where R: GLWEInfos, A: GLWEInfos, B: GGSWInfos, { - let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), b_infos.size()); - res_dft - + self - .glwe_external_product_internal_tmp_bytes(res_infos, a_infos, b_infos) - .max(self.vec_znx_big_normalize_tmp_bytes()) + let cols: usize = res.rank().as_usize() + 1; + let size: usize = if a.base2k() != ggsw.base2k() { + let a_conv_infos = &GLWELayout { + n: a.n(), + base2k: ggsw.base2k(), + k: a.k(), + rank: a.rank(), + }; + self.glwe_external_product_internal_tmp_bytes(res, a_conv_infos, ggsw) + GLWE::bytes_of_from_infos(a_conv_infos) + } else { + self.glwe_external_product_internal_tmp_bytes(res, a, ggsw) + }; + + size.max(self.vec_znx_big_normalize_tmp_bytes()) + self.bytes_of_vec_znx_dft(cols, ggsw.size()) } - fn glwe_external_product_inplace(&self, res: &mut R, a: &D, scratch: &mut Scratch) + fn glwe_external_product_inplace(&self, res: &mut R, ggsw: &D, scratch: &mut Scratch) where - R: GLWEToMut, + R: GLWEToMut + GLWEInfos, D: GGSWPreparedToRef + GGSWInfos, Scratch: ScratchTakeCore, { - let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let rhs: &GGSWPrepared<&[u8], BE> = &a.to_ref(); + assert_eq!(ggsw.rank(), res.rank()); + assert_eq!(ggsw.n(), res.n()); + assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, ggsw)); - let basek_in: usize = res.base2k().into(); - let basek_ggsw: usize = rhs.base2k().into(); + let base2k_res: usize = res.base2k().as_usize(); + let base2k_ggsw: usize = ggsw.base2k().as_usize(); - #[cfg(debug_assertions)] - { - use poulpy_hal::api::ScratchAvailable; + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), ggsw.size()); // Todo optimise - assert_eq!(rhs.rank(), res.rank()); - assert_eq!(rhs.n(), res.n()); - assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, rhs)); - } - - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), a.size()); // Todo optimise - let res_big = self.glwe_external_product_internal(res_dft, res, a, scratch_1); - for j in 0..(res.rank() + 1).into() { - self.vec_znx_big_normalize( - basek_in, - &mut res.data, - j, - basek_ggsw, - &res_big, - j, - scratch_1, - ); - } - } - - fn glwe_external_product(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch) - where - R: GLWEToMut, - A: GLWEToRef, - D: GGSWPreparedToRef, - Scratch: ScratchTakeCore, - { - let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let lhs: &GLWE<&[u8]> = &lhs.to_ref(); - - let rhs: &GGSWPrepared<&[u8], BE> = &rhs.to_ref(); - - let basek_ggsw: usize = rhs.base2k().into(); - let basek_out: usize = res.base2k().into(); - - #[cfg(debug_assertions)] - { - use poulpy_hal::api::ScratchAvailable; - - assert_eq!(rhs.rank(), lhs.rank()); - assert_eq!(rhs.rank(), res.rank()); - assert_eq!(rhs.n(), res.n()); - assert_eq!(lhs.n(), res.n()); - assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, lhs, rhs)); - } - - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), rhs.size()); // Todo optimise - let res_big = self.glwe_external_product_internal(res_dft, lhs, rhs, scratch_1); - - for j in 0..(res.rank() + 1).into() { - self.vec_znx_big_normalize( - basek_out, - &mut res.data, - j, - basek_ggsw, - &res_big, - j, - scratch_1, - ); - } - } - - fn glwe_external_product_add(&self, res: &mut R, a: &A, key: &D, scratch: &mut Scratch) - where - R: GLWEToMut, - A: GLWEToRef, - D: GGSWPreparedToRef, - Scratch: ScratchTakeCore, - { - let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let a: &GLWE<&[u8]> = &a.to_ref(); - let key: &GGSWPrepared<&[u8], BE> = &key.to_ref(); - - assert_eq!(a.base2k(), res.base2k()); - - let res_base2k: usize = res.base2k().into(); - let key_base2k: usize = key.base2k().into(); - - #[cfg(debug_assertions)] - { - use poulpy_hal::api::ScratchAvailable; - - assert_eq!(key.rank(), a.rank()); - assert_eq!(key.rank(), res.rank()); - assert_eq!(key.n(), res.n()); - assert_eq!(a.n(), res.n()); - assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, a, key)); - } - - if res_base2k == key_base2k { - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise - let mut res_big = self.glwe_external_product_internal(res_dft, a, key, scratch_1); - for j in 0..(res.rank() + 1).into() { - self.vec_znx_big_add_small_inplace(&mut res_big, j, res.data(), j); - self.vec_znx_big_normalize( - res_base2k, - &mut res.data, - j, - key_base2k, - &res_big, - j, - scratch_1, - ); - } - } else { - let (mut a_conv, scratch_1) = scratch.take_glwe(&GLWELayout { - n: a.n(), - base2k: key.base2k(), - k: a.k(), - rank: a.rank(), - }); + let res_big: VecZnxBig<&mut [u8], BE> = if base2k_res != base2k_ggsw { let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { n: res.n(), - base2k: key.base2k(), + base2k: ggsw.base2k(), k: res.k(), rank: res.rank(), }); - self.glwe_normalize(&mut a_conv, a, scratch_2); self.glwe_normalize(&mut res_conv, res, scratch_2); - let (res_dft, scratch_2) = scratch_2.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise - let mut res_big = self.glwe_external_product_internal(res_dft, &a_conv, key, scratch_2); - for j in 0..(res.rank() + 1).into() { - self.vec_znx_big_add_small_inplace(&mut res_big, j, res_conv.data(), j); - self.vec_znx_big_normalize( - res_base2k, - &mut res.data, - j, - key_base2k, - &res_big, - j, - scratch_2, - ); - } + self.glwe_external_product_internal(res_dft, &res_conv, ggsw, scratch_2) + } else { + self.glwe_external_product_internal(res_dft, res, ggsw, scratch_1) + }; + + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + for j in 0..(res.rank() + 1).into() { + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + j, + base2k_ggsw, + &res_big, + j, + scratch_1, + ); + } + } + + fn glwe_external_product(&self, res: &mut R, a: &A, ggsw: &G, scratch: &mut Scratch) + where + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, + G: GGSWPreparedToRef + GGSWInfos, + Scratch: ScratchTakeCore, + { + assert_eq!(ggsw.rank(), a.rank()); + assert_eq!(ggsw.rank(), res.rank()); + assert_eq!(ggsw.n(), res.n()); + assert_eq!(a.n(), res.n()); + assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, a, ggsw)); + + let base2k_a: usize = a.base2k().into(); + let base2k_ggsw: usize = ggsw.base2k().into(); + let base2k_res: usize = res.base2k().into(); + + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), ggsw.size()); // Todo optimise + + let res_big: VecZnxBig<&mut [u8], BE> = if base2k_a != base2k_ggsw { + let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: a.n(), + base2k: ggsw.base2k(), + k: a.k(), + rank: a.rank(), + }); + self.glwe_normalize(&mut a_conv, a, scratch_2); + self.glwe_external_product_internal(res_dft, &a_conv, ggsw, scratch_2) + } else { + self.glwe_external_product_internal(res_dft, a, ggsw, scratch_1) + }; + + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + for j in 0..(res.rank() + 1).into() { + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + j, + base2k_ggsw, + &res_big, + j, + scratch_1, + ); } } } @@ -309,12 +248,7 @@ where ); let normalize_big: usize = self.vec_znx_normalize_tmp_bytes(); - if a_infos.base2k() == b_infos.base2k() { - a_dft + (vmp | normalize_big) - } else { - let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank() + 1).into(), in_size); - (a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big - } + a_dft + vmp.max(normalize_big) } fn glwe_external_product_internal( @@ -333,69 +267,36 @@ where let a: &GLWE<&[u8]> = &a.to_ref(); let ggsw: &GGSWPrepared<&[u8], BE> = &ggsw.to_ref(); - let basek_in: usize = a.base2k().into(); - let basek_ggsw: usize = ggsw.base2k().into(); + assert_eq!(a.base2k(), ggsw.base2k()); let cols: usize = (ggsw.rank() + 1).into(); let dsize: usize = ggsw.dsize().into(); - let a_size: usize = (a.size() * basek_in).div_ceil(basek_ggsw); + let a_size: usize = a.size(); let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize)); a_dft.data_mut().fill(0); - if basek_in == basek_ggsw { - for di in 0..dsize { - // (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize) - a_dft.set_size((a.size() + di) / dsize); + for di in 0..dsize { + // (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize) + a_dft.set_size((a.size() + di) / dsize); - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(ggsw.size() - ((dsize - di) as isize - 2).max(0) as usize); - - for j in 0..cols { - self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a.data, j); - } - - if di == 0 { - self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &ggsw.data, scratch_1); - } else { - self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &ggsw.data, di, scratch_1); - } - } - } else { - let (mut a_conv, scratch_3) = scratch_1.take_vec_znx(self.n(), cols, a_size); + // Small optimization for dsize > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then + // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. + // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last dsize-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res_dft.set_size(ggsw.size() - ((dsize - di) as isize - 2).max(0) as usize); for j in 0..cols { - self.vec_znx_normalize(basek_ggsw, &mut a_conv, j, basek_in, &a.data, j, scratch_3); + self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a.data, j); } - for di in 0..dsize { - // (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize) - a_dft.set_size((a.size() + di) / dsize); - - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res_dft.set_size(ggsw.size() - ((dsize - di) as isize - 2).max(0) as usize); - - for j in 0..cols { - self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a.data, j); - } - - if di == 0 { - self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &ggsw.data, scratch_1); - } else { - self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &ggsw.data, di, scratch_1); - } + if di == 0 { + self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &ggsw.data, scratch_1); + } else { + self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &ggsw.data, di, scratch_1); } } diff --git a/poulpy-core/src/glwe_packer.rs b/poulpy-core/src/glwe_packer.rs index 6ebbdea..b0b9595 100644 --- a/poulpy-core/src/glwe_packer.rs +++ b/poulpy-core/src/glwe_packer.rs @@ -132,17 +132,21 @@ impl GLWEPacker { } /// Flush result to`res`. - pub fn flush(&mut self, module: &M, res: &mut R) + pub fn flush(&mut self, module: &M, res: &mut R, scratch: &mut Scratch) where - R: GLWEToMut, + R: GLWEToMut + GLWEInfos, M: GLWEPackerOps, + Scratch: ScratchTakeCore, { assert!(self.counter as u32 == self.accumulators[0].data.n()); - // Copy result GLWE into res GLWE - module.glwe_copy( - res, - &self.accumulators[module.log_n() - self.log_batch - 1].data, - ); + + let out: &GLWE> = &self.accumulators[module.log_n() - self.log_batch - 1].data; + + if out.base2k() == res.base2k() { + module.glwe_copy(res, out) + } else { + module.glwe_normalize(res, out, scratch); + } self.reset(); } @@ -244,7 +248,11 @@ fn pack_core( // No previous value -> copies and sets flags accordingly if let Some(a_ref) = a { - module.glwe_copy(&mut acc_mut_ref.data, a_ref); + if a_ref.base2k() == acc_mut_ref.data.base2k() { + module.glwe_copy(&mut acc_mut_ref.data, a_ref); + } else { + module.glwe_normalize(&mut acc_mut_ref.data, a_ref, scratch); + } acc_mut_ref.value = true } else { acc_mut_ref.value = false @@ -331,30 +339,29 @@ fn combine( // since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q. if acc.value { if let Some(b) = b { - let (mut tmp_b, scratch_1) = scratch.take_glwe(a); + let (mut tmp, scratch_1) = scratch.take_glwe(a); // a = a * X^-t module.glwe_rotate_inplace(-t, a, scratch_1); // tmp_b = a * X^-t - b - module.glwe_sub(&mut tmp_b, a, b); - module.glwe_rsh(1, &mut tmp_b, scratch_1); - + module.glwe_sub(&mut tmp, a, b); + module.glwe_rsh(1, &mut tmp, scratch_1); // a = a * X^-t + b module.glwe_add_inplace(a, b); - module.glwe_rsh(1, a, scratch_1); - module.glwe_normalize_inplace(&mut tmp_b, scratch_1); + module.glwe_rsh(1, a, scratch_1); + module.glwe_normalize_inplace(&mut tmp, scratch_1); // tmp_b = phi(a * X^-t - b) if let Some(auto_key) = auto_keys.get_automorphism_key(gal_el) { - module.glwe_automorphism_inplace(&mut tmp_b, auto_key, scratch_1); + module.glwe_automorphism_inplace(&mut tmp, auto_key, scratch_1); } else { panic!("auto_key[{gal_el}] not found"); } // a = a * X^-t + b - phi(a * X^-t - b) - module.glwe_sub_inplace(a, &tmp_b); + module.glwe_sub_inplace(a, &tmp); module.glwe_normalize_inplace(a, scratch_1); // a = a + b * X^t - phi(a * X^-t - b) * X^t diff --git a/poulpy-core/src/glwe_packing.rs b/poulpy-core/src/glwe_packing.rs index 1ddb565..0626c55 100644 --- a/poulpy-core/src/glwe_packing.rs +++ b/poulpy-core/src/glwe_packing.rs @@ -7,9 +7,16 @@ use poulpy_hal::{ use crate::{ GLWEAdd, GLWEAutomorphism, GLWECopy, GLWENormalize, GLWERotate, GLWEShift, GLWESub, GLWETrace, ScratchTakeCore, - layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWEAutomorphismKeyHelper, GLWEInfos, GLWEToMut, GetGaloisElement}, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEAutomorphismKeyHelper, GLWEInfos, GLWEToMut, GetGaloisElement}, }; pub trait GLWEPacking { + fn glwe_pack_galois_elements(&self) -> Vec; + + fn glwe_pack_tmp_bytes(&self, res: &R, key: &K) -> usize + where + R: GLWEInfos, + K: GGLWEInfos; + /// Packs [x_0: GLWE(m_0), x_1: GLWE(m_1), ..., x_i: GLWE(m_i)] /// to [0: GLWE(m_0 * X^x_0 + m_1 * X^x_1 + ... + m_i * X^x_i)] fn glwe_pack( @@ -40,6 +47,22 @@ where + GLWETrace, Scratch: ScratchTakeCore, { + fn glwe_pack_galois_elements(&self) -> Vec { + self.glwe_trace_galois_elements() + } + + fn glwe_pack_tmp_bytes(&self, res: &R, key: &K) -> usize + where + R: GLWEInfos, + K: GGLWEInfos, + { + self.glwe_rotate_tmp_bytes() + .max(self.glwe_rsh_tmp_byte()) + .max(self.glwe_normalize_tmp_bytes()) + .max(self.glwe_automorphism_tmp_bytes(res, res, key)) + + GLWE::bytes_of_from_infos(res) + } + /// Packs [x_0: GLWE(m_0), x_1: GLWE(m_1), ..., x_i: GLWE(m_i)] /// to [0: GLWE(m_0 * X^x_0 + m_1 * X^x_1 + ... + m_i * X^x_i)] fn glwe_pack( diff --git a/poulpy-core/src/glwe_trace.rs b/poulpy-core/src/glwe_trace.rs index ad49c4e..3a27428 100644 --- a/poulpy-core/src/glwe_trace.rs +++ b/poulpy-core/src/glwe_trace.rs @@ -1,10 +1,10 @@ use poulpy_hal::{ - api::{ModuleLogN, VecZnxNormalize, VecZnxNormalizeTmpBytes}, + api::{ModuleLogN, VecZnxNormalizeTmpBytes}, layouts::{Backend, CyclotomicOrder, DataMut, GaloisElement, Module, Scratch, VecZnx, galois_element}, }; use crate::{ - GLWEAutomorphism, GLWECopy, GLWEShift, ScratchTakeCore, + GLWEAutomorphism, GLWECopy, GLWENormalize, GLWEShift, ScratchTakeCore, layouts::{ GGLWEInfos, GGLWELayout, GGLWEPreparedToRef, GLWE, GLWEAutomorphismKeyHelper, GLWEInfos, GLWELayout, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos, @@ -75,7 +75,7 @@ where + GLWECopy + CyclotomicOrder + VecZnxNormalizeTmpBytes - + VecZnxNormalize, + + GLWENormalize, Scratch: ScratchTakeCore, { fn glwe_trace_galois_elements(&self) -> Vec { @@ -114,15 +114,28 @@ where K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, H: GLWEAutomorphismKeyHelper, { - let (mut tmp, scratch_1) = if a.k() > res.k() { - scratch.take_glwe(a) - } else { - scratch.take_glwe(res) - }; + let atk_layout: &GGLWELayout = &keys.automorphism_key_infos(); + + let (mut tmp, scratch_1) = scratch.take_glwe(&GLWELayout { + n: res.n(), + base2k: atk_layout.base2k(), + k: a.k().max(res.k()), + rank: res.rank(), + }); + + if a.base2k() == atk_layout.base2k() { + self.glwe_copy(&mut tmp, a); + } else { + self.glwe_normalize(&mut tmp, a, scratch_1); + } - self.glwe_copy(&mut tmp, a); self.glwe_trace_inplace(&mut tmp, skip, keys, scratch_1); - self.glwe_copy(res, &tmp); + + if res.base2k() == atk_layout.base2k() { + self.glwe_copy(res, &tmp); + } else { + self.glwe_normalize(res, &tmp, scratch_1); + } } fn glwe_trace_inplace(&self, res: &mut R, skip: usize, keys: &H, scratch: &mut Scratch) @@ -143,52 +156,15 @@ where assert_eq!(ksk_infos.rank_out(), res.rank()); if res.base2k() != ksk_infos.base2k() { - let (mut self_conv, scratch_1) = scratch.take_glwe(&GLWELayout { + let (mut res_conv, scratch_1) = scratch.take_glwe(&GLWELayout { n: self.n().into(), base2k: ksk_infos.base2k(), k: res.k(), rank: res.rank(), }); - - for j in 0..(res.rank() + 1).into() { - self.vec_znx_normalize( - ksk_infos.base2k().into(), - &mut self_conv.data, - j, - res.base2k().into(), - res.data(), - j, - scratch_1, - ); - } - - for i in skip..log_n { - self.glwe_rsh(1, &mut self_conv, scratch_1); - - let p: i64 = if i == 0 { - -1 - } else { - self.galois_element(1 << (i - 1)) - }; - - if let Some(key) = keys.get_automorphism_key(p) { - self.glwe_automorphism_add_inplace(&mut self_conv, key, scratch_1); - } else { - panic!("keys[{p}] is empty") - } - } - - for j in 0..(res.rank() + 1).into() { - self.vec_znx_normalize( - res.base2k().into(), - res.data_mut(), - j, - ksk_infos.base2k().into(), - &self_conv.data, - j, - scratch_1, - ); - } + self.glwe_normalize(&mut res_conv, res, scratch_1); + self.glwe_trace_inplace(&mut res_conv, skip, keys, scratch_1); + self.glwe_normalize(res, &res_conv, scratch_1); } else { for i in skip..log_n { self.glwe_rsh(1, res, scratch); diff --git a/poulpy-core/src/keyswitching/gglwe.rs b/poulpy-core/src/keyswitching/gglwe.rs index d837002..779ff11 100644 --- a/poulpy-core/src/keyswitching/gglwe.rs +++ b/poulpy-core/src/keyswitching/gglwe.rs @@ -21,7 +21,7 @@ impl GLWEAutomorphismKey> { impl GLWEAutomorphismKey { pub fn keyswitch(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) where - A: GGLWEToRef + GGLWEToRef, + A: GGLWEToRef + GGLWEInfos, B: GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, M: GGLWEKeyswitch, @@ -54,7 +54,7 @@ impl GLWESwitchingKey> { impl GLWESwitchingKey { pub fn keyswitch(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) where - A: GGLWEToRef, + A: GGLWEToRef + GGLWEInfos, B: GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, M: GGLWEKeyswitch, @@ -87,7 +87,7 @@ impl GGLWE> { impl GGLWE { pub fn keyswitch(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) where - A: GGLWEToRef, + A: GGLWEToRef + GGLWEInfos, B: GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, M: GGLWEKeyswitch, @@ -122,14 +122,11 @@ where fn gglwe_keyswitch(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) where - R: GGLWEToMut, - A: GGLWEToRef, + R: GGLWEToMut + GGLWEInfos, + A: GGLWEToRef + GGLWEInfos, B: GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { - let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); - let a: &GGLWE<&[u8]> = &a.to_ref(); - assert_eq!( res.rank_in(), a.rank_in(), @@ -164,6 +161,10 @@ where res.dsize(), a.dsize() ); + assert_eq!(res.base2k(), a.base2k()); + + let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GGLWE<&[u8]> = &a.to_ref(); for row in 0..res.dnum().into() { for col in 0..res.rank_in().into() { diff --git a/poulpy-core/src/keyswitching/ggsw.rs b/poulpy-core/src/keyswitching/ggsw.rs index 3dfb0b1..b36644d 100644 --- a/poulpy-core/src/keyswitching/ggsw.rs +++ b/poulpy-core/src/keyswitching/ggsw.rs @@ -3,7 +3,7 @@ use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch}; use crate::{ GGSWExpandRows, ScratchTakeCore, keyswitching::GLWEKeyswitch, - layouts::{GGLWEInfos, GGLWEPreparedToRef, GGLWEToGGSWKeyPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef}, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GGLWEToGGSWKeyPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef, LWEInfos}, }; impl GGSW> { @@ -29,7 +29,7 @@ impl GGSW { pub fn keyswitch(&mut self, module: &M, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) where A: GGSWToRef, - K: GGLWEPreparedToRef, + K: GGLWEPreparedToRef + GGLWEInfos, T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, M: GGSWKeyswitch, @@ -39,7 +39,7 @@ impl GGSW { pub fn keyswitch_inplace(&mut self, module: &M, key: &K, tsk: &T, scratch: &mut Scratch) where - K: GGLWEPreparedToRef, + K: GGLWEPreparedToRef + GGLWEInfos, T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, M: GGSWKeyswitch, @@ -70,7 +70,7 @@ where fn ggsw_keyswitch_inplace(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch) where R: GGSWToMut, - K: GGLWEPreparedToRef, + K: GGLWEPreparedToRef + GGLWEInfos, T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, { @@ -89,7 +89,7 @@ where where R: GGSWToMut, A: GGSWToRef, - K: GGLWEPreparedToRef, + K: GGLWEPreparedToRef + GGLWEInfos, T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, { @@ -98,6 +98,7 @@ where assert!(res.dnum() <= a.dnum()); assert_eq!(res.dsize(), a.dsize()); + assert_eq!(res.base2k(), a.base2k()); for row in 0..a.dnum().into() { // Key-switch column 0, i.e. @@ -124,14 +125,14 @@ where where R: GGSWToMut, A: GGSWToRef, - K: GGLWEPreparedToRef, + K: GGLWEPreparedToRef + GGLWEInfos, T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore; fn ggsw_keyswitch_inplace(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch) where R: GGSWToMut, - K: GGLWEPreparedToRef, + K: GGLWEPreparedToRef + GGLWEInfos, T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore; } diff --git a/poulpy-core/src/keyswitching/glwe.rs b/poulpy-core/src/keyswitching/glwe.rs index 72def40..5fe298e 100644 --- a/poulpy-core/src/keyswitching/glwe.rs +++ b/poulpy-core/src/keyswitching/glwe.rs @@ -4,12 +4,12 @@ use poulpy_hal::{ VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, - layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VecZnxDftToRef, VmpPMat, ZnxInfos}, + layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnxBig, VecZnxDft, VecZnxDftToRef, VmpPMat, ZnxInfos}, }; use crate::{ - ScratchTakeCore, - layouts::{GGLWEInfos, GGLWEPrepared, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos}, + GLWENormalize, ScratchTakeCore, + layouts::{GGLWEInfos, GGLWEPrepared, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWELayout, GLWEToMut, GLWEToRef, LWEInfos}, }; impl GLWE> { @@ -27,8 +27,8 @@ impl GLWE> { impl GLWE { pub fn keyswitch(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) where - A: GLWEToRef, - B: GGLWEPreparedToRef, + A: GLWEToRef + GLWEInfos, + B: GGLWEPreparedToRef + GGLWEInfos, M: GLWEKeyswitch, Scratch: ScratchTakeCore, { @@ -37,7 +37,7 @@ impl GLWE { pub fn keyswitch_inplace(&mut self, module: &M, a: &A, scratch: &mut Scratch) where - A: GGLWEPreparedToRef, + A: GGLWEPreparedToRef + GGLWEInfos, M: GLWEKeyswitch, Scratch: ScratchTakeCore, { @@ -47,7 +47,7 @@ impl GLWE { impl GLWEKeyswitch for Module where - Self: Sized + GLWEKeySwitchInternal + VecZnxBigNormalizeTmpBytes + VecZnxBigNormalize, + Self: Sized + GLWEKeySwitchInternal + VecZnxBigNormalizeTmpBytes + VecZnxBigNormalize + GLWENormalize, Scratch: ScratchTakeCore, { fn glwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize @@ -57,41 +57,47 @@ where B: GGLWEInfos, { let cols: usize = res_infos.rank().as_usize() + 1; - self.glwe_keyswitch_internal_tmp_bytes(res_infos, a_infos, key_infos) - .max(self.vec_znx_big_normalize_tmp_bytes()) - + self.bytes_of_vec_znx_dft(cols, key_infos.size()) + let size: usize = if a_infos.base2k() != key_infos.base2k() { + let a_conv_infos = &GLWELayout { + n: a_infos.n(), + base2k: key_infos.base2k(), + k: a_infos.k(), + rank: a_infos.rank(), + }; + self.glwe_keyswitch_internal_tmp_bytes(res_infos, a_conv_infos, key_infos) + GLWE::bytes_of_from_infos(a_conv_infos) + } else { + self.glwe_keyswitch_internal_tmp_bytes(res_infos, a_infos, key_infos) + }; + + size.max(self.vec_znx_big_normalize_tmp_bytes()) + self.bytes_of_vec_znx_dft(cols, key_infos.size()) } fn glwe_keyswitch(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - A: GLWEToRef, - K: GGLWEPreparedToRef, + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, + K: GGLWEPreparedToRef + GGLWEInfos, { - let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let a: &GLWE<&[u8]> = &a.to_ref(); - let b: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); - assert_eq!( a.rank(), - b.rank_in(), + key.rank_in(), "a.rank(): {} != b.rank_in(): {}", a.rank(), - b.rank_in() + key.rank_in() ); assert_eq!( res.rank(), - b.rank_out(), + key.rank_out(), "res.rank(): {} != b.rank_out(): {}", res.rank(), - b.rank_out() + key.rank_out() ); assert_eq!(res.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32); - assert_eq!(b.n(), self.n() as u32); + assert_eq!(key.n(), self.n() as u32); - let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, a, b); + let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, a, key); assert!( scratch.available() >= scrach_needed, @@ -99,17 +105,32 @@ where scratch.available(), ); - let basek_out: usize = res.base2k().into(); - let base2k_out: usize = b.base2k().into(); + let base2k_a: usize = a.base2k().into(); + let base2k_key: usize = key.base2k().into(); + let base2k_res: usize = res.base2k().into(); - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), b.size()); // Todo optimise - let res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, b, scratch_1); + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise + + let res_big: VecZnxBig<&mut [u8], BE> = if base2k_a != base2k_key { + let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: a.n(), + base2k: key.base2k(), + k: a.k(), + rank: a.rank(), + }); + self.glwe_normalize(&mut a_conv, a, scratch_2); + self.glwe_keyswitch_internal(res_dft, &a_conv, key, scratch_2) + } else { + self.glwe_keyswitch_internal(res_dft, a, key, scratch_1) + }; + + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); for i in 0..(res.rank() + 1).into() { self.vec_znx_big_normalize( - basek_out, - &mut res.data, + base2k_res, + res.data_mut(), i, - base2k_out, + base2k_key, &res_big, i, scratch_1, @@ -119,12 +140,9 @@ where fn glwe_keyswitch_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - K: GGLWEPreparedToRef, + R: GLWEToMut + GLWEInfos, + K: GGLWEPreparedToRef + GGLWEInfos, { - let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); - assert_eq!( res.rank(), key.rank_in(), @@ -151,17 +169,32 @@ where scratch.available(), ); - let base2k_in: usize = res.base2k().into(); - let base2k_out: usize = key.base2k().into(); + let base2k_res: usize = res.base2k().as_usize(); + let base2k_key: usize = key.base2k().as_usize(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise - let res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); + + let res_big: VecZnxBig<&mut [u8], BE> = if base2k_res != base2k_key { + let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: res.n(), + base2k: key.base2k(), + k: res.k(), + rank: res.rank(), + }); + self.glwe_normalize(&mut res_conv, res, scratch_2); + + self.glwe_keyswitch_internal(res_dft, &res_conv, key, scratch_2) + } else { + self.glwe_keyswitch_internal(res_dft, res, key, scratch_1) + }; + + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); for i in 0..(res.rank() + 1).into() { self.vec_znx_big_normalize( - base2k_in, - &mut res.data, + base2k_res, + res.data_mut(), i, - base2k_out, + base2k_key, &res_big, i, scratch_1, @@ -179,14 +212,14 @@ pub trait GLWEKeyswitch { fn glwe_keyswitch(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - A: GLWEToRef, - K: GGLWEPreparedToRef; + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, + K: GGLWEPreparedToRef + GGLWEInfos; fn glwe_keyswitch_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) where - R: GLWEToMut, - K: GGLWEPreparedToRef; + R: GLWEToMut + GLWEInfos, + K: GGLWEPreparedToRef + GGLWEInfos; } impl GLWEKeySwitchInternal for Module where @@ -216,14 +249,7 @@ where { let cols: usize = (a_infos.rank() + 1).into(); let a_size: usize = a_infos.size(); - - let a_conv = if a_infos.base2k() == key_infos.base2k() { - 0 - } else { - VecZnx::bytes_of(self.n(), 1, a_size) + self.vec_znx_normalize_tmp_bytes() - }; - - self.gglwe_product_dft_tmp_bytes(res_infos.size(), a_size, key_infos) + self.bytes_of_vec_znx_dft(cols, a_size) + a_conv + self.gglwe_product_dft_tmp_bytes(res_infos.size(), a_size, key_infos) + self.bytes_of_vec_znx_dft(cols - 1, a_size) } fn glwe_keyswitch_internal( @@ -241,36 +267,14 @@ where { let a: &GLWE<&[u8]> = &a.to_ref(); let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); - - let base2k_in: usize = a.base2k().into(); - let base2k_out: usize = key.base2k().into(); + assert_eq!(a.base2k(), key.base2k()); let cols: usize = (a.rank() + 1).into(); - let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out); - + let a_size: usize = a.size(); let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, a_size); - - if base2k_in == base2k_out { - for col_i in 0..cols - 1 { - self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, a.data(), col_i + 1); - } - } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_size); - for i in 0..cols - 1 { - self.vec_znx_normalize( - base2k_out, - &mut a_conv, - 0, - base2k_in, - a.data(), - i + 1, - scratch_2, - ); - self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_conv, 0); - } + for col_i in 0..cols - 1 { + self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, a.data(), col_i + 1); } - self.gglwe_product_dft(&mut res, &a_dft, key, scratch_1); - let mut res_big: VecZnxBig = self.vec_znx_idft_apply_consume(res); self.vec_znx_big_add_small_inplace(&mut res_big, 0, a.data(), 0); res_big diff --git a/poulpy-core/src/keyswitching/lwe.rs b/poulpy-core/src/keyswitching/lwe.rs index bf5abf5..a31df66 100644 --- a/poulpy-core/src/keyswitching/lwe.rs +++ b/poulpy-core/src/keyswitching/lwe.rs @@ -83,34 +83,30 @@ where assert_eq!(ksk.n(), self.n() as u32); assert!(scratch.available() >= self.lwe_keyswitch_tmp_bytes(res, a, ksk)); - let max_k: TorusPrecision = res.k().max(a.k()); - - let a_size: usize = a.k().div_ceil(ksk.base2k()) as usize; - let (mut glwe_in, scratch_1) = scratch.take_glwe(&GLWELayout { n: ksk.n(), base2k: a.base2k(), - k: max_k, + k: a.k(), rank: Rank(1), }); glwe_in.data.zero(); - let (mut glwe_out, scratch_1) = scratch_1.take_glwe(&GLWELayout { - n: ksk.n(), - base2k: res.base2k(), - k: max_k, - rank: Rank(1), - }); - let n_lwe: usize = a.n().into(); - for i in 0..a_size { + for i in 0..a.size() { let data_lwe: &[i64] = a.data.at(0, i); glwe_in.data.at_mut(0, i)[0] = data_lwe[0]; glwe_in.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); } - self.glwe_keyswitch(&mut glwe_out, &glwe_in, ksk, scratch_1); + let (mut glwe_out, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: ksk.n(), + base2k: res.base2k(), + k: res.k(), + rank: Rank(1), + }); + + self.glwe_keyswitch(&mut glwe_out, &glwe_in, ksk, scratch_2); self.lwe_sample_extract(res, &glwe_out); } } diff --git a/poulpy-core/src/layouts/compressed/gglwe.rs b/poulpy-core/src/layouts/compressed/gglwe.rs index 0fb0382..bb5fe55 100644 --- a/poulpy-core/src/layouts/compressed/gglwe.rs +++ b/poulpy-core/src/layouts/compressed/gglwe.rs @@ -263,9 +263,8 @@ where let rank_in: usize = res.rank_in().into(); let dnum: usize = res.dnum().into(); - - for row_i in 0..dnum { - for col_i in 0..rank_in { + for col_i in 0..rank_in { + for row_i in 0..dnum { self.decompress_glwe(&mut res.at_mut(row_i, col_i), &other.at(row_i, col_i)); } } diff --git a/poulpy-core/src/layouts/compressed/lwe.rs b/poulpy-core/src/layouts/compressed/lwe.rs index ce4c000..0c7a27e 100644 --- a/poulpy-core/src/layouts/compressed/lwe.rs +++ b/poulpy-core/src/layouts/compressed/lwe.rs @@ -1,10 +1,10 @@ use std::fmt; use poulpy_hal::{ - api::ZnFillUniform, + api::VecZnxFillUniform, layouts::{ - Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo, Zn, ZnToMut, ZnToRef, ZnxInfos, ZnxView, - ZnxViewMut, + Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo, ZnxInfos, + ZnxView, ZnxViewMut, }, source::Source, }; @@ -13,7 +13,7 @@ use crate::layouts::{Base2K, Degree, LWE, LWEInfos, LWEToMut, TorusPrecision}; #[derive(PartialEq, Eq, Clone)] pub struct LWECompressed { - pub(crate) data: Zn, + pub(crate) data: VecZnx, pub(crate) k: TorusPrecision, pub(crate) base2k: Base2K, pub(crate) seed: [u8; 32], @@ -72,7 +72,7 @@ impl LWECompressed> { pub fn alloc(base2k: Base2K, k: TorusPrecision) -> Self { LWECompressed { - data: Zn::alloc(1, 1, k.0.div_ceil(base2k.0) as usize), + data: VecZnx::alloc(1, 1, k.0.div_ceil(base2k.0) as usize), k, base2k, seed: [0u8; 32], @@ -87,7 +87,7 @@ impl LWECompressed> { } pub fn bytes_of(base2k: Base2K, k: TorusPrecision) -> usize { - Zn::bytes_of(1, 1, k.0.div_ceil(base2k.0) as usize) + VecZnx::bytes_of(1, 1, k.0.div_ceil(base2k.0) as usize) } } @@ -113,7 +113,7 @@ impl WriterTo for LWECompressed { pub trait LWEDecompress where - Self: ZnFillUniform, + Self: VecZnxFillUniform, { fn decompress_lwe(&self, res: &mut R, other: &O) where @@ -126,20 +126,14 @@ where assert_eq!(res.lwe_layout(), other.lwe_layout()); let mut source: Source = Source::new(other.seed); - self.zn_fill_uniform( - res.n().into(), - other.base2k().into(), - &mut res.data, - 0, - &mut source, - ); + self.vec_znx_fill_uniform(other.base2k().into(), &mut res.data, 0, &mut source); for i in 0..res.size() { res.data.at_mut(0, i)[0] = other.data.at(0, i)[0]; } } } -impl LWEDecompress for Module where Self: ZnFillUniform {} +impl LWEDecompress for Module where Self: VecZnxFillUniform {} impl LWE { pub fn decompress(&mut self, module: &M, other: &O) diff --git a/poulpy-core/src/layouts/glwe_plaintext.rs b/poulpy-core/src/layouts/glwe_plaintext.rs index 3261d3d..411617d 100644 --- a/poulpy-core/src/layouts/glwe_plaintext.rs +++ b/poulpy-core/src/layouts/glwe_plaintext.rs @@ -158,3 +158,15 @@ impl GLWEPlaintextToMut for GLWEPlaintext { } } } + +impl GLWEPlaintext { + pub fn data_mut(&mut self) -> &mut VecZnx { + &mut self.data + } +} + +impl GLWEPlaintext { + pub fn data(&self) -> &VecZnx { + &self.data + } +} diff --git a/poulpy-core/src/layouts/lwe.rs b/poulpy-core/src/layouts/lwe.rs index 6f8cdce..ce50d54 100644 --- a/poulpy-core/src/layouts/lwe.rs +++ b/poulpy-core/src/layouts/lwe.rs @@ -1,7 +1,7 @@ use std::fmt; use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo, Zn, ZnToMut, ZnToRef, ZnxInfos}, + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo, ZnxInfos}, source::Source, }; @@ -15,7 +15,7 @@ pub trait LWEInfos { } fn k(&self) -> TorusPrecision; fn max_k(&self) -> TorusPrecision { - TorusPrecision(self.k().0 * self.size() as u32) + TorusPrecision(self.base2k().0 * self.size() as u32) } fn base2k(&self) -> Base2K; fn size(&self) -> usize { @@ -57,7 +57,7 @@ impl LWEInfos for LWELayout { } #[derive(PartialEq, Eq, Clone)] pub struct LWE { - pub(crate) data: Zn, + pub(crate) data: VecZnx, pub(crate) k: TorusPrecision, pub(crate) base2k: Base2K, } @@ -90,13 +90,13 @@ impl SetLWEInfos for LWE { } impl LWE { - pub fn data(&self) -> &Zn { + pub fn data(&self) -> &VecZnx { &self.data } } impl LWE { - pub fn data_mut(&mut self) -> &Zn { + pub fn data_mut(&mut self) -> &VecZnx { &mut self.data } } @@ -121,7 +121,7 @@ impl fmt::Display for LWE { impl FillUniform for LWE where - Zn: FillUniform, + VecZnx: FillUniform, { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.data.fill_uniform(log_bound, source); @@ -138,7 +138,7 @@ impl LWE> { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision) -> Self { LWE { - data: Zn::alloc((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize), + data: VecZnx::alloc((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize), k, base2k, } @@ -152,7 +152,7 @@ impl LWE> { } pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision) -> usize { - Zn::bytes_of((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize) + VecZnx::bytes_of((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize) } } diff --git a/poulpy-core/src/layouts/lwe_plaintext.rs b/poulpy-core/src/layouts/lwe_plaintext.rs index 966ceb5..7c0d39a 100644 --- a/poulpy-core/src/layouts/lwe_plaintext.rs +++ b/poulpy-core/src/layouts/lwe_plaintext.rs @@ -1,6 +1,6 @@ use std::fmt; -use poulpy_hal::layouts::{Data, DataMut, DataRef, Zn, ZnToMut, ZnToRef, ZnxInfos}; +use poulpy_hal::layouts::{Data, DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos}; use crate::layouts::{Base2K, Degree, LWEInfos, TorusPrecision}; @@ -29,7 +29,7 @@ impl LWEInfos for LWEPlaintextLayout { } pub struct LWEPlaintext { - pub(crate) data: Zn, + pub(crate) data: VecZnx, pub(crate) k: TorusPrecision, pub(crate) base2k: Base2K, } @@ -62,7 +62,7 @@ impl LWEPlaintext> { pub fn alloc(base2k: Base2K, k: TorusPrecision) -> Self { LWEPlaintext { - data: Zn::alloc(1, 1, k.0.div_ceil(base2k.0) as usize), + data: VecZnx::alloc(1, 1, k.0.div_ceil(base2k.0) as usize), k, base2k, } @@ -111,8 +111,14 @@ impl LWEPlaintextToMut for LWEPlaintext { } } +impl LWEPlaintext { + pub fn data(&self) -> &VecZnx { + &self.data + } +} + impl LWEPlaintext { - pub fn data_mut(&mut self) -> &mut Zn { + pub fn data_mut(&mut self) -> &mut VecZnx { &mut self.data } } diff --git a/poulpy-core/src/noise/gglwe.rs b/poulpy-core/src/noise/gglwe.rs index 8831517..12c907f 100644 --- a/poulpy-core/src/noise/gglwe.rs +++ b/poulpy-core/src/noise/gglwe.rs @@ -1,76 +1,90 @@ use poulpy_hal::{ - api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, VecZnxFillUniform, VecZnxSubScalarInplace}, - layouts::{Backend, DataRef, Module, ScalarZnxToRef, Scratch, ScratchOwned, ZnxZero}, + api::VecZnxAddScalarInplace, + layouts::{Backend, DataRef, Module, ScalarZnxToRef, Scratch, Stats, ZnxZero}, }; -use crate::decryption::GLWEDecrypt; -use crate::layouts::{GGLWE, GGLWEInfos, GGLWEToRef, GLWEPlaintext, LWEInfos, prepared::GLWESecretPreparedToRef}; +use crate::{ + GLWENoise, + layouts::{GGLWE, GGLWEInfos, GGLWEToRef, prepared::GLWESecretPreparedToRef}, +}; +use crate::{ScratchTakeCore, layouts::GLWEPlaintext}; impl GGLWE { - pub fn assert_noise(&self, module: &M, sk_prepared: &S, pt_want: &P, max_noise: f64) + pub fn noise( + &self, + module: &M, + row: usize, + col: usize, + pt_want: &P, + sk_prepared: &S, + scratch: &mut Scratch, + ) -> Stats where S: GLWESecretPreparedToRef, P: ScalarZnxToRef, M: GGLWENoise, - Scratch: ScratchTakeBasic, { - module.gglwe_assert_noise(self, sk_prepared, pt_want, max_noise); + module.gglwe_noise(self, row, col, pt_want, sk_prepared, scratch) } } pub trait GGLWENoise { - fn gglwe_assert_noise(&self, res: &R, sk_prepared: &S, pt_want: &P, max_noise: f64) + fn gglwe_noise_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn gglwe_noise( + &self, + res: &R, + res_row: usize, + res_col: usize, + pt_want: &P, + sk_prepared: &S, + scratch: &mut Scratch, + ) -> Stats where R: GGLWEToRef, S: GLWESecretPreparedToRef, - P: ScalarZnxToRef, - Scratch: ScratchTakeBasic; + P: ScalarZnxToRef; } impl GGLWENoise for Module where - Module: GLWEDecrypt + VecZnxFillUniform + VecZnxSubScalarInplace, - ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, - Scratch: ScratchAvailable + ScratchTakeBasic, + Module: VecZnxAddScalarInplace + GLWENoise, + Scratch: ScratchTakeCore, { - fn gglwe_assert_noise(&self, res: &R, sk_prepared: &S, pt_want: &P, max_noise: f64) + fn gglwe_noise_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + GLWEPlaintext::bytes_of_from_infos(infos) + self.glwe_noise_tmp_bytes(infos) + } + + fn gglwe_noise( + &self, + res: &R, + res_row: usize, + res_col: usize, + pt_want: &P, + sk_prepared: &S, + scratch: &mut Scratch, + ) -> Stats where R: GGLWEToRef, S: GLWESecretPreparedToRef, P: ScalarZnxToRef, - ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, - Scratch: ScratchAvailable + ScratchTakeBasic, { let res: &GGLWE<&[u8]> = &res.to_ref(); - let dsize: usize = res.dsize().into(); - let base2k: usize = res.base2k().into(); - - let mut scratch: ScratchOwned = ScratchOwned::alloc(self.glwe_decrypt_tmp_bytes(res)); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(res); - - (0..res.rank_in().into()).for_each(|col_i| { - (0..res.dnum().into()).for_each(|row_i| { - self.glwe_decrypt( - &res.at(row_i, col_i), - &mut pt, - sk_prepared, - scratch.borrow(), - ); - - self.vec_znx_sub_scalar_inplace(&mut pt.data, 0, (dsize - 1) + row_i * dsize, pt_want, col_i); - - let noise_have: f64 = pt.data.stats(base2k, 0).std().log2(); - - println!("noise_have: {noise_have}"); - - assert!( - noise_have <= max_noise, - "noise_have: {noise_have} > max_noise: {max_noise}" - ); - - pt.data.zero(); - }); - }); + let (mut pt, scratch_1) = scratch.take_glwe_plaintext(res); + pt.data_mut().zero(); + self.vec_znx_add_scalar_inplace( + &mut pt.data, + 0, + (dsize - 1) + res_row * dsize, + pt_want, + res_col, + ); + self.glwe_noise(&res.at(res_row, res_col), &pt, sk_prepared, scratch_1) } } diff --git a/poulpy-core/src/noise/ggsw.rs b/poulpy-core/src/noise/ggsw.rs index 0c0173f..069d4f6 100644 --- a/poulpy-core/src/noise/ggsw.rs +++ b/poulpy-core/src/noise/ggsw.rs @@ -1,45 +1,49 @@ use poulpy_hal::{ api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, SvpApplyDftToDftInplace, VecZnxAddScalarInplace, VecZnxBigAlloc, - VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA, VecZnxNormalizeTmpBytes, VecZnxSubInplace, + ScratchTakeBasic, SvpApplyDftToDftInplace, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, }, - layouts::{Backend, DataRef, Module, ScalarZnxToRef, Scratch, ScratchOwned, VecZnxBig, VecZnxDft, ZnxZero}, + layouts::{Backend, DataRef, Module, ScalarZnxToRef, Scratch, Stats, ZnxZero}, }; -use crate::decryption::GLWEDecrypt; -use crate::layouts::prepared::GLWESecretPreparedToRef; -use crate::layouts::{GGSW, GGSWInfos, GGSWToRef, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}; +use crate::layouts::{GGSW, GGSWInfos, GGSWToRef, LWEInfos, prepared::GLWESecretPrepared}; +use crate::{GLWENoise, layouts::prepared::GLWESecretPreparedToRef}; +use crate::{ScratchTakeCore, layouts::GLWEPlaintext}; impl GGSW { - pub fn assert_noise(&self, module: &M, sk_prepared: &S, pt_want: &P, max_noise: &F) + pub fn noise( + &self, + module: &M, + row: usize, + col: usize, + pt_want: &P, + sk_prepared: &S, + scratch: &mut Scratch, + ) -> Stats where S: GLWESecretPreparedToRef, P: ScalarZnxToRef, M: GGSWNoise, - F: Fn(usize) -> f64, + Scratch: ScratchTakeCore, { - module.ggsw_assert_noise(self, sk_prepared, pt_want, max_noise); - } - - pub fn print_noise(&self, module: &M, sk_prepared: &S, pt_want: &P) - where - S: GLWESecretPreparedToRef, - P: ScalarZnxToRef, - M: GGSWNoise, - { - module.ggsw_print_noise(self, sk_prepared, pt_want); + module.ggsw_noise(self, row, col, pt_want, sk_prepared, scratch) } } pub trait GGSWNoise { - fn ggsw_assert_noise(&self, res: &R, sk_prepared: &S, pt_want: &P, max_noise: &F) + fn ggsw_noise_tmp_bytes(&self, infos: &A) -> usize where - R: GGSWToRef, - S: GLWESecretPreparedToRef, - P: ScalarZnxToRef, - F: Fn(usize) -> f64; + A: GGSWInfos; - fn ggsw_print_noise(&self, res: &R, sk_prepared: &S, pt_want: &P) + fn ggsw_noise( + &self, + res: &R, + res_row: usize, + res_col: usize, + pt_want: &P, + sk_prepared: &S, + scratch: &mut Scratch, + ) -> Stats where R: GGSWToRef, S: GLWESecretPreparedToRef, @@ -48,79 +52,39 @@ pub trait GGSWNoise { impl GGSWNoise for Module where - Module: GLWEDecrypt - + VecZnxDftAlloc - + VecZnxBigAlloc - + VecZnxAddScalarInplace - + VecZnxIdftApplyTmpA - + VecZnxSubInplace, - Scratch: ScratchTakeBasic, - ScratchOwned: ScratchOwnedBorrow + ScratchOwnedAlloc, + Module: VecZnxAddScalarInplace + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxDftBytesOf + + VecZnxBigNormalize + + VecZnxBigNormalizeTmpBytes + + GLWENoise, + Scratch: ScratchTakeCore, { - fn ggsw_assert_noise(&self, res: &R, sk_prepared: &S, pt_want: &P, max_noise: &F) + fn ggsw_noise_tmp_bytes(&self, infos: &A) -> usize where - R: GGSWToRef, - S: GLWESecretPreparedToRef, - P: ScalarZnxToRef, - F: Fn(usize) -> f64, + A: GGSWInfos, { - let res: &GGSW<&[u8]> = &res.to_ref(); - let sk_prepared: &GLWESecretPrepared<&[u8], BE> = &sk_prepared.to_ref(); - - let base2k: usize = res.base2k().into(); - let dsize: usize = res.dsize().into(); - - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(res); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(res); - let mut pt_dft: VecZnxDft, BE> = self.vec_znx_dft_alloc(1, res.size()); - let mut pt_big: VecZnxBig, BE> = self.vec_znx_big_alloc(1, res.size()); - - let mut scratch: ScratchOwned = - ScratchOwned::alloc(self.glwe_decrypt_tmp_bytes(res) | self.vec_znx_normalize_tmp_bytes()); - - (0..(res.rank() + 1).into()).for_each(|col_j| { - (0..res.dnum().into()).for_each(|row_i| { - self.vec_znx_add_scalar_inplace(&mut pt.data, 0, (dsize - 1) + row_i * dsize, pt_want, 0); - - // mul with sk[col_j-1] - if col_j > 0 { - self.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0); - self.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1); - self.vec_znx_idft_apply_tmpa(&mut pt_big, 0, &mut pt_dft, 0); - self.vec_znx_big_normalize( - base2k, - &mut pt.data, - 0, - base2k, - &pt_big, - 0, - scratch.borrow(), - ); - } - - self.glwe_decrypt( - &res.at(row_i, col_j), - &mut pt_have, - sk_prepared, - scratch.borrow(), - ); - - self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0); - - let std_pt: f64 = pt_have.data.stats(base2k, 0).std().log2(); - let noise: f64 = max_noise(col_j); - assert!(std_pt <= noise, "{std_pt} > {noise}"); - - pt.data.zero(); - }); - }); + GLWEPlaintext::bytes_of_from_infos(infos) + + (self.bytes_of_vec_znx_dft(1, infos.size()) + self.vec_znx_big_normalize_tmp_bytes()) + .max(self.glwe_noise_tmp_bytes(infos)) } - fn ggsw_print_noise(&self, res: &R, sk_prepared: &S, pt_want: &P) + fn ggsw_noise( + &self, + res: &R, + res_row: usize, + res_col: usize, + pt_want: &P, + sk_prepared: &S, + scratch: &mut Scratch, + ) -> Stats where R: GGSWToRef, S: GLWESecretPreparedToRef, P: ScalarZnxToRef, + Scratch: ScratchTakeCore, { let res: &GGSW<&[u8]> = &res.to_ref(); let sk_prepared: &GLWESecretPrepared<&[u8], BE> = &sk_prepared.to_ref(); @@ -128,47 +92,19 @@ where let base2k: usize = res.base2k().into(); let dsize: usize = res.dsize().into(); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(res); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(res); - let mut pt_dft: VecZnxDft, BE> = self.vec_znx_dft_alloc(1, res.size()); - let mut pt_big: VecZnxBig, BE> = self.vec_znx_big_alloc(1, res.size()); + let (mut pt, scratch_1) = scratch.take_glwe_plaintext(res); + pt.data_mut().zero(); + self.vec_znx_add_scalar_inplace(&mut pt.data, 0, (dsize - 1) + res_row * dsize, pt_want, 0); - let mut scratch: ScratchOwned = - ScratchOwned::alloc(self.glwe_decrypt_tmp_bytes(res) | self.vec_znx_normalize_tmp_bytes()); - - for col_j in 0..(res.rank() + 1).into() { - for row_i in 0..res.dnum().into() { - self.vec_znx_add_scalar_inplace(&mut pt.data, 0, (dsize - 1) + row_i * dsize, pt_want, 0); - - // mul with sk[col_j-1] - if col_j > 0 { - self.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0); - self.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1); - self.vec_znx_idft_apply_tmpa(&mut pt_big, 0, &mut pt_dft, 0); - self.vec_znx_big_normalize( - base2k, - &mut pt.data, - 0, - base2k, - &pt_big, - 0, - scratch.borrow(), - ); - } - - self.glwe_decrypt( - &res.at(row_i, col_j), - &mut pt_have, - sk_prepared, - scratch.borrow(), - ); - - self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0); - - let std_pt: f64 = pt_have.data.stats(base2k, 0).std().log2(); - println!("col: {col_j} row: {row_i}: {std_pt}"); - pt.data.zero(); - } + // mul with sk[col_j-1] + if res_col > 0 { + let (mut pt_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, 1, res.size()); + self.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0); + self.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, res_col - 1); + let pt_big = self.vec_znx_idft_apply_consume(pt_dft); + self.vec_znx_big_normalize(base2k, &mut pt.data, 0, base2k, &pt_big, 0, scratch_2); } + + self.glwe_noise(&res.at(res_row, res_col), &pt, sk_prepared, scratch_1) } } diff --git a/poulpy-core/src/noise/glwe.rs b/poulpy-core/src/noise/glwe.rs index dbdeea9..5cfb512 100644 --- a/poulpy-core/src/noise/glwe.rs +++ b/poulpy-core/src/noise/glwe.rs @@ -1,83 +1,61 @@ -use poulpy_hal::{ - api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalizeInplace, VecZnxSubInplace}, - layouts::{Backend, DataRef, Module, Scratch, ScratchOwned, Stats}, -}; +use poulpy_hal::layouts::{Backend, DataRef, Module, Scratch, Stats}; use crate::{ - ScratchTakeCore, + GLWENormalize, GLWESub, ScratchTakeCore, decryption::GLWEDecrypt, - layouts::{GLWE, GLWEPlaintext, GLWEPlaintextToRef, GLWEToRef, LWEInfos, prepared::GLWESecretPreparedToRef}, + layouts::{GLWE, GLWEInfos, GLWEPlaintext, GLWEToRef, LWEInfos, prepared::GLWESecretPreparedToRef}, }; impl GLWE { - pub fn noise(&self, module: &M, sk_prepared: &S, pt_want: &P, scratch: &mut Scratch) -> Stats + pub fn noise(&self, module: &M, pt_want: &P, sk_prepared: &S, scratch: &mut Scratch) -> Stats where M: GLWENoise, + P: GLWEToRef, S: GLWESecretPreparedToRef, - P: GLWEPlaintextToRef, { - module.glwe_noise(self, sk_prepared, pt_want, scratch) - } - - pub fn assert_noise(&self, module: &M, sk_prepared: &S, pt_want: &P, max_noise: f64) - where - S: GLWESecretPreparedToRef, - P: GLWEPlaintextToRef, - M: GLWENoise, - { - module.glwe_assert_noise(self, sk_prepared, pt_want, max_noise); + module.glwe_noise(self, pt_want, sk_prepared, scratch) } } pub trait GLWENoise { - fn glwe_noise(&self, res: &R, sk_prepared: &S, pt_want: &P, scratch: &mut Scratch) -> Stats + fn glwe_noise_tmp_bytes(&self, infos: &A) -> usize where - R: GLWEToRef, - S: GLWESecretPreparedToRef, - P: GLWEPlaintextToRef; + A: GLWEInfos; - fn glwe_assert_noise(&self, res: &R, sk_prepared: &S, pt_want: &P, max_noise: f64) + fn glwe_noise(&self, res: &R, pt_want: &P, sk_prepared: &S, scratch: &mut Scratch) -> Stats where - R: GLWEToRef, - S: GLWESecretPreparedToRef, - P: GLWEPlaintextToRef; + R: GLWEToRef + GLWEInfos, + P: GLWEToRef, + S: GLWESecretPreparedToRef; } impl GLWENoise for Module where - Module: GLWEDecrypt + VecZnxSubInplace + VecZnxNormalizeInplace, - ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Module: GLWEDecrypt + GLWESub + GLWENormalize, Scratch: ScratchTakeCore, { - fn glwe_noise(&self, res: &R, sk_prepared: &S, pt_want: &P, scratch: &mut Scratch) -> Stats + fn glwe_noise_tmp_bytes(&self, infos: &A) -> usize where - R: GLWEToRef, - S: GLWESecretPreparedToRef, - P: GLWEPlaintextToRef, + A: GLWEInfos, { - let res_ref: &GLWE<&[u8]> = &res.to_ref(); - - let pt_want: &GLWEPlaintext<&[u8]> = &pt_want.to_ref(); - - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(res_ref); - self.glwe_decrypt(res, &mut pt_have, sk_prepared, scratch); - self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt_want.data, 0); - self.vec_znx_normalize_inplace(res_ref.base2k().into(), &mut pt_have.data, 0, scratch); - pt_have.data.stats(res_ref.base2k().into(), 0) + GLWEPlaintext::bytes_of_from_infos(infos) + + self + .glwe_normalize_tmp_bytes() + .max(self.glwe_decrypt_tmp_bytes(infos)) } - fn glwe_assert_noise(&self, res: &R, sk_prepared: &S, pt_want: &P, max_noise: f64) + fn glwe_noise(&self, res: &R, pt_want: &P, sk_prepared: &S, scratch: &mut Scratch) -> Stats where - R: GLWEToRef, + R: GLWEToRef + GLWEInfos, + P: GLWEToRef, S: GLWESecretPreparedToRef, - P: GLWEPlaintextToRef, { - let res: &GLWE<&[u8]> = &res.to_ref(); - let mut scratch: ScratchOwned = ScratchOwned::alloc(self.glwe_decrypt_tmp_bytes(res)); - let noise_have: f64 = self - .glwe_noise(res, sk_prepared, pt_want, scratch.borrow()) - .std() - .log2(); - assert!(noise_have <= max_noise, "{noise_have} {max_noise}"); + let (mut pt_have, scratch_1) = scratch.take_glwe_plaintext(res); + self.glwe_decrypt(res, &mut pt_have, sk_prepared, scratch_1); + // println!("pt_have: {pt_have}"); + // println!("pt_want: {}", pt_want.to_ref()); + self.glwe_sub_inplace(&mut pt_have, pt_want); + self.glwe_normalize_inplace(&mut pt_have, scratch_1); + pt_have.data.stats(pt_have.base2k().into(), 0) } } diff --git a/poulpy-core/src/noise/mod.rs b/poulpy-core/src/noise/mod.rs index 6f8882f..e65a2ef 100644 --- a/poulpy-core/src/noise/mod.rs +++ b/poulpy-core/src/noise/mod.rs @@ -38,6 +38,32 @@ pub(crate) fn var_noise_gglwe_product( noise } +#[allow(clippy::too_many_arguments)] +#[allow(dead_code)] +pub(crate) fn var_noise_gglwe_product_v2( + n: f64, + k_ksk: usize, + dnum: usize, + dsize: usize, + base2k: usize, + var_xs: f64, + var_msg: f64, + var_a_err: f64, + var_gct_err_lhs: f64, + var_gct_err_rhs: f64, + rank_in: f64, +) -> f64 { + let base: f64 = ((dsize * base2k) as f64).exp2(); + let var_base: f64 = base * base / 12f64; + let scale: f64 = (k_ksk as f64).exp2(); + + let mut noise: f64 = (dnum as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); + noise += var_msg * var_a_err * var_base * n; + noise *= rank_in; + noise /= scale * scale; + noise +} + #[allow(clippy::too_many_arguments)] #[allow(dead_code)] pub(crate) fn log2_std_noise_gglwe_product( diff --git a/poulpy-core/src/scratch.rs b/poulpy-core/src/scratch.rs index 06b8a04..1d39a9b 100644 --- a/poulpy-core/src/scratch.rs +++ b/poulpy-core/src/scratch.rs @@ -23,7 +23,7 @@ where where A: LWEInfos, { - let (data, scratch) = self.take_zn(infos.n().into(), 1, infos.size()); + let (data, scratch) = self.take_vec_znx(infos.n().into(), 1, infos.size()); ( LWE { k: infos.k(), diff --git a/poulpy-core/src/tests/mod.rs b/poulpy-core/src/tests/mod.rs index e29f625..7708c2a 100644 --- a/poulpy-core/src/tests/mod.rs +++ b/poulpy-core/src/tests/mod.rs @@ -6,64 +6,6 @@ mod serialization; #[allow(unused_imports)] use poulpy_hal::backend_test_suite; -#[cfg(test)] -backend_test_suite!( - mod cpu_spqlios, - backend = poulpy_backend::cpu_spqlios::FFT64Spqlios, - size = 1<<8, - tests = { - //GLWE Encryption - glwe_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_sk, - glwe_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_compressed_encrypt_sk, - glwe_encrypt_zero_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_zero_sk, - glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk, -// GLWE Keyswitch -glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, -glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace, -// GLWE Automorphism -glwe_automorphism => crate::tests::test_suite::automorphism::test_glwe_automorphism, -glwe_automorphism_inplace => crate::tests::test_suite::automorphism::test_glwe_automorphism_inplace, -// GLWE External Product -glwe_external_product => crate::tests::test_suite::external_product::test_glwe_external_product, -glwe_external_product_inplace => crate::tests::test_suite::external_product::test_glwe_external_product_inplace, -// GLWE Trace -glwe_trace_inplace => crate::tests::test_suite::test_glwe_trace_inplace, -glwe_packing => crate::tests::test_suite::test_glwe_packer, -// GGLWE Encryption -gglwe_switching_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_encrypt_sk, -gglwe_switching_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_compressed_encrypt_sk, -gglwe_automorphism_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_encrypt_sk, -gglwe_automorphism_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_compressed_encrypt_sk, -gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk, -gglwe_tensor_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_compressed_encrypt_sk, -gglwe_to_ggsw_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_to_ggsw_key_encrypt_sk, -// GGLWE Keyswitching -gglwe_switching_key_keyswitch => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch, -gglwe_switching_key_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch_inplace, -// GGLWE External Product -gglwe_switching_key_external_product => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product, -gglwe_switching_key_external_product_inplace => crate::tests::test_suite::external_product::test_gglwe_switching_key_external_product_inplace, -// GGLWE Automorphism -gglwe_automorphism_key_automorphism => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism, -gglwe_automorphism_key_automorphism_inplace => crate::tests::test_suite::automorphism::test_gglwe_automorphism_key_automorphism_inplace, -// GGSW Encryption -ggsw_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_encrypt_sk, -ggsw_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_ggsw_compressed_encrypt_sk, -// GGSW Keyswitching -ggsw_keyswitch => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch, -ggsw_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_ggsw_keyswitch_inplace, -// GGSW External Product -ggsw_external_product => crate::tests::test_suite::external_product::test_ggsw_external_product, -ggsw_external_product_inplace => crate::tests::test_suite::external_product::test_ggsw_external_product_inplace, -// GGSW Automorphism -ggsw_automorphism => crate::tests::test_suite::automorphism::test_ggsw_automorphism, -ggsw_automorphism_inplace => crate::tests::test_suite::automorphism::test_ggsw_automorphism_inplace, -// LWE -lwe_keyswitch => crate::tests::test_suite::keyswitch::test_lwe_keyswitch, -glwe_to_lwe => crate::tests::test_suite::test_glwe_to_lwe, -lwe_to_glwe => crate::tests::test_suite::test_lwe_to_glwe, -} -); #[cfg(test)] backend_test_suite!( mod cpu_ref, @@ -75,6 +17,8 @@ backend_test_suite!( glwe_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_compressed_encrypt_sk, glwe_encrypt_zero_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_zero_sk, glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk, +// GLWE Base2k Conversion +glwe_base2k_conv => crate::tests::test_suite::test_glwe_base2k_conversion, // GLWE Keyswitch glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace, @@ -86,10 +30,12 @@ glwe_external_product => crate::tests::test_suite::external_product::test_glwe_e glwe_external_product_inplace => crate::tests::test_suite::external_product::test_glwe_external_product_inplace, // GLWE Trace glwe_trace_inplace => crate::tests::test_suite::test_glwe_trace_inplace, -glwe_packing => crate::tests::test_suite::test_glwe_packer, +glwe_packing => crate::tests::test_suite::test_glwe_packing, +glwe_packer => crate::tests::test_suite::test_glwe_packer, // GGLWE Encryption gglwe_switching_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_encrypt_sk, gglwe_switching_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_compressed_encrypt_sk, +gglwe_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_compressed_encrypt_sk, gglwe_automorphism_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_encrypt_sk, gglwe_automorphism_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_compressed_encrypt_sk, gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk, @@ -135,6 +81,8 @@ backend_test_suite!( glwe_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_glwe_compressed_encrypt_sk, glwe_encrypt_zero_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_zero_sk, glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk, + // GLWE Base2k Conversion +glwe_base2k_conv => crate::tests::test_suite::test_glwe_base2k_conversion, // GLWE Keyswitch glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace, @@ -146,10 +94,12 @@ glwe_external_product => crate::tests::test_suite::external_product::test_glwe_e glwe_external_product_inplace => crate::tests::test_suite::external_product::test_glwe_external_product_inplace, // GLWE Trace glwe_trace_inplace => crate::tests::test_suite::test_glwe_trace_inplace, -glwe_packing => crate::tests::test_suite::test_glwe_packer, +glwe_packing => crate::tests::test_suite::test_glwe_packing, +glwe_packer => crate::tests::test_suite::test_glwe_packer, // GGLWE Encryption gglwe_switching_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_encrypt_sk, gglwe_switching_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_switching_key_compressed_encrypt_sk, +gglwe_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_compressed_encrypt_sk, gglwe_automorphism_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_encrypt_sk, gglwe_automorphism_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_compressed_encrypt_sk, gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk, diff --git a/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs b/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs index 14aab05..50d2d1b 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs @@ -5,14 +5,14 @@ use poulpy_hal::{ }; use crate::{ - GLWEAutomorphismKeyAutomorphism, GLWEAutomorphismKeyEncryptSk, GLWEDecrypt, ScratchTakeCore, + GGLWENoise, GLWEAutomorphismKeyAutomorphism, GLWEAutomorphismKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGLWEInfos, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWEPlaintext, - GLWESecret, GLWESecretPreparedFactory, + GGLWEInfos, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWEInfos, GLWESecret, + GLWESecretPreparedFactory, prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared}, }, - noise::log2_std_noise_gglwe_product, + var_noise_gglwe_product_v2, }; #[allow(clippy::too_many_arguments)] @@ -25,30 +25,31 @@ where + GaloisElement + VecZnxSubScalarInplace + GLWESecretPreparedFactory - + GLWEDecrypt, + + GGLWENoise, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 60; - let k_out: usize = 40; - let dsize: usize = k_in.div_ceil(base2k); + let base2k_in: usize = 17; + let base2k_key: usize = 13; + let base2k_out: usize = base2k_in; // MUST BE SAME + let k_in: usize = 102; + let max_dsize: usize = k_in.div_ceil(base2k_key); let p0: i64 = -1; let p1: i64 = -5; for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_apply: usize = (dsize + di) * base2k; + for dsize in 1..max_dsize + 1 { + let k_ksk: usize = k_in + base2k_key * dsize; + let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); let dsize_in: usize = 1; - let dnum_in: usize = k_in / (base2k * di); - let dnum_out: usize = k_out / (base2k * di); - let dnum_apply: usize = k_in.div_ceil(base2k * di); + let dnum_in: usize = k_in / base2k_in; + let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize); let auto_key_in_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_in.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -57,19 +58,19 @@ where let auto_key_out_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), - dnum: dnum_out.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank: rank.into(), }; let auto_key_apply_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k.into(), - k: k_apply.into(), - dnum: dnum_apply.into(), - dsize: di.into(), + base2k: base2k_key.into(), + k: k_ksk.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank: rank.into(), }; @@ -83,13 +84,16 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_in_infos) - | GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_apply_infos) - | GLWEAutomorphismKey::automorphism_tmp_bytes( + .max(GLWEAutomorphismKey::encrypt_sk_tmp_bytes( + module, + &auto_key_apply_infos, + )) + .max(GLWEAutomorphismKey::automorphism_tmp_bytes( module, &auto_key_out_infos, &auto_key_in_infos, &auto_key_apply_infos, - ), + )), ); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&auto_key_in); @@ -128,8 +132,6 @@ where scratch.borrow(), ); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&auto_key_out_infos); - let mut sk_auto: GLWESecret> = GLWESecret::alloc_from_infos(&auto_key_out_infos); sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk for i in 0..rank { @@ -145,41 +147,37 @@ where let mut sk_auto_dft: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk_auto); sk_auto_dft.prepare(module, &sk_auto); - (0..auto_key_out.rank_in().into()).for_each(|col_i| { - (0..auto_key_out.dnum().into()).for_each(|row_i| { - auto_key_out - .at(row_i, col_i) - .decrypt(module, &mut pt, &sk_auto_dft, scratch.borrow()); + let max_noise: f64 = var_noise_gglwe_product_v2( + module.n() as f64, + k_ksk, + dnum_ksk, + dsize, + base2k_key, + 0.5, + 0.5, + 0f64, + SIGMA * SIGMA, + 0f64, + rank as f64, + ) + .sqrt() + .log2(); - module.vec_znx_sub_scalar_inplace( - &mut pt.data, - 0, - (dsize_in - 1) + row_i * dsize_in, - &sk.data, - col_i, - ); - - let noise_have: f64 = pt.data.stats(base2k, 0).std().log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - n as f64, - base2k * di, - 0.5, - 0.5, - 0f64, - SIGMA * SIGMA, - 0f64, - rank as f64, - k_out, - k_apply, - ); + for row in 0..auto_key_out.dnum().as_usize() { + for col in 0..auto_key_out.rank().as_usize() { + let noise_have = auto_key_out + .key + .noise(module, row, col, &sk.data, &sk_auto_dft, scratch.borrow()) + .std() + .log2(); assert!( - noise_have < noise_want + 0.5, - "{noise_have} {}", - noise_want + 0.5 + noise_have < max_noise + 0.5, + "{noise_have} > {}", + max_noise + 0.5 ); - }); - }); + } + } } } } @@ -194,29 +192,31 @@ where + GaloisElement + VecZnxSubScalarInplace + GLWESecretPreparedFactory - + GLWEDecrypt, + + GGLWENoise, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 60; - let dsize: usize = k_in.div_ceil(base2k); + let base2k_out: usize = 17; + let base2k_key: usize = 13; + let k_out: usize = 102; + let max_dsize: usize = k_out.div_ceil(base2k_key); + let p0: i64 = -1; let p1: i64 = -5; for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_apply: usize = (dsize + di) * base2k; + for dsize in 1..max_dsize + 1 { + let k_ksk: usize = k_out + base2k_key * dsize; let n: usize = module.n(); let dsize_in: usize = 1; - let dnum_in: usize = k_in / (base2k * di); - let dnum_apply: usize = k_in.div_ceil(base2k * di); + let dnum_in: usize = k_out / base2k_out; + let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize); let auto_key_layout: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k.into(), - k: k_in.into(), + base2k: base2k_out.into(), + k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), rank: rank.into(), @@ -224,10 +224,10 @@ where let auto_key_apply_layout: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k.into(), - k: k_apply.into(), - dnum: dnum_apply.into(), - dsize: di.into(), + base2k: base2k_key.into(), + k: k_ksk.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank: rank.into(), }; @@ -275,8 +275,6 @@ where // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) auto_key.automorphism_inplace(module, &auto_key_apply_prepared, scratch.borrow()); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&auto_key); - let mut sk_auto: GLWESecret> = GLWESecret::alloc_from_infos(&auto_key); sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk @@ -293,40 +291,37 @@ where let mut sk_auto_dft: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk_auto); sk_auto_dft.prepare(module, &sk_auto); - (0..auto_key.rank_in().into()).for_each(|col_i| { - (0..auto_key.dnum().into()).for_each(|row_i| { - auto_key - .at(row_i, col_i) - .decrypt(module, &mut pt, &sk_auto_dft, scratch.borrow()); - module.vec_znx_sub_scalar_inplace( - &mut pt.data, - 0, - (dsize_in - 1) + row_i * dsize_in, - &sk.data, - col_i, - ); + let max_noise: f64 = var_noise_gglwe_product_v2( + module.n() as f64, + k_ksk, + dnum_ksk, + dsize, + base2k_key, + 0.5, + 0.5, + 0f64, + SIGMA * SIGMA, + 0f64, + rank as f64, + ) + .sqrt() + .log2(); - let noise_have: f64 = pt.data.stats(base2k, 0).std().log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - n as f64, - base2k * di, - 0.5, - 0.5, - 0f64, - SIGMA * SIGMA, - 0f64, - rank as f64, - k_in, - k_apply, - ); + for row in 0..auto_key.dnum().as_usize() { + for col in 0..auto_key.rank().as_usize() { + let noise_have = auto_key + .key + .noise(module, row, col, &sk.data, &sk_auto_dft, scratch.borrow()) + .std() + .log2(); assert!( - noise_have < noise_want + 0.5, + noise_have < max_noise + 0.5, "{noise_have} {}", - noise_want + 0.5 + max_noise + 0.5 ); - }); - }); + } + } } } } diff --git a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs index 6e2a226..8b3679b 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs @@ -8,8 +8,8 @@ use crate::{ GGLWEToGGSWKeyEncryptSk, GGSWAutomorphism, GGSWEncryptSk, GGSWNoise, GLWEAutomorphismKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGLWEToGGSWKey, GGLWEToGGSWKeyLayout, GGLWEToGGSWKeyPreparedFactory, GGSW, GGSWLayout, GLWEAutomorphismKey, - GLWEAutomorphismKeyPreparedFactory, GLWESecret, GLWESecretPreparedFactory, + GGLWEToGGSWKey, GGLWEToGGSWKeyLayout, GGLWEToGGSWKeyPreparedFactory, GGSW, GGSWInfos, GGSWLayout, GLWEAutomorphismKey, + GLWEAutomorphismKeyPreparedFactory, GLWEInfos, GLWESecret, GLWESecretPreparedFactory, prepared::{GGLWEToGGSWKeyPrepared, GLWEAutomorphismKeyPrepared, GLWESecretPrepared}, }, noise::noise_ggsw_keyswitch, @@ -29,26 +29,28 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 54; - let dsize: usize = k_in.div_ceil(base2k); - let p: i64 = -5; + let base2k_in: usize = 17; + let base2k_key: usize = 13; + let base2k_out: usize = base2k_in; // MUST BE SAME + let k_in: usize = 102; + let max_dsize: usize = k_in.div_ceil(base2k_key); + let p: i64 = -5; for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_ksk: usize = k_in + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_ksk: usize = k_in + base2k_key * dsize; let k_tsk: usize = k_ksk; let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); - let dnum: usize = k_in.div_ceil(base2k * di); - let dnum_in: usize = k_in.div_euclid(base2k * di); + let dnum_in: usize = k_in / base2k_in; + let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize); let dsize_in: usize = 1; let ggsw_in_layout: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_in.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -57,7 +59,7 @@ where let ggsw_out_layout: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -66,19 +68,19 @@ where let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_tsk.into(), - dnum: dnum.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank: rank.into(), }; let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), - dnum: dnum.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank: rank.into(), }; @@ -154,7 +156,7 @@ where let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( n as f64, - base2k * di, + base2k_key * dsize, col_j, var_xs, 0f64, @@ -167,7 +169,17 @@ where ) + 0.5 }; - ct_out.assert_noise(module, &sk_prepared, &pt_scalar, &max_noise); + for row in 0..ct_out.dnum().as_usize() { + for col in 0..ct_out.rank().as_usize() + 1 { + assert!( + ct_out + .noise(module, row, col, &pt_scalar, &sk_prepared, scratch.borrow()) + .std() + .log2() + <= max_noise(col) + ) + } + } } } } @@ -187,23 +199,25 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_out: usize = 54; - let dsize: usize = k_out.div_ceil(base2k); + let base2k_out: usize = 17; + let base2k_key: usize = 13; + let k_out: usize = 102; + let max_dsize: usize = k_out.div_ceil(base2k_key); + let p: i64 = -1; for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_ksk: usize = k_out + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_ksk: usize = k_out + base2k_key * dsize; let k_tsk: usize = k_ksk; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(di * base2k); - let dnum_in: usize = k_out.div_euclid(base2k * di); + let dnum_in: usize = k_out / base2k_out; + let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize); let dsize_in: usize = 1; let ggsw_out_layout: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -212,19 +226,19 @@ where let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_tsk.into(), - dnum: dnum.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank: rank.into(), }; let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), - dnum: dnum.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank: rank.into(), }; @@ -293,7 +307,7 @@ where let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( n as f64, - base2k * di, + base2k_key * dsize, col_j, var_xs, 0f64, @@ -303,10 +317,22 @@ where k_out, k_ksk, k_tsk, - ) + 0.5 + ) + 4.0 }; - ct.assert_noise(module, &sk_prepared, &pt_scalar, &max_noise); + for row in 0..ct.dnum().as_usize() { + for col in 0..ct.rank().as_usize() + 1 { + let noise_have: f64 = ct + .noise(module, row, col, &pt_scalar, &sk_prepared, scratch.borrow()) + .std() + .log2(); + let noise_max: f64 = max_noise(col); + assert!( + noise_have <= noise_max, + "noise_have:{noise_have} > noise_max:{noise_max}", + ) + } + } } } } diff --git a/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs index 58f737a..aa881c2 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs @@ -5,14 +5,14 @@ use poulpy_hal::{ }; use crate::{ - GLWEAutomorphism, GLWEAutomorphismKeyEncryptSk, GLWEDecrypt, GLWEEncryptSk, GLWENoise, ScratchTakeCore, + GLWEAutomorphism, GLWEAutomorphismKeyEncryptSk, GLWEDecrypt, GLWEEncryptSk, GLWENoise, GLWENormalize, ScratchTakeCore, encryption::SIGMA, layouts::{ GLWE, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared}, }, - noise::log2_std_noise_gglwe_product, + var_noise_gglwe_product_v2, }; pub fn test_glwe_automorphism(module: &Module) @@ -25,55 +25,59 @@ where + GLWEAutomorphismKeyEncryptSk + GLWEAutomorphismKeyPreparedFactory + GLWENoise - + VecZnxAutomorphismInplace, + + VecZnxAutomorphismInplace + + GLWENormalize, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 60; - let dsize: usize = k_in.div_ceil(base2k); + let base2k_in: usize = 17; + let base2k_key: usize = 13; + let base2k_out: usize = 15; + let k_in: usize = 102; + let max_dsize: usize = k_in.div_ceil(base2k_key); let p: i64 = -5; for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_ksk: usize = k_in + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_ksk: usize = k_in + base2k_key * dsize; let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); - let dnum: usize = k_in.div_ceil(base2k * dsize); + let dnum: usize = k_in.div_ceil(base2k_key * dsize); let ct_in_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_in.into(), rank: rank.into(), }; let ct_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), rank: rank.into(), }; let autokey_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_out.into(), rank: rank.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), }; let mut autokey: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&autokey_infos); let mut ct_in: GLWE> = GLWE::alloc_from_infos(&ct_in_infos); let mut ct_out: GLWE> = GLWE::alloc_from_infos(&ct_out_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ct_out_infos); + let mut pt_in: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ct_in_infos); + let mut pt_out: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ct_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k_in, &mut pt_in.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey) @@ -99,7 +103,7 @@ where ct_in.encrypt_sk( module, - &pt_want, + &pt_in, &sk_prepared, &mut source_xa, &mut source_xe, @@ -112,22 +116,32 @@ where ct_out.automorphism(module, &ct_in, &autokey_prepared, scratch.borrow()); - let max_noise: f64 = log2_std_noise_gglwe_product( + let max_noise: f64 = var_noise_gglwe_product_v2( module.n() as f64, - base2k * dsize, + k_ksk, + dnum, + max_dsize, + base2k_key, 0.5, 0.5, 0f64, SIGMA * SIGMA, 0f64, rank as f64, - k_in, - k_ksk, - ); + ) + .sqrt() + .log2(); - module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0, scratch.borrow()); + module.glwe_normalize(&mut pt_out, &pt_in, scratch.borrow()); + module.vec_znx_automorphism_inplace(p, &mut pt_out.data, 0, scratch.borrow()); - ct_out.assert_noise(module, &sk_prepared, &pt_want, max_noise + 1.0); + assert!( + ct_out + .noise(module, &pt_out, &sk_prepared, scratch.borrow()) + .std() + .log2() + <= max_noise + 1.0 + ) } } } @@ -147,31 +161,33 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_out: usize = 60; - let dsize: usize = k_out.div_ceil(base2k); - let p = -5; + let base2k_out: usize = 17; + let base2k_key: usize = 13; + let k_out: usize = 102; + let max_dsize: usize = k_out.div_ceil(base2k_key); + + let p: i64 = -5; for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_ksk: usize = k_out + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_ksk: usize = k_out + base2k_key * dsize; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(base2k * dsize); + let dnum: usize = k_out.div_ceil(base2k_key * dsize); let ct_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), rank: rank.into(), }; let autokey_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), rank: rank.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), }; let mut autokey: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&autokey_infos); @@ -182,7 +198,7 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k_out, &mut pt_want.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey) @@ -221,22 +237,30 @@ where ct.automorphism_inplace(module, &autokey_prepared, scratch.borrow()); - let max_noise: f64 = log2_std_noise_gglwe_product( + let max_noise: f64 = var_noise_gglwe_product_v2( module.n() as f64, - base2k * dsize, + k_ksk, + dnum, + dsize, + base2k_key, 0.5, 0.5, 0f64, SIGMA * SIGMA, 0f64, rank as f64, - k_out, - k_ksk, - ); + ) + .sqrt() + .log2(); module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0, scratch.borrow()); - ct.assert_noise(module, &sk_prepared, &pt_want, max_noise + 1.0); + assert!( + ct.noise(module, &pt_want, &sk_prepared, scratch.borrow()) + .std() + .log2() + <= max_noise + 1.0 + ) } } } diff --git a/poulpy-core/src/tests/test_suite/conversion.rs b/poulpy-core/src/tests/test_suite/conversion.rs index 4c2fe7f..9ee70e2 100644 --- a/poulpy-core/src/tests/test_suite/conversion.rs +++ b/poulpy-core/src/tests/test_suite/conversion.rs @@ -1,20 +1,103 @@ use poulpy_hal::{ - api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow}, - layouts::{Backend, Module, Scratch, ScratchOwned, ZnxView}, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxFillUniform, VecZnxNormalize}, + layouts::{Backend, FillUniform, Module, Scratch, ScratchOwned, ZnxView}, source::Source, }; +use rug::Float; use crate::{ - GLWEDecrypt, GLWEEncryptSk, GLWEFromLWE, GLWEToLWESwitchingKeyEncryptSk, LWEDecrypt, LWEEncryptSk, LWEFromGLWE, - LWEToGLWESwitchingKeyEncryptSk, ScratchTakeCore, + GLWEDecrypt, GLWEEncryptSk, GLWEFromLWE, GLWENoise, GLWENormalize, GLWEToLWESwitchingKeyEncryptSk, LWEDecrypt, LWEEncryptSk, + LWEFromGLWE, LWEToGLWESwitchingKeyEncryptSk, SIGMA, ScratchTakeCore, layouts::{ Base2K, Degree, Dnum, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWEToLWEKey, - GLWEToLWEKeyLayout, GLWEToLWEKeyPrepared, GLWEToLWEKeyPreparedFactory, LWE, LWELayout, LWEPlaintext, LWESecret, + GLWEToLWEKeyLayout, GLWEToLWEKeyPrepared, GLWEToLWEKeyPreparedFactory, LWE, LWEInfos, LWELayout, LWEPlaintext, LWESecret, LWEToGLWEKey, LWEToGLWEKeyLayout, LWEToGLWEKeyPrepared, LWEToGLWEKeyPreparedFactory, Rank, TorusPrecision, prepared::GLWESecretPrepared, }, }; +pub fn test_glwe_base2k_conversion(module: &Module) +where + Module: GLWEEncryptSk + + GLWEDecrypt + + GLWENormalize + + VecZnxFillUniform + + GLWESecretPreparedFactory + + GLWENoise, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, +{ + let n_glwe: Degree = Degree(module.n() as u32); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + + for rank in 1_usize..3 { + for bases in [[12, 8], [8, 12]] { + let glwe_infos_in: GLWELayout = GLWELayout { + n: n_glwe, + base2k: Base2K(bases[0]), + k: TorusPrecision(34), + rank: Rank(rank as u32), + }; + + let glwe_infos_out: GLWELayout = GLWELayout { + n: n_glwe, + base2k: Base2K(bases[1]), + k: TorusPrecision(34), + rank: Rank(rank as u32), + }; + + let mut sk: GLWESecret> = GLWESecret::alloc(module.n().into(), rank.into()); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_prep: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); + sk_prep.prepare(module, &sk); + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWE::encrypt_sk_tmp_bytes(module, &glwe_infos_in).max(module.glwe_noise_tmp_bytes(&glwe_infos_out)), + ); + + let mut ct_in: GLWE> = GLWE::alloc_from_infos(&glwe_infos_in); + let mut ct_out: GLWE> = GLWE::alloc_from_infos(&glwe_infos_out); + + let pt_in: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos_in); + let pt_out: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos_out); + + ct_in.encrypt_sk( + module, + &pt_in, + &sk_prep, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut data: Vec = (0..module.n()).map(|_| Float::with_val(128, 0)).collect(); + ct_in + .data() + .decode_vec_float(ct_in.base2k().into(), 0, &mut data); + + ct_out.fill_uniform(ct_out.base2k().into(), &mut source_xa); + module.glwe_normalize(&mut ct_out, &ct_in, scratch.borrow()); + + let mut data_conv: Vec = (0..module.n()).map(|_| Float::with_val(128, 0)).collect(); + ct_out + .data() + .decode_vec_float(ct_out.base2k().into(), 0, &mut data_conv); + + assert!( + ct_out + .noise(module, &pt_out, &sk_prep, scratch.borrow()) + .std() + .log2() + <= -(ct_out.k().as_u32() as f64) + SIGMA.log2() + 0.50 + ) + } + } +} + pub fn test_lwe_to_glwe(module: &Module) where Module: GLWEFromLWE @@ -22,7 +105,8 @@ where + GLWEDecrypt + GLWESecretPreparedFactory + LWEEncryptSk - + LWEToGLWEKeyPreparedFactory, + + LWEToGLWEKeyPreparedFactory + + VecZnxNormalize, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { @@ -38,23 +122,23 @@ where let lwe_to_glwe_infos: LWEToGLWEKeyLayout = LWEToGLWEKeyLayout { n: n_glwe, - base2k: Base2K(17), - k: TorusPrecision(51), + base2k: Base2K(13), + k: TorusPrecision(92), dnum: Dnum(2), rank_out: rank, }; let glwe_infos: GLWELayout = GLWELayout { n: n_glwe, - base2k: Base2K(17), - k: TorusPrecision(34), + base2k: Base2K(15), + k: TorusPrecision(75), rank, }; let lwe_infos: LWELayout = LWELayout { n: n_lwe, base2k: Base2K(17), - k: TorusPrecision(34), + k: TorusPrecision(75), }; let mut scratch: ScratchOwned = ScratchOwned::alloc( @@ -78,7 +162,14 @@ where lwe_pt.encode_i64(data, k_lwe_pt); let mut lwe_ct: LWE> = LWE::alloc_from_infos(&lwe_infos); - lwe_ct.encrypt_sk(module, &lwe_pt, &sk_lwe, &mut source_xa, &mut source_xe); + lwe_ct.encrypt_sk( + module, + &lwe_pt, + &sk_lwe, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); let mut ksk: LWEToGLWEKey> = LWEToGLWEKey::alloc_from_infos(&lwe_to_glwe_infos); @@ -101,7 +192,19 @@ where let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); glwe_ct.decrypt(module, &mut glwe_pt, &sk_glwe_prepared, scratch.borrow()); - assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt.data.at(0, 0)[0]); + let mut lwe_pt_conv = LWEPlaintext::alloc(glwe_pt.base2k(), lwe_pt.k()); + + module.vec_znx_normalize( + glwe_pt.base2k().as_usize(), + lwe_pt_conv.data_mut(), + 0, + lwe_pt.base2k().as_usize(), + lwe_pt.data(), + 0, + scratch.borrow(), + ); + + assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt_conv.data.at(0, 0)[0]); } pub fn test_glwe_to_lwe(module: &Module) @@ -114,7 +217,8 @@ where + GLWEDecrypt + GLWESecretPreparedFactory + GLWEToLWESwitchingKeyEncryptSk - + GLWEToLWEKeyPreparedFactory, + + GLWEToLWEKeyPreparedFactory + + VecZnxNormalize, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { @@ -126,8 +230,8 @@ where let glwe_to_lwe_infos: GLWEToLWEKeyLayout = GLWEToLWEKeyLayout { n: n_glwe, - base2k: Base2K(17), - k: TorusPrecision(51), + base2k: Base2K(13), + k: TorusPrecision(91), dnum: Dnum(2), rank_in: rank, }; @@ -135,14 +239,14 @@ where let glwe_infos: GLWELayout = GLWELayout { n: n_glwe, base2k: Base2K(17), - k: TorusPrecision(34), + k: TorusPrecision(72), rank, }; let lwe_infos: LWELayout = LWELayout { n: n_lwe, - base2k: Base2K(17), - k: TorusPrecision(34), + base2k: Base2K(15), + k: TorusPrecision(72), }; let mut source_xs: Source = Source::new([0u8; 32]); @@ -171,8 +275,6 @@ where let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); glwe_pt.encode_vec_i64(&data, k_lwe_pt); - println!("glwe_pt: {glwe_pt}"); - let mut glwe_ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); glwe_ct.encrypt_sk( module, @@ -202,7 +304,19 @@ where lwe_ct.from_glwe(module, &glwe_ct, a_idx, &ksk_prepared, scratch.borrow()); let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc_from_infos(&lwe_infos); - lwe_ct.decrypt(module, &mut lwe_pt, &sk_lwe); + lwe_ct.decrypt(module, &mut lwe_pt, &sk_lwe, scratch.borrow()); - assert_eq!(glwe_pt.data.at(0, 0)[a_idx], lwe_pt.data.at(0, 0)[0]); + let mut glwe_pt_conv = GLWEPlaintext::alloc(glwe_ct.n(), lwe_pt.base2k(), lwe_pt.k()); + + module.vec_znx_normalize( + lwe_pt.base2k().as_usize(), + glwe_pt_conv.data_mut(), + 0, + glwe_ct.base2k().as_usize(), + glwe_pt.data(), + 0, + scratch.borrow(), + ); + + assert_eq!(glwe_pt_conv.data.at(0, 0)[a_idx], lwe_pt.data.at(0, 0)[0]); } diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs index 33d3bd7..9f917af 100644 --- a/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_atk.rs @@ -9,8 +9,8 @@ use crate::{ GLWESwitchingKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ - GLWEAutomorphismKey, GLWEAutomorphismKeyDecompress, GLWEAutomorphismKeyLayout, GLWEInfos, GLWESecret, - GLWESecretPreparedFactory, GLWESwitchingKeyDecompress, compressed::GLWEAutomorphismKeyCompressed, + GGLWEInfos, GLWEAutomorphismKey, GLWEAutomorphismKeyDecompress, GLWEAutomorphismKeyLayout, GLWEInfos, GLWESecret, + GLWESecretPreparedFactory, GLWESwitchingKeyDecompress, LWEInfos, compressed::GLWEAutomorphismKeyCompressed, prepared::GLWESecretPrepared, }, noise::GGLWENoise, @@ -84,8 +84,26 @@ where let mut sk_out_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, sk_out.rank()); sk_out_prepared.prepare(module, &sk_out); - atk.key - .assert_noise(module, &sk_out_prepared, &sk.data, SIGMA); + let max_noise: f64 = SIGMA.log2() - (atk.k().as_usize() as f64) + 0.5; + + for row in 0..atk.dnum().as_usize() { + for col in 0..atk.rank().as_usize() { + assert!( + atk.key + .noise( + module, + row, + col, + &sk.data, + &sk_out_prepared, + scratch.borrow() + ) + .std() + .log2() + <= max_noise + ) + } + } } } } @@ -106,18 +124,18 @@ where { let base2k: usize = 12; let k_ksk: usize = 60; - let dsize: usize = k_ksk.div_ceil(base2k) - 1; - for rank in 1_usize..3 { - for di in 1..dsize + 1 { + let max_dsize: usize = k_ksk.div_ceil(base2k) - 1; + for rank in 2_usize..3 { + for dsize in 1..max_dsize + 1 { let n: usize = module.n(); - let dnum: usize = (k_ksk - di * base2k) / (di * base2k); + let dnum: usize = (k_ksk - dsize * base2k) / (dsize * base2k); let atk_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), rank: rank.into(), }; @@ -134,7 +152,7 @@ where let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&atk_infos); sk.fill_ternary_prob(0.5, &mut source_xs); - let p = -5; + let p: i64 = -5; let seed_xa: [u8; 32] = [1u8; 32]; @@ -156,8 +174,30 @@ where let mut atk: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&atk_infos); atk.decompress(module, &atk_compressed); - atk.key - .assert_noise(module, &sk_out_prepared, &sk.data, SIGMA); + let max_noise: f64 = SIGMA.log2() - (atk.k().as_usize() as f64) + 0.5; + + for row in 0..atk.dnum().as_usize() { + for col in 0..atk.rank().as_usize() { + let noise_have = atk + .key + .noise( + module, + row, + col, + &sk.data, + &sk_out_prepared, + scratch.borrow(), + ) + .std() + .log2(); + + assert!( + noise_have < max_noise + 0.5, + "row:{row} col:{col} noise_have:{noise_have} > max_noise:{}", + max_noise + 0.5 + ); + } + } } } } diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs index 2b64f02..36d0a72 100644 --- a/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs @@ -5,12 +5,13 @@ use poulpy_hal::{ }; use crate::{ - GGLWEEncryptSk, GGLWEKeyswitch, GLWESwitchingKeyCompressedEncryptSk, GLWESwitchingKeyEncryptSk, ScratchTakeCore, + GGLWECompressedEncryptSk, GGLWEEncryptSk, GGLWEKeyswitch, GLWESwitchingKeyCompressedEncryptSk, GLWESwitchingKeyEncryptSk, + ScratchTakeCore, decryption::GLWEDecrypt, encryption::SIGMA, layouts::{ - GGLWELayout, GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyCompressed, - GLWESwitchingKeyDecompress, + GGLWE, GGLWECompressed, GGLWEInfos, GGLWELayout, GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, + GLWESwitchingKeyCompressed, GLWESwitchingKeyDecompress, LWEInfos, prepared::{GGLWEPreparedFactory, GLWESecretPrepared}, }, noise::GGLWENoise, @@ -74,8 +75,30 @@ where scratch.borrow(), ); - ksk.key - .assert_noise(module, &sk_out_prepared, &sk_in.data, SIGMA); + let max_noise: f64 = SIGMA.log2() - (ksk.k().as_usize() as f64) + 0.5; + + for row in 0..ksk.dnum().as_usize() { + for col in 0..ksk.rank_in().as_usize() { + let noise_have = ksk + .key + .noise( + module, + row, + col, + &sk_in.data, + &sk_out_prepared, + scratch.borrow(), + ) + .std() + .log2(); + + assert!( + noise_have < max_noise + 0.5, + "row:{row} col:{col} noise_have:{noise_have} > max_noise:{}", + max_noise + 0.5 + ); + } + } } } } @@ -99,18 +122,18 @@ where let n: usize = module.n(); let base2k: usize = 12; let k_ksk: usize = 54; - let dsize: usize = k_ksk / base2k; + let max_dsize: usize = k_ksk / base2k; for rank_in in 1_usize..3 { for rank_out in 1_usize..3 { - for di in 1_usize..dsize + 1 { - let dnum: usize = (k_ksk - di * base2k) / (di * base2k); + for dsize in 1_usize..max_dsize { + let dnum: usize = (k_ksk - dsize * base2k) / (dsize * base2k); let gglwe_infos: GGLWELayout = GGLWELayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), rank_in: rank_in.into(), rank_out: rank_out.into(), }; @@ -148,8 +171,122 @@ where let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&gglwe_infos); ksk.decompress(module, &ksk_compressed); - ksk.key - .assert_noise(module, &sk_out_prepared, &sk_in.data, SIGMA); + let max_noise: f64 = SIGMA.log2() - (ksk.k().as_usize() as f64) + 0.5; + + for row in 0..ksk.dnum().as_usize() { + for col in 0..ksk.rank_in().as_usize() { + let noise_have = ksk + .key + .noise( + module, + row, + col, + &sk_in.data, + &sk_out_prepared, + scratch.borrow(), + ) + .std() + .log2(); + + assert!( + noise_have < max_noise + 0.5, + "row:{row} col:{col} noise_have:{noise_have} > max_noise:{}", + max_noise + 0.5 + ); + } + } + } + } + } +} + +pub fn test_gglwe_compressed_encrypt_sk(module: &Module) +where + Module: GGLWEEncryptSk + + GGLWEPreparedFactory + + GGLWEKeyswitch + + GLWEDecrypt + + GLWESecretPreparedFactory + + GLWESwitchingKeyEncryptSk + + GGLWECompressedEncryptSk + + GLWESwitchingKeyDecompress + + GGLWENoise + + VecZnxFillUniform, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, +{ + let n: usize = module.n(); + let base2k: usize = 12; + let k_ksk: usize = 54; + let max_dsize: usize = k_ksk / base2k; + for rank_in in 1_usize..3 { + for rank_out in 1_usize..3 { + for dsize in 1_usize..max_dsize + 1 { + let dnum: usize = (k_ksk - dsize * base2k) / (dsize * base2k); + + let gglwe_infos: GGLWELayout = GGLWELayout { + n: n.into(), + base2k: base2k.into(), + k: k_ksk.into(), + dnum: dnum.into(), + dsize: dsize.into(), + rank_in: rank_in.into(), + rank_out: rank_out.into(), + }; + + let mut ksk_compressed: GGLWECompressed> = GGLWECompressed::alloc_from_infos(&gglwe_infos); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = + ScratchOwned::alloc(GGLWECompressed::encrypt_sk_tmp_bytes(module, &gglwe_infos)); + + let mut sk_in: GLWESecret> = GLWESecret::alloc(n.into(), rank_in.into()); + sk_in.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_out: GLWESecret> = GLWESecret::alloc(n.into(), rank_out.into()); + sk_out.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_out_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank_out.into()); + sk_out_prepared.prepare(module, &sk_out); + + let seed_xa = [1u8; 32]; + + ksk_compressed.encrypt_sk( + module, + &sk_in.data, + &sk_out_prepared, + seed_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut ksk: GGLWE> = GGLWE::alloc_from_infos(&gglwe_infos); + ksk.decompress(module, &ksk_compressed); + + let max_noise: f64 = SIGMA.log2() - (ksk.k().as_usize() as f64) + 0.5; + + for row in 0..ksk.dnum().as_usize() { + for col in 0..ksk.rank_in().as_usize() { + let noise_have = ksk + .noise( + module, + row, + col, + &sk_in.data, + &sk_out_prepared, + scratch.borrow(), + ) + .std() + .log2(); + + assert!( + noise_have < max_noise + 0.5, + "row:{row} col:{col} noise_have:{noise_have} > max_noise:{}", + max_noise + 0.5 + ); + } + } } } } diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_to_ggsw_key.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_to_ggsw_key.rs index 884e21a..c508bf4 100644 --- a/poulpy-core/src/tests/test_suite/encryption/gglwe_to_ggsw_key.rs +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_to_ggsw_key.rs @@ -9,8 +9,9 @@ use crate::{ decryption::GLWEDecrypt, encryption::SIGMA, layouts::{ - Dsize, GGLWEDecompress, GGLWEToGGSWKey, GGLWEToGGSWKeyCompressed, GGLWEToGGSWKeyDecompress, GGLWEToGGSWKeyLayout, - GLWESecret, GLWESecretPreparedFactory, GLWESecretTensor, GLWESecretTensorFactory, LWEInfos, prepared::GLWESecretPrepared, + Dsize, GGLWE, GGLWEDecompress, GGLWEInfos, GGLWEToGGSWKey, GGLWEToGGSWKeyCompressed, GGLWEToGGSWKeyDecompress, + GGLWEToGGSWKeyLayout, GLWESecret, GLWESecretPreparedFactory, GLWESecretTensor, GLWESecretTensorFactory, LWEInfos, + prepared::GLWESecretPrepared, }, }; @@ -79,9 +80,17 @@ where ); } - println!("pt_want: {}", pt_want.as_vec_znx()); - - module.gglwe_assert_noise(key.at(i), &sk_prepared, &pt_want, max_noise); + let ksk: &GGLWE> = key.at(i); + for row in 0..ksk.dnum().as_usize() { + for col in 0..ksk.rank_in().as_usize() { + assert!( + ksk.noise(module, row, col, &pt_want, &sk_prepared, scratch.borrow()) + .std() + .log2() + <= max_noise + ) + } + } } } } @@ -137,8 +146,31 @@ where let mut sk_tensor: GLWESecretTensor> = GLWESecretTensor::alloc_from_infos(&sk); sk_tensor.prepare(module, &sk, scratch.borrow()); + let max_noise = SIGMA.log2() + 0.5 - (key.k().as_u32() as f64); + + let mut pt_want: ScalarZnx> = ScalarZnx::alloc(module.n(), rank); + for i in 0..rank { - module.gglwe_assert_noise(key.at(i), &sk_prepared, &sk_tensor.data, SIGMA + 0.5); + for j in 0..rank { + module.vec_znx_copy( + &mut pt_want.as_vec_znx_mut(), + j, + &sk_tensor.at(i, j).as_vec_znx(), + 0, + ); + } + + let ksk: &GGLWE> = key.at(i); + for row in 0..ksk.dnum().as_usize() { + for col in 0..ksk.rank_in().as_usize() { + assert!( + ksk.noise(module, row, col, &pt_want, &sk_prepared, scratch.borrow()) + .std() + .log2() + <= max_noise + ) + } + } } } } diff --git a/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs index e53eedf..8fd0b7a 100644 --- a/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/ggsw_ct.rs @@ -8,8 +8,8 @@ use crate::{ GGSWCompressedEncryptSk, GGSWEncryptSk, GGSWNoise, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGSW, GGSWDecompress, GGSWLayout, GLWESecret, GLWESecretPreparedFactory, compressed::GGSWCompressed, - prepared::GLWESecretPrepared, + GGSW, GGSWDecompress, GGSWInfos, GGSWLayout, GLWEInfos, GLWESecret, GLWESecretPreparedFactory, + compressed::GGSWCompressed, prepared::GLWESecretPrepared, }, }; @@ -65,7 +65,16 @@ where let noise_f = |_col_i: usize| -(k as f64) + SIGMA.log2() + 0.5; - ct.assert_noise(module, &sk_prepared, &pt_scalar, &noise_f); + for row in 0..ct.dnum().as_usize() { + for col in 0..ct.rank().as_usize() + 1 { + assert!( + ct.noise(module, row, col, &pt_scalar, &sk_prepared, scratch.borrow()) + .std() + .log2() + <= noise_f(col) + ) + } + } } } } @@ -126,7 +135,16 @@ where let mut ct: GGSW> = GGSW::alloc_from_infos(&ggsw_infos); ct.decompress(module, &ct_compressed); - ct.assert_noise(module, &sk_prepared, &pt_scalar, &noise_f); + for row in 0..ct.dnum().as_usize() { + for col in 0..ct.rank().as_usize() + 1 { + assert!( + ct.noise(module, row, col, &pt_scalar, &sk_prepared, scratch.borrow()) + .std() + .log2() + <= noise_f(col) + ) + } + } } } } diff --git a/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs b/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs index 7d4fcff..81901ad 100644 --- a/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs @@ -5,8 +5,7 @@ use poulpy_hal::{ }; use crate::{ - GLWECompressedEncryptSk, GLWEEncryptPk, GLWEEncryptSk, GLWEPublicKeyGenerate, GLWESub, ScratchTakeCore, - decryption::GLWEDecrypt, + GLWECompressedEncryptSk, GLWEEncryptPk, GLWEEncryptSk, GLWENoise, GLWEPublicKeyGenerate, GLWESub, ScratchTakeCore, encryption::SIGMA, layouts::{ GLWE, GLWELayout, GLWEPlaintext, GLWEPlaintextLayout, GLWEPublicKey, GLWEPublicKeyPreparedFactory, GLWESecret, @@ -18,7 +17,7 @@ use crate::{ pub fn test_glwe_encrypt_sk(module: &Module) where - Module: GLWEEncryptSk + GLWEDecrypt + GLWESecretPreparedFactory + VecZnxFillUniform + GLWESub, + Module: GLWEEncryptSk + GLWENoise + GLWESecretPreparedFactory + VecZnxFillUniform + GLWESub, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { @@ -44,14 +43,13 @@ where let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&pt_infos); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&pt_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = - ScratchOwned::alloc(GLWE::encrypt_sk_tmp_bytes(module, &glwe_infos) | GLWE::decrypt_tmp_bytes(module, &glwe_infos)); + ScratchOwned::alloc(GLWE::encrypt_sk_tmp_bytes(module, &glwe_infos).max(module.glwe_noise_tmp_bytes(&glwe_infos))); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk.fill_ternary_prob(0.5, &mut source_xs); @@ -70,12 +68,11 @@ where scratch.borrow(), ); - ct.decrypt(module, &mut pt_have, &sk_prepared, scratch.borrow()); - - module.glwe_sub_inplace(&mut pt_want, &pt_have); - - let noise_have: f64 = pt_want.data.stats(base2k, 0).std() * (ct.k().as_u32() as f64).exp2(); - let noise_want: f64 = SIGMA; + let noise_have: f64 = ct + .noise(module, &pt_want, &sk_prepared, scratch.borrow()) + .std() + .log2(); + let noise_want: f64 = SIGMA.log2() - (ct.k().as_usize() as f64) + 0.5; assert!(noise_have <= noise_want + 0.2); } @@ -83,7 +80,7 @@ where pub fn test_glwe_compressed_encrypt_sk(module: &Module) where - Module: GLWECompressedEncryptSk + GLWEDecrypt + GLWESecretPreparedFactory + VecZnxFillUniform + GLWESub, + Module: GLWECompressedEncryptSk + GLWENoise + GLWESecretPreparedFactory + VecZnxFillUniform + GLWESub, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { @@ -92,7 +89,6 @@ where let k_pt: usize = 30; for rank in 1_usize..3 { - // println!("rank: {}", rank); let n: usize = module.n(); let glwe_infos: GLWELayout = GLWELayout { @@ -111,14 +107,13 @@ where let mut ct_compressed: GLWECompressed> = GLWECompressed::alloc_from_infos(&glwe_infos); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&pt_infos); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&pt_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECompressed::encrypt_sk_tmp_bytes(module, &glwe_infos) | GLWE::decrypt_tmp_bytes(module, &glwe_infos), + GLWECompressed::encrypt_sk_tmp_bytes(module, &glwe_infos).max(module.glwe_noise_tmp_bytes(&glwe_infos)), ); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); @@ -143,24 +138,18 @@ where let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); ct.decompress(module, &ct_compressed); - ct.decrypt(module, &mut pt_have, &sk_prepared, scratch.borrow()); - - module.glwe_sub_inplace(&mut pt_want, &pt_have); - - let noise_have: f64 = pt_want.data.stats(base2k, 0).std() * (ct.k().as_u32() as f64).exp2(); - let noise_want: f64 = SIGMA; - - assert!( - noise_have <= noise_want + 0.2, - "{noise_have} <= {}", - noise_want + 0.2 - ); + let noise_have: f64 = ct + .noise(module, &pt_want, &sk_prepared, scratch.borrow()) + .std() + .log2(); + let noise_want: f64 = SIGMA.log2() - (ct.k().as_usize() as f64) + 0.5; + assert!(noise_have <= noise_want + 0.2); } } pub fn test_glwe_encrypt_zero_sk(module: &Module) where - Module: GLWEEncryptSk + GLWEDecrypt + GLWESecretPreparedFactory + VecZnxFillUniform + GLWESub, + Module: GLWEEncryptSk + GLWENoise + GLWESecretPreparedFactory + VecZnxFillUniform + GLWESub, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { @@ -177,14 +166,17 @@ where rank: rank.into(), }; - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); + let pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([1u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = - ScratchOwned::alloc(GLWE::decrypt_tmp_bytes(module, &glwe_infos) | GLWE::encrypt_sk_tmp_bytes(module, &glwe_infos)); + let mut scratch: ScratchOwned = ScratchOwned::alloc( + module + .glwe_noise_tmp_bytes(&glwe_infos) + .max(GLWE::encrypt_sk_tmp_bytes(module, &glwe_infos)), + ); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk.fill_ternary_prob(0.5, &mut source_xs); @@ -201,9 +193,13 @@ where &mut source_xe, scratch.borrow(), ); - ct.decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); - assert!((SIGMA - pt.data.stats(base2k, 0).std() * (k_ct as f64).exp2()) <= 0.2); + let noise_have: f64 = ct + .noise(module, &pt, &sk_prepared, scratch.borrow()) + .std() + .log2(); + let noise_want: f64 = SIGMA.log2() - (ct.k().as_usize() as f64) + 0.5; + assert!(noise_have <= noise_want + 0.2); } } @@ -212,7 +208,7 @@ where Module: GLWEEncryptPk + GLWEPublicKeyPreparedFactory + GLWEPublicKeyGenerate - + GLWEDecrypt + + GLWENoise + GLWESecretPreparedFactory + VecZnxFillUniform + GLWESub, @@ -233,7 +229,6 @@ where }; let mut ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); - let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); let mut source_xs: Source = Source::new([0u8; 32]); @@ -241,8 +236,11 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xu: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = - ScratchOwned::alloc(GLWE::decrypt_tmp_bytes(module, &glwe_infos) | GLWE::encrypt_pk_tmp_bytes(module, &glwe_infos)); + let mut scratch: ScratchOwned = ScratchOwned::alloc( + module + .glwe_noise_tmp_bytes(&glwe_infos) + .max(GLWE::encrypt_pk_tmp_bytes(module, &glwe_infos)), + ); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); sk.fill_ternary_prob(0.5, &mut source_xs); @@ -267,17 +265,11 @@ where scratch.borrow(), ); - ct.decrypt(module, &mut pt_have, &sk_prepared, scratch.borrow()); - - module.glwe_sub_inplace(&mut pt_want, &pt_have); - - let noise_have: f64 = pt_want.data.stats(base2k, 0).std().log2(); + let noise_have: f64 = ct + .noise(module, &pt_want, &sk_prepared, scratch.borrow()) + .std() + .log2(); let noise_want: f64 = ((((rank as f64) + 1.0) * n as f64 * 0.5 * SIGMA * SIGMA).sqrt()).log2() - (k_ct as f64); - - assert!( - noise_have <= noise_want + 0.2, - "{noise_have} <= {}", - noise_want + 0.2 - ); + assert!(noise_have <= noise_want + 0.2); } } diff --git a/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs b/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs index 26baa92..a0dcbf5 100644 --- a/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs +++ b/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs @@ -9,8 +9,8 @@ use crate::{ decryption::GLWEDecrypt, encryption::SIGMA, layouts::{ - Dsize, GGLWEDecompress, GLWESecret, GLWESecretPreparedFactory, GLWESecretTensor, GLWESecretTensorFactory, GLWETensorKey, - GLWETensorKeyCompressed, GLWETensorKeyLayout, prepared::GLWESecretPrepared, + Dsize, GGLWEDecompress, GGLWEInfos, GLWESecret, GLWESecretPreparedFactory, GLWESecretTensor, GLWESecretTensorFactory, + GLWETensorKey, GLWETensorKeyCompressed, GLWETensorKeyLayout, LWEInfos, prepared::GLWESecretPrepared, }, }; @@ -67,7 +67,27 @@ where let mut sk_tensor: GLWESecretTensor> = GLWESecretTensor::alloc_from_infos(&sk); sk_tensor.prepare(module, &sk, scratch.borrow()); - module.gglwe_assert_noise(&tensor_key, &sk_prepared, &sk_tensor.data, SIGMA + 0.5); + let max_noise: f64 = SIGMA.log2() - (tensor_key.k().as_usize() as f64) + 0.5; + + for row in 0..tensor_key.dnum().as_usize() { + for col in 0..tensor_key.rank_in().as_usize() { + assert!( + tensor_key + .0 + .noise( + module, + row, + col, + &sk_tensor.data, + &sk_prepared, + scratch.borrow() + ) + .std() + .log2() + <= max_noise + ) + } + } } } @@ -124,6 +144,26 @@ where let mut sk_tensor: GLWESecretTensor> = GLWESecretTensor::alloc_from_infos(&sk); sk_tensor.prepare(module, &sk, scratch.borrow()); - module.gglwe_assert_noise(&tensor_key, &sk_prepared, &sk_tensor.data, SIGMA + 0.5); + let max_noise: f64 = SIGMA.log2() - (tensor_key.k().as_usize() as f64) + 0.5; + + for row in 0..tensor_key.dnum().as_usize() { + for col in 0..tensor_key.rank_in().as_usize() { + assert!( + tensor_key + .0 + .noise( + module, + row, + col, + &sk_tensor.data, + &sk_prepared, + scratch.borrow() + ) + .std() + .log2() + <= max_noise + ) + } + } } } diff --git a/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs b/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs index 07a0926..7161546 100644 --- a/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs +++ b/poulpy-core/src/tests/test_suite/external_product/gglwe_ksk.rs @@ -8,7 +8,8 @@ use crate::{ GGLWEExternalProduct, GGLWENoise, GGSWEncryptSk, GLWESwitchingKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGSW, GGSWLayout, GGSWPreparedFactory, GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, + GGLWEInfos, GGSW, GGSWLayout, GGSWPreparedFactory, GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, + GLWESwitchingKeyLayout, prepared::{GGSWPrepared, GLWESecretPrepared}, }, noise::noise_ggsw_product, @@ -27,24 +28,28 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 60; - let dsize: usize = k_in.div_ceil(base2k); + let base2k_in: usize = 17; + let base2k_key: usize = 13; + let base2k_out: usize = base2k_in; // MUST BE SAME + let k_in: usize = 102; + let max_dsize: usize = k_in.div_ceil(base2k_key); + for rank_in in 1_usize..3 { for rank_out in 1_usize..3 { - for di in 1_usize..dsize + 1 { - let k_ggsw: usize = k_in + base2k * di; + for dsize in 1_usize..max_dsize + 1 { + let k_ggsw: usize = k_in + base2k_key * dsize; let k_out: usize = k_in; // Better capture noise. let n: usize = module.n(); - let dnum: usize = k_in.div_ceil(base2k * di); + let dnum_in: usize = k_in / base2k_in; + let dnum: usize = k_in.div_ceil(base2k_key * dsize); let dsize_in: usize = 1; let gglwe_in_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_in.into(), - dnum: dnum.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank_in: rank_in.into(), rank_out: rank_out.into(), @@ -52,9 +57,9 @@ where let gglwe_out_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), - dnum: dnum.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank_in: rank_in.into(), rank_out: rank_out.into(), @@ -62,10 +67,10 @@ where let ggsw_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ggsw.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), rank: rank_out.into(), }; @@ -143,7 +148,7 @@ where let max_noise: f64 = noise_ggsw_product( n as f64, - base2k * di, + base2k_key * dsize, var_xs, var_msg, var_a0_err, @@ -155,9 +160,25 @@ where k_ggsw, ); - ct_gglwe_out - .key - .assert_noise(module, &sk_out_prepared, &sk_in.data, max_noise + 0.5); + for row in 0..ct_gglwe_out.dnum().as_usize() { + for col in 0..ct_gglwe_out.rank_in().as_usize() { + assert!( + ct_gglwe_out + .key + .noise( + module, + row, + col, + &sk_in.data, + &sk_out_prepared, + scratch.borrow() + ) + .std() + .log2() + <= max_noise + 0.5 + ) + } + } } } } @@ -176,24 +197,27 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_out: usize = 60; - let dsize: usize = k_out.div_ceil(base2k); + let base2k_out: usize = 17; + let base2k_key: usize = 13; + let k_out: usize = 102; + let max_dsize: usize = k_out.div_ceil(base2k_key); + for rank_in in 1_usize..3 { for rank_out in 1_usize..3 { - for di in 1_usize..dsize + 1 { - let k_ggsw: usize = k_out + base2k * di; + for dsize in 1_usize..max_dsize + 1 { + let k_ggsw: usize = k_out + base2k_key * dsize; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(base2k * di); + let dnum_in: usize = k_out / base2k_out; + let dnum: usize = k_out.div_ceil(base2k_key * dsize); let dsize_in: usize = 1; let gglwe_out_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), - dnum: dnum.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank_in: rank_in.into(), rank_out: rank_out.into(), @@ -201,10 +225,10 @@ where let ggsw_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ggsw.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), rank: rank_out.into(), }; @@ -281,7 +305,7 @@ where let max_noise: f64 = noise_ggsw_product( n as f64, - base2k * di, + base2k_key * dsize, var_xs, var_msg, var_a0_err, @@ -293,9 +317,25 @@ where k_ggsw, ); - ct_gglwe - .key - .assert_noise(module, &sk_out_prepared, &sk_in.data, max_noise + 0.5); + for row in 0..ct_gglwe.dnum().as_usize() { + for col in 0..ct_gglwe.rank_in().as_usize() { + assert!( + ct_gglwe + .key + .noise( + module, + row, + col, + &sk_in.data, + &sk_out_prepared, + scratch.borrow() + ) + .std() + .log2() + <= max_noise + 0.5 + ) + } + } } } } diff --git a/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs index b455b70..501a161 100644 --- a/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/external_product/ggsw_ct.rs @@ -8,7 +8,7 @@ use crate::{ GGSWEncryptSk, GGSWExternalProduct, GGSWNoise, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGSW, GGSWLayout, GGSWPreparedFactory, GLWESecret, GLWESecretPreparedFactory, + GGSW, GGSWInfos, GGSWLayout, GGSWPreparedFactory, GLWEInfos, GLWESecret, GLWESecretPreparedFactory, prepared::{GGSWPrepared, GLWESecretPrepared}, }, noise::noise_ggsw_product, @@ -26,23 +26,26 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 60; - let dsize: usize = k_in.div_ceil(base2k); + let base2k_in: usize = 17; + let base2k_key: usize = 13; + let base2k_out: usize = base2k_in; // MUST BE SAME + let k_in: usize = 102; + let max_dsize: usize = k_in.div_ceil(base2k_key); + for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_apply: usize = k_in + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_apply: usize = k_in + base2k_key * dsize; let k_out: usize = k_in; // Better capture noise. let n: usize = module.n(); - let dnum: usize = k_in.div_ceil(base2k * di); - let dnum_in: usize = k_in.div_euclid(base2k * di); + let dnum: usize = k_in.div_ceil(base2k_key * dsize); + let dnum_in: usize = k_in / base2k_in; let dsize_in: usize = 1; let ggsw_in_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_in.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -51,7 +54,7 @@ where let ggsw_out_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -60,10 +63,10 @@ where let ggsw_apply_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_apply.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), rank: rank.into(), }; @@ -130,7 +133,7 @@ where let max_noise = |_col_j: usize| -> f64 { noise_ggsw_product( n as f64, - base2k * di, + base2k_key * dsize, 0.5, var_msg, var_a0_err, @@ -143,7 +146,17 @@ where ) + 0.5 }; - ggsw_out.assert_noise(module, &sk_prepared, &pt_in, &max_noise); + for row in 0..ggsw_out.dnum().as_usize() { + for col in 0..ggsw_out.rank().as_usize() + 1 { + assert!( + ggsw_out + .noise(module, row, col, &pt_in, &sk_prepared, scratch.borrow()) + .std() + .log2() + <= max_noise(col) + ) + } + } } } } @@ -160,21 +173,23 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_out: usize = 60; - let dsize: usize = k_out.div_ceil(base2k); + let base2k_out: usize = 17; + let base2k_key: usize = 13; + let k_out: usize = 102; + let max_dsize: usize = k_out.div_ceil(base2k_key); + for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_apply: usize = k_out + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_apply: usize = k_out + base2k_key * dsize; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(di * base2k); - let dnum_in: usize = k_out.div_euclid(base2k * di); + let dnum: usize = k_out.div_ceil(dsize * base2k_key); + let dnum_in: usize = k_out / base2k_out; let dsize_in: usize = 1; let ggsw_out_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -183,10 +198,10 @@ where let ggsw_apply_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_apply.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), rank: rank.into(), }; @@ -253,7 +268,7 @@ where let max_noise = |_col_j: usize| -> f64 { noise_ggsw_product( n as f64, - base2k * di, + base2k_key * dsize, 0.5, var_msg, var_a0_err, @@ -266,7 +281,17 @@ where ) + 0.5 }; - ggsw_out.assert_noise(module, &sk_prepared, &pt_in, &max_noise); + for row in 0..ggsw_out.dnum().as_usize() { + for col in 0..ggsw_out.rank().as_usize() + 1 { + assert!( + ggsw_out + .noise(module, row, col, &pt_in, &sk_prepared, scratch.borrow()) + .std() + .log2() + <= max_noise(col) + ) + } + } } } } diff --git a/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs b/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs index 0425d35..49ef81f 100644 --- a/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/external_product/glwe_ct.rs @@ -5,7 +5,7 @@ use poulpy_hal::{ }; use crate::{ - GGSWEncryptSk, GLWEEncryptSk, GLWEExternalProduct, GLWENoise, ScratchTakeCore, + GGSWEncryptSk, GLWEEncryptSk, GLWEExternalProduct, GLWENoise, GLWENormalize, ScratchTakeCore, encryption::SIGMA, layouts::{ GGSW, GGSWLayout, GGSWPreparedFactory, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, @@ -24,41 +24,44 @@ where + GLWEEncryptSk + GLWENoise + VecZnxRotateInplace - + GLWESecretPreparedFactory, + + GLWESecretPreparedFactory + + GLWENormalize, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 45; - let dsize: usize = k_in.div_ceil(base2k); + let base2k_in: usize = 17; + let base2k_key: usize = 13; + let base2k_out: usize = 15; + let k_in: usize = 102; + let max_dsize: usize = k_in.div_ceil(base2k_key); for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_ggsw: usize = k_in + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_ggsw: usize = k_in + base2k_key * dsize; let k_out: usize = k_ggsw; // Better capture noise let n: usize = module.n(); - let dnum: usize = k_in.div_ceil(base2k * dsize); + let dnum: usize = k_in.div_ceil(k_ggsw * dsize); let glwe_in_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_in.into(), rank: rank.into(), }; let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), rank: rank.into(), }; let ggsw_apply_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ggsw.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), rank: rank.into(), }; @@ -66,16 +69,17 @@ where let mut glwe_in: GLWE> = GLWE::alloc_from_infos(&glwe_in_infos); let mut glwe_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); let mut pt_ggsw: ScalarZnx> = ScalarZnx::alloc(n, 1); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_in_infos); + let mut pt_in: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_in_infos); + let mut pt_out: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); // Random input plaintext - module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k_in, &mut pt_in.data, 0, &mut source_xa); - pt_want.data.at_mut(0, 0)[1] = 1; + pt_in.data.at_mut(0, 0)[1] = 1; let k: usize = 1; @@ -104,7 +108,7 @@ where glwe_in.encrypt_sk( module, - &pt_want, + &pt_in, &sk_prepared, &mut source_xa, &mut source_xe, @@ -116,7 +120,9 @@ where glwe_out.external_product(module, &glwe_in, &ct_ggsw_prepared, scratch.borrow()); - module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0, scratch.borrow()); + module.vec_znx_rotate_inplace(k as i64, &mut pt_in.data, 0, scratch.borrow()); + + module.glwe_normalize(&mut pt_out, &pt_in, scratch.borrow()); let var_gct_err_lhs: f64 = SIGMA * SIGMA; let var_gct_err_rhs: f64 = 0f64; @@ -127,7 +133,7 @@ where let max_noise: f64 = noise_ggsw_product( n as f64, - base2k * dsize, + base2k_key * max_dsize, 0.5, var_msg, var_a0_err, @@ -139,7 +145,13 @@ where k_ggsw, ); - glwe_out.assert_noise(module, &sk_prepared, &pt_want, max_noise + 0.5); + assert!( + glwe_out + .noise(module, &pt_out, &sk_prepared, scratch.borrow()) + .std() + .log2() + <= max_noise + 1.0 + ) } } } @@ -158,29 +170,31 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_out: usize = 60; - let dsize: usize = k_out.div_ceil(base2k); + let base2k_out: usize = 17; + let base2k_key: usize = 13; + let k_out: usize = 102; + let max_dsize: usize = k_out.div_ceil(base2k_key); + for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_ggsw: usize = k_out + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_ggsw: usize = k_out + base2k_key * dsize; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(base2k * dsize); + let dnum: usize = k_out.div_ceil(base2k_out * max_dsize); let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), rank: rank.into(), }; let ggsw_apply_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ggsw.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), rank: rank.into(), }; @@ -194,7 +208,7 @@ where let mut source_xa: Source = Source::new([0u8; 32]); // Random input plaintext - module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k_out, &mut pt_want.data, 0, &mut source_xa); pt_want.data.at_mut(0, 0)[1] = 1; @@ -248,7 +262,7 @@ where let max_noise: f64 = noise_ggsw_product( n as f64, - base2k * dsize, + base2k_key * max_dsize, 0.5, var_msg, var_a0_err, @@ -260,7 +274,13 @@ where k_ggsw, ); - glwe_out.assert_noise(module, &sk_prepared, &pt_want, max_noise + 0.5); + assert!( + glwe_out + .noise(module, &pt_want, &sk_prepared, scratch.borrow()) + .std() + .log2() + <= max_noise + 1.0 + ) } } } diff --git a/poulpy-core/src/tests/test_suite/glwe_packer.rs b/poulpy-core/src/tests/test_suite/glwe_packer.rs index 04b08ab..e663836 100644 --- a/poulpy-core/src/tests/test_suite/glwe_packer.rs +++ b/poulpy-core/src/tests/test_suite/glwe_packer.rs @@ -33,25 +33,26 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let n: usize = module.n(); - let base2k: usize = 18; + let base2k_out: usize = 15; + let base2k_key: usize = 10; let k_ct: usize = 36; let pt_k: usize = 18; let rank: usize = 3; let dsize: usize = 1; - let k_ksk: usize = k_ct + base2k * dsize; + let k_ksk: usize = k_ct + base2k_key * dsize; - let dnum: usize = k_ct.div_ceil(base2k * dsize); + let dnum: usize = k_ct.div_ceil(base2k_key * dsize); let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_ct.into(), rank: rank.into(), }; let key_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), rank: rank.into(), dsize: dsize.into(), @@ -134,7 +135,7 @@ where }); let mut res: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); - packer.flush(module, &mut res); + packer.flush(module, &mut res, scratch.borrow()); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); let mut data: Vec = vec![0i64; n]; @@ -153,7 +154,7 @@ where let noise_have: f64 = pt.stats().std().log2(); assert!( - noise_have < -((k_ct - base2k) as f64), + noise_have < -((k_ct - base2k_out) as f64), "noise: {noise_have}" ); } diff --git a/poulpy-core/src/tests/test_suite/glwe_packing.rs b/poulpy-core/src/tests/test_suite/glwe_packing.rs new file mode 100644 index 0000000..666e4d4 --- /dev/null +++ b/poulpy-core/src/tests/test_suite/glwe_packing.rs @@ -0,0 +1,148 @@ +use std::collections::HashMap; + +use itertools::Itertools; +use poulpy_hal::{ + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, Module, Scratch, ScratchOwned}, + source::Source, +}; + +use crate::{ + GLWEAutomorphismKeyEncryptSk, GLWEDecrypt, GLWEEncryptSk, GLWENoise, GLWEPacking, GLWERotate, GLWESub, ScratchTakeCore, + layouts::{ + GLWE, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWELayout, GLWEPlaintext, + GLWESecret, GLWESecretPreparedFactory, + prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared}, + }, +}; + +pub fn test_glwe_packing(module: &Module) +where + Module: GLWEEncryptSk + + GLWEAutomorphismKeyEncryptSk + + GLWEAutomorphismKeyPreparedFactory + + GLWEPacking + + GLWESecretPreparedFactory + + GLWESub + + GLWEDecrypt + + GLWERotate + + GLWENoise, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, +{ + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let n: usize = module.n(); + let base2k_out: usize = 15; + let base2k_key: usize = 10; + let k_ct: usize = 36; + let pt_k: usize = base2k_out; + let rank: usize = 3; + let dsize: usize = 1; + let k_ksk: usize = k_ct + base2k_key * dsize; + + let dnum: usize = k_ct.div_ceil(base2k_key * dsize); + + let glwe_out_infos: GLWELayout = GLWELayout { + n: n.into(), + base2k: base2k_out.into(), + k: k_ct.into(), + rank: rank.into(), + }; + + let key_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { + n: n.into(), + base2k: base2k_key.into(), + k: k_ksk.into(), + rank: rank.into(), + dsize: dsize.into(), + dnum: dnum.into(), + }; + + let mut scratch: ScratchOwned = ScratchOwned::alloc( + GLWE::encrypt_sk_tmp_bytes(module, &glwe_out_infos) + .max(GLWEAutomorphismKey::encrypt_sk_tmp_bytes( + module, &key_infos, + )) + .max(module.glwe_pack_tmp_bytes(&glwe_out_infos, &key_infos)), + ); + + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_out_infos); + sk.fill_ternary_prob(0.5, &mut source_xs); + + let mut sk_prep: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk); + sk_prep.prepare(module, &sk); + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); + let mut data: Vec = vec![0i64; n]; + data.iter_mut().enumerate().for_each(|(i, x)| { + *x = i as i64; + }); + + pt.encode_vec_i64(&data, pt_k.into()); + + let gal_els: Vec = module.glwe_pack_galois_elements(); + + let mut auto_keys: HashMap, BE>> = HashMap::new(); + let mut tmp: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&key_infos); + gal_els.iter().for_each(|gal_el| { + tmp.encrypt_sk( + module, + *gal_el, + &sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + let mut atk_prepared: GLWEAutomorphismKeyPrepared, BE> = + GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &tmp); + atk_prepared.prepare(module, &tmp, scratch.borrow()); + auto_keys.insert(*gal_el, atk_prepared); + }); + + let mut cts = (0..n) + .step_by(5) + .map(|_| { + let mut ct = GLWE::alloc_from_infos(&glwe_out_infos); + ct.encrypt_sk( + module, + &pt, + &sk_prep, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + module.glwe_rotate_inplace(-5, &mut pt, scratch.borrow()); // X^-batch * pt + ct + }) + .collect_vec(); + + let mut cts_map: HashMap>> = HashMap::new(); + + for (i, ct) in cts.iter_mut().enumerate() { + cts_map.insert(5 * i, ct); + } + + let mut res: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); + + module.glwe_pack(&mut res, cts_map, 0, &auto_keys, scratch.borrow()); + + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); + let mut data: Vec = vec![0i64; n]; + data.iter_mut().enumerate().for_each(|(i, x)| { + if i.is_multiple_of(5) { + *x = i as i64; + } + }); + + pt_want.encode_vec_i64(&data, pt_k.into()); + + assert!( + res.noise(module, &pt_want, &sk_prep, scratch.borrow()) + .std() + .log2() + <= ((k_ct - base2k_out) as f64) + ); +} diff --git a/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs index 548d1f0..59310bf 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs @@ -8,10 +8,12 @@ use crate::{ GGLWEKeyswitch, GGLWENoise, GLWESwitchingKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ - GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, GLWESwitchingKeyPreparedFactory, + GGLWEInfos, GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, + GLWESwitchingKeyPreparedFactory, prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared}, }, noise::log2_std_noise_gglwe_product, + var_noise_gglwe_product_v2, }; pub fn test_gglwe_switching_key_keyswitch(module: &Module) @@ -24,27 +26,29 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 60; - let dsize: usize = k_in.div_ceil(base2k); + let base2k_in: usize = 17; + let base2k_key: usize = 13; + let base2k_out: usize = base2k_in; // MUST BE SAME + let k_in: usize = 102; + let max_dsize: usize = k_in.div_ceil(base2k_key); - for rank_in_s0s1 in 1_usize..3 { + for rank_in_s0s1 in 1_usize..2 { for rank_out_s0s1 in 1_usize..3 { for rank_out_s1s2 in 1_usize..3 { - for di in 1_usize..dsize + 1 { - let k_ksk: usize = k_in + base2k * di; + for dsize in 1_usize..max_dsize + 1 { + let k_ksk: usize = k_in + base2k_key * dsize; let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); - let dnum: usize = k_in / base2k; - let dnum_apply: usize = k_in.div_ceil(base2k * di); let dsize_in: usize = 1; + let dnum_in: usize = k_in / base2k_in; + let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize); let gglwe_s0s1_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_in.into(), - dnum: dnum.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank_in: rank_in_s0s1.into(), rank_out: rank_out_s0s1.into(), @@ -52,19 +56,19 @@ where let gglwe_s1s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), - dnum: dnum_apply.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank_in: rank_out_s0s1.into(), rank_out: rank_out_s1s2.into(), }; let gglwe_s0s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), - dnum: dnum_apply.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank_in: rank_in_s0s1.into(), rank_out: rank_out_s1s2.into(), @@ -85,8 +89,8 @@ where ); let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::keyswitch_tmp_bytes( module, - &gglwe_s0s1_infos, &gglwe_s0s2_infos, + &gglwe_s0s1_infos, &gglwe_s1s2_infos, )); @@ -135,22 +139,41 @@ where scratch_apply.borrow(), ); - let max_noise: f64 = log2_std_noise_gglwe_product( - n as f64, - base2k * di, + let max_noise: f64 = var_noise_gglwe_product_v2( + module.n() as f64, + k_ksk, + dnum_ksk, + dsize, + base2k_key, 0.5, 0.5, 0f64, SIGMA * SIGMA, 0f64, rank_out_s0s1 as f64, - k_in, - k_ksk, - ); + ) + .sqrt() + .log2(); - gglwe_s0s2 - .key - .assert_noise(module, &sk2_prepared, &sk0.data, max_noise + 0.5); + for row in 0..gglwe_s0s2.dnum().as_usize() { + for col in 0..gglwe_s0s2.rank_in().as_usize() { + assert!( + gglwe_s0s2 + .key + .noise( + module, + row, + col, + &sk0.data, + &sk2_prepared, + scratch_apply.borrow() + ) + .std() + .log2() + <= max_noise + 0.5 + ) + } + } } } } @@ -168,23 +191,27 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_out: usize = 60; - let dsize: usize = k_out.div_ceil(base2k); + let base2k_out: usize = 17; + let base2k_key: usize = 13; + let k_out: usize = 102; + let max_dsize: usize = k_out.div_ceil(base2k_key); + for rank_in in 1_usize..3 { for rank_out in 1_usize..3 { - for di in 1_usize..dsize + 1 { - let k_ksk: usize = k_out + base2k * di; + for dsize in 1_usize..max_dsize + 1 { + let k_ksk: usize = k_out + base2k_key * dsize; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(base2k * di); let dsize_in: usize = 1; + let dnum_in: usize = k_out / base2k_out; + let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize); + let gglwe_s0s1_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), - dnum: dnum.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank_in: rank_in.into(), rank_out: rank_out.into(), @@ -192,10 +219,10 @@ where let gglwe_s1s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), - dnum: dnum.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank_in: rank_out.into(), rank_out: rank_out.into(), }; @@ -263,7 +290,7 @@ where let max_noise: f64 = log2_std_noise_gglwe_product( n as f64, - base2k * di, + base2k_key * dsize, var_xs, var_xs, 0f64, @@ -274,9 +301,25 @@ where k_ksk, ); - gglwe_s0s2 - .key - .assert_noise(module, &sk2_prepared, &sk0.data, max_noise + 0.5); + for row in 0..gglwe_s0s2.dnum().as_usize() { + for col in 0..gglwe_s0s2.rank_in().as_usize() { + assert!( + gglwe_s0s2 + .key + .noise( + module, + row, + col, + &sk0.data, + &sk2_prepared, + scratch_apply.borrow() + ) + .std() + .log2() + <= max_noise + 0.5 + ) + } + } } } } diff --git a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs index c4191fa..f1e33d5 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs @@ -8,8 +8,8 @@ use crate::{ GGLWEToGGSWKeyEncryptSk, GGSWEncryptSk, GGSWKeyswitch, GGSWNoise, GLWESwitchingKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGLWEToGGSWKey, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedFactory, GGSW, GGSWLayout, GLWESecret, - GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, GLWESwitchingKeyPreparedFactory, + GGLWEToGGSWKey, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedFactory, GGSW, GGSWInfos, GGSWLayout, GLWEInfos, + GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, GLWESwitchingKeyPreparedFactory, GLWETensorKeyLayout, prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared}, }, @@ -30,53 +30,57 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 54; - let dsize: usize = k_in.div_ceil(base2k); + let base2k_in: usize = 17; + let base2k_key: usize = 13; + let base2k_out: usize = base2k_in; // MUST BE SAME + let k_in: usize = 102; + let max_dsize: usize = k_in.div_ceil(base2k_key); + for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_ksk: usize = k_in + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_ksk: usize = k_in + base2k_key * dsize; let k_tsk: usize = k_ksk; let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); - let dnum: usize = k_in.div_ceil(di * base2k); + let dnum_in: usize = k_in / base2k_in; + let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize); let dsize_in: usize = 1; let ggsw_in_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_in.into(), - dnum: dnum.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank: rank.into(), }; let ggsw_out_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), - dnum: dnum.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank: rank.into(), }; let tsk_infos: GLWETensorKeyLayout = GLWETensorKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_tsk.into(), - dnum: dnum.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank: rank.into(), }; let ksk_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), - dnum: dnum.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank_in: rank.into(), rank_out: rank.into(), }; @@ -163,7 +167,7 @@ where let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( n as f64, - base2k * di, + base2k_key * dsize, col_j, var_xs, 0f64, @@ -176,7 +180,24 @@ where ) + 0.5 }; - ggsw_out.assert_noise(module, &sk_out_prepared, &pt_scalar, &max_noise); + for row in 0..ggsw_out.dnum().as_usize() { + for col in 0..ggsw_out.rank().as_usize() + 1 { + assert!( + ggsw_out + .noise( + module, + row, + col, + &pt_scalar, + &sk_out_prepared, + scratch.borrow() + ) + .std() + .log2() + <= max_noise(col) + ) + } + } } } } @@ -195,43 +216,45 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_out: usize = 54; - let dsize: usize = k_out.div_ceil(base2k); + let base2k_out: usize = 17; + let base2k_key: usize = 13; + let k_out: usize = 102; + let max_dsize: usize = k_out.div_ceil(base2k_key); + for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_ksk: usize = k_out + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_ksk: usize = k_out + base2k_key * dsize; let k_tsk: usize = k_ksk; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(di * base2k); - + let dnum_in: usize = k_out / base2k_out; + let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize); let dsize_in: usize = 1; let ggsw_out_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), - dnum: dnum.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank: rank.into(), }; let tsk_infos: GLWETensorKeyLayout = GLWETensorKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_tsk.into(), - dnum: dnum.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank: rank.into(), }; let ksk_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), - dnum: dnum.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank_in: rank.into(), rank_out: rank.into(), }; @@ -311,7 +334,7 @@ where let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( n as f64, - base2k * di, + base2k_key * dsize, col_j, var_xs, 0f64, @@ -324,7 +347,24 @@ where ) + 0.5 }; - ggsw_out.assert_noise(module, &sk_out_prepared, &pt_scalar, &max_noise); + for row in 0..ggsw_out.dnum().as_usize() { + for col in 0..ggsw_out.rank().as_usize() + 1 { + assert!( + ggsw_out + .noise( + module, + row, + col, + &pt_scalar, + &sk_out_prepared, + scratch.borrow() + ) + .std() + .log2() + <= max_noise(col) + ) + } + } } } } diff --git a/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs index 90cc543..c4ac553 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs @@ -5,14 +5,14 @@ use poulpy_hal::{ }; use crate::{ - GLWEEncryptSk, GLWEKeyswitch, GLWENoise, GLWESwitchingKeyEncryptSk, ScratchTakeCore, + GLWEEncryptSk, GLWEKeyswitch, GLWENoise, GLWENormalize, GLWESwitchingKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, - GLWESwitchingKeyPreparedFactory, + GLWESwitchingKeyPreparedFactory, LWEInfos, prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared}, }, - noise::log2_std_noise_gglwe_product, + var_noise_gglwe_product_v2, }; #[allow(clippy::too_many_arguments)] @@ -24,43 +24,46 @@ where + GLWEKeyswitch + GLWESecretPreparedFactory + GLWESwitchingKeyPreparedFactory - + GLWENoise, + + GLWENoise + + GLWENormalize, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 45; - let dsize: usize = k_in.div_ceil(base2k); + let base2k_in: usize = 17; + let base2k_key: usize = 13; + let base2k_out: usize = 15; + let k_in: usize = 102; + let max_dsize: usize = k_in.div_ceil(base2k_key); for rank_in in 1_usize..3 { for rank_out in 1_usize..3 { - for di in 1_usize..dsize + 1 { - let k_ksk: usize = k_in + base2k * di; + for dsize in 1_usize..max_dsize + 1 { + let k_ksk: usize = k_in + base2k_key * dsize; let k_out: usize = k_ksk; // better capture noise let n: usize = module.n(); - let dnum: usize = k_in.div_ceil(base2k * dsize); + let dnum: usize = k_in.div_ceil(base2k_key * dsize); let glwe_in_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_in.into(), rank: rank_in.into(), }; let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), rank: rank_out.into(), }; let ksk: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), rank_in: rank_in.into(), rank_out: rank_out.into(), }; @@ -68,13 +71,14 @@ where let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&ksk); let mut glwe_in: GLWE> = GLWE::alloc_from_infos(&glwe_in_infos); let mut glwe_out: GLWE> = GLWE::alloc_from_infos(&glwe_out_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_in_infos); + let mut pt_in: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_in_infos); + let mut pt_out: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(pt_in.base2k().into(), &mut pt_in.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk) @@ -105,7 +109,7 @@ where glwe_in.encrypt_sk( module, - &pt_want, + &pt_in, &sk_in_prepared, &mut source_xa, &mut source_xe, @@ -118,20 +122,31 @@ where glwe_out.keyswitch(module, &glwe_in, &ksk_prepared, scratch.borrow()); - let max_noise: f64 = log2_std_noise_gglwe_product( + let max_noise: f64 = var_noise_gglwe_product_v2( module.n() as f64, - base2k * dsize, + k_ksk, + dnum, + dsize, + base2k_key, 0.5, 0.5, 0f64, SIGMA * SIGMA, 0f64, rank_in as f64, - k_in, - k_ksk, - ); + ) + .sqrt() + .log2(); - glwe_out.assert_noise(module, &sk_out_prepared, &pt_want, max_noise + 0.5); + module.glwe_normalize(&mut pt_out, &pt_in, scratch.borrow()); + + assert!( + glwe_out + .noise(module, &pt_out, &sk_out_prepared, scratch.borrow()) + .std() + .log2() + <= max_noise + 1.0 + ) } } } @@ -149,30 +164,31 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_out: usize = 45; - let dsize: usize = k_out.div_ceil(base2k); + let base2k_out: usize = 17; + let base2k_key: usize = 13; + + let k_out: usize = 102; + let max_dsize: usize = k_out.div_ceil(base2k_key); for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_ksk: usize = k_out + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_ksk: usize = k_out + base2k_key * dsize; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(base2k * dsize); - + let dnum: usize = k_out.div_ceil(base2k_key * dsize); let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), rank: rank.into(), }; let ksk_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), rank_in: rank.into(), rank_out: rank.into(), }; @@ -185,7 +201,12 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform( + pt_want.base2k().into(), + &mut pt_want.data, + 0, + &mut source_xa, + ); let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_infos) @@ -229,20 +250,29 @@ where glwe_out.keyswitch_inplace(module, &ksk_prepared, scratch.borrow()); - let max_noise: f64 = log2_std_noise_gglwe_product( + let max_noise: f64 = var_noise_gglwe_product_v2( module.n() as f64, - base2k * dsize, + k_ksk, + dnum, + dsize, + base2k_key, 0.5, 0.5, 0f64, SIGMA * SIGMA, 0f64, rank as f64, - k_out, - k_ksk, - ); + ) + .sqrt() + .log2(); - glwe_out.assert_noise(module, &sk_out_prepared, &pt_want, max_noise + 0.5); + assert!( + glwe_out + .noise(module, &pt_want, &sk_out_prepared, scratch.borrow()) + .std() + .log2() + <= max_noise + 1.0 + ) } } } diff --git a/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs index 7617b09..cd06749 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs @@ -1,5 +1,5 @@ use poulpy_hal::{ - api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow}, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalize}, layouts::{Backend, Module, Scratch, ScratchOwned, ZnxView}, source::Source, }; @@ -14,21 +14,27 @@ use crate::{ pub fn test_lwe_keyswitch(module: &Module) where - Module: - LWEKeySwitch + LWESwitchingKeyEncrypt + LWEEncryptSk + LWESwitchingKeyPreparedFactory + LWEDecrypt, + Module: LWEKeySwitch + + LWESwitchingKeyEncrypt + + LWEEncryptSk + + LWESwitchingKeyPreparedFactory + + LWEDecrypt + + VecZnxNormalize, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { let n: usize = module.n(); - let base2k: usize = 17; + let base2k_in: usize = 17; + let base2k_out: usize = 15; + let base2k_key: usize = 13; - let n_lwe_in: usize = 22; - let n_lwe_out: usize = 30; - let k_lwe_ct: usize = 2 * base2k; + let n_lwe_in: usize = module.n() >> 1; + let n_lwe_out: usize = module.n() >> 1; + let k_lwe_ct: usize = 102; let k_lwe_pt: usize = 8; - let k_ksk: usize = k_lwe_ct + base2k; - let dnum: usize = k_lwe_ct.div_ceil(base2k); + let k_ksk: usize = k_lwe_ct + base2k_key; + let dnum: usize = k_lwe_ct.div_ceil(base2k_key); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); @@ -36,21 +42,21 @@ where let key_apply_infos: LWESwitchingKeyLayout = LWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), dnum: dnum.into(), }; let lwe_in_infos: LWELayout = LWELayout { n: n_lwe_in.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_lwe_ct.into(), }; let lwe_out_infos: LWELayout = LWELayout { n: n_lwe_out.into(), k: k_lwe_ct.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), }; let mut scratch: ScratchOwned = ScratchOwned::alloc( @@ -66,7 +72,7 @@ where let data: i64 = 17; - let mut lwe_pt_in: LWEPlaintext> = LWEPlaintext::alloc(base2k.into(), k_lwe_pt.into()); + let mut lwe_pt_in: LWEPlaintext> = LWEPlaintext::alloc(base2k_in.into(), k_lwe_pt.into()); lwe_pt_in.encode_i64(data, k_lwe_pt.into()); let mut lwe_ct_in: LWE> = LWE::alloc_from_infos(&lwe_in_infos); @@ -76,6 +82,7 @@ where &sk_lwe_in, &mut source_xa, &mut source_xe, + scratch.borrow(), ); let mut ksk: LWESwitchingKey> = LWESwitchingKey::alloc_from_infos(&key_apply_infos); @@ -97,7 +104,18 @@ where lwe_ct_out.keyswitch(module, &lwe_ct_in, &ksk_prepared, scratch.borrow()); let mut lwe_pt_out: LWEPlaintext> = LWEPlaintext::alloc_from_infos(&lwe_out_infos); - lwe_ct_out.decrypt(module, &mut lwe_pt_out, &sk_lwe_out); + lwe_ct_out.decrypt(module, &mut lwe_pt_out, &sk_lwe_out, scratch.borrow()); - assert_eq!(lwe_pt_in.data.at(0, 0)[0], lwe_pt_out.data.at(0, 0)[0]); + let mut lwe_pt_want: LWEPlaintext> = LWEPlaintext::alloc_from_infos(&lwe_out_infos); + module.vec_znx_normalize( + base2k_out, + lwe_pt_want.data_mut(), + 0, + base2k_in, + lwe_pt_in.data(), + 0, + scratch.borrow(), + ); + + assert_eq!(lwe_pt_want.data.at(0, 0)[0], lwe_pt_out.data.at(0, 0)[0]); } diff --git a/poulpy-core/src/tests/test_suite/mod.rs b/poulpy-core/src/tests/test_suite/mod.rs index 6777d46..6b086e7 100644 --- a/poulpy-core/src/tests/test_suite/mod.rs +++ b/poulpy-core/src/tests/test_suite/mod.rs @@ -5,8 +5,10 @@ pub mod keyswitch; mod conversion; mod glwe_packer; +mod glwe_packing; mod trace; pub use conversion::*; pub use glwe_packer::*; +pub use glwe_packing::*; pub use trace::*; diff --git a/poulpy-core/src/tests/test_suite/trace.rs b/poulpy-core/src/tests/test_suite/trace.rs index c4db7e4..b57dee7 100644 --- a/poulpy-core/src/tests/test_suite/trace.rs +++ b/poulpy-core/src/tests/test_suite/trace.rs @@ -32,26 +32,27 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 8; + let base2k_out: usize = 15; + let base2k_key: usize = 10; let k: usize = 54; for rank in 1_usize..3 { let n: usize = module.n(); - let k_autokey: usize = k + base2k; + let k_autokey: usize = k + base2k_key; let dsize: usize = 1; - let dnum: usize = k.div_ceil(base2k * dsize); + let dnum: usize = k.div_ceil(base2k_key * dsize); let glwe_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k.into(), rank: rank.into(), }; let key_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_autokey.into(), rank: rank.into(), dsize: dsize.into(), @@ -85,7 +86,7 @@ where .iter_mut() .for_each(|x| *x = source_xa.next_i64() & 0xFF); - module.vec_znx_fill_uniform(base2k, &mut pt_have.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k_out, &mut pt_have.data, 0, &mut source_xa); glwe_out.encrypt_sk( module, @@ -121,13 +122,18 @@ where glwe_out.decrypt(module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_sub_inplace(&mut pt_want.data, 0, &pt_have.data, 0); - module.vec_znx_normalize_inplace(base2k, &mut pt_want.data, 0, scratch.borrow()); + module.vec_znx_normalize_inplace( + pt_want.base2k().as_usize(), + &mut pt_want.data, + 0, + scratch.borrow(), + ); let noise_have: f64 = pt_want.stats().std().log2(); let mut noise_want: f64 = var_noise_gglwe_product( n as f64, - base2k, + base2k_key * dsize, 0.5, 0.5, 1.0 / 12.0, diff --git a/poulpy-core/src/utils.rs b/poulpy-core/src/utils.rs index 858cd5d..7484fcd 100644 --- a/poulpy-core/src/utils.rs +++ b/poulpy-core/src/utils.rs @@ -37,16 +37,20 @@ impl GLWEPlaintext { impl LWEPlaintext { pub fn encode_i64(&mut self, data: i64, k: TorusPrecision) { let base2k: usize = self.base2k().into(); - self.data.encode_i64(base2k, k.into(), data); + self.data.encode_coeff_i64(base2k, 0, k.into(), 0, data); } } impl LWEPlaintext { pub fn decode_i64(&self, k: TorusPrecision) -> i64 { - self.data.decode_i64(self.base2k().into(), k.into()) + self.data + .decode_coeff_i64(self.base2k().into(), 0, k.into(), 0) } pub fn decode_float(&self) -> Float { - self.data.decode_float(self.base2k().into()) + let mut out: [Float; 1] = [Float::new(self.k().as_u32())]; + self.data + .decode_vec_float(self.base2k().into(), 0, &mut out); + out[0].clone() } } diff --git a/poulpy-hal/src/api/mod.rs b/poulpy-hal/src/api/mod.rs index b024a94..9af22de 100644 --- a/poulpy-hal/src/api/mod.rs +++ b/poulpy-hal/src/api/mod.rs @@ -6,7 +6,6 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp_pmat; -mod zn; pub use convolution::*; pub use module::*; @@ -16,4 +15,3 @@ pub use vec_znx::*; pub use vec_znx_big::*; pub use vec_znx_dft::*; pub use vmp_pmat::*; -pub use zn::*; diff --git a/poulpy-hal/src/api/scratch.rs b/poulpy-hal/src/api/scratch.rs index aef02e1..4dbb14b 100644 --- a/poulpy-hal/src/api/scratch.rs +++ b/poulpy-hal/src/api/scratch.rs @@ -1,6 +1,6 @@ use crate::{ api::{ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf}, - layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, Zn}, + layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, }; /// Allocates a new [crate::layouts::ScratchOwned] of `size` aligned bytes. @@ -69,11 +69,6 @@ where (SvpPPol::from_data(take_slice, module.n(), cols), rem_slice) } - fn take_zn(&mut self, n: usize, cols: usize, size: usize) -> (Zn<&mut [u8]>, &mut Self) { - let (take_slice, rem_slice) = self.take_slice(Zn::bytes_of(n, cols, size)); - (Zn::from_data(take_slice, n, cols, size), rem_slice) - } - fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) { let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(n, cols, size)); (VecZnx::from_data(take_slice, n, cols, size), rem_slice) diff --git a/poulpy-hal/src/api/zn.rs b/poulpy-hal/src/api/zn.rs deleted file mode 100644 index 9e8e308..0000000 --- a/poulpy-hal/src/api/zn.rs +++ /dev/null @@ -1,58 +0,0 @@ -use crate::{ - layouts::{Backend, Scratch, ZnToMut}, - reference::zn::zn_normalize_tmp_bytes, - source::Source, -}; - -pub trait ZnNormalizeTmpBytes { - fn zn_normalize_tmp_bytes(&self, n: usize) -> usize { - zn_normalize_tmp_bytes(n) - } -} - -pub trait ZnNormalizeInplace { - /// Normalizes the selected column of `a`. - fn zn_normalize_inplace(&self, n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) - where - R: ZnToMut; -} - -pub trait ZnFillUniform { - /// Fills the first `size` size with uniform values in \[-2^{base2k-1}, 2^{base2k-1}\] - fn zn_fill_uniform(&self, n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) - where - R: ZnToMut; -} - -#[allow(clippy::too_many_arguments)] -pub trait ZnFillNormal { - fn zn_fill_normal( - &self, - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut; -} - -#[allow(clippy::too_many_arguments)] -pub trait ZnAddNormal { - /// Adds a discrete normal vector scaled by 2^{-k} with the provided standard deviation and bounded to \[-bound, bound\]. - fn zn_add_normal( - &self, - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut; -} diff --git a/poulpy-hal/src/delegates/mod.rs b/poulpy-hal/src/delegates/mod.rs index 85de88d..595a641 100644 --- a/poulpy-hal/src/delegates/mod.rs +++ b/poulpy-hal/src/delegates/mod.rs @@ -5,4 +5,3 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp_pmat; -mod zn; diff --git a/poulpy-hal/src/delegates/zn.rs b/poulpy-hal/src/delegates/zn.rs deleted file mode 100644 index 6a4c999..0000000 --- a/poulpy-hal/src/delegates/zn.rs +++ /dev/null @@ -1,81 +0,0 @@ -use crate::{ - api::{ZnAddNormal, ZnFillNormal, ZnFillUniform, ZnNormalizeInplace, ZnNormalizeTmpBytes}, - layouts::{Backend, Module, Scratch, ZnToMut}, - oep::{ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl}, - source::Source, -}; - -impl ZnNormalizeTmpBytes for Module -where - B: Backend + ZnNormalizeTmpBytesImpl, -{ - fn zn_normalize_tmp_bytes(&self, n: usize) -> usize { - B::zn_normalize_tmp_bytes_impl(n) - } -} - -impl ZnNormalizeInplace for Module -where - B: Backend + ZnNormalizeInplaceImpl, -{ - fn zn_normalize_inplace(&self, n: usize, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) - where - A: ZnToMut, - { - B::zn_normalize_inplace_impl(n, base2k, a, a_col, scratch) - } -} - -impl ZnFillUniform for Module -where - B: Backend + ZnFillUniformImpl, -{ - fn zn_fill_uniform(&self, n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) - where - R: ZnToMut, - { - B::zn_fill_uniform_impl(n, base2k, res, res_col, source); - } -} - -impl ZnFillNormal for Module -where - B: Backend + ZnFillNormalImpl, -{ - fn zn_fill_normal( - &self, - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut, - { - B::zn_fill_normal_impl(n, base2k, res, res_col, k, source, sigma, bound); - } -} - -impl ZnAddNormal for Module -where - B: Backend + ZnAddNormalImpl, -{ - fn zn_add_normal( - &self, - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut, - { - B::zn_add_normal_impl(n, base2k, res, res_col, k, source, sigma, bound); - } -} diff --git a/poulpy-hal/src/layouts/encoding.rs b/poulpy-hal/src/layouts/encoding.rs index 37a4677..6934eec 100644 --- a/poulpy-hal/src/layouts/encoding.rs +++ b/poulpy-hal/src/layouts/encoding.rs @@ -2,7 +2,7 @@ use itertools::izip; use rug::{Assign, Float}; use crate::{ - layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, Zn, ZnToMut, ZnToRef, ZnxInfos, ZnxView, ZnxViewMut}, + layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, reference::znx::{ ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStepInplace, ZnxRef, ZnxZero, get_carry_i128, get_digit_i128, znx_zero_ref, @@ -245,90 +245,6 @@ impl VecZnx { } } -impl Zn { - pub fn encode_i64(&mut self, base2k: usize, k: usize, data: i64) { - let size: usize = k.div_ceil(base2k); - - #[cfg(debug_assertions)] - { - let a: Zn<&mut [u8]> = self.to_mut(); - assert!( - size <= a.size(), - "invalid argument k.div_ceil(base2k)={} > a.size()={}", - size, - a.size() - ); - } - - let mut a: Zn<&mut [u8]> = self.to_mut(); - let a_size = a.size(); - - for j in 0..a_size { - a.at_mut(0, j)[0] = 0 - } - - a.at_mut(0, size - 1)[0] = data; - - let mut carry: Vec = vec![0i64; 1]; - let k_rem: usize = (base2k - (k % base2k)) % base2k; - - for j in (0..size).rev() { - let slice = &mut a.at_mut(0, j)[..1]; - - if j == size - 1 { - ZnxRef::znx_normalize_first_step_inplace(base2k, k_rem, slice, &mut carry); - } else if j == 0 { - ZnxRef::znx_normalize_final_step_inplace(base2k, k_rem, slice, &mut carry); - } else { - ZnxRef::znx_normalize_middle_step_inplace(base2k, k_rem, slice, &mut carry); - } - } - } -} - -impl Zn { - pub fn decode_i64(&self, base2k: usize, k: usize) -> i64 { - let a: Zn<&[u8]> = self.to_ref(); - let size: usize = k.div_ceil(base2k); - let mut res: i64 = 0; - let rem: usize = base2k - (k % base2k); - (0..size).for_each(|j| { - let x: i64 = a.at(0, j)[0]; - if j == size - 1 && rem != base2k { - let k_rem: usize = (base2k - rem) % base2k; - let scale: i64 = 1 << rem as i64; - res = (res << k_rem) + div_round(x, scale); - } else { - res = (res << base2k) + x; - } - }); - res - } - - pub fn decode_float(&self, base2k: usize) -> Float { - let a: Zn<&[u8]> = self.to_ref(); - let size: usize = a.size(); - let prec: u32 = (base2k * size) as u32; - - // 2^{base2k} - let base: Float = Float::with_val(prec, (1 << base2k) as f64); - let mut res: Float = Float::with_val(prec, (1 << base2k) as f64); - - // y[i] = sum x[j][i] * 2^{-base2k*j} - (0..size).for_each(|i| { - if i == 0 { - res.assign(a.at(0, size - i - 1)[0]); - res /= &base; - } else { - res += Float::with_val(prec, a.at(0, size - i - 1)[0]); - res /= &base; - } - }); - - res - } -} - #[inline] pub fn div_round(a: i64, b: i64) -> i64 { assert!(b != 0, "division by zero"); diff --git a/poulpy-hal/src/layouts/mod.rs b/poulpy-hal/src/layouts/mod.rs index 7d4600b..d164234 100644 --- a/poulpy-hal/src/layouts/mod.rs +++ b/poulpy-hal/src/layouts/mod.rs @@ -10,7 +10,6 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp_pmat; -mod zn; mod znx_base; pub use mat_znx::*; @@ -24,7 +23,6 @@ pub use vec_znx::*; pub use vec_znx_big::*; pub use vec_znx_dft::*; pub use vmp_pmat::*; -pub use zn::*; pub use znx_base::*; pub trait Data = PartialEq + Eq + Sized + Default; diff --git a/poulpy-hal/src/layouts/stats.rs b/poulpy-hal/src/layouts/stats.rs index bb4f9d2..b4bfe0b 100644 --- a/poulpy-hal/src/layouts/stats.rs +++ b/poulpy-hal/src/layouts/stats.rs @@ -32,7 +32,7 @@ impl VecZnx { data.iter().for_each(|x| { avg.add_assign_round(x, Round::Nearest); - max.max_mut(&Float::with_val(53, x.abs_ref())); + max.max_mut(&Float::with_val(prec, x.abs_ref())); }); avg.div_assign_round(Float::with_val(prec, data.len()), Round::Nearest); data.iter_mut().for_each(|x| { diff --git a/poulpy-hal/src/layouts/zn.rs b/poulpy-hal/src/layouts/zn.rs deleted file mode 100644 index 00f8067..0000000 --- a/poulpy-hal/src/layouts/zn.rs +++ /dev/null @@ -1,273 +0,0 @@ -use std::{ - fmt, - hash::{DefaultHasher, Hasher}, -}; - -use crate::{ - alloc_aligned, - layouts::{ - Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, ToOwnedDeep, WriterTo, ZnxInfos, - ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, - }, - source::Source, -}; - -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; -use rand::RngCore; - -#[repr(C)] -#[derive(PartialEq, Eq, Clone, Copy, Hash)] -pub struct Zn { - pub data: D, - pub n: usize, - pub cols: usize, - pub size: usize, - pub max_size: usize, -} - -impl DigestU64 for Zn { - fn digest_u64(&self) -> u64 { - let mut h: DefaultHasher = DefaultHasher::new(); - h.write(self.data.as_ref()); - h.write_usize(self.n); - h.write_usize(self.cols); - h.write_usize(self.size); - h.write_usize(self.max_size); - h.finish() - } -} - -impl ToOwnedDeep for Zn { - type Owned = Zn>; - fn to_owned_deep(&self) -> Self::Owned { - Zn { - data: self.data.as_ref().to_vec(), - n: self.n, - cols: self.cols, - size: self.size, - max_size: self.max_size, - } - } -} - -impl fmt::Debug for Zn { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{self}") - } -} - -impl ZnxInfos for Zn { - fn cols(&self) -> usize { - self.cols - } - - fn rows(&self) -> usize { - 1 - } - - fn n(&self) -> usize { - self.n - } - - fn size(&self) -> usize { - self.size - } -} - -impl ZnxSliceSize for Zn { - fn sl(&self) -> usize { - self.n() * self.cols() - } -} - -impl DataView for Zn { - type D = D; - fn data(&self) -> &Self::D { - &self.data - } -} - -impl DataViewMut for Zn { - fn data_mut(&mut self) -> &mut Self::D { - &mut self.data - } -} - -impl ZnxView for Zn { - type Scalar = i64; -} - -impl Zn> { - pub fn rsh_tmp_bytes(n: usize) -> usize { - n * std::mem::size_of::() - } -} - -impl ZnxZero for Zn { - fn zero(&mut self) { - self.raw_mut().fill(0) - } - fn zero_at(&mut self, i: usize, j: usize) { - self.at_mut(i, j).fill(0); - } -} - -impl Zn> { - pub fn bytes_of(n: usize, cols: usize, size: usize) -> usize { - n * cols * size * size_of::() - } - - pub fn alloc(n: usize, cols: usize, size: usize) -> Self { - let data: Vec = alloc_aligned::(Self::bytes_of(n, cols, size)); - Self { - data, - n, - cols, - size, - max_size: size, - } - } - - pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { - let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of(n, cols, size)); - Self { - data, - n, - cols, - size, - max_size: size, - } - } -} - -impl Zn { - pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { - Self { - data, - n, - cols, - size, - max_size: size, - } - } -} - -impl fmt::Display for Zn { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!( - f, - "Zn(n={}, cols={}, size={})", - self.n, self.cols, self.size - )?; - - for col in 0..self.cols { - writeln!(f, "Column {col}:")?; - for size in 0..self.size { - let coeffs = self.at(col, size); - write!(f, " Size {size}: [")?; - - let max_show = 100; - let show_count = coeffs.len().min(max_show); - - for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{coeff}")?; - } - - if coeffs.len() > max_show { - write!(f, ", ... ({} more)", coeffs.len() - max_show)?; - } - - writeln!(f, "]")?; - } - } - Ok(()) - } -} - -impl FillUniform for Zn { - fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { - match log_bound { - 64 => source.fill_bytes(self.data.as_mut()), - 0 => panic!("invalid log_bound, cannot be zero"), - _ => { - let mask: u64 = (1u64 << log_bound) - 1; - for x in self.raw_mut().iter_mut() { - let r = source.next_u64() & mask; - *x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound); - } - } - } - } -} - -pub type ZnOwned = Zn>; -pub type ZnMut<'a> = Zn<&'a mut [u8]>; -pub type ZnRef<'a> = Zn<&'a [u8]>; - -pub trait ZnToRef { - fn to_ref(&self) -> Zn<&[u8]>; -} - -impl ZnToRef for Zn { - fn to_ref(&self) -> Zn<&[u8]> { - Zn { - data: self.data.as_ref(), - n: self.n, - cols: self.cols, - size: self.size, - max_size: self.max_size, - } - } -} - -pub trait ZnToMut { - fn to_mut(&mut self) -> Zn<&mut [u8]>; -} - -impl ZnToMut for Zn { - fn to_mut(&mut self) -> Zn<&mut [u8]> { - Zn { - data: self.data.as_mut(), - n: self.n, - cols: self.cols, - size: self.size, - max_size: self.max_size, - } - } -} - -impl ReaderFrom for Zn { - fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.n = reader.read_u64::()? as usize; - self.cols = reader.read_u64::()? as usize; - self.size = reader.read_u64::()? as usize; - self.max_size = reader.read_u64::()? as usize; - let len: usize = reader.read_u64::()? as usize; - let buf: &mut [u8] = self.data.as_mut(); - if buf.len() != len { - return Err(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - format!("self.data.len()={} != read len={}", buf.len(), len), - )); - } - reader.read_exact(&mut buf[..len])?; - Ok(()) - } -} - -impl WriterTo for Zn { - fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.n as u64)?; - writer.write_u64::(self.cols as u64)?; - writer.write_u64::(self.size as u64)?; - writer.write_u64::(self.max_size as u64)?; - let buf: &[u8] = self.data.as_ref(); - writer.write_u64::(buf.len() as u64)?; - writer.write_all(buf)?; - Ok(()) - } -} diff --git a/poulpy-hal/src/oep/mod.rs b/poulpy-hal/src/oep/mod.rs index dac0def..bc53c0e 100644 --- a/poulpy-hal/src/oep/mod.rs +++ b/poulpy-hal/src/oep/mod.rs @@ -5,7 +5,6 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp_pmat; -mod zn; pub use module::*; pub use scratch::*; @@ -14,4 +13,3 @@ pub use vec_znx::*; pub use vec_znx_big::*; pub use vec_znx_dft::*; pub use vmp_pmat::*; -pub use zn::*; diff --git a/poulpy-hal/src/oep/zn.rs b/poulpy-hal/src/oep/zn.rs deleted file mode 100644 index d2e03ad..0000000 --- a/poulpy-hal/src/oep/zn.rs +++ /dev/null @@ -1,70 +0,0 @@ -use crate::{ - layouts::{Backend, Scratch, ZnToMut}, - source::Source, -}; - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation. -/// * See [crate::api::ZnNormalizeTmpBytes] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait ZnNormalizeTmpBytesImpl { - fn zn_normalize_tmp_bytes_impl(n: usize) -> usize; -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation. -/// * See [crate::api::ZnNormalizeInplace] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait ZnNormalizeInplaceImpl { - fn zn_normalize_inplace_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) - where - R: ZnToMut; -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation. -/// * See [crate::api::ZnFillUniform] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait ZnFillUniformImpl { - fn zn_fill_uniform_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) - where - R: ZnToMut; -} - -#[allow(clippy::too_many_arguments)] -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation. -/// * See [crate::api::ZnFillNormal] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait ZnFillNormalImpl { - fn zn_fill_normal_impl( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut; -} - -#[allow(clippy::too_many_arguments)] -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation. -/// * See [crate::api::ZnAddNormal] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait ZnAddNormalImpl { - fn zn_add_normal_impl( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut; -} diff --git a/poulpy-hal/src/reference/fft64/vec_znx_big.rs b/poulpy-hal/src/reference/fft64/vec_znx_big.rs index 233eb5c..64a643e 100644 --- a/poulpy-hal/src/reference/fft64/vec_znx_big.rs +++ b/poulpy-hal/src/reference/fft64/vec_znx_big.rs @@ -235,10 +235,10 @@ pub fn vec_znx_big_normalize_tmp_bytes(n: usize) -> usize { } pub fn vec_znx_big_normalize( - res_basek: usize, + res_base2k: usize, res: &mut R, res_col: usize, - a_basek: usize, + a_base2k: usize, a: &A, a_col: usize, carry: &mut [i64], @@ -267,7 +267,7 @@ pub fn vec_znx_big_normalize( max_size: a.max_size, }; - vec_znx_normalize::<_, _, BE>(res_basek, res, res_col, a_basek, &a_vznx, a_col, carry); + vec_znx_normalize::<_, _, BE>(res_base2k, res, res_col, a_base2k, &a_vznx, a_col, carry); } pub fn vec_znx_big_add_normal_ref>( diff --git a/poulpy-hal/src/reference/mod.rs b/poulpy-hal/src/reference/mod.rs index 9fd1500..9bd1154 100644 --- a/poulpy-hal/src/reference/mod.rs +++ b/poulpy-hal/src/reference/mod.rs @@ -1,4 +1,3 @@ pub mod fft64; pub mod vec_znx; -pub mod zn; pub mod znx; diff --git a/poulpy-hal/src/reference/vec_znx/normalize.rs b/poulpy-hal/src/reference/vec_znx/normalize.rs index a62c106..139c8a5 100644 --- a/poulpy-hal/src/reference/vec_znx/normalize.rs +++ b/poulpy-hal/src/reference/vec_znx/normalize.rs @@ -53,6 +53,8 @@ pub fn vec_znx_normalize( let res_size: usize = res.size(); let a_size: usize = a.size(); + let carry: &mut [i64] = &mut carry[..2 * n]; + if res_base2k == a_base2k { if a_size > res_size { for j in (res_size..a_size).rev() { @@ -95,9 +97,9 @@ pub fn vec_znx_normalize( // Get carry for limbs of a that have higher precision than res for j in (a_min_size..a_size).rev() { if j == a_size - 1 { - ZNXARI::znx_normalize_first_step_carry_only(res_base2k, 0, a.at(a_col, j), carry); + ZNXARI::znx_normalize_first_step_carry_only(a_base2k, 0, a.at(a_col, j), carry); } else { - ZNXARI::znx_normalize_middle_step_carry_only(res_base2k, 0, a.at(a_col, j), carry); + ZNXARI::znx_normalize_middle_step_carry_only(a_base2k, 0, a.at(a_col, j), carry); } } @@ -118,6 +120,10 @@ pub fn vec_znx_normalize( // for the current limb. let mut res_left: usize = res_base2k; + for j in 0..res_size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } + for j in (0..a_min_size).rev() { // Trackers: wow much of a_norm is left to // be flushed on res. @@ -196,10 +202,6 @@ pub fn vec_znx_normalize( } } } - - for j in res_min_size..res_size { - ZNXARI::znx_zero(res.at_mut(res_col, j)); - } } } @@ -380,14 +382,6 @@ fn test_vec_znx_normalize_conv() { err.sub_assign_round(&data_res[i], Round::Nearest); err = err.abs(); - // println!( - // "want: {} have: {} tmp: {} (want-have): {}", - // data_want[i].to_f64(), - // data_res[i].to_f64(), - // data_tmp[i].to_f64(), - // err.to_f64() - // ); - let err_log2: f64 = err .clone() .max(&Float::with_val(prec as u32, 1e-60)) diff --git a/poulpy-hal/src/reference/zn/mod.rs b/poulpy-hal/src/reference/zn/mod.rs deleted file mode 100644 index d4838c3..0000000 --- a/poulpy-hal/src/reference/zn/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod normalization; -mod sampling; - -pub use normalization::*; -pub use sampling::*; diff --git a/poulpy-hal/src/reference/zn/normalization.rs b/poulpy-hal/src/reference/zn/normalization.rs deleted file mode 100644 index 83cfeb7..0000000 --- a/poulpy-hal/src/reference/zn/normalization.rs +++ /dev/null @@ -1,72 +0,0 @@ -use crate::{ - api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace, ZnNormalizeTmpBytes}, - layouts::{Backend, Module, ScratchOwned, Zn, ZnToMut, ZnxInfos, ZnxView, ZnxViewMut}, - reference::znx::{ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStepInplace, ZnxRef}, - source::Source, -}; - -pub fn zn_normalize_tmp_bytes(n: usize) -> usize { - n * size_of::() -} - -pub fn zn_normalize_inplace(n: usize, base2k: usize, res: &mut R, res_col: usize, carry: &mut [i64]) -where - R: ZnToMut, - ARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeFinalStepInplace + ZnxNormalizeMiddleStepInplace, -{ - let mut res: Zn<&mut [u8]> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(carry.len(), res.n()); - } - - let res_size: usize = res.size(); - - for j in (0..res_size).rev() { - let out = &mut res.at_mut(res_col, j)[..n]; - - if j == res_size - 1 { - ARI::znx_normalize_first_step_inplace(base2k, 0, out, carry); - } else if j == 0 { - ARI::znx_normalize_final_step_inplace(base2k, 0, out, carry); - } else { - ARI::znx_normalize_middle_step_inplace(base2k, 0, out, carry); - } - } -} - -pub fn test_zn_normalize_inplace(module: &Module) -where - Module: ZnNormalizeInplace + ZnNormalizeTmpBytes, - ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, -{ - let mut source: Source = Source::new([0u8; 32]); - let cols: usize = 2; - let base2k: usize = 12; - - let n = 33; - - let mut carry: Vec = vec![0i64; zn_normalize_tmp_bytes(n)]; - - let mut scratch: ScratchOwned = ScratchOwned::alloc(module.zn_normalize_tmp_bytes(module.n())); - - for res_size in [1, 2, 6, 11] { - let mut res_0: Zn> = Zn::alloc(n, cols, res_size); - let mut res_1: Zn> = Zn::alloc(n, cols, res_size); - - res_0 - .raw_mut() - .iter_mut() - .for_each(|x| *x = source.next_i32() as i64); - res_1.raw_mut().copy_from_slice(res_0.raw()); - - // Reference - for i in 0..cols { - zn_normalize_inplace::<_, ZnxRef>(n, base2k, &mut res_0, i, &mut carry); - module.zn_normalize_inplace(n, base2k, &mut res_1, i, scratch.borrow()); - } - - assert_eq!(res_0.raw(), res_1.raw()); - } -} diff --git a/poulpy-hal/src/reference/zn/sampling.rs b/poulpy-hal/src/reference/zn/sampling.rs deleted file mode 100644 index b376dcc..0000000 --- a/poulpy-hal/src/reference/zn/sampling.rs +++ /dev/null @@ -1,75 +0,0 @@ -use crate::{ - layouts::{Zn, ZnToMut, ZnxInfos, ZnxViewMut}, - reference::znx::{znx_add_normal_f64_ref, znx_fill_normal_f64_ref, znx_fill_uniform_ref}, - source::Source, -}; - -pub fn zn_fill_uniform(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) -where - R: ZnToMut, -{ - let mut res: Zn<&mut [u8]> = res.to_mut(); - for j in 0..res.size() { - znx_fill_uniform_ref(base2k, &mut res.at_mut(res_col, j)[..n], source) - } -} - -#[allow(clippy::too_many_arguments)] -pub fn zn_fill_normal( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, -) where - R: ZnToMut, -{ - let mut res: Zn<&mut [u8]> = res.to_mut(); - assert!( - (bound.log2().ceil() as i64) < 64, - "invalid bound: ceil(log2(bound))={} > 63", - (bound.log2().ceil() as i64) - ); - - let limb: usize = k.div_ceil(base2k) - 1; - let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64; - znx_fill_normal_f64_ref( - &mut res.at_mut(res_col, limb)[..n], - sigma * scale, - bound * scale, - source, - ) -} - -#[allow(clippy::too_many_arguments)] -pub fn zn_add_normal( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, -) where - R: ZnToMut, -{ - let mut res: Zn<&mut [u8]> = res.to_mut(); - assert!( - (bound.log2().ceil() as i64) < 64, - "invalid bound: ceil(log2(bound))={} > 63", - (bound.log2().ceil() as i64) - ); - - let limb: usize = k.div_ceil(base2k) - 1; - let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64; - znx_add_normal_f64_ref( - &mut res.at_mut(res_col, limb)[..n], - sigma * scale, - bound * scale, - source, - ) -} diff --git a/poulpy-schemes/examples/circuit_bootstrapping.rs b/poulpy-schemes/examples/circuit_bootstrapping.rs index 55d779f..8ececbe 100644 --- a/poulpy-schemes/examples/circuit_bootstrapping.rs +++ b/poulpy-schemes/examples/circuit_bootstrapping.rs @@ -1,8 +1,8 @@ use poulpy_core::{ GLWENormalize, layouts::{ - GGLWEToGGSWKeyLayout, GGSW, GGSWLayout, GLWE, GLWEAutomorphismKeyLayout, GLWELayout, GLWEPlaintext, GLWESecret, LWE, - LWEInfos, LWELayout, LWEPlaintext, LWESecret, + GGLWEToGGSWKeyLayout, GGSW, GGSWInfos, GGSWLayout, GLWE, GLWEAutomorphismKeyLayout, GLWEInfos, GLWELayout, GLWEPlaintext, + GLWESecret, LWE, LWELayout, LWEPlaintext, LWESecret, prepared::{GGSWPrepared, GLWESecretPrepared}, }, }; @@ -15,7 +15,7 @@ use poulpy_backend::FFT64Avx as BackendImpl; use poulpy_backend::FFT64Ref as BackendImpl; use poulpy_hal::{ - api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace}, + api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalizeInplace}, layouts::{Module, ScalarZnx, ScratchOwned, ZnxView, ZnxViewMut}, source::Source, }; @@ -155,20 +155,21 @@ fn main() { pt_lwe.encode_i64(data, (k_lwe_pt + 1).into()); // +1 for padding bit // Normalize plaintext to nicely print coefficients - module.zn_normalize_inplace( - pt_lwe.n().into(), - base2k, - pt_lwe.data_mut(), - 0, - scratch.borrow(), - ); + module.vec_znx_normalize_inplace(base2k, pt_lwe.data_mut(), 0, scratch.borrow()); println!("pt_lwe: {pt_lwe}"); // LWE ciphertext let mut ct_lwe: LWE> = LWE::alloc_from_infos(&lwe_infos); // Encrypt LWE Plaintext - ct_lwe.encrypt_sk(&module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); + ct_lwe.encrypt_sk( + &module, + &pt_lwe, + &sk_lwe, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); let now: Instant = Instant::now(); @@ -211,7 +212,23 @@ fn main() { pt_ggsw.at_mut(0, 0)[0] = data; // Prints noise of GGSW(data) - res.print_noise(&module, &sk_glwe_prepared, &pt_ggsw); + for row in 0..res.dnum().as_usize() { + for col in 0..res.rank().as_usize() + 1 { + println!( + "row:{row} col:{col} -> {}", + res.noise( + &module, + row, + col, + &pt_ggsw, + &sk_glwe_prepared, + scratch.borrow() + ) + .std() + .log2() + ) + } + } // Tests RLWE(1) * GGSW(data) diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs index 59a745c..45b2fb3 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_2w_to_1w.rs @@ -98,27 +98,6 @@ impl<'a, T: UnsignedInteger, BE: Backend> BitSize for FheUintHelper<'a, T, BE> { } } -pub struct JoinedBits { - pub lo: A, - pub hi: B, - pub split: usize, // 32 in your example -} - -impl GetGGSWBit for JoinedBits -where - BE: Backend, - A: GetGGSWBit, - B: GetGGSWBit, -{ - fn get_bit(&self, bit: usize) -> GGSWPrepared<&[u8], BE> { - if bit < self.split { - self.lo.get_bit(bit) - } else { - self.hi.get_bit(bit - self.split) - } - } -} - #[macro_export] macro_rules! define_bdd_2w_to_1w_trait { ($(#[$meta:meta])* $vis:vis $trait_name:ident, $method_name:ident) => { diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs index adb4def..7853673 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs @@ -1,9 +1,10 @@ use poulpy_core::{ - GLWEAdd, GLWECopy, GLWEDecrypt, GLWEEncryptSk, GLWENoise, GLWEPacking, GLWERotate, GLWESub, GLWETrace, LWEFromGLWE, - ScratchTakeCore, + GLWEAdd, GLWECopy, GLWEDecrypt, GLWEEncryptSk, GLWEKeyswitch, GLWENoise, GLWEPacking, GLWERotate, GLWESub, GLWETrace, + LWEFromGLWE, ScratchTakeCore, layouts::{ - Base2K, Degree, GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEAutomorphismKeyHelper, GLWEInfos, GLWEPlaintextLayout, - GLWESecretPreparedToRef, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos, LWEToMut, Rank, TorusPrecision, + Base2K, Degree, GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEAutomorphismKeyHelper, GLWEInfos, GLWELayout, + GLWEPlaintextLayout, GLWESecretPreparedToRef, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos, LWEToMut, Rank, + TorusPrecision, }, }; use poulpy_hal::{ @@ -146,7 +147,7 @@ impl FheUint { data_bits[T::bit_index(i) << log_gap] = want.bit(i) as i64 } pt.encode_vec_i64(&data_bits, TorusPrecision(2)); - self.bits.noise(module, sk, &pt, scratch_1) + self.bits.noise(module, &pt, sk, scratch_1) } pub fn decrypt(&self, module: &M, sk: &S, scratch: &mut Scratch) -> T @@ -323,16 +324,41 @@ where impl ScratchTakeBDD for Scratch where Self: ScratchTakeCore {} impl FheUint { - pub fn get_bit_lwe(&self, module: &M, bit: usize, res: &mut R, ks: &K, scratch: &mut Scratch) - where + pub fn get_bit_lwe( + &self, + module: &M, + bit: usize, + res: &mut R, + ks_glwe: Option<&KGLWE>, + ks_lwe: &KLWE, + scratch: &mut Scratch, + ) where R: LWEToMut, - K: GGLWEPreparedToRef + GGLWEInfos, - M: ModuleLogN + LWEFromGLWE + GLWERotate, + KGLWE: GGLWEPreparedToRef + GGLWEInfos, + KLWE: GGLWEPreparedToRef + GGLWEInfos, + M: ModuleLogN + LWEFromGLWE + GLWERotate + GLWEKeyswitch, Scratch: ScratchTakeCore, { let log_gap: usize = module.log_n() - T::LOG_BITS as usize; - res.to_mut() - .from_glwe(module, self, T::bit_index(bit) << log_gap, ks, scratch); + if let Some(ks_glwe) = ks_glwe { + let (mut res_tmp, scratch_1) = scratch.take_glwe(&GLWELayout { + n: self.n(), + base2k: ks_lwe.base2k(), + k: ks_lwe.k().min(self.k()), + rank: ks_lwe.rank_out(), + }); + module.glwe_keyswitch(&mut res_tmp, self, ks_glwe, scratch_1); + res.to_mut().from_glwe( + module, + &res_tmp, + T::bit_index(bit) << log_gap, + ks_lwe, + scratch_1, + ); + } else { + res.to_mut() + .from_glwe(module, self, T::bit_index(bit) << log_gap, ks_lwe, scratch); + } } pub fn get_bit_glwe(&self, module: &M, bit: usize, res: &mut R, keys: &H, scratch: &mut Scratch) diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs index bc4f74c..8cad75b 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared.rs @@ -333,7 +333,7 @@ where K: BDDKeyHelper + BDDKeyInfos, { let bit_end = bit_start + bit_count; - let (cbt, ks) = key.get_cbt_key(); + let (cbt, ks_glwe, ks_lwe) = key.get_cbt_key(); assert!(bit_end <= T::BITS as usize); @@ -363,7 +363,14 @@ where let (mut tmp_ggsw, scratch_1) = scratch_thread.take_ggsw(ggsw_infos); let (mut tmp_lwe, scratch_2) = scratch_1.take_lwe(bits); for (local_bit, dst) in res_bits_chunk.iter_mut().enumerate() { - bits.get_bit_lwe(self, start + local_bit, &mut tmp_lwe, ks, scratch_2); + bits.get_bit_lwe( + self, + start + local_bit, + &mut tmp_lwe, + ks_glwe, + ks_lwe, + scratch_2, + ); cbt.execute_to_constant(self, &mut tmp_ggsw, &tmp_lwe, 1, 1, scratch_2); dst.prepare(self, &tmp_ggsw, scratch_2); } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs index d536d57..46e3254 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint_prepared_debug.rs @@ -13,8 +13,8 @@ use poulpy_core::{ layouts::{GGSWInfos, GGSWPreparedFactory, GLWEInfos, LWEInfos}, }; -use poulpy_hal::api::ModuleN; -use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; +use poulpy_hal::api::{ModuleN, ScratchTakeBasic}; +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, Stats, ZnxZero}; pub struct FheUintPreparedDebug { pub(crate) bits: Vec>, @@ -81,31 +81,29 @@ impl GGSWInfos for FheUintPreparedDebug { } impl FheUintPreparedDebug { - pub fn print_noise(&self, module: &M, sk: &S, want: T) + pub fn noise( + &self, + module: &M, + row: usize, + col: usize, + want: T, + sk: &S, + scratch: &mut Scratch, + ) -> Vec where S: GLWESecretPreparedToRef, M: GGSWNoise, + Scratch: ScratchTakeCore, { + let mut stats = Vec::new(); for (i, ggsw) in self.bits.iter().enumerate() { - use poulpy_hal::layouts::{ScalarZnx, ZnxViewMut}; - let mut pt_want = ScalarZnx::alloc(self.n().into(), 1); + use poulpy_hal::layouts::ZnxViewMut; + let (mut pt_want, scratch_1) = scratch.take_scalar_znx(self.n().into(), 1); + pt_want.zero(); pt_want.at_mut(0, 0)[0] = want.bit(i) as i64; - ggsw.print_noise(module, sk, &pt_want); - } - } - - pub fn assert_noise(&self, module: &M, sk: &S, want: T, max_noise: &F) - where - S: GLWESecretPreparedToRef, - M: GGSWNoise, - F: Fn(usize) -> f64, - { - for (i, ggsw) in self.bits.iter().enumerate() { - use poulpy_hal::layouts::{ScalarZnx, ZnxViewMut}; - let mut pt_want = ScalarZnx::alloc(self.n().into(), 1); - pt_want.at_mut(0, 0)[0] = want.bit(i) as i64; - ggsw.assert_noise(module, sk, &pt_want, max_noise); + stats.push(ggsw.noise(module, row, col, &pt_want, sk, scratch_1)); } + stats } } @@ -128,7 +126,14 @@ where let (_, scratch_1) = scratch.take_ggsw(res); let (mut tmp_lwe, scratch_2) = scratch_1.take_lwe(bits); for (bit, dst) in res.bits.iter_mut().enumerate() { - bits.get_bit_lwe(self, bit, &mut tmp_lwe, &key.ks, scratch_2); + bits.get_bit_lwe( + self, + bit, + &mut tmp_lwe, + key.ks_glwe.as_ref(), + &key.ks_lwe, + scratch_2, + ); key.cbt .execute_to_constant(self, dst, &tmp_lwe, 1, 1, scratch_2); } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs index cb09748..b036317 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs @@ -9,7 +9,11 @@ use crate::tfhe::{ }, }; -use poulpy_core::layouts::{GGLWEInfos, GLWEAutomorphismKeyHelper, GLWEAutomorphismKeyPrepared}; +use poulpy_core::GLWESwitchingKeyEncryptSk; +use poulpy_core::layouts::{ + GGLWEInfos, GLWEAutomorphismKeyHelper, GLWEAutomorphismKeyPrepared, GLWESecret, GLWESwitchingKey, GLWESwitchingKeyLayout, + GLWESwitchingKeyPrepared, +}; use poulpy_core::{ GLWEToLWESwitchingKeyEncryptSk, GetDistribution, ScratchTakeCore, layouts::{ @@ -24,13 +28,15 @@ use poulpy_hal::{ pub trait BDDKeyInfos { fn cbt_infos(&self) -> CircuitBootstrappingKeyLayout; - fn ks_infos(&self) -> GLWEToLWEKeyLayout; + fn ks_lwe_infos(&self) -> GLWEToLWEKeyLayout; + fn ks_glwe_infos(&self) -> Option; } #[derive(Debug, Clone, Copy)] pub struct BDDKeyLayout { pub cbt: CircuitBootstrappingKeyLayout, - pub ks: GLWEToLWEKeyLayout, + pub ks_glwe: Option, + pub ks_lwe: GLWEToLWEKeyLayout, } impl BDDKeyInfos for BDDKeyLayout { @@ -38,8 +44,12 @@ impl BDDKeyInfos for BDDKeyLayout { self.cbt } - fn ks_infos(&self) -> GLWEToLWEKeyLayout { - self.ks + fn ks_glwe_infos(&self) -> Option { + self.ks_glwe + } + + fn ks_lwe_infos(&self) -> GLWEToLWEKeyLayout { + self.ks_lwe } } @@ -49,7 +59,8 @@ where BRA: BlindRotationAlgo, { pub(crate) cbt: CircuitBootstrappingKey, - pub(crate) ks: GLWEToLWEKey, + pub(crate) ks_glwe: Option>, + pub(crate) ks_lwe: GLWEToLWEKey, } impl BDDKey, BRA> @@ -60,9 +71,16 @@ where where A: BDDKeyInfos, { + let ks_glwe: Option>> = if let Some(ks_infos) = &infos.ks_glwe_infos() { + Some(GLWESwitchingKey::alloc_from_infos(ks_infos)) + } else { + None + }; + Self { cbt: CircuitBootstrappingKey::alloc_from_infos(&infos.cbt_infos()), - ks: GLWEToLWEKey::alloc_from_infos(&infos.ks_infos()), + ks_glwe: ks_glwe, + ks_lwe: GLWEToLWEKey::alloc_from_infos(&infos.ks_lwe_infos()), } } } @@ -88,7 +106,7 @@ pub trait BDDKeyEncryptSk { impl BDDKeyEncryptSk for Module where - Self: CircuitBootstrappingKeyEncryptSk + GLWEToLWESwitchingKeyEncryptSk, + Self: CircuitBootstrappingKeyEncryptSk + GLWEToLWESwitchingKeyEncryptSk + GLWESwitchingKeyEncryptSk, Scratch: ScratchTakeCore, { fn bdd_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize @@ -96,7 +114,7 @@ where A: BDDKeyInfos, { self.circuit_bootstrapping_key_encrypt_sk_tmp_bytes(&infos.cbt_infos()) - .max(self.glwe_to_lwe_key_encrypt_sk_tmp_bytes(&infos.ks_infos())) + .max(self.glwe_to_lwe_key_encrypt_sk_tmp_bytes(&infos.ks_lwe_infos())) } fn bdd_key_encrypt_sk( @@ -112,8 +130,17 @@ where S0: LWESecretToRef + GetDistribution + LWEInfos, S1: GLWESecretToRef + GetDistribution + GLWEInfos, { - res.ks - .encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); + if let Some(key) = &mut res.ks_glwe { + let mut sk_out: GLWESecret> = GLWESecret::alloc(sk_glwe.n(), key.rank_out()); + sk_out.fill_ternary_prob(0.5, source_xe); + key.encrypt_sk(self, sk_glwe, &sk_out, source_xa, source_xe, scratch); + res.ks_lwe + .encrypt_sk(self, sk_lwe, &sk_out, source_xa, source_xe, scratch); + } else { + res.ks_lwe + .encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); + } + res.cbt .encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); } @@ -145,7 +172,8 @@ where BE: Backend, { pub(crate) cbt: CircuitBootstrappingKeyPrepared, - pub(crate) ks: GLWEToLWEKeyPrepared, + pub(crate) ks_glwe: Option>, + pub(crate) ks_lwe: GLWEToLWEKeyPrepared, } impl BDDKeyInfos for BDDKeyPrepared { @@ -156,13 +184,28 @@ impl BDDKeyInfos for BDDKeyPrep layout_tsk: self.cbt.tsk_infos(), } } - fn ks_infos(&self) -> GLWEToLWEKeyLayout { + fn ks_glwe_infos(&self) -> Option { + if let Some(ks_glwe) = &self.ks_glwe { + Some(GLWESwitchingKeyLayout { + n: ks_glwe.n(), + base2k: ks_glwe.base2k(), + k: ks_glwe.k(), + rank_in: ks_glwe.rank_in(), + rank_out: ks_glwe.rank_out(), + dnum: ks_glwe.dnum(), + dsize: ks_glwe.dsize(), + }) + } else { + None + } + } + fn ks_lwe_infos(&self) -> GLWEToLWEKeyLayout { GLWEToLWEKeyLayout { - n: self.ks.n(), - base2k: self.ks.base2k(), - k: self.ks.k(), - rank_in: self.ks.rank_in(), - dnum: self.ks.dnum(), + n: self.ks_lwe.n(), + base2k: self.ks_lwe.base2k(), + k: self.ks_lwe.k(), + rank_in: self.ks_lwe.rank_in(), + dnum: self.ks_lwe.dnum(), } } } @@ -187,9 +230,19 @@ where where A: BDDKeyInfos, { + let ks_glwe = if let Some(ks_glwe_infos) = &infos.ks_glwe_infos() { + Some(GLWESwitchingKeyPrepared::alloc_from_infos( + self, + ks_glwe_infos, + )) + } else { + None + }; + BDDKeyPrepared { cbt: CircuitBootstrappingKeyPrepared::alloc_from_infos(self, &infos.cbt_infos()), - ks: GLWEToLWEKeyPrepared::alloc_from_infos(self, &infos.ks_infos()), + ks_glwe, + ks_lwe: GLWEToLWEKeyPrepared::alloc_from_infos(self, &infos.ks_lwe_infos()), } } @@ -198,7 +251,7 @@ where A: BDDKeyInfos, { self.circuit_bootstrapping_key_prepare_tmp_bytes(&infos.cbt_infos()) - .max(self.prepare_glwe_to_lwe_key_tmp_bytes(&infos.ks_infos())) + .max(self.prepare_glwe_to_lwe_key_tmp_bytes(&infos.ks_lwe_infos())) } fn prepare_bdd_key(&self, res: &mut BDDKeyPrepared, other: &BDDKey, scratch: &mut Scratch) @@ -208,7 +261,16 @@ where Scratch: ScratchTakeCore, { res.cbt.prepare(self, &other.cbt, scratch); - res.ks.prepare(self, &other.ks, scratch); + + if let Some(key_prep) = &mut res.ks_glwe { + if let Some(other) = &other.ks_glwe { + key_prep.prepare(self, other, scratch); + } else { + panic!("incompatible keys: res has Some(ks_glwe) but other has none") + } + } + + res.ks_lwe.prepare(self, &other.ks_lwe, scratch); } } impl BDDKeyPreparedFactory for Module where @@ -231,9 +293,10 @@ impl BDDKeyHelper f &self, ) -> ( &CircuitBootstrappingKeyPrepared, + Option<&GLWESwitchingKeyPrepared>, &GLWEToLWEKeyPrepared, ) { - (&self.cbt, &self.ks) + (&self.cbt, self.ks_glwe.as_ref(), &self.ks_lwe) } } @@ -242,6 +305,7 @@ pub trait BDDKeyHelper { &self, ) -> ( &CircuitBootstrappingKeyPrepared, + Option<&GLWESwitchingKeyPrepared>, &GLWEToLWEKeyPrepared, ); } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/fheuint.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/fheuint.rs index 7d19305..f1b8c7a 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/fheuint.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/fheuint.rs @@ -49,8 +49,6 @@ where a_enc.sext(module, j, keys, scratch.borrow()); - // println!("{:08x} -> {:08x} {:08x}", a, sext(a, j), a_enc.decrypt(module, sk, scratch.borrow())); - assert_eq!( sext(a, ((1 + j as u32) << 3) - 1), a_enc.decrypt(module, sk, scratch.borrow()) @@ -70,8 +68,6 @@ where a_enc.sext(module, j, keys, scratch.borrow()); - // println!("{:08x} -> {:08x} {:08x}", a, sext(a, j), a_enc.decrypt(module, sk, scratch.borrow())); - assert_eq!( sext(a, ((1 + j as u32) << 3) - 1), a_enc.decrypt(module, sk, scratch.borrow()) diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs index 42ea9aa..d265a6c 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs @@ -1,8 +1,8 @@ use poulpy_core::{ GGSWEncryptSk, GGSWNoise, GLWEDecrypt, GLWEEncryptSk, SIGMA, ScratchTakeCore, layouts::{ - Base2K, Dnum, Dsize, GGSW, GGSWLayout, GGSWPreparedFactory, GLWESecretPrepared, GLWESecretPreparedFactory, LWEInfos, - Rank, TorusPrecision, + Base2K, Dnum, Dsize, GGSW, GGSWInfos, GGSWLayout, GGSWPreparedFactory, GLWEInfos, GLWESecretPrepared, + GLWESecretPreparedFactory, LWEInfos, Rank, TorusPrecision, }, }; use poulpy_hal::{ @@ -15,7 +15,7 @@ use rand::RngCore; use crate::tfhe::{ bdd_arithmetic::{ FheUintPrepared, GGSWBlindRotation, - tests::test_suite::{TEST_BASE2K, TEST_RANK, TestContext}, + tests::test_suite::{TEST_FHEUINT_BASE2K, TEST_RANK, TestContext}, }, blind_rotation::BlindRotationAlgo, }; @@ -37,7 +37,7 @@ where let module: &Module = &test_context.module; let sk_glwe_prep: &GLWESecretPrepared, BE> = &test_context.sk_glwe; - let base2k: Base2K = TEST_BASE2K.into(); + let base2k: Base2K = TEST_FHEUINT_BASE2K.into(); let rank: Rank = TEST_RANK.into(); let k_ggsw_res: TorusPrecision = TorusPrecision(39); let k_ggsw_apply: TorusPrecision = TorusPrecision(52); @@ -76,8 +76,6 @@ where let k: u32 = source.next_u32(); - // println!("k: {k}"); - let mut k_enc_prep: FheUintPrepared, u32, BE> = FheUintPrepared::, u32, BE>::alloc_from_infos(module, &ggsw_k_infos); k_enc_prep.encrypt_sk( @@ -133,9 +131,23 @@ where module.vec_znx_rotate_inplace(-rot, &mut scalar_want.as_vec_znx_mut(), 0, scratch.borrow()); - // res.print_noise(&module, &sk_glwe_prep, &scalar_want); - - res.assert_noise(module, sk_glwe_prep, &scalar_want, &max_noise); + for row in 0..res.dnum().as_usize() { + for col in 0..res.rank().as_usize() + 1 { + assert!( + res.noise( + module, + row, + col, + &scalar_want, + sk_glwe_prep, + scratch.borrow() + ) + .std() + .log2() + <= max_noise(col) + ) + } + } bit_step += digit; bit_start += digit; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs index 54baf97..394d30e 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs @@ -15,7 +15,7 @@ use rand::RngCore; use crate::tfhe::{ bdd_arithmetic::{ FheUintPrepared, GLWEBlindRotation, - tests::test_suite::{TEST_BASE2K, TEST_RANK, TestContext}, + tests::test_suite::{TEST_FHEUINT_BASE2K, TEST_RANK, TestContext}, }, blind_rotation::BlindRotationAlgo, }; @@ -35,7 +35,7 @@ where let module: &Module = &test_context.module; let sk_glwe_prep: &GLWESecretPrepared, BE> = &test_context.sk_glwe; - let base2k: Base2K = TEST_BASE2K.into(); + let base2k: Base2K = TEST_FHEUINT_BASE2K.into(); let rank: Rank = TEST_RANK.into(); let k_glwe: TorusPrecision = TorusPrecision(26); let k_ggsw: TorusPrecision = TorusPrecision(39); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_selection.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_selection.rs index d453789..c9c2e66 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_selection.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_selection.rs @@ -17,7 +17,7 @@ use rand::RngCore; use crate::tfhe::{ bdd_arithmetic::{ FheUintPrepared, GLWEBlinSelection, - tests::test_suite::{TEST_BASE2K, TEST_RANK, TestContext}, + tests::test_suite::{TEST_FHEUINT_BASE2K, TEST_RANK, TestContext}, }, blind_rotation::BlindRotationAlgo, }; @@ -37,7 +37,7 @@ where let module: &Module = &test_context.module; let sk_glwe_prep: &GLWESecretPrepared, BE> = &test_context.sk_glwe; - let base2k: Base2K = TEST_BASE2K.into(); + let base2k: Base2K = TEST_FHEUINT_BASE2K.into(); let rank: Rank = TEST_RANK.into(); let k_glwe: TorusPrecision = TorusPrecision(26); let k_ggsw: TorusPrecision = TorusPrecision(39); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs index 74fb6ce..a11a51e 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs @@ -41,7 +41,8 @@ use poulpy_core::{ ScratchTakeCore, layouts::{ Base2K, Degree, Dnum, Dsize, GGLWEToGGSWKeyLayout, GGSWLayout, GLWEAutomorphismKeyLayout, GLWELayout, GLWESecret, - GLWESecretPrepared, GLWESecretPreparedFactory, GLWEToLWEKeyLayout, LWESecret, Rank, TorusPrecision, + GLWESecretPrepared, GLWESecretPreparedFactory, GLWESwitchingKeyLayout, GLWEToLWEKeyLayout, LWESecret, Rank, + TorusPrecision, }, }; @@ -137,7 +138,11 @@ impl TestContext { pub(crate) const TEST_N_GLWE: u32 = 256; pub(crate) const TEST_N_LWE: u32 = 77; -pub(crate) const TEST_BASE2K: u32 = 13; +pub(crate) const TEST_FHEUINT_BASE2K: u32 = 13; +pub(crate) const TEST_BRK_BASE2K: u32 = 12; +pub(crate) const TEST_ATK_BASE2K: u32 = 11; +pub(crate) const TEST_TSK_BASE2K: u32 = 10; +pub(crate) const TEST_LWE_BASE2K: u32 = 4; pub(crate) const TEST_K_GLWE: u32 = 26; pub(crate) const TEST_K_GGSW: u32 = 39; pub(crate) const TEST_BLOCK_SIZE: u32 = 7; @@ -145,14 +150,14 @@ pub(crate) const TEST_RANK: u32 = 2; pub(crate) static TEST_GLWE_INFOS: GLWELayout = GLWELayout { n: Degree(TEST_N_GLWE), - base2k: Base2K(TEST_BASE2K), + base2k: Base2K(TEST_FHEUINT_BASE2K), k: TorusPrecision(TEST_K_GLWE), rank: Rank(TEST_RANK), }; pub(crate) static TEST_GGSW_INFOS: GGSWLayout = GGSWLayout { n: Degree(TEST_N_GLWE), - base2k: Base2K(TEST_BASE2K), + base2k: Base2K(TEST_FHEUINT_BASE2K), k: TorusPrecision(TEST_K_GGSW), rank: Rank(TEST_RANK), dnum: Dnum(2), @@ -164,33 +169,42 @@ pub(crate) static TEST_BDD_KEY_LAYOUT: BDDKeyLayout = BDDKeyLayout { layout_brk: BlindRotationKeyLayout { n_glwe: Degree(TEST_N_GLWE), n_lwe: Degree(TEST_N_LWE), - base2k: Base2K(TEST_BASE2K), + base2k: Base2K(TEST_BRK_BASE2K), k: TorusPrecision(52), - dnum: Dnum(3), + dnum: Dnum(4), rank: Rank(TEST_RANK), }, layout_atk: GLWEAutomorphismKeyLayout { n: Degree(TEST_N_GLWE), - base2k: Base2K(TEST_BASE2K), + base2k: Base2K(TEST_ATK_BASE2K), k: TorusPrecision(52), rank: Rank(TEST_RANK), - dnum: Dnum(3), + dnum: Dnum(4), dsize: Dsize(1), }, layout_tsk: GGLWEToGGSWKeyLayout { n: Degree(TEST_N_GLWE), - base2k: Base2K(TEST_BASE2K), + base2k: Base2K(TEST_TSK_BASE2K), k: TorusPrecision(52), rank: Rank(TEST_RANK), - dnum: Dnum(3), + dnum: Dnum(4), dsize: Dsize(1), }, }, - ks: GLWEToLWEKeyLayout { + ks_glwe: Some(GLWESwitchingKeyLayout { n: Degree(TEST_N_GLWE), - base2k: Base2K(TEST_BASE2K), - k: TorusPrecision(39), + base2k: Base2K(TEST_LWE_BASE2K), + k: TorusPrecision(20), rank_in: Rank(TEST_RANK), - dnum: Dnum(2), + rank_out: Rank(1), + dnum: Dnum(3), + dsize: Dsize(1), + }), + ks_lwe: GLWEToLWEKeyLayout { + n: Degree(TEST_N_GLWE), + base2k: Base2K(TEST_LWE_BASE2K), + k: TorusPrecision(16), + rank_in: Rank(1), + dnum: Dnum(3), }, }; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/prepare.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/prepare.rs index c99dd2a..384bf3f 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/prepare.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/prepare.rs @@ -1,10 +1,10 @@ use poulpy_core::{ GGSWNoise, GLWEDecrypt, GLWEEncryptSk, GLWENoise, SIGMA, ScratchTakeCore, - layouts::{GGSWLayout, GLWELayout, GLWESecretPreparedFactory, LWEInfos, prepared::GLWESecretPrepared}, + layouts::{GGSWInfos, GGSWLayout, GLWEInfos, GLWELayout, GLWESecretPreparedFactory, LWEInfos, prepared::GLWESecretPrepared}, }; use poulpy_hal::{ api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow}, - layouts::{Backend, Module, Scratch, ScratchOwned}, + layouts::{Backend, Module, Scratch, ScratchOwned, Stats}, source::Source, }; use rand::RngCore; @@ -13,7 +13,7 @@ use crate::tfhe::{ bdd_arithmetic::{ BDDKeyEncryptSk, BDDKeyPrepared, BDDKeyPreparedFactory, ExecuteBDDCircuit2WTo1W, FheUint, FheUintPrepare, FheUintPrepareDebug, FheUintPreparedDebug, FheUintPreparedEncryptSk, FheUintPreparedFactory, - tests::test_suite::{TEST_BASE2K, TEST_GGSW_INFOS, TEST_GLWE_INFOS, TestContext}, + tests::test_suite::{TEST_GGSW_INFOS, TEST_GLWE_INFOS, TestContext}, }, blind_rotation::{BlindRotationAlgo, BlindRotationKey, BlindRotationKeyFactory}, }; @@ -73,7 +73,7 @@ where c_enc_prep_debug.prepare(module, &c_enc, bdd_key_prepared, scratch_2.borrow()); let max_noise = |col_i: usize| { - let mut noise: f64 = -(ggsw_infos.size() as f64 * TEST_BASE2K as f64) + SIGMA.log2() + 2.0; + let mut noise: f64 = -(ggsw_infos.size() as f64 * ggsw_infos.base2k().as_usize() as f64) + SIGMA.log2() + 2.0; noise += 0.5 * ggsw_infos.log_n() as f64; if col_i != 0 { noise += 0.5 * ggsw_infos.log_n() as f64 @@ -81,7 +81,17 @@ where noise }; - // c_enc_prep_debug.print_noise(module, sk_glwe_prep, value); - - c_enc_prep_debug.assert_noise(module, sk_glwe_prep, value, &max_noise); + for row in 0..c_enc_prep_debug.dnum().as_usize() { + for col in 0..c_enc_prep_debug.rank().as_usize() + 1 { + let stats: Vec = c_enc_prep_debug.noise(module, row, col, value, sk_glwe_prep, scratch.borrow()); + for (i, stat) in stats.iter().enumerate() { + let noise_have: f64 = stat.std().log2(); + let noise_max: f64 = max_noise(col); + assert!( + noise_have <= noise_max, + "bit: {i} noise_have: {noise_have} > noise_max: {noise_max}" + ) + } + } + } } diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs index 1651987..9f91538 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs @@ -111,7 +111,14 @@ pub fn test_blind_rotation( pt_lwe.encode_i64(x, (log_message_modulus + 1).into()); - lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); + lwe.encrypt_sk( + module, + &pt_lwe, + &sk_lwe, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); let f = |x: i64| -> i64 { 2 * x + 1 }; diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs index 62b706a..0bb2d83 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs @@ -2,14 +2,14 @@ use std::collections::HashMap; use poulpy_hal::{ api::{ModuleLogN, ModuleN, ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow}, - layouts::{Backend, DataRef, Module, Scratch, ScratchOwned, ToOwnedDeep}, + layouts::{Backend, DataRef, Module, Scratch, ScratchOwned}, }; use poulpy_core::{ - GGSWFromGGLWE, GLWEDecrypt, GLWEPacking, GLWERotate, GLWETrace, ScratchTakeCore, + GGSWExpandRows, GGSWFromGGLWE, GLWECopy, GLWEDecrypt, GLWENormalize, GLWEPacking, GLWERotate, GLWETrace, ScratchTakeCore, layouts::{ Dsize, GGLWE, GGLWEInfos, GGLWELayout, GGLWEPreparedToRef, GGSWInfos, GGSWToMut, GLWEAutomorphismKeyHelper, GLWEInfos, - GLWESecretPreparedFactory, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos, LWEToRef, Rank, + GLWELayout, GLWESecretPreparedFactory, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos, LWEToRef, Rank, }, }; @@ -114,10 +114,12 @@ where + BlindRotationExecute + GLWETrace + GLWEPacking - + GGSWFromGGLWE + GLWESecretPreparedFactory + GLWEDecrypt - + GLWERotate, + + GLWERotate + + GLWENormalize + + GLWECopy + + GGSWExpandRows, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchTakeCore, { @@ -240,7 +242,10 @@ pub fn circuit_bootstrap_core( + GLWESecretPreparedFactory + GLWEDecrypt + GLWERotate - + ModuleLogN, + + ModuleLogN + + GLWENormalize + + GLWECopy + + GGSWExpandRows, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchTakeCore, { @@ -248,27 +253,22 @@ pub fn circuit_bootstrap_core( let lwe: &LWE<&[u8]> = &lwe.to_ref(); assert_eq!(res.n(), key.brk.n()); - assert_eq!(lwe.base2k(), key.brk.base2k()); - assert_eq!(res.base2k(), key.brk.base2k()); - let n: usize = res.n().into(); - let base2k: usize = res.base2k().into(); - let dnum: usize = res.dnum().into(); - let rank: usize = res.rank().into(); - let k: usize = res.k().into(); + let base2k_res: usize = res.base2k().as_usize(); + let dnum_res: usize = res.dnum().into(); - let alpha: usize = dnum.next_power_of_two(); + let alpha: usize = dnum_res.next_power_of_two(); let mut f: Vec = vec![0i64; (1 << log_domain) * alpha]; if to_exponent { - (0..dnum).for_each(|i| { - f[i] = 1 << (base2k * (dnum - 1 - i)); + (0..dnum_res).for_each(|i| { + f[i] = 1 << (base2k_res * (dnum_res - 1 - i)); }); } else { (0..1 << log_domain).for_each(|j| { - (0..dnum).for_each(|i| { - f[j * alpha + i] = j as i64 * (1 << (base2k * (dnum - 1 - i))); + (0..dnum_res).for_each(|i| { + f[j * alpha + i] = j as i64 * (1 << (base2k_res * (dnum_res - 1 - i))); }); }); } @@ -276,71 +276,79 @@ pub fn circuit_bootstrap_core( let lut_infos: LookUpTableLayout = LookUpTableLayout { n: module.n().into(), extension_factor, - k: (base2k * dnum).into(), - base2k: base2k.into(), + k: (base2k_res * dnum_res).into(), + base2k: key.brk.base2k(), }; // Lut precision, basically must be able to hold the decomposition power basis of the GGSW let mut lut: LookupTable = LookupTable::alloc(&lut_infos); - lut.set(module, &f, base2k * dnum); + lut.set(module, &f, base2k_res * dnum_res); if to_exponent { lut.set_rotation_direction(LookUpTableRotationDirection::Right); } - // TODO: separate GGSW k from output of blind rotation k - let (mut res_glwe, scratch_1) = scratch.take_glwe(res); - - let gglwe_infos: GGLWELayout = GGLWELayout { - n: n.into(), - base2k: base2k.into(), - k: k.into(), - dnum: dnum.into(), - dsize: Dsize(1), - rank_in: rank.max(1).into(), - rank_out: rank.into(), + let glwe_brk_layout = &GLWELayout { + n: key.brk.n(), + base2k: key.brk.base2k(), + k: key.brk.k(), + rank: key.brk.rank(), }; - let (mut tmp_gglwe, scratch_2) = scratch_1.take_gglwe(&gglwe_infos); + let atk_layout: &GGLWELayout = &key.atk.automorphism_key_infos(); - key.brk.execute(module, &mut res_glwe, lwe, &lut, scratch_2); + let glwe_atk_layout: &GLWELayout = &GLWELayout { + n: glwe_brk_layout.n(), + base2k: atk_layout.base2k(), + k: glwe_brk_layout.k(), + rank: glwe_brk_layout.rank(), + }; + + let (mut res_glwe_atk_layout, scratch_1) = scratch.take_glwe(glwe_atk_layout); + + // Execute blind rotation over BRK layout and returns result over ATK layout + { + let (mut res_glwe_brk_layout, scratch_2) = scratch_1.take_glwe(glwe_brk_layout); + key.brk + .execute(module, &mut res_glwe_brk_layout, lwe, &lut, scratch_2); + + if res_glwe_brk_layout.base2k() == res_glwe_atk_layout.base2k() { + module.glwe_copy(&mut res_glwe_atk_layout, &res_glwe_brk_layout); + } else { + module.glwe_normalize(&mut res_glwe_atk_layout, &res_glwe_brk_layout, scratch_2); + } + } let gap: usize = 2 * lut.drift / lut.extension_factor(); let log_gap_in: usize = (usize::BITS - (gap * alpha - 1).leading_zeros()) as _; - (0..dnum).for_each(|i| { - let mut tmp_glwe: GLWE<&mut [u8]> = tmp_gglwe.at_mut(i, 0); + for i in 0..dnum_res { + let mut res_row: GLWE<&mut [u8]> = res.at_mut(i, 0); if to_exponent { // Isolates i-th LUT and moves coefficients according to requested gap. post_process( module, - &mut tmp_glwe, - &res_glwe, + &mut res_row, + &res_glwe_atk_layout, log_gap_in, log_gap_out, log_domain, &key.atk, - scratch_2, + scratch_1, ); } else { - tmp_glwe.trace(module, 0, &res_glwe, &key.atk, scratch_2); + module.glwe_trace(&mut res_row, 0, &res_glwe_atk_layout, &key.atk, scratch_1); } - // let sk_glwe: &poulpy_core::layouts::GLWESecret<&[u8]> = &sk_glwe.to_ref(); - // let sk_glwe_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, sk_glwe.rank()); - // let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&res_glwe); - // res_glwe.decrypt(module, &mut pt, &sk_glwe_prepared, scratch_2); - // println!("pt[{i}]: {}", pt); - - if i < dnum { - module.glwe_rotate_inplace(-(gap as i64), &mut res_glwe, scratch_2); + if i < dnum_res { + module.glwe_rotate_inplace(-(gap as i64), &mut res_glwe_atk_layout, scratch_1); } - }); + } // Expands GGLWE to GGSW using GGLWE(s^2) - res.from_gglwe(module, &tmp_gglwe, &key.tsk, scratch_2); + module.ggsw_expand_row(res, &key.tsk, scratch); } #[allow(clippy::too_many_arguments)] @@ -354,47 +362,48 @@ fn post_process( auto_keys: &H, scratch: &mut Scratch, ) where - R: GLWEToMut, - A: GLWEToRef, + R: GLWEToMut + GLWEInfos, + A: GLWEToRef + GLWEInfos, H: GLWEAutomorphismKeyHelper, K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - M: ModuleLogN + GLWETrace + GLWEPacking + GLWERotate, + M: ModuleLogN + GLWETrace + GLWEPacking + GLWERotate + GLWECopy, Scratch: ScratchTakeCore, { - let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let a: &GLWE<&[u8]> = &a.to_ref(); - - let mut cts: HashMap>> = HashMap::new(); - - // First partial trace, vanishes all coefficients which are not multiples of gap_in - // [1, 1, 1, 1, 0, 0, 0, ..., 0, 0, -1, -1, -1, -1] -> [1, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0] - res.trace( - module, - module.log_n() - log_gap_in + 1, - a, - auto_keys, - scratch, - ); - // TODO: optimize with packing and final partial trace // If gap_out < gap_in, then we need to repack, i.e. reduce the cap between coefficients. if log_gap_in != log_gap_out { + let (mut a_trace, scratch_1) = scratch.take_glwe(a); + + // First partial trace, vanishes all coefficients which are not multiples of gap_in + // [1, 1, 1, 1, 0, 0, 0, ..., 0, 0, -1, -1, -1, -1] -> [1, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0] + module.glwe_trace( + &mut a_trace, + module.log_n() - log_gap_in + 1, + a, + auto_keys, + scratch_1, + ); + let steps: usize = 1 << log_domain; // TODO: from Scratch - let mut cts_vec: Vec>> = Vec::new(); + let (mut cts_vec, scratch_2) = scratch_1.take_glwe_slice(steps, a); - for i in 0..steps { + for (i, ct) in cts_vec.iter_mut().enumerate().take(steps) { if i != 0 { - module.glwe_rotate_inplace(-(1 << log_gap_in), res, scratch); + module.glwe_rotate_inplace(-(1 << log_gap_in), &mut a_trace, scratch_2); } - cts_vec.push(res.to_owned_deep()); + + module.glwe_copy(ct, &a_trace); } + let mut cts: HashMap> = HashMap::new(); for (i, ct) in cts_vec.iter_mut().enumerate().take(steps) { cts.insert(i * (1 << log_gap_out), ct); } - module.glwe_pack(res, cts, log_gap_out, auto_keys, scratch); + module.glwe_pack(res, cts, log_gap_out, auto_keys, scratch_2); + } else { + module.glwe_trace(res, module.log_n() - log_gap_in + 1, a, auto_keys, scratch); } } diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs index a75d211..c55f2a7 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs @@ -17,8 +17,8 @@ use crate::tfhe::{ use poulpy_core::{ GGSWNoise, GLWEDecrypt, GLWEEncryptSk, GLWEExternalProduct, LWEEncryptSk, ScratchTakeCore, layouts::{ - Dsize, GGLWEToGGSWKeyLayout, GGSWLayout, GGSWPreparedFactory, GLWEAutomorphismKeyLayout, GLWESecretPreparedFactory, - LWELayout, + Dsize, GGLWEToGGSWKeyLayout, GGSWInfos, GGSWLayout, GGSWPreparedFactory, GLWEAutomorphismKeyLayout, GLWEInfos, + GLWESecretPreparedFactory, LWELayout, }, }; @@ -46,7 +46,11 @@ where Scratch: ScratchTakeCore, { let n_glwe: usize = module.n(); - let base2k: usize = 17; + let base2k_res: usize = 15; + let base2k_lwe: usize = 14; + let base2k_brk: usize = 13; + let base2k_tsk: usize = 12; + let base2k_atk: usize = 11; let extension_factor: usize = 1; let rank: usize = 1; @@ -55,36 +59,36 @@ where let k_lwe_ct: usize = 22; let block_size: usize = 7; - let k_brk: usize = 5 * base2k; + let k_ggsw_res: usize = 4 * base2k_res; + let rows_ggsw_res: usize = 3; + + let k_brk: usize = k_ggsw_res + base2k_brk; let rows_brk: usize = 4; - let k_atk: usize = 5 * base2k; + let k_atk: usize = k_ggsw_res + base2k_tsk; let rows_atk: usize = 4; - let k_tsk: usize = 5 * base2k; + let k_tsk: usize = k_ggsw_res + base2k_atk; let rows_tsk: usize = 4; - let k_ggsw_res: usize = 4 * base2k; - let rows_ggsw_res: usize = 2; - let lwe_infos: LWELayout = LWELayout { n: n_lwe.into(), k: k_lwe_ct.into(), - base2k: base2k.into(), + base2k: base2k_lwe.into(), }; let cbt_infos: CircuitBootstrappingKeyLayout = CircuitBootstrappingKeyLayout { layout_brk: BlindRotationKeyLayout { n_glwe: n_glwe.into(), n_lwe: n_lwe.into(), - base2k: base2k.into(), + base2k: base2k_brk.into(), k: k_brk.into(), dnum: rows_brk.into(), rank: rank.into(), }, layout_atk: GLWEAutomorphismKeyLayout { n: n_glwe.into(), - base2k: base2k.into(), + base2k: base2k_atk.into(), k: k_atk.into(), dnum: rows_atk.into(), rank: rank.into(), @@ -92,7 +96,7 @@ where }, layout_tsk: GGLWEToGGSWKeyLayout { n: n_glwe.into(), - base2k: base2k.into(), + base2k: base2k_tsk.into(), k: k_tsk.into(), dnum: rows_tsk.into(), dsize: Dsize(1), @@ -102,7 +106,7 @@ where let ggsw_infos: GGSWLayout = GGSWLayout { n: n_glwe.into(), - base2k: base2k.into(), + base2k: base2k_res.into(), k: k_ggsw_res.into(), dnum: rows_ggsw_res.into(), dsize: Dsize(1), @@ -126,13 +130,20 @@ where let data: i64 = 1; - let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(base2k.into(), k_lwe_pt.into()); + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(base2k_lwe.into(), k_lwe_pt.into()); pt_lwe.encode_i64(data, (k_lwe_pt + 1).into()); println!("pt_lwe: {pt_lwe}"); let mut ct_lwe: LWE> = LWE::alloc_from_infos(&lwe_infos); - ct_lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); + ct_lwe.encrypt_sk( + module, + &pt_lwe, + &sk_lwe, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); let now: Instant = Instant::now(); let mut cbt_key: CircuitBootstrappingKey, BRA> = CircuitBootstrappingKey::alloc_from_infos(&cbt_infos); @@ -179,11 +190,26 @@ where scratch.borrow(), ); - res.print_noise(module, &sk_glwe_prepared, &pt_ggsw); - + for row in 0..res.dnum().as_usize() { + for col in 0..res.rank().as_usize() + 1 { + println!( + "row:{row} col:{col} -> {}", + res.noise( + module, + row, + col, + &pt_ggsw, + &sk_glwe_prepared, + scratch.borrow() + ) + .std() + .log2() + ) + } + } let mut ct_glwe: GLWE> = GLWE::alloc_from_infos(&ggsw_infos); let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ggsw_infos); - pt_glwe.data.at_mut(0, 0)[0] = 1 << (base2k - 2); + pt_glwe.data.at_mut(0, 0)[0] = 1 << (base2k_res - 2); ct_glwe.encrypt_sk( module, @@ -227,7 +253,11 @@ where Scratch: ScratchTakeCore, { let n_glwe: usize = module.n(); - let base2k: usize = 14; + let base2k_res: usize = 15; + let base2k_lwe: usize = 14; + let base2k_brk: usize = 13; + let base2k_tsk: usize = 12; + let base2k_atk: usize = 11; let extension_factor: usize = 1; let rank: usize = 1; @@ -236,36 +266,36 @@ where let k_lwe_ct: usize = 13; let block_size: usize = 7; - let k_brk: usize = 5 * base2k; - let rows_brk: usize = 3; + let k_ggsw_res: usize = 4 * base2k_res; + let rows_ggsw_res: usize = 3; - let k_atk: usize = 5 * base2k; + let k_brk: usize = k_ggsw_res + base2k_brk; + let rows_brk: usize = 4; + + let k_atk: usize = k_ggsw_res + base2k_tsk; let rows_atk: usize = 4; - let k_tsk: usize = 5 * base2k; + let k_tsk: usize = k_ggsw_res + base2k_atk; let rows_tsk: usize = 4; - let k_ggsw_res: usize = 4 * base2k; - let rows_ggsw_res: usize = 3; - let lwe_infos: LWELayout = LWELayout { n: n_lwe.into(), k: k_lwe_ct.into(), - base2k: base2k.into(), + base2k: base2k_lwe.into(), }; let cbt_infos: CircuitBootstrappingKeyLayout = CircuitBootstrappingKeyLayout { layout_brk: BlindRotationKeyLayout { n_glwe: n_glwe.into(), n_lwe: n_lwe.into(), - base2k: base2k.into(), + base2k: base2k_brk.into(), k: k_brk.into(), dnum: rows_brk.into(), rank: rank.into(), }, layout_atk: GLWEAutomorphismKeyLayout { n: n_glwe.into(), - base2k: base2k.into(), + base2k: base2k_atk.into(), k: k_atk.into(), dnum: rows_atk.into(), rank: rank.into(), @@ -273,7 +303,7 @@ where }, layout_tsk: GGLWEToGGSWKeyLayout { n: n_glwe.into(), - base2k: base2k.into(), + base2k: base2k_tsk.into(), k: k_tsk.into(), dnum: rows_tsk.into(), dsize: Dsize(1), @@ -283,7 +313,7 @@ where let ggsw_infos: GGSWLayout = GGSWLayout { n: n_glwe.into(), - base2k: base2k.into(), + base2k: base2k_res.into(), k: k_ggsw_res.into(), dnum: rows_ggsw_res.into(), dsize: Dsize(1), @@ -307,13 +337,20 @@ where let data: i64 = 1; - let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(base2k.into(), k_lwe_pt.into()); + let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(base2k_lwe.into(), k_lwe_pt.into()); pt_lwe.encode_i64(data, (k_lwe_pt + 1).into()); println!("pt_lwe: {pt_lwe}"); let mut ct_lwe: LWE> = LWE::alloc_from_infos(&lwe_infos); - ct_lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); + ct_lwe.encrypt_sk( + module, + &pt_lwe, + &sk_lwe, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); let now: Instant = Instant::now(); let mut cbt_key: CircuitBootstrappingKey, BRA> = CircuitBootstrappingKey::alloc_from_infos(&cbt_infos); @@ -351,11 +388,27 @@ where let mut pt_ggsw: ScalarZnx> = ScalarZnx::alloc(n_glwe, 1); pt_ggsw.at_mut(0, 0)[0] = data; - res.print_noise(module, &sk_glwe_prepared, &pt_ggsw); + for row in 0..res.dnum().as_usize() { + for col in 0..res.rank().as_usize() + 1 { + println!( + "row:{row} col:{col} -> {}", + res.noise( + module, + row, + col, + &pt_ggsw, + &sk_glwe_prepared, + scratch.borrow() + ) + .std() + .log2() + ) + } + } let mut ct_glwe: GLWE> = GLWE::alloc_from_infos(&ggsw_infos); let mut pt_glwe: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ggsw_infos); - pt_glwe.data.at_mut(0, 0)[0] = 1 << (base2k - k_lwe_pt - 1); + pt_glwe.data.at_mut(0, 0)[0] = 1 << (base2k_res - k_lwe_pt - 1); ct_glwe.encrypt_sk( module, @@ -377,5 +430,6 @@ where // Parameters are set such that the first limb should be noiseless. let mut pt_want: Vec = vec![0i64; module.n()]; pt_want[0] = pt_glwe.data.at(0, 0)[0] * data; + println!("pt_res: {pt_res}"); assert_eq!(pt_res.data.at(0, 0), pt_want); }