From bf513dc55579a1172a5a05d713aa03e94ad83a98 Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Thu, 21 Aug 2025 12:16:53 +0200 Subject: [PATCH] Add Zn type --- poulpy-backend/examples/rlwe_encrypt.rs | 8 +- poulpy-backend/src/cpu_spqlios/ffi/mod.rs | 32 +-- poulpy-backend/src/cpu_spqlios/ffi/vec_znx.rs | 4 +- .../src/cpu_spqlios/ffi/vec_znx_big.rs | 6 +- .../src/cpu_spqlios/ffi/vec_znx_dft.rs | 2 +- poulpy-backend/src/cpu_spqlios/ffi/vmp.rs | 3 +- poulpy-backend/src/cpu_spqlios/ffi/zn64.rs | 13 + poulpy-backend/src/cpu_spqlios/fft64/mod.rs | 1 + .../src/cpu_spqlios/fft64/vec_znx.rs | 10 +- .../src/cpu_spqlios/fft64/vec_znx_big.rs | 7 +- .../src/cpu_spqlios/fft64/vec_znx_dft.rs | 6 +- .../src/cpu_spqlios/fft64/vmp_pmat.rs | 19 +- poulpy-backend/src/cpu_spqlios/fft64/zn.rs | 201 ++++++++++++++ .../src/cpu_spqlios/spqlios-arithmetic | 2 +- .../benches/external_product_glwe_fft64.rs | 19 +- poulpy-core/benches/keyswitch_glwe_fft64.rs | 11 +- poulpy-core/examples/encryption.rs | 4 +- poulpy-core/src/automorphism/gglwe_atk.rs | 6 +- poulpy-core/src/automorphism/ggsw_ct.rs | 11 +- poulpy-core/src/automorphism/glwe_ct.rs | 6 +- poulpy-core/src/conversion/glwe_to_lwe.rs | 5 +- poulpy-core/src/conversion/lwe_to_glwe.rs | 5 +- poulpy-core/src/decryption/glwe_ct.rs | 5 +- poulpy-core/src/decryption/lwe_ct.rs | 7 +- .../src/encryption/compressed/gglwe_atk.rs | 15 +- .../src/encryption/compressed/gglwe_ct.rs | 9 +- .../src/encryption/compressed/gglwe_ksk.rs | 9 +- .../src/encryption/compressed/gglwe_tsk.rs | 4 +- .../src/encryption/compressed/ggsw_ct.rs | 4 +- .../src/encryption/compressed/glwe_ct.rs | 4 +- poulpy-core/src/encryption/gglwe_atk.rs | 12 +- poulpy-core/src/encryption/gglwe_ct.rs | 12 +- poulpy-core/src/encryption/gglwe_ksk.rs | 12 +- poulpy-core/src/encryption/gglwe_tsk.rs | 14 +- poulpy-core/src/encryption/ggsw_ct.rs | 10 +- poulpy-core/src/encryption/glwe_ct.rs | 23 +- poulpy-core/src/encryption/glwe_pk.rs | 1 - poulpy-core/src/encryption/glwe_to_lwe_ksk.rs | 7 +- poulpy-core/src/encryption/lwe_ct.rs | 26 +- poulpy-core/src/encryption/lwe_ksk.rs | 8 +- poulpy-core/src/encryption/lwe_to_glwe_ksk.rs | 4 +- poulpy-core/src/external_product/gglwe_atk.rs | 6 +- poulpy-core/src/external_product/gglwe_ksk.rs | 6 +- poulpy-core/src/external_product/ggsw_ct.rs | 7 +- poulpy-core/src/external_product/glwe_ct.rs | 12 +- poulpy-core/src/glwe_packing.rs | 13 +- poulpy-core/src/glwe_trace.rs | 6 +- poulpy-core/src/keyswitching/gglwe_ct.rs | 14 +- poulpy-core/src/keyswitching/ggsw_ct.rs | 25 +- poulpy-core/src/keyswitching/glwe_ct.rs | 23 +- poulpy-core/src/keyswitching/lwe_ct.rs | 5 +- .../src/layouts/compressed/gglwe_atk.rs | 10 +- .../src/layouts/compressed/gglwe_ct.rs | 10 +- .../src/layouts/compressed/gglwe_ksk.rs | 10 +- .../src/layouts/compressed/gglwe_tsk.rs | 10 +- poulpy-core/src/layouts/compressed/ggsw_ct.rs | 10 +- poulpy-core/src/layouts/compressed/glwe_ct.rs | 10 +- .../src/layouts/compressed/glwe_to_lwe_ksk.rs | 4 +- poulpy-core/src/layouts/compressed/lwe_ct.rs | 23 +- poulpy-core/src/layouts/compressed/lwe_ksk.rs | 14 +- .../src/layouts/compressed/lwe_to_glwe_ksk.rs | 14 +- poulpy-core/src/layouts/compressed/mod.rs | 9 +- poulpy-core/src/layouts/lwe_ct.rs | 16 +- poulpy-core/src/layouts/lwe_pt.rs | 10 +- poulpy-core/src/layouts/prepared/gglwe_atk.rs | 9 +- poulpy-core/src/layouts/prepared/gglwe_ct.rs | 17 +- poulpy-core/src/layouts/prepared/gglwe_ksk.rs | 17 +- poulpy-core/src/layouts/prepared/gglwe_tsk.rs | 9 +- poulpy-core/src/layouts/prepared/ggsw_ct.rs | 9 +- poulpy-core/src/layouts/prepared/glwe_pk.rs | 10 +- poulpy-core/src/layouts/prepared/glwe_sk.rs | 10 +- .../src/layouts/prepared/glwe_to_lwe_ksk.rs | 9 +- poulpy-core/src/layouts/prepared/lwe_ksk.rs | 17 +- .../src/layouts/prepared/lwe_to_glwe_ksk.rs | 9 +- poulpy-core/src/noise/gglwe_ct.rs | 7 +- poulpy-core/src/noise/ggsw_ct.rs | 18 +- poulpy-core/src/noise/glwe_ct.rs | 1 - .../tests/generics/automorphism/gglwe_atk.rs | 12 +- .../tests/generics/automorphism/ggsw_ct.rs | 24 +- .../tests/generics/automorphism/glwe_ct.rs | 19 +- poulpy-core/src/tests/generics/conversion.rs | 22 +- .../tests/generics/encryption/gglwe_atk.rs | 4 +- .../src/tests/generics/encryption/gglwe_ct.rs | 4 +- .../src/tests/generics/encryption/ggsw_ct.rs | 4 +- .../src/tests/generics/encryption/glwe_ct.rs | 18 +- .../src/tests/generics/encryption/glwe_tsk.rs | 14 +- .../generics/external_product/gglwe_ksk.rs | 12 +- .../generics/external_product/ggsw_ct.rs | 8 +- .../generics/external_product/glwe_ct.rs | 11 +- .../src/tests/generics/keyswitch/gglwe_ct.rs | 5 +- .../src/tests/generics/keyswitch/ggsw_ct.rs | 16 +- .../src/tests/generics/keyswitch/glwe_ct.rs | 11 +- .../src/tests/generics/keyswitch/lwe_ct.rs | 11 +- poulpy-core/src/tests/generics/packing.rs | 6 +- poulpy-core/src/tests/generics/trace.rs | 8 +- poulpy-core/src/utils.rs | 9 +- poulpy-hal/src/api/mod.rs | 2 + poulpy-hal/src/api/svp_ppol.rs | 6 +- poulpy-hal/src/api/vec_znx.rs | 2 +- poulpy-hal/src/api/vec_znx_big.rs | 8 +- poulpy-hal/src/api/vec_znx_dft.rs | 8 +- poulpy-hal/src/api/vmp_pmat.rs | 18 +- poulpy-hal/src/api/zn.rs | 86 ++++++ poulpy-hal/src/delegates/mod.rs | 1 + poulpy-hal/src/delegates/svp_ppol.rs | 12 +- poulpy-hal/src/delegates/vec_znx.rs | 4 +- poulpy-hal/src/delegates/vec_znx_big.rs | 16 +- poulpy-hal/src/delegates/vec_znx_dft.rs | 16 +- poulpy-hal/src/delegates/vmp_pmat.rs | 30 +-- poulpy-hal/src/delegates/zn.rs | 114 ++++++++ poulpy-hal/src/layouts/encoding.rs | 89 +++++- poulpy-hal/src/layouts/mod.rs | 2 + poulpy-hal/src/layouts/zn.rs | 255 ++++++++++++++++++ poulpy-hal/src/oep/mod.rs | 2 + poulpy-hal/src/oep/vec_znx.rs | 2 +- poulpy-hal/src/oep/vec_znx_big.rs | 2 +- poulpy-hal/src/oep/vec_znx_dft.rs | 2 +- poulpy-hal/src/oep/vmp_pmat.rs | 11 +- poulpy-hal/src/oep/zn.rs | 97 +++++++ poulpy-hal/src/tests/vmp_pmat/vmp_apply.rs | 11 +- .../examples/circuit_bootstrapping.rs | 4 +- .../src/tfhe/blind_rotation/cggi_algo.rs | 19 +- .../src/tfhe/blind_rotation/cggi_key.rs | 12 +- .../src/tfhe/blind_rotation/key_prepared.rs | 5 +- poulpy-schemes/src/tfhe/blind_rotation/lut.rs | 22 +- .../tests/generic_blind_rotation.rs | 13 +- .../tfhe/blind_rotation/tests/generic_lut.rs | 6 +- .../src/tfhe/circuit_bootstrapping/circuit.rs | 2 +- .../tests/circuit_bootstrapping.rs | 13 +- 129 files changed, 1400 insertions(+), 686 deletions(-) create mode 100644 poulpy-backend/src/cpu_spqlios/ffi/zn64.rs create mode 100644 poulpy-backend/src/cpu_spqlios/fft64/zn.rs create mode 100644 poulpy-hal/src/api/zn.rs create mode 100644 poulpy-hal/src/delegates/zn.rs create mode 100644 poulpy-hal/src/layouts/zn.rs create mode 100644 poulpy-hal/src/oep/zn.rs diff --git a/poulpy-backend/examples/rlwe_encrypt.rs b/poulpy-backend/examples/rlwe_encrypt.rs index 2e970dc..a300150 100644 --- a/poulpy-backend/examples/rlwe_encrypt.rs +++ b/poulpy-backend/examples/rlwe_encrypt.rs @@ -18,7 +18,7 @@ fn main() { let log_scale: usize = msg_size * basek - 5; let module: Module = Module::::new(n as u64); - let mut scratch: ScratchOwned = ScratchOwned::::alloc(module.vec_znx_big_normalize_tmp_bytes(n)); + let mut scratch: ScratchOwned = ScratchOwned::::alloc(module.vec_znx_big_normalize_tmp_bytes()); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); @@ -28,7 +28,7 @@ fn main() { s.fill_ternary_prob(0, 0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_dft: SvpPPol, FFT64> = module.svp_ppol_alloc(n, s.cols()); + let mut s_dft: SvpPPol, FFT64> = module.svp_ppol_alloc(s.cols()); // s_dft <- DFT(s) module.svp_prepare(&mut s_dft, 0, &s, 0); @@ -43,7 +43,7 @@ fn main() { // Fill the second column with random values: ct = (0, a) module.vec_znx_fill_uniform(basek, &mut ct, 1, ct_size * basek, &mut source); - let mut buf_dft: VecZnxDft, FFT64> = module.vec_znx_dft_alloc(n, 1, ct_size); + let mut buf_dft: VecZnxDft, FFT64> = module.vec_znx_dft_alloc(1, ct_size); module.vec_znx_dft_from_vec_znx(1, 0, &mut buf_dft, 0, &ct, 1); @@ -58,7 +58,7 @@ fn main() { // Alias scratch space (VecZnxDft is always at least as big as VecZnxBig) // BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized) - let mut buf_big: VecZnxBig, FFT64> = module.vec_znx_big_alloc(n, 1, ct_size); + let mut buf_big: VecZnxBig, FFT64> = module.vec_znx_big_alloc(1, ct_size); module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); // Creates a plaintext: VecZnx with 1 column diff --git a/poulpy-backend/src/cpu_spqlios/ffi/mod.rs b/poulpy-backend/src/cpu_spqlios/ffi/mod.rs index 6d40a1e..af417ec 100644 --- a/poulpy-backend/src/cpu_spqlios/ffi/mod.rs +++ b/poulpy-backend/src/cpu_spqlios/ffi/mod.rs @@ -1,15 +1,17 @@ -#[allow(non_camel_case_types)] -pub mod module; -#[allow(non_camel_case_types)] -pub mod svp; -#[allow(non_camel_case_types)] -pub mod vec_znx; -#[allow(dead_code)] -#[allow(non_camel_case_types)] -pub mod vec_znx_big; -#[allow(non_camel_case_types)] -pub mod vec_znx_dft; -#[allow(non_camel_case_types)] -pub mod vmp; -#[allow(non_camel_case_types)] -pub mod znx; +#[allow(non_camel_case_types)] +pub mod module; +#[allow(non_camel_case_types)] +pub mod svp; +#[allow(non_camel_case_types)] +pub mod vec_znx; +#[allow(dead_code)] +#[allow(non_camel_case_types)] +pub mod vec_znx_big; +#[allow(non_camel_case_types)] +pub mod vec_znx_dft; +#[allow(non_camel_case_types)] +pub mod vmp; +#[allow(non_camel_case_types)] +pub mod zn64; +#[allow(non_camel_case_types)] +pub mod znx; diff --git a/poulpy-backend/src/cpu_spqlios/ffi/vec_znx.rs b/poulpy-backend/src/cpu_spqlios/ffi/vec_znx.rs index 020fb9e..6ede6f1 100644 --- a/poulpy-backend/src/cpu_spqlios/ffi/vec_znx.rs +++ b/poulpy-backend/src/cpu_spqlios/ffi/vec_znx.rs @@ -103,7 +103,6 @@ unsafe extern "C" { unsafe extern "C" { pub unsafe fn vec_znx_normalize_base2k( module: *const MODULE, - n: u64, base2k: u64, res: *mut i64, res_size: u64, @@ -114,6 +113,7 @@ unsafe extern "C" { tmp_space: *mut u8, ); } + unsafe extern "C" { - pub unsafe fn vec_znx_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64; + pub unsafe fn vec_znx_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; } diff --git a/poulpy-backend/src/cpu_spqlios/ffi/vec_znx_big.rs b/poulpy-backend/src/cpu_spqlios/ffi/vec_znx_big.rs index 55b9ea7..e62153f 100644 --- a/poulpy-backend/src/cpu_spqlios/ffi/vec_znx_big.rs +++ b/poulpy-backend/src/cpu_spqlios/ffi/vec_znx_big.rs @@ -93,13 +93,12 @@ unsafe extern "C" { } unsafe extern "C" { - pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64; + pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; } unsafe extern "C" { pub unsafe fn vec_znx_big_normalize_base2k( module: *const MODULE, - n: u64, log2_base2k: u64, res: *mut i64, res_size: u64, @@ -113,7 +112,6 @@ unsafe extern "C" { unsafe extern "C" { pub unsafe fn vec_znx_big_range_normalize_base2k( module: *const MODULE, - n: u64, log2_base2k: u64, res: *mut i64, res_size: u64, @@ -127,7 +125,7 @@ unsafe extern "C" { } unsafe extern "C" { - pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64; + pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64; } unsafe extern "C" { diff --git a/poulpy-backend/src/cpu_spqlios/ffi/vec_znx_dft.rs b/poulpy-backend/src/cpu_spqlios/ffi/vec_znx_dft.rs index 9612f37..e786775 100644 --- a/poulpy-backend/src/cpu_spqlios/ffi/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_spqlios/ffi/vec_znx_dft.rs @@ -43,7 +43,7 @@ unsafe extern "C" { ); } unsafe extern "C" { - pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE, n: u64) -> u64; + pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE) -> u64; } unsafe extern "C" { pub unsafe fn vec_znx_idft_tmp_a( diff --git a/poulpy-backend/src/cpu_spqlios/ffi/vmp.rs b/poulpy-backend/src/cpu_spqlios/ffi/vmp.rs index 48c3a84..6cf8635 100644 --- a/poulpy-backend/src/cpu_spqlios/ffi/vmp.rs +++ b/poulpy-backend/src/cpu_spqlios/ffi/vmp.rs @@ -79,7 +79,6 @@ unsafe extern "C" { unsafe extern "C" { pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes( module: *const MODULE, - nn: u64, res_size: u64, a_size: u64, nrows: u64, @@ -99,5 +98,5 @@ unsafe extern "C" { } unsafe extern "C" { - pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nn: u64, nrows: u64, ncols: u64) -> u64; + pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64; } diff --git a/poulpy-backend/src/cpu_spqlios/ffi/zn64.rs b/poulpy-backend/src/cpu_spqlios/ffi/zn64.rs new file mode 100644 index 0000000..9228332 --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/ffi/zn64.rs @@ -0,0 +1,13 @@ +unsafe extern "C" { + pub unsafe fn zn64_normalize_base2k_ref( + n: u64, + base2k: u64, + res: *mut i64, + res_size: u64, + res_sl: u64, + a: *const i64, + a_size: u64, + a_sl: u64, + tmp_space: *mut u8, + ); +} diff --git a/poulpy-backend/src/cpu_spqlios/fft64/mod.rs b/poulpy-backend/src/cpu_spqlios/fft64/mod.rs index b81e73d..3ca4713 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/mod.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/mod.rs @@ -5,6 +5,7 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp_pmat; +mod zn; pub use module::FFT64; diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs index c5271cc..00143c0 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs @@ -25,8 +25,8 @@ use crate::cpu_spqlios::{ }; unsafe impl VecZnxNormalizeTmpBytesImpl for FFT64 { - fn vec_znx_normalize_tmp_bytes_impl(module: &Module, n: usize) -> usize { - unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr() as *const module_info_t, n as u64) as usize } + fn vec_znx_normalize_tmp_bytes_impl(module: &Module) -> usize { + unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr() as *const module_info_t) as usize } } } @@ -54,12 +54,11 @@ where assert_eq!(res.n(), a.n()); } - let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes(a.n())); + let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes()); unsafe { vec_znx::vec_znx_normalize_base2k( module.ptr() as *const module_info_t, - a.n() as u64, basek as u64, res.at_mut_ptr(res_col, 0), res.size() as u64, @@ -88,12 +87,11 @@ where { let mut a: VecZnx<&mut [u8]> = a.to_mut(); - let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes(a.n())); + let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes()); unsafe { vec_znx::vec_znx_normalize_base2k( module.ptr() as *const module_info_t, - a.n() as u64, basek as u64, a.at_mut_ptr(a_col, 0), a.size() as u64, diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs index c571371..10496f4 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_big.rs @@ -569,8 +569,8 @@ unsafe impl VecZnxBigNegateInplaceImpl for FFT64 { } unsafe impl VecZnxBigNormalizeTmpBytesImpl for FFT64 { - fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module, n: usize) -> usize { - unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr(), n as u64) as usize } + fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module) -> usize { + unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr()) as usize } } } @@ -598,11 +598,10 @@ where assert_eq!(res.n(), a.n()); } - let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes(a.n())); + let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes()); unsafe { vec_znx::vec_znx_normalize_base2k( module.ptr(), - a.n() as u64, basek as u64, res.at_mut_ptr(res_col, 0), res.size() as u64, diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs index 669485b..a2487bd 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs @@ -36,8 +36,8 @@ unsafe impl VecZnxDftAllocImpl for FFT64 { } unsafe impl VecZnxDftToVecZnxBigTmpBytesImpl for FFT64 { - fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module, n: usize) -> usize { - unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr(), n as u64) as usize } + fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module) -> usize { + unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr()) as usize } } } @@ -61,7 +61,7 @@ unsafe impl VecZnxDftToVecZnxBigImpl for FFT64 { assert_eq!(res.n(), a.n()) } - let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_dft_to_vec_znx_big_tmp_bytes(a.n())); + let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_dft_to_vec_znx_big_tmp_bytes()); let min_size: usize = res.size().min(a.size()); diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs b/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs index 68c7ccd..b5770f0 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vmp_pmat.rs @@ -41,18 +41,10 @@ unsafe impl VmpPMatAllocImpl for FFT64 { } unsafe impl VmpPrepareTmpBytesImpl for FFT64 { - fn vmp_prepare_tmp_bytes_impl( - module: &Module, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> usize { + fn vmp_prepare_tmp_bytes_impl(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { unsafe { vmp::vmp_prepare_tmp_bytes( module.ptr(), - n as u64, (rows * cols_in) as u64, (cols_out * size) as u64, ) as usize @@ -102,8 +94,7 @@ unsafe impl VmpPMatPrepareImpl for FFT64 { ); } - let (tmp_bytes, _) = - scratch.take_slice(module.vmp_prepare_tmp_bytes(res.n(), a.rows(), a.cols_in(), a.cols_out(), a.size())); + let (tmp_bytes, _) = scratch.take_slice(module.vmp_prepare_tmp_bytes(a.rows(), a.cols_in(), a.cols_out(), a.size())); unsafe { vmp::vmp_prepare_contiguous( @@ -121,7 +112,6 @@ unsafe impl VmpPMatPrepareImpl for FFT64 { unsafe impl VmpApplyTmpBytesImpl for FFT64 { fn vmp_apply_tmp_bytes_impl( module: &Module, - n: usize, res_size: usize, a_size: usize, b_rows: usize, @@ -132,7 +122,6 @@ unsafe impl VmpApplyTmpBytesImpl for FFT64 { unsafe { vmp::vmp_apply_dft_to_dft_tmp_bytes( module.ptr(), - n as u64, (res_size * b_cols_out) as u64, (a_size * b_cols_in) as u64, (b_rows * b_cols_in) as u64, @@ -174,7 +163,6 @@ unsafe impl VmpApplyImpl for FFT64 { } let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes( - res.n(), res.size(), a.size(), b.rows(), @@ -201,7 +189,6 @@ unsafe impl VmpApplyImpl for FFT64 { unsafe impl VmpApplyAddTmpBytesImpl for FFT64 { fn vmp_apply_add_tmp_bytes_impl( module: &Module, - n: usize, res_size: usize, a_size: usize, b_rows: usize, @@ -212,7 +199,6 @@ unsafe impl VmpApplyAddTmpBytesImpl for FFT64 { unsafe { vmp::vmp_apply_dft_to_dft_tmp_bytes( module.ptr(), - n as u64, (res_size * b_cols_out) as u64, (a_size * b_cols_in) as u64, (b_rows * b_cols_in) as u64, @@ -254,7 +240,6 @@ unsafe impl VmpApplyAddImpl for FFT64 { } let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes( - res.n(), res.size(), a.size(), b.rows(), diff --git a/poulpy-backend/src/cpu_spqlios/fft64/zn.rs b/poulpy-backend/src/cpu_spqlios/fft64/zn.rs new file mode 100644 index 0000000..3e2726f --- /dev/null +++ b/poulpy-backend/src/cpu_spqlios/fft64/zn.rs @@ -0,0 +1,201 @@ +use poulpy_hal::{ + api::{TakeSlice, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut}, + layouts::{Scratch, Zn, ZnToMut}, + oep::{ + TakeSliceImpl, ZnAddDistF64Impl, ZnAddNormalImpl, ZnFillDistF64Impl, ZnFillNormalImpl, ZnFillUniformImpl, + ZnNormalizeInplaceImpl, + }, + source::Source, +}; +use rand_distr::Normal; + +use crate::cpu_spqlios::{FFT64, ffi::zn64}; + +unsafe impl ZnNormalizeInplaceImpl for FFT64 +where + Self: TakeSliceImpl, +{ + fn zn_normalize_inplace_impl(n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + where + A: ZnToMut, + { + let mut a: Zn<&mut [u8]> = a.to_mut(); + + let (tmp_bytes, _) = scratch.take_slice(n * size_of::()); + + unsafe { + zn64::zn64_normalize_base2k_ref( + n as u64, + basek as u64, + a.at_mut_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + a.at_ptr(a_col, 0), + a.size() as u64, + a.sl() as u64, + tmp_bytes.as_mut_ptr(), + ); + } + } +} + +unsafe impl ZnFillUniformImpl for FFT64 { + fn zn_fill_uniform_impl(n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source) + where + R: ZnToMut, + { + let mut a: Zn<&mut [u8]> = res.to_mut(); + let base2k: u64 = 1 << basek; + let mask: u64 = base2k - 1; + let base2k_half: i64 = (base2k >> 1) as i64; + (0..k.div_ceil(basek)).for_each(|j| { + a.at_mut(res_col, j)[..n] + .iter_mut() + .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); + }) + } +} + +unsafe impl ZnFillDistF64Impl for FFT64 { + fn zn_fill_dist_f64_impl>( + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) where + R: ZnToMut, + { + let mut a: Zn<&mut [u8]> = res.to_mut(); + assert!( + (bound.log2().ceil() as i64) < 64, + "invalid bound: ceil(log2(bound))={} > 63", + (bound.log2().ceil() as i64) + ); + + let limb: usize = k.div_ceil(basek) - 1; + let basek_rem: usize = (limb + 1) * basek - k; + + if basek_rem != 0 { + a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a = (dist_f64.round() as i64) << basek_rem; + }); + } else { + a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a = dist_f64.round() as i64 + }); + } + } +} + +unsafe impl ZnAddDistF64Impl for FFT64 { + fn zn_add_dist_f64_impl>( + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) where + R: ZnToMut, + { + let mut a: Zn<&mut [u8]> = res.to_mut(); + assert!( + (bound.log2().ceil() as i64) < 64, + "invalid bound: ceil(log2(bound))={} > 63", + (bound.log2().ceil() as i64) + ); + + let limb: usize = k.div_ceil(basek) - 1; + let basek_rem: usize = (limb + 1) * basek - k; + + if basek_rem != 0 { + a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += (dist_f64.round() as i64) << basek_rem; + }); + } else { + a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += dist_f64.round() as i64 + }); + } + } +} + +unsafe impl ZnFillNormalImpl for FFT64 +where + Self: ZnFillDistF64Impl, +{ + fn zn_fill_normal_impl( + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: ZnToMut, + { + Self::zn_fill_dist_f64_impl( + n, + basek, + res, + res_col, + k, + source, + Normal::new(0.0, sigma).unwrap(), + bound, + ); + } +} + +unsafe impl ZnAddNormalImpl for FFT64 +where + Self: ZnAddDistF64Impl, +{ + fn zn_add_normal_impl( + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: ZnToMut, + { + Self::zn_add_dist_f64_impl( + n, + basek, + res, + res_col, + k, + source, + Normal::new(0.0, sigma).unwrap(), + bound, + ); + } +} diff --git a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic index de62af3..708e5d7 160000 --- a/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic +++ b/poulpy-backend/src/cpu_spqlios/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit de62af3507776597231e0c0d2b26495a0c92d207 +Subproject commit 708e5d7e867abba60f029794eea58aa2735e1f15 diff --git a/poulpy-core/benches/external_product_glwe_fft64.rs b/poulpy-core/benches/external_product_glwe_fft64.rs index dc860e4..ab3959e 100644 --- a/poulpy-core/benches/external_product_glwe_fft64.rs +++ b/poulpy-core/benches/external_product_glwe_fft64.rs @@ -44,11 +44,10 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) { let pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(&module, n, basek, ct_ggsw.k(), rank) - | GLWECiphertext::encrypt_sk_scratch_space(&module, n, basek, ct_glwe_in.k()) + GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe_in.k()) | GLWECiphertext::external_product_scratch_space( &module, - n, basek, ct_glwe_out.k(), ct_glwe_in.k(), @@ -137,17 +136,9 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { let pt_rgsw: ScalarZnx> = ScalarZnx::alloc(n, 1); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(&module, n, basek, ct_ggsw.k(), rank) - | GLWECiphertext::encrypt_sk_scratch_space(&module, n, basek, ct_glwe.k()) - | GLWECiphertext::external_product_inplace_scratch_space( - &module, - n, - basek, - ct_glwe.k(), - ct_ggsw.k(), - digits, - rank, - ), + GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k()) + | GLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_ggsw.k(), digits, rank), ); let mut source_xs = Source::new([0u8; 32]); diff --git a/poulpy-core/benches/keyswitch_glwe_fft64.rs b/poulpy-core/benches/keyswitch_glwe_fft64.rs index 537295e..37f3326 100644 --- a/poulpy-core/benches/keyswitch_glwe_fft64.rs +++ b/poulpy-core/benches/keyswitch_glwe_fft64.rs @@ -45,11 +45,10 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { let mut ct_out: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_rlwe_out, rank_out); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(&module, n, basek, ksk.k(), rank_in, rank_out) - | GLWECiphertext::encrypt_sk_scratch_space(&module, n, basek, ct_in.k()) + GGLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank_in, rank_out) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k()) | GLWECiphertext::keyswitch_scratch_space( &module, - n, basek, ct_out.k(), ct_in.k(), @@ -148,9 +147,9 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { let mut ct: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_ct, rank); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(&module, n, basek, ksk.k(), rank, rank) - | GLWECiphertext::encrypt_sk_scratch_space(&module, n, basek, ct.k()) - | GLWECiphertext::keyswitch_inplace_scratch_space(&module, n, basek, ct.k(), ksk.k(), digits, rank), + GGLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank, rank) + | GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) + | GLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct.k(), ksk.k(), digits, rank), ); let mut source_xs: Source = Source::new([0u8; 32]); diff --git a/poulpy-core/examples/encryption.rs b/poulpy-core/examples/encryption.rs index aa0109e..df6e944 100644 --- a/poulpy-core/examples/encryption.rs +++ b/poulpy-core/examples/encryption.rs @@ -45,8 +45,8 @@ fn main() { // Scratch space let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(&module, n, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(&module, n, basek, ct.k()), + GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k()) + | GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()), ); // Generate secret-key diff --git a/poulpy-core/src/automorphism/gglwe_atk.rs b/poulpy-core/src/automorphism/gglwe_atk.rs index 21934c1..bddfade 100644 --- a/poulpy-core/src/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/automorphism/gglwe_atk.rs @@ -13,7 +13,6 @@ impl GGLWEAutomorphismKey> { #[allow(clippy::too_many_arguments)] pub fn automorphism_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_in: usize, @@ -24,12 +23,11 @@ impl GGLWEAutomorphismKey> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { - GLWECiphertext::keyswitch_scratch_space(module, n, basek, k_out, k_in, k_ksk, digits, rank, rank) + GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) } pub fn automorphism_inplace_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_ksk: usize, @@ -39,7 +37,7 @@ impl GGLWEAutomorphismKey> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { - GGLWEAutomorphismKey::automorphism_scratch_space(module, n, basek, k_out, k_out, k_ksk, digits, rank) + GGLWEAutomorphismKey::automorphism_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank) } } diff --git a/poulpy-core/src/automorphism/ggsw_ct.rs b/poulpy-core/src/automorphism/ggsw_ct.rs index 7cd1a63..e85fe14 100644 --- a/poulpy-core/src/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/automorphism/ggsw_ct.rs @@ -17,7 +17,6 @@ impl GGSWCiphertext> { #[allow(clippy::too_many_arguments)] pub fn automorphism_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_in: usize, @@ -32,17 +31,16 @@ impl GGSWCiphertext> { VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes, { let out_size: usize = k_out.div_ceil(basek); - let ci_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, out_size); + let ci_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, out_size); let ks_internal: usize = - GLWECiphertext::keyswitch_scratch_space(module, n, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank); - let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, n, basek, k_out, k_tsk, digits_tsk, rank); + GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank); + let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank); ci_dft + (ks_internal | expand) } #[allow(clippy::too_many_arguments)] pub fn automorphism_inplace_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_ksk: usize, @@ -56,7 +54,7 @@ impl GGSWCiphertext> { VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes, { GGSWCiphertext::automorphism_scratch_space( - module, n, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank, + module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank, ) } } @@ -117,7 +115,6 @@ impl GGSWCiphertext { scratch.available() >= GGSWCiphertext::automorphism_scratch_space( module, - self.n(), self.basek(), self.k(), lhs.k(), diff --git a/poulpy-core/src/automorphism/glwe_ct.rs b/poulpy-core/src/automorphism/glwe_ct.rs index 1b4c344..cc94083 100644 --- a/poulpy-core/src/automorphism/glwe_ct.rs +++ b/poulpy-core/src/automorphism/glwe_ct.rs @@ -13,7 +13,6 @@ impl GLWECiphertext> { #[allow(clippy::too_many_arguments)] pub fn automorphism_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_in: usize, @@ -24,12 +23,11 @@ impl GLWECiphertext> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { - Self::keyswitch_scratch_space(module, n, basek, k_out, k_in, k_ksk, digits, rank, rank) + Self::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) } pub fn automorphism_inplace_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_ksk: usize, @@ -39,7 +37,7 @@ impl GLWECiphertext> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { - Self::keyswitch_inplace_scratch_space(module, n, basek, k_out, k_ksk, digits, rank) + Self::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) } } diff --git a/poulpy-core/src/conversion/glwe_to_lwe.rs b/poulpy-core/src/conversion/glwe_to_lwe.rs index c99afda..2b01e74 100644 --- a/poulpy-core/src/conversion/glwe_to_lwe.rs +++ b/poulpy-core/src/conversion/glwe_to_lwe.rs @@ -15,7 +15,6 @@ use crate::{ impl LWECiphertext> { pub fn from_glwe_scratch_space( module: &Module, - n: usize, basek: usize, k_lwe: usize, k_glwe: usize, @@ -25,8 +24,8 @@ impl LWECiphertext> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { - GLWECiphertext::bytes_of(n, basek, k_lwe, 1) - + GLWECiphertext::keyswitch_scratch_space(module, n, basek, k_lwe, k_glwe, k_ksk, 1, rank, 1) + GLWECiphertext::bytes_of(module.n(), basek, k_lwe, 1) + + GLWECiphertext::keyswitch_scratch_space(module, basek, k_lwe, k_glwe, k_ksk, 1, rank, 1) } } diff --git a/poulpy-core/src/conversion/lwe_to_glwe.rs b/poulpy-core/src/conversion/lwe_to_glwe.rs index 83d8b99..43d0de2 100644 --- a/poulpy-core/src/conversion/lwe_to_glwe.rs +++ b/poulpy-core/src/conversion/lwe_to_glwe.rs @@ -15,7 +15,6 @@ use crate::{ impl GLWECiphertext> { pub fn from_lwe_scratch_space( module: &Module, - n: usize, basek: usize, k_lwe: usize, k_glwe: usize, @@ -25,8 +24,8 @@ impl GLWECiphertext> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { - GLWECiphertext::keyswitch_scratch_space(module, n, basek, k_glwe, k_lwe, k_ksk, 1, 1, rank) - + GLWECiphertext::bytes_of(n, basek, k_lwe, 1) + GLWECiphertext::keyswitch_scratch_space(module, basek, k_glwe, k_lwe, k_ksk, 1, 1, rank) + + GLWECiphertext::bytes_of(module.n(), basek, k_lwe, 1) } } diff --git a/poulpy-core/src/decryption/glwe_ct.rs b/poulpy-core/src/decryption/glwe_ct.rs index 2d9ee17..d590585 100644 --- a/poulpy-core/src/decryption/glwe_ct.rs +++ b/poulpy-core/src/decryption/glwe_ct.rs @@ -9,13 +9,12 @@ use poulpy_hal::{ use crate::layouts::{GLWECiphertext, GLWEPlaintext, Infos, prepared::GLWESecretPrepared}; impl GLWECiphertext> { - pub fn decrypt_scratch_space(module: &Module, n: usize, basek: usize, k: usize) -> usize + pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize where Module: VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, { let size: usize = k.div_ceil(basek); - (module.vec_znx_normalize_tmp_bytes(n) | module.vec_znx_dft_alloc_bytes(n, 1, size)) - + module.vec_znx_dft_alloc_bytes(n, 1, size) + (module.vec_znx_normalize_tmp_bytes() | module.vec_znx_dft_alloc_bytes(1, size)) + module.vec_znx_dft_alloc_bytes(1, size) } } diff --git a/poulpy-core/src/decryption/lwe_ct.rs b/poulpy-core/src/decryption/lwe_ct.rs index 50fb4d7..bd29947 100644 --- a/poulpy-core/src/decryption/lwe_ct.rs +++ b/poulpy-core/src/decryption/lwe_ct.rs @@ -1,5 +1,5 @@ use poulpy_hal::{ - api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalizeInplace, ZnxView, ZnxViewMut}, + api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace, ZnxView, ZnxViewMut}, layouts::{Backend, DataMut, DataRef, Module, ScratchOwned}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, }; @@ -14,7 +14,7 @@ where where DataPt: DataMut, DataSk: DataRef, - Module: VecZnxNormalizeInplace, + Module: ZnNormalizeInplace, B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { #[cfg(debug_assertions)] @@ -30,7 +30,8 @@ where .map(|(x, y)| x * y) .sum::(); }); - module.vec_znx_normalize_inplace( + module.zn_normalize_inplace( + pt.n(), self.basek(), &mut pt.data, 0, diff --git a/poulpy-core/src/encryption/compressed/gglwe_atk.rs b/poulpy-core/src/encryption/compressed/gglwe_atk.rs index c417028..414dbd1 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_atk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_atk.rs @@ -18,11 +18,12 @@ use crate::{ }; impl GGLWEAutomorphismKeyCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize where Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + SvpPPolAllocBytes, { - GGLWESwitchingKeyCompressed::encrypt_sk_scratch_space(module, n, basek, k, rank, rank) + GLWESecret::bytes_of(n, rank) + GGLWESwitchingKeyCompressed::encrypt_sk_scratch_space(module, basek, k, rank, rank) + + GLWESecret::bytes_of(module.n(), rank) } } @@ -66,18 +67,12 @@ impl GGLWEAutomorphismKeyCompressed { assert_eq!(sk.rank(), self.rank()); assert!( scratch.available() - >= GGLWEAutomorphismKeyCompressed::encrypt_sk_scratch_space( - module, - sk.n(), - self.basek(), - self.k(), - self.rank() - ), + >= GGLWEAutomorphismKeyCompressed::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()), "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", scratch.available(), self.rank(), self.size(), - GGLWEAutomorphismKeyCompressed::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k(), self.rank()) + GGLWEAutomorphismKeyCompressed::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()) ) } diff --git a/poulpy-core/src/encryption/compressed/gglwe_ct.rs b/poulpy-core/src/encryption/compressed/gglwe_ct.rs index 9cb2c56..a672ca5 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_ct.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_ct.rs @@ -15,11 +15,11 @@ use crate::{ }; impl GGLWECiphertextCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize where Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, { - GGLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k) + GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k) } } @@ -71,13 +71,12 @@ impl GGLWECiphertextCompressed { assert_eq!(self.n(), sk.n()); assert_eq!(pt.n(), sk.n()); assert!( - scratch.available() - >= GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k()), + scratch.available() >= GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self.basek(), self.k()), "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", scratch.available(), self.rank(), self.size(), - GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k()) + GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self.basek(), self.k()) ); assert!( self.rows() * self.digits() * self.basek() <= self.k(), diff --git a/poulpy-core/src/encryption/compressed/gglwe_ksk.rs b/poulpy-core/src/encryption/compressed/gglwe_ksk.rs index 1061db8..3bc93da 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_ksk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_ksk.rs @@ -17,7 +17,6 @@ use crate::{ impl GGLWESwitchingKeyCompressed> { pub fn encrypt_sk_scratch_space( module: &Module, - n: usize, basek: usize, k: usize, rank_in: usize, @@ -26,9 +25,9 @@ impl GGLWESwitchingKeyCompressed> { where Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + SvpPPolAllocBytes, { - (GGLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k) | ScalarZnx::alloc_bytes(n, 1)) - + ScalarZnx::alloc_bytes(n, rank_in) - + GLWESecretPrepared::bytes_of(module, n, rank_out) + (GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k) | ScalarZnx::alloc_bytes(module.n(), 1)) + + ScalarZnx::alloc_bytes(module.n(), rank_in) + + GLWESecretPrepared::bytes_of(module, rank_out) } } @@ -72,7 +71,6 @@ impl GGLWESwitchingKeyCompressed { scratch.available() >= GGLWESwitchingKey::encrypt_sk_scratch_space( module, - sk_out.n(), self.basek(), self.k(), self.rank_in(), @@ -82,7 +80,6 @@ impl GGLWESwitchingKeyCompressed { scratch.available(), GGLWESwitchingKey::encrypt_sk_scratch_space( module, - sk_out.n(), self.basek(), self.k(), self.rank_in(), diff --git a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs index fd7f385..e7d4f08 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs @@ -16,12 +16,12 @@ use crate::{ }; impl GGLWETensorKeyCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize where Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes, { - GGLWETensorKey::encrypt_sk_scratch_space(module, n, basek, k, rank) + GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k, rank) } } diff --git a/poulpy-core/src/encryption/compressed/ggsw_ct.rs b/poulpy-core/src/encryption/compressed/ggsw_ct.rs index fc5d262..85605fe 100644 --- a/poulpy-core/src/encryption/compressed/ggsw_ct.rs +++ b/poulpy-core/src/encryption/compressed/ggsw_ct.rs @@ -15,11 +15,11 @@ use crate::{ }; impl GGSWCiphertextCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize where Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, { - GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k, rank) + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank) } } diff --git a/poulpy-core/src/encryption/compressed/glwe_ct.rs b/poulpy-core/src/encryption/compressed/glwe_ct.rs index ed97acf..1db80e9 100644 --- a/poulpy-core/src/encryption/compressed/glwe_ct.rs +++ b/poulpy-core/src/encryption/compressed/glwe_ct.rs @@ -14,11 +14,11 @@ use crate::{ }; impl GLWECiphertextCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize where Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, { - GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) } } diff --git a/poulpy-core/src/encryption/gglwe_atk.rs b/poulpy-core/src/encryption/gglwe_atk.rs index 464cce2..9a4e7a7 100644 --- a/poulpy-core/src/encryption/gglwe_atk.rs +++ b/poulpy-core/src/encryption/gglwe_atk.rs @@ -15,15 +15,15 @@ use crate::{ }; impl GGLWEAutomorphismKey> { - pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize where Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, { - GGLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k, rank, rank) + GLWESecret::bytes_of(n, rank) + GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank) + GLWESecret::bytes_of(module.n(), rank) } - pub fn encrypt_pk_scratch_space(module: &Module, _n: usize, _basek: usize, _k: usize, _rank: usize) -> usize { - GGLWESwitchingKey::encrypt_pk_scratch_space(module, _n, _basek, _k, _rank, _rank) + pub fn encrypt_pk_scratch_space(module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { + GGLWESwitchingKey::encrypt_pk_scratch_space(module, _basek, _k, _rank, _rank) } } @@ -67,12 +67,12 @@ impl GGLWEAutomorphismKey { assert_eq!(sk.rank(), self.rank()); assert!( scratch.available() - >= GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k(), self.rank()), + >= GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()), "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", scratch.available(), self.rank(), self.size(), - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k(), self.rank()) + GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()) ) } diff --git a/poulpy-core/src/encryption/gglwe_ct.rs b/poulpy-core/src/encryption/gglwe_ct.rs index b98ec0e..0450ac3 100644 --- a/poulpy-core/src/encryption/gglwe_ct.rs +++ b/poulpy-core/src/encryption/gglwe_ct.rs @@ -14,15 +14,15 @@ use crate::{ }; impl GGLWECiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize where Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, { - GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k) - + (GLWEPlaintext::byte_of(n, basek, k) | module.vec_znx_normalize_tmp_bytes(n)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + + (GLWEPlaintext::byte_of(module.n(), basek, k) | module.vec_znx_normalize_tmp_bytes()) } - pub fn encrypt_pk_scratch_space(_module: &Module, _n: usize, _basek: usize, _k: usize, _rank: usize) -> usize { + pub fn encrypt_pk_scratch_space(_module: &Module, _basek: usize, _k: usize, _rank: usize) -> usize { unimplemented!() } } @@ -75,12 +75,12 @@ impl GGLWECiphertext { assert_eq!(self.n(), sk.n()); assert_eq!(pt.n(), sk.n()); assert!( - scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k()), + scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()), "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", scratch.available(), self.rank(), self.size(), - GGLWECiphertext::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k()) + GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()) ); assert!( self.rows() * self.digits() * self.basek() <= self.k(), diff --git a/poulpy-core/src/encryption/gglwe_ksk.rs b/poulpy-core/src/encryption/gglwe_ksk.rs index efd186a..f199e0a 100644 --- a/poulpy-core/src/encryption/gglwe_ksk.rs +++ b/poulpy-core/src/encryption/gglwe_ksk.rs @@ -17,7 +17,6 @@ use crate::{ impl GGLWESwitchingKey> { pub fn encrypt_sk_scratch_space( module: &Module, - n: usize, basek: usize, k: usize, rank_in: usize, @@ -26,20 +25,19 @@ impl GGLWESwitchingKey> { where Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, { - (GGLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k) | ScalarZnx::alloc_bytes(n, 1)) - + ScalarZnx::alloc_bytes(n, rank_in) - + GLWESecretPrepared::bytes_of(module, n, rank_out) + (GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k) | ScalarZnx::alloc_bytes(module.n(), 1)) + + ScalarZnx::alloc_bytes(module.n(), rank_in) + + GLWESecretPrepared::bytes_of(module, rank_out) } pub fn encrypt_pk_scratch_space( module: &Module, - _n: usize, _basek: usize, _k: usize, _rank_in: usize, _rank_out: usize, ) -> usize { - GGLWECiphertext::encrypt_pk_scratch_space(module, _n, _basek, _k, _rank_out) + GGLWECiphertext::encrypt_pk_scratch_space(module, _basek, _k, _rank_out) } } @@ -83,7 +81,6 @@ impl GGLWESwitchingKey { scratch.available() >= GGLWESwitchingKey::encrypt_sk_scratch_space( module, - sk_out.n(), self.basek(), self.k(), self.rank_in(), @@ -93,7 +90,6 @@ impl GGLWESwitchingKey { scratch.available(), GGLWESwitchingKey::encrypt_sk_scratch_space( module, - sk_out.n(), self.basek(), self.k(), self.rank_in(), diff --git a/poulpy-core/src/encryption/gglwe_tsk.rs b/poulpy-core/src/encryption/gglwe_tsk.rs index 7fad00b..871c4fc 100644 --- a/poulpy-core/src/encryption/gglwe_tsk.rs +++ b/poulpy-core/src/encryption/gglwe_tsk.rs @@ -18,17 +18,17 @@ use crate::{ }; impl GGLWETensorKey> { - pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize where Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes, { - GLWESecretPrepared::bytes_of(module, n, rank) - + module.vec_znx_dft_alloc_bytes(n, rank, 1) - + module.vec_znx_big_alloc_bytes(n, 1, 1) - + module.vec_znx_dft_alloc_bytes(n, 1, 1) - + GLWESecret::bytes_of(n, 1) - + GGLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k, rank, rank) + GLWESecretPrepared::bytes_of(module, rank) + + module.vec_znx_dft_alloc_bytes(rank, 1) + + module.vec_znx_big_alloc_bytes(1, 1) + + module.vec_znx_dft_alloc_bytes(1, 1) + + GLWESecret::bytes_of(module.n(), 1) + + GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank) } } diff --git a/poulpy-core/src/encryption/ggsw_ct.rs b/poulpy-core/src/encryption/ggsw_ct.rs index d69a198..26dd683 100644 --- a/poulpy-core/src/encryption/ggsw_ct.rs +++ b/poulpy-core/src/encryption/ggsw_ct.rs @@ -14,15 +14,15 @@ use crate::{ }; impl GGSWCiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize where Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, { let size = k.div_ceil(basek); - GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k) - + VecZnx::alloc_bytes(n, rank + 1, size) - + VecZnx::alloc_bytes(n, 1, size) - + module.vec_znx_dft_alloc_bytes(n, rank + 1, size) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + + VecZnx::alloc_bytes(module.n(), rank + 1, size) + + VecZnx::alloc_bytes(module.n(), 1, size) + + module.vec_znx_dft_alloc_bytes(rank + 1, size) } } diff --git a/poulpy-core/src/encryption/glwe_ct.rs b/poulpy-core/src/encryption/glwe_ct.rs index a8dafe0..d4bb05f 100644 --- a/poulpy-core/src/encryption/glwe_ct.rs +++ b/poulpy-core/src/encryption/glwe_ct.rs @@ -19,21 +19,24 @@ use crate::{ }; impl GLWECiphertext> { - pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize where Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, { let size: usize = k.div_ceil(basek); - module.vec_znx_normalize_tmp_bytes(n) + 2 * VecZnx::alloc_bytes(n, 1, size) + module.vec_znx_dft_alloc_bytes(n, 1, size) + module.vec_znx_normalize_tmp_bytes() + + 2 * VecZnx::alloc_bytes(module.n(), 1, size) + + module.vec_znx_dft_alloc_bytes(1, size) } - pub fn encrypt_pk_scratch_space(module: &Module, n: usize, basek: usize, k: usize) -> usize + pub fn encrypt_pk_scratch_space(module: &Module, basek: usize, k: usize) -> usize where Module: VecZnxDftAllocBytes + SvpPPolAllocBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes, { let size: usize = k.div_ceil(basek); - ((module.vec_znx_dft_alloc_bytes(n, 1, size) + module.vec_znx_big_alloc_bytes(n, 1, size)) | ScalarZnx::alloc_bytes(n, 1)) - + module.svp_ppol_alloc_bytes(n, 1) - + module.vec_znx_normalize_tmp_bytes(n) + ((module.vec_znx_dft_alloc_bytes(1, size) + module.vec_znx_big_alloc_bytes(1, size)) + | ScalarZnx::alloc_bytes(module.n(), 1)) + + module.svp_ppol_alloc_bytes(1) + + module.vec_znx_normalize_tmp_bytes() } } @@ -69,10 +72,10 @@ impl GLWECiphertext { assert_eq!(sk.n(), self.n()); assert_eq!(pt.n(), self.n()); assert!( - scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.n(), self.basek(), self.k()), + scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()), "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", scratch.available(), - GLWECiphertext::encrypt_sk_scratch_space(module, self.n(), self.basek(), self.k()) + GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()) ) } @@ -107,10 +110,10 @@ impl GLWECiphertext { assert_eq!(self.rank(), sk.rank()); assert_eq!(sk.n(), self.n()); assert!( - scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.n(), self.basek(), self.k()), + scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()), "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", scratch.available(), - GLWECiphertext::encrypt_sk_scratch_space(module, self.n(), self.basek(), self.k()) + GLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()) ) } self.encrypt_sk_internal( diff --git a/poulpy-core/src/encryption/glwe_pk.rs b/poulpy-core/src/encryption/glwe_pk.rs index 7b05296..8129a27 100644 --- a/poulpy-core/src/encryption/glwe_pk.rs +++ b/poulpy-core/src/encryption/glwe_pk.rs @@ -54,7 +54,6 @@ impl GLWEPublicKey { // Its ok to allocate scratch space here since pk is usually generated only once. let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::encrypt_sk_scratch_space( module, - self.n(), self.basek(), self.k(), )); diff --git a/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs b/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs index 377955c..e3c4e2f 100644 --- a/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/encryption/glwe_to_lwe_ksk.rs @@ -16,12 +16,13 @@ use crate::{ }; impl GLWEToLWESwitchingKey> { - pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank_in: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank_in: usize) -> usize where Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, { - GLWESecretPrepared::bytes_of(module, n, rank_in) - + (GGLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k, rank_in, 1) | GLWESecret::bytes_of(n, rank_in)) + GLWESecretPrepared::bytes_of(module, rank_in) + + (GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank_in, 1) + | GLWESecret::bytes_of(module.n(), rank_in)) } } diff --git a/poulpy-core/src/encryption/lwe_ct.rs b/poulpy-core/src/encryption/lwe_ct.rs index 4ed15af..808d67d 100644 --- a/poulpy-core/src/encryption/lwe_ct.rs +++ b/poulpy-core/src/encryption/lwe_ct.rs @@ -1,8 +1,6 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddNormal, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxView, ZnxViewMut, - }, - layouts::{Backend, DataMut, DataRef, Module, ScratchOwned, VecZnx}, + api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, ZnxView, ZnxViewMut}, + layouts::{Backend, DataMut, DataRef, Module, ScratchOwned, Zn}, oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl}, source::Source, }; @@ -23,7 +21,7 @@ impl LWECiphertext { ) where DataPt: DataRef, DataSk: DataRef, - Module: VecZnxFillUniform + VecZnxAddNormal + VecZnxNormalizeInplace, + Module: ZnFillUniform + ZnAddNormal + ZnNormalizeInplace, B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { #[cfg(debug_assertions)] @@ -34,9 +32,9 @@ impl LWECiphertext { let basek: usize = self.basek(); let k: usize = self.k(); - module.vec_znx_fill_uniform(basek, &mut self.data, 0, k, source_xa); + module.zn_fill_uniform(self.n() + 1, basek, &mut self.data, 0, k, source_xa); - let mut tmp_znx: VecZnx> = VecZnx::alloc(1, 1, self.size()); + let mut tmp_znx: Zn> = Zn::alloc(1, 1, self.size()); let min_size = self.size().min(pt.size()); @@ -57,9 +55,19 @@ impl LWECiphertext { .sum::(); }); - module.vec_znx_add_normal(basek, &mut self.data, 0, k, source_xe, SIGMA, SIGMA_BOUND); + module.zn_add_normal( + 1, + basek, + &mut self.data, + 0, + k, + source_xe, + SIGMA, + SIGMA_BOUND, + ); - module.vec_znx_normalize_inplace( + module.zn_normalize_inplace( + 1, basek, &mut tmp_znx, 0, diff --git a/poulpy-core/src/encryption/lwe_ksk.rs b/poulpy-core/src/encryption/lwe_ksk.rs index 7f262b5..aebb2fb 100644 --- a/poulpy-core/src/encryption/lwe_ksk.rs +++ b/poulpy-core/src/encryption/lwe_ksk.rs @@ -15,13 +15,13 @@ use crate::{ }; impl LWESwitchingKey> { - pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize where Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, { - GLWESecret::bytes_of(n, 1) - + GLWESecretPrepared::bytes_of(module, n, 1) - + GGLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k, 1, 1) + GLWESecret::bytes_of(module.n(), 1) + + GLWESecretPrepared::bytes_of(module, 1) + + GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, 1, 1) } } diff --git a/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs b/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs index 3b9f749..5c2f570 100644 --- a/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/encryption/lwe_to_glwe_ksk.rs @@ -15,11 +15,11 @@ use crate::{ }; impl LWEToGLWESwitchingKey> { - pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank_out: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank_out: usize) -> usize where Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, { - GGLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k, 1, rank_out) + GLWESecret::bytes_of(n, 1) + GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, 1, rank_out) + GLWESecret::bytes_of(module.n(), 1) } } diff --git a/poulpy-core/src/external_product/gglwe_atk.rs b/poulpy-core/src/external_product/gglwe_atk.rs index 0eed7e6..8e5e4f5 100644 --- a/poulpy-core/src/external_product/gglwe_atk.rs +++ b/poulpy-core/src/external_product/gglwe_atk.rs @@ -12,7 +12,6 @@ impl GGLWEAutomorphismKey> { #[allow(clippy::too_many_arguments)] pub fn external_product_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_in: usize, @@ -23,12 +22,11 @@ impl GGLWEAutomorphismKey> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxNormalizeTmpBytes, { - GGLWESwitchingKey::external_product_scratch_space(module, n, basek, k_out, k_in, ggsw_k, digits, rank) + GGLWESwitchingKey::external_product_scratch_space(module, basek, k_out, k_in, ggsw_k, digits, rank) } pub fn external_product_inplace_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, ggsw_k: usize, @@ -38,7 +36,7 @@ impl GGLWEAutomorphismKey> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxNormalizeTmpBytes, { - GGLWESwitchingKey::external_product_inplace_scratch_space(module, n, basek, k_out, ggsw_k, digits, rank) + GGLWESwitchingKey::external_product_inplace_scratch_space(module, basek, k_out, ggsw_k, digits, rank) } } diff --git a/poulpy-core/src/external_product/gglwe_ksk.rs b/poulpy-core/src/external_product/gglwe_ksk.rs index a11479b..a9caa9e 100644 --- a/poulpy-core/src/external_product/gglwe_ksk.rs +++ b/poulpy-core/src/external_product/gglwe_ksk.rs @@ -12,7 +12,6 @@ impl GGLWESwitchingKey> { #[allow(clippy::too_many_arguments)] pub fn external_product_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_in: usize, @@ -23,12 +22,11 @@ impl GGLWESwitchingKey> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxNormalizeTmpBytes, { - GLWECiphertext::external_product_scratch_space(module, n, basek, k_out, k_in, k_ggsw, digits, rank) + GLWECiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank) } pub fn external_product_inplace_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_ggsw: usize, @@ -38,7 +36,7 @@ impl GGLWESwitchingKey> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxNormalizeTmpBytes, { - GLWECiphertext::external_product_inplace_scratch_space(module, n, basek, k_out, k_ggsw, digits, rank) + GLWECiphertext::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank) } } diff --git a/poulpy-core/src/external_product/ggsw_ct.rs b/poulpy-core/src/external_product/ggsw_ct.rs index 03d080f..8becf1b 100644 --- a/poulpy-core/src/external_product/ggsw_ct.rs +++ b/poulpy-core/src/external_product/ggsw_ct.rs @@ -12,7 +12,6 @@ impl GGSWCiphertext> { #[allow(clippy::too_many_arguments)] pub fn external_product_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_in: usize, @@ -23,12 +22,11 @@ impl GGSWCiphertext> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxNormalizeTmpBytes, { - GLWECiphertext::external_product_scratch_space(module, n, basek, k_out, k_in, k_ggsw, digits, rank) + GLWECiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank) } pub fn external_product_inplace_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_ggsw: usize, @@ -38,7 +36,7 @@ impl GGSWCiphertext> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxNormalizeTmpBytes, { - GLWECiphertext::external_product_inplace_scratch_space(module, n, basek, k_out, k_ggsw, digits, rank) + GLWECiphertext::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank) } } @@ -86,7 +84,6 @@ impl GGSWCiphertext { scratch.available() >= GGSWCiphertext::external_product_scratch_space( module, - self.n(), self.basek(), self.k(), lhs.k(), diff --git a/poulpy-core/src/external_product/glwe_ct.rs b/poulpy-core/src/external_product/glwe_ct.rs index 6348d6a..9f725c6 100644 --- a/poulpy-core/src/external_product/glwe_ct.rs +++ b/poulpy-core/src/external_product/glwe_ct.rs @@ -12,7 +12,6 @@ impl GLWECiphertext> { #[allow(clippy::too_many_arguments)] pub fn external_product_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_in: usize, @@ -26,10 +25,9 @@ impl GLWECiphertext> { let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); let out_size: usize = k_out.div_ceil(basek); let ggsw_size: usize = k_ggsw.div_ceil(basek); - let res_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, ggsw_size); - let a_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, in_size); + let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, ggsw_size); + let a_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, in_size); let vmp: usize = module.vmp_apply_tmp_bytes( - n, out_size, in_size, in_size, // rows @@ -37,13 +35,12 @@ impl GLWECiphertext> { rank + 1, // cols out ggsw_size, ); - let normalize: usize = module.vec_znx_normalize_tmp_bytes(n); + let normalize: usize = module.vec_znx_normalize_tmp_bytes(); res_dft + a_dft + (vmp | normalize) } pub fn external_product_inplace_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_ggsw: usize, @@ -53,7 +50,7 @@ impl GLWECiphertext> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxNormalizeTmpBytes, { - Self::external_product_scratch_space(module, n, basek, k_out, k_out, k_ggsw, digits, rank) + Self::external_product_scratch_space(module, basek, k_out, k_out, k_ggsw, digits, rank) } } @@ -91,7 +88,6 @@ impl GLWECiphertext { scratch.available() >= GLWECiphertext::external_product_scratch_space( module, - self.n(), self.basek(), self.k(), lhs.k(), diff --git a/poulpy-core/src/glwe_packing.rs b/poulpy-core/src/glwe_packing.rs index e93ce11..30ad712 100644 --- a/poulpy-core/src/glwe_packing.rs +++ b/poulpy-core/src/glwe_packing.rs @@ -89,7 +89,6 @@ impl GLWEPacker { /// Number of scratch space bytes required to call [Self::add]. pub fn scratch_space( module: &Module, - n: usize, basek: usize, ct_k: usize, k_ksk: usize, @@ -99,7 +98,7 @@ impl GLWEPacker { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { - pack_core_scratch_space(module, n, basek, ct_k, k_ksk, digits, rank) + pack_core_scratch_space(module, basek, ct_k, k_ksk, digits, rank) } pub fn galois_elements(module: &Module) -> Vec { @@ -180,7 +179,6 @@ impl GLWEPacker { fn pack_core_scratch_space( module: &Module, - n: usize, basek: usize, ct_k: usize, k_ksk: usize, @@ -190,7 +188,7 @@ fn pack_core_scratch_space( where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { - combine_scratch_space(module, n, basek, ct_k, k_ksk, digits, rank) + combine_scratch_space(module, basek, ct_k, k_ksk, digits, rank) } fn pack_core( @@ -275,7 +273,6 @@ fn pack_core( fn combine_scratch_space( module: &Module, - n: usize, basek: usize, ct_k: usize, k_ksk: usize, @@ -285,9 +282,9 @@ fn combine_scratch_space( where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { - GLWECiphertext::bytes_of(n, basek, ct_k, rank) - + (GLWECiphertext::rsh_scratch_space(n) - | GLWECiphertext::automorphism_scratch_space(module, n, basek, ct_k, ct_k, k_ksk, digits, rank)) + GLWECiphertext::bytes_of(module.n(), basek, ct_k, rank) + + (GLWECiphertext::rsh_scratch_space(module.n()) + | GLWECiphertext::automorphism_scratch_space(module, basek, ct_k, ct_k, k_ksk, digits, rank)) } /// [combine] merges two ciphertexts together. diff --git a/poulpy-core/src/glwe_trace.rs b/poulpy-core/src/glwe_trace.rs index c06f53c..4ab0116 100644 --- a/poulpy-core/src/glwe_trace.rs +++ b/poulpy-core/src/glwe_trace.rs @@ -30,7 +30,6 @@ impl GLWECiphertext> { #[allow(clippy::too_many_arguments)] pub fn trace_scratch_space( module: &Module, - n: usize, basek: usize, out_k: usize, in_k: usize, @@ -41,12 +40,11 @@ impl GLWECiphertext> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { - Self::automorphism_inplace_scratch_space(module, n, basek, out_k.min(in_k), ksk_k, digits, rank) + Self::automorphism_inplace_scratch_space(module, basek, out_k.min(in_k), ksk_k, digits, rank) } pub fn trace_inplace_scratch_space( module: &Module, - n: usize, basek: usize, out_k: usize, ksk_k: usize, @@ -56,7 +54,7 @@ impl GLWECiphertext> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { - Self::automorphism_inplace_scratch_space(module, n, basek, out_k, ksk_k, digits, rank) + Self::automorphism_inplace_scratch_space(module, basek, out_k, ksk_k, digits, rank) } } diff --git a/poulpy-core/src/keyswitching/gglwe_ct.rs b/poulpy-core/src/keyswitching/gglwe_ct.rs index bf9bc70..9aad51d 100644 --- a/poulpy-core/src/keyswitching/gglwe_ct.rs +++ b/poulpy-core/src/keyswitching/gglwe_ct.rs @@ -15,7 +15,6 @@ impl GGLWEAutomorphismKey> { #[allow(clippy::too_many_arguments)] pub fn keyswitch_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_in: usize, @@ -26,12 +25,11 @@ impl GGLWEAutomorphismKey> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { - GGLWESwitchingKey::keyswitch_scratch_space(module, n, basek, k_out, k_in, k_ksk, digits, rank, rank) + GGLWESwitchingKey::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank) } pub fn keyswitch_inplace_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_ksk: usize, @@ -41,7 +39,7 @@ impl GGLWEAutomorphismKey> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { - GGLWESwitchingKey::keyswitch_inplace_scratch_space(module, n, basek, k_out, k_ksk, digits, rank) + GGLWESwitchingKey::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) } } @@ -92,7 +90,6 @@ impl GGLWESwitchingKey> { #[allow(clippy::too_many_arguments)] pub fn keyswitch_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_in: usize, @@ -104,14 +101,11 @@ impl GGLWESwitchingKey> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { - GLWECiphertext::keyswitch_scratch_space( - module, n, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out, - ) + GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out) } pub fn keyswitch_inplace_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_ksk: usize, @@ -121,7 +115,7 @@ impl GGLWESwitchingKey> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { - GLWECiphertext::keyswitch_inplace_scratch_space(module, n, basek, k_out, k_ksk, digits, rank) + GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank) } } diff --git a/poulpy-core/src/keyswitching/ggsw_ct.rs b/poulpy-core/src/keyswitching/ggsw_ct.rs index 1d8a4c0..e90bb77 100644 --- a/poulpy-core/src/keyswitching/ggsw_ct.rs +++ b/poulpy-core/src/keyswitching/ggsw_ct.rs @@ -19,7 +19,6 @@ use crate::{ impl GGSWCiphertext> { pub(crate) fn expand_row_scratch_space( module: &Module, - n: usize, basek: usize, self_k: usize, k_tsk: usize, @@ -33,10 +32,9 @@ impl GGSWCiphertext> { let self_size_out: usize = self_k.div_ceil(basek); let self_size_in: usize = self_size_out.div_ceil(digits); - let tmp_dft_i: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, tsk_size); - let tmp_a: usize = module.vec_znx_dft_alloc_bytes(n, 1, self_size_in); + let tmp_dft_i: usize = module.vec_znx_dft_alloc_bytes(rank + 1, tsk_size); + let tmp_a: usize = module.vec_znx_dft_alloc_bytes(1, self_size_in); let vmp: usize = module.vmp_apply_tmp_bytes( - n, self_size_out, self_size_in, self_size_in, @@ -44,15 +42,14 @@ impl GGSWCiphertext> { rank, tsk_size, ); - let tmp_idft: usize = module.vec_znx_big_alloc_bytes(n, 1, tsk_size); - let norm: usize = module.vec_znx_normalize_tmp_bytes(n); + let tmp_idft: usize = module.vec_znx_big_alloc_bytes(1, tsk_size); + let norm: usize = module.vec_znx_normalize_tmp_bytes(); tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm)) } #[allow(clippy::too_many_arguments)] pub fn keyswitch_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_in: usize, @@ -67,18 +64,17 @@ impl GGSWCiphertext> { VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes, { let out_size: usize = k_out.div_ceil(basek); - let res_znx: usize = VecZnx::alloc_bytes(n, rank + 1, out_size); - let ci_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, out_size); - let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, n, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank); - let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, n, basek, k_out, k_tsk, digits_tsk, rank); - let res_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, out_size); + let res_znx: usize = VecZnx::alloc_bytes(module.n(), rank + 1, out_size); + let ci_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, out_size); + let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank); + let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank); + let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, out_size); res_znx + ci_dft + (ks | expand_rows | res_dft) } #[allow(clippy::too_many_arguments)] pub fn keyswitch_inplace_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_ksk: usize, @@ -92,7 +88,7 @@ impl GGSWCiphertext> { VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes, { GGSWCiphertext::keyswitch_scratch_space( - module, n, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank, + module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank, ) } } @@ -217,7 +213,6 @@ impl GGSWCiphertext { scratch.available() >= GGSWCiphertext::expand_row_scratch_space( module, - self.n(), self.basek(), self.k(), tsk.k(), diff --git a/poulpy-core/src/keyswitching/glwe_ct.rs b/poulpy-core/src/keyswitching/glwe_ct.rs index 5dd402b..c20f7ce 100644 --- a/poulpy-core/src/keyswitching/glwe_ct.rs +++ b/poulpy-core/src/keyswitching/glwe_ct.rs @@ -12,7 +12,6 @@ impl GLWECiphertext> { #[allow(clippy::too_many_arguments)] pub fn keyswitch_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_in: usize, @@ -27,24 +26,16 @@ impl GLWECiphertext> { let in_size: usize = k_in.div_ceil(basek).div_ceil(digits); let out_size: usize = k_out.div_ceil(basek); let ksk_size: usize = k_ksk.div_ceil(basek); - let res_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank_out + 1, ksk_size); // TODO OPTIMIZE - let ai_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank_in, in_size); - let vmp: usize = module.vmp_apply_tmp_bytes( - n, - out_size, - in_size, - in_size, - rank_in, - rank_out + 1, - ksk_size, - ) + module.vec_znx_dft_alloc_bytes(n, rank_in, in_size); - let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(n); + let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank_out + 1, ksk_size); // TODO OPTIMIZE + let ai_dft: usize = module.vec_znx_dft_alloc_bytes(rank_in, in_size); + let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank_in, rank_out + 1, ksk_size) + + module.vec_znx_dft_alloc_bytes(rank_in, in_size); + let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); res_dft + ((ai_dft + vmp) | normalize) } pub fn keyswitch_inplace_scratch_space( module: &Module, - n: usize, basek: usize, k_out: usize, k_ksk: usize, @@ -54,7 +45,7 @@ impl GLWECiphertext> { where Module: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes, { - Self::keyswitch_scratch_space(module, n, basek, k_out, k_out, k_ksk, digits, rank, rank) + Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank) } } @@ -95,7 +86,6 @@ impl GLWECiphertext { scratch.available() >= GLWECiphertext::keyswitch_scratch_space( module, - self.n(), self.basek(), self.k(), lhs.k(), @@ -117,7 +107,6 @@ impl GLWECiphertext { scratch.available(), GLWECiphertext::keyswitch_scratch_space( module, - self.n(), self.basek(), self.k(), lhs.k(), diff --git a/poulpy-core/src/keyswitching/lwe_ct.rs b/poulpy-core/src/keyswitching/lwe_ct.rs index d0aca35..f3bcd93 100644 --- a/poulpy-core/src/keyswitching/lwe_ct.rs +++ b/poulpy-core/src/keyswitching/lwe_ct.rs @@ -15,7 +15,6 @@ use crate::{ impl LWECiphertext> { pub fn keyswitch_scratch_space( module: &Module, - n: usize, basek: usize, k_lwe_out: usize, k_lwe_in: usize, @@ -33,8 +32,8 @@ impl LWECiphertext> { + VecZnxBigAddSmallInplace + VecZnxBigNormalize, { - GLWECiphertext::bytes_of(n, basek, k_lwe_out.max(k_lwe_in), 1) - + GLWECiphertext::keyswitch_inplace_scratch_space(module, n, basek, k_lwe_out, k_ksk, 1, 1) + GLWECiphertext::bytes_of(module.n(), basek, k_lwe_out.max(k_lwe_in), 1) + + GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_lwe_out, k_ksk, 1, 1) } } diff --git a/poulpy-core/src/layouts/compressed/gglwe_atk.rs b/poulpy-core/src/layouts/compressed/gglwe_atk.rs index 27c32cc..5a4daef 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_atk.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_atk.rs @@ -106,11 +106,11 @@ impl WriterTo for GGLWEAutomorphismKeyCompressed { } } -impl Decompress> for GGLWEAutomorphismKey { - fn decompress(&mut self, module: &Module, other: &GGLWEAutomorphismKeyCompressed) - where - Module: VecZnxFillUniform + VecZnxCopy, - { +impl Decompress> for GGLWEAutomorphismKey +where + Module: VecZnxFillUniform + VecZnxCopy, +{ + fn decompress(&mut self, module: &Module, other: &GGLWEAutomorphismKeyCompressed) { self.key.decompress(module, &other.key); self.p = other.p; } diff --git a/poulpy-core/src/layouts/compressed/gglwe_ct.rs b/poulpy-core/src/layouts/compressed/gglwe_ct.rs index 2216937..3cfc1a9 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_ct.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_ct.rs @@ -194,11 +194,11 @@ impl WriterTo for GGLWECiphertextCompressed { } } -impl Decompress> for GGLWECiphertext { - fn decompress(&mut self, module: &Module, other: &GGLWECiphertextCompressed) - where - Module: VecZnxFillUniform + VecZnxCopy, - { +impl Decompress> for GGLWECiphertext +where + Module: VecZnxFillUniform + VecZnxCopy, +{ + fn decompress(&mut self, module: &Module, other: &GGLWECiphertextCompressed) { #[cfg(debug_assertions)] { use poulpy_hal::api::ZnxInfos; diff --git a/poulpy-core/src/layouts/compressed/gglwe_ksk.rs b/poulpy-core/src/layouts/compressed/gglwe_ksk.rs index 82f1213..821f2b5 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_ksk.rs @@ -115,11 +115,11 @@ impl WriterTo for GGLWESwitchingKeyCompressed { } } -impl Decompress> for GGLWESwitchingKey { - fn decompress(&mut self, module: &Module, other: &GGLWESwitchingKeyCompressed) - where - Module: VecZnxFillUniform + VecZnxCopy, - { +impl Decompress> for GGLWESwitchingKey +where + Module: VecZnxFillUniform + VecZnxCopy, +{ + fn decompress(&mut self, module: &Module, other: &GGLWESwitchingKeyCompressed) { self.key.decompress(module, &other.key); self.sk_in_n = other.sk_in_n; self.sk_out_n = other.sk_out_n; diff --git a/poulpy-core/src/layouts/compressed/gglwe_tsk.rs b/poulpy-core/src/layouts/compressed/gglwe_tsk.rs index 9c9084b..cb50a2f 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_tsk.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_tsk.rs @@ -139,11 +139,11 @@ impl GGLWETensorKeyCompressed { } } -impl Decompress> for GGLWETensorKey { - fn decompress(&mut self, module: &Module, other: &GGLWETensorKeyCompressed) - where - Module: VecZnxFillUniform + VecZnxCopy, - { +impl Decompress> for GGLWETensorKey +where + Module: VecZnxFillUniform + VecZnxCopy, +{ + fn decompress(&mut self, module: &Module, other: &GGLWETensorKeyCompressed) { #[cfg(debug_assertions)] { assert_eq!( diff --git a/poulpy-core/src/layouts/compressed/ggsw_ct.rs b/poulpy-core/src/layouts/compressed/ggsw_ct.rs index 2ba0757..42a6bd9 100644 --- a/poulpy-core/src/layouts/compressed/ggsw_ct.rs +++ b/poulpy-core/src/layouts/compressed/ggsw_ct.rs @@ -185,11 +185,11 @@ impl WriterTo for GGSWCiphertextCompressed { } } -impl Decompress> for GGSWCiphertext { - fn decompress(&mut self, module: &Module, other: &GGSWCiphertextCompressed) - where - Module: VecZnxFillUniform + VecZnxCopy, - { +impl Decompress> for GGSWCiphertext +where + Module: VecZnxFillUniform + VecZnxCopy, +{ + fn decompress(&mut self, module: &Module, other: &GGSWCiphertextCompressed) { #[cfg(debug_assertions)] { assert_eq!(self.rank(), other.rank()) diff --git a/poulpy-core/src/layouts/compressed/glwe_ct.rs b/poulpy-core/src/layouts/compressed/glwe_ct.rs index 466211f..bd1f8a5 100644 --- a/poulpy-core/src/layouts/compressed/glwe_ct.rs +++ b/poulpy-core/src/layouts/compressed/glwe_ct.rs @@ -111,11 +111,11 @@ impl WriterTo for GLWECiphertextCompressed { } } -impl Decompress> for GLWECiphertext { - fn decompress(&mut self, module: &Module, other: &GLWECiphertextCompressed) - where - Module: VecZnxCopy + VecZnxFillUniform, - { +impl Decompress> for GLWECiphertext +where + Module: VecZnxFillUniform + VecZnxCopy, +{ + fn decompress(&mut self, module: &Module, other: &GLWECiphertextCompressed) { #[cfg(debug_assertions)] { use poulpy_hal::api::ZnxInfos; diff --git a/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs b/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs index 8eff364..9c412f4 100644 --- a/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs @@ -93,7 +93,7 @@ impl GLWEToLWESwitchingKeyCompressed> { )) } - pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank_in: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank_in: usize) -> usize where Module: VecZnxDftAllocBytes + VecZnxBigNormalize @@ -112,6 +112,6 @@ impl GLWEToLWESwitchingKeyCompressed> { + SvpPPolAllocBytes + SvpPPolAlloc, { - GLWEToLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k, rank_in) + GLWEToLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank_in) } } diff --git a/poulpy-core/src/layouts/compressed/lwe_ct.rs b/poulpy-core/src/layouts/compressed/lwe_ct.rs index ed25f85..19e2b80 100644 --- a/poulpy-core/src/layouts/compressed/lwe_ct.rs +++ b/poulpy-core/src/layouts/compressed/lwe_ct.rs @@ -1,7 +1,7 @@ use std::fmt; use poulpy_hal::{ - api::{FillUniform, Reset, VecZnxFillUniform, ZnxInfos, ZnxView, ZnxViewMut}, + api::{FillUniform, Reset, ZnFillUniform, ZnxInfos, ZnxView, ZnxViewMut}, layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, VecZnx, WriterTo}, source::Source, }; @@ -117,13 +117,20 @@ impl WriterTo for LWECiphertextCompressed { } } -impl Decompress> for LWECiphertext { - fn decompress(&mut self, module: &Module, other: &LWECiphertextCompressed) - where - Module: VecZnxFillUniform, - { - let mut source = Source::new(other.seed); - module.vec_znx_fill_uniform(other.basek(), &mut self.data, 0, other.k(), &mut source); +impl Decompress> for LWECiphertext +where + Module: ZnFillUniform, +{ + fn decompress(&mut self, module: &Module, other: &LWECiphertextCompressed) { + let mut source: Source = Source::new(other.seed); + module.zn_fill_uniform( + self.n(), + other.basek(), + &mut self.data, + 0, + other.k(), + &mut source, + ); (0..self.size()).for_each(|i| { self.data.at_mut(0, i)[0] = other.data.at(0, i)[0]; }); diff --git a/poulpy-core/src/layouts/compressed/lwe_ksk.rs b/poulpy-core/src/layouts/compressed/lwe_ksk.rs index b240c4b..450d57a 100644 --- a/poulpy-core/src/layouts/compressed/lwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/lwe_ksk.rs @@ -94,7 +94,7 @@ impl LWESwitchingKeyCompressed> { )) } - pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize where Module: VecZnxDftAllocBytes + VecZnxBigNormalize @@ -113,15 +113,15 @@ impl LWESwitchingKeyCompressed> { + SvpPPolAllocBytes + SvpPPolAlloc, { - LWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k) + LWESwitchingKey::encrypt_sk_scratch_space(module, basek, k) } } -impl Decompress> for LWESwitchingKey { - fn decompress(&mut self, module: &Module, other: &LWESwitchingKeyCompressed) - where - Module: VecZnxCopy + VecZnxFillUniform, - { +impl Decompress> for LWESwitchingKey +where + Module: VecZnxFillUniform + VecZnxCopy, +{ + fn decompress(&mut self, module: &Module, other: &LWESwitchingKeyCompressed) { self.0.decompress(module, &other.0); } } diff --git a/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs b/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs index bb1cbe0..0a9e2f9 100644 --- a/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs @@ -95,7 +95,7 @@ impl LWEToGLWESwitchingKeyCompressed> { )) } - pub fn encrypt_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank_out: usize) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank_out: usize) -> usize where Module: VecZnxDftAllocBytes + VecZnxBigNormalize @@ -114,15 +114,15 @@ impl LWEToGLWESwitchingKeyCompressed> { + SvpPPolAllocBytes + SvpPPolAlloc, { - LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k, rank_out) + LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank_out) } } -impl Decompress> for LWEToGLWESwitchingKey { - fn decompress(&mut self, module: &Module, other: &LWEToGLWESwitchingKeyCompressed) - where - Module: VecZnxCopy + VecZnxFillUniform, - { +impl Decompress> for LWEToGLWESwitchingKey +where + Module: VecZnxFillUniform + VecZnxCopy, +{ + fn decompress(&mut self, module: &Module, other: &LWEToGLWESwitchingKeyCompressed) { self.0.decompress(module, &other.0); } } diff --git a/poulpy-core/src/layouts/compressed/mod.rs b/poulpy-core/src/layouts/compressed/mod.rs index 36157bc..c1fcacf 100644 --- a/poulpy-core/src/layouts/compressed/mod.rs +++ b/poulpy-core/src/layouts/compressed/mod.rs @@ -20,13 +20,8 @@ pub use lwe_ct::*; pub use lwe_ksk::*; pub use lwe_to_glwe_ksk::*; -use poulpy_hal::{ - api::{VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Module}, -}; +use poulpy_hal::layouts::{Backend, Module}; pub trait Decompress { - fn decompress(&mut self, module: &Module, other: &C) - where - Module: VecZnxFillUniform + VecZnxCopy; + fn decompress(&mut self, module: &Module, other: &C); } diff --git a/poulpy-core/src/layouts/lwe_ct.rs b/poulpy-core/src/layouts/lwe_ct.rs index 8126d44..4496c54 100644 --- a/poulpy-core/src/layouts/lwe_ct.rs +++ b/poulpy-core/src/layouts/lwe_ct.rs @@ -2,25 +2,25 @@ use std::fmt; use poulpy_hal::{ api::{FillUniform, Reset, ZnxInfos}, - layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo}, + layouts::{Data, DataMut, DataRef, ReaderFrom, WriterTo, Zn, ZnToMut, ZnToRef}, source::Source, }; #[derive(PartialEq, Eq, Clone)] pub struct LWECiphertext { - pub(crate) data: VecZnx, + pub(crate) data: Zn, pub(crate) k: usize, pub(crate) basek: usize, } impl LWECiphertext { - pub fn data(&self) -> &VecZnx { + pub fn data(&self) -> &Zn { &self.data } } impl LWECiphertext { - pub fn data_mut(&mut self) -> &VecZnx { + pub fn data_mut(&mut self) -> &Zn { &mut self.data } } @@ -53,7 +53,7 @@ impl Reset for LWECiphertext { impl FillUniform for LWECiphertext where - VecZnx: FillUniform, + Zn: FillUniform, { fn fill_uniform(&mut self, source: &mut Source) { self.data.fill_uniform(source); @@ -63,7 +63,7 @@ where impl LWECiphertext> { pub fn alloc(n: usize, basek: usize, k: usize) -> Self { Self { - data: VecZnx::alloc(n + 1, 1, k.div_ceil(basek)), + data: Zn::alloc(n + 1, 1, k.div_ceil(basek)), k, basek, } @@ -72,9 +72,9 @@ impl LWECiphertext> { impl Infos for LWECiphertext where - VecZnx: ZnxInfos, + Zn: ZnxInfos, { - type Inner = VecZnx; + type Inner = Zn; fn n(&self) -> usize { &self.inner().n() - 1 diff --git a/poulpy-core/src/layouts/lwe_pt.rs b/poulpy-core/src/layouts/lwe_pt.rs index 3dd805c..f7a5cba 100644 --- a/poulpy-core/src/layouts/lwe_pt.rs +++ b/poulpy-core/src/layouts/lwe_pt.rs @@ -1,11 +1,11 @@ use std::fmt; -use poulpy_hal::layouts::{Data, DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef}; +use poulpy_hal::layouts::{Data, DataMut, DataRef, Zn, ZnToMut, ZnToRef}; use crate::layouts::{Infos, SetMetaData}; pub struct LWEPlaintext { - pub(crate) data: VecZnx, + pub(crate) data: Zn, pub(crate) k: usize, pub(crate) basek: usize, } @@ -13,7 +13,7 @@ pub struct LWEPlaintext { impl LWEPlaintext> { pub fn alloc(basek: usize, k: usize) -> Self { Self { - data: VecZnx::alloc(1, 1, k.div_ceil(basek)), + data: Zn::alloc(1, 1, k.div_ceil(basek)), k, basek, } @@ -33,7 +33,7 @@ impl fmt::Display for LWEPlaintext { } impl Infos for LWEPlaintext { - type Inner = VecZnx; + type Inner = Zn; fn inner(&self) -> &Self::Inner { &self.data @@ -89,7 +89,7 @@ impl LWEPlaintextToMut for LWEPlaintext { } impl LWEPlaintext { - pub fn data_mut(&mut self) -> &mut VecZnx { + pub fn data_mut(&mut self) -> &mut Zn { &mut self.data } } diff --git a/poulpy-core/src/layouts/prepared/gglwe_atk.rs b/poulpy-core/src/layouts/prepared/gglwe_atk.rs index ce5d5d0..2470075 100644 --- a/poulpy-core/src/layouts/prepared/gglwe_atk.rs +++ b/poulpy-core/src/layouts/prepared/gglwe_atk.rs @@ -15,21 +15,21 @@ pub struct GGLWEAutomorphismKeyPrepared { } impl GGLWEAutomorphismKeyPrepared, B> { - pub fn alloc(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self where Module: VmpPMatAlloc, { GGLWEAutomorphismKeyPrepared::, B> { - key: GGLWESwitchingKeyPrepared::alloc(module, n, basek, k, rows, digits, rank, rank), + key: GGLWESwitchingKeyPrepared::alloc(module, basek, k, rows, digits, rank, rank), p: 0, } } - pub fn bytes_of(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize where Module: VmpPMatAllocBytes, { - GGLWESwitchingKeyPrepared::bytes_of(module, n, basek, k, rows, digits, rank, rank) + GGLWESwitchingKeyPrepared::bytes_of(module, basek, k, rows, digits, rank, rank) } } @@ -88,7 +88,6 @@ where fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGLWEAutomorphismKeyPrepared, B> { let mut atk_prepared: GGLWEAutomorphismKeyPrepared, B> = GGLWEAutomorphismKeyPrepared::alloc( module, - self.n(), self.basek(), self.k(), self.rows(), diff --git a/poulpy-core/src/layouts/prepared/gglwe_ct.rs b/poulpy-core/src/layouts/prepared/gglwe_ct.rs index 2003d4c..51fe2c0 100644 --- a/poulpy-core/src/layouts/prepared/gglwe_ct.rs +++ b/poulpy-core/src/layouts/prepared/gglwe_ct.rs @@ -18,16 +18,7 @@ pub struct GGLWECiphertextPrepared { impl GGLWECiphertextPrepared, B> { #[allow(clippy::too_many_arguments)] - pub fn alloc( - module: &Module, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> Self + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self where Module: VmpPMatAlloc, { @@ -48,7 +39,7 @@ impl GGLWECiphertextPrepared, B> { ); Self { - data: module.vmp_pmat_alloc(n, rows, rank_in, rank_out + 1, size), + data: module.vmp_pmat_alloc(rows, rank_in, rank_out + 1, size), basek, k, digits, @@ -58,7 +49,6 @@ impl GGLWECiphertextPrepared, B> { #[allow(clippy::too_many_arguments)] pub fn bytes_of( module: &Module, - n: usize, basek: usize, k: usize, rows: usize, @@ -85,7 +75,7 @@ impl GGLWECiphertextPrepared, B> { size ); - module.vmp_pmat_alloc_bytes(n, rows, rank_in, rank_out + 1, rows) + module.vmp_pmat_alloc_bytes(rows, rank_in, rank_out + 1, rows) } } @@ -142,7 +132,6 @@ where fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGLWECiphertextPrepared, B> { let mut atk_prepared: GGLWECiphertextPrepared, B> = GGLWECiphertextPrepared::alloc( module, - self.n(), self.basek(), self.k(), self.rows(), diff --git a/poulpy-core/src/layouts/prepared/gglwe_ksk.rs b/poulpy-core/src/layouts/prepared/gglwe_ksk.rs index 4e5f0b0..f5174d0 100644 --- a/poulpy-core/src/layouts/prepared/gglwe_ksk.rs +++ b/poulpy-core/src/layouts/prepared/gglwe_ksk.rs @@ -17,21 +17,12 @@ pub struct GGLWESwitchingKeyPrepared { impl GGLWESwitchingKeyPrepared, B> { #[allow(clippy::too_many_arguments)] - pub fn alloc( - module: &Module, - n: usize, - basek: usize, - k: usize, - rows: usize, - digits: usize, - rank_in: usize, - rank_out: usize, - ) -> Self + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self where Module: VmpPMatAlloc, { GGLWESwitchingKeyPrepared::, B> { - key: GGLWECiphertextPrepared::alloc(module, n, basek, k, rows, digits, rank_in, rank_out), + key: GGLWECiphertextPrepared::alloc(module, basek, k, rows, digits, rank_in, rank_out), sk_in_n: 0, sk_out_n: 0, } @@ -40,7 +31,6 @@ impl GGLWESwitchingKeyPrepared, B> { #[allow(clippy::too_many_arguments)] pub fn bytes_of( module: &Module, - n: usize, basek: usize, k: usize, rows: usize, @@ -51,7 +41,7 @@ impl GGLWESwitchingKeyPrepared, B> { where Module: VmpPMatAllocBytes, { - GGLWECiphertextPrepared::bytes_of(module, n, basek, k, rows, digits, rank_in, rank_out) + GGLWECiphertextPrepared::bytes_of(module, basek, k, rows, digits, rank_in, rank_out) } } @@ -115,7 +105,6 @@ where fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGLWESwitchingKeyPrepared, B> { let mut atk_prepared: GGLWESwitchingKeyPrepared, B> = GGLWESwitchingKeyPrepared::alloc( module, - self.n(), self.basek(), self.k(), self.rows(), diff --git a/poulpy-core/src/layouts/prepared/gglwe_tsk.rs b/poulpy-core/src/layouts/prepared/gglwe_tsk.rs index 3b4500c..0e00702 100644 --- a/poulpy-core/src/layouts/prepared/gglwe_tsk.rs +++ b/poulpy-core/src/layouts/prepared/gglwe_tsk.rs @@ -14,7 +14,7 @@ pub struct GGLWETensorKeyPrepared { } impl GGLWETensorKeyPrepared, B> { - pub fn alloc(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self where Module: VmpPMatAlloc, { @@ -22,18 +22,18 @@ impl GGLWETensorKeyPrepared, B> { let pairs: usize = (((rank + 1) * rank) >> 1).max(1); (0..pairs).for_each(|_| { keys.push(GGLWESwitchingKeyPrepared::alloc( - module, n, basek, k, rows, digits, 1, rank, + module, basek, k, rows, digits, 1, rank, )); }); Self { keys } } - pub fn bytes_of(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize where Module: VmpPMatAllocBytes, { let pairs: usize = (((rank + 1) * rank) >> 1).max(1); - pairs * GGLWESwitchingKeyPrepared::bytes_of(module, n, basek, k, rows, digits, 1, rank) + pairs * GGLWESwitchingKeyPrepared::bytes_of(module, basek, k, rows, digits, 1, rank) } } @@ -118,7 +118,6 @@ where fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGLWETensorKeyPrepared, B> { let mut tsk_prepared: GGLWETensorKeyPrepared, B> = GGLWETensorKeyPrepared::alloc( module, - self.n(), self.basek(), self.k(), self.rows(), diff --git a/poulpy-core/src/layouts/prepared/ggsw_ct.rs b/poulpy-core/src/layouts/prepared/ggsw_ct.rs index dc88e06..09f06da 100644 --- a/poulpy-core/src/layouts/prepared/ggsw_ct.rs +++ b/poulpy-core/src/layouts/prepared/ggsw_ct.rs @@ -17,7 +17,7 @@ pub struct GGSWCiphertextPrepared { } impl GGSWCiphertextPrepared, B> { - pub fn alloc(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self where Module: VmpPMatAlloc, { @@ -40,14 +40,14 @@ impl GGSWCiphertextPrepared, B> { ); Self { - data: module.vmp_pmat_alloc(n, rows, rank + 1, rank + 1, k.div_ceil(basek)), + data: module.vmp_pmat_alloc(rows, rank + 1, rank + 1, k.div_ceil(basek)), basek, k, digits, } } - pub fn bytes_of(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize where Module: VmpPMatAllocBytes, { @@ -67,7 +67,7 @@ impl GGSWCiphertextPrepared, B> { size ); - module.vmp_pmat_alloc_bytes(n, rows, rank + 1, rank + 1, size) + module.vmp_pmat_alloc_bytes(rows, rank + 1, rank + 1, size) } } @@ -122,7 +122,6 @@ where fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GGSWCiphertextPrepared, B> { let mut ggsw_prepared: GGSWCiphertextPrepared, B> = GGSWCiphertextPrepared::alloc( module, - self.n(), self.basek(), self.k(), self.rows(), diff --git a/poulpy-core/src/layouts/prepared/glwe_pk.rs b/poulpy-core/src/layouts/prepared/glwe_pk.rs index 7bc287b..d0a59e2 100644 --- a/poulpy-core/src/layouts/prepared/glwe_pk.rs +++ b/poulpy-core/src/layouts/prepared/glwe_pk.rs @@ -42,23 +42,23 @@ impl GLWEPublicKeyPrepared { } impl GLWEPublicKeyPrepared, B> { - pub fn alloc(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> Self + pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self where Module: VecZnxDftAlloc, { Self { - data: module.vec_znx_dft_alloc(n, rank + 1, k.div_ceil(basek)), + data: module.vec_znx_dft_alloc(rank + 1, k.div_ceil(basek)), basek, k, dist: Distribution::NONE, } } - pub fn bytes_of(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> usize + pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize where Module: VecZnxDftAllocBytes, { - module.vec_znx_dft_alloc_bytes(n, rank + 1, k.div_ceil(basek)) + module.vec_znx_dft_alloc_bytes(rank + 1, k.div_ceil(basek)) } } @@ -68,7 +68,7 @@ where { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GLWEPublicKeyPrepared, B> { let mut pk_prepared: GLWEPublicKeyPrepared, B> = - GLWEPublicKeyPrepared::alloc(module, self.n(), self.basek(), self.k(), self.rank()); + GLWEPublicKeyPrepared::alloc(module, self.basek(), self.k(), self.rank()); pk_prepared.prepare(module, self, scratch); pk_prepared } diff --git a/poulpy-core/src/layouts/prepared/glwe_sk.rs b/poulpy-core/src/layouts/prepared/glwe_sk.rs index 27e7fcc..595b155 100644 --- a/poulpy-core/src/layouts/prepared/glwe_sk.rs +++ b/poulpy-core/src/layouts/prepared/glwe_sk.rs @@ -17,21 +17,21 @@ pub struct GLWESecretPrepared { } impl GLWESecretPrepared, B> { - pub fn alloc(module: &Module, n: usize, rank: usize) -> Self + pub fn alloc(module: &Module, rank: usize) -> Self where Module: SvpPPolAlloc, { Self { - data: module.svp_ppol_alloc(n, rank), + data: module.svp_ppol_alloc(rank), dist: Distribution::NONE, } } - pub fn bytes_of(module: &Module, n: usize, rank: usize) -> usize + pub fn bytes_of(module: &Module, rank: usize) -> usize where Module: SvpPPolAllocBytes, { - module.svp_ppol_alloc_bytes(n, rank) + module.svp_ppol_alloc_bytes(rank) } } @@ -54,7 +54,7 @@ where Module: SvpPrepare + SvpPPolAlloc, { fn prepare_alloc(&self, module: &Module, scratch: &mut poulpy_hal::layouts::Scratch) -> GLWESecretPrepared, B> { - let mut sk_dft: GLWESecretPrepared, B> = GLWESecretPrepared::alloc(module, self.n(), self.rank()); + let mut sk_dft: GLWESecretPrepared, B> = GLWESecretPrepared::alloc(module, self.rank()); sk_dft.prepare(module, self, scratch); sk_dft } diff --git a/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs b/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs index bce3954..199ffcf 100644 --- a/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/layouts/prepared/glwe_to_lwe_ksk.rs @@ -46,20 +46,20 @@ impl GLWEToLWESwitchingKeyPrepared { } impl GLWEToLWESwitchingKeyPrepared, B> { - pub fn alloc(module: &Module, n: usize, basek: usize, k: usize, rows: usize, rank_in: usize) -> Self + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank_in: usize) -> Self where Module: VmpPMatAlloc, { Self(GGLWESwitchingKeyPrepared::alloc( - module, n, basek, k, rows, 1, rank_in, 1, + module, basek, k, rows, 1, rank_in, 1, )) } - pub fn bytes_of(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize) -> usize + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize) -> usize where Module: VmpPMatAllocBytes, { - GGLWESwitchingKeyPrepared::, B>::bytes_of(module, n, basek, k, rows, digits, rank_in, 1) + GGLWESwitchingKeyPrepared::, B>::bytes_of(module, basek, k, rows, digits, rank_in, 1) } } @@ -70,7 +70,6 @@ where fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> GLWEToLWESwitchingKeyPrepared, B> { let mut ksk_prepared: GLWEToLWESwitchingKeyPrepared, B> = GLWEToLWESwitchingKeyPrepared::alloc( module, - self.0.n(), self.0.basek(), self.0.k(), self.0.rows(), diff --git a/poulpy-core/src/layouts/prepared/lwe_ksk.rs b/poulpy-core/src/layouts/prepared/lwe_ksk.rs index 3177671..72e6be7 100644 --- a/poulpy-core/src/layouts/prepared/lwe_ksk.rs +++ b/poulpy-core/src/layouts/prepared/lwe_ksk.rs @@ -46,20 +46,20 @@ impl LWESwitchingKeyPrepared { } impl LWESwitchingKeyPrepared, B> { - pub fn alloc(module: &Module, n: usize, basek: usize, k: usize, rows: usize) -> Self + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize) -> Self where Module: VmpPMatAlloc, { Self(GGLWESwitchingKeyPrepared::alloc( - module, n, basek, k, rows, 1, 1, 1, + module, basek, k, rows, 1, 1, 1, )) } - pub fn bytes_of(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize) -> usize + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize) -> usize where Module: VmpPMatAllocBytes, { - GGLWESwitchingKeyPrepared::, B>::bytes_of(module, n, basek, k, rows, digits, 1, 1) + GGLWESwitchingKeyPrepared::, B>::bytes_of(module, basek, k, rows, digits, 1, 1) } } @@ -68,13 +68,8 @@ where Module: VmpPrepare + VmpPMatAlloc, { fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> LWESwitchingKeyPrepared, B> { - let mut ksk_prepared: LWESwitchingKeyPrepared, B> = LWESwitchingKeyPrepared::alloc( - module, - self.0.n(), - self.0.basek(), - self.0.k(), - self.0.rows(), - ); + let mut ksk_prepared: LWESwitchingKeyPrepared, B> = + LWESwitchingKeyPrepared::alloc(module, self.0.basek(), self.0.k(), self.0.rows()); ksk_prepared.prepare(module, self, scratch); ksk_prepared } diff --git a/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs b/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs index 806fa7f..de5c5e1 100644 --- a/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/layouts/prepared/lwe_to_glwe_ksk.rs @@ -47,20 +47,20 @@ impl LWEToGLWESwitchingKeyPrepared { } impl LWEToGLWESwitchingKeyPrepared, B> { - pub fn alloc(module: &Module, n: usize, basek: usize, k: usize, rows: usize, rank_out: usize) -> Self + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank_out: usize) -> Self where Module: VmpPMatAlloc, { Self(GGLWESwitchingKeyPrepared::alloc( - module, n, basek, k, rows, 1, 1, rank_out, + module, basek, k, rows, 1, 1, rank_out, )) } - pub fn bytes_of(module: &Module, n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank_out: usize) -> usize + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_out: usize) -> usize where Module: VmpPMatAllocBytes, { - GGLWESwitchingKeyPrepared::, B>::bytes_of(module, n, basek, k, rows, digits, 1, rank_out) + GGLWESwitchingKeyPrepared::, B>::bytes_of(module, basek, k, rows, digits, 1, rank_out) } } @@ -71,7 +71,6 @@ where fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> LWEToGLWESwitchingKeyPrepared, B> { let mut ksk_prepared: LWEToGLWESwitchingKeyPrepared, B> = LWEToGLWESwitchingKeyPrepared::alloc( module, - self.0.n(), self.0.basek(), self.0.k(), self.0.rows(), diff --git a/poulpy-core/src/noise/gglwe_ct.rs b/poulpy-core/src/noise/gglwe_ct.rs index 6fe8af2..829b497 100644 --- a/poulpy-core/src/noise/gglwe_ct.rs +++ b/poulpy-core/src/noise/gglwe_ct.rs @@ -36,12 +36,7 @@ impl GGLWECiphertext { let basek: usize = self.basek(); let k: usize = self.k(); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space( - module, - self.n(), - basek, - k, - )); + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, basek, k)); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); (0..self.rank_in()).for_each(|col_i| { diff --git a/poulpy-core/src/noise/ggsw_ct.rs b/poulpy-core/src/noise/ggsw_ct.rs index 4b5bfb8..d1237ce 100644 --- a/poulpy-core/src/noise/ggsw_ct.rs +++ b/poulpy-core/src/noise/ggsw_ct.rs @@ -45,12 +45,11 @@ impl GGSWCiphertext { let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); - let mut pt_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(self.n(), 1, self.size()); - let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(self.n(), 1, self.size()); + let mut pt_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(1, self.size()); + let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(1, self.size()); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::decrypt_scratch_space(module, self.n(), basek, k) | module.vec_znx_normalize_tmp_bytes(self.n()), - ); + let mut scratch: ScratchOwned = + ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, basek, k) | module.vec_znx_normalize_tmp_bytes()); (0..self.rank() + 1).for_each(|col_j| { (0..self.rows()).for_each(|row_i| { @@ -112,12 +111,11 @@ impl GGSWCiphertext { let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::alloc(self.n(), basek, k); - let mut pt_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(self.n(), 1, self.size()); - let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(self.n(), 1, self.size()); + let mut pt_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(1, self.size()); + let mut pt_big: VecZnxBig, B> = module.vec_znx_big_alloc(1, self.size()); - let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::decrypt_scratch_space(module, self.n(), basek, k) | module.vec_znx_normalize_tmp_bytes(module.n()), - ); + let mut scratch: ScratchOwned = + ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, basek, k) | module.vec_znx_normalize_tmp_bytes()); (0..self.rank() + 1).for_each(|col_j| { (0..self.rows()).for_each(|row_i| { diff --git a/poulpy-core/src/noise/glwe_ct.rs b/poulpy-core/src/noise/glwe_ct.rs index 6f9b0e3..793e732 100644 --- a/poulpy-core/src/noise/glwe_ct.rs +++ b/poulpy-core/src/noise/glwe_ct.rs @@ -41,7 +41,6 @@ impl GLWECiphertext { let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space( module, - self.n(), self.basek(), self.k(), )); diff --git a/poulpy-core/src/tests/generics/automorphism/gglwe_atk.rs b/poulpy-core/src/tests/generics/automorphism/gglwe_atk.rs index 1bb0fcc..9b63478 100644 --- a/poulpy-core/src/tests/generics/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/tests/generics/automorphism/gglwe_atk.rs @@ -92,8 +92,8 @@ pub fn test_gglwe_automorphism_key_automorphism( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, n, basek, k_apply, rank) - | GGLWEAutomorphismKey::automorphism_scratch_space(module, n, basek, k_out, k_in, k_apply, digits, rank), + GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_apply, rank) + | GGLWEAutomorphismKey::automorphism_scratch_space(module, basek, k_out, k_in, k_apply, digits, rank), ); let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); @@ -120,7 +120,7 @@ pub fn test_gglwe_automorphism_key_automorphism( ); let mut auto_key_apply_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, n, basek, k_apply, rows_apply, digits, rank); + GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_apply, rows_apply, digits, rank); auto_key_apply_prepared.prepare(module, &auto_key_apply, scratch.borrow()); @@ -266,8 +266,8 @@ pub fn test_gglwe_automorphism_key_automorphism_inplace( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, n, basek, k_apply, rank) - | GGLWEAutomorphismKey::automorphism_inplace_scratch_space(module, n, basek, k_in, k_apply, digits, rank), + GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_apply, rank) + | GGLWEAutomorphismKey::automorphism_inplace_scratch_space(module, basek, k_in, k_apply, digits, rank), ); let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); @@ -294,7 +294,7 @@ pub fn test_gglwe_automorphism_key_automorphism_inplace( ); let mut auto_key_apply_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, n, basek, k_apply, rows_apply, digits, rank); + GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_apply, rows_apply, digits, rank); auto_key_apply_prepared.prepare(module, &auto_key_apply, scratch.borrow()); diff --git a/poulpy-core/src/tests/generics/automorphism/ggsw_ct.rs b/poulpy-core/src/tests/generics/automorphism/ggsw_ct.rs index 653fba9..269cebc 100644 --- a/poulpy-core/src/tests/generics/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/tests/generics/automorphism/ggsw_ct.rs @@ -102,11 +102,11 @@ pub fn test_ggsw_automorphism( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k_in, rank) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, n, basek, k_ksk, rank) - | GGLWETensorKey::encrypt_sk_scratch_space(module, n, basek, k_tsk, rank) + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_in, rank) + | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) + | GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) | GGSWCiphertext::automorphism_scratch_space( - module, n, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, + module, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, ), ); @@ -144,11 +144,11 @@ pub fn test_ggsw_automorphism( ); let mut auto_key_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, n, basek, k_ksk, rows, digits, rank); + GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_ksk, rows, digits, rank); auto_key_prepared.prepare(module, &auto_key, scratch.borrow()); let mut tsk_prepared: GGLWETensorKeyPrepared, B> = - GGLWETensorKeyPrepared::alloc(module, n, basek, k_tsk, rows, digits, rank); + GGLWETensorKeyPrepared::alloc(module, basek, k_tsk, rows, digits, rank); tsk_prepared.prepare(module, &tensor_key, scratch.borrow()); ct_out.automorphism( @@ -255,10 +255,10 @@ pub fn test_ggsw_automorphism_inplace( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k_ct, rank) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, n, basek, k_ksk, rank) - | GGLWETensorKey::encrypt_sk_scratch_space(module, n, basek, k_tsk, rank) - | GGSWCiphertext::automorphism_inplace_scratch_space(module, n, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ct, rank) + | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) + | GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) + | GGSWCiphertext::automorphism_inplace_scratch_space(module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), ); let var_xs: f64 = 0.5; @@ -295,11 +295,11 @@ pub fn test_ggsw_automorphism_inplace( ); let mut auto_key_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, n, basek, k_ksk, rows, digits, rank); + GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_ksk, rows, digits, rank); auto_key_prepared.prepare(module, &auto_key, scratch.borrow()); let mut tsk_prepared: GGLWETensorKeyPrepared, B> = - GGLWETensorKeyPrepared::alloc(module, n, basek, k_tsk, rows, digits, rank); + GGLWETensorKeyPrepared::alloc(module, basek, k_tsk, rows, digits, rank); tsk_prepared.prepare(module, &tensor_key, scratch.borrow()); ct.automorphism_inplace(module, &auto_key_prepared, &tsk_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/generics/automorphism/glwe_ct.rs b/poulpy-core/src/tests/generics/automorphism/glwe_ct.rs index 39a33d7..b04ac0d 100644 --- a/poulpy-core/src/tests/generics/automorphism/glwe_ct.rs +++ b/poulpy-core/src/tests/generics/automorphism/glwe_ct.rs @@ -89,12 +89,11 @@ pub fn test_glwe_automorphism( module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_in, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, n, basek, autokey.k(), rank) - | GLWECiphertext::decrypt_scratch_space(module, n, basek, ct_out.k()) - | GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, ct_in.k()) + GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, autokey.k(), rank) + | GLWECiphertext::decrypt_scratch_space(module, basek, ct_out.k()) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_in.k()) | GLWECiphertext::automorphism_scratch_space( module, - n, basek, ct_out.k(), ct_in.k(), @@ -127,7 +126,7 @@ pub fn test_glwe_automorphism( ); let mut autokey_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, n, basek, k_ksk, rows, digits, rank); + GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_ksk, rows, digits, rank); autokey_prepared.prepare(module, &autokey, scratch.borrow()); ct_out.automorphism(module, &ct_in, &autokey_prepared, scratch.borrow()); @@ -213,10 +212,10 @@ pub fn test_glwe_automorphism_inplace( module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_ct, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, n, basek, autokey.k(), rank) - | GLWECiphertext::decrypt_scratch_space(module, n, basek, ct.k()) - | GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, ct.k()) - | GLWECiphertext::automorphism_inplace_scratch_space(module, n, basek, ct.k(), autokey.k(), digits, rank), + GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, autokey.k(), rank) + | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) + | GLWECiphertext::automorphism_inplace_scratch_space(module, basek, ct.k(), autokey.k(), digits, rank), ); let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); @@ -242,7 +241,7 @@ pub fn test_glwe_automorphism_inplace( ); let mut autokey_prepared: GGLWEAutomorphismKeyPrepared, B> = - GGLWEAutomorphismKeyPrepared::alloc(module, n, basek, k_ksk, rows, digits, rank); + GGLWEAutomorphismKeyPrepared::alloc(module, basek, k_ksk, rows, digits, rank); autokey_prepared.prepare(module, &autokey, scratch.borrow()); ct.automorphism_inplace(module, &autokey_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/generics/conversion.rs b/poulpy-core/src/tests/generics/conversion.rs index 34ad4e5..9f0aedc 100644 --- a/poulpy-core/src/tests/generics/conversion.rs +++ b/poulpy-core/src/tests/generics/conversion.rs @@ -5,7 +5,7 @@ use poulpy_hal::{ VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, VmpPMatAlloc, VmpPrepare, - ZnxView, + ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, ZnxView, }, layouts::{Backend, Module, ScratchOwned}, oep::{ @@ -50,7 +50,10 @@ where + VmpApplyAdd + VecZnxBigNormalizeTmpBytes + VecZnxSwithcDegree - + VecZnxAutomorphismInplace, + + VecZnxAutomorphismInplace + + ZnNormalizeInplace + + ZnFillUniform + + ZnAddNormal, B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl @@ -79,9 +82,9 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k_ksk, rank) - | GLWECiphertext::from_lwe_scratch_space(module, n, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) - | GLWECiphertext::decrypt_scratch_space(module, n, basek, k_glwe_ct), + LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) + | GLWECiphertext::from_lwe_scratch_space(module, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) + | GLWECiphertext::decrypt_scratch_space(module, basek, k_glwe_ct), ); let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n, rank); @@ -152,7 +155,8 @@ where + VmpApplyAdd + VecZnxBigNormalizeTmpBytes + VecZnxSwithcDegree - + VecZnxAutomorphismInplace, + + VecZnxAutomorphismInplace + + ZnNormalizeInplace, B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl @@ -181,9 +185,9 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k_ksk, rank) - | LWECiphertext::from_glwe_scratch_space(module, n, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) - | GLWECiphertext::decrypt_scratch_space(module, n, basek, k_glwe_ct), + LWEToGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) + | LWECiphertext::from_glwe_scratch_space(module, basek, k_lwe_ct, k_glwe_ct, k_ksk, rank) + | GLWECiphertext::decrypt_scratch_space(module, basek, k_glwe_ct), ); let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n, rank); diff --git a/poulpy-core/src/tests/generics/encryption/gglwe_atk.rs b/poulpy-core/src/tests/generics/encryption/gglwe_atk.rs index 7945281..8a71612 100644 --- a/poulpy-core/src/tests/generics/encryption/gglwe_atk.rs +++ b/poulpy-core/src/tests/generics/encryption/gglwe_atk.rs @@ -77,7 +77,7 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWEAutomorphismKey::encrypt_sk_scratch_space( - module, n, basek, k_ksk, rank, + module, basek, k_ksk, rank, )); let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); @@ -169,7 +169,7 @@ pub fn test_gglwe_automorphisk_key_compressed_encrypt_sk( let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWEAutomorphismKey::encrypt_sk_scratch_space( - module, n, basek, k_ksk, rank, + module, basek, k_ksk, rank, )); let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); diff --git a/poulpy-core/src/tests/generics/encryption/gglwe_ct.rs b/poulpy-core/src/tests/generics/encryption/gglwe_ct.rs index bf46961..5675b85 100644 --- a/poulpy-core/src/tests/generics/encryption/gglwe_ct.rs +++ b/poulpy-core/src/tests/generics/encryption/gglwe_ct.rs @@ -78,7 +78,7 @@ pub fn test_gglwe_switching_key_encrypt_sk( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::encrypt_sk_scratch_space( - module, n, basek, k_ksk, rank_in, rank_out, + module, basek, k_ksk, rank_in, rank_out, )); let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); @@ -156,7 +156,7 @@ pub fn test_gglwe_switching_key_compressed_encrypt_sk( let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKeyCompressed::encrypt_sk_scratch_space( - module, n, basek, k_ksk, rank_in, rank_out, + module, basek, k_ksk, rank_in, rank_out, )); let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank_in); diff --git a/poulpy-core/src/tests/generics/encryption/ggsw_ct.rs b/poulpy-core/src/tests/generics/encryption/ggsw_ct.rs index d16cbf5..d2f5a74 100644 --- a/poulpy-core/src/tests/generics/encryption/ggsw_ct.rs +++ b/poulpy-core/src/tests/generics/encryption/ggsw_ct.rs @@ -79,7 +79,7 @@ where pt_scalar.fill_ternary_hw(0, n, &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCiphertext::encrypt_sk_scratch_space( - module, n, basek, k, rank, + module, basek, k, rank, )); let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); @@ -154,7 +154,7 @@ where pt_scalar.fill_ternary_hw(0, n, &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::alloc(GGSWCiphertextCompressed::encrypt_sk_scratch_space( - module, n, basek, k, rank, + module, basek, k, rank, )); let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); diff --git a/poulpy-core/src/tests/generics/encryption/glwe_ct.rs b/poulpy-core/src/tests/generics/encryption/glwe_ct.rs index 39e99f6..6db2f46 100644 --- a/poulpy-core/src/tests/generics/encryption/glwe_ct.rs +++ b/poulpy-core/src/tests/generics/encryption/glwe_ct.rs @@ -81,8 +81,8 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(module, n, basek, ct.k()), + GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) + | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()), ); let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); @@ -169,8 +169,8 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertextCompressed::encrypt_sk_scratch_space(module, n, basek, k_ct) - | GLWECiphertext::decrypt_scratch_space(module, n, basek, k_ct), + GLWECiphertextCompressed::encrypt_sk_scratch_space(module, basek, k_ct) + | GLWECiphertext::decrypt_scratch_space(module, basek, k_ct), ); let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); @@ -263,8 +263,8 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::decrypt_scratch_space(module, n, basek, k_ct) - | GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k_ct), + GLWECiphertext::decrypt_scratch_space(module, basek, k_ct) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, k_ct), ); let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); @@ -331,9 +331,9 @@ where let mut source_xu: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(module, n, basek, ct.k()) - | GLWECiphertext::encrypt_pk_scratch_space(module, n, basek, k_pk), + GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) + | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()) + | GLWECiphertext::encrypt_pk_scratch_space(module, basek, k_pk), ); let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); diff --git a/poulpy-core/src/tests/generics/encryption/glwe_tsk.rs b/poulpy-core/src/tests/generics/encryption/glwe_tsk.rs index a87dd0d..e69fb96 100644 --- a/poulpy-core/src/tests/generics/encryption/glwe_tsk.rs +++ b/poulpy-core/src/tests/generics/encryption/glwe_tsk.rs @@ -75,7 +75,6 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWETensorKey::encrypt_sk_scratch_space( module, - n, basek, tensor_key.k(), rank, @@ -95,10 +94,10 @@ where let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k); - let mut sk_ij_dft = module.vec_znx_dft_alloc(n, 1, 1); - let mut sk_ij_big = module.vec_znx_big_alloc(n, 1, 1); + let mut sk_ij_dft = module.vec_znx_dft_alloc(1, 1); + let mut sk_ij_big = module.vec_znx_big_alloc(1, 1); let mut sk_ij: GLWESecret> = GLWESecret::alloc(n, 1); - let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(n, rank, 1); + let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(rank, 1); (0..rank).for_each(|i| { module.vec_znx_dft_from_vec_znx(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); @@ -185,7 +184,6 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWETensorKeyCompressed::encrypt_sk_scratch_space( module, - n, basek, tensor_key_compressed.k(), rank, @@ -204,10 +202,10 @@ where let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(n, basek, k); - let mut sk_ij_dft = module.vec_znx_dft_alloc(n, 1, 1); - let mut sk_ij_big = module.vec_znx_big_alloc(n, 1, 1); + let mut sk_ij_dft = module.vec_znx_dft_alloc(1, 1); + let mut sk_ij_big = module.vec_znx_big_alloc(1, 1); let mut sk_ij: GLWESecret> = GLWESecret::alloc(n, 1); - let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(n, rank, 1); + let mut sk_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(rank, 1); (0..rank).for_each(|i| { module.vec_znx_dft_from_vec_znx(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); diff --git a/poulpy-core/src/tests/generics/external_product/gglwe_ksk.rs b/poulpy-core/src/tests/generics/external_product/gglwe_ksk.rs index e1d5ea0..565df12 100644 --- a/poulpy-core/src/tests/generics/external_product/gglwe_ksk.rs +++ b/poulpy-core/src/tests/generics/external_product/gglwe_ksk.rs @@ -93,9 +93,9 @@ pub fn test_gglwe_switching_key_external_product( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k_in, rank_in, rank_out) - | GGLWESwitchingKey::external_product_scratch_space(module, n, basek, k_out, k_in, k_ggsw, digits, rank_out) - | GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k_ggsw, rank_out), + GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_in, rank_in, rank_out) + | GGLWESwitchingKey::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank_out) + | GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank_out), ); let r: usize = 1; @@ -231,9 +231,9 @@ pub fn test_gglwe_switching_key_external_product_inplace( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k_ct, rank_in, rank_out) - | GGLWESwitchingKey::external_product_inplace_scratch_space(module, n, basek, k_ct, k_ggsw, digits, rank_out) - | GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k_ggsw, rank_out), + GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ct, rank_in, rank_out) + | GGLWESwitchingKey::external_product_inplace_scratch_space(module, basek, k_ct, k_ggsw, digits, rank_out) + | GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank_out), ); let r: usize = 1; diff --git a/poulpy-core/src/tests/generics/external_product/ggsw_ct.rs b/poulpy-core/src/tests/generics/external_product/ggsw_ct.rs index eb6b8c3..04e99a4 100644 --- a/poulpy-core/src/tests/generics/external_product/ggsw_ct.rs +++ b/poulpy-core/src/tests/generics/external_product/ggsw_ct.rs @@ -99,8 +99,8 @@ pub fn test_ggsw_external_product( pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k_ggsw, rank) - | GGSWCiphertext::external_product_scratch_space(module, n, basek, k_out, k_in, k_ggsw, digits, rank), + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank) + | GGSWCiphertext::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank), ); let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); @@ -231,8 +231,8 @@ pub fn test_ggsw_external_product_inplace( pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k_ggsw, rank) - | GGSWCiphertext::external_product_inplace_scratch_space(module, n, basek, k_ct, k_ggsw, digits, rank), + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ggsw, rank) + | GGSWCiphertext::external_product_inplace_scratch_space(module, basek, k_ct, k_ggsw, digits, rank), ); let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); diff --git a/poulpy-core/src/tests/generics/external_product/glwe_ct.rs b/poulpy-core/src/tests/generics/external_product/glwe_ct.rs index a164550..25d7d5c 100644 --- a/poulpy-core/src/tests/generics/external_product/glwe_ct.rs +++ b/poulpy-core/src/tests/generics/external_product/glwe_ct.rs @@ -92,11 +92,10 @@ pub fn test_glwe_external_product( pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, ct_ggsw.k(), rank) - | GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, ct_glwe_in.k()) + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, ct_ggsw.k(), rank) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_glwe_in.k()) | GLWECiphertext::external_product_scratch_space( module, - n, basek, ct_glwe_out.k(), ct_glwe_in.k(), @@ -225,9 +224,9 @@ pub fn test_glwe_external_product_inplace( pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, ct_ggsw.k(), rank) - | GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, ct_glwe.k()) - | GLWECiphertext::external_product_inplace_scratch_space(module, n, basek, ct_glwe.k(), ct_ggsw.k(), digits, rank), + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, ct_ggsw.k(), rank) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_glwe.k()) + | GLWECiphertext::external_product_inplace_scratch_space(module, basek, ct_glwe.k(), ct_ggsw.k(), digits, rank), ); let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); diff --git a/poulpy-core/src/tests/generics/keyswitch/gglwe_ct.rs b/poulpy-core/src/tests/generics/keyswitch/gglwe_ct.rs index 34619e8..5330418 100644 --- a/poulpy-core/src/tests/generics/keyswitch/gglwe_ct.rs +++ b/poulpy-core/src/tests/generics/keyswitch/gglwe_ct.rs @@ -97,7 +97,6 @@ pub fn test_gglwe_switching_key_keyswitch( let mut scratch_enc: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::encrypt_sk_scratch_space( module, - n, basek, k_ksk, rank_in_s0s1 | rank_out_s0s1, @@ -105,7 +104,6 @@ pub fn test_gglwe_switching_key_keyswitch( )); let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::keyswitch_scratch_space( module, - n, basek, k_out, k_in, @@ -237,14 +235,13 @@ pub fn test_gglwe_switching_key_keyswitch_inplace( let mut scratch_enc: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::encrypt_sk_scratch_space( module, - n, basek, k_ksk, rank_in | rank_out, rank_out, )); let mut scratch_apply: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKey::keyswitch_inplace_scratch_space( - module, n, basek, k_ct, k_ksk, digits, rank_out, + module, basek, k_ct, k_ksk, digits, rank_out, )); let var_xs: f64 = 0.5; diff --git a/poulpy-core/src/tests/generics/keyswitch/ggsw_ct.rs b/poulpy-core/src/tests/generics/keyswitch/ggsw_ct.rs index 4b7a873..d82d517 100644 --- a/poulpy-core/src/tests/generics/keyswitch/ggsw_ct.rs +++ b/poulpy-core/src/tests/generics/keyswitch/ggsw_ct.rs @@ -94,11 +94,11 @@ pub fn test_ggsw_keyswitch( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k_in, rank) - | GGLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k_ksk, rank, rank) - | GGLWETensorKey::encrypt_sk_scratch_space(module, n, basek, k_tsk, rank) + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_in, rank) + | GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank, rank) + | GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) | GGSWCiphertext::keyswitch_scratch_space( - module, n, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, + module, basek, k_out, k_in, k_ksk, digits, k_tsk, digits, rank, ), ); @@ -237,10 +237,10 @@ pub fn test_ggsw_keyswitch_inplace( let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k_ct, rank) - | GGLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k_ksk, rank, rank) - | GGLWETensorKey::encrypt_sk_scratch_space(module, n, basek, k_tsk, rank) - | GGSWCiphertext::keyswitch_inplace_scratch_space(module, n, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k_ct, rank) + | GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank, rank) + | GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k_tsk, rank) + | GGSWCiphertext::keyswitch_inplace_scratch_space(module, basek, k_ct, k_ksk, digits, k_tsk, digits, rank), ); let var_xs: f64 = 0.5; diff --git a/poulpy-core/src/tests/generics/keyswitch/glwe_ct.rs b/poulpy-core/src/tests/generics/keyswitch/glwe_ct.rs index 11328fc..b8530dc 100644 --- a/poulpy-core/src/tests/generics/keyswitch/glwe_ct.rs +++ b/poulpy-core/src/tests/generics/keyswitch/glwe_ct.rs @@ -86,11 +86,10 @@ pub fn test_glwe_keyswitch( module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_in, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, ksk.k(), rank_in, rank_out) - | GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, ct_in.k()) + GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, ksk.k(), rank_in, rank_out) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_in.k()) | GLWECiphertext::keyswitch_scratch_space( module, - n, basek, ct_out.k(), ct_in.k(), @@ -200,9 +199,9 @@ where module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_ct, &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GGLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, ksk.k(), rank, rank) - | GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, ct_glwe.k()) - | GLWECiphertext::keyswitch_inplace_scratch_space(module, n, basek, ct_glwe.k(), ksk.k(), digits, rank), + GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, ksk.k(), rank, rank) + | GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct_glwe.k()) + | GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, ct_glwe.k(), ksk.k(), digits, rank), ); let mut sk_in: GLWESecret> = GLWESecret::alloc(n, rank); diff --git a/poulpy-core/src/tests/generics/keyswitch/lwe_ct.rs b/poulpy-core/src/tests/generics/keyswitch/lwe_ct.rs index ba7010e..a5f05d6 100644 --- a/poulpy-core/src/tests/generics/keyswitch/lwe_ct.rs +++ b/poulpy-core/src/tests/generics/keyswitch/lwe_ct.rs @@ -5,7 +5,7 @@ use poulpy_hal::{ VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, VmpPMatAlloc, VmpPrepare, - ZnxView, + ZnAddNormal, ZnFillUniform, ZnNormalizeInplace, ZnxView, }, layouts::{Backend, Module, ScratchOwned}, oep::{ @@ -49,7 +49,10 @@ where + VmpApplyAdd + VecZnxBigNormalizeTmpBytes + VecZnxSwithcDegree - + VecZnxAutomorphismInplace, + + VecZnxAutomorphismInplace + + ZnNormalizeInplace + + ZnFillUniform + + ZnAddNormal, B: Backend + TakeVecZnxDftImpl + TakeVecZnxBigImpl @@ -75,8 +78,8 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - LWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k_ksk) - | LWECiphertext::keyswitch_scratch_space(module, n, basek, k_lwe_ct, k_lwe_ct, k_ksk), + LWESwitchingKey::encrypt_sk_scratch_space(module, basek, k_ksk) + | LWECiphertext::keyswitch_scratch_space(module, basek, k_lwe_ct, k_lwe_ct, k_ksk), ); let mut sk_lwe_in: LWESecret> = LWESecret::alloc(n_lwe_in); diff --git a/poulpy-core/src/tests/generics/packing.rs b/poulpy-core/src/tests/generics/packing.rs index 0d7e559..8562193 100644 --- a/poulpy-core/src/tests/generics/packing.rs +++ b/poulpy-core/src/tests/generics/packing.rs @@ -89,9 +89,9 @@ where let rows: usize = k_ct.div_ceil(basek * digits); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k_ct) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, n, basek, k_ksk, rank) - | GLWEPacker::scratch_space(module, n, basek, k_ct, k_ksk, digits, rank), + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k_ct) + | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_ksk, rank) + | GLWEPacker::scratch_space(module, basek, k_ct, k_ksk, digits, rank), ); let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); diff --git a/poulpy-core/src/tests/generics/trace.rs b/poulpy-core/src/tests/generics/trace.rs index 25bd28d..a4fced0 100644 --- a/poulpy-core/src/tests/generics/trace.rs +++ b/poulpy-core/src/tests/generics/trace.rs @@ -87,10 +87,10 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, ct.k()) - | GLWECiphertext::decrypt_scratch_space(module, n, basek, ct.k()) - | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, n, basek, k_autokey, rank) - | GLWECiphertext::trace_inplace_scratch_space(module, n, basek, ct.k(), k_autokey, digits, rank), + GLWECiphertext::encrypt_sk_scratch_space(module, basek, ct.k()) + | GLWECiphertext::decrypt_scratch_space(module, basek, ct.k()) + | GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, basek, k_autokey, rank) + | GLWECiphertext::trace_inplace_scratch_space(module, basek, ct.k(), k_autokey, digits, rank), ); let mut sk: GLWESecret> = GLWESecret::alloc(n, rank); diff --git a/poulpy-core/src/utils.rs b/poulpy-core/src/utils.rs index 3d54753..6b8e8f5 100644 --- a/poulpy-core/src/utils.rs +++ b/poulpy-core/src/utils.rs @@ -37,19 +37,16 @@ impl GLWEPlaintext { impl LWEPlaintext { pub fn encode_i64(&mut self, data: i64, k: usize) { let basek: usize = self.basek(); - self.data - .encode_coeff_i64(basek, 0, k, 0, data, i64::BITS as usize); + self.data.encode_i64(basek, k, data, i64::BITS as usize); } } impl LWEPlaintext { pub fn decode_i64(&self, k: usize) -> i64 { - self.data.decode_coeff_i64(self.basek(), 0, k, 0) + self.data.decode_i64(self.basek(), k) } pub fn decode_float(&self) -> Float { - let mut data: Vec = vec![Float::new(self.k() as u32)]; - self.data.decode_vec_float(self.basek(), 0, &mut data); - data[0].clone() + self.data.decode_float(self.basek()) } } diff --git a/poulpy-hal/src/api/mod.rs b/poulpy-hal/src/api/mod.rs index e518fcd..6e631bd 100644 --- a/poulpy-hal/src/api/mod.rs +++ b/poulpy-hal/src/api/mod.rs @@ -5,6 +5,7 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp_pmat; +mod zn; mod znx_base; pub use module::*; @@ -14,4 +15,5 @@ pub use vec_znx::*; pub use vec_znx_big::*; pub use vec_znx_dft::*; pub use vmp_pmat::*; +pub use zn::*; pub use znx_base::*; diff --git a/poulpy-hal/src/api/svp_ppol.rs b/poulpy-hal/src/api/svp_ppol.rs index f8a1cfc..bc915ae 100644 --- a/poulpy-hal/src/api/svp_ppol.rs +++ b/poulpy-hal/src/api/svp_ppol.rs @@ -2,18 +2,18 @@ use crate::layouts::{Backend, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPo /// Allocates as [crate::layouts::SvpPPol]. pub trait SvpPPolAlloc { - fn svp_ppol_alloc(&self, n: usize, cols: usize) -> SvpPPolOwned; + fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned; } /// Returns the size in bytes to allocate a [crate::layouts::SvpPPol]. pub trait SvpPPolAllocBytes { - fn svp_ppol_alloc_bytes(&self, n: usize, cols: usize) -> usize; + fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize; } /// Consume a vector of bytes into a [crate::layouts::MatZnx]. /// User must ensure that bytes is memory aligned and that it length is equal to [SvpPPolAllocBytes]. pub trait SvpPPolFromBytes { - fn svp_ppol_from_bytes(&self, n: usize, cols: usize, bytes: Vec) -> SvpPPolOwned; + fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec) -> SvpPPolOwned; } /// Prepare a [crate::layouts::ScalarZnx] into an [crate::layouts::SvpPPol]. diff --git a/poulpy-hal/src/api/vec_znx.rs b/poulpy-hal/src/api/vec_znx.rs index d5ab4bc..f69a7a3 100644 --- a/poulpy-hal/src/api/vec_znx.rs +++ b/poulpy-hal/src/api/vec_znx.rs @@ -7,7 +7,7 @@ use crate::{ pub trait VecZnxNormalizeTmpBytes { /// Returns the minimum number of bytes necessary for normalization. - fn vec_znx_normalize_tmp_bytes(&self, n: usize) -> usize; + fn vec_znx_normalize_tmp_bytes(&self) -> usize; } pub trait VecZnxNormalize { diff --git a/poulpy-hal/src/api/vec_znx_big.rs b/poulpy-hal/src/api/vec_znx_big.rs index 11b52b2..cefaaa4 100644 --- a/poulpy-hal/src/api/vec_znx_big.rs +++ b/poulpy-hal/src/api/vec_znx_big.rs @@ -7,18 +7,18 @@ use crate::{ /// Allocates as [crate::layouts::VecZnxBig]. pub trait VecZnxBigAlloc { - fn vec_znx_big_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxBigOwned; + fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned; } /// Returns the size in bytes to allocate a [crate::layouts::VecZnxBig]. pub trait VecZnxBigAllocBytes { - fn vec_znx_big_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize; + fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize; } /// Consume a vector of bytes into a [crate::layouts::VecZnxBig]. /// User must ensure that bytes is memory aligned and that it length is equal to [VecZnxBigAllocBytes]. pub trait VecZnxBigFromBytes { - fn vec_znx_big_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned; + fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned; } #[allow(clippy::too_many_arguments)] @@ -187,7 +187,7 @@ pub trait VecZnxBigNegateInplace { } pub trait VecZnxBigNormalizeTmpBytes { - fn vec_znx_big_normalize_tmp_bytes(&self, n: usize) -> usize; + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; } pub trait VecZnxBigNormalize { diff --git a/poulpy-hal/src/api/vec_znx_dft.rs b/poulpy-hal/src/api/vec_znx_dft.rs index 8fef8d7..b4590fb 100644 --- a/poulpy-hal/src/api/vec_znx_dft.rs +++ b/poulpy-hal/src/api/vec_znx_dft.rs @@ -3,19 +3,19 @@ use crate::layouts::{ }; pub trait VecZnxDftAlloc { - fn vec_znx_dft_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxDftOwned; + fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned; } pub trait VecZnxDftFromBytes { - fn vec_znx_dft_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned; + fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned; } pub trait VecZnxDftAllocBytes { - fn vec_znx_dft_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize; + fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize; } pub trait VecZnxDftToVecZnxBigTmpBytes { - fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self, n: usize) -> usize; + fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize; } pub trait VecZnxDftToVecZnxBig { diff --git a/poulpy-hal/src/api/vmp_pmat.rs b/poulpy-hal/src/api/vmp_pmat.rs index 7b3b732..088b773 100644 --- a/poulpy-hal/src/api/vmp_pmat.rs +++ b/poulpy-hal/src/api/vmp_pmat.rs @@ -1,27 +1,19 @@ use crate::layouts::{Backend, MatZnxToRef, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef}; pub trait VmpPMatAlloc { - fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned; + fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned; } pub trait VmpPMatAllocBytes { - fn vmp_pmat_alloc_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; + fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; } pub trait VmpPMatFromBytes { - fn vmp_pmat_from_bytes( - &self, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - bytes: Vec, - ) -> VmpPMatOwned; + fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec) -> VmpPMatOwned; } pub trait VmpPrepareTmpBytes { - fn vmp_prepare_tmp_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; + fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; } pub trait VmpPrepare { @@ -35,7 +27,6 @@ pub trait VmpPrepare { pub trait VmpApplyTmpBytes { fn vmp_apply_tmp_bytes( &self, - n: usize, res_size: usize, a_size: usize, b_rows: usize, @@ -81,7 +72,6 @@ pub trait VmpApply { pub trait VmpApplyAddTmpBytes { fn vmp_apply_add_tmp_bytes( &self, - n: usize, res_size: usize, a_size: usize, b_rows: usize, diff --git a/poulpy-hal/src/api/zn.rs b/poulpy-hal/src/api/zn.rs new file mode 100644 index 0000000..60e8ac9 --- /dev/null +++ b/poulpy-hal/src/api/zn.rs @@ -0,0 +1,86 @@ +use rand_distr::Distribution; + +use crate::{ + layouts::{Backend, Scratch, ZnToMut}, + source::Source, +}; + +pub trait ZnNormalizeInplace { + /// Normalizes the selected column of `a`. + fn zn_normalize_inplace(&self, n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + where + A: ZnToMut; +} + +pub trait ZnFillUniform { + /// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\] + fn zn_fill_uniform(&self, n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source) + where + R: ZnToMut; +} + +#[allow(clippy::too_many_arguments)] +pub trait ZnFillDistF64 { + fn zn_fill_dist_f64>( + &self, + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) where + R: ZnToMut; +} + +#[allow(clippy::too_many_arguments)] +pub trait ZnAddDistF64 { + /// Adds vector sampled according to the provided distribution, scaled by 2^{-k} and bounded to \[-bound, bound\]. + fn zn_add_dist_f64>( + &self, + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) where + R: ZnToMut; +} + +#[allow(clippy::too_many_arguments)] +pub trait ZnFillNormal { + fn zn_fill_normal( + &self, + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: ZnToMut; +} + +#[allow(clippy::too_many_arguments)] +pub trait ZnAddNormal { + /// Adds a discrete normal vector scaled by 2^{-k} with the provided standard deviation and bounded to \[-bound, bound\]. + fn zn_add_normal( + &self, + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: ZnToMut; +} diff --git a/poulpy-hal/src/delegates/mod.rs b/poulpy-hal/src/delegates/mod.rs index 595a641..85de88d 100644 --- a/poulpy-hal/src/delegates/mod.rs +++ b/poulpy-hal/src/delegates/mod.rs @@ -5,3 +5,4 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp_pmat; +mod zn; diff --git a/poulpy-hal/src/delegates/svp_ppol.rs b/poulpy-hal/src/delegates/svp_ppol.rs index af76dd7..86e4cc0 100644 --- a/poulpy-hal/src/delegates/svp_ppol.rs +++ b/poulpy-hal/src/delegates/svp_ppol.rs @@ -8,8 +8,8 @@ impl SvpPPolFromBytes for Module where B: Backend + SvpPPolFromBytesImpl, { - fn svp_ppol_from_bytes(&self, n: usize, cols: usize, bytes: Vec) -> SvpPPolOwned { - B::svp_ppol_from_bytes_impl(n, cols, bytes) + fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec) -> SvpPPolOwned { + B::svp_ppol_from_bytes_impl(self.n(), cols, bytes) } } @@ -17,8 +17,8 @@ impl SvpPPolAlloc for Module where B: Backend + SvpPPolAllocImpl, { - fn svp_ppol_alloc(&self, n: usize, cols: usize) -> SvpPPolOwned { - B::svp_ppol_alloc_impl(n, cols) + fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned { + B::svp_ppol_alloc_impl(self.n(), cols) } } @@ -26,8 +26,8 @@ impl SvpPPolAllocBytes for Module where B: Backend + SvpPPolAllocBytesImpl, { - fn svp_ppol_alloc_bytes(&self, n: usize, cols: usize) -> usize { - B::svp_ppol_alloc_bytes_impl(n, cols) + fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize { + B::svp_ppol_alloc_bytes_impl(self.n(), cols) } } diff --git a/poulpy-hal/src/delegates/vec_znx.rs b/poulpy-hal/src/delegates/vec_znx.rs index 4ff3964..e1e3b92 100644 --- a/poulpy-hal/src/delegates/vec_znx.rs +++ b/poulpy-hal/src/delegates/vec_znx.rs @@ -22,8 +22,8 @@ impl VecZnxNormalizeTmpBytes for Module where B: Backend + VecZnxNormalizeTmpBytesImpl, { - fn vec_znx_normalize_tmp_bytes(&self, n: usize) -> usize { - B::vec_znx_normalize_tmp_bytes_impl(self, n) + fn vec_znx_normalize_tmp_bytes(&self) -> usize { + B::vec_znx_normalize_tmp_bytes_impl(self) } } diff --git a/poulpy-hal/src/delegates/vec_znx_big.rs b/poulpy-hal/src/delegates/vec_znx_big.rs index e78092a..953949f 100644 --- a/poulpy-hal/src/delegates/vec_znx_big.rs +++ b/poulpy-hal/src/delegates/vec_znx_big.rs @@ -24,8 +24,8 @@ impl VecZnxBigAlloc for Module where B: Backend + VecZnxBigAllocImpl, { - fn vec_znx_big_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxBigOwned { - B::vec_znx_big_alloc_impl(n, cols, size) + fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned { + B::vec_znx_big_alloc_impl(self.n(), cols, size) } } @@ -33,8 +33,8 @@ impl VecZnxBigFromBytes for Module where B: Backend + VecZnxBigFromBytesImpl, { - fn vec_znx_big_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned { - B::vec_znx_big_from_bytes_impl(n, cols, size, bytes) + fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxBigOwned { + B::vec_znx_big_from_bytes_impl(self.n(), cols, size, bytes) } } @@ -42,8 +42,8 @@ impl VecZnxBigAllocBytes for Module where B: Backend + VecZnxBigAllocBytesImpl, { - fn vec_znx_big_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize { - B::vec_znx_big_alloc_bytes_impl(n, cols, size) + fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize { + B::vec_znx_big_alloc_bytes_impl(self.n(), cols, size) } } @@ -283,8 +283,8 @@ impl VecZnxBigNormalizeTmpBytes for Module where B: Backend + VecZnxBigNormalizeTmpBytesImpl, { - fn vec_znx_big_normalize_tmp_bytes(&self, n: usize) -> usize { - B::vec_znx_big_normalize_tmp_bytes_impl(self, n) + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { + B::vec_znx_big_normalize_tmp_bytes_impl(self) } } diff --git a/poulpy-hal/src/delegates/vec_znx_dft.rs b/poulpy-hal/src/delegates/vec_znx_dft.rs index 7cf602b..8d4c098 100644 --- a/poulpy-hal/src/delegates/vec_znx_dft.rs +++ b/poulpy-hal/src/delegates/vec_znx_dft.rs @@ -20,8 +20,8 @@ impl VecZnxDftFromBytes for Module where B: Backend + VecZnxDftFromBytesImpl, { - fn vec_znx_dft_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { - B::vec_znx_dft_from_bytes_impl(n, cols, size, bytes) + fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { + B::vec_znx_dft_from_bytes_impl(self.n(), cols, size, bytes) } } @@ -29,8 +29,8 @@ impl VecZnxDftAllocBytes for Module where B: Backend + VecZnxDftAllocBytesImpl, { - fn vec_znx_dft_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize { - B::vec_znx_dft_alloc_bytes_impl(n, cols, size) + fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize { + B::vec_znx_dft_alloc_bytes_impl(self.n(), cols, size) } } @@ -38,8 +38,8 @@ impl VecZnxDftAlloc for Module where B: Backend + VecZnxDftAllocImpl, { - fn vec_znx_dft_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxDftOwned { - B::vec_znx_dft_alloc_impl(n, cols, size) + fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned { + B::vec_znx_dft_alloc_impl(self.n(), cols, size) } } @@ -47,8 +47,8 @@ impl VecZnxDftToVecZnxBigTmpBytes for Module where B: Backend + VecZnxDftToVecZnxBigTmpBytesImpl, { - fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self, n: usize) -> usize { - B::vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(self, n) + fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize { + B::vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(self) } } diff --git a/poulpy-hal/src/delegates/vmp_pmat.rs b/poulpy-hal/src/delegates/vmp_pmat.rs index 9465d9e..a34e412 100644 --- a/poulpy-hal/src/delegates/vmp_pmat.rs +++ b/poulpy-hal/src/delegates/vmp_pmat.rs @@ -14,8 +14,8 @@ impl VmpPMatAlloc for Module where B: Backend + VmpPMatAllocImpl, { - fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned { - B::vmp_pmat_alloc_impl(n, rows, cols_in, cols_out, size) + fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned { + B::vmp_pmat_alloc_impl(self.n(), rows, cols_in, cols_out, size) } } @@ -23,8 +23,8 @@ impl VmpPMatAllocBytes for Module where B: Backend + VmpPMatAllocBytesImpl, { - fn vmp_pmat_alloc_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size) + fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + B::vmp_pmat_alloc_bytes_impl(self.n(), rows, cols_in, cols_out, size) } } @@ -32,16 +32,8 @@ impl VmpPMatFromBytes for Module where B: Backend + VmpPMatFromBytesImpl, { - fn vmp_pmat_from_bytes( - &self, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - bytes: Vec, - ) -> VmpPMatOwned { - B::vmp_pmat_from_bytes_impl(n, rows, cols_in, cols_out, size, bytes) + fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec) -> VmpPMatOwned { + B::vmp_pmat_from_bytes_impl(self.n(), rows, cols_in, cols_out, size, bytes) } } @@ -49,8 +41,8 @@ impl VmpPrepareTmpBytes for Module where B: Backend + VmpPrepareTmpBytesImpl, { - fn vmp_prepare_tmp_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - B::vmp_prepare_tmp_bytes_impl(self, n, rows, cols_in, cols_out, size) + fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + B::vmp_prepare_tmp_bytes_impl(self, rows, cols_in, cols_out, size) } } @@ -73,7 +65,6 @@ where { fn vmp_apply_tmp_bytes( &self, - n: usize, res_size: usize, a_size: usize, b_rows: usize, @@ -82,7 +73,7 @@ where b_size: usize, ) -> usize { B::vmp_apply_tmp_bytes_impl( - self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size, + self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size, ) } } @@ -107,7 +98,6 @@ where { fn vmp_apply_add_tmp_bytes( &self, - n: usize, res_size: usize, a_size: usize, b_rows: usize, @@ -116,7 +106,7 @@ where b_size: usize, ) -> usize { B::vmp_apply_add_tmp_bytes_impl( - self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size, + self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size, ) } } diff --git a/poulpy-hal/src/delegates/zn.rs b/poulpy-hal/src/delegates/zn.rs new file mode 100644 index 0000000..e5311bb --- /dev/null +++ b/poulpy-hal/src/delegates/zn.rs @@ -0,0 +1,114 @@ +use crate::{ + api::{ZnAddDistF64, ZnAddNormal, ZnFillDistF64, ZnFillNormal, ZnFillUniform, ZnNormalizeInplace}, + layouts::{Backend, Module, Scratch, ZnToMut}, + oep::{ZnAddDistF64Impl, ZnAddNormalImpl, ZnFillDistF64Impl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl}, + source::Source, +}; + +impl ZnNormalizeInplace for Module +where + B: Backend + ZnNormalizeInplaceImpl, +{ + fn zn_normalize_inplace(&self, n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + where + A: ZnToMut, + { + B::zn_normalize_inplace_impl(n, basek, a, a_col, scratch) + } +} + +impl ZnFillUniform for Module +where + B: Backend + ZnFillUniformImpl, +{ + fn zn_fill_uniform(&self, n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source) + where + R: ZnToMut, + { + B::zn_fill_uniform_impl(n, basek, res, res_col, k, source); + } +} + +impl ZnFillDistF64 for Module +where + B: Backend + ZnFillDistF64Impl, +{ + fn zn_fill_dist_f64>( + &self, + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) where + R: ZnToMut, + { + B::zn_fill_dist_f64_impl(n, basek, res, res_col, k, source, dist, bound); + } +} + +impl ZnAddDistF64 for Module +where + B: Backend + ZnAddDistF64Impl, +{ + fn zn_add_dist_f64>( + &self, + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) where + R: ZnToMut, + { + B::zn_add_dist_f64_impl(n, basek, res, res_col, k, source, dist, bound); + } +} + +impl ZnFillNormal for Module +where + B: Backend + ZnFillNormalImpl, +{ + fn zn_fill_normal( + &self, + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: ZnToMut, + { + B::zn_fill_normal_impl(n, basek, res, res_col, k, source, sigma, bound); + } +} + +impl ZnAddNormal for Module +where + B: Backend + ZnAddNormalImpl, +{ + fn zn_add_normal( + &self, + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: ZnToMut, + { + B::zn_add_normal_impl(n, basek, res, res_col, k, source, sigma, bound); + } +} diff --git a/poulpy-hal/src/layouts/encoding.rs b/poulpy-hal/src/layouts/encoding.rs index 717d90a..cde5e8a 100644 --- a/poulpy-hal/src/layouts/encoding.rs +++ b/poulpy-hal/src/layouts/encoding.rs @@ -3,7 +3,7 @@ use rug::{Assign, Float}; use crate::{ api::{ZnxInfos, ZnxView, ZnxViewMut, ZnxZero}, - layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef}, + layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, Zn, ZnToMut, ZnToRef}, }; impl VecZnx { @@ -202,3 +202,90 @@ impl VecZnx { }); } } + +impl Zn { + pub fn encode_i64(&mut self, basek: usize, k: usize, data: i64, log_max: usize) { + let size: usize = k.div_ceil(basek); + + #[cfg(debug_assertions)] + { + let a: Zn<&mut [u8]> = self.to_mut(); + assert!( + size <= a.size(), + "invalid argument k.div_ceil(basek)={} > a.size()={}", + size, + a.size() + ); + } + + let k_rem: usize = basek - (k % basek); + let mut a: Zn<&mut [u8]> = self.to_mut(); + (0..a.size()).for_each(|j| a.at_mut(0, j)[0] = 0); + + // If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy + // values on the last limb. + // Else we decompose values base2k. + if log_max + k_rem < 63 || k_rem == basek { + a.at_mut(0, size - 1)[0] = data; + } else { + let mask: i64 = (1 << basek) - 1; + let steps: usize = size.min(log_max.div_ceil(basek)); + (size - steps..size) + .rev() + .enumerate() + .for_each(|(j, j_rev)| { + a.at_mut(0, j_rev)[0] = (data >> (j * basek)) & mask; + }) + } + + // Case where prec % k != 0. + if k_rem != basek { + let steps: usize = size.min(log_max.div_ceil(basek)); + (size - steps..size).rev().for_each(|j| { + a.at_mut(0, j)[0] <<= k_rem; + }) + } + } +} + +impl Zn { + pub fn decode_i64(&self, basek: usize, k: usize) -> i64 { + let a: Zn<&[u8]> = self.to_ref(); + let size: usize = k.div_ceil(basek); + let mut res: i64 = 0; + let rem: usize = basek - (k % basek); + (0..size).for_each(|j| { + let x: i64 = a.at(0, j)[0]; + if j == size - 1 && rem != basek { + let k_rem: usize = basek - rem; + res = (res << k_rem) + (x >> rem); + } else { + res = (res << basek) + x; + } + }); + res + } + + pub fn decode_float(&self, basek: usize) -> Float { + let a: Zn<&[u8]> = self.to_ref(); + let size: usize = a.size(); + let prec: u32 = (basek * size) as u32; + + // 2^{basek} + let base: Float = Float::with_val(prec, (1 << basek) as f64); + let mut res: Float = Float::with_val(prec, (1 << basek) as f64); + + // y[i] = sum x[j][i] * 2^{-basek*j} + (0..size).for_each(|i| { + if i == 0 { + res.assign(a.at(0, size - i - 1)[0]); + res /= &base; + } else { + res += Float::with_val(prec, a.at(0, size - i - 1)[0]); + res /= &base; + } + }); + + res + } +} diff --git a/poulpy-hal/src/layouts/mod.rs b/poulpy-hal/src/layouts/mod.rs index 505dd74..ced5a3e 100644 --- a/poulpy-hal/src/layouts/mod.rs +++ b/poulpy-hal/src/layouts/mod.rs @@ -10,6 +10,7 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp_pmat; +mod zn; pub use mat_znx::*; pub use module::*; @@ -21,6 +22,7 @@ pub use vec_znx::*; pub use vec_znx_big::*; pub use vec_znx_dft::*; pub use vmp_pmat::*; +pub use zn::*; pub trait Data = PartialEq + Eq + Sized; pub trait DataRef = Data + AsRef<[u8]>; diff --git a/poulpy-hal/src/layouts/zn.rs b/poulpy-hal/src/layouts/zn.rs new file mode 100644 index 0000000..33ef6ea --- /dev/null +++ b/poulpy-hal/src/layouts/zn.rs @@ -0,0 +1,255 @@ +use std::fmt; + +use crate::{ + alloc_aligned, + api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero}, + layouts::{Data, DataMut, DataRef, ReaderFrom, ToOwnedDeep, WriterTo}, + source::Source, +}; + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use rand::RngCore; + +#[derive(PartialEq, Eq, Clone, Copy)] +pub struct Zn { + pub data: D, + pub n: usize, + pub cols: usize, + pub size: usize, + pub max_size: usize, +} + +impl ToOwnedDeep for Zn { + type Owned = Zn>; + fn to_owned_deep(&self) -> Self::Owned { + Zn { + data: self.data.as_ref().to_vec(), + n: self.n, + cols: self.cols, + size: self.size, + max_size: self.max_size, + } + } +} + +impl fmt::Debug for Zn { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl ZnxInfos for Zn { + fn cols(&self) -> usize { + self.cols + } + + fn rows(&self) -> usize { + 1 + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + self.size + } +} + +impl ZnxSliceSize for Zn { + fn sl(&self) -> usize { + self.n() * self.cols() + } +} + +impl DataView for Zn { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for Zn { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl ZnxView for Zn { + type Scalar = i64; +} + +impl Zn> { + pub fn rsh_scratch_space(n: usize) -> usize { + n * std::mem::size_of::() + } +} + +impl ZnxZero for Zn { + fn zero(&mut self) { + self.raw_mut().fill(0) + } + fn zero_at(&mut self, i: usize, j: usize) { + self.at_mut(i, j).fill(0); + } +} + +impl Zn> { + pub fn alloc_bytes(n: usize, cols: usize, size: usize) -> usize { + n * cols * size * size_of::() + } + + pub fn alloc(n: usize, cols: usize, size: usize) -> Self { + let data: Vec = alloc_aligned::(Self::alloc_bytes(n, cols, size)); + Self { + data, + n, + cols, + size, + max_size: size, + } + } + + pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into>) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::alloc_bytes(n, cols, size)); + Self { + data, + n, + cols, + size, + max_size: size, + } + } +} + +impl Zn { + pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { + Self { + data, + n, + cols, + size, + max_size: size, + } + } +} + +impl fmt::Display for Zn { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "Zn(n={}, cols={}, size={})", + self.n, self.cols, self.size + )?; + + for col in 0..self.cols { + writeln!(f, "Column {}:", col)?; + for size in 0..self.size { + let coeffs = self.at(col, size); + write!(f, " Size {}: [", size)?; + + let max_show = 100; + let show_count = coeffs.len().min(max_show); + + for (i, &coeff) in coeffs.iter().take(show_count).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", coeff)?; + } + + if coeffs.len() > max_show { + write!(f, ", ... ({} more)", coeffs.len() - max_show)?; + } + + writeln!(f, "]")?; + } + } + Ok(()) + } +} + +impl FillUniform for Zn { + fn fill_uniform(&mut self, source: &mut Source) { + source.fill_bytes(self.data.as_mut()); + } +} + +impl Reset for Zn { + fn reset(&mut self) { + self.zero(); + self.n = 0; + self.cols = 0; + self.size = 0; + self.max_size = 0; + } +} + +pub type ZnOwned = Zn>; +pub type ZnMut<'a> = Zn<&'a mut [u8]>; +pub type ZnRef<'a> = Zn<&'a [u8]>; + +pub trait ZnToRef { + fn to_ref(&self) -> Zn<&[u8]>; +} + +impl ZnToRef for Zn { + fn to_ref(&self) -> Zn<&[u8]> { + Zn { + data: self.data.as_ref(), + n: self.n, + cols: self.cols, + size: self.size, + max_size: self.max_size, + } + } +} + +pub trait ZnToMut { + fn to_mut(&mut self) -> Zn<&mut [u8]>; +} + +impl ZnToMut for Zn { + fn to_mut(&mut self) -> Zn<&mut [u8]> { + Zn { + data: self.data.as_mut(), + n: self.n, + cols: self.cols, + size: self.size, + max_size: self.max_size, + } + } +} + +impl ReaderFrom for Zn { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + self.n = reader.read_u64::()? as usize; + self.cols = reader.read_u64::()? as usize; + self.size = reader.read_u64::()? as usize; + self.max_size = reader.read_u64::()? as usize; + let len: usize = reader.read_u64::()? as usize; + let buf: &mut [u8] = self.data.as_mut(); + if buf.len() != len { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + format!("self.data.len()={} != read len={}", buf.len(), len), + )); + } + reader.read_exact(&mut buf[..len])?; + Ok(()) + } +} + +impl WriterTo for Zn { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.n as u64)?; + writer.write_u64::(self.cols as u64)?; + writer.write_u64::(self.size as u64)?; + writer.write_u64::(self.max_size as u64)?; + let buf: &[u8] = self.data.as_ref(); + writer.write_u64::(buf.len() as u64)?; + writer.write_all(buf)?; + Ok(()) + } +} diff --git a/poulpy-hal/src/oep/mod.rs b/poulpy-hal/src/oep/mod.rs index bc53c0e..dac0def 100644 --- a/poulpy-hal/src/oep/mod.rs +++ b/poulpy-hal/src/oep/mod.rs @@ -5,6 +5,7 @@ mod vec_znx; mod vec_znx_big; mod vec_znx_dft; mod vmp_pmat; +mod zn; pub use module::*; pub use scratch::*; @@ -13,3 +14,4 @@ pub use vec_znx::*; pub use vec_znx_big::*; pub use vec_znx_dft::*; pub use vmp_pmat::*; +pub use zn::*; diff --git a/poulpy-hal/src/oep/vec_znx.rs b/poulpy-hal/src/oep/vec_znx.rs index 0aac0dc..ddfe6fe 100644 --- a/poulpy-hal/src/oep/vec_znx.rs +++ b/poulpy-hal/src/oep/vec_znx.rs @@ -10,7 +10,7 @@ use crate::{ /// * See [crate::api::VecZnxNormalizeTmpBytes] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxNormalizeTmpBytesImpl { - fn vec_znx_normalize_tmp_bytes_impl(module: &Module, n: usize) -> usize; + fn vec_znx_normalize_tmp_bytes_impl(module: &Module) -> usize; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) diff --git a/poulpy-hal/src/oep/vec_znx_big.rs b/poulpy-hal/src/oep/vec_znx_big.rs index 3c13393..2764ef2 100644 --- a/poulpy-hal/src/oep/vec_znx_big.rs +++ b/poulpy-hal/src/oep/vec_znx_big.rs @@ -263,7 +263,7 @@ pub unsafe trait VecZnxBigNegateInplaceImpl { /// * See TODO for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxBigNormalizeTmpBytesImpl { - fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module, n: usize) -> usize; + fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module) -> usize; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) diff --git a/poulpy-hal/src/oep/vec_znx_dft.rs b/poulpy-hal/src/oep/vec_znx_dft.rs index 4a962ea..0a28dce 100644 --- a/poulpy-hal/src/oep/vec_znx_dft.rs +++ b/poulpy-hal/src/oep/vec_znx_dft.rs @@ -32,7 +32,7 @@ pub unsafe trait VecZnxDftAllocBytesImpl { /// * See TODO for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxDftToVecZnxBigTmpBytesImpl { - fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module, n: usize) -> usize; + fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module) -> usize; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) diff --git a/poulpy-hal/src/oep/vmp_pmat.rs b/poulpy-hal/src/oep/vmp_pmat.rs index b7971cf..5781ce1 100644 --- a/poulpy-hal/src/oep/vmp_pmat.rs +++ b/poulpy-hal/src/oep/vmp_pmat.rs @@ -38,14 +38,7 @@ pub unsafe trait VmpPMatFromBytesImpl { /// * See TODO for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VmpPrepareTmpBytesImpl { - fn vmp_prepare_tmp_bytes_impl( - module: &Module, - n: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - ) -> usize; + fn vmp_prepare_tmp_bytes_impl(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize; } /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) @@ -67,7 +60,6 @@ pub unsafe trait VmpPMatPrepareImpl { pub unsafe trait VmpApplyTmpBytesImpl { fn vmp_apply_tmp_bytes_impl( module: &Module, - n: usize, res_size: usize, a_size: usize, b_rows: usize, @@ -97,7 +89,6 @@ pub unsafe trait VmpApplyImpl { pub unsafe trait VmpApplyAddTmpBytesImpl { fn vmp_apply_add_tmp_bytes_impl( module: &Module, - n: usize, res_size: usize, a_size: usize, b_rows: usize, diff --git a/poulpy-hal/src/oep/zn.rs b/poulpy-hal/src/oep/zn.rs new file mode 100644 index 0000000..4a35185 --- /dev/null +++ b/poulpy-hal/src/oep/zn.rs @@ -0,0 +1,97 @@ +use rand_distr::Distribution; + +use crate::{ + layouts::{Backend, Scratch, ZnToMut}, + source::Source, +}; + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [zn_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/zn64.c#L9) for reference code. +/// * See [crate::api::ZnxNormalizeInplace] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait ZnNormalizeInplaceImpl { + fn zn_normalize_inplace_impl(n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch) + where + A: ZnToMut; +} + +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::api::ZnFillUniform] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait ZnFillUniformImpl { + fn zn_fill_uniform_impl(n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source) + where + R: ZnToMut; +} + +#[allow(clippy::too_many_arguments)] +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::api::ZnFillDistF64] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait ZnFillDistF64Impl { + fn zn_fill_dist_f64_impl>( + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) where + R: ZnToMut; +} + +#[allow(clippy::too_many_arguments)] +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::api::ZnAddDistF64] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait ZnAddDistF64Impl { + fn zn_add_dist_f64_impl>( + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + dist: D, + bound: f64, + ) where + R: ZnToMut; +} + +#[allow(clippy::too_many_arguments)] +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::api::ZnFillNormal] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait ZnFillNormalImpl { + fn zn_fill_normal_impl( + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: ZnToMut; +} + +#[allow(clippy::too_many_arguments)] +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See [crate::api::ZnAddNormal] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait ZnAddNormalImpl { + fn zn_add_normal_impl( + n: usize, + basek: usize, + res: &mut R, + res_col: usize, + k: usize, + source: &mut Source, + sigma: f64, + bound: f64, + ) where + R: ZnToMut; +} diff --git a/poulpy-hal/src/tests/vmp_pmat/vmp_apply.rs b/poulpy-hal/src/tests/vmp_pmat/vmp_apply.rs index a81d7be..6b77654 100644 --- a/poulpy-hal/src/tests/vmp_pmat/vmp_apply.rs +++ b/poulpy-hal/src/tests/vmp_pmat/vmp_apply.rs @@ -51,14 +51,13 @@ where let mut scratch = ScratchOwned::alloc( module.vmp_apply_tmp_bytes( - n, res_size, a_size, mat_rows, mat_cols_in, mat_cols_out, mat_size, - ) | module.vec_znx_big_normalize_tmp_bytes(n), + ) | module.vec_znx_big_normalize_tmp_bytes(), ); let mut a: VecZnx> = VecZnx::alloc(n, a_cols, a_size); @@ -67,10 +66,10 @@ where a.at_mut(i, a_size - 1)[i + 1] = 1; }); - let mut vmp: VmpPMat, B> = module.vmp_pmat_alloc(n, mat_rows, mat_cols_in, mat_cols_out, mat_size); + let mut vmp: VmpPMat, B> = module.vmp_pmat_alloc(mat_rows, mat_cols_in, mat_cols_out, mat_size); - let mut c_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(n, mat_cols_out, mat_size); - let mut c_big: VecZnxBig, B> = module.vec_znx_big_alloc(n, mat_cols_out, mat_size); + let mut c_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(mat_cols_out, mat_size); + let mut c_big: VecZnxBig, B> = module.vec_znx_big_alloc(mat_cols_out, mat_size); let mut mat: MatZnx> = MatZnx::alloc(n, mat_rows, mat_cols_in, mat_cols_out, mat_size); @@ -86,7 +85,7 @@ where module.vmp_prepare(&mut vmp, &mat, scratch.borrow()); - let mut a_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(n, a_cols, a_size); + let mut a_dft: VecZnxDft, B> = module.vec_znx_dft_alloc(a_cols, a_size); (0..a_cols).for_each(|i| { module.vec_znx_dft_from_vec_znx(1, 0, &mut a_dft, i, &a, i); }); diff --git a/poulpy-schemes/examples/circuit_bootstrapping.rs b/poulpy-schemes/examples/circuit_bootstrapping.rs index 1f2720d..40ca143 100644 --- a/poulpy-schemes/examples/circuit_bootstrapping.rs +++ b/poulpy-schemes/examples/circuit_bootstrapping.rs @@ -8,7 +8,7 @@ use poulpy_core::{ use std::time::Instant; use poulpy_hal::{ - api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalizeInplace, ZnxView, ZnxViewMut}, + api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace, ZnxView, ZnxViewMut}, layouts::{Module, ScalarZnx, ScratchOwned}, source::Source, }; @@ -109,7 +109,7 @@ fn main() { pt_lwe.encode_i64(data, k_lwe_pt + 1); // +1 for padding bit // Normalize plaintext to nicely print coefficients - module.vec_znx_normalize_inplace(basek, pt_lwe.data_mut(), 0, scratch.borrow()); + module.zn_normalize_inplace(pt_lwe.n(), basek, pt_lwe.data_mut(), 0, scratch.borrow()); println!("pt_lwe: {}", pt_lwe); // LWE ciphertext diff --git a/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs b/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs index aabc58f..7c13387 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/cggi_algo.rs @@ -23,7 +23,6 @@ use crate::tfhe::blind_rotation::{ #[allow(clippy::too_many_arguments)] pub fn cggi_blind_rotate_scratch_space( module: &Module, - n: usize, block_size: usize, extension_factor: usize, basek: usize, @@ -44,14 +43,14 @@ where if block_size > 1 { let cols: usize = rank + 1; - let acc_dft: usize = module.vec_znx_dft_alloc_bytes(n, cols, rows) * extension_factor; - let acc_big: usize = module.vec_znx_big_alloc_bytes(n, 1, brk_size); - let vmp_res: usize = module.vec_znx_dft_alloc_bytes(n, cols, brk_size) * extension_factor; - let vmp_xai: usize = module.vec_znx_dft_alloc_bytes(n, 1, brk_size); + let acc_dft: usize = module.vec_znx_dft_alloc_bytes(cols, rows) * extension_factor; + let acc_big: usize = module.vec_znx_big_alloc_bytes(1, brk_size); + let vmp_res: usize = module.vec_znx_dft_alloc_bytes(cols, brk_size) * extension_factor; + let vmp_xai: usize = module.vec_znx_dft_alloc_bytes(1, brk_size); let acc_dft_add: usize = vmp_res; - let vmp: usize = module.vmp_apply_tmp_bytes(n, brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) + let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) let acc: usize = if extension_factor > 1 { - VecZnx::alloc_bytes(n, cols, k_res.div_ceil(basek)) * extension_factor + VecZnx::alloc_bytes(module.n(), cols, k_res.div_ceil(basek)) * extension_factor } else { 0 }; @@ -60,10 +59,10 @@ where + acc_dft_add + vmp_res + vmp_xai - + (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes(n) | module.vec_znx_dft_to_vec_znx_big_tmp_bytes(n)))) + + (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_dft_to_vec_znx_big_tmp_bytes()))) } else { - GLWECiphertext::bytes_of(n, basek, k_res, rank) - + GLWECiphertext::external_product_scratch_space(module, n, basek, k_res, k_res, k_brk, 1, rank) + GLWECiphertext::bytes_of(module.n(), basek, k_res, rank) + + GLWECiphertext::external_product_scratch_space(module, basek, k_res, k_res, k_brk, 1, rank) } } diff --git a/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs b/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs index 2c50fd9..a262fd4 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/cggi_key.rs @@ -38,11 +38,11 @@ impl BlindRotationKeyAlloc for BlindRotationKey, CGGI> { } impl BlindRotationKey, CGGI> { - pub fn generate_from_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> usize + pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize where Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, { - GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k, rank) + GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank) } } @@ -108,11 +108,11 @@ impl BlindRotationKeyPreparedAlloc for BlindRotationKeyPrepared: VmpPMatAlloc + VmpPrepare, { - fn alloc(module: &Module, n_glwe: usize, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self { + fn alloc(module: &Module, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self { let mut data: Vec, B>> = Vec::with_capacity(n_lwe); (0..n_lwe).for_each(|_| { data.push(GGSWCiphertextPrepared::alloc( - module, n_glwe, basek, k, rows, 1, rank, + module, basek, k, rows, 1, rank, )) }); Self { @@ -139,11 +139,11 @@ impl BlindRotationKeyCompressed, CGGI> { } } - pub fn generate_from_sk_scratch_space(module: &Module, n: usize, basek: usize, k: usize, rank: usize) -> usize + pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize where Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, { - GGSWCiphertextCompressed::encrypt_sk_scratch_space(module, n, basek, k, rank) + GGSWCiphertextCompressed::encrypt_sk_scratch_space(module, basek, k, rank) } } diff --git a/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs b/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs index 5c61867..6431167 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/key_prepared.rs @@ -16,7 +16,7 @@ use poulpy_core::{ use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKey, utils::set_xai_plus_y}; pub trait BlindRotationKeyPreparedAlloc { - fn alloc(module: &Module, n_glwe: usize, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self; + fn alloc(module: &Module, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self; } #[derive(PartialEq, Eq)] @@ -74,7 +74,6 @@ where fn prepare_alloc(&self, module: &Module, scratch: &mut Scratch) -> BlindRotationKeyPrepared, BRA, B> { let mut brk: BlindRotationKeyPrepared, BRA, B> = BlindRotationKeyPrepared::alloc( module, - self.n(), self.keys.len(), self.basek(), self.k(), @@ -112,7 +111,7 @@ where let mut x_pow_a: Vec, B>> = Vec::with_capacity(n << 1); let mut buf: ScalarZnx> = ScalarZnx::alloc(n, 1); (0..n << 1).for_each(|i| { - let mut res: SvpPPol, B> = module.svp_ppol_alloc(n, 1); + let mut res: SvpPPol, B> = module.svp_ppol_alloc(1); set_xai_plus_y(module, i, 0, &mut res, &mut buf); x_pow_a.push(res); }); diff --git a/poulpy-schemes/src/tfhe/blind_rotation/lut.rs b/poulpy-schemes/src/tfhe/blind_rotation/lut.rs index 52f590c..cab01b4 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/lut.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/lut.rs @@ -22,7 +22,7 @@ pub struct LookUpTable { } impl LookUpTable { - pub fn alloc(n: usize, basek: usize, k: usize, extension_factor: usize) -> Self { + pub fn alloc(module: &Module, basek: usize, k: usize, extension_factor: usize) -> Self { #[cfg(debug_assertions)] { assert!( @@ -34,7 +34,7 @@ impl LookUpTable { let size: usize = k.div_ceil(basek); let mut data: Vec>> = Vec::with_capacity(extension_factor); (0..extension_factor).for_each(|_| { - data.push(VecZnx::alloc(n, 1, size)); + data.push(VecZnx::alloc(module.n(), 1, size)); }); Self { data, @@ -121,16 +121,6 @@ impl LookUpTable { let drift: usize = step >> 1; // Rotates half the step to the left - module.vec_znx_rotate_inplace(-(drift as i64), &mut lut_full, 0); - - let n_large: usize = lut_full.n(); - - module.vec_znx_normalize_inplace( - self.basek, - &mut lut_full, - 0, - ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes(n_large)).borrow(), - ); if self.extension_factor() > 1 { (0..self.extension_factor()).for_each(|i| { @@ -143,6 +133,14 @@ impl LookUpTable { module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0); } + let mut scratch: ScratchOwned = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes()); + + self.data.iter_mut().for_each(|a| { + module.vec_znx_normalize_inplace(self.basek, a, 0, scratch.borrow()); + }); + + self.rotate(module, -(drift as i64)); + self.drift = drift } diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs index 80915a4..a8649a0 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_blind_rotation.rs @@ -6,7 +6,8 @@ use poulpy_hal::{ VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftSubABInplace, VecZnxDftToVecZnxBig, VecZnxDftToVecZnxBigConsume, VecZnxDftToVecZnxBigTmpBytes, VecZnxDftZero, VecZnxFillUniform, VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxSub, VecZnxSubABInplace, - VecZnxSwithcDegree, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, VmpPMatAlloc, VmpPrepare, ZnxView, + VecZnxSwithcDegree, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, + ZnNormalizeInplace, ZnxView, }, layouts::{Backend, Module, ScratchOwned}, oep::{ @@ -65,7 +66,10 @@ where + VmpPMatAlloc + VmpPrepare + VmpApply - + VmpApplyAdd, + + VmpApplyAdd + + ZnFillUniform + + ZnAddNormal + + ZnNormalizeInplace, B: Backend + VecZnxDftAllocBytesImpl + VecZnxBigAllocBytesImpl @@ -96,7 +100,7 @@ where let mut source_xa: Source = Source::new([1u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::::alloc(BlindRotationKey::generate_from_sk_scratch_space( - module, n, basek, k_brk, rank, + module, basek, k_brk, rank, )); let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n, rank); @@ -108,7 +112,6 @@ where let mut scratch_br: ScratchOwned = ScratchOwned::::alloc(cggi_blind_rotate_scratch_space( module, - n, block_size, extension_factor, basek, @@ -148,7 +151,7 @@ where .enumerate() .for_each(|(i, x)| *x = f(i as i64)); - let mut lut: LookUpTable = LookUpTable::alloc(n, basek, k_lut, extension_factor); + let mut lut: LookUpTable = LookUpTable::alloc(module, basek, k_lut, extension_factor); lut.set(module, &f_vec, log_message_modulus + 1); let mut res: GLWECiphertext> = GLWECiphertext::alloc(n, basek, k_res, rank); diff --git a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs index d4ef98e..6beb131 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/tests/generic_lut.rs @@ -13,7 +13,6 @@ where Module: VecZnxRotateInplace + VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes + VecZnxSwithcDegree + VecZnxCopy, B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { - let n: usize = module.n(); let basek: usize = 20; let k_lut: usize = 40; let message_modulus: usize = 16; @@ -26,7 +25,7 @@ where .enumerate() .for_each(|(i, x)| *x = (i as i64) - 8); - let mut lut: LookUpTable = LookUpTable::alloc(n, basek, k_lut, extension_factor); + let mut lut: LookUpTable = LookUpTable::alloc(module, basek, k_lut, extension_factor); lut.set(module, &f, log_scale); let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; @@ -49,7 +48,6 @@ where Module: VecZnxRotateInplace + VecZnxNormalizeInplace + VecZnxNormalizeTmpBytes + VecZnxSwithcDegree + VecZnxCopy, B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl, { - let n: usize = module.n(); let basek: usize = 20; let k_lut: usize = 40; let message_modulus: usize = 16; @@ -62,7 +60,7 @@ where .enumerate() .for_each(|(i, x)| *x = (i as i64) - 8); - let mut lut: LookUpTable = LookUpTable::alloc(n, basek, k_lut, extension_factor); + let mut lut: LookUpTable = LookUpTable::alloc(module, basek, k_lut, extension_factor); lut.set(module, &f, log_scale); let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs index 331ab98..80b5e22 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs @@ -191,7 +191,7 @@ pub fn circuit_bootstrap_core( } // Lut precision, basically must be able to hold the decomposition power basis of the GGSW - let mut lut: LookUpTable = LookUpTable::alloc(n, basek, basek * rows, extension_factor); + let mut lut: LookUpTable = LookUpTable::alloc(module, basek, basek * rows, extension_factor); lut.set(module, &f, basek * rows); if to_exponent { diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs index 54d37ce..834d9a3 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/tests/circuit_bootstrapping.rs @@ -9,7 +9,8 @@ use poulpy_hal::{ VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxDftToVecZnxBigTmpA, VecZnxFillUniform, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, - VecZnxSwithcDegree, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, VmpPMatAlloc, VmpPrepare, ZnxView, ZnxViewMut, + VecZnxSwithcDegree, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform, + ZnNormalizeInplace, ZnxView, ZnxViewMut, }, layouts::{Backend, Module, ScalarZnx, ScratchOwned}, oep::{ @@ -80,7 +81,10 @@ where + VecZnxBigSubSmallBInplace + VecZnxBigAllocBytes + VecZnxDftAddInplace - + VecZnxRotate, + + VecZnxRotate + + ZnFillUniform + + ZnAddNormal + + ZnNormalizeInplace, B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl @@ -258,7 +262,10 @@ where + VecZnxBigSubSmallBInplace + VecZnxBigAllocBytes + VecZnxDftAddInplace - + VecZnxRotate, + + VecZnxRotate + + ZnFillUniform + + ZnAddNormal + + ZnNormalizeInplace, B: Backend + ScratchOwnedAllocImpl + ScratchOwnedBorrowImpl