From f39e3e28655239ee9011a16842be35188d1afb1a Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Tue, 18 Nov 2025 01:08:20 +0100 Subject: [PATCH] Remove Zn (replaced by VecZnx), add more cross-base2k ops & tests --- poulpy-backend/examples/rlwe_encrypt.rs | 12 +- poulpy-backend/src/cpu_fft64_avx/mod.rs | 1 - poulpy-backend/src/cpu_fft64_avx/zn.rs | 73 ---- poulpy-backend/src/cpu_fft64_ref/mod.rs | 1 - poulpy-backend/src/cpu_fft64_ref/zn.rs | 73 ---- poulpy-backend/src/cpu_spqlios/fft64/mod.rs | 1 - poulpy-backend/src/cpu_spqlios/fft64/zn.rs | 82 ---- poulpy-core/src/automorphism/gglwe_atk.rs | 38 +- poulpy-core/src/automorphism/ggsw_ct.rs | 19 +- poulpy-core/src/automorphism/glwe_ct.rs | 352 +++++++++++++----- poulpy-core/src/conversion/gglwe_to_ggsw.rs | 219 ++++++----- poulpy-core/src/decryption/lwe.rs | 51 ++- poulpy-core/src/encryption/lwe.rs | 73 ++-- poulpy-core/src/external_product/gglwe.rs | 23 +- poulpy-core/src/external_product/ggsw.rs | 2 + poulpy-core/src/keyswitching/gglwe.rs | 17 +- poulpy-core/src/keyswitching/ggsw.rs | 3 +- poulpy-core/src/keyswitching/glwe.rs | 20 +- poulpy-core/src/keyswitching/lwe.rs | 24 +- poulpy-core/src/layouts/compressed/lwe.rs | 24 +- poulpy-core/src/layouts/glwe_plaintext.rs | 12 + poulpy-core/src/layouts/lwe.rs | 14 +- poulpy-core/src/layouts/lwe_plaintext.rs | 14 +- poulpy-core/src/noise/mod.rs | 4 +- poulpy-core/src/scratch.rs | 2 +- .../test_suite/automorphism/gglwe_atk.rs | 130 ++++--- .../tests/test_suite/automorphism/ggsw_ct.rs | 68 ++-- .../tests/test_suite/automorphism/glwe_ct.rs | 89 +++-- .../src/tests/test_suite/conversion.rs | 67 +++- .../tests/test_suite/keyswitch/gglwe_ct.rs | 78 ++-- .../src/tests/test_suite/keyswitch/ggsw_ct.rs | 72 ++-- .../src/tests/test_suite/keyswitch/glwe_ct.rs | 1 - .../src/tests/test_suite/keyswitch/lwe_ct.rs | 48 ++- poulpy-core/src/utils.rs | 10 +- poulpy-hal/src/api/mod.rs | 2 - poulpy-hal/src/api/scratch.rs | 7 +- poulpy-hal/src/api/zn.rs | 58 --- poulpy-hal/src/delegates/mod.rs | 1 - poulpy-hal/src/delegates/zn.rs | 81 ---- poulpy-hal/src/layouts/encoding.rs | 86 +---- poulpy-hal/src/layouts/mod.rs | 2 - poulpy-hal/src/layouts/zn.rs | 273 -------------- poulpy-hal/src/oep/mod.rs | 2 - poulpy-hal/src/oep/zn.rs | 70 ---- poulpy-hal/src/reference/mod.rs | 1 - poulpy-hal/src/reference/vec_znx/normalize.rs | 2 + poulpy-hal/src/reference/zn/mod.rs | 5 - poulpy-hal/src/reference/zn/normalization.rs | 72 ---- poulpy-hal/src/reference/zn/sampling.rs | 75 ---- .../examples/circuit_bootstrapping.rs | 21 +- .../tests/generic_blind_rotation.rs | 9 +- .../tests/circuit_bootstrapping.rs | 18 +- 52 files changed, 952 insertions(+), 1550 deletions(-) delete mode 100644 poulpy-backend/src/cpu_fft64_avx/zn.rs delete mode 100644 poulpy-backend/src/cpu_fft64_ref/zn.rs delete mode 100644 poulpy-backend/src/cpu_spqlios/fft64/zn.rs delete mode 100644 poulpy-hal/src/api/zn.rs delete mode 100644 poulpy-hal/src/delegates/zn.rs delete mode 100644 poulpy-hal/src/layouts/zn.rs delete mode 100644 poulpy-hal/src/oep/zn.rs delete mode 100644 poulpy-hal/src/reference/zn/mod.rs delete mode 100644 poulpy-hal/src/reference/zn/normalization.rs delete mode 100644 poulpy-hal/src/reference/zn/sampling.rs diff --git a/poulpy-backend/examples/rlwe_encrypt.rs b/poulpy-backend/examples/rlwe_encrypt.rs index b4338b9..64061e2 100644 --- a/poulpy-backend/examples/rlwe_encrypt.rs +++ b/poulpy-backend/examples/rlwe_encrypt.rs @@ -1,5 +1,5 @@ use itertools::izip; -use poulpy_backend::cpu_spqlios::FFT64Spqlios; +use poulpy_backend::cpu_fft64_ref::FFT64Ref; use poulpy_hal::{ api::{ ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal, @@ -16,9 +16,9 @@ fn main() { let ct_size: usize = 3; let msg_size: usize = 2; let log_scale: usize = msg_size * base2k - 5; - let module: Module = Module::::new(n as u64); + let module: Module = Module::::new(n as u64); - let mut scratch: ScratchOwned = ScratchOwned::::alloc(module.vec_znx_big_normalize_tmp_bytes()); + let mut scratch: ScratchOwned = ScratchOwned::::alloc(module.vec_znx_big_normalize_tmp_bytes()); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); @@ -28,7 +28,7 @@ fn main() { s.fill_ternary_prob(0, 0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_dft: SvpPPol, FFT64Spqlios> = module.svp_ppol_alloc(s.cols()); + let mut s_dft: SvpPPol, FFT64Ref> = module.svp_ppol_alloc(s.cols()); // s_dft <- DFT(s) module.svp_prepare(&mut s_dft, 0, &s, 0); @@ -43,7 +43,7 @@ fn main() { // Fill the second column with random values: ct = (0, a) module.vec_znx_fill_uniform(base2k, &mut ct, 1, &mut source); - let mut buf_dft: VecZnxDft, FFT64Spqlios> = module.vec_znx_dft_alloc(1, ct_size); + let mut buf_dft: VecZnxDft, FFT64Ref> = module.vec_znx_dft_alloc(1, ct_size); module.vec_znx_dft_apply(1, 0, &mut buf_dft, 0, &ct, 1); @@ -58,7 +58,7 @@ fn main() { // Alias scratch space (VecZnxDft is always at least as big as VecZnxBig) // BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized) - let mut buf_big: VecZnxBig, FFT64Spqlios> = module.vec_znx_big_alloc(1, ct_size); + let mut buf_big: VecZnxBig, FFT64Ref> = module.vec_znx_big_alloc(1, ct_size); module.vec_znx_idft_apply_tmpa(&mut buf_big, 0, &mut buf_dft, 0); // Creates a plaintext: VecZnx with 1 column diff --git a/poulpy-backend/src/cpu_fft64_avx/mod.rs b/poulpy-backend/src/cpu_fft64_avx/mod.rs index d9d4e5f..4ba20c7 100644 --- a/poulpy-backend/src/cpu_fft64_avx/mod.rs +++ b/poulpy-backend/src/cpu_fft64_avx/mod.rs @@ -7,7 +7,6 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp; -mod zn; mod znx_avx; pub struct FFT64Avx {} diff --git a/poulpy-backend/src/cpu_fft64_avx/zn.rs b/poulpy-backend/src/cpu_fft64_avx/zn.rs deleted file mode 100644 index 53ce1c9..0000000 --- a/poulpy-backend/src/cpu_fft64_avx/zn.rs +++ /dev/null @@ -1,73 +0,0 @@ -use poulpy_hal::{ - api::TakeSlice, - layouts::{Scratch, ZnToMut}, - oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl}, - reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform, zn_normalize_inplace, zn_normalize_tmp_bytes}, - source::Source, -}; - -use crate::cpu_fft64_avx::FFT64Avx; - -unsafe impl ZnNormalizeTmpBytesImpl for FFT64Avx { - fn zn_normalize_tmp_bytes_impl(n: usize) -> usize { - zn_normalize_tmp_bytes(n) - } -} - -unsafe impl ZnNormalizeInplaceImpl for FFT64Avx -where - Self: TakeSliceImpl, -{ - fn zn_normalize_inplace_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) - where - R: ZnToMut, - { - let (carry, _) = scratch.take_slice(n); - zn_normalize_inplace::(n, base2k, res, res_col, carry); - } -} - -unsafe impl ZnFillUniformImpl for FFT64Avx { - fn zn_fill_uniform_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) - where - R: ZnToMut, - { - zn_fill_uniform(n, base2k, res, res_col, source); - } -} - -unsafe impl ZnFillNormalImpl for FFT64Avx { - #[allow(clippy::too_many_arguments)] - fn zn_fill_normal_impl( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut, - { - zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound); - } -} - -unsafe impl ZnAddNormalImpl for FFT64Avx { - #[allow(clippy::too_many_arguments)] - fn zn_add_normal_impl( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut, - { - zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound); - } -} diff --git a/poulpy-backend/src/cpu_fft64_ref/mod.rs b/poulpy-backend/src/cpu_fft64_ref/mod.rs index 360c315..e0110a4 100644 --- a/poulpy-backend/src/cpu_fft64_ref/mod.rs +++ b/poulpy-backend/src/cpu_fft64_ref/mod.rs @@ -6,7 +6,6 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp; -mod zn; mod znx; #[cfg(test)] diff --git a/poulpy-backend/src/cpu_fft64_ref/zn.rs b/poulpy-backend/src/cpu_fft64_ref/zn.rs deleted file mode 100644 index 954d559..0000000 --- a/poulpy-backend/src/cpu_fft64_ref/zn.rs +++ /dev/null @@ -1,73 +0,0 @@ -use poulpy_hal::{ - api::TakeSlice, - layouts::{Scratch, ZnToMut}, - oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl}, - reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform, zn_normalize_inplace, zn_normalize_tmp_bytes}, - source::Source, -}; - -use crate::cpu_fft64_ref::FFT64Ref; - -unsafe impl ZnNormalizeTmpBytesImpl for FFT64Ref { - fn zn_normalize_tmp_bytes_impl(n: usize) -> usize { - zn_normalize_tmp_bytes(n) - } -} - -unsafe impl ZnNormalizeInplaceImpl for FFT64Ref -where - Self: TakeSliceImpl, -{ - fn zn_normalize_inplace_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) - where - R: ZnToMut, - { - let (carry, _) = scratch.take_slice(n); - zn_normalize_inplace::(n, base2k, res, res_col, carry); - } -} - -unsafe impl ZnFillUniformImpl for FFT64Ref { - fn zn_fill_uniform_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) - where - R: ZnToMut, - { - zn_fill_uniform(n, base2k, res, res_col, source); - } -} - -unsafe impl ZnFillNormalImpl for FFT64Ref { - #[allow(clippy::too_many_arguments)] - fn zn_fill_normal_impl( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut, - { - zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound); - } -} - -unsafe impl ZnAddNormalImpl for FFT64Ref { - #[allow(clippy::too_many_arguments)] - fn zn_add_normal_impl( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut, - { - zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound); - } -} diff --git a/poulpy-backend/src/cpu_spqlios/fft64/mod.rs b/poulpy-backend/src/cpu_spqlios/fft64/mod.rs index f87b264..4a1f4e3 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/mod.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/mod.rs @@ -5,7 +5,6 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp_pmat; -mod zn; mod znx; pub struct FFT64Spqlios; diff --git a/poulpy-backend/src/cpu_spqlios/fft64/zn.rs b/poulpy-backend/src/cpu_spqlios/fft64/zn.rs deleted file mode 100644 index adc84fd..0000000 --- a/poulpy-backend/src/cpu_spqlios/fft64/zn.rs +++ /dev/null @@ -1,82 +0,0 @@ -use poulpy_hal::{ - api::TakeSlice, - layouts::{Scratch, Zn, ZnToMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut}, - oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl}, - reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform}, - source::Source, -}; - -use crate::cpu_spqlios::{FFT64Spqlios, ffi::zn64}; - -unsafe impl ZnNormalizeInplaceImpl for FFT64Spqlios -where - Self: TakeSliceImpl, -{ - fn zn_normalize_inplace_impl(n: usize, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) - where - A: ZnToMut, - { - let mut a: Zn<&mut [u8]> = a.to_mut(); - - let (tmp_bytes, _) = scratch.take_slice(n * size_of::()); - - unsafe { - zn64::zn64_normalize_base2k_ref( - n as u64, - base2k as u64, - a.at_mut_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - a.at_ptr(a_col, 0), - a.size() as u64, - a.sl() as u64, - tmp_bytes.as_mut_ptr(), - ); - } - } -} - -unsafe impl ZnFillUniformImpl for FFT64Spqlios { - fn zn_fill_uniform_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) - where - R: ZnToMut, - { - zn_fill_uniform(n, base2k, res, res_col, source); - } -} - -unsafe impl ZnFillNormalImpl for FFT64Spqlios { - #[allow(clippy::too_many_arguments)] - fn zn_fill_normal_impl( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut, - { - zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound); - } -} - -unsafe impl ZnAddNormalImpl for FFT64Spqlios { - #[allow(clippy::too_many_arguments)] - fn zn_add_normal_impl( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut, - { - zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound); - } -} diff --git a/poulpy-core/src/automorphism/gglwe_atk.rs b/poulpy-core/src/automorphism/gglwe_atk.rs index 87e545e..1e4dd92 100644 --- a/poulpy-core/src/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/automorphism/gglwe_atk.rs @@ -55,7 +55,11 @@ where A: GGLWEInfos, K: GGLWEInfos, { - self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + if res_infos.glwe_layout() == a_infos.glwe_layout() { + self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + } else { + self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + GLWE::bytes_of_from_infos(a_infos) + } } fn glwe_automorphism_key_automorphism(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) @@ -79,12 +83,16 @@ where a.dsize() ); + assert_eq!(res.base2k(), a.base2k()); + let cols_out: usize = (key.rank_out() + 1).into(); let cols_in: usize = key.rank_in().into(); let p: i64 = a.p(); let p_inv: i64 = self.galois_element_inv(p); + let same_layout: bool = res.glwe_layout() == a.glwe_layout(); + { let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); let a: &GGLWE<&[u8]> = &a.to_ref(); @@ -94,18 +102,30 @@ where let mut res_tmp: GLWE<&mut [u8]> = res.at_mut(row, col); let a_ct: GLWE<&[u8]> = a.at(row, col); - // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) - for i in 0..cols_out { - self.vec_znx_automorphism(p, res_tmp.data_mut(), i, &a_ct.data, i); + if same_layout { + // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) + for i in 0..cols_out { + self.vec_znx_automorphism(p, res_tmp.data_mut(), i, &a_ct.data, i); + } + + // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) + self.glwe_keyswitch_inplace(&mut res_tmp, key, scratch); + } else { + let (mut tmp_glwe, scratch_1) = scratch.take_glwe(a); + + // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) + for i in 0..cols_out { + self.vec_znx_automorphism(p, tmp_glwe.data_mut(), i, &a_ct.data, i); + } + + // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) + self.glwe_keyswitch(&mut res_tmp, &tmp_glwe, key, scratch_1); } - // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) - self.glwe_keyswitch_inplace(&mut res_tmp, key, scratch); - // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) - (0..cols_out).for_each(|i| { + for i in 0..cols_out { self.vec_znx_automorphism_inplace(p_inv, res_tmp.data_mut(), i, scratch); - }); + } } } } diff --git a/poulpy-core/src/automorphism/ggsw_ct.rs b/poulpy-core/src/automorphism/ggsw_ct.rs index 8644f98..0b2d977 100644 --- a/poulpy-core/src/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/automorphism/ggsw_ct.rs @@ -34,9 +34,9 @@ impl GGSW> { impl GGSW { pub fn automorphism(&mut self, module: &M, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) where - A: GGSWToRef, + A: GGSWToRef + GGSWInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, - T: GGLWEToGGSWKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, M: GGSWAutomorphism, { @@ -73,20 +73,21 @@ where fn ggsw_automorphism(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) where - R: GGSWToMut, - A: GGSWToRef, + R: GGSWToMut + GGSWInfos, + A: GGSWToRef + GGSWInfos, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, - T: GGLWEToGGSWKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { + assert_eq!(res.dsize(), a.dsize()); + assert_eq!(res.base2k(), a.base2k()); + assert!(res.dnum() <= a.dnum()); + assert!(scratch.available() >= self.ggsw_automorphism_tmp_bytes(res, a, key, tsk)); + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let a: &GGSW<&[u8]> = &a.to_ref(); let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref(); - assert_eq!(res.dsize(), a.dsize()); - assert!(res.dnum() <= a.dnum()); - assert!(scratch.available() >= self.ggsw_automorphism_tmp_bytes(res, a, key, tsk)); - // Keyswitch the j-th row of the col 0 for row in 0..res.dnum().as_usize() { // Key-switch column 0, i.e. diff --git a/poulpy-core/src/automorphism/glwe_ct.rs b/poulpy-core/src/automorphism/glwe_ct.rs index b382197..8daa416 100644 --- a/poulpy-core/src/automorphism/glwe_ct.rs +++ b/poulpy-core/src/automorphism/glwe_ct.rs @@ -7,8 +7,8 @@ use poulpy_hal::{ }; use crate::{ - GLWEKeySwitchInternal, GLWEKeyswitch, ScratchTakeCore, - layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos}, + GLWEKeySwitchInternal, GLWEKeyswitch, GLWENormalize, ScratchTakeCore, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWELayout, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos}, }; impl GLWE> { @@ -164,7 +164,8 @@ where + VecZnxBigSubSmallInplace + VecZnxBigSubSmallNegateInplace + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, + + VecZnxBigNormalize + + GLWENormalize, Scratch: ScratchTakeCore, { fn glwe_automorphism_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize @@ -217,22 +218,50 @@ where let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let a: &GLWE<&[u8]> = &a.to_ref(); - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); + let base2k_a: usize = a.base2k().into(); + let base2k_key: usize = key.base2k().into(); + let base2k_res: usize = res.base2k().into(); - for i in 0..res.rank().as_usize() + 1 { - self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); - self.vec_znx_big_add_small_inplace(&mut res_big, i, a.data(), i); - self.vec_znx_big_normalize( - res.base2k().into(), - res.data_mut(), - i, - key.base2k().into(), - &res_big, - i, - scratch_1, - ); - } + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + + if base2k_a != base2k_key { + let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: a.n(), + base2k: key.base2k(), + k: a.k(), + rank: a.rank(), + }); + self.glwe_normalize(&mut a_conv, a, scratch_2); + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, &a_conv, key, scratch_2); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); + self.vec_znx_big_add_small_inplace(&mut res_big, i, a_conv.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_2, + ); + } + } else { + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_add_small_inplace(&mut res_big, i, a.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_1, + ); + } + }; } fn glwe_automorphism_add_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) @@ -243,22 +272,49 @@ where { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); + let base2k_key: usize = key.base2k().into(); + let base2k_res: usize = res.base2k().into(); - for i in 0..res.rank().as_usize() + 1 { - self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); - self.vec_znx_big_add_small_inplace(&mut res_big, i, res.data(), i); - self.vec_znx_big_normalize( - res.base2k().into(), - res.data_mut(), - i, - key.base2k().into(), - &res_big, - i, - scratch_1, - ); - } + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + + if base2k_res != base2k_key { + let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: res.n(), + base2k: key.base2k(), + k: res.k(), + rank: res.rank(), + }); + self.glwe_normalize(&mut res_conv, res, scratch_2); + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, &res_conv, key, scratch_2); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); + self.vec_znx_big_add_small_inplace(&mut res_big, i, res_conv.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_2, + ); + } + } else { + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_add_small_inplace(&mut res_big, i, res.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_1, + ); + } + }; } fn glwe_automorphism_sub(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) @@ -271,22 +327,50 @@ where let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let a: &GLWE<&[u8]> = &a.to_ref(); - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); + let base2k_a: usize = a.base2k().into(); + let base2k_key: usize = key.base2k().into(); + let base2k_res: usize = res.base2k().into(); - for i in 0..res.rank().as_usize() + 1 { - self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); - self.vec_znx_big_sub_small_inplace(&mut res_big, i, a.data(), i); - self.vec_znx_big_normalize( - res.base2k().into(), - res.data_mut(), - i, - key.base2k().into(), - &res_big, - i, - scratch_1, - ); - } + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + + if base2k_a != base2k_key { + let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: a.n(), + base2k: key.base2k(), + k: a.k(), + rank: a.rank(), + }); + self.glwe_normalize(&mut a_conv, a, scratch_2); + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, &a_conv, key, scratch_2); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); + self.vec_znx_big_sub_small_inplace(&mut res_big, i, a_conv.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_2, + ); + } + } else { + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_sub_small_inplace(&mut res_big, i, a.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_1, + ); + } + }; } fn glwe_automorphism_sub_negate(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) @@ -299,22 +383,50 @@ where let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let a: &GLWE<&[u8]> = &a.to_ref(); - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); + let base2k_a: usize = a.base2k().into(); + let base2k_key: usize = key.base2k().into(); + let base2k_res: usize = res.base2k().into(); - for i in 0..res.rank().as_usize() + 1 { - self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); - self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, a.data(), i); - self.vec_znx_big_normalize( - res.base2k().into(), - res.data_mut(), - i, - key.base2k().into(), - &res_big, - i, - scratch_1, - ); - } + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + + if base2k_a != base2k_key { + let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: a.n(), + base2k: key.base2k(), + k: a.k(), + rank: a.rank(), + }); + self.glwe_normalize(&mut a_conv, a, scratch_2); + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, &a_conv, key, scratch_2); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); + self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, a_conv.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_2, + ); + } + } else { + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, a.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_1, + ); + } + }; } fn glwe_automorphism_sub_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) @@ -325,22 +437,49 @@ where { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); + let base2k_key: usize = key.base2k().into(); + let base2k_res: usize = res.base2k().into(); - for i in 0..res.rank().as_usize() + 1 { - self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); - self.vec_znx_big_sub_small_inplace(&mut res_big, i, res.data(), i); - self.vec_znx_big_normalize( - res.base2k().into(), - res.data_mut(), - i, - key.base2k().into(), - &res_big, - i, - scratch_1, - ); - } + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + + if base2k_res != base2k_key { + let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: res.n(), + base2k: key.base2k(), + k: res.k(), + rank: res.rank(), + }); + self.glwe_normalize(&mut res_conv, res, scratch_2); + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, &res_conv, key, scratch_2); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); + self.vec_znx_big_sub_small_inplace(&mut res_big, i, res_conv.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_2, + ); + } + } else { + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_sub_small_inplace(&mut res_big, i, res.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_1, + ); + } + }; } fn glwe_automorphism_sub_negate_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) @@ -351,21 +490,48 @@ where { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); + let base2k_key: usize = key.base2k().into(); + let base2k_res: usize = res.base2k().into(); - for i in 0..res.rank().as_usize() + 1 { - self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); - self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, res.data(), i); - self.vec_znx_big_normalize( - res.base2k().into(), - res.data_mut(), - i, - key.base2k().into(), - &res_big, - i, - scratch_1, - ); - } + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size + + if base2k_res != base2k_key { + let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: res.n(), + base2k: key.base2k(), + k: res.k(), + rank: res.rank(), + }); + self.glwe_normalize(&mut res_conv, res, scratch_2); + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, &res_conv, key, scratch_2); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2); + self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, res_conv.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_2, + ); + } + } else { + let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); + self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, res.data(), i); + self.vec_znx_big_normalize( + base2k_res, + res.data_mut(), + i, + base2k_key, + &res_big, + i, + scratch_1, + ); + } + }; } } diff --git a/poulpy-core/src/conversion/gglwe_to_ggsw.rs b/poulpy-core/src/conversion/gglwe_to_ggsw.rs index 8554e50..9770f1d 100644 --- a/poulpy-core/src/conversion/gglwe_to_ggsw.rs +++ b/poulpy-core/src/conversion/gglwe_to_ggsw.rs @@ -1,9 +1,9 @@ use poulpy_hal::{ api::{ ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, - VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, }, - layouts::{Backend, DataMut, Module, Scratch, VecZnxBig}, + layouts::{Backend, DataMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VecZnxDftToRef, VecZnxToRef}, }; use crate::{ @@ -65,6 +65,7 @@ where assert_eq!(res.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32); assert_eq!(tsk.n(), self.n() as u32); + assert_eq!(res.base2k(), a.base2k()); for row in 0..res.dnum().into() { self.glwe_copy(&mut res.at_mut(row, 0), &a.at(row, 0)); @@ -111,28 +112,29 @@ where + VecZnxDftApply + VecZnxNormalize + VecZnxBigAddSmallInplace - + VecZnxIdftApplyConsume, + + VecZnxIdftApplyConsume + + VecZnxCopy, { fn ggsw_expand_rows_tmp_bytes(&self, res_infos: &R, tsk_infos: &A) -> usize where R: GGSWInfos, A: GGLWEInfos, { - let base2k_in: usize = res_infos.base2k().into(); let base2k_tsk: usize = tsk_infos.base2k().into(); let rank: usize = res_infos.rank().into(); let cols: usize = rank + 1; - let res_size = res_infos.size(); - let a_size: usize = (res_infos.size() * base2k_in).div_ceil(base2k_tsk); + let res_size: usize = res_infos.size(); + let a_size: usize = res_infos.max_k().as_usize().div_ceil(base2k_tsk); - let a_dft = self.bytes_of_vec_znx_dft(cols - 1, a_size); - let res_dft = self.bytes_of_vec_znx_dft(cols, a_size); + let a_0: usize = VecZnx::bytes_of(self.n(), 1, a_size); + let a_dft: usize = self.bytes_of_vec_znx_dft(cols - 1, a_size); + let res_dft: usize = self.bytes_of_vec_znx_dft(cols, a_size); let gglwe_prod: usize = self.gglwe_product_dft_tmp_bytes(res_size, a_size, tsk_infos); - let normalize = self.vec_znx_big_normalize_tmp_bytes(); + let normalize: usize = self.vec_znx_big_normalize_tmp_bytes(); - (a_dft + res_dft + gglwe_prod).max(normalize) + (a_0 + a_dft + res_dft + gglwe_prod).max(normalize) } fn ggsw_expand_row(&self, res: &mut R, tsk: &T, scratch: &mut Scratch) @@ -144,7 +146,7 @@ where let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref(); - let base2k_in: usize = res.base2k().into(); + let base2k_res: usize = res.base2k().into(); let base2k_tsk: usize = tsk.base2k().into(); assert!(scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk)); @@ -152,96 +154,129 @@ where let rank: usize = res.rank().into(); let cols: usize = rank + 1; - let a_size: usize = (res.size() * base2k_in).div_ceil(base2k_tsk); + let res_conv_size: usize = res.max_k().as_usize().div_ceil(base2k_tsk); + + let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, res_conv_size); + let (mut a_0, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, res_conv_size); // Keyswitch the j-th row of the col 0 for row in 0..res.dnum().as_usize() { - let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, a_size); + let glwe_mi_1: &GLWE<&[u8]> = &res.at(row, 0); - { - let glwe_mi_1: &GLWE<&[u8]> = &res.at(row, 0); - - if base2k_in == base2k_tsk { - for col_i in 0..cols - 1 { - self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, glwe_mi_1.data(), col_i + 1); - } - } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_size); - for i in 0..cols - 1 { - self.vec_znx_normalize( - base2k_tsk, - &mut a_conv, - 0, - base2k_in, - glwe_mi_1.data(), - i + 1, - scratch_2, - ); - self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_conv, 0); - } + if base2k_res == base2k_tsk { + for col_i in 0..cols - 1 { + self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, glwe_mi_1.data(), col_i + 1); } - } - - // Example for rank 3: - // - // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is - // actually composed of that many dnum and we focus on a specific row here - // implicitely given ci_dft. - // - // # Input - // - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) - // col 1: (0, 0, 0, 0) - // col 2: (0, 0, 0, 0) - // col 3: (0, 0, 0, 0) - // - // # Output - // - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) - // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) - // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) - // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) - for col in 1..cols { - let (mut res_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, tsk.size()); // Todo optimise - - // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 - // - // # Example for col=1 - // - // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) - // + - // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) - // + - // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) - // = - // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) - self.gglwe_product_dft(&mut res_dft, &a_dft, tsk.at(col - 1), scratch_2); - - let mut res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft); - - // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i - // - // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2) - // + - // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0) - // = - // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) - // = - // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) - self.vec_znx_big_add_small_inplace(&mut res_big, col, res.at(row, 0).data(), 0); - - for j in 0..cols { - self.vec_znx_big_normalize( - res.base2k().as_usize(), - res.at_mut(row, col).data_mut(), - j, - tsk.base2k().as_usize(), - &res_big, - j, + self.vec_znx_copy(&mut a_0, 0, glwe_mi_1.data(), 0); + } else { + for i in 0..cols - 1 { + self.vec_znx_normalize( + base2k_tsk, + &mut a_0, + 0, + base2k_res, + glwe_mi_1.data(), + i + 1, scratch_2, ); + self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_0, 0); } + self.vec_znx_normalize( + base2k_tsk, + &mut a_0, + 0, + base2k_res, + glwe_mi_1.data(), + 0, + scratch_2, + ); } + + ggsw_expand_rows_internal(self, row, res, &a_0, &a_dft, tsk, scratch_2) + } + } +} + +fn ggsw_expand_rows_internal( + module: &M, + row: usize, + res: &mut R, + a_0: &C, + a_dft: &A, + tsk: &T, + scratch: &mut Scratch, +) where + R: GGSWToMut, + C: VecZnxToRef, + A: VecZnxDftToRef, + M: GGLWEProduct + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace + VecZnxBigNormalize, + T: GGLWEToGGSWKeyPreparedToRef, + Scratch: ScratchTakeCore, +{ + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let a_0: &VecZnx<&[u8]> = &a_0.to_ref(); + let a_dft: &VecZnxDft<&[u8], BE> = &a_dft.to_ref(); + let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref(); + let cols: usize = res.rank().as_usize() + 1; + + // Example for rank 3: + // + // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is + // actually composed of that many dnum and we focus on a specific row here + // implicitely given ci_dft. + // + // # Input + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (0, 0, 0, 0) + // col 2: (0, 0, 0, 0) + // col 3: (0, 0, 0, 0) + // + // # Output + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) + // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) + // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) + for col in 1..cols { + let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols, tsk.size()); // Todo optimise + + // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 + // + // # Example for col=1 + // + // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) + // + + // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) + // + + // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) + module.gglwe_product_dft(&mut res_dft, a_dft, tsk.at(col - 1), scratch_1); + + let mut res_big: VecZnxBig<&mut [u8], BE> = module.vec_znx_idft_apply_consume(res_dft); + + // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i + // + // (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2) + // + + // (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) + // = + // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) + module.vec_znx_big_add_small_inplace(&mut res_big, col, a_0, 0); + + for j in 0..cols { + module.vec_znx_big_normalize( + res.base2k().as_usize(), + res.at_mut(row, col).data_mut(), + j, + tsk.base2k().as_usize(), + &res_big, + j, + scratch_1, + ); } } } diff --git a/poulpy-core/src/decryption/lwe.rs b/poulpy-core/src/decryption/lwe.rs index edd727b..997d7ea 100644 --- a/poulpy-core/src/decryption/lwe.rs +++ b/poulpy-core/src/decryption/lwe.rs @@ -1,39 +1,44 @@ use poulpy_hal::{ - api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace}, - layouts::{Backend, DataMut, DataRef, Module, ScratchOwned, ZnxView, ZnxViewMut}, + api::VecZnxNormalizeInplace, + layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut}, }; -use crate::layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToMut, LWESecret, LWESecretToRef, LWEToMut}; +use crate::{ + ScratchTakeCore, + layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToMut, LWESecret, LWESecretToRef, LWEToMut}, +}; impl LWE { - pub fn decrypt(&mut self, module: &M, pt: &mut P, sk: &S) + pub fn decrypt(&mut self, module: &M, pt: &mut P, sk: &S, scratch: &mut Scratch) where P: LWEPlaintextToMut, S: LWESecretToRef, - M: LWEDecrypt, + M: LWEDecrypt, + Scratch: ScratchTakeCore, { - module.lwe_decrypt(self, pt, sk); + module.lwe_decrypt(self, pt, sk, scratch); } } pub trait LWEDecrypt { - fn lwe_decrypt(&self, res: &mut R, pt: &mut P, sk: &S) - where - R: LWEToMut, - P: LWEPlaintextToMut, - S: LWESecretToRef; -} - -impl LWEDecrypt for Module -where - Self: Sized + ZnNormalizeInplace, - ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, -{ - fn lwe_decrypt(&self, res: &mut R, pt: &mut P, sk: &S) + fn lwe_decrypt(&self, res: &mut R, pt: &mut P, sk: &S, scratch: &mut Scratch) where R: LWEToMut, P: LWEPlaintextToMut, S: LWESecretToRef, + Scratch: ScratchTakeCore; +} + +impl LWEDecrypt for Module +where + Self: Sized + VecZnxNormalizeInplace, +{ + fn lwe_decrypt(&self, res: &mut R, pt: &mut P, sk: &S, scratch: &mut Scratch) + where + R: LWEToMut, + P: LWEPlaintextToMut, + S: LWESecretToRef, + Scratch: ScratchTakeCore, { let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); let pt: &mut LWEPlaintext<&mut [u8]> = &mut pt.to_mut(); @@ -52,13 +57,7 @@ where .map(|(x, y)| x * y) .sum::(); }); - self.zn_normalize_inplace( - 1, - res.base2k().into(), - &mut pt.data, - 0, - ScratchOwned::alloc(size_of::()).borrow(), - ); + self.vec_znx_normalize_inplace(res.base2k().into(), &mut pt.data, 0, scratch); pt.base2k = res.base2k(); pt.k = crate::layouts::TorusPrecision(res.k().0.min(pt.size() as u32 * res.base2k().0)); } diff --git a/poulpy-core/src/encryption/lwe.rs b/poulpy-core/src/encryption/lwe.rs index 7651e8d..1ddebd4 100644 --- a/poulpy-core/src/encryption/lwe.rs +++ b/poulpy-core/src/encryption/lwe.rs @@ -1,43 +1,67 @@ use poulpy_hal::{ - api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace}, - layouts::{Backend, DataMut, Module, ScratchOwned, Zn, ZnxView, ZnxViewMut}, + api::{VecZnxAddNormal, VecZnxFillUniform, VecZnxNormalizeInplace}, + layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxView, ZnxViewMut}, source::Source, }; use crate::{ + ScratchTakeCore, encryption::{SIGMA, SIGMA_BOUND}, layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToRef, LWESecret, LWESecretToRef, LWEToMut}, }; impl LWE { - pub fn encrypt_sk(&mut self, module: &M, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source) - where + pub fn encrypt_sk( + &mut self, + module: &M, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where P: LWEPlaintextToRef, S: LWESecretToRef, M: LWEEncryptSk, + Scratch: ScratchTakeCore, { - module.lwe_encrypt_sk(self, pt, sk, source_xa, source_xe); + module.lwe_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch); } } pub trait LWEEncryptSk { - fn lwe_encrypt_sk(&self, res: &mut R, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source) - where + fn lwe_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where R: LWEToMut, P: LWEPlaintextToRef, - S: LWESecretToRef; + S: LWESecretToRef, + Scratch: ScratchTakeCore; } impl LWEEncryptSk for Module where - Self: Sized + ZnFillUniform + ZnAddNormal + ZnNormalizeInplace, - ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Self: Sized + VecZnxFillUniform + VecZnxAddNormal + VecZnxNormalizeInplace, { - fn lwe_encrypt_sk(&self, res: &mut R, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source) - where + fn lwe_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where R: LWEToMut, P: LWEPlaintextToRef, S: LWESecretToRef, + Scratch: ScratchTakeCore, { let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); let pt: &LWEPlaintext<&[u8]> = &pt.to_ref(); @@ -51,11 +75,11 @@ where let base2k: usize = res.base2k().into(); let k: usize = res.k().into(); - self.zn_fill_uniform((res.n() + 1).into(), base2k, &mut res.data, 0, source_xa); + self.vec_znx_fill_uniform(base2k, &mut res.data, 0, source_xa); - let mut tmp_znx: Zn> = Zn::alloc(1, 1, res.size()); + let mut tmp_znx: VecZnx> = VecZnx::alloc(1, 1, res.size()); - let min_size = res.size().min(pt.size()); + let min_size: usize = res.size().min(pt.size()); (0..min_size).for_each(|i| { tmp_znx.at_mut(0, i)[0] = pt.data.at(0, i)[0] @@ -74,24 +98,9 @@ where .sum::(); }); - self.zn_add_normal( - 1, - base2k, - &mut res.data, - 0, - k, - source_xe, - SIGMA, - SIGMA_BOUND, - ); + self.vec_znx_add_normal(base2k, &mut tmp_znx, 0, k, source_xe, SIGMA, SIGMA_BOUND); - self.zn_normalize_inplace( - 1, - base2k, - &mut tmp_znx, - 0, - ScratchOwned::alloc(size_of::()).borrow(), - ); + self.vec_znx_normalize_inplace(base2k, &mut tmp_znx, 0, scratch); (0..res.size()).for_each(|i| { res.data.at_mut(0, i)[0] = tmp_znx.at(0, i)[0]; diff --git a/poulpy-core/src/external_product/gglwe.rs b/poulpy-core/src/external_product/gglwe.rs index 437cf39..c021e67 100644 --- a/poulpy-core/src/external_product/gglwe.rs +++ b/poulpy-core/src/external_product/gglwe.rs @@ -30,8 +30,8 @@ impl GLWEAutomorphismKey { pub fn external_product(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) where M: GGLWEExternalProduct, - A: GGLWEToRef, - B: GGSWPreparedToRef, + A: GGLWEToRef + GGLWEInfos, + B: GGSWPreparedToRef + GGSWInfos, Scratch: ScratchTakeCore, { module.gglwe_external_product(self, a, b, scratch); @@ -62,15 +62,11 @@ where fn gglwe_external_product(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) where - R: GGLWEToMut, - A: GGLWEToRef, - B: GGSWPreparedToRef, + R: GGLWEToMut + GGLWEInfos, + A: GGLWEToRef + GGLWEInfos, + B: GGSWPreparedToRef + GGSWInfos, Scratch: ScratchTakeCore, { - let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); - let a: &GGLWE<&[u8]> = &a.to_ref(); - let b: &GGSWPrepared<&[u8], BE> = &b.to_ref(); - assert_eq!( res.rank_in(), a.rank_in(), @@ -92,6 +88,11 @@ where res.rank_out(), b.rank() ); + assert_eq!(res.base2k(), a.base2k()); + + let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GGLWE<&[u8]> = &a.to_ref(); + let b: &GGSWPrepared<&[u8], BE> = &b.to_ref(); for row in 0..res.dnum().into() { for col in 0..res.rank_in().into() { @@ -149,8 +150,8 @@ impl GLWESwitchingKey { pub fn external_product(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) where M: GGLWEExternalProduct, - A: GGLWEToRef, - B: GGSWPreparedToRef, + A: GGLWEToRef + GGLWEInfos, + B: GGSWPreparedToRef + GGSWInfos, Scratch: ScratchTakeCore, { module.gglwe_external_product(self, a, b, scratch); diff --git a/poulpy-core/src/external_product/ggsw.rs b/poulpy-core/src/external_product/ggsw.rs index d33055d..92ff84b 100644 --- a/poulpy-core/src/external_product/ggsw.rs +++ b/poulpy-core/src/external_product/ggsw.rs @@ -50,6 +50,8 @@ where b.rank() ); + assert_eq!(res.base2k(), a.base2k()); + assert!(scratch.available() >= self.ggsw_external_product_tmp_bytes(res, a, b)); let min_dnum: usize = res.dnum().min(a.dnum()).into(); diff --git a/poulpy-core/src/keyswitching/gglwe.rs b/poulpy-core/src/keyswitching/gglwe.rs index d837002..779ff11 100644 --- a/poulpy-core/src/keyswitching/gglwe.rs +++ b/poulpy-core/src/keyswitching/gglwe.rs @@ -21,7 +21,7 @@ impl GLWEAutomorphismKey> { impl GLWEAutomorphismKey { pub fn keyswitch(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) where - A: GGLWEToRef + GGLWEToRef, + A: GGLWEToRef + GGLWEInfos, B: GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, M: GGLWEKeyswitch, @@ -54,7 +54,7 @@ impl GLWESwitchingKey> { impl GLWESwitchingKey { pub fn keyswitch(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) where - A: GGLWEToRef, + A: GGLWEToRef + GGLWEInfos, B: GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, M: GGLWEKeyswitch, @@ -87,7 +87,7 @@ impl GGLWE> { impl GGLWE { pub fn keyswitch(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch) where - A: GGLWEToRef, + A: GGLWEToRef + GGLWEInfos, B: GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, M: GGLWEKeyswitch, @@ -122,14 +122,11 @@ where fn gglwe_keyswitch(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) where - R: GGLWEToMut, - A: GGLWEToRef, + R: GGLWEToMut + GGLWEInfos, + A: GGLWEToRef + GGLWEInfos, B: GGLWEPreparedToRef + GGLWEInfos, Scratch: ScratchTakeCore, { - let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); - let a: &GGLWE<&[u8]> = &a.to_ref(); - assert_eq!( res.rank_in(), a.rank_in(), @@ -164,6 +161,10 @@ where res.dsize(), a.dsize() ); + assert_eq!(res.base2k(), a.base2k()); + + let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GGLWE<&[u8]> = &a.to_ref(); for row in 0..res.dnum().into() { for col in 0..res.rank_in().into() { diff --git a/poulpy-core/src/keyswitching/ggsw.rs b/poulpy-core/src/keyswitching/ggsw.rs index 3dfb0b1..3a5efb3 100644 --- a/poulpy-core/src/keyswitching/ggsw.rs +++ b/poulpy-core/src/keyswitching/ggsw.rs @@ -3,7 +3,7 @@ use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch}; use crate::{ GGSWExpandRows, ScratchTakeCore, keyswitching::GLWEKeyswitch, - layouts::{GGLWEInfos, GGLWEPreparedToRef, GGLWEToGGSWKeyPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef}, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GGLWEToGGSWKeyPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef, LWEInfos}, }; impl GGSW> { @@ -98,6 +98,7 @@ where assert!(res.dnum() <= a.dnum()); assert_eq!(res.dsize(), a.dsize()); + assert_eq!(res.base2k(), a.base2k()); for row in 0..a.dnum().into() { // Key-switch column 0, i.e. diff --git a/poulpy-core/src/keyswitching/glwe.rs b/poulpy-core/src/keyswitching/glwe.rs index be670e6..5bb5875 100644 --- a/poulpy-core/src/keyswitching/glwe.rs +++ b/poulpy-core/src/keyswitching/glwe.rs @@ -57,21 +57,19 @@ where B: GGLWEInfos, { let cols: usize = res_infos.rank().as_usize() + 1; - let size: usize = self - .glwe_keyswitch_internal_tmp_bytes(res_infos, a_infos, key_infos) - .max(self.vec_znx_big_normalize_tmp_bytes()) - + self.bytes_of_vec_znx_dft(cols, key_infos.size()); - - if a_infos.base2k() != key_infos.base2k() { - size + GLWE::bytes_of_from_infos(&GLWELayout { + let size: usize = if a_infos.base2k() != key_infos.base2k() { + let a_conv_infos = &GLWELayout { n: a_infos.n(), base2k: key_infos.base2k(), k: a_infos.k(), rank: a_infos.rank(), - }) + }; + self.glwe_keyswitch_internal_tmp_bytes(res_infos, a_conv_infos, key_infos) + GLWE::bytes_of_from_infos(a_conv_infos) } else { - size - } + self.glwe_keyswitch_internal_tmp_bytes(res_infos, a_infos, key_infos) + }; + + size.max(self.vec_znx_big_normalize_tmp_bytes()) + self.bytes_of_vec_znx_dft(cols, key_infos.size()) } fn glwe_keyswitch(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) @@ -256,7 +254,7 @@ where { let cols: usize = (a_infos.rank() + 1).into(); let a_size: usize = a_infos.size(); - self.gglwe_product_dft_tmp_bytes(res_infos.size(), a_size, key_infos) + self.bytes_of_vec_znx_dft(cols, a_size) + self.gglwe_product_dft_tmp_bytes(res_infos.size(), a_size, key_infos) + self.bytes_of_vec_znx_dft(cols - 1, a_size) } fn glwe_keyswitch_internal( diff --git a/poulpy-core/src/keyswitching/lwe.rs b/poulpy-core/src/keyswitching/lwe.rs index bf5abf5..a31df66 100644 --- a/poulpy-core/src/keyswitching/lwe.rs +++ b/poulpy-core/src/keyswitching/lwe.rs @@ -83,34 +83,30 @@ where assert_eq!(ksk.n(), self.n() as u32); assert!(scratch.available() >= self.lwe_keyswitch_tmp_bytes(res, a, ksk)); - let max_k: TorusPrecision = res.k().max(a.k()); - - let a_size: usize = a.k().div_ceil(ksk.base2k()) as usize; - let (mut glwe_in, scratch_1) = scratch.take_glwe(&GLWELayout { n: ksk.n(), base2k: a.base2k(), - k: max_k, + k: a.k(), rank: Rank(1), }); glwe_in.data.zero(); - let (mut glwe_out, scratch_1) = scratch_1.take_glwe(&GLWELayout { - n: ksk.n(), - base2k: res.base2k(), - k: max_k, - rank: Rank(1), - }); - let n_lwe: usize = a.n().into(); - for i in 0..a_size { + for i in 0..a.size() { let data_lwe: &[i64] = a.data.at(0, i); glwe_in.data.at_mut(0, i)[0] = data_lwe[0]; glwe_in.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); } - self.glwe_keyswitch(&mut glwe_out, &glwe_in, ksk, scratch_1); + let (mut glwe_out, scratch_2) = scratch_1.take_glwe(&GLWELayout { + n: ksk.n(), + base2k: res.base2k(), + k: res.k(), + rank: Rank(1), + }); + + self.glwe_keyswitch(&mut glwe_out, &glwe_in, ksk, scratch_2); self.lwe_sample_extract(res, &glwe_out); } } diff --git a/poulpy-core/src/layouts/compressed/lwe.rs b/poulpy-core/src/layouts/compressed/lwe.rs index ce4c000..0c7a27e 100644 --- a/poulpy-core/src/layouts/compressed/lwe.rs +++ b/poulpy-core/src/layouts/compressed/lwe.rs @@ -1,10 +1,10 @@ use std::fmt; use poulpy_hal::{ - api::ZnFillUniform, + api::VecZnxFillUniform, layouts::{ - Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo, Zn, ZnToMut, ZnToRef, ZnxInfos, ZnxView, - ZnxViewMut, + Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo, ZnxInfos, + ZnxView, ZnxViewMut, }, source::Source, }; @@ -13,7 +13,7 @@ use crate::layouts::{Base2K, Degree, LWE, LWEInfos, LWEToMut, TorusPrecision}; #[derive(PartialEq, Eq, Clone)] pub struct LWECompressed { - pub(crate) data: Zn, + pub(crate) data: VecZnx, pub(crate) k: TorusPrecision, pub(crate) base2k: Base2K, pub(crate) seed: [u8; 32], @@ -72,7 +72,7 @@ impl LWECompressed> { pub fn alloc(base2k: Base2K, k: TorusPrecision) -> Self { LWECompressed { - data: Zn::alloc(1, 1, k.0.div_ceil(base2k.0) as usize), + data: VecZnx::alloc(1, 1, k.0.div_ceil(base2k.0) as usize), k, base2k, seed: [0u8; 32], @@ -87,7 +87,7 @@ impl LWECompressed> { } pub fn bytes_of(base2k: Base2K, k: TorusPrecision) -> usize { - Zn::bytes_of(1, 1, k.0.div_ceil(base2k.0) as usize) + VecZnx::bytes_of(1, 1, k.0.div_ceil(base2k.0) as usize) } } @@ -113,7 +113,7 @@ impl WriterTo for LWECompressed { pub trait LWEDecompress where - Self: ZnFillUniform, + Self: VecZnxFillUniform, { fn decompress_lwe(&self, res: &mut R, other: &O) where @@ -126,20 +126,14 @@ where assert_eq!(res.lwe_layout(), other.lwe_layout()); let mut source: Source = Source::new(other.seed); - self.zn_fill_uniform( - res.n().into(), - other.base2k().into(), - &mut res.data, - 0, - &mut source, - ); + self.vec_znx_fill_uniform(other.base2k().into(), &mut res.data, 0, &mut source); for i in 0..res.size() { res.data.at_mut(0, i)[0] = other.data.at(0, i)[0]; } } } -impl LWEDecompress for Module where Self: ZnFillUniform {} +impl LWEDecompress for Module where Self: VecZnxFillUniform {} impl LWE { pub fn decompress(&mut self, module: &M, other: &O) diff --git a/poulpy-core/src/layouts/glwe_plaintext.rs b/poulpy-core/src/layouts/glwe_plaintext.rs index 3261d3d..411617d 100644 --- a/poulpy-core/src/layouts/glwe_plaintext.rs +++ b/poulpy-core/src/layouts/glwe_plaintext.rs @@ -158,3 +158,15 @@ impl GLWEPlaintextToMut for GLWEPlaintext { } } } + +impl GLWEPlaintext { + pub fn data_mut(&mut self) -> &mut VecZnx { + &mut self.data + } +} + +impl GLWEPlaintext { + pub fn data(&self) -> &VecZnx { + &self.data + } +} diff --git a/poulpy-core/src/layouts/lwe.rs b/poulpy-core/src/layouts/lwe.rs index dd8a3f4..ce50d54 100644 --- a/poulpy-core/src/layouts/lwe.rs +++ b/poulpy-core/src/layouts/lwe.rs @@ -1,7 +1,7 @@ use std::fmt; use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo, Zn, ZnToMut, ZnToRef, ZnxInfos}, + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo, ZnxInfos}, source::Source, }; @@ -57,7 +57,7 @@ impl LWEInfos for LWELayout { } #[derive(PartialEq, Eq, Clone)] pub struct LWE { - pub(crate) data: Zn, + pub(crate) data: VecZnx, pub(crate) k: TorusPrecision, pub(crate) base2k: Base2K, } @@ -90,13 +90,13 @@ impl SetLWEInfos for LWE { } impl LWE { - pub fn data(&self) -> &Zn { + pub fn data(&self) -> &VecZnx { &self.data } } impl LWE { - pub fn data_mut(&mut self) -> &Zn { + pub fn data_mut(&mut self) -> &VecZnx { &mut self.data } } @@ -121,7 +121,7 @@ impl fmt::Display for LWE { impl FillUniform for LWE where - Zn: FillUniform, + VecZnx: FillUniform, { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.data.fill_uniform(log_bound, source); @@ -138,7 +138,7 @@ impl LWE> { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision) -> Self { LWE { - data: Zn::alloc((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize), + data: VecZnx::alloc((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize), k, base2k, } @@ -152,7 +152,7 @@ impl LWE> { } pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision) -> usize { - Zn::bytes_of((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize) + VecZnx::bytes_of((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize) } } diff --git a/poulpy-core/src/layouts/lwe_plaintext.rs b/poulpy-core/src/layouts/lwe_plaintext.rs index 966ceb5..7c0d39a 100644 --- a/poulpy-core/src/layouts/lwe_plaintext.rs +++ b/poulpy-core/src/layouts/lwe_plaintext.rs @@ -1,6 +1,6 @@ use std::fmt; -use poulpy_hal::layouts::{Data, DataMut, DataRef, Zn, ZnToMut, ZnToRef, ZnxInfos}; +use poulpy_hal::layouts::{Data, DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos}; use crate::layouts::{Base2K, Degree, LWEInfos, TorusPrecision}; @@ -29,7 +29,7 @@ impl LWEInfos for LWEPlaintextLayout { } pub struct LWEPlaintext { - pub(crate) data: Zn, + pub(crate) data: VecZnx, pub(crate) k: TorusPrecision, pub(crate) base2k: Base2K, } @@ -62,7 +62,7 @@ impl LWEPlaintext> { pub fn alloc(base2k: Base2K, k: TorusPrecision) -> Self { LWEPlaintext { - data: Zn::alloc(1, 1, k.0.div_ceil(base2k.0) as usize), + data: VecZnx::alloc(1, 1, k.0.div_ceil(base2k.0) as usize), k, base2k, } @@ -111,8 +111,14 @@ impl LWEPlaintextToMut for LWEPlaintext { } } +impl LWEPlaintext { + pub fn data(&self) -> &VecZnx { + &self.data + } +} + impl LWEPlaintext { - pub fn data_mut(&mut self) -> &mut Zn { + pub fn data_mut(&mut self) -> &mut VecZnx { &mut self.data } } diff --git a/poulpy-core/src/noise/mod.rs b/poulpy-core/src/noise/mod.rs index b4908a7..e65a2ef 100644 --- a/poulpy-core/src/noise/mod.rs +++ b/poulpy-core/src/noise/mod.rs @@ -42,7 +42,7 @@ pub(crate) fn var_noise_gglwe_product( #[allow(dead_code)] pub(crate) fn var_noise_gglwe_product_v2( n: f64, - logq: usize, + k_ksk: usize, dnum: usize, dsize: usize, base2k: usize, @@ -55,7 +55,7 @@ pub(crate) fn var_noise_gglwe_product_v2( ) -> f64 { let base: f64 = ((dsize * base2k) as f64).exp2(); let var_base: f64 = base * base / 12f64; - let scale: f64 = (logq as f64).exp2(); + let scale: f64 = (k_ksk as f64).exp2(); let mut noise: f64 = (dnum as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); noise += var_msg * var_a_err * var_base * n; diff --git a/poulpy-core/src/scratch.rs b/poulpy-core/src/scratch.rs index 06b8a04..1d39a9b 100644 --- a/poulpy-core/src/scratch.rs +++ b/poulpy-core/src/scratch.rs @@ -23,7 +23,7 @@ where where A: LWEInfos, { - let (data, scratch) = self.take_zn(infos.n().into(), 1, infos.size()); + let (data, scratch) = self.take_vec_znx(infos.n().into(), 1, infos.size()); ( LWE { k: infos.k(), diff --git a/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs b/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs index 14aab05..00c0169 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/gglwe_atk.rs @@ -9,10 +9,10 @@ use crate::{ encryption::SIGMA, layouts::{ GGLWEInfos, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWEPlaintext, - GLWESecret, GLWESecretPreparedFactory, + GLWESecret, GLWESecretPreparedFactory, LWEInfos, prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared}, }, - noise::log2_std_noise_gglwe_product, + var_noise_gglwe_product_v2, }; #[allow(clippy::too_many_arguments)] @@ -29,26 +29,27 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 60; - let k_out: usize = 40; - let dsize: usize = k_in.div_ceil(base2k); + let base2k_in: usize = 17; + let base2k_key: usize = 13; + let base2k_out: usize = base2k_in; // MUST BE SAME + let k_in: usize = 102; + let max_dsize: usize = k_in.div_ceil(base2k_key); let p0: i64 = -1; let p1: i64 = -5; for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_apply: usize = (dsize + di) * base2k; + for dsize in 1..max_dsize + 1 { + let k_ksk: usize = k_in + base2k_key * dsize; + let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); let dsize_in: usize = 1; - let dnum_in: usize = k_in / (base2k * di); - let dnum_out: usize = k_out / (base2k * di); - let dnum_apply: usize = k_in.div_ceil(base2k * di); + let dnum_in: usize = k_in / base2k_in; + let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize); let auto_key_in_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_in.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -57,19 +58,19 @@ where let auto_key_out_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), - dnum: dnum_out.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank: rank.into(), }; let auto_key_apply_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k.into(), - k: k_apply.into(), - dnum: dnum_apply.into(), - dsize: di.into(), + base2k: base2k_key.into(), + k: k_ksk.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank: rank.into(), }; @@ -83,13 +84,16 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_in_infos) - | GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_apply_infos) - | GLWEAutomorphismKey::automorphism_tmp_bytes( + .max(GLWEAutomorphismKey::encrypt_sk_tmp_bytes( + module, + &auto_key_apply_infos, + )) + .max(GLWEAutomorphismKey::automorphism_tmp_bytes( module, &auto_key_out_infos, &auto_key_in_infos, &auto_key_apply_infos, - ), + )), ); let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&auto_key_in); @@ -128,7 +132,7 @@ where scratch.borrow(), ); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&auto_key_out_infos); + let mut pt_out: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&auto_key_out_infos); let mut sk_auto: GLWESecret> = GLWESecret::alloc_from_infos(&auto_key_out_infos); sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk @@ -145,41 +149,44 @@ where let mut sk_auto_dft: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk_auto); sk_auto_dft.prepare(module, &sk_auto); - (0..auto_key_out.rank_in().into()).for_each(|col_i| { - (0..auto_key_out.dnum().into()).for_each(|row_i| { + for col_i in 0..auto_key_out.rank_in().into() { + for row_i in 0..auto_key_out.dnum().into() { auto_key_out .at(row_i, col_i) - .decrypt(module, &mut pt, &sk_auto_dft, scratch.borrow()); + .decrypt(module, &mut pt_out, &sk_auto_dft, scratch.borrow()); module.vec_znx_sub_scalar_inplace( - &mut pt.data, + &mut pt_out.data, 0, (dsize_in - 1) + row_i * dsize_in, &sk.data, col_i, ); - let noise_have: f64 = pt.data.stats(base2k, 0).std().log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - n as f64, - base2k * di, + let noise_have: f64 = pt_out.data.stats(pt_out.base2k().into(), 0).std().log2(); + let max_noise: f64 = var_noise_gglwe_product_v2( + module.n() as f64, + k_ksk, + dnum_ksk, + dsize, + base2k_key, 0.5, 0.5, 0f64, SIGMA * SIGMA, 0f64, rank as f64, - k_out, - k_apply, - ); + ) + .sqrt() + .log2(); assert!( - noise_have < noise_want + 0.5, + noise_have < max_noise + 0.5, "{noise_have} {}", - noise_want + 0.5 + max_noise + 0.5 ); - }); - }); + } + } } } } @@ -198,25 +205,27 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 60; - let dsize: usize = k_in.div_ceil(base2k); + let base2k_out: usize = 17; + let base2k_key: usize = 13; + let k_out: usize = 102; + let max_dsize: usize = k_out.div_ceil(base2k_key); + let p0: i64 = -1; let p1: i64 = -5; for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_apply: usize = (dsize + di) * base2k; + for dsize in 1..max_dsize + 1 { + let k_ksk: usize = k_out + base2k_key * dsize; let n: usize = module.n(); let dsize_in: usize = 1; - let dnum_in: usize = k_in / (base2k * di); - let dnum_apply: usize = k_in.div_ceil(base2k * di); + let dnum_in: usize = k_out / base2k_out; + let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize); let auto_key_layout: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k.into(), - k: k_in.into(), + base2k: base2k_out.into(), + k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), rank: rank.into(), @@ -224,10 +233,10 @@ where let auto_key_apply_layout: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k.into(), - k: k_apply.into(), - dnum: dnum_apply.into(), - dsize: di.into(), + base2k: base2k_key.into(), + k: k_ksk.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank: rank.into(), }; @@ -306,24 +315,27 @@ where col_i, ); - let noise_have: f64 = pt.data.stats(base2k, 0).std().log2(); - let noise_want: f64 = log2_std_noise_gglwe_product( - n as f64, - base2k * di, + let noise_have: f64 = pt.data.stats(pt.base2k().into(), 0).std().log2(); + let max_noise: f64 = var_noise_gglwe_product_v2( + module.n() as f64, + k_ksk, + dnum_ksk, + dsize, + base2k_key, 0.5, 0.5, 0f64, SIGMA * SIGMA, 0f64, rank as f64, - k_in, - k_apply, - ); + ) + .sqrt() + .log2(); assert!( - noise_have < noise_want + 0.5, + noise_have < max_noise + 0.5, "{noise_have} {}", - noise_want + 0.5 + max_noise + 0.5 ); }); }); diff --git a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs index 6e2a226..fcdd1ef 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs @@ -29,26 +29,28 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 54; - let dsize: usize = k_in.div_ceil(base2k); - let p: i64 = -5; + let base2k_in: usize = 17; + let base2k_key: usize = 13; + let base2k_out: usize = base2k_in; // MUST BE SAME + let k_in: usize = 102; + let max_dsize: usize = k_in.div_ceil(base2k_key); + let p: i64 = -5; for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_ksk: usize = k_in + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_ksk: usize = k_in + base2k_key * dsize; let k_tsk: usize = k_ksk; let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); - let dnum: usize = k_in.div_ceil(base2k * di); - let dnum_in: usize = k_in.div_euclid(base2k * di); + let dnum_in: usize = k_in / base2k_in; + let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize); let dsize_in: usize = 1; let ggsw_in_layout: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_in.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -57,7 +59,7 @@ where let ggsw_out_layout: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -66,19 +68,19 @@ where let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_tsk.into(), - dnum: dnum.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank: rank.into(), }; let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), - dnum: dnum.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank: rank.into(), }; @@ -154,7 +156,7 @@ where let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( n as f64, - base2k * di, + base2k_key * dsize, col_j, var_xs, 0f64, @@ -187,23 +189,25 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_out: usize = 54; - let dsize: usize = k_out.div_ceil(base2k); + let base2k_out: usize = 17; + let base2k_key: usize = 13; + let k_out: usize = 102; + let max_dsize: usize = k_out.div_ceil(base2k_key); + let p: i64 = -1; for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_ksk: usize = k_out + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_ksk: usize = k_out + base2k_key * dsize; let k_tsk: usize = k_ksk; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(di * base2k); - let dnum_in: usize = k_out.div_euclid(base2k * di); + let dnum_in: usize = k_out / base2k_out; + let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize); let dsize_in: usize = 1; let ggsw_out_layout: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), dnum: dnum_in.into(), dsize: dsize_in.into(), @@ -212,19 +216,19 @@ where let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_tsk.into(), - dnum: dnum.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank: rank.into(), }; let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), - dnum: dnum.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank: rank.into(), }; @@ -293,7 +297,7 @@ where let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( n as f64, - base2k * di, + base2k_key * dsize, col_j, var_xs, 0f64, diff --git a/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs index 58f737a..5b43f44 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/glwe_ct.rs @@ -5,14 +5,14 @@ use poulpy_hal::{ }; use crate::{ - GLWEAutomorphism, GLWEAutomorphismKeyEncryptSk, GLWEDecrypt, GLWEEncryptSk, GLWENoise, ScratchTakeCore, + GLWEAutomorphism, GLWEAutomorphismKeyEncryptSk, GLWEDecrypt, GLWEEncryptSk, GLWENoise, GLWENormalize, ScratchTakeCore, encryption::SIGMA, layouts::{ GLWE, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared}, }, - noise::log2_std_noise_gglwe_product, + var_noise_gglwe_product_v2, }; pub fn test_glwe_automorphism(module: &Module) @@ -25,55 +25,59 @@ where + GLWEAutomorphismKeyEncryptSk + GLWEAutomorphismKeyPreparedFactory + GLWENoise - + VecZnxAutomorphismInplace, + + VecZnxAutomorphismInplace + + GLWENormalize, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 60; - let dsize: usize = k_in.div_ceil(base2k); + let base2k_in: usize = 17; + let base2k_key: usize = 13; + let base2k_out: usize = 15; + let k_in: usize = 102; + let max_dsize: usize = k_in.div_ceil(base2k_key); let p: i64 = -5; for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_ksk: usize = k_in + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_ksk: usize = k_in + base2k_key * dsize; let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); - let dnum: usize = k_in.div_ceil(base2k * dsize); + let dnum: usize = k_in.div_ceil(base2k_key * dsize); let ct_in_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_in.into(), rank: rank.into(), }; let ct_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), rank: rank.into(), }; let autokey_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_out.into(), rank: rank.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), }; let mut autokey: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&autokey_infos); let mut ct_in: GLWE> = GLWE::alloc_from_infos(&ct_in_infos); let mut ct_out: GLWE> = GLWE::alloc_from_infos(&ct_out_infos); - let mut pt_want: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ct_out_infos); + let mut pt_in: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ct_in_infos); + let mut pt_out: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&ct_out_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k_in, &mut pt_in.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey) @@ -99,7 +103,7 @@ where ct_in.encrypt_sk( module, - &pt_want, + &pt_in, &sk_prepared, &mut source_xa, &mut source_xe, @@ -112,22 +116,26 @@ where ct_out.automorphism(module, &ct_in, &autokey_prepared, scratch.borrow()); - let max_noise: f64 = log2_std_noise_gglwe_product( + let max_noise: f64 = var_noise_gglwe_product_v2( module.n() as f64, - base2k * dsize, + k_ksk, + dnum, + max_dsize, + base2k_key, 0.5, 0.5, 0f64, SIGMA * SIGMA, 0f64, rank as f64, - k_in, - k_ksk, - ); + ) + .sqrt() + .log2(); - module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0, scratch.borrow()); + module.glwe_normalize(&mut pt_out, &pt_in, scratch.borrow()); + module.vec_znx_automorphism_inplace(p, &mut pt_out.data, 0, scratch.borrow()); - ct_out.assert_noise(module, &sk_prepared, &pt_want, max_noise + 1.0); + ct_out.assert_noise(module, &sk_prepared, &pt_out, max_noise + 1.0); } } } @@ -147,31 +155,33 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_out: usize = 60; - let dsize: usize = k_out.div_ceil(base2k); + let base2k_out: usize = 17; + let base2k_key: usize = 13; + let k_out: usize = 102; + let max_dsize: usize = k_out.div_ceil(base2k_key); + let p = -5; for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_ksk: usize = k_out + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_ksk: usize = k_out + base2k_key * dsize; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(base2k * dsize); + let dnum: usize = k_out.div_ceil(base2k_key * dsize); let ct_out_infos: GLWELayout = GLWELayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), rank: rank.into(), }; let autokey_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), rank: rank.into(), dnum: dnum.into(), - dsize: di.into(), + dsize: dsize.into(), }; let mut autokey: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&autokey_infos); @@ -182,7 +192,7 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); + module.vec_znx_fill_uniform(base2k_out, &mut pt_want.data, 0, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey) @@ -221,18 +231,21 @@ where ct.automorphism_inplace(module, &autokey_prepared, scratch.borrow()); - let max_noise: f64 = log2_std_noise_gglwe_product( + let max_noise: f64 = var_noise_gglwe_product_v2( module.n() as f64, - base2k * dsize, + k_ksk, + dnum, + dsize, + base2k_key, 0.5, 0.5, 0f64, SIGMA * SIGMA, 0f64, rank as f64, - k_out, - k_ksk, - ); + ) + .sqrt() + .log2(); module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/conversion.rs b/poulpy-core/src/tests/test_suite/conversion.rs index eecaaea..d4586e1 100644 --- a/poulpy-core/src/tests/test_suite/conversion.rs +++ b/poulpy-core/src/tests/test_suite/conversion.rs @@ -1,5 +1,5 @@ use poulpy_hal::{ - api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxFillUniform}, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxFillUniform, VecZnxNormalize}, layouts::{Backend, FillUniform, Module, Scratch, ScratchOwned, ZnxView}, source::Source, }; @@ -104,7 +104,8 @@ where + GLWEDecrypt + GLWESecretPreparedFactory + LWEEncryptSk - + LWEToGLWEKeyPreparedFactory, + + LWEToGLWEKeyPreparedFactory + + VecZnxNormalize, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { @@ -120,23 +121,23 @@ where let lwe_to_glwe_infos: LWEToGLWEKeyLayout = LWEToGLWEKeyLayout { n: n_glwe, - base2k: Base2K(17), - k: TorusPrecision(51), + base2k: Base2K(13), + k: TorusPrecision(92), dnum: Dnum(2), rank_out: rank, }; let glwe_infos: GLWELayout = GLWELayout { n: n_glwe, - base2k: Base2K(17), - k: TorusPrecision(34), + base2k: Base2K(15), + k: TorusPrecision(75), rank, }; let lwe_infos: LWELayout = LWELayout { n: n_lwe, base2k: Base2K(17), - k: TorusPrecision(34), + k: TorusPrecision(75), }; let mut scratch: ScratchOwned = ScratchOwned::alloc( @@ -160,7 +161,14 @@ where lwe_pt.encode_i64(data, k_lwe_pt); let mut lwe_ct: LWE> = LWE::alloc_from_infos(&lwe_infos); - lwe_ct.encrypt_sk(module, &lwe_pt, &sk_lwe, &mut source_xa, &mut source_xe); + lwe_ct.encrypt_sk( + module, + &lwe_pt, + &sk_lwe, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); let mut ksk: LWEToGLWEKey> = LWEToGLWEKey::alloc_from_infos(&lwe_to_glwe_infos); @@ -183,7 +191,19 @@ where let mut glwe_pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); glwe_ct.decrypt(module, &mut glwe_pt, &sk_glwe_prepared, scratch.borrow()); - assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt.data.at(0, 0)[0]); + let mut lwe_pt_conv = LWEPlaintext::alloc(glwe_pt.base2k(), lwe_pt.k()); + + module.vec_znx_normalize( + glwe_pt.base2k().as_usize(), + lwe_pt_conv.data_mut(), + 0, + lwe_pt.base2k().as_usize(), + lwe_pt.data(), + 0, + scratch.borrow(), + ); + + assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt_conv.data.at(0, 0)[0]); } pub fn test_glwe_to_lwe(module: &Module) @@ -196,7 +216,8 @@ where + GLWEDecrypt + GLWESecretPreparedFactory + GLWEToLWESwitchingKeyEncryptSk - + GLWEToLWEKeyPreparedFactory, + + GLWEToLWEKeyPreparedFactory + + VecZnxNormalize, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { @@ -208,8 +229,8 @@ where let glwe_to_lwe_infos: GLWEToLWEKeyLayout = GLWEToLWEKeyLayout { n: n_glwe, - base2k: Base2K(17), - k: TorusPrecision(51), + base2k: Base2K(13), + k: TorusPrecision(91), dnum: Dnum(2), rank_in: rank, }; @@ -217,14 +238,14 @@ where let glwe_infos: GLWELayout = GLWELayout { n: n_glwe, base2k: Base2K(17), - k: TorusPrecision(34), + k: TorusPrecision(72), rank, }; let lwe_infos: LWELayout = LWELayout { n: n_lwe, - base2k: Base2K(17), - k: TorusPrecision(34), + base2k: Base2K(15), + k: TorusPrecision(72), }; let mut source_xs: Source = Source::new([0u8; 32]); @@ -284,7 +305,19 @@ where lwe_ct.from_glwe(module, &glwe_ct, a_idx, &ksk_prepared, scratch.borrow()); let mut lwe_pt: LWEPlaintext> = LWEPlaintext::alloc_from_infos(&lwe_infos); - lwe_ct.decrypt(module, &mut lwe_pt, &sk_lwe); + lwe_ct.decrypt(module, &mut lwe_pt, &sk_lwe, scratch.borrow()); - assert_eq!(glwe_pt.data.at(0, 0)[a_idx], lwe_pt.data.at(0, 0)[0]); + let mut glwe_pt_conv = GLWEPlaintext::alloc(glwe_ct.n(), lwe_pt.base2k(), lwe_pt.k()); + + module.vec_znx_normalize( + lwe_pt.base2k().as_usize(), + glwe_pt_conv.data_mut(), + 0, + glwe_ct.base2k().as_usize(), + glwe_pt.data(), + 0, + scratch.borrow(), + ); + + assert_eq!(glwe_pt_conv.data.at(0, 0)[a_idx], lwe_pt.data.at(0, 0)[0]); } diff --git a/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs index 548d1f0..1590784 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/gglwe_ct.rs @@ -12,6 +12,7 @@ use crate::{ prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared}, }, noise::log2_std_noise_gglwe_product, + var_noise_gglwe_product_v2, }; pub fn test_gglwe_switching_key_keyswitch(module: &Module) @@ -24,27 +25,29 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 60; - let dsize: usize = k_in.div_ceil(base2k); + let base2k_in: usize = 17; + let base2k_key: usize = 13; + let base2k_out: usize = base2k_in; // MUST BE SAME + let k_in: usize = 102; + let max_dsize: usize = k_in.div_ceil(base2k_key); - for rank_in_s0s1 in 1_usize..3 { + for rank_in_s0s1 in 1_usize..2 { for rank_out_s0s1 in 1_usize..3 { for rank_out_s1s2 in 1_usize..3 { - for di in 1_usize..dsize + 1 { - let k_ksk: usize = k_in + base2k * di; + for dsize in 1_usize..max_dsize + 1 { + let k_ksk: usize = k_in + base2k_key * dsize; let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); - let dnum: usize = k_in / base2k; - let dnum_apply: usize = k_in.div_ceil(base2k * di); let dsize_in: usize = 1; + let dnum_in: usize = k_in / base2k_in; + let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize); let gglwe_s0s1_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_in.into(), - dnum: dnum.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank_in: rank_in_s0s1.into(), rank_out: rank_out_s0s1.into(), @@ -52,19 +55,19 @@ where let gglwe_s1s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), - dnum: dnum_apply.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank_in: rank_out_s0s1.into(), rank_out: rank_out_s1s2.into(), }; let gglwe_s0s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), - dnum: dnum_apply.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank_in: rank_in_s0s1.into(), rank_out: rank_out_s1s2.into(), @@ -85,8 +88,8 @@ where ); let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GLWESwitchingKey::keyswitch_tmp_bytes( module, - &gglwe_s0s1_infos, &gglwe_s0s2_infos, + &gglwe_s0s1_infos, &gglwe_s1s2_infos, )); @@ -135,18 +138,21 @@ where scratch_apply.borrow(), ); - let max_noise: f64 = log2_std_noise_gglwe_product( - n as f64, - base2k * di, + let max_noise: f64 = var_noise_gglwe_product_v2( + module.n() as f64, + k_ksk, + dnum_ksk, + dsize, + base2k_key, 0.5, 0.5, 0f64, SIGMA * SIGMA, 0f64, rank_out_s0s1 as f64, - k_in, - k_ksk, - ); + ) + .sqrt() + .log2(); gglwe_s0s2 .key @@ -168,23 +174,27 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_out: usize = 60; - let dsize: usize = k_out.div_ceil(base2k); + let base2k_out: usize = 17; + let base2k_key: usize = 13; + let k_out: usize = 102; + let max_dsize: usize = k_out.div_ceil(base2k_key); + for rank_in in 1_usize..3 { for rank_out in 1_usize..3 { - for di in 1_usize..dsize + 1 { - let k_ksk: usize = k_out + base2k * di; + for dsize in 1_usize..max_dsize + 1 { + let k_ksk: usize = k_out + base2k_key * dsize; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(base2k * di); let dsize_in: usize = 1; + let dnum_in: usize = k_out / base2k_out; + let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize); + let gglwe_s0s1_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), - dnum: dnum.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank_in: rank_in.into(), rank_out: rank_out.into(), @@ -192,10 +202,10 @@ where let gglwe_s1s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), - dnum: dnum.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank_in: rank_out.into(), rank_out: rank_out.into(), }; @@ -263,7 +273,7 @@ where let max_noise: f64 = log2_std_noise_gglwe_product( n as f64, - base2k * di, + base2k_key * dsize, var_xs, var_xs, 0f64, diff --git a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs index c4191fa..40ed030 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs @@ -30,53 +30,57 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_in: usize = 54; - let dsize: usize = k_in.div_ceil(base2k); + let base2k_in: usize = 17; + let base2k_key: usize = 13; + let base2k_out: usize = base2k_in; // MUST BE SAME + let k_in: usize = 102; + let max_dsize: usize = k_in.div_ceil(base2k_key); + for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_ksk: usize = k_in + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_ksk: usize = k_in + base2k_key * dsize; let k_tsk: usize = k_ksk; let k_out: usize = k_ksk; // Better capture noise. let n: usize = module.n(); - let dnum: usize = k_in.div_ceil(di * base2k); + let dnum_in: usize = k_in / base2k_in; + let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize); let dsize_in: usize = 1; let ggsw_in_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_in.into(), - dnum: dnum.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank: rank.into(), }; let ggsw_out_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), - dnum: dnum.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank: rank.into(), }; let tsk_infos: GLWETensorKeyLayout = GLWETensorKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_tsk.into(), - dnum: dnum.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank: rank.into(), }; let ksk_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), - dnum: dnum.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank_in: rank.into(), rank_out: rank.into(), }; @@ -163,7 +167,7 @@ where let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( n as f64, - base2k * di, + base2k_key * dsize, col_j, var_xs, 0f64, @@ -195,43 +199,45 @@ where ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { - let base2k: usize = 12; - let k_out: usize = 54; - let dsize: usize = k_out.div_ceil(base2k); + let base2k_out: usize = 17; + let base2k_key: usize = 13; + let k_out: usize = 102; + let max_dsize: usize = k_out.div_ceil(base2k_key); + for rank in 1_usize..3 { - for di in 1..dsize + 1 { - let k_ksk: usize = k_out + base2k * di; + for dsize in 1..max_dsize + 1 { + let k_ksk: usize = k_out + base2k_key * dsize; let k_tsk: usize = k_ksk; let n: usize = module.n(); - let dnum: usize = k_out.div_ceil(di * base2k); - + let dnum_in: usize = k_out / base2k_out; + let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize); let dsize_in: usize = 1; let ggsw_out_infos: GGSWLayout = GGSWLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), k: k_out.into(), - dnum: dnum.into(), + dnum: dnum_in.into(), dsize: dsize_in.into(), rank: rank.into(), }; let tsk_infos: GLWETensorKeyLayout = GLWETensorKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_tsk.into(), - dnum: dnum.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank: rank.into(), }; let ksk_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), - dnum: dnum.into(), - dsize: di.into(), + dnum: dnum_ksk.into(), + dsize: dsize.into(), rank_in: rank.into(), rank_out: rank.into(), }; @@ -311,7 +317,7 @@ where let max_noise = |col_j: usize| -> f64 { noise_ggsw_keyswitch( n as f64, - base2k * di, + base2k_key * dsize, col_j, var_xs, 0f64, diff --git a/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs index 4abffc6..76f5863 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/glwe_ct.rs @@ -12,7 +12,6 @@ use crate::{ GLWESwitchingKeyPreparedFactory, LWEInfos, prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared}, }, - noise::log2_std_noise_gglwe_product, var_noise_gglwe_product_v2, }; diff --git a/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs index 7617b09..cd06749 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/lwe_ct.rs @@ -1,5 +1,5 @@ use poulpy_hal::{ - api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow}, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalize}, layouts::{Backend, Module, Scratch, ScratchOwned, ZnxView}, source::Source, }; @@ -14,21 +14,27 @@ use crate::{ pub fn test_lwe_keyswitch(module: &Module) where - Module: - LWEKeySwitch + LWESwitchingKeyEncrypt + LWEEncryptSk + LWESwitchingKeyPreparedFactory + LWEDecrypt, + Module: LWEKeySwitch + + LWESwitchingKeyEncrypt + + LWEEncryptSk + + LWESwitchingKeyPreparedFactory + + LWEDecrypt + + VecZnxNormalize, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { let n: usize = module.n(); - let base2k: usize = 17; + let base2k_in: usize = 17; + let base2k_out: usize = 15; + let base2k_key: usize = 13; - let n_lwe_in: usize = 22; - let n_lwe_out: usize = 30; - let k_lwe_ct: usize = 2 * base2k; + let n_lwe_in: usize = module.n() >> 1; + let n_lwe_out: usize = module.n() >> 1; + let k_lwe_ct: usize = 102; let k_lwe_pt: usize = 8; - let k_ksk: usize = k_lwe_ct + base2k; - let dnum: usize = k_lwe_ct.div_ceil(base2k); + let k_ksk: usize = k_lwe_ct + base2k_key; + let dnum: usize = k_lwe_ct.div_ceil(base2k_key); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); @@ -36,21 +42,21 @@ where let key_apply_infos: LWESwitchingKeyLayout = LWESwitchingKeyLayout { n: n.into(), - base2k: base2k.into(), + base2k: base2k_key.into(), k: k_ksk.into(), dnum: dnum.into(), }; let lwe_in_infos: LWELayout = LWELayout { n: n_lwe_in.into(), - base2k: base2k.into(), + base2k: base2k_in.into(), k: k_lwe_ct.into(), }; let lwe_out_infos: LWELayout = LWELayout { n: n_lwe_out.into(), k: k_lwe_ct.into(), - base2k: base2k.into(), + base2k: base2k_out.into(), }; let mut scratch: ScratchOwned = ScratchOwned::alloc( @@ -66,7 +72,7 @@ where let data: i64 = 17; - let mut lwe_pt_in: LWEPlaintext> = LWEPlaintext::alloc(base2k.into(), k_lwe_pt.into()); + let mut lwe_pt_in: LWEPlaintext> = LWEPlaintext::alloc(base2k_in.into(), k_lwe_pt.into()); lwe_pt_in.encode_i64(data, k_lwe_pt.into()); let mut lwe_ct_in: LWE> = LWE::alloc_from_infos(&lwe_in_infos); @@ -76,6 +82,7 @@ where &sk_lwe_in, &mut source_xa, &mut source_xe, + scratch.borrow(), ); let mut ksk: LWESwitchingKey> = LWESwitchingKey::alloc_from_infos(&key_apply_infos); @@ -97,7 +104,18 @@ where lwe_ct_out.keyswitch(module, &lwe_ct_in, &ksk_prepared, scratch.borrow()); let mut lwe_pt_out: LWEPlaintext> = LWEPlaintext::alloc_from_infos(&lwe_out_infos); - lwe_ct_out.decrypt(module, &mut lwe_pt_out, &sk_lwe_out); + lwe_ct_out.decrypt(module, &mut lwe_pt_out, &sk_lwe_out, scratch.borrow()); - assert_eq!(lwe_pt_in.data.at(0, 0)[0], lwe_pt_out.data.at(0, 0)[0]); + let mut lwe_pt_want: LWEPlaintext> = LWEPlaintext::alloc_from_infos(&lwe_out_infos); + module.vec_znx_normalize( + base2k_out, + lwe_pt_want.data_mut(), + 0, + base2k_in, + lwe_pt_in.data(), + 0, + scratch.borrow(), + ); + + assert_eq!(lwe_pt_want.data.at(0, 0)[0], lwe_pt_out.data.at(0, 0)[0]); } diff --git a/poulpy-core/src/utils.rs b/poulpy-core/src/utils.rs index 858cd5d..7484fcd 100644 --- a/poulpy-core/src/utils.rs +++ b/poulpy-core/src/utils.rs @@ -37,16 +37,20 @@ impl GLWEPlaintext { impl LWEPlaintext { pub fn encode_i64(&mut self, data: i64, k: TorusPrecision) { let base2k: usize = self.base2k().into(); - self.data.encode_i64(base2k, k.into(), data); + self.data.encode_coeff_i64(base2k, 0, k.into(), 0, data); } } impl LWEPlaintext { pub fn decode_i64(&self, k: TorusPrecision) -> i64 { - self.data.decode_i64(self.base2k().into(), k.into()) + self.data + .decode_coeff_i64(self.base2k().into(), 0, k.into(), 0) } pub fn decode_float(&self) -> Float { - self.data.decode_float(self.base2k().into()) + let mut out: [Float; 1] = [Float::new(self.k().as_u32())]; + self.data + .decode_vec_float(self.base2k().into(), 0, &mut out); + out[0].clone() } } diff --git a/poulpy-hal/src/api/mod.rs b/poulpy-hal/src/api/mod.rs index b024a94..9af22de 100644 --- a/poulpy-hal/src/api/mod.rs +++ b/poulpy-hal/src/api/mod.rs @@ -6,7 +6,6 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp_pmat; -mod zn; pub use convolution::*; pub use module::*; @@ -16,4 +15,3 @@ pub use vec_znx::*; pub use vec_znx_big::*; pub use vec_znx_dft::*; pub use vmp_pmat::*; -pub use zn::*; diff --git a/poulpy-hal/src/api/scratch.rs b/poulpy-hal/src/api/scratch.rs index aef02e1..4dbb14b 100644 --- a/poulpy-hal/src/api/scratch.rs +++ b/poulpy-hal/src/api/scratch.rs @@ -1,6 +1,6 @@ use crate::{ api::{ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf}, - layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, Zn}, + layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat}, }; /// Allocates a new [crate::layouts::ScratchOwned] of `size` aligned bytes. @@ -69,11 +69,6 @@ where (SvpPPol::from_data(take_slice, module.n(), cols), rem_slice) } - fn take_zn(&mut self, n: usize, cols: usize, size: usize) -> (Zn<&mut [u8]>, &mut Self) { - let (take_slice, rem_slice) = self.take_slice(Zn::bytes_of(n, cols, size)); - (Zn::from_data(take_slice, n, cols, size), rem_slice) - } - fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) { let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(n, cols, size)); (VecZnx::from_data(take_slice, n, cols, size), rem_slice) diff --git a/poulpy-hal/src/api/zn.rs b/poulpy-hal/src/api/zn.rs deleted file mode 100644 index 9e8e308..0000000 --- a/poulpy-hal/src/api/zn.rs +++ /dev/null @@ -1,58 +0,0 @@ -use crate::{ - layouts::{Backend, Scratch, ZnToMut}, - reference::zn::zn_normalize_tmp_bytes, - source::Source, -}; - -pub trait ZnNormalizeTmpBytes { - fn zn_normalize_tmp_bytes(&self, n: usize) -> usize { - zn_normalize_tmp_bytes(n) - } -} - -pub trait ZnNormalizeInplace { - /// Normalizes the selected column of `a`. - fn zn_normalize_inplace(&self, n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) - where - R: ZnToMut; -} - -pub trait ZnFillUniform { - /// Fills the first `size` size with uniform values in \[-2^{base2k-1}, 2^{base2k-1}\] - fn zn_fill_uniform(&self, n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) - where - R: ZnToMut; -} - -#[allow(clippy::too_many_arguments)] -pub trait ZnFillNormal { - fn zn_fill_normal( - &self, - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut; -} - -#[allow(clippy::too_many_arguments)] -pub trait ZnAddNormal { - /// Adds a discrete normal vector scaled by 2^{-k} with the provided standard deviation and bounded to \[-bound, bound\]. - fn zn_add_normal( - &self, - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut; -} diff --git a/poulpy-hal/src/delegates/mod.rs b/poulpy-hal/src/delegates/mod.rs index 85de88d..595a641 100644 --- a/poulpy-hal/src/delegates/mod.rs +++ b/poulpy-hal/src/delegates/mod.rs @@ -5,4 +5,3 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp_pmat; -mod zn; diff --git a/poulpy-hal/src/delegates/zn.rs b/poulpy-hal/src/delegates/zn.rs deleted file mode 100644 index 6a4c999..0000000 --- a/poulpy-hal/src/delegates/zn.rs +++ /dev/null @@ -1,81 +0,0 @@ -use crate::{ - api::{ZnAddNormal, ZnFillNormal, ZnFillUniform, ZnNormalizeInplace, ZnNormalizeTmpBytes}, - layouts::{Backend, Module, Scratch, ZnToMut}, - oep::{ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl}, - source::Source, -}; - -impl ZnNormalizeTmpBytes for Module -where - B: Backend + ZnNormalizeTmpBytesImpl, -{ - fn zn_normalize_tmp_bytes(&self, n: usize) -> usize { - B::zn_normalize_tmp_bytes_impl(n) - } -} - -impl ZnNormalizeInplace for Module -where - B: Backend + ZnNormalizeInplaceImpl, -{ - fn zn_normalize_inplace(&self, n: usize, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) - where - A: ZnToMut, - { - B::zn_normalize_inplace_impl(n, base2k, a, a_col, scratch) - } -} - -impl ZnFillUniform for Module -where - B: Backend + ZnFillUniformImpl, -{ - fn zn_fill_uniform(&self, n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) - where - R: ZnToMut, - { - B::zn_fill_uniform_impl(n, base2k, res, res_col, source); - } -} - -impl ZnFillNormal for Module -where - B: Backend + ZnFillNormalImpl, -{ - fn zn_fill_normal( - &self, - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut, - { - B::zn_fill_normal_impl(n, base2k, res, res_col, k, source, sigma, bound); - } -} - -impl ZnAddNormal for Module -where - B: Backend + ZnAddNormalImpl, -{ - fn zn_add_normal( - &self, - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut, - { - B::zn_add_normal_impl(n, base2k, res, res_col, k, source, sigma, bound); - } -} diff --git a/poulpy-hal/src/layouts/encoding.rs b/poulpy-hal/src/layouts/encoding.rs index 37a4677..6934eec 100644 --- a/poulpy-hal/src/layouts/encoding.rs +++ b/poulpy-hal/src/layouts/encoding.rs @@ -2,7 +2,7 @@ use itertools::izip; use rug::{Assign, Float}; use crate::{ - layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, Zn, ZnToMut, ZnToRef, ZnxInfos, ZnxView, ZnxViewMut}, + layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut}, reference::znx::{ ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStepInplace, ZnxRef, ZnxZero, get_carry_i128, get_digit_i128, znx_zero_ref, @@ -245,90 +245,6 @@ impl VecZnx { } } -impl Zn { - pub fn encode_i64(&mut self, base2k: usize, k: usize, data: i64) { - let size: usize = k.div_ceil(base2k); - - #[cfg(debug_assertions)] - { - let a: Zn<&mut [u8]> = self.to_mut(); - assert!( - size <= a.size(), - "invalid argument k.div_ceil(base2k)={} > a.size()={}", - size, - a.size() - ); - } - - let mut a: Zn<&mut [u8]> = self.to_mut(); - let a_size = a.size(); - - for j in 0..a_size { - a.at_mut(0, j)[0] = 0 - } - - a.at_mut(0, size - 1)[0] = data; - - let mut carry: Vec = vec![0i64; 1]; - let k_rem: usize = (base2k - (k % base2k)) % base2k; - - for j in (0..size).rev() { - let slice = &mut a.at_mut(0, j)[..1]; - - if j == size - 1 { - ZnxRef::znx_normalize_first_step_inplace(base2k, k_rem, slice, &mut carry); - } else if j == 0 { - ZnxRef::znx_normalize_final_step_inplace(base2k, k_rem, slice, &mut carry); - } else { - ZnxRef::znx_normalize_middle_step_inplace(base2k, k_rem, slice, &mut carry); - } - } - } -} - -impl Zn { - pub fn decode_i64(&self, base2k: usize, k: usize) -> i64 { - let a: Zn<&[u8]> = self.to_ref(); - let size: usize = k.div_ceil(base2k); - let mut res: i64 = 0; - let rem: usize = base2k - (k % base2k); - (0..size).for_each(|j| { - let x: i64 = a.at(0, j)[0]; - if j == size - 1 && rem != base2k { - let k_rem: usize = (base2k - rem) % base2k; - let scale: i64 = 1 << rem as i64; - res = (res << k_rem) + div_round(x, scale); - } else { - res = (res << base2k) + x; - } - }); - res - } - - pub fn decode_float(&self, base2k: usize) -> Float { - let a: Zn<&[u8]> = self.to_ref(); - let size: usize = a.size(); - let prec: u32 = (base2k * size) as u32; - - // 2^{base2k} - let base: Float = Float::with_val(prec, (1 << base2k) as f64); - let mut res: Float = Float::with_val(prec, (1 << base2k) as f64); - - // y[i] = sum x[j][i] * 2^{-base2k*j} - (0..size).for_each(|i| { - if i == 0 { - res.assign(a.at(0, size - i - 1)[0]); - res /= &base; - } else { - res += Float::with_val(prec, a.at(0, size - i - 1)[0]); - res /= &base; - } - }); - - res - } -} - #[inline] pub fn div_round(a: i64, b: i64) -> i64 { assert!(b != 0, "division by zero"); diff --git a/poulpy-hal/src/layouts/mod.rs b/poulpy-hal/src/layouts/mod.rs index 7d4600b..d164234 100644 --- a/poulpy-hal/src/layouts/mod.rs +++ b/poulpy-hal/src/layouts/mod.rs @@ -10,7 +10,6 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp_pmat; -mod zn; mod znx_base; pub use mat_znx::*; @@ -24,7 +23,6 @@ pub use vec_znx::*; pub use vec_znx_big::*; pub use vec_znx_dft::*; pub use vmp_pmat::*; -pub use zn::*; pub use znx_base::*; pub trait Data = PartialEq + Eq + Sized + Default; diff --git a/poulpy-hal/src/layouts/zn.rs b/poulpy-hal/src/layouts/zn.rs deleted file mode 100644 index 00f8067..0000000 --- a/poulpy-hal/src/layouts/zn.rs +++ /dev/null @@ -1,273 +0,0 @@ -use std::{ - fmt, - hash::{DefaultHasher, Hasher}, -}; - -use crate::{ - alloc_aligned, - layouts::{ - Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, ToOwnedDeep, WriterTo, ZnxInfos, - ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, - }, - source::Source, -}; - -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; -use rand::RngCore; - -#[repr(C)] -#[derive(PartialEq, Eq, Clone, Copy, Hash)] -pub struct Zn { - pub data: D, - pub n: usize, - pub cols: usize, - pub size: usize, - pub max_size: usize, -} - -impl DigestU64 for Zn { - fn digest_u64(&self) -> u64 { - let mut h: DefaultHasher = DefaultHasher::new(); - h.write(self.data.as_ref()); - h.write_usize(self.n); - h.write_usize(self.cols); - h.write_usize(self.size); - h.write_usize(self.max_size); - h.finish() - } -} - -impl ToOwnedDeep for Zn { - type Owned = Zn>; - fn to_owned_deep(&self) -> Self::Owned { - Zn { - data: self.data.as_ref().to_vec(), - n: self.n, - cols: self.cols, - size: self.size, - max_size: self.max_size, - } - } -} - -impl fmt::Debug for Zn { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{self}") - } -} - -impl ZnxInfos for Zn { - fn cols(&self) -> usize { - self.cols - } - - fn rows(&self) -> usize { - 1 - } - - fn n(&self) -> usize { - self.n - } - - fn size(&self) -> usize { - self.size - } -} - -impl ZnxSliceSize for Zn { - fn sl(&self) -> usize { - self.n() * self.cols() - } -} - -impl DataView for Zn { - type D = D; - fn data(&self) -> &Self::D { - &self.data - } -} - -impl DataViewMut for Zn { - fn data_mut(&mut self) -> &mut Self::D { - &mut self.data - } -} - -impl ZnxView for Zn { - type Scalar = i64; -} - -impl Zn> { - pub fn rsh_tmp_bytes(n: usize) -> usize { - n * std::mem::size_of::() - } -} - -impl ZnxZero for Zn { - fn zero(&mut self) { - self.raw_mut().fill(0) - } - fn zero_at(&mut self, i: usize, j: usize) { - self.at_mut(i, j).fill(0); - } -} - -impl Zn> { - pub fn bytes_of(n: usize, cols: usize, size: usize) -> usize { - n * cols * size * size_of::() - } - - pub fn alloc(n: usize, cols: usize, size: usize) -> Self { - let data: Vec = alloc_aligned::(Self::bytes_of(n, cols, size)); - Self { - data, - n, - cols, - size, - max_size: size, - } - } - - pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { - let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of(n, cols, size)); - Self { - data, - n, - cols, - size, - max_size: size, - } - } -} - -impl Zn { - pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { - Self { - data, - n, - cols, - size, - max_size: size, - } - } -} - -impl fmt::Display for Zn { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!( - f, - "Zn(n={}, cols={}, size={})", - self.n, self.cols, self.size - )?; - - for col in 0..self.cols { - writeln!(f, "Column {col}:")?; - for size in 0..self.size { - let coeffs = self.at(col, size); - write!(f, " Size {size}: [")?; - - let max_show = 100; - let show_count = coeffs.len().min(max_show); - - for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{coeff}")?; - } - - if coeffs.len() > max_show { - write!(f, ", ... ({} more)", coeffs.len() - max_show)?; - } - - writeln!(f, "]")?; - } - } - Ok(()) - } -} - -impl FillUniform for Zn { - fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { - match log_bound { - 64 => source.fill_bytes(self.data.as_mut()), - 0 => panic!("invalid log_bound, cannot be zero"), - _ => { - let mask: u64 = (1u64 << log_bound) - 1; - for x in self.raw_mut().iter_mut() { - let r = source.next_u64() & mask; - *x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound); - } - } - } - } -} - -pub type ZnOwned = Zn>; -pub type ZnMut<'a> = Zn<&'a mut [u8]>; -pub type ZnRef<'a> = Zn<&'a [u8]>; - -pub trait ZnToRef { - fn to_ref(&self) -> Zn<&[u8]>; -} - -impl ZnToRef for Zn { - fn to_ref(&self) -> Zn<&[u8]> { - Zn { - data: self.data.as_ref(), - n: self.n, - cols: self.cols, - size: self.size, - max_size: self.max_size, - } - } -} - -pub trait ZnToMut { - fn to_mut(&mut self) -> Zn<&mut [u8]>; -} - -impl ZnToMut for Zn { - fn to_mut(&mut self) -> Zn<&mut [u8]> { - Zn { - data: self.data.as_mut(), - n: self.n, - cols: self.cols, - size: self.size, - max_size: self.max_size, - } - } -} - -impl ReaderFrom for Zn { - fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - self.n = reader.read_u64::()? as usize; - self.cols = reader.read_u64::()? as usize; - self.size = reader.read_u64::()? as usize; - self.max_size = reader.read_u64::()? as usize; - let len: usize = reader.read_u64::()? as usize; - let buf: &mut [u8] = self.data.as_mut(); - if buf.len() != len { - return Err(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - format!("self.data.len()={} != read len={}", buf.len(), len), - )); - } - reader.read_exact(&mut buf[..len])?; - Ok(()) - } -} - -impl WriterTo for Zn { - fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.n as u64)?; - writer.write_u64::(self.cols as u64)?; - writer.write_u64::(self.size as u64)?; - writer.write_u64::(self.max_size as u64)?; - let buf: &[u8] = self.data.as_ref(); - writer.write_u64::(buf.len() as u64)?; - writer.write_all(buf)?; - Ok(()) - } -} diff --git a/poulpy-hal/src/oep/mod.rs b/poulpy-hal/src/oep/mod.rs index dac0def..bc53c0e 100644 --- a/poulpy-hal/src/oep/mod.rs +++ b/poulpy-hal/src/oep/mod.rs @@ -5,7 +5,6 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp_pmat; -mod zn; pub use module::*; pub use scratch::*; @@ -14,4 +13,3 @@ pub use vec_znx::*; pub use vec_znx_big::*; pub use vec_znx_dft::*; pub use vmp_pmat::*; -pub use zn::*; diff --git a/poulpy-hal/src/oep/zn.rs b/poulpy-hal/src/oep/zn.rs deleted file mode 100644 index d2e03ad..0000000 --- a/poulpy-hal/src/oep/zn.rs +++ /dev/null @@ -1,70 +0,0 @@ -use crate::{ - layouts::{Backend, Scratch, ZnToMut}, - source::Source, -}; - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation. -/// * See [crate::api::ZnNormalizeTmpBytes] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait ZnNormalizeTmpBytesImpl { - fn zn_normalize_tmp_bytes_impl(n: usize) -> usize; -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation. -/// * See [crate::api::ZnNormalizeInplace] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait ZnNormalizeInplaceImpl { - fn zn_normalize_inplace_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch) - where - R: ZnToMut; -} - -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation. -/// * See [crate::api::ZnFillUniform] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait ZnFillUniformImpl { - fn zn_fill_uniform_impl(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) - where - R: ZnToMut; -} - -#[allow(clippy::too_many_arguments)] -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation. -/// * See [crate::api::ZnFillNormal] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait ZnFillNormalImpl { - fn zn_fill_normal_impl( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut; -} - -#[allow(clippy::too_many_arguments)] -/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) -/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation. -/// * See [crate::api::ZnAddNormal] for corresponding public API. -/// # Safety [crate::doc::backend_safety] for safety contract. -pub unsafe trait ZnAddNormalImpl { - fn zn_add_normal_impl( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, - ) where - R: ZnToMut; -} diff --git a/poulpy-hal/src/reference/mod.rs b/poulpy-hal/src/reference/mod.rs index 9fd1500..9bd1154 100644 --- a/poulpy-hal/src/reference/mod.rs +++ b/poulpy-hal/src/reference/mod.rs @@ -1,4 +1,3 @@ pub mod fft64; pub mod vec_znx; -pub mod zn; pub mod znx; diff --git a/poulpy-hal/src/reference/vec_znx/normalize.rs b/poulpy-hal/src/reference/vec_znx/normalize.rs index 998e461..6392f84 100644 --- a/poulpy-hal/src/reference/vec_znx/normalize.rs +++ b/poulpy-hal/src/reference/vec_znx/normalize.rs @@ -53,6 +53,8 @@ pub fn vec_znx_normalize( let res_size: usize = res.size(); let a_size: usize = a.size(); + let carry = &mut carry[..2 * n]; + if res_base2k == a_base2k { if a_size > res_size { for j in (res_size..a_size).rev() { diff --git a/poulpy-hal/src/reference/zn/mod.rs b/poulpy-hal/src/reference/zn/mod.rs deleted file mode 100644 index d4838c3..0000000 --- a/poulpy-hal/src/reference/zn/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod normalization; -mod sampling; - -pub use normalization::*; -pub use sampling::*; diff --git a/poulpy-hal/src/reference/zn/normalization.rs b/poulpy-hal/src/reference/zn/normalization.rs deleted file mode 100644 index 83cfeb7..0000000 --- a/poulpy-hal/src/reference/zn/normalization.rs +++ /dev/null @@ -1,72 +0,0 @@ -use crate::{ - api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace, ZnNormalizeTmpBytes}, - layouts::{Backend, Module, ScratchOwned, Zn, ZnToMut, ZnxInfos, ZnxView, ZnxViewMut}, - reference::znx::{ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStepInplace, ZnxRef}, - source::Source, -}; - -pub fn zn_normalize_tmp_bytes(n: usize) -> usize { - n * size_of::() -} - -pub fn zn_normalize_inplace(n: usize, base2k: usize, res: &mut R, res_col: usize, carry: &mut [i64]) -where - R: ZnToMut, - ARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeFinalStepInplace + ZnxNormalizeMiddleStepInplace, -{ - let mut res: Zn<&mut [u8]> = res.to_mut(); - - #[cfg(debug_assertions)] - { - assert_eq!(carry.len(), res.n()); - } - - let res_size: usize = res.size(); - - for j in (0..res_size).rev() { - let out = &mut res.at_mut(res_col, j)[..n]; - - if j == res_size - 1 { - ARI::znx_normalize_first_step_inplace(base2k, 0, out, carry); - } else if j == 0 { - ARI::znx_normalize_final_step_inplace(base2k, 0, out, carry); - } else { - ARI::znx_normalize_middle_step_inplace(base2k, 0, out, carry); - } - } -} - -pub fn test_zn_normalize_inplace(module: &Module) -where - Module: ZnNormalizeInplace + ZnNormalizeTmpBytes, - ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, -{ - let mut source: Source = Source::new([0u8; 32]); - let cols: usize = 2; - let base2k: usize = 12; - - let n = 33; - - let mut carry: Vec = vec![0i64; zn_normalize_tmp_bytes(n)]; - - let mut scratch: ScratchOwned = ScratchOwned::alloc(module.zn_normalize_tmp_bytes(module.n())); - - for res_size in [1, 2, 6, 11] { - let mut res_0: Zn> = Zn::alloc(n, cols, res_size); - let mut res_1: Zn> = Zn::alloc(n, cols, res_size); - - res_0 - .raw_mut() - .iter_mut() - .for_each(|x| *x = source.next_i32() as i64); - res_1.raw_mut().copy_from_slice(res_0.raw()); - - // Reference - for i in 0..cols { - zn_normalize_inplace::<_, ZnxRef>(n, base2k, &mut res_0, i, &mut carry); - module.zn_normalize_inplace(n, base2k, &mut res_1, i, scratch.borrow()); - } - - assert_eq!(res_0.raw(), res_1.raw()); - } -} diff --git a/poulpy-hal/src/reference/zn/sampling.rs b/poulpy-hal/src/reference/zn/sampling.rs deleted file mode 100644 index b376dcc..0000000 --- a/poulpy-hal/src/reference/zn/sampling.rs +++ /dev/null @@ -1,75 +0,0 @@ -use crate::{ - layouts::{Zn, ZnToMut, ZnxInfos, ZnxViewMut}, - reference::znx::{znx_add_normal_f64_ref, znx_fill_normal_f64_ref, znx_fill_uniform_ref}, - source::Source, -}; - -pub fn zn_fill_uniform(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source) -where - R: ZnToMut, -{ - let mut res: Zn<&mut [u8]> = res.to_mut(); - for j in 0..res.size() { - znx_fill_uniform_ref(base2k, &mut res.at_mut(res_col, j)[..n], source) - } -} - -#[allow(clippy::too_many_arguments)] -pub fn zn_fill_normal( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, -) where - R: ZnToMut, -{ - let mut res: Zn<&mut [u8]> = res.to_mut(); - assert!( - (bound.log2().ceil() as i64) < 64, - "invalid bound: ceil(log2(bound))={} > 63", - (bound.log2().ceil() as i64) - ); - - let limb: usize = k.div_ceil(base2k) - 1; - let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64; - znx_fill_normal_f64_ref( - &mut res.at_mut(res_col, limb)[..n], - sigma * scale, - bound * scale, - source, - ) -} - -#[allow(clippy::too_many_arguments)] -pub fn zn_add_normal( - n: usize, - base2k: usize, - res: &mut R, - res_col: usize, - k: usize, - source: &mut Source, - sigma: f64, - bound: f64, -) where - R: ZnToMut, -{ - let mut res: Zn<&mut [u8]> = res.to_mut(); - assert!( - (bound.log2().ceil() as i64) < 64, - "invalid bound: ceil(log2(bound))={} > 63", - (bound.log2().ceil() as i64) - ); - - let limb: usize = k.div_ceil(base2k) - 1; - let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64; - znx_add_normal_f64_ref( - &mut res.at_mut(res_col, limb)[..n], - sigma * scale, - bound * scale, - source, - ) -} diff --git a/poulpy-schemes/examples/circuit_bootstrapping.rs b/poulpy-schemes/examples/circuit_bootstrapping.rs index 55d779f..c58291a 100644 --- a/poulpy-schemes/examples/circuit_bootstrapping.rs +++ b/poulpy-schemes/examples/circuit_bootstrapping.rs @@ -2,7 +2,7 @@ use poulpy_core::{ GLWENormalize, layouts::{ GGLWEToGGSWKeyLayout, GGSW, GGSWLayout, GLWE, GLWEAutomorphismKeyLayout, GLWELayout, GLWEPlaintext, GLWESecret, LWE, - LWEInfos, LWELayout, LWEPlaintext, LWESecret, + LWELayout, LWEPlaintext, LWESecret, prepared::{GGSWPrepared, GLWESecretPrepared}, }, }; @@ -15,7 +15,7 @@ use poulpy_backend::FFT64Avx as BackendImpl; use poulpy_backend::FFT64Ref as BackendImpl; use poulpy_hal::{ - api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace}, + api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalizeInplace}, layouts::{Module, ScalarZnx, ScratchOwned, ZnxView, ZnxViewMut}, source::Source, }; @@ -155,20 +155,21 @@ fn main() { pt_lwe.encode_i64(data, (k_lwe_pt + 1).into()); // +1 for padding bit // Normalize plaintext to nicely print coefficients - module.zn_normalize_inplace( - pt_lwe.n().into(), - base2k, - pt_lwe.data_mut(), - 0, - scratch.borrow(), - ); + module.vec_znx_normalize_inplace(base2k, pt_lwe.data_mut(), 0, scratch.borrow()); println!("pt_lwe: {pt_lwe}"); // LWE ciphertext let mut ct_lwe: LWE> = LWE::alloc_from_infos(&lwe_infos); // Encrypt LWE Plaintext - ct_lwe.encrypt_sk(&module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); + ct_lwe.encrypt_sk( + &module, + &pt_lwe, + &sk_lwe, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); let now: Instant = Instant::now(); diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs index 1651987..9f91538 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs @@ -111,7 +111,14 @@ pub fn test_blind_rotation( pt_lwe.encode_i64(x, (log_message_modulus + 1).into()); - lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); + lwe.encrypt_sk( + module, + &pt_lwe, + &sk_lwe, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); let f = |x: i64| -> i64 { 2 * x + 1 }; diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs index a75d211..29ffac1 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs @@ -132,7 +132,14 @@ where println!("pt_lwe: {pt_lwe}"); let mut ct_lwe: LWE> = LWE::alloc_from_infos(&lwe_infos); - ct_lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); + ct_lwe.encrypt_sk( + module, + &pt_lwe, + &sk_lwe, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); let now: Instant = Instant::now(); let mut cbt_key: CircuitBootstrappingKey, BRA> = CircuitBootstrappingKey::alloc_from_infos(&cbt_infos); @@ -313,7 +320,14 @@ where println!("pt_lwe: {pt_lwe}"); let mut ct_lwe: LWE> = LWE::alloc_from_infos(&lwe_infos); - ct_lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); + ct_lwe.encrypt_sk( + module, + &pt_lwe, + &sk_lwe, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); let now: Instant = Instant::now(); let mut cbt_key: CircuitBootstrappingKey, BRA> = CircuitBootstrappingKey::alloc_from_infos(&cbt_infos);